├── .envrc ├── .github └── workflows │ ├── flake-ci.yml │ └── haskell.yml ├── .gitignore ├── .tintin.yml ├── .travis.yml ├── Build.hs ├── CHANGELOG.md ├── LICENSE ├── README.md ├── Setup.hs ├── backprop.cabal ├── bench └── bench.hs ├── doc ├── 01-getting-started.md ├── 02-a-detailed-look.md ├── 03-manipulating-bvars.md ├── 04-the-backprop-typeclass.md ├── 05-applications.md ├── 06-manual-gradients.md ├── 07-performance.md ├── 08-equipping-your-library.md ├── 09-comparisons.md └── index.md ├── doctest └── doctest.hs ├── flake.lock ├── flake.nix ├── fourmolu.yaml ├── renders ├── backprop-mnist.md ├── backprop-mnist.pdf ├── extensible-neural.md └── extensible-neural.pdf ├── samples ├── backprop-mnist.lhs └── extensible-neural.lhs ├── src ├── Data │ └── Type │ │ └── Util.hs ├── Numeric │ ├── Backprop.hs │ └── Backprop │ │ ├── Class.hs │ │ ├── Explicit.hs │ │ ├── Internal.hs │ │ ├── Num.hs │ │ └── Op.hs └── Prelude │ ├── Backprop.hs │ └── Backprop │ ├── Explicit.hs │ └── Num.hs └── test └── Spec.hs /.envrc: -------------------------------------------------------------------------------- 1 | nix_direnv_manual_reload 2 | watch_file "*.cabal" 3 | use flake 4 | -------------------------------------------------------------------------------- /.github/workflows/flake-ci.yml: -------------------------------------------------------------------------------- 1 | name: "Flake CI" 2 | on: 3 | pull_request: 4 | push: 5 | jobs: 6 | checks: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Free Disk Space 10 | uses: insightsengineering/free-disk-space@v1.1.0 11 | - uses: actions/checkout@v3 12 | - uses: webfactory/ssh-agent@v0.9.0 13 | with: 14 | ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} 15 | - uses: cachix/install-nix-action@v22 16 | with: 17 | nix_path: nixpkgs=channel:nixos-unstable 18 | github_access_token: ${{ secrets.GITHUB_TOKEN }} 19 | extra_nix_config: | 20 | trusted-public-keys = hydra.iohk.io:f/Ea+s+dFdN+3Y/G+FDgSq+a5NEWhJGzdjvKNGv0/EQ= cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY= loony-tools:pr9m4BkM/5/eSTZlkQyRt57Jz7OMBxNSUiMC4FkcNfk= 21 | allow-import-from-derivation = true 22 | auto-optimise-store = true 23 | substituters = https://hydra.iohk.io https://cache.nixos.org/ https://cache.iog.io https://cache.zw3rk.com https://mstksg.cachix.org 24 | - uses: cachix/cachix-action@v13 25 | with: 26 | name: mstksg 27 | authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' 28 | - run: nix flake check --show-trace 29 | 30 | cache: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - name: Free Disk Space 34 | uses: insightsengineering/free-disk-space@v1.1.0 35 | - uses: actions/checkout@v4.1.1 36 | - uses: webfactory/ssh-agent@v0.9.0 37 | with: 38 | ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} 39 | - uses: cachix/install-nix-action@v22 40 | with: 41 | nix_path: nixpkgs=channel:nixos-unstable 42 | github_access_token: ${{ secrets.GITHUB_TOKEN }} 43 | extra_nix_config: | 44 | trusted-public-keys = hydra.iohk.io:f/Ea+s+dFdN+3Y/G+FDgSq+a5NEWhJGzdjvKNGv0/EQ= cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY= loony-tools:pr9m4BkM/5/eSTZlkQyRt57Jz7OMBxNSUiMC4FkcNfk= 45 | allow-import-from-derivation = true 46 | auto-optimise-store = true 47 | substituters = https://hydra.iohk.io https://cache.nixos.org/ https://cache.iog.io https://cache.zw3rk.com https://mstksg.cachix.org 48 | - uses: cachix/cachix-action@v13 49 | with: 50 | name: mstksg 51 | authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' 52 | - run: nix build --show-trace 53 | - run: nix develop --show-trace 54 | 55 | every-compiler: 56 | runs-on: ubuntu-latest 57 | steps: 58 | - name: Free Disk Space 59 | uses: insightsengineering/free-disk-space@v1.1.0 60 | - uses: actions/checkout@v3 61 | - uses: webfactory/ssh-agent@v0.9.0 62 | with: 63 | ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} 64 | - uses: cachix/install-nix-action@v22 65 | with: 66 | nix_path: nixpkgs=channel:nixos-unstable 67 | github_access_token: ${{ secrets.GITHUB_TOKEN }} 68 | extra_nix_config: | 69 | trusted-public-keys = hydra.iohk.io:f/Ea+s+dFdN+3Y/G+FDgSq+a5NEWhJGzdjvKNGv0/EQ= cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY= loony-tools:pr9m4BkM/5/eSTZlkQyRt57Jz7OMBxNSUiMC4FkcNfk= 70 | allow-import-from-derivation = true 71 | auto-optimise-store = true 72 | substituters = https://hydra.iohk.io https://cache.nixos.org/ https://cache.iog.io https://cache.zw3rk.com https://mstksg.cachix.org 73 | - uses: cachix/cachix-action@v13 74 | with: 75 | name: mstksg 76 | authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' 77 | - run: nix build .#everyCompiler 78 | 79 | -------------------------------------------------------------------------------- /.github/workflows/haskell.yml: -------------------------------------------------------------------------------- 1 | # Haskell stack project Github Actions template 2 | # https://gist.github.com/mstksg/11f753d891cee5980326a8ea8c865233 3 | # 4 | # To use, mainly change the list in 'plans' and modify 'include' for 5 | # any OS package manager deps. 6 | # 7 | # Currently not working for cabal-install >= 3 8 | # 9 | # Based on https://raw.githubusercontent.com/commercialhaskell/stack/stable/doc/travis-complex.yml 10 | # 11 | # TODO: 12 | # * cache (https://github.com/actions/cache) 13 | # * support for cabal-install >= 3 14 | 15 | name: Haskell Stack Project CI 16 | 17 | on: 18 | push: 19 | schedule: 20 | - cron: "0 0 * * 1" 21 | 22 | jobs: 23 | build: 24 | strategy: 25 | matrix: 26 | os: [ubuntu-latest, macOS-latest] 27 | # use this to specify what resolvers and ghc to use 28 | plan: 29 | # - { build: stack, resolver: "--resolver lts-9" } 30 | # - { build: stack, resolver: "--resolver lts-11" } 31 | - { build: stack, resolver: "--resolver lts-12" } 32 | - { build: stack, resolver: "--resolver lts-13" } 33 | - { build: stack, resolver: "--resolver lts-14" } 34 | - { build: stack, resolver: "--resolver nightly" } 35 | - { build: stack, resolver: "" } 36 | # - { build: cabal, ghc: 8.0.2, cabal-install: "1.24" } 37 | # - { build: cabal, ghc: 8.2.2, cabal-install: "2.0" } 38 | - { build: cabal, ghc: 8.4.4, cabal-install: "2.2" } 39 | - { build: cabal, ghc: 8.6.5, cabal-install: "2.4" } 40 | - { build: cabal, ghc: 8.8.1, cabal-install: "2.4" } # currently not working for >= 3.0 41 | # use this to include any dependencies from OS package managers 42 | include: 43 | # - os: macOS-latest 44 | # brew: anybrewdeps 45 | - os: ubuntu-latest 46 | apt-get: libblas-dev liblapack-dev 47 | 48 | exclude: 49 | - os: macOS-latest 50 | plan: 51 | build: cabal 52 | 53 | runs-on: ${{ matrix.os }} 54 | steps: 55 | - name: Install OS Packages 56 | uses: mstksg/get-package@v1 57 | with: 58 | apt-get: ${{ matrix.apt-get }} 59 | brew: ${{ matrix.brew }} 60 | - uses: actions/checkout@v1 61 | 62 | - name: Setup stack 63 | uses: mstksg/setup-stack@v1 64 | 65 | - name: Setup cabal-install 66 | uses: actions/setup-haskell@v1 67 | with: 68 | ghc-version: ${{ matrix.plan.ghc }} 69 | cabal-version: ${{ matrix.plan.cabal-install }} 70 | if: matrix.plan.build == 'cabal' 71 | 72 | - name: Install dependencies 73 | run: | 74 | set -ex 75 | case "$BUILD" in 76 | stack) 77 | stack --no-terminal --install-ghc $ARGS test --bench --only-dependencies 78 | ;; 79 | cabal) 80 | cabal --version 81 | cabal update 82 | PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@') 83 | cabal install --only-dependencies --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES 84 | ;; 85 | esac 86 | set +ex 87 | env: 88 | ARGS: ${{ matrix.plan.resolver }} 89 | BUILD: ${{ matrix.plan.build }} 90 | 91 | - name: Build 92 | run: | 93 | set -ex 94 | case "$BUILD" in 95 | stack) 96 | stack --no-terminal $ARGS test --bench --no-run-benchmarks --haddock --no-haddock-deps 97 | ;; 98 | cabal) 99 | PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@') 100 | cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES 101 | 102 | ORIGDIR=$(pwd) 103 | for dir in $PACKAGES 104 | do 105 | cd $dir 106 | cabal check || [ "$CABALVER" == "1.16" ] 107 | cabal sdist 108 | PKGVER=$(cabal info . | awk '{print $2;exit}') 109 | SRC_TGZ=$PKGVER.tar.gz 110 | cd dist 111 | tar zxfv "$SRC_TGZ" 112 | cd "$PKGVER" 113 | cabal configure --enable-tests --ghc-options -O0 114 | cabal build 115 | if [ "$CABALVER" = "1.16" ] || [ "$CABALVER" = "1.18" ]; then 116 | cabal test 117 | else 118 | cabal test --show-details=streaming --log=/dev/stdout 119 | fi 120 | cd $ORIGDIR 121 | done 122 | ;; 123 | esac 124 | set +ex 125 | env: 126 | ARGS: ${{ matrix.plan.resolver }} 127 | BUILD: ${{ matrix.plan.build }} 128 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.stack-work 2 | /.shake 3 | /.build 4 | /.direnv 5 | /wiki 6 | /tags 7 | /TAGS 8 | /data 9 | /samples-exe 10 | /dist-newstyle 11 | 12 | /bench-results 13 | /prof-results 14 | /bench-prof 15 | *.prof 16 | *.prof.folded 17 | *.prof.html 18 | /samples-prof 19 | *.dyn_o 20 | *.o 21 | *.hi 22 | /result 23 | -------------------------------------------------------------------------------- /.tintin.yml: -------------------------------------------------------------------------------- 1 | color: darkOrange 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | script: 2 | - | 3 | set -ex 4 | case "$BUILD" in 5 | stack) 6 | stack --no-terminal $ARGS test --bench --no-run-benchmarks --haddock --no-haddock-deps 7 | ;; 8 | cabal) 9 | cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES 10 | 11 | ORIGDIR=$(pwd) 12 | for dir in $PACKAGES 13 | do 14 | cd $dir 15 | cabal check || [ "$CABALVER" == "1.16" ] 16 | cabal sdist 17 | PKGVER=$(cabal info . | awk '{print $2;exit}') 18 | SRC_TGZ=$PKGVER.tar.gz 19 | cd dist 20 | tar zxfv "$SRC_TGZ" 21 | cd "$PKGVER" 22 | cabal configure --enable-tests --ghc-options -O0 23 | cabal build 24 | if [ "$CABALVER" = "1.16" ] || [ "$CABALVER" = "1.18" ]; then 25 | cabal test 26 | else 27 | cabal test --show-details=streaming --log=/dev/stdout 28 | fi 29 | cd $ORIGDIR 30 | done 31 | ;; 32 | esac 33 | set +ex 34 | matrix: 35 | include: 36 | - env: BUILD=cabal GHCVER=8.4.4 CABALVER=2.2 HAPPYVER=1.19.5 ALEXVER=3.1.7 37 | addons: 38 | apt: 39 | sources: 40 | - hvr-ghc 41 | packages: 42 | - cabal-install-2.2 43 | - ghc-8.4.4 44 | - happy-1.19.5 45 | - alex-3.1.7 46 | - libblas-dev 47 | - liblapack-dev 48 | compiler: ': #GHC 8.4.4' 49 | - env: BUILD=cabal GHCVER=8.6.5 CABALVER=2.4 HAPPYVER=1.19.5 ALEXVER=3.1.7 50 | addons: 51 | apt: 52 | sources: 53 | - hvr-ghc 54 | packages: 55 | - cabal-install-2.4 56 | - ghc-8.6.5 57 | - happy-1.19.5 58 | - alex-3.1.7 59 | - libblas-dev 60 | - liblapack-dev 61 | compiler: ': #GHC 8.6.5' 62 | - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 63 | addons: 64 | apt: 65 | sources: 66 | - hvr-ghc 67 | packages: 68 | - cabal-install-head 69 | - ghc-head 70 | - happy-1.19.5 71 | - alex-3.1.7 72 | - libblas-dev 73 | - liblapack-dev 74 | compiler: ': #GHC HEAD' 75 | - env: BUILD=stack ARGS="" 76 | addons: 77 | apt: 78 | packages: 79 | - libgmp-dev 80 | - libblas-dev 81 | - liblapack-dev 82 | compiler: ': #stack default' 83 | - env: BUILD=stack ARGS="--resolver lts-12" 84 | addons: 85 | apt: 86 | packages: 87 | - libgmp-dev 88 | - libblas-dev 89 | - liblapack-dev 90 | compiler: ': #stack 8.4.4' 91 | - env: BUILD=stack ARGS="--resolver lts-13" 92 | addons: 93 | apt: 94 | packages: 95 | - libgmp-dev 96 | - libblas-dev 97 | - liblapack-dev 98 | compiler: ': #stack 8.6.5' 99 | - env: BUILD=stack ARGS="--resolver nightly" 100 | addons: 101 | apt: 102 | packages: 103 | - libgmp-dev 104 | - libblas-dev 105 | - liblapack-dev 106 | compiler: ': #stack nightly' 107 | - env: BUILD=stack ARGS="" 108 | os: osx 109 | compiler: ': #stack default osx' 110 | - env: BUILD=stack ARGS="--resolver lts-12" 111 | os: osx 112 | compiler: ': #stack 8.4.4 osx' 113 | - env: BUILD=stack ARGS="--resolver lts-13" 114 | os: osx 115 | compiler: ': #stack 8.6.5 osx' 116 | - env: BUILD=stack ARGS="--resolver nightly" 117 | os: osx 118 | compiler: ': #stack nightly osx' 119 | allow_failures: 120 | - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 121 | - env: BUILD=stack ARGS="--resolver nightly" 122 | install: 123 | - echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo 124 | '?')]" 125 | - if [ -f configure.ac ]; then autoreconf -i; fi 126 | - | 127 | set -ex 128 | case "$BUILD" in 129 | stack) 130 | # Add in extra-deps for older snapshots, as necessary 131 | # 132 | # This is disabled by default, as relying on the solver like this can 133 | # make builds unreliable. Instead, if you have this situation, it's 134 | # recommended that you maintain multiple stack-lts-X.yaml files. 135 | 136 | #stack --no-terminal --install-ghc $ARGS test --bench --dry-run || ( \ 137 | # stack --no-terminal $ARGS build cabal-install && \ 138 | # stack --no-terminal $ARGS solver --update-config) 139 | 140 | # Build the dependencies 141 | stack --no-terminal --install-ghc $ARGS test --bench --only-dependencies 142 | ;; 143 | cabal) 144 | cabal --version 145 | travis_retry cabal update 146 | 147 | # Get the list of packages from the stack.yaml file. Note that 148 | # this will also implicitly run hpack as necessary to generate 149 | # the .cabal files needed by cabal-install. 150 | PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@') 151 | 152 | cabal install --only-dependencies --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES 153 | ;; 154 | esac 155 | set +ex 156 | cache: 157 | directories: 158 | - $HOME/.ghc 159 | - $HOME/.cabal 160 | - $HOME/.stack 161 | - $TRAVIS_BUILD_DIR/.stack-work 162 | before_install: 163 | - unset CC 164 | - CABALARGS="" 165 | - if [ "x$GHCVER" = "xhead" ]; then CABALARGS=--allow-newer; fi 166 | - export PATH=/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:$HOME/.local/bin:/opt/alex/$ALEXVER/bin:/opt/happy/$HAPPYVER/bin:$HOME/.cabal/bin:$PATH 167 | - mkdir -p ~/.local/bin 168 | - | 169 | if [ `uname` = "Darwin" ] 170 | then 171 | travis_retry curl --insecure -L https://get.haskellstack.org/stable/osx-x86_64.tar.gz | tar xz --strip-components=1 --include '*/stack' -C ~/.local/bin 172 | else 173 | travis_retry curl -L https://get.haskellstack.org/stable/linux-x86_64.tar.gz | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack' 174 | fi 175 | 176 | # Use the more reliable S3 mirror of Hackage 177 | mkdir -p $HOME/.cabal 178 | echo 'remote-repo: hackage.haskell.org:http://hackage.fpcomplete.com/' > $HOME/.cabal/config 179 | echo 'remote-repo-cache: $HOME/.cabal/packages' >> $HOME/.cabal/config 180 | language: generic 181 | sudo: false 182 | 183 | -------------------------------------------------------------------------------- /Build.hs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env stack 2 | -- stack --install-ghc runghc --package shake-0.16.4 --stack-yaml stack.yaml 3 | 4 | import Development.Shake 5 | import Development.Shake.FilePath 6 | import System.Directory 7 | 8 | opts = 9 | shakeOptions 10 | { shakeFiles = ".shake" 11 | , shakeVersion = "1.0" 12 | , shakeVerbosity = Normal 13 | , shakeThreads = 1 14 | } 15 | 16 | data Doc = Lab 17 | 18 | main :: IO () 19 | main = 20 | getDirectoryFilesIO "samples" ["/*.lhs", "/*.hs"] >>= \allSamps -> 21 | getDirectoryFilesIO "src" ["//*.hs"] >>= \allSrc -> 22 | shakeArgs opts $ do 23 | want ["all"] 24 | 25 | "all" 26 | ~> need ["pdf", "md", "gentags", "install", "exe"] 27 | 28 | "pdf" 29 | ~> need 30 | [ "renders" takeFileName f -<.> ".pdf" 31 | | f <- allSamps 32 | , takeExtension f == ".lhs" 33 | ] 34 | 35 | "md" 36 | ~> need 37 | [ "renders" takeFileName f -<.> ".md" 38 | | f <- allSamps 39 | , takeExtension f == ".lhs" 40 | ] 41 | 42 | "exe" 43 | ~> need (map (\f -> "samples-exe" dropExtension f) allSamps) 44 | 45 | "haddocks" ~> do 46 | need $ ("src" ) <$> allSrc 47 | cmd "jle-git-haddocks" 48 | 49 | "install" ~> do 50 | need $ ("src" ) <$> allSrc 51 | cmd "stack install" 52 | 53 | "install-profile" ~> do 54 | need $ ("src" ) <$> allSrc 55 | cmd "stack install --profile" 56 | 57 | "gentags" 58 | ~> need ["tags", "TAGS"] 59 | 60 | ["renders/*.pdf", "renders/*.md"] |%> \f -> do 61 | let src = "samples" takeFileName f -<.> "lhs" 62 | need [src] 63 | liftIO $ createDirectoryIfMissing True "renders" 64 | cmd 65 | "pandoc" 66 | "-V geometry:margin=1in" 67 | "-V fontfamily:palatino,cmtt" 68 | "-V links-as-notes" 69 | "-s" 70 | "--highlight-style tango" 71 | "--reference-links" 72 | "--reference-location block" 73 | "-o" 74 | f 75 | src 76 | 77 | "samples-exe/*" %> \f -> do 78 | need ["install"] 79 | [src] <- getDirectoryFiles "samples" $ (takeFileName f <.>) <$> ["hs", "lhs"] 80 | liftIO $ do 81 | createDirectoryIfMissing True "samples-exe" 82 | createDirectoryIfMissing True ".build" 83 | removeFilesAfter "samples" ["/*.o"] 84 | cmd 85 | "stack ghc" 86 | "--stack-yaml stack.yaml" 87 | "--package finite-typelits" 88 | "--package hmatrix-backprop" 89 | "--package hmatrix-vector-sized" 90 | "--package microlens-th" 91 | "--package mnist-idx" 92 | "--package mwc-random" 93 | "--package one-liner" 94 | "--package one-liner-instances" 95 | "--package random" 96 | "--package singletons" 97 | "--package split" 98 | "--package vector-sized" 99 | "--" 100 | ("samples" src) 101 | "-o" 102 | f 103 | "-hidir .build" 104 | "-Wall" 105 | "-O2" 106 | 107 | "profile" ~> do 108 | need $ do 109 | s <- ["manual", "bp-lens", "bp-hkd", "hybrid"] 110 | e <- ["prof.html", "svg"] 111 | return $ "bench-prof/bench-" ++ s <.> e 112 | 113 | "bench-prof/bench" %> \f -> do 114 | let src = "bench" takeFileName f <.> ".hs" 115 | need ["install-profile", src] 116 | unit $ 117 | cmd 118 | "stack install" 119 | "--profile" 120 | "--stack-yaml stack.yaml" 121 | [ "lens" 122 | , "hmatrix" 123 | , "one-liner-instances" 124 | , "split" 125 | , "criterion" 126 | ] 127 | unit $ 128 | cmd 129 | "stack ghc" 130 | "--profile" 131 | "--stack-yaml stack.yaml" 132 | src 133 | "--" 134 | "-o" 135 | f 136 | "-hidir .build" 137 | "-O2" 138 | "-prof" 139 | "-fexternal-interpreter" 140 | 141 | "bench-prof/bench-*.prof" %> \f -> do 142 | need ["bench-prof/bench"] 143 | let b = drop 6 $ takeBaseName f 144 | unit $ 145 | cmd 146 | "./bench-prof/bench" 147 | ("gradient/" ++ b) 148 | "+RTS" 149 | "-p" 150 | cmd "mv" "bench.prof" f 151 | 152 | "**/*.prof.html" %> \f -> do 153 | let src = f -<.> "" 154 | need [src] 155 | cmd "profiteur" src 156 | 157 | "**/*.prof.folded" %> \f -> do 158 | let src = f -<.> "" 159 | need [src] 160 | Stdout out <- cmd "cat" [src] 161 | cmd 162 | (Stdin out) 163 | (FileStdout f) 164 | "ghc-prof-flamegraph" 165 | 166 | "bench-prof/*.svg" %> \f -> do 167 | let src = f -<.> "prof.folded" 168 | need [src] 169 | cmd 170 | (FileStdout f) 171 | "flamegraph.pl" 172 | "--width 2000" 173 | src 174 | 175 | ["tags", "TAGS"] &%> \_ -> do 176 | need (("src" ) <$> allSrc) 177 | cmd "hasktags" "src/" 178 | 179 | "clean" ~> do 180 | unit $ cmd "stack clean" 181 | removeFilesAfter ".shake" ["//*"] 182 | removeFilesAfter ".build" ["//*"] 183 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Justin Le (c) 2020 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Justin Le nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [backprop][docs] 2 | ================ 3 | 4 | [![backprop on Hackage](https://img.shields.io/hackage/v/backprop.svg?maxAge=86400)](https://hackage.haskell.org/package/backprop) 5 | [![backprop on Stackage LTS 11](http://stackage.org/package/backprop/badge/lts-11)](http://stackage.org/lts-11/package/backprop) 6 | [![backprop on Stackage Nightly](http://stackage.org/package/backprop/badge/nightly)](http://stackage.org/nightly/package/backprop) 7 | [![Build Status](https://travis-ci.org/mstksg/backprop.svg?branch=master)](https://travis-ci.org/mstksg/backprop) 8 | 9 | [![Join the chat at https://gitter.im/haskell-backprop/Lobby](https://badges.gitter.im/haskell-backprop/Lobby.svg)](https://gitter.im/haskell-backprop/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 10 | 11 | [**Documentation and Walkthrough**][docs] 12 | 13 | [docs]: https://backprop.jle.im 14 | 15 | Automatic *heterogeneous* back-propagation. 16 | 17 | Write your functions to compute your result, and the library will automatically 18 | generate functions to compute your gradient. 19 | 20 | Differs from [ad][] by offering full heterogeneity -- each intermediate step 21 | and the resulting value can have different types (matrices, vectors, scalars, 22 | lists, etc.). 23 | 24 | [ad]: http://hackage.haskell.org/package/ad 25 | 26 | Useful for applications in [differentiable programming][dp] and deep learning 27 | for creating and training numerical models, especially as described in this 28 | blog post on [a purely functional typed approach to trainable models][models]. 29 | Overall, intended for the implementation of gradient descent and other numeric 30 | optimization techniques. Comparable to the python library [autograd][]. 31 | 32 | [dp]: https://www.facebook.com/yann.lecun/posts/10155003011462143 33 | [models]: https://blog.jle.im/entry/purely-functional-typed-models-1.html 34 | [autograd]: https://github.com/HIPS/autograd 35 | 36 | Currently up on [hackage][], with haddock documentation! However, a proper 37 | library introduction and usage tutorial [is available here][docs]. See also my 38 | [introductory blog post][blog]. You can also find help or support on the 39 | [gitter channel][gitter]. 40 | 41 | [hackage]: http://hackage.haskell.org/package/backprop 42 | [blog]: https://blog.jle.im/entry/introducing-the-backprop-library.html 43 | [gitter]: https://gitter.im/haskell-backprop/Lobby 44 | 45 | If you want to provide *backprop* for users of your library, see this [guide 46 | to equipping your library with backprop][library]. 47 | 48 | [library]: https://backprop.jle.im/08-equipping-your-library.html 49 | 50 | 51 | MNIST Digit Classifier Example 52 | ------------------------------ 53 | 54 | My [blog post][blog] introduces the concepts in this library in the context of 55 | training a handwritten digit classifier. I recommend reading that first. 56 | 57 | There are some [literate haskell examples][mnist-lhs] in the source, though 58 | ([rendered as pdf here][mnist-pdf]), which can be built (if [stack][] is 59 | installed) using: 60 | 61 | [mnist-lhs]: https://github.com/mstksg/backprop/blob/master/samples/backprop-mnist.lhs 62 | [mnist-pdf]: https://github.com/mstksg/backprop/blob/master/renders/backprop-mnist.pdf 63 | [stack]: http://haskellstack.org/ 64 | 65 | ```bash 66 | $ ./Build.hs exe 67 | ``` 68 | 69 | There is a follow-up tutorial on using the library with more advanced types, 70 | with extensible neural networks a la [this blog post][blog], [available as 71 | literate haskell][neural-lhs] and also [rendered as a PDF][neural-pdf]. 72 | 73 | [blog]: https://blog.jle.im/entries/series/+practical-dependent-types-in-haskell.html 74 | [neural-lhs]: https://github.com/mstksg/backprop/blob/master/samples/extensible-neural.lhs 75 | [neural-pdf]: https://github.com/mstksg/backprop/blob/master/renders/extensible-neural.pdf 76 | 77 | Brief example 78 | ------------- 79 | 80 | (This is a really brief version of [the documentation walkthrough][docs] and my 81 | [blog post][blog]) 82 | 83 | The quick example below describes the running of a neural network with one 84 | hidden layer to calculate its squared error with respect to target `targ`, 85 | which is parameterized by two weight matrices and two bias vectors. 86 | Vector/matrix types are from the *hmatrix* package. 87 | 88 | Let's make a data type to store our parameters, with convenient accessors using 89 | *[lens][]*: 90 | 91 | [lens]: http://hackage.haskell.org/package/lens 92 | 93 | ```haskell 94 | import Numeric.LinearAlgebra.Static.Backprop 95 | 96 | data Network = Net { _weight1 :: L 20 100 97 | , _bias1 :: R 20 98 | , _weight2 :: L 5 20 99 | , _bias2 :: R 5 100 | } 101 | 102 | makeLenses ''Network 103 | ``` 104 | 105 | (`R n` is an n-length vector, `L m n` is an m-by-n matrix, etc., `#>` is 106 | matrix-vector multiplication) 107 | 108 | "Running" a network on an input vector might look like this: 109 | 110 | ```haskell 111 | runNet net x = z 112 | where 113 | y = logistic $ (net ^^. weight1) #> x + (net ^^. bias1) 114 | z = logistic $ (net ^^. weight2) #> y + (net ^^. bias2) 115 | 116 | logistic :: Floating a => a -> a 117 | logistic x = 1 / (1 + exp (-x)) 118 | ``` 119 | 120 | And that's it! `neuralNet` is now backpropagatable! 121 | 122 | We can "run" it using `evalBP`: 123 | 124 | ```haskell 125 | evalBP2 runNet :: Network -> R 100 -> R 5 126 | ``` 127 | 128 | If we write a function to compute errors: 129 | 130 | ```haskell 131 | squaredError target output = error `dot` error 132 | where 133 | error = target - output 134 | ``` 135 | 136 | we can "test" our networks: 137 | 138 | ```haskell 139 | netError target input net = squaredError (auto target) 140 | (runNet net (auto input)) 141 | ``` 142 | 143 | This can be run, again: 144 | 145 | ```haskell 146 | evalBP (netError myTarget myVector) :: Network -> Double 147 | ``` 148 | 149 | Now, we just wrote a *normal function to compute the error of our network*. 150 | With the *backprop* library, we now also have a way to *compute the gradient*, 151 | as well! 152 | 153 | ```haskell 154 | gradBP (netError myTarget myVector) :: Network -> Network 155 | ``` 156 | 157 | Now, we can perform gradient descent! 158 | 159 | ```haskell 160 | gradDescent 161 | :: R 100 162 | -> R 5 163 | -> Network 164 | -> Network 165 | gradDescent x targ n0 = n0 - 0.1 * gradient 166 | where 167 | gradient = gradBP (netError targ x) n0 168 | ``` 169 | 170 | Ta dah! We were able to compute the gradient of our error function, just by 171 | only saying how to compute *the error itself*. 172 | 173 | For a more fleshed out example, see [the documentaiton][docs], my [blog 174 | post][blog] and the [MNIST tutorial][mnist-lhs] (also [rendered as a 175 | pdf][mnist-pdf]) 176 | 177 | Benchmarks and Performance 178 | -------------------------- 179 | 180 | Here are some basic benchmarks comparing the library's automatic 181 | differentiation process to "manual" differentiation by hand. When using the 182 | [MNIST tutorial][bench] as an example: 183 | 184 | [bench]: https://github.com/mstksg/backprop/blob/master/bench/bench.hs 185 | 186 | ![benchmarks](https://i.imgur.com/rLUx4x4.png) 187 | 188 | Here we compare: 189 | 190 | 1. "Manual" differentiation of a 784 x 300 x 100 x 10 fully-connected 191 | feed-forward ANN. 192 | 2. Automatic differentiation using *backprop* and the lens-based accessor 193 | interface 194 | 3. Automatic differentiation using *backprop* and the "higher-kinded 195 | data"-based pattern matching interface 196 | 4. A hybrid approach that manually provides gradients for individual layers 197 | but uses automatic differentiation for chaining the layers together. 198 | 199 | We can see that simply *running* the network and functions (using `evalBP`) 200 | incurs virtually zero overhead. This means that library authors could actually 201 | export *only* backprop-lifted functions, and users would be able to use them 202 | without losing any performance. 203 | 204 | As for computing gradients, there exists some associated overhead, from three 205 | main sources. Of these, the building of the computational graph and the 206 | Wengert Tape wind up being negligible. For more information, see [a detailed 207 | look at performance, overhead, and optimization techniques][performance] in the 208 | documentation. 209 | 210 | [performance]: https://backprop.jle.im/07-performance.html 211 | 212 | Note that the manual and hybrid modes almost overlap in the range of their 213 | random variances. 214 | 215 | Comparisons 216 | ----------- 217 | 218 | *backprop* can be compared and contrasted to many other similar libraries with 219 | some overlap: 220 | 221 | 1. The *[ad][]* library (and variants like *[diffhask][]*) support automatic 222 | differentiation, but only for *homogeneous*/*monomorphic* situations. All 223 | values in a computation must be of the same type --- so, your computation 224 | might be the manipulation of `Double`s through a `Double -> Double` 225 | function. 226 | 227 | *backprop* allows you to mix matrices, vectors, doubles, integers, and even 228 | key-value maps as a part of your computation, and they will all be 229 | backpropagated properly with the help of the `Backprop` typeclass. 230 | 231 | 2. The *[autograd][]* library is a very close equivalent to *backprop*, 232 | implemented in Python for Python applications. The difference between 233 | *backprop* and *autograd* is mostly the difference between Haskell and 234 | Python --- static types with type inference, purity, etc. 235 | 236 | 3. There is a link between *backprop* and deep learning/neural network 237 | libraries like *[tensorflow][]*, *[caffe][]*, and *[theano][]*, which all 238 | support some form of heterogeneous automatic differentiation. Haskell 239 | libraries doing similar things include *[grenade][]*. 240 | 241 | These are all frameworks for working with neural networks or other 242 | gradient-based optimizations --- they include things like built-in 243 | optimizers, methods to automate training data, built-in models to use out 244 | of the box. *backprop* could be used as a *part* of such a framework, like 245 | I described in my [A Purely Functional Typed Approach to Trainable 246 | Models][models] blog series; however, the *backprop* library itself does 247 | not provide any built in models or optimizers or automated data processing 248 | pipelines. 249 | 250 | [diffhask]: https://hackage.haskell.org/package/diffhask 251 | [tensorflow]: https://www.tensorflow.org/ 252 | [caffe]: http://caffe.berkeleyvision.org/ 253 | [theano]: http://www.deeplearning.net/software/theano/ 254 | [grenade]: http://hackage.haskell.org/package/grenade 255 | 256 | See [documentation][comparisons] for a more detailed look. 257 | 258 | [comparisons]: https://backprop.jle.im/09-comparisons.html 259 | 260 | Todo 261 | ---- 262 | 263 | 1. Benchmark against competing back-propagation libraries like *ad*, and 264 | auto-differentiating tensor libraries like *[grenade][]* 265 | 266 | [grenade]: https://github.com/HuwCampbell/grenade 267 | 268 | 2. Write tests! 269 | 270 | 3. Explore opportunities for parallelization. There are some naive ways of 271 | directly parallelizing right now, but potential overhead should be 272 | investigated. 273 | 274 | 4. Some open questions: 275 | 276 | a. Is it possible to support constructors with existential types? 277 | 278 | b. How to support "monadic" operations that depend on results of previous 279 | operations? (`ApBP` already exists for situations that don't) 280 | 281 | c. What needs to be done to allow us to automatically do second, 282 | third-order differentiation, as well? This might be useful for certain 283 | ODE solvers which rely on second order gradients and hessians. 284 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | 3 | main = defaultMain 4 | -------------------------------------------------------------------------------- /backprop.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | name: backprop 3 | version: 0.2.7.2 4 | synopsis: Heterogeneous automatic differentation 5 | description: 6 | Write your functions to compute your result, and the library will 7 | automatically generate functions to compute your gradient. 8 | . 9 | Implements heterogeneous reverse-mode automatic differentiation, commonly 10 | known as "backpropagation". 11 | . 12 | See for official introduction and documentation. 13 | 14 | category: Math 15 | homepage: https://backprop.jle.im 16 | bug-reports: https://github.com/mstksg/backprop/issues 17 | author: Justin Le 18 | maintainer: justin@jle.im 19 | copyright: (c) Justin Le 2018 20 | license: BSD3 21 | license-file: LICENSE 22 | build-type: Simple 23 | tested-with: GHC >=8.4 24 | extra-source-files: 25 | Build.hs 26 | CHANGELOG.md 27 | doc/01-getting-started.md 28 | doc/02-a-detailed-look.md 29 | doc/03-manipulating-bvars.md 30 | doc/04-the-backprop-typeclass.md 31 | doc/05-applications.md 32 | doc/06-manual-gradients.md 33 | doc/07-performance.md 34 | doc/08-equipping-your-library.md 35 | doc/09-comparisons.md 36 | doc/index.md 37 | README.md 38 | renders/backprop-mnist.md 39 | renders/backprop-mnist.pdf 40 | renders/extensible-neural.md 41 | renders/extensible-neural.pdf 42 | samples/backprop-mnist.lhs 43 | samples/extensible-neural.lhs 44 | 45 | source-repository head 46 | type: git 47 | location: https://github.com/mstksg/backprop 48 | 49 | flag vinyl_0_14 50 | manual: False 51 | default: True 52 | 53 | library 54 | exposed-modules: 55 | Numeric.Backprop 56 | Numeric.Backprop.Class 57 | Numeric.Backprop.Explicit 58 | Numeric.Backprop.Internal 59 | Numeric.Backprop.Num 60 | Numeric.Backprop.Op 61 | Prelude.Backprop 62 | Prelude.Backprop.Explicit 63 | Prelude.Backprop.Num 64 | 65 | other-modules: Data.Type.Util 66 | hs-source-dirs: src 67 | ghc-options: 68 | -Wall -Wcompat -Wincomplete-record-updates -Wredundant-constraints 69 | -Wunused-packages 70 | 71 | build-depends: 72 | base >=4.7 && <5 73 | , containers 74 | , deepseq 75 | , microlens 76 | , reflection 77 | , transformers 78 | , vector 79 | , vinyl >=0.9.1 80 | 81 | default-language: Haskell2010 82 | 83 | if flag(vinyl_0_14) 84 | build-depends: vinyl >=0.14.2 85 | 86 | else 87 | build-depends: vinyl <0.14 88 | 89 | benchmark backprop-mnist-bench 90 | type: exitcode-stdio-1.0 91 | main-is: bench.hs 92 | other-modules: Paths_backprop 93 | hs-source-dirs: bench 94 | ghc-options: 95 | -Wall -Wcompat -Wincomplete-record-updates -Wredundant-constraints 96 | -threaded -rtsopts -with-rtsopts=-N -O2 -Wunused-packages 97 | 98 | build-depends: 99 | backprop 100 | , base >=4.7 && <5 101 | , criterion 102 | , deepseq 103 | , directory 104 | , hmatrix >=0.18 105 | , microlens 106 | , microlens-th 107 | , mwc-random 108 | , time 109 | , vector 110 | 111 | default-language: Haskell2010 112 | -------------------------------------------------------------------------------- /bench/bench.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE FlexibleInstances #-} 6 | {-# LANGUAGE GADTs #-} 7 | {-# LANGUAGE LambdaCase #-} 8 | {-# LANGUAGE PolyKinds #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE StandaloneDeriving #-} 11 | {-# LANGUAGE TemplateHaskell #-} 12 | {-# LANGUAGE TypeApplications #-} 13 | {-# LANGUAGE TypeFamilies #-} 14 | {-# LANGUAGE ViewPatterns #-} 15 | {-# OPTIONS_GHC -fno-warn-orphans #-} 16 | 17 | import Control.DeepSeq 18 | import Criterion.Main 19 | import Criterion.Types 20 | import Data.Char 21 | import Data.Functor.Identity 22 | import Data.Time 23 | import qualified Data.Vector as V 24 | import GHC.Generics (Generic) 25 | import GHC.TypeLits 26 | import Lens.Micro 27 | import Lens.Micro.TH 28 | import Numeric.Backprop 29 | import Numeric.Backprop.Class 30 | import qualified Numeric.LinearAlgebra as HM 31 | import Numeric.LinearAlgebra.Static 32 | import System.Directory 33 | import qualified System.Random.MWC as MWC 34 | 35 | type family HKD f a where 36 | HKD Identity a = a 37 | HKD f a = f a 38 | 39 | data Layer' i o f 40 | = Layer 41 | { _lWeights :: !(HKD f (L o i)) 42 | , _lBiases :: !(HKD f (R o)) 43 | } 44 | deriving (Generic) 45 | 46 | type Layer i o = Layer' i o Identity 47 | 48 | deriving instance (KnownNat i, KnownNat o) => Show (Layer i o) 49 | instance NFData (Layer i o) 50 | 51 | makeLenses ''Layer' 52 | 53 | data Network' i h1 h2 o f 54 | = Net 55 | { _nLayer1 :: !(HKD f (Layer i h1)) 56 | , _nLayer2 :: !(HKD f (Layer h1 h2)) 57 | , _nLayer3 :: !(HKD f (Layer h2 o)) 58 | } 59 | deriving (Generic) 60 | 61 | type Network i h1 h2 o = Network' i h1 h2 o Identity 62 | 63 | deriving instance (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => Show (Network i h1 h2 o) 64 | instance NFData (Network i h1 h2 o) 65 | 66 | makeLenses ''Network' 67 | 68 | main :: IO () 69 | main = do 70 | g <- 71 | MWC.initialize 72 | . V.fromList 73 | . map (fromIntegral . ord) 74 | $ "hello world" 75 | test0 <- MWC.uniformR @(R 784, R 10) ((0, 0), (1, 1)) g 76 | net0 <- MWC.uniformR @(Network 784 300 100 10) (-0.5, 0.5) g 77 | t <- getZonedTime 78 | let tstr = formatTime defaultTimeLocale "%Y%m%d-%H%M%S" t 79 | createDirectoryIfMissing True "bench-results" 80 | defaultMainWith 81 | defaultConfig 82 | { reportFile = Just $ "bench-results/mnist-bench_" ++ tstr ++ ".html" 83 | , timeLimit = 10 84 | } 85 | [ bgroup 86 | "gradient" 87 | [ let runTest x y = gradNetManual x y net0 88 | in bench "manual" $ nf (uncurry runTest) test0 89 | , let runTest x y = gradBP (netErr x y) net0 90 | in bench "bp-lens" $ nf (uncurry runTest) test0 91 | , let runTest x y = gradBP (netErrHKD x y) net0 92 | in bench "bp-hkd" $ nf (uncurry runTest) test0 93 | , let runTest x y = gradBP (\n' -> netErrHybrid n' y x) net0 94 | in bench "hybrid" $ nf (uncurry runTest) test0 95 | ] 96 | , bgroup 97 | "descent" 98 | [ let runTest x y = trainStepManual 0.02 x y net0 99 | in bench "manual" $ nf (uncurry runTest) test0 100 | , let runTest x y = trainStep 0.02 x y net0 101 | in bench "bp-lens" $ nf (uncurry runTest) test0 102 | , let runTest x y = trainStepHKD 0.02 x y net0 103 | in bench "bp-hkd" $ nf (uncurry runTest) test0 104 | , let runTest x y = trainStepHybrid 0.02 x y net0 105 | in bench "hybrid" $ nf (uncurry runTest) test0 106 | ] 107 | , bgroup 108 | "run" 109 | [ let runTest = runNetManual net0 110 | in bench "manual" $ nf runTest (fst test0) 111 | , let runTest x = evalBP (`runNetwork` x) net0 112 | in bench "bp-lens" $ nf runTest (fst test0) 113 | , let runTest x = evalBP (`runNetworkHKD` x) net0 114 | in bench "bp-hkd" $ nf runTest (fst test0) 115 | , let runTest x = evalBP (`runNetHybrid` x) net0 116 | in bench "hybrid" $ nf runTest (fst test0) 117 | ] 118 | ] 119 | 120 | -- ------------------------------ 121 | -- - "Backprop" Lens Mode - 122 | -- ------------------------------ 123 | 124 | runLayer :: 125 | (KnownNat i, KnownNat o, Reifies s W) => 126 | BVar s (Layer i o) -> 127 | BVar s (R i) -> 128 | BVar s (R o) 129 | runLayer l x = (l ^^. lWeights) #>! x + (l ^^. lBiases) 130 | {-# INLINE runLayer #-} 131 | 132 | softMax :: (KnownNat n, Reifies s W) => BVar s (R n) -> BVar s (R n) 133 | softMax x = konst' (1 / sumElements' expx) * expx 134 | where 135 | expx = exp x 136 | {-# INLINE softMax #-} 137 | 138 | runNetwork :: 139 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o, Reifies s W) => 140 | BVar s (Network i h1 h2 o) -> 141 | R i -> 142 | BVar s (R o) 143 | runNetwork n = 144 | softMax 145 | . runLayer (n ^^. nLayer3) 146 | . logistic 147 | . runLayer (n ^^. nLayer2) 148 | . logistic 149 | . runLayer (n ^^. nLayer1) 150 | . auto 151 | {-# INLINE runNetwork #-} 152 | 153 | crossEntropy :: 154 | (KnownNat n, Reifies s W) => 155 | R n -> 156 | BVar s (R n) -> 157 | BVar s Double 158 | crossEntropy t r = negate $ log r <.>! auto t 159 | {-# INLINE crossEntropy #-} 160 | 161 | netErr :: 162 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o, Reifies s W) => 163 | R i -> 164 | R o -> 165 | BVar s (Network i h1 h2 o) -> 166 | BVar s Double 167 | netErr x t n = crossEntropy t (runNetwork n x) 168 | {-# INLINE netErr #-} 169 | 170 | trainStep :: 171 | forall i h1 h2 o. 172 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => 173 | Double -> 174 | R i -> 175 | R o -> 176 | Network i h1 h2 o -> 177 | Network i h1 h2 o 178 | trainStep r !x !t !n = n - realToFrac r * gradBP (netErr x t) n 179 | {-# INLINE trainStep #-} 180 | 181 | -- ------------------------------ 182 | -- - "Backprop" HKD Mode - 183 | -- ------------------------------ 184 | 185 | runLayerHKD :: 186 | (KnownNat i, KnownNat o, Reifies s W) => 187 | BVar s (Layer i o) -> 188 | BVar s (R i) -> 189 | BVar s (R o) 190 | runLayerHKD (splitBV -> Layer w b) x = w #>! x + b 191 | {-# INLINE runLayerHKD #-} 192 | 193 | runNetworkHKD :: 194 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o, Reifies s W) => 195 | BVar s (Network i h1 h2 o) -> 196 | R i -> 197 | BVar s (R o) 198 | runNetworkHKD (splitBV -> Net l1 l2 l3) = 199 | softMax 200 | . runLayerHKD l3 201 | . logistic 202 | . runLayerHKD l2 203 | . logistic 204 | . runLayerHKD l1 205 | . auto 206 | {-# INLINE runNetworkHKD #-} 207 | 208 | netErrHKD :: 209 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o, Reifies s W) => 210 | R i -> 211 | R o -> 212 | BVar s (Network i h1 h2 o) -> 213 | BVar s Double 214 | netErrHKD x t n = crossEntropy t (runNetworkHKD n x) 215 | {-# INLINE netErrHKD #-} 216 | 217 | trainStepHKD :: 218 | forall i h1 h2 o. 219 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => 220 | Double -> 221 | R i -> 222 | R o -> 223 | Network i h1 h2 o -> 224 | Network i h1 h2 o 225 | trainStepHKD r !x !t !n = n - realToFrac r * gradBP (netErrHKD x t) n 226 | {-# INLINE trainStepHKD #-} 227 | 228 | -- ------------------------------ 229 | -- - "Manual" Mode - 230 | -- ------------------------------ 231 | 232 | runLayerManual :: 233 | (KnownNat i, KnownNat o) => 234 | Layer i o -> 235 | R i -> 236 | R o 237 | runLayerManual l x = (l ^. lWeights) #> x + (l ^. lBiases) 238 | {-# INLINE runLayerManual #-} 239 | 240 | softMaxManual :: KnownNat n => R n -> R n 241 | softMaxManual x = konst (1 / sumElements expx) * expx 242 | where 243 | expx = exp x 244 | {-# INLINE softMaxManual #-} 245 | 246 | runNetManual :: 247 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => 248 | Network i h1 h2 o -> 249 | R i -> 250 | R o 251 | runNetManual n = 252 | softMaxManual 253 | . runLayerManual (n ^. nLayer3) 254 | . logistic 255 | . runLayerManual (n ^. nLayer2) 256 | . logistic 257 | . runLayerManual (n ^. nLayer1) 258 | {-# INLINE runNetManual #-} 259 | 260 | gradNetManual :: 261 | forall i h1 h2 o. 262 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => 263 | R i -> 264 | R o -> 265 | Network i h1 h2 o -> 266 | Network i h1 h2 o 267 | gradNetManual x t (Net (Layer w1 b1) (Layer w2 b2) (Layer w3 b3)) = 268 | let y1 = w1 #> x 269 | z1 = y1 + b1 270 | x2 = logistic z1 271 | y2 = w2 #> x2 272 | z2 = y2 + b2 273 | x3 = logistic z2 274 | y3 = w3 #> x3 275 | z3 = y3 + b3 276 | o0 = exp z3 277 | o1 = HM.sumElements (extract o0) 278 | o2 = o0 / konst o1 279 | -- o3 = - (log o2 <.> t) 280 | dEdO3 = 1 281 | dEdO2 = - (dEdO3 * t / o2) 282 | dEdO1 = -((dEdO2 <.> o0) / (o1 ** 2)) 283 | dEdO0 = konst dEdO1 + dEdO2 / konst o1 284 | dEdZ3 = dEdO0 * o0 285 | dEdY3 = dEdZ3 286 | dEdX3 = tr w3 #> dEdY3 287 | dEdZ2 = dEdX3 * (x3 * (1 - x3)) 288 | dEdY2 = dEdZ2 289 | dEdX2 = tr w2 #> dEdY2 290 | dEdZ1 = dEdX2 * (x2 * (1 - x2)) 291 | dEdY1 = dEdZ1 292 | dEdB3 = dEdZ3 293 | dEdW3 = dEdY3 `outer` x3 294 | dEdB2 = dEdZ2 295 | dEdW2 = dEdY2 `outer` x2 296 | dEdB1 = dEdZ1 297 | dEdW1 = dEdY1 `outer` x 298 | in Net (Layer dEdW1 dEdB1) (Layer dEdW2 dEdB2) (Layer dEdW3 dEdB3) 299 | {-# INLINE gradNetManual #-} 300 | 301 | trainStepManual :: 302 | forall i h1 h2 o. 303 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => 304 | Double -> 305 | R i -> 306 | R o -> 307 | Network i h1 h2 o -> 308 | Network i h1 h2 o 309 | trainStepManual r !x !t !n = 310 | let gN = gradNetManual x t n 311 | in n - (realToFrac r * gN) 312 | 313 | -- ------------------------------ 314 | -- - "Hybrid" Mode - 315 | -- ------------------------------ 316 | 317 | layerOp :: (KnownNat i, KnownNat o) => Op '[Layer i o, R i] (R o) 318 | layerOp = op2 $ \(Layer w b) x -> 319 | ( w #> x + b 320 | , \g -> (Layer (g `outer` x) g, tr w #> g) 321 | ) 322 | {-# INLINE layerOp #-} 323 | 324 | logisticOp :: 325 | Floating a => 326 | Op '[a] a 327 | logisticOp = op1 $ \x -> 328 | let lx = logistic x 329 | in (lx, \g -> lx * (1 - lx) * g) 330 | {-# INLINE logisticOp #-} 331 | 332 | softMaxOp :: 333 | KnownNat n => 334 | Op '[R n] (R n) 335 | softMaxOp = op1 $ \x -> 336 | let expx = exp x 337 | tot = sumElements expx 338 | invtot = 1 / tot 339 | res = konst invtot * expx 340 | in ( res 341 | , \g -> res - konst (invtot ** 2) * exp (2 * x) * g 342 | ) 343 | {-# INLINE softMaxOp #-} 344 | 345 | softMaxCrossEntropyOp :: 346 | KnownNat n => 347 | R n -> 348 | Op '[R n] Double 349 | softMaxCrossEntropyOp targ = op1 $ \x -> 350 | let expx = exp x 351 | sm = konst (1 / sumElements expx) * expx 352 | ce = negate $ log sm <.> targ 353 | in ( ce 354 | , \g -> (sm - targ) * konst g 355 | ) 356 | {-# INLINE softMaxCrossEntropyOp #-} 357 | 358 | runNetHybrid :: 359 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o, Reifies s W) => 360 | BVar s (Network i h1 h2 o) -> 361 | R i -> 362 | BVar s (R o) 363 | runNetHybrid n = 364 | liftOp1 softMaxOp 365 | . liftOp2 layerOp (n ^^. nLayer3) 366 | . liftOp1 logisticOp 367 | . liftOp2 layerOp (n ^^. nLayer2) 368 | . liftOp1 logisticOp 369 | . liftOp2 layerOp (n ^^. nLayer1) 370 | . auto 371 | {-# INLINE runNetHybrid #-} 372 | 373 | netErrHybrid :: 374 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o, Reifies s W) => 375 | BVar s (Network i h1 h2 o) -> 376 | R o -> 377 | R i -> 378 | BVar s Double 379 | netErrHybrid n t = 380 | liftOp1 (softMaxCrossEntropyOp t) 381 | . liftOp2 layerOp (n ^^. nLayer3) 382 | . liftOp1 logisticOp 383 | . liftOp2 layerOp (n ^^. nLayer2) 384 | . liftOp1 logisticOp 385 | . liftOp2 layerOp (n ^^. nLayer1) 386 | . auto 387 | {-# INLINE netErrHybrid #-} 388 | 389 | trainStepHybrid :: 390 | forall i h1 h2 o. 391 | (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => 392 | Double -> 393 | R i -> 394 | R o -> 395 | Network i h1 h2 o -> 396 | Network i h1 h2 o 397 | trainStepHybrid r !x !t !n = 398 | let gN = gradBP (\n' -> netErrHybrid n' t x) n 399 | in n - (realToFrac r * gN) 400 | {-# INLINE trainStepHybrid #-} 401 | 402 | -- ------------------------------ 403 | -- - Operations - 404 | -- ------------------------------ 405 | 406 | infixr 8 #>! 407 | (#>!) :: 408 | (KnownNat m, KnownNat n, Reifies s W) => 409 | BVar s (L m n) -> 410 | BVar s (R n) -> 411 | BVar s (R m) 412 | (#>!) = liftOp2 . op2 $ \m v -> 413 | (m #> v, \g -> (g `outer` v, tr m #> g)) 414 | {-# INLINE (#>!) #-} 415 | 416 | infixr 8 <.>! 417 | (<.>!) :: 418 | (KnownNat n, Reifies s W) => 419 | BVar s (R n) -> 420 | BVar s (R n) -> 421 | BVar s Double 422 | (<.>!) = liftOp2 . op2 $ \x y -> 423 | ( x <.> y 424 | , \g -> (konst g * y, x * konst g) 425 | ) 426 | {-# INLINE (<.>!) #-} 427 | 428 | konst' :: 429 | (KnownNat n, Reifies s W) => 430 | BVar s Double -> 431 | BVar s (R n) 432 | konst' = liftOp1 . op1 $ \c -> (konst c, HM.sumElements . extract) 433 | {-# INLINE konst' #-} 434 | 435 | sumElements :: KnownNat n => R n -> Double 436 | sumElements = HM.sumElements . extract 437 | {-# INLINE sumElements #-} 438 | 439 | sumElements' :: 440 | (KnownNat n, Reifies s W) => 441 | BVar s (R n) -> 442 | BVar s Double 443 | sumElements' = liftOp1 . op1 $ \x -> (sumElements x, konst) 444 | {-# INLINE sumElements' #-} 445 | 446 | logistic :: Floating a => a -> a 447 | logistic x = 1 / (1 + exp (-x)) 448 | {-# INLINE logistic #-} 449 | 450 | -- ------------------------------ 451 | -- - Instances - 452 | -- ------------------------------ 453 | 454 | instance (KnownNat i, KnownNat o) => Num (Layer i o) where 455 | Layer w1 b1 + Layer w2 b2 = Layer (w1 + w2) (b1 + b2) 456 | Layer w1 b1 - Layer w2 b2 = Layer (w1 - w2) (b1 - b2) 457 | Layer w1 b1 * Layer w2 b2 = Layer (w1 * w2) (b1 * b2) 458 | abs (Layer w b) = Layer (abs w) (abs b) 459 | signum (Layer w b) = Layer (signum w) (signum b) 460 | negate (Layer w b) = Layer (negate w) (negate b) 461 | fromInteger x = Layer (fromInteger x) (fromInteger x) 462 | 463 | instance (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => Num (Network i h1 h2 o) where 464 | Net a b c + Net d e f = Net (a + d) (b + e) (c + f) 465 | Net a b c - Net d e f = Net (a - d) (b - e) (c - f) 466 | Net a b c * Net d e f = Net (a * d) (b * e) (c * f) 467 | abs (Net a b c) = Net (abs a) (abs b) (abs c) 468 | signum (Net a b c) = Net (signum a) (signum b) (signum c) 469 | negate (Net a b c) = Net (negate a) (negate b) (negate c) 470 | fromInteger x = Net (fromInteger x) (fromInteger x) (fromInteger x) 471 | 472 | instance (KnownNat i, KnownNat o) => Fractional (Layer i o) where 473 | Layer w1 b1 / Layer w2 b2 = Layer (w1 / w2) (b1 / b2) 474 | recip (Layer w b) = Layer (recip w) (recip b) 475 | fromRational x = Layer (fromRational x) (fromRational x) 476 | 477 | instance (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => Fractional (Network i h1 h2 o) where 478 | Net a b c / Net d e f = Net (a / d) (b / e) (c / f) 479 | recip (Net a b c) = Net (recip a) (recip b) (recip c) 480 | fromRational x = Net (fromRational x) (fromRational x) (fromRational x) 481 | 482 | instance KnownNat n => MWC.Variate (R n) where 483 | uniform g = randomVector <$> MWC.uniform g <*> pure Uniform 484 | uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g 485 | 486 | instance (KnownNat m, KnownNat n) => MWC.Variate (L m n) where 487 | uniform g = uniformSample <$> MWC.uniform g <*> pure 0 <*> pure 1 488 | uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g 489 | 490 | instance (KnownNat i, KnownNat o) => MWC.Variate (Layer i o) where 491 | uniform g = Layer <$> MWC.uniform g <*> MWC.uniform g 492 | uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g 493 | 494 | instance (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => MWC.Variate (Network i h1 h2 o) where 495 | uniform g = Net <$> MWC.uniform g <*> MWC.uniform g <*> MWC.uniform g 496 | uniformR (l, h) g = (\x -> x * (h - l) + l) <$> MWC.uniform g 497 | 498 | instance Backprop (R n) where 499 | zero = zeroNum 500 | add = addNum 501 | one = oneNum 502 | 503 | instance (KnownNat n, KnownNat m) => Backprop (L m n) where 504 | zero = zeroNum 505 | add = addNum 506 | one = oneNum 507 | 508 | instance (KnownNat i, KnownNat o) => Backprop (Layer i o) 509 | instance (KnownNat i, KnownNat h1, KnownNat h2, KnownNat o) => Backprop (Network i h1 h2 o) 510 | -------------------------------------------------------------------------------- /doc/01-getting-started.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Getting Started 3 | --- 4 | 5 | Getting Started 6 | =============== 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE DataKinds #-} 10 | {-# LANGUAGE DeriveGeneric #-} 11 | {-# LANGUAGE FlexibleContexts #-} 12 | {-# LANGUAGE TemplateHaskell #-} 13 | {-# LANGUAGE ViewPatterns #-} 14 | 15 | 16 | import Data.List 17 | import Debug.SimpleReflect 18 | import GHC.Generics (Generic) 19 | import GHC.TypeNats 20 | import Inliterate.Import 21 | import Lens.Micro 22 | import Lens.Micro.TH 23 | import Numeric.Backprop.Class 24 | import Numeric.LinearAlgebra.Static (L, R) 25 | import System.Random 26 | import qualified Numeric.LinearAlgebra.Static as H 27 | ``` 28 | 29 | *backprop* is a Haskell library **[available on hackage][haddock]**, so can be 30 | used in your package however way you like to require libraries. Be sure to add 31 | it to your cabal file's (or package.yaml's) build-depends field. 32 | 33 | Automatic Backpropagated Functions 34 | ---------------------------------- 35 | 36 | With *backprop*, you can write your functions in Haskell as normal functions: 37 | 38 | ```haskell top 39 | import Numeric.Backprop 40 | 41 | myFunc x = sqrt (x * 4) 42 | ``` 43 | 44 | They can be run with `evalBP`: 45 | 46 | ```haskell eval 47 | evalBP myFunc 9 :: Double 48 | ``` 49 | 50 | And...the twist? You can also get the gradient of your functions! 51 | 52 | ```haskell eval 53 | gradBP myFunc 9 :: Double 54 | ``` 55 | 56 | We can even be cute with with the *[simple-reflect][]* library: 57 | 58 | [simple-reflect]: https://hackage.haskell.org/package/simple-reflect 59 | 60 | ```haskell top hide 61 | instance Backprop Expr where 62 | zero = zeroNum 63 | add = addNum 64 | one = oneNum 65 | instance AskInliterate Expr 66 | ``` 67 | 68 | ```haskell eval 69 | evalBP myFunc x :: Expr 70 | ``` 71 | 72 | 73 | ```haskell eval 74 | gradBP myFunc x :: Expr 75 | ``` 76 | 77 | 78 | And that's the gist of the entire library: write your functions to compute your 79 | things, and `gradBP` will give you the gradients and derivatives of those 80 | functions. 81 | 82 | ### Multiple Same-Type Inputs 83 | 84 | Multiple inputs of the same type can be handled with `sequenceVar`: 85 | 86 | ```haskell top 87 | funcOnList (sequenceVar->[x,y,z]) = sqrt (x / y) * z 88 | ``` 89 | 90 | ```haskell eval 91 | evalBP funcOnList [3,5,-2] :: Double 92 | ``` 93 | 94 | ```haskell eval 95 | gradBP funcOnList [3,5,-2] :: [Double] 96 | ``` 97 | 98 | Heterogeneous Backprop 99 | ---------------------- 100 | 101 | But the real magic happens when you mix and match types. Let's make a simple 102 | type representing a feed-forward fully connected artificial neural network with 103 | 100 inputs, a single hidden layer of 20 nodes, and 5 outputs: 104 | 105 | ```haskell top 106 | data Net = N { _nWeights1 :: L 20 100 107 | , _nBias1 :: R 20 108 | , _nWeights2 :: L 5 20 109 | , _nBias2 :: R 5 110 | } 111 | deriving (Show, Generic) 112 | 113 | instance Backprop Net 114 | 115 | -- requires -XTemplateHaskell 116 | makeLenses ''Net 117 | ``` 118 | 119 | using the `L m n` type from the *[hmatrix][]* library to represent an m-by-n 120 | matrix, and the `R n` type to represent an n-vector. 121 | 122 | [hmatrix]: http://hackage.haskell.org/package/hmatrix 123 | 124 | We can write a function to "run" the network on a `R 100` and get an `R 5` 125 | back, using `^^.` for lens access and `#>` from the *[hmatrix-backprop][]* library for 126 | matrix-vector multiplication: 127 | 128 | [hmatrix-backprop]: http://hackage.haskell.org/package/hmatrix-backprop 129 | 130 | ```haskell top hide 131 | instance Backprop (R n) where 132 | zero = zeroNum 133 | add = addNum 134 | one = oneNum 135 | 136 | instance (KnownNat n, KnownNat m) => Backprop (L n m) where 137 | zero = zeroNum 138 | add = addNum 139 | one = oneNum 140 | 141 | (#>) 142 | :: (KnownNat n, KnownNat m, Reifies s W) 143 | => BVar s (L n m) -> BVar s (R m) -> BVar s (R n) 144 | (#>) = liftOp2 . op2 $ \xs y -> 145 | ( xs H.#> y 146 | , \d -> (d `H.outer` y, H.tr xs H.#> d) 147 | ) 148 | 149 | dot :: (KnownNat n, Reifies s W) => BVar s (R n) -> BVar s (R n) -> BVar s Double 150 | dot = liftOp2 . op2 $ \x y -> 151 | ( x `H.dot` y 152 | , \d -> let d' = H.konst d 153 | in (d' * y, x * d') 154 | ) 155 | ``` 156 | 157 | ```haskell top 158 | runNet net x = z 159 | where 160 | -- run first layer 161 | y = logistic $ (net ^^. nWeights1) #> x + (net ^^. nBias1) 162 | -- run second layer 163 | z = logistic $ (net ^^. nWeights2) #> y + (net ^^. nBias2) 164 | 165 | logistic :: Floating a => a -> a 166 | logistic x = 1 / (1 + exp (-x)) 167 | ``` 168 | 169 | We can *run* this with a network and input vector: 170 | 171 | ```haskell top hide 172 | myVector :: R 100 173 | myVector = H.randomVector 93752345 H.Uniform - 0.5 174 | 175 | myTarget :: R 5 176 | myTarget = H.randomVector 93752345 H.Uniform - 0.5 177 | 178 | myNet :: Net 179 | myNet = N (H.uniformSample 2394834 (-0.5) 0.5) 180 | (H.randomVector 84783451 H.Uniform - 0.5) 181 | (H.uniformSample 9293092 (-0.5) 0.5) 182 | (H.randomVector 64814524 H.Uniform - 0.5) 183 | 184 | instance KnownNat n => AskInliterate (R n) where 185 | askInliterate = answerWith (show . H.extract) 186 | instance AskInliterate Net where 187 | askInliterate = answerWith $ intercalate "\n" 188 | . ((++ ["-- ..."]) . map lim) 189 | . take 5 190 | . lines 191 | . show 192 | where 193 | lim = (++ " -- ...") . take 100 194 | ``` 195 | 196 | ```haskell eval 197 | evalBP2 runNet myNet myVector 198 | ``` 199 | 200 | But --- and here's the fun part --- if we write a "loss function" to evaluate 201 | "how badly" our network has done, using `dot` from the *hmatrix-backprop* 202 | library: 203 | 204 | ```haskell top 205 | squaredError target output = error `dot` error 206 | where 207 | error = target - output 208 | ``` 209 | 210 | we can "test" our networks: 211 | 212 | ```haskell top 213 | netError target input net = squaredError (auto target) 214 | (runNet net (auto input)) 215 | ``` 216 | 217 | (more on `auto` later) 218 | 219 | ```haskell eval 220 | evalBP (netError myTarget myVector) myNet 221 | ``` 222 | 223 | At this point, we've *written a normal function to compute the error of our 224 | network*. And, with the backprop library...we now have a way to compute the 225 | *gradient* of our network's error with respect to all of our weights! 226 | 227 | ```haskell eval 228 | gradBP (netError myTarget myVector) myNet 229 | ``` 230 | 231 | We can now use the gradient to "[train][]" our network to give the correct 232 | responses given a certain input! This can be done by computing the gradient 233 | for every expected input-output pair, and adjusting the network in the opposite 234 | direction of the gradient every time. 235 | 236 | [train]: https://blog.jle.im/entry/purely-functional-typed-models-1.html 237 | 238 | Main Idea 239 | --------- 240 | 241 | The main pattern of usage for this library is: 242 | 243 | 1. Write your function normally to compute something (like the loss function) 244 | 2. Use `gradBP` to automatically get the gradient of that something with 245 | respect to your inputs! 246 | 247 | In the case of optimizing models, you: 248 | 249 | 1. Write your function normally to compute the thing you want to minimize 250 | 2. Use `gradBP` to automatically get the gradient of the thing you want to 251 | minimize with respect to your inputs. Then, adjust your inputs according 252 | to this gradient until you get the perfect minimal result! 253 | 254 | Now that you've had a taste, let's **[look at the details][details]**. You can 255 | also just go ahead and **[jump into the haddock documentation][haddock]**! 256 | 257 | [details]: https://backprop.jle.im/02-a-detailed-look.html 258 | [haddock]: https://hackage.haskell.org/package/backprop 259 | -------------------------------------------------------------------------------- /doc/02-a-detailed-look.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: A Detailed Look 3 | --- 4 | 5 | A Detailed Look 6 | =============== 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE FlexibleContexts #-} 10 | 11 | import Numeric.Backprop 12 | ``` 13 | 14 | So, what's really going on? 15 | 16 | The BVar 17 | -------- 18 | 19 | The entire library revolves around the `BVar`, a variable holding a 20 | "backpropagatable value". As you use a `BVar`, the *backprop* library will 21 | track how it is used and where you use it. You can use `evalBP` to simply get 22 | the result, but using `gradBP` will perform backpropagation ("reverse-mode 23 | [automatic differentiation][autodiff]") 24 | 25 | [autodiff]: https://en.wikipedia.org/wiki/Automatic_differentiation 26 | 27 | For example, we looked earlier at a function that computes the square root of a 28 | quadrupled number: 29 | 30 | ```haskell top 31 | myFunc :: Double 32 | -> Double 33 | myFunc x = sqrt (x * 4) 34 | ``` 35 | 36 | As we are using it, its type is "really": 37 | 38 | ```haskell top 39 | myFunc' :: Reifies s W 40 | => BVar s Double 41 | -> BVar s Double 42 | myFunc' x = sqrt (x * 4) 43 | ``` 44 | 45 | `myFunc'` takes a `BVar s Double` (a `BVar` containing a `Double`) and returns 46 | a new one that is the square root of the quadrupled number. You can think of 47 | the `Reifies s W` as being a necessary constraint that allows backpropagation 48 | to happen. 49 | 50 | `BVar`s have `Num`, `Fractional`, and `Floating` instances, and so can be used 51 | with addition, multiplication, square rooting, etc. The "most general" type of 52 | `myFunc` is `myFunc :: Floating a => a -> a`, and since `BVar s Double` has a 53 | `Floating` instance, you could even just use it directly as a backpropagatable 54 | function. 55 | 56 | This means you can basically treat a `BVar s Double` almost exactly like it was 57 | a `Double` --- you'll practically never tell the difference! `BVar`s also have 58 | `Ord` and `Eq` instances, so you can compare them and branch on the results, 59 | too. 60 | 61 | ```haskell top 62 | myAbs :: Reifies s W 63 | => BVar s Double 64 | -> BVar s Double 65 | myAbs x | x < 0 = negate x 66 | | otherwise = x 67 | ``` 68 | 69 | The goal of the `BVar` interface is that you should be able to treat a `BVar s 70 | a` (a `BVar` containing an `a`) as if it was an `a`, with no easily noticeable 71 | differences. 72 | 73 | Runners 74 | ------- 75 | 76 | The entire point of the library is to write your computation as a normal 77 | function taking a `BVar` (or many) and returning a single `BVar`. Just treat 78 | `BVar`s as if they actually were the value they are containing, and you can't 79 | go wrong. 80 | 81 | Once you do this, you can use `evalBP` to "run" the function itself: 82 | 83 | ```haskell 84 | evalBP :: (forall s. Reifies s W => BVar s a -> BVar s b) 85 | -> (a -> b) 86 | ``` 87 | 88 | This can be read as taking a `BVar s a -> BVar s b` and returning the `a -> b` 89 | that that function encodes. The RankN type there (the `forall s.`) is mostly 90 | there to prevent leakage of `BVar`s (same as it is used in *Control.Monad.ST* 91 | and `runST`). It ensures that no `BVar`s "escape" the function somehow. 92 | 93 | `evalBP` is extremely efficient, and usually carries virtually zero overhead 94 | over writing your function directly on your values without `BVar`s. 95 | 96 | *But*, the more interesting thing of course is computing the *gradient* of your 97 | function. This is done with `gradBP`: 98 | 99 | ```haskell 100 | gradBP :: (Backprop a, Backprop b) 101 | => (forall s. Reifies s W => BVar s a -> BVar s b) 102 | -> a 103 | -> a 104 | ``` 105 | 106 | Which takes a `BVar s a -> BVar s b` backpropagatable function and an input, 107 | and returns *the gradient at that input*. It gives the direction of greatest 108 | positive change (in the output) of your input, and also how much a variation in 109 | your input will affect your output. 110 | 111 | And that's all there is to it! Instead of `a -> b`'s, write `BVar s a -> BVar 112 | s b`'s to compute what you want to know the gradient of. These are normal 113 | functions, so you can use all of your favorite higher order functions and 114 | combinators (like `(.)`, `map`, etc.). And once you're done, use `gradBP` to 115 | compute that gradient. 116 | 117 | Note that `gradBP` requires a `Backprop` constraint on the input and output of 118 | your function. `Backprop` is essentially the typeclass of values that can be 119 | "backpropagated". For product types, this instance is automatically derivable. 120 | But writing your own custom instances for your own types is also fairly 121 | straightforward. More on this later! 122 | 123 | The rest of the package really is just ways to manipulate `BVar s a`s as if 124 | they were just `a`s, to make everything as smooth as possible. Let's move on 125 | to learning about **[ways to manipulate BVars][bvars]**! 126 | 127 | [bvars]: https://backprop.jle.im/03-manipulating-bvars.html 128 | -------------------------------------------------------------------------------- /doc/03-manipulating-bvars.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Manipulating BVars 3 | --- 4 | 5 | Manipulating BVars 6 | ================== 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE DataKinds #-} 10 | {-# LANGUAGE DeriveGeneric #-} 11 | {-# LANGUAGE FlexibleContexts #-} 12 | {-# LANGUAGE FlexibleInstances #-} 13 | {-# LANGUAGE StandaloneDeriving #-} 14 | {-# LANGUAGE TemplateHaskell #-} 15 | {-# LANGUAGE TypeFamilies #-} 16 | {-# LANGUAGE ViewPatterns #-} 17 | 18 | 19 | import Data.Functor.Identity 20 | import GHC.Generics (Generic) 21 | import GHC.TypeNats 22 | import Lens.Micro 23 | import Lens.Micro.TH 24 | import Numeric.Backprop 25 | import Numeric.Backprop.Class 26 | import Numeric.LinearAlgebra.Static (L, R) 27 | import System.Random 28 | import qualified Numeric.LinearAlgebra.Static as H 29 | ``` 30 | 31 | The most important aspect of the usability of this library is allowing you to 32 | seamlessly manipulate `BVar s a`s as if they were just `a`s, without requiring 33 | you as the user to be able to recognize or acknowledge the difference. Here 34 | are some techniques to that end. 35 | 36 | Remember, a `BVar s a` is a `BVar` containing an `a` --- it's an `a` that, when 37 | used, keeps track of and propagates your gradient. 38 | 39 | Typeclass Interface 40 | ------------------- 41 | 42 | `BVar`s have `Num`, `Fractional`, `Floating`, `Eq`, and `Ord` instances. These 43 | instances are basically "lifted" to the `BVar` itself, so if you have a `BVar s 44 | Double`, you can use `(*)`, `sqrt`, `(>)`, etc. on it exactly as if it were 45 | just a `Double`. 46 | 47 | Constant Values 48 | --------------- 49 | 50 | If we don't *care* about a value's gradient, we can use `auto`: 51 | 52 | ```haskell 53 | auto :: a -> BVar s a 54 | ``` 55 | 56 | `auto x` basically gives you a `BVar` that contains just `x` alone. Useful for 57 | using with functions that expect `BVar`s, but you just have a specific value 58 | you want to use. 59 | 60 | Coercible 61 | --------- 62 | 63 | If `a` and `b` are `Coercible`, then so are `BVar s a` and `BVar s b`, using 64 | the `coerceVar` function. This is useful for "unwrapping" and "wrapping" 65 | `BVar`s of newtypes: 66 | 67 | 68 | ```haskell top 69 | newtype MyInt = MyInt Int 70 | 71 | getMyInt :: BVar s MyInt -> BVar s Int 72 | getMyInt = coerceVar 73 | ``` 74 | 75 | Accessing Contents 76 | ------------------ 77 | 78 | The following techniques can be used to access values inside `BVar`s: 79 | 80 | ### Traversable Containers 81 | 82 | One that we saw earlier was `sequenceVar`, which we used to turn a `BVar` 83 | containing a list into a list of `BVar`s: 84 | 85 | ```haskell 86 | sequenceVar :: (Backprop a, Reifies s W) 87 | => BVar s [a] 88 | -> [BVar s a] 89 | ``` 90 | 91 | If you have a `BVar` containing a list, you can get a list of `BVar`s of all of 92 | that list's elements. (`sequenceVar` actually works on all `Traversable` 93 | instances, not just lists) This is very useful when combined with 94 | `-XViewPatterns`, as seen earlier. 95 | 96 | ### Records and Fields 97 | 98 | In practice, a lot of usage involves functions involving contents of records or 99 | data types containing fields. The previous example, involving a simple ANN, 100 | demonstrates this: 101 | 102 | ```haskell top 103 | data Net = N { _nWeights1 :: L 20 100 104 | , _nBias1 :: R 20 105 | , _nWeights2 :: L 5 20 106 | , _nBias2 :: R 5 107 | } 108 | deriving (Show, Generic) 109 | 110 | instance Backprop Net -- can be automatically defined 111 | ``` 112 | 113 | To compute the result of this network (ran on an `R 100`, a 100-vector) and get 114 | the output `R 5`, we need do a matrix multiplication by the `_nWeights1` field, 115 | add the result to the `_nBias1` field...basically, the result is a function of 116 | linear algebra and related operations on the input and all of the contents of 117 | the `Net` data type. However, you can't directly use `_nWeights`, since it 118 | takes a `Net`, not `BVar s Net`. And you also can't directly pattern match on 119 | the `N` constructor. 120 | 121 | There are two main options for this: the lens interface, and the higher-kinded 122 | data interface. 123 | 124 | #### Lens Interface 125 | 126 | The most straightforward way to do this is the lens-based interface, using 127 | `viewVar` or `^^.`. 128 | 129 | If we make lenses for `Net` using the *[lens][]* or *[microlens-th][]* packages: 130 | 131 | [lens]: http://hackage.haskell.org/package/lens 132 | [microlens-th]: http://hackage.haskell.org/package/microlens-th 133 | 134 | ```haskell top 135 | -- requires -XTemplateHaskell 136 | makeLenses ''Net 137 | ``` 138 | 139 | or make them manually: 140 | 141 | ```haskell top 142 | nBias1' :: Functor f => (R 20 -> f (R 20)) -> Net -> f Net 143 | nBias1' f n = (\b -> n { _nBias1 = b }) <$> f (_nBias1 n) 144 | ``` 145 | 146 | then `^.` from the *lens* or *[microlens][]* packages lets you retrieve a field 147 | from a `Net`: 148 | 149 | [microlens]: http://hackage.haskell.org/package/microlens 150 | 151 | ```haskell 152 | (^. nWeights1) :: Net -> L 20 100 153 | (^. nBias1 ) :: Net -> R 20 154 | (^. nWeights2) :: Net -> L 5 20 155 | (^. nBias2 ) :: Net -> R 5 156 | ``` 157 | 158 | And, `^^.` from *backprop* (also aliased as `viewVar`) lets you do the same 159 | thing from a `BVar s Net` (a `BVar` containing your `Net`): 160 | 161 | ```haskell 162 | (^^. nWeights1) :: BVar s Net -> BVar s (L 20 100) 163 | (^^. nBias1 ) :: BVar s Net -> BVar s (R 20) 164 | (^^. nWeights2) :: BVar s Net -> BVar s (L 5 20) 165 | (^^. nBias2 ) :: BVar s Net -> BVar s (R 5) 166 | ``` 167 | 168 | ```haskell top hide 169 | instance Backprop (R n) where 170 | zero = zeroNum 171 | add = addNum 172 | one = oneNum 173 | 174 | instance (KnownNat n, KnownNat m) => Backprop (L n m) where 175 | zero = zeroNum 176 | add = addNum 177 | one = oneNum 178 | 179 | (#>) 180 | :: (KnownNat n, KnownNat m, Reifies s W) 181 | => BVar s (L n m) -> BVar s (R m) -> BVar s (R n) 182 | (#>) = liftOp2 . op2 $ \xs y -> 183 | ( xs H.#> y 184 | , \d -> (d `H.outer` y, H.tr xs H.#> d) 185 | ) 186 | 187 | dot :: (KnownNat n, Reifies s W) => BVar s (R n) -> BVar s (R n) -> BVar s Double 188 | dot = liftOp2 . op2 $ \x y -> 189 | ( x `H.dot` y 190 | , \d -> let d' = H.konst d 191 | in (d' * y, x * d') 192 | ) 193 | ``` 194 | 195 | With our lenses and `^^.`, we can write our network running function. This 196 | time, I'll include the type! 197 | 198 | ```haskell top 199 | runNet :: Reifies s W 200 | => BVar s Net 201 | -> BVar s (R 100) 202 | -> BVar s (R 5) 203 | runNet net x = z 204 | where 205 | -- run first layer 206 | y = logistic $ (net ^^. nWeights1) #> x + (net ^^. nBias1) 207 | -- run second layer 208 | z = logistic $ (net ^^. nWeights2) #> y + (net ^^. nBias2) 209 | 210 | logistic :: Floating a => a -> a 211 | logistic x = 1 / (1 + exp (-x)) 212 | ``` 213 | 214 | Note that we are using versions of `#>` lifted for `BVar`s, from the 215 | *[hmatrix-backprop][]* library: 216 | 217 | ```haskell 218 | (#>) :: BVar s (L m n) -> BVar s (R n) -> BVar s (R m) 219 | ``` 220 | 221 | [hmatrix-backprop]: http://hackage.haskell.org/package/hmatrix-backprop 222 | 223 | #### Higher-Kinded Data Interface 224 | 225 | Using the lens based interface, you can't directly pattern match and construct 226 | fields. To allow for directly pattern matching, there's another interface 227 | option involving the "Higher-Kinded Data" techniques described in [this 228 | article][hkd]. 229 | 230 | [hkd]: http://reasonablypolymorphic.com/blog/higher-kinded-data/ 231 | 232 | If we had a type-family (that can be re-used for all of your data types): 233 | 234 | ```haskell top 235 | type family HKD f a where 236 | HKD Identity a = a 237 | HKD f a = f a 238 | ``` 239 | 240 | We can define `Net` instead as: 241 | 242 | ```haskell top 243 | data Met' f = M { _mWeights1 :: HKD f (L 20 100) 244 | , _mBias1 :: HKD f (R 20) 245 | , _mWeights2 :: HKD f (L 5 20) 246 | , _mBias2 :: HKD f (R 5) 247 | } 248 | deriving Generic 249 | ``` 250 | 251 | Then our *original* type is: 252 | 253 | ```haskell top 254 | type Met = Met' Identity 255 | 256 | deriving instance Show Met 257 | instance Backprop Met 258 | ``` 259 | 260 | `Met` is the same as `Net` in every way -- it can be pattern matched on to get 261 | the `L 20 100`, etc. (the `Identity` disappears): 262 | 263 | ```haskell top 264 | getMetBias1 :: Met -> R 20 265 | getMetBias1 (M _ b _ _) = b 266 | ``` 267 | 268 | The benefit of this is that we can now directly pattern match on a `BVar s Met` 269 | to get the internal fields as `BVar`s using `splitBV` as a view pattern (or the 270 | `BV` pattern synonym): 271 | 272 | ```haskell top 273 | runMet :: Reifies s W 274 | => BVar s Met 275 | -> BVar s (R 100) 276 | -> BVar s (R 5) 277 | runMet (splitBV -> M w1 b1 w2 b2) x = z 278 | where 279 | -- run first layer 280 | y = logistic $ w1 #> x + b1 281 | -- run second layer 282 | z = logistic $ w2 #> y + b2 283 | 284 | runMetPS :: Reifies s W 285 | => BVar s Met 286 | -> BVar s (R 100) 287 | -> BVar s (R 5) 288 | runMetPS (BV (M w1 b1 w2 b2)) x = z 289 | where 290 | -- run first layer 291 | y = logistic $ w1 #> x + b1 292 | -- run second layer 293 | z = logistic $ w2 #> y + b2 294 | ``` 295 | 296 | Now, the `M w1 b1 w2 b2` pattern can be used to deconstruct *both* "normal" 297 | `Met`s, as well as a `BVar s Met` (with `splitBV` or `BV`). 298 | 299 | Note that this HKD access method is potentially less performant than lens 300 | access (by about 10-20%). 301 | 302 | ### Potential or Many Fields 303 | 304 | Some values "may" or "may not" have values of a given field. An example would 305 | include the nth item in a list or vector, or the `Just` of a `Maybe`. 306 | 307 | For these, the lens-based (prism-based/traversal-based) interface is the main way to access 308 | partial fields. You can use `(^^?)` or `previewVar` with any `Traversal`: 309 | 310 | ```haskell 311 | (^?) :: a -> Traversal' a b -> Maybe b 312 | (^^?) :: BVar s a -> Traversal' a b -> Maybe (BVar s b) 313 | ``` 314 | 315 | If the value in the `BVar` "has" that field, then you'll get a `Just` with the 316 | `BVar` of that field's contents. If it doesn't, you'll get a `Nothing`. 317 | 318 | You can use this with any prism or traversal, like using `_head` to get the 319 | first item in a list if it exists. 320 | 321 | If you have a type that might contain *many* values of a field (like a tree or 322 | list), you can use `(^^..)` or `toListOfVar`, which works on any `Traversal`: 323 | 324 | ```haskell 325 | (^..) :: a -> Traversal' a b -> [ b] 326 | (^^..) :: BVar s a -> Traversal' a b -> [BVar s b] 327 | ``` 328 | 329 | This can be used to implement `sequenceVar`, actually: 330 | 331 | ```haskell 332 | sequenceVar :: BVar s [a] -> [BVar s a] 333 | sequenceVar xs = xs ^^.. traverse 334 | ``` 335 | 336 | ### Tuples 337 | 338 | The `T2` pattern synonym is provided, which allow you to pattern match on a 339 | `BVar s (a, b)` to get a `BVar s a` and `BVar s b`. The `T3` pattern is also 340 | provided, which does the same thing for three-tuples. 341 | 342 | Note that `T2` and `T3` are *bidirectional* pattern synonyms, and can be used to 343 | construct as well as deconstruct. 344 | 345 | Combining BVars 346 | --------------- 347 | 348 | The following techniques can be used to "combine" `BVar`s: 349 | 350 | ### Foldable Containers 351 | 352 | The "opposite" of `sequenceVar` is `collectVar`, which takes a foldable 353 | container of `BVar`s and returns a `BVar` containing that foldable container of 354 | contents: 355 | 356 | ```haskell 357 | collectVar :: (Backprop a, Foldable t, Functor t, Reifies s W) 358 | => t (BVar s a) 359 | -> BVar s (t a) 360 | ``` 361 | 362 | ### Constructors 363 | 364 | Sometimes you would like to combine a bunch of `BVar`s into a `BVar` of 365 | specific container or data type. 366 | 367 | #### isoVar 368 | 369 | The simplest way to do this is using the `isoVar`, `isoVar2`, etc. family of 370 | functions: 371 | 372 | ```haskell 373 | isoVar2 374 | :: (Backprop a, Backprop b, Backprop c, Reifies s W) 375 | => (a -> b -> c) 376 | -> (c -> (a, b)) 377 | -> BVar s a 378 | -> BVar s b 379 | -> BVar s c 380 | ``` 381 | 382 | So if we had a type like: 383 | 384 | ```haskell top 385 | data DoubleInt = DI Double Int 386 | ``` 387 | 388 | We can combine a `Double` and `Int` into a `DoubleInt` using `isoVar2`: 389 | 390 | ```haskell 391 | isoVar2 DI (\(DI x y) -> (x,y)) 392 | :: Reifies s W 393 | => BVar s Double 394 | -> BVar s Int 395 | -> BVar s DoubleInt 396 | ``` 397 | 398 | #### Higher-Kinded Data Interface 399 | 400 | You can also use the ["Higher Kinded Data"][hkd] interface, as well. For our 401 | `Met` type above, you can use `joinBV`, or the `BV` pattern synonym: 402 | 403 | ```haskell top 404 | makeMet :: Reifies s W 405 | => BVar s (L 20 100) 406 | -> BVar s (R 20) 407 | -> BVar s (L 5 20) 408 | -> BVar s (R 5) 409 | -> BVar s Met 410 | makeMet w1 b1 w2 b2 = joinBV (M w1 b1 w2 b2) 411 | ``` 412 | 413 | ### Modifying fields 414 | 415 | If you just want to "set" a specific field, you can use the lens-based 416 | interface with `(.~~)` or `setVar`. For example, if we wanted to set the 417 | `_nWeights2` field of a `Net` to a new matrix, we can do: 418 | 419 | ```haskell 420 | myNet & nWeights2 .~~ newMatrix 421 | ``` 422 | 423 | or 424 | 425 | ```haskell 426 | setVar nWeights2 427 | :: Reifies s W 428 | => BVar s (L 20 5) 429 | -> BVar s Net 430 | -> BVar s Net 431 | ``` 432 | 433 | You can also use `(%~~)` or `overVar` to apply a *function* to a specific 434 | inside your value. 435 | 436 | Prelude Modules 437 | --------------- 438 | 439 | Finally, the *Prelude.Backprop* module has a lot of your normal Prelude 440 | functions "lifted" to work on `BVar`s of values. For many situations, these 441 | aren't necessary, and normal Prelude functions will work just fine on `BVar`s 442 | of values (like `(.)`). However, it does have some convenient functions, like 443 | `minimum`, `foldl'`, `fmap`, `toList`, `fromIntegral`, `realToFrac`, etc. 444 | lifted to work on `BVar`s. This module is meant to be imported qualified. 445 | 446 | Moving On 447 | ========= 448 | 449 | Now that you know all about `BVar`s, you really can just **[jump into the 450 | haddocks][haddock]** and start writing programs. The next section of this 451 | documentation is more details about **[the `Backprop` typeclass][class]**. 452 | 453 | [class]: https://backprop.jle.im/04-the-backprop-typeclass.html 454 | [haddock]: https://hackage.haskell.org/package/backprop 455 | -------------------------------------------------------------------------------- /doc/04-the-backprop-typeclass.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: The Backprop Typeclass 3 | --- 4 | 5 | The Backprop Typeclass 6 | ====================== 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE FlexibleContexts #-} 10 | {-# LANGUAGE DataKinds #-} 11 | {-# LANGUAGE DeriveGeneric #-} 12 | 13 | import GHC.Generics (Generic) 14 | import GHC.TypeNats 15 | import Numeric.LinearAlgebra.Static (L, R) 16 | import Numeric.Backprop 17 | import Numeric.Backprop.Class 18 | import qualified Data.Vector as V 19 | ``` 20 | 21 | Most of the functions in this module require a `Backprop` constraint on values 22 | you wish to backpropagate. Even if you manage to get around it for the most 23 | part, `gradBP` (the actual function to compute gradients) requires it on both 24 | the inputs and outputs. Let's dig deeper into what it is, and how to define 25 | instances. 26 | 27 | The Class 28 | --------- 29 | 30 | The typeclass contains three methods: `zero`, `add`, and `one`: 31 | 32 | ```haskell 33 | class Backprop a where 34 | zero :: a -> a 35 | add :: a -> a -> a 36 | one :: a -> a 37 | ``` 38 | 39 | `zero` is "zero" in the verb sense -- it takes a value and "zeroes out" all 40 | components. For a vector, this means returning a zero vector of the same 41 | shape. For a list, this means replacing all of the items with zero and 42 | returning a list of the same length. 43 | 44 | `one` does the same thing but with one; the point of it is to be `one = gradBP 45 | id` --- the gradient of the identity function for your type. 46 | 47 | `add` is used to add together contributions in gradients, and is usually a 48 | component-wise addition. 49 | 50 | Instances are provided for most common data types where it makes sense. 51 | 52 | Custom Instances 53 | ---------------- 54 | 55 | ### Generics 56 | 57 | When defining your own custom types, if your custom type is has *a single 58 | constructor* where all fields are instances of `Backprop`, then *GHC.Generics* 59 | can be used to write your instances automatically: 60 | 61 | ```haskell top hide 62 | instance Backprop (R n) where 63 | zero = zeroNum 64 | add = addNum 65 | one = oneNum 66 | 67 | instance (KnownNat n, KnownNat m) => Backprop (L n m) where 68 | zero = zeroNum 69 | add = addNum 70 | one = oneNum 71 | ``` 72 | 73 | ```haskell top 74 | data MyType = MkMyType Double [Float] (R 10) (L 20 10) (V.Vector Double) 75 | deriving Generic 76 | ``` 77 | 78 | Nice type. Since it has a single constructor and all of its fields are already 79 | `Backprop` instances, we can just write: 80 | 81 | ```haskell top 82 | instance Backprop MyType 83 | ``` 84 | 85 | and now your type can be backpropagated! 86 | 87 | ### Common Patterns 88 | 89 | For writing "primitive" `Backprop` instances (types that aren't product types), 90 | you can use the provided "helpers" from the *Numeric.Backprop.Class* module. 91 | 92 | If your type is a `Num` instance, you can use `zeroNum`, `addNum`, and 93 | `oneNum`: 94 | 95 | ```haskell 96 | instance Backprop Double where 97 | zero = zeroNum 98 | add = addNum 99 | one = oneNum 100 | ``` 101 | 102 | If your type is made using a `Functor` instance, you can use `zeroFunctor` and 103 | `oneFunctor`: 104 | 105 | ```haskell 106 | instance Backprop a => Backprop (V.Vector a) where 107 | zero = zeroFunctor 108 | add = undefined -- ?? 109 | one = oneFunctor 110 | ``` 111 | 112 | And if your type has an `IsList` instance, you can use `addIsList`: 113 | 114 | ```haskell 115 | instance Backprop a => Backprop (V.Vector a) where 116 | zero = zeroFunctor 117 | add = addIsList 118 | one = oneFunctor 119 | ``` 120 | 121 | ### Completely Custom 122 | 123 | Completely custom instances are also possible; you just need to implement 124 | `zero`, `add`, and `one` as they make sense for your type. Just make sure that 125 | you obey [the laws][laws] for sane behavior! 126 | 127 | [laws]: http://hackage.haskell.org/package/backprop/docs/Numeric-Backprop-Class.html 128 | 129 | Moving On 130 | ========= 131 | 132 | At this point, feel free to **[jump into the haddocks][haddock]**, or read on 133 | further for **[a list of applications and resources][applications]**. 134 | 135 | [haddock]: https://hackage.haskell.org/package/backprop 136 | [applications]: https://backprop.jle.im/05-applications.html 137 | -------------------------------------------------------------------------------- /doc/05-applications.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Applications and Resources 3 | --- 4 | 5 | Applications and Resources 6 | ========================== 7 | 8 | Congratulations! You are now a *backprop* master. Maybe you've even looked at 9 | the [haddocks][haddock], which has the technical run-down of all of the 10 | functions and types in this library. Now what? 11 | 12 | * Check out my [Introducing the backprop library][intro] blog post where I 13 | announced the library to the world. In it, I introduce the library by 14 | building and training a full artificial neural network with it, and use it 15 | to classify the famous MNIST handwritten digit data set. 16 | 17 | * If you want an even more high-level perspective and inspiration, check out 18 | my [A Purely Functional Typed Approach to Trainable Models][models] blog 19 | series, where I talk about how looking at modeling through the lens of 20 | differentiable programming with purely functional typed code can provide 21 | new insights and help you develop and train effective models. 22 | 23 | * While they are mostly re-phrasings of the two things above, I also have 24 | some [example projects as literate haskell files][lhs] on the github 25 | repository for the library. These are also [rendered as pdfs][renders] for 26 | easier reading. 27 | 28 | * If you're doing anything with linear algebra, why not check out the 29 | *[hmatrix-backprop][]* library, which provides the "backprop-lifted" 30 | operations that all of the above examples rely on for linear algebra 31 | operations? 32 | 33 | [haddock]: https://hackage.haskell.org/package/backprop 34 | [intro]: https://blog.jle.im/entry/introducing-the-backprop-library.html 35 | [models]: https://blog.jle.im/entry/purely-functional-typed-models-1.html 36 | [lhs]: https://github.com/mstksg/backprop/blob/master/samples 37 | [renders]: https://github.com/mstksg/backprop/tree/master/renders 38 | [hmatrix-backprop]: http://hackage.haskell.org/package/hmatrix-backprop 39 | 40 | This is the end of the "end-user" documentation for *backprop*! The rest of 41 | all you need to know to use the library is in the **[haddocks on 42 | hackage][haddock]**. 43 | 44 | Check out the sidebar for more technical details on [writing manual 45 | gradients][manual-gradients], [optimization and performance][performance], and 46 | [equipping your library for backprop][equipping]! 47 | 48 | [manual-gradients]: https://backprop.jle.im/06-manual-gradients.html 49 | [performance]: https://backprop.jle.im/07-performance.html 50 | [equipping]: https://backprop.jle.im/08-equipping-your-library.html 51 | -------------------------------------------------------------------------------- /doc/06-manual-gradients.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Manual Gradients 3 | --- 4 | 5 | Providing Hand-Written Gradients 6 | ================================ 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE DataKinds #-} 10 | {-# LANGUAGE DeriveGeneric #-} 11 | {-# LANGUAGE FlexibleContexts #-} 12 | {-# LANGUAGE FlexibleInstances #-} 13 | {-# LANGUAGE StandaloneDeriving #-} 14 | {-# LANGUAGE TemplateHaskell #-} 15 | {-# LANGUAGE TypeFamilies #-} 16 | {-# LANGUAGE ViewPatterns #-} 17 | 18 | 19 | import Data.Functor.Identity 20 | import qualified Data.List 21 | import GHC.Generics (Generic) 22 | import GHC.TypeNats 23 | import Inliterate.Import 24 | import Lens.Micro 25 | import Lens.Micro.TH 26 | import Numeric.Backprop 27 | import Numeric.Backprop.Class 28 | import Numeric.LinearAlgebra.Static (L, R, konst) 29 | import System.Random 30 | import qualified Data.Vector as V 31 | import qualified Numeric.LinearAlgebra.Static as H 32 | import qualified Numeric.LinearAlgebra as HU 33 | ``` 34 | 35 | Providing and writing hand-written gradients for operations can be useful if 36 | you are [peforming low-level optimizations][performance] or [equipping your 37 | library for backprop][equipping]. 38 | 39 | [performance]: https://backprop.jle.im/07-performance.html 40 | [equipping]: https://backprop.jle.im/08-equipping-your-library.html 41 | 42 | Ideally, as an *end user*, you should never have to do this. The whole point 43 | of the *backprop* library is to allow you to use backpropagatable functions as 44 | normal functions, and to let you build complicated functions by simply 45 | composing normal Haskell functions, where the *backprop* library automatically 46 | infers your gradients. 47 | 48 | However, if you are writing a library, you probably need to provide "primitive" 49 | backpropagatable functions (like matrix-vector multiplication for a linear 50 | algebra library) for your users, so your users can then use those primitive 51 | functions to write their own code, without ever having to be aware of any 52 | gradients. 53 | 54 | If you are writing code and recognize some bottlenecks related to library 55 | overhead as [described in this post][performance], then you might also want to 56 | provide manual gradients as a last resort. However, this should always be a 57 | last resort, as *figuring out* manual gradients is a tedious and error-prone 58 | process that can introduce subtle bugs in ways that don't always appear in 59 | testing. It also makes your code much more fragile and difficult to refactor 60 | and shuffle around (since you aren't using normal function composition and 61 | application anymore) and much harder to read. Only proceed if you decide that 62 | the huge cognitive costs are worth it. 63 | 64 | The Lifted Function 65 | ------------------- 66 | 67 | A lifted function of type 68 | 69 | ```haskell 70 | myFunc :: Reifies s W => BVar s a -> BVar s b 71 | ``` 72 | 73 | represents a backpropagatble function taking an `a` and returning a `b`. It is 74 | represented as a function taking a `BVar` containing an `a` and returning a 75 | `BVar` containing a `b`; the `BVar s` with the `Reifies s W` is what allows for 76 | tracking of backpropagation. 77 | 78 | A `BVar s a -> BVar s b` is really, actually, under the hood: 79 | 80 | ```haskell 81 | type BVar s a -> BVar s b 82 | = a -> (b, b -> a) 83 | ``` 84 | 85 | That is, given an input `a`, you get: 86 | 87 | 1. A `b`, the result (the "forward pass") 88 | 2. A `b -> a`, the "scaled gradient" function. 89 | 90 | A full technical description is given in the documentation for [Numeric.Backprop.Op][op]. 91 | 92 | [op]: http://hackage.haskell.org/package/backprop/docs/Numeric-Backprop-Op.html 93 | 94 | The `b` result is simple enough; it's the result of your function. The "scaled 95 | gradient" function requires some elaboration. Let's say you are writing a 96 | lifted version of your function \\(y = f(x)\\) (whose derivative is 97 | \\(\frac{dy}{dx}\\)), and that your *final result* at the end of your 98 | computation is \\(z = g(f(x))\\) (whose derivative is \\(\frac{dz}{dx}\\)). In 99 | that case, because of the chain rule, \\(\frac{dz}{dx} = \frac{dz}{dy} 100 | \frac{dy}{dx}\\). 101 | 102 | The scaled gradient `b -> a` is the function which, *given* 103 | \\(\frac{dy}{dz}\\) `:: b`, *returns* \\(\frac{dz}{dx}\\) `:: a`. (that is, 104 | returns \\(\frac{dz}{dy} \frac{dy}{dx}\\) `:: a`). 105 | 106 | For example, for the mathematical operation \\(y = f(x) = x^2\\), then, 107 | considering \\(z = g(f(x))\\), \\(\frac{dz}{dx} = \frac{dz}{dy} 2x\\). In fact, 108 | for all functions taking and returning scalars (just normal single numbers), 109 | \\(\frac{dz}{dx} = \frac{dz}{dy} f'(x)\\). 110 | 111 | Simple Example 112 | -------------- 113 | 114 | With that in mind, let's a lifted "squared" operation, that takes `x` and 115 | returns `x^2`: 116 | 117 | ```haskell top 118 | square 119 | :: (Num a, Backprop a, Reifies s W) 120 | => BVar s a 121 | -> BVar s a 122 | square = liftOp1 . op1 $ \x -> 123 | ( x^2 , \dzdy -> dzdy * 2 * x) 124 | -- ^- actual result ^- scaled gradient function 125 | ``` 126 | 127 | We can write one for `sin`, as well. For \\(y = f(x) = \sin(x)\\), we consider 128 | \\(z = g(f(x))\\) to see \\(\frac{dz}{dx} = \frac{dz}{dy} \cos(x)\\). So, we 129 | have: 130 | 131 | ```haskell top 132 | liftedSin 133 | :: (Floating a, Backprop a, Reifies s W) 134 | => BVar s a 135 | -> BVar s a 136 | liftedSin = liftOp1 . op1 $ \x -> 137 | ( sin x, \dzdy -> dzdy * cos x ) 138 | ``` 139 | 140 | In general, for functions that take and return scalars: 141 | 142 | ```haskell 143 | liftedF 144 | :: (Num a, Backprop a, Reifies s W) 145 | => BVar s a 146 | -> BVar s a 147 | liftedF = liftOp1 . op1 $ \x -> 148 | ( f x, \dzdy -> dzdy * dfdx x ) 149 | ``` 150 | 151 | For an example of every single numeric function in base Haskell, see [the 152 | source of Op.hs][opsource] for the `Op` definitions for every method in `Num`, 153 | `Fractional`, and `Floating`. 154 | 155 | [opsource]: https://github.com/mstksg/backprop/blob/a7651b4549048a3aca73c79c6fbe07c3e8ee500e/src/Numeric/Backprop/Op.hs#L646-L787 156 | 157 | Non-trivial example 158 | ------------------- 159 | 160 | A simple non-trivial example is `sumElements`, which we can define to take the 161 | *hmatrix* library's `R n` type (an n-vector of `Double`). In this case, we 162 | have to think about \\(g(\mathrm{sum}(\mathbf{x}))\\). In this case, the types 163 | guide our thinking: 164 | 165 | ```haskell 166 | sumElements :: R n -> Double 167 | sumElementsScaledGrad :: R n -> Double -> R n 168 | ``` 169 | 170 | The simplest way for me to do this personally is to just take it element by 171 | element. 172 | 173 | 1. *Write out the functions in question, in a simple example* 174 | 175 | In our case: 176 | 177 | * \\(y = f(\langle a, b, c \rangle) = a + b + c\\) 178 | * \\(z = g(y) = g(a + b + c)\\) 179 | 180 | 2. *Identify the components in your gradient* 181 | 182 | In our case, we have to return a gradient \\(\langle \frac{\partial z}{\partial a}, 183 | \frac{\partial z}{\partial b}, \frac{\partial z}{\partial c} \rangle\\). 184 | 185 | 3. *Work out each component of the gradient until you start to notice a 186 | pattern* 187 | 188 | Let's start with \\(\frac{\partial z}{\partial a}\\). We need to find 189 | \\(\frac{\partial z}{\partial a}\\) in terms of \\(\frac{dz}{dy}\\): 190 | 191 | * Through the chain rule, \\(\frac{\partial z}{\partial a} = 192 | \frac{dz}{dy} \frac{\partial y}{\partial a}\\). 193 | * Because \\(y = a + b + c\\), we know that \\(\frac{\partial y}{\partial 194 | a} = 1\\). 195 | * Because \\(\frac{\partial y}{\partial a} = 1\\), we know that 196 | \\(\frac{\partial z}{\partial a} = \frac{dz}{dy} \times 1 = 197 | \frac{dz}{dy}\\). 198 | 199 | So, our expression of \\(\frac{\partial z}{\partial a}\\) in terms of 200 | \\(\frac{dz}{dy}\\) is simple -- it's simply \\(\frac{\partial z}{\partial 201 | a} = \frac{dz}{dy}\\). 202 | 203 | Now, let's look at \\(\frac{\partial z}{\partial b}\\). We need to find 204 | \\(\frac{\partial z}{\partial b}\\) in terms of \\(\frac{dz}{dy}\\). 205 | 206 | * Through the chain rule, \\(\frac{\partial z}{\partial b} = 207 | \frac{dz}{dy} \frac{\partial y}{\partial b}\\). 208 | * Because \\(y = a + b + c\\), we know that \\(\frac{\partial y}{\partial 209 | b} = 1\\). 210 | * Because \\(\frac{\partial y}{\partial b} = 1\\), we know that 211 | \\(\frac{\partial z}{\partial b} = \frac{dz}{dy} \times 1 = 212 | \frac{dz}{dy}\\). 213 | 214 | It looks like \\(\frac{\partial z}{\partial b} = \frac{\partial z}{\partial 215 | y}\\), as well. 216 | 217 | At this point, we start to notice a pattern. We can apply the same logic 218 | to see that \\(\frac{\partial z}{\partial c} = \frac{dz}{dy}\\). 219 | 220 | 4. *Write out the pattern* 221 | 222 | Extrapolating the pattern, \\(\frac{\partial z}{\partial q}\\), where 223 | \\(q\\) is *any* component, is always going to be a constant -- 224 | \\(\frac{dz}{dy}\\). 225 | 226 | So in the end: 227 | 228 | ```haskell top hide 229 | instance Backprop (R n) where 230 | zero = zeroNum 231 | add = addNum 232 | one = oneNum 233 | 234 | instance (KnownNat n, KnownNat m) => Backprop (L n m) where 235 | zero = zeroNum 236 | add = addNum 237 | one = oneNum 238 | 239 | sumElements :: KnownNat n => R n -> Double 240 | sumElements = HU.sumElements . H.extract 241 | ``` 242 | 243 | ```haskell top 244 | liftedSumElements 245 | :: (KnownNat n, Reifies s W) 246 | => BVar s (R n) 247 | -> BVar s Double 248 | liftedSumElements = liftOp1 . op1 $ \xs -> 249 | ( sumElements xs, \dzdy -> konst dzdy ) -- a constant vector 250 | ``` 251 | 252 | ### Multiple-argument functions 253 | 254 | Lifting multiple-argument functions is the same thing, except using `liftOp2` 255 | and `op2`, or `liftOpN` and `opN`. 256 | 257 | A `BVar s a -> BVar s b -> BVar s c` is, really, under the hood: 258 | 259 | ```haskell 260 | type BVar s a -> BVar s b -> BVar s c = 261 | a -> b -> (c, c -> (a, b)) 262 | ``` 263 | 264 | That is, given an input `a` and `b`, you get: 265 | 266 | 1. A `c`, the result (the "forward pass") 267 | 2. A `c -> (a, b)`, the "scaled gradient" function returning the gradient of 268 | both inputs. 269 | 270 | The `c` parameter of the scaled gradient is again \\(\frac{dz}{dy}\\), and the 271 | final `(a,b)` is a tuple of \\(\frac{\partial z}{\partial x_1}\\) and 272 | \\(\frac{\partial z}{\partial x_2}\\): how \\(\frac{dz}{dy}\\) affects both of 273 | the inputs. 274 | 275 | For a simple example, let's look at \\(x + y\\). Working it out: 276 | 277 | * \\(y = f(x_1, x_2) = x_1 + x_2\\) 278 | * \\(z = g(f(x_1, x_2)) = g(x_1 + x_2)\\) 279 | * Looking first for \\(\frac{\partial z}{\partial x_1}\\) in terms of 280 | \\(\frac{dz}{dy}\\): 281 | * \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial 282 | y}{\partial x_1}\\) (chain rule) 283 | * From \\(y = x_1 + x_2\\), we see that \\(\frac{\partial y}{\partial 284 | x_1} = 1\\) 285 | * Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \times 1 286 | = \frac{dz}{dy}\\). 287 | * Looking second for \\(\frac{\partial z}{\partial x_2}\\) in terms of 288 | \\(\frac{dz}{dy}\\): 289 | * \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \frac{\partial 290 | y}{\partial x_2}\\) (chain rule) 291 | * From \\(y = x_1 + x_2\\), we see that \\(\frac{\partial y}{\partial 292 | x_2} = 1\\) 293 | * Therefore, \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \times 1 294 | = \frac{dz}{dy}\\). 295 | * Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy}\\), and also 296 | \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy}\\). 297 | 298 | Putting it into code: 299 | 300 | ```haskell top 301 | add :: (Num a, Backprop a, Reifies s W) 302 | => BVar s a 303 | -> BVar s a 304 | -> BVar s a 305 | add = liftOp2 . op2 $ \x1 x2 -> 306 | ( x1 + x2, \dzdy -> (dzdy, dzdy) ) 307 | ``` 308 | 309 | Let's try our hand at multiplication, or \\(x * y\\): 310 | 311 | * \\(y = f(x_1, x_2) = x_1 x_2\\) 312 | * \\(z = g(f(x_1, x_2)) = g(x_1 x_2)\\) 313 | * Looking first for \\(\frac{d\partial }{d\partial _1}\\) in terms of 314 | \\(\frac{dz}{dy}\\): 315 | * \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial 316 | y}{\partial x_1}\\) (chain rule) 317 | * From \\(y = x_1 x_2\\), we see that \\(\frac{\partial y}{\partial x_1} 318 | = x_2\\) 319 | * Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2\\). 320 | * Looking second for \\(\frac{\partial z}{\partial x_2}\\) in terms of 321 | \\(\frac{dz}{dy}\\): 322 | * \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial 323 | y}{\partial x_1}\\) (chain rule) 324 | * From \\(y = x_1 x_2\\), we see that \\(\frac{\partial y}{\partial x_2} 325 | = x_1\\) 326 | * Therefore, \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} x_1\\). 327 | * Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2\\), and 328 | \\(\frac{\partial z}{\partial x_2} = x_1 \frac{dz}{dy}\\). 329 | 330 | In code: 331 | 332 | ```haskell top 333 | mul :: (Num a, Backprop a, Reifies s W) 334 | => BVar s a 335 | -> BVar s a 336 | -> BVar s a 337 | mul = liftOp2 . op2 $ \x1 x2 -> 338 | ( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) ) 339 | ``` 340 | 341 | For non-trivial examples involving linear algebra, see the source for the *[hmatrix-backprop][]* library. 342 | 343 | [hmatrix-backprop]: http://hackage.haskell.org/package/hmatrix-backprop 344 | 345 | Some examples, for the dot product between two vectors and for matrix-vector 346 | multiplication: 347 | 348 | ```haskell top 349 | -- import qualified Numeric.LinearAlgebra.Static as H 350 | 351 | -- | dot product between two vectors 352 | dot 353 | :: (KnownNat n, Reifies s W) 354 | => BVar s (R n) 355 | -> BVar s (R n) 356 | -> BVar s Double 357 | dot = liftOp2 . op2 $ \u v -> 358 | ( u `H.dot` v 359 | , \dzdy -> (H.konst dzdy * v, u * H.konst dzdy) 360 | ) 361 | 362 | 363 | -- | matrix-vector multiplication 364 | (#>) 365 | :: (KnownNat m, KnownNat n, Reifies s W) 366 | => BVar s (L m n) 367 | -> BVar s (R n) 368 | -> BVar s (R m) 369 | (#>) = liftOp2 . op2 $ \mat vec -> 370 | ( mat H.#> vec 371 | , \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy) 372 | ) 373 | ``` 374 | 375 | Possibilities 376 | ------------- 377 | 378 | That's it for this introductory tutorial on lifting single operations. More 379 | information on the ways to apply these techniques to fully equip your library 380 | for backpropagation (including arguments with multiple results, taking 381 | advantage of isomorphisms, providing non-gradient functions) can be [found 382 | here][equipping]! 383 | -------------------------------------------------------------------------------- /doc/07-performance.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Performance & Optimizations 3 | --- 4 | 5 | Performance and Optimizations 6 | ============================= 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE DataKinds #-} 10 | {-# LANGUAGE DeriveGeneric #-} 11 | {-# LANGUAGE FlexibleContexts #-} 12 | {-# LANGUAGE FlexibleInstances #-} 13 | {-# LANGUAGE StandaloneDeriving #-} 14 | {-# LANGUAGE TemplateHaskell #-} 15 | {-# LANGUAGE TypeFamilies #-} 16 | {-# LANGUAGE ViewPatterns #-} 17 | 18 | 19 | import GHC.Generics (Generic) 20 | import GHC.TypeNats 21 | import Inliterate.Import 22 | import Lens.Micro 23 | import Lens.Micro.TH 24 | import Numeric.Backprop 25 | import Numeric.Backprop.Class 26 | import Numeric.LinearAlgebra.Static (L, R) 27 | import qualified Numeric.LinearAlgebra as HU 28 | import qualified Numeric.LinearAlgebra.Static as H 29 | ``` 30 | 31 | We can use the [MNIST tutorial][bench] as an example to compare automatic 32 | differentiation with "manual" differentiation: 33 | 34 | [bench]: https://github.com/mstksg/backprop/blob/master/bench/bench.hs 35 | 36 | ![benchmarks](https://i.imgur.com/rLUx4x4.png) 37 | 38 | In the above, we compare: 39 | 40 | 1. "Manual" differentiation of a 784 x 300 x 100 x 10 fully-connected 41 | feed-forward ANN. 42 | 2. Automatic differentiation using *backprop* and the lens-based accessor 43 | interface 44 | 3. Automatic differentiation using *backprop* and the "higher-kinded 45 | data"-based pattern matching interface 46 | 4. A hybrid approach that manually provides gradients for individual layers 47 | but uses automatic differentiation for chaining the layers together. See 48 | the section "Dealing with Overhead from Redundant Updates" for details. 49 | 50 | Sources of Overhead 51 | ------------------- 52 | 53 | One immediate result is that simply *running* the network and functions (using 54 | `evalBP`) incurs virtually zero overhead. This means that library authors 55 | could actually export *only* backprop-lifted functions, and users would be able 56 | to use them without losing any performance. 57 | 58 | As for computing gradients, there exists some associated overhead. There are 59 | three main sources: 60 | 61 | 1. The construction and traversal of the [Wengert tape][] used to implement 62 | automatic differentiation. However, this overhead is typically negligible 63 | for backpropagating any numerical computations of non-trivial complexity. 64 | 65 | 2. Redundant updates of entire data types during gradient accumulation. This 66 | will be, **by far**, the *dominating* source of any overhead compared to manual 67 | differentiation for any numerical computation of non-trivial complexity. 68 | 69 | 3. Inefficiencies associated with "naive" differentiation, compared to manual 70 | symbolic differentiation. However, this inefficiency is typically 71 | negligible except in edge cases. 72 | 73 | [Wengert tape]: https://dl.acm.org/citation.cfm?doid=355586.364791 74 | 75 | In addition, usage of the "Higher-Kinded Data"-based pattern matching interface 76 | (over the lens-based accessor interface) seems to result in more efficient 77 | compiled code, but not by any significant amount. 78 | 79 | Optimization Techniques 80 | ----------------------- 81 | 82 | ### Dealing with Overhead from Redundant Updates 83 | 84 | By far the dominating source of overhead when using *backprop* is the redundant 85 | update of data type fields when accumulating gradients. 86 | 87 | #### Example 88 | 89 | That is, if we had a data type like: 90 | 91 | ```haskell top 92 | data MyType = MT { _mtX :: Double 93 | , _mtY :: Double 94 | , _mtZ :: Double 95 | } 96 | deriving (Show, Generic) 97 | 98 | makeLenses ''MyType 99 | 100 | instance Backprop MyType 101 | ``` 102 | 103 | ```haskell top hide 104 | instance AskInliterate MyType 105 | ``` 106 | 107 | 108 | and we *use* all three fields somehow: 109 | 110 | ```haskell top 111 | myFunc :: Reifies s W => BVar s MyType -> BVar s Double 112 | myFunc mt = (mt ^^. mtX) * (mt ^^. mtY) + (mt ^^. mtZ) 113 | ``` 114 | 115 | and we compute its gradient: 116 | 117 | ```haskell eval 118 | gradBP myFunc (MT 5 7 2) 119 | ``` 120 | 121 | The library will first compute the derivative of the first field, and embed it 122 | into `MyType`: 123 | 124 | ```haskell 125 | MT { _mtX = 7.0, _mtY = 0.0, _mtZ = 0.0 } 126 | ``` 127 | 128 | Then it'll compute the derivative of the second field and embed it: 129 | 130 | ```haskell 131 | MT { _mtX = 0.0, _mtY = 5.0, _mtZ = 0.0 } 132 | ``` 133 | 134 | And finally compute the derivative of the third field and embed it: 135 | 136 | ```haskell 137 | MT { _mtX = 0.0, _mtY = 0.0, _mtZ = 1.0 } 138 | ``` 139 | 140 | And it'll compute the final derivative by `add`-ing all three of those 141 | together. 142 | 143 | This is not too bad with `Double`s, but when you have huge matrices, there will 144 | be *six redundant addition of zeroes* for a data type with three fields...and 145 | those additions of zero matrices can incur a huge cost. 146 | 147 | In general, for a data type with \\(n\\) fields where you use \\(m\\) of those 148 | fields, you will have something on the order of \\(\mathcal{O}(n m)\\) 149 | redundant additions by zero. 150 | 151 | #### Mitigating 152 | 153 | One way to mitigate these redundant updates is to prefer data types with less 154 | fields if possible, or re-factor your data types into multiple "levels" of 155 | nesting, to reduce the amount of redundant additions by zero. That is, instead 156 | of having a giant ten-field data type, have two five-field data types, and one 157 | type having a value of each type. This also works well with recursive "linked 158 | list" data types, as well, as long as you write functions on your linked lists 159 | inductively. 160 | 161 | You can also be careful in how many times you use `^^.` (`viewVar`), because 162 | each usage site incurs another addition-by-zero in the gradient accumulation. 163 | If possible, refactor all of your `^^.` into a single binding, and share it 164 | within your expression, instead of using it again several times for the same 165 | field in the same expression. 166 | 167 | You can also use clever lenses too "simulate" having a data type with less 168 | fields than you actually have. For example, you can have a lens on the first 169 | two fields: 170 | 171 | ```haskell top 172 | mtXY :: Lens' MyType (Double, Double) 173 | mtXY f (MT x y z) = (\(x', y') -> MT x' y' z) <$> f (x, y) 174 | ``` 175 | 176 | This treats accessing both fields as effectively a single access to a single 177 | tuple field, and so cuts out an extra addition by zero. 178 | 179 | As a last resort, you can *completely eliminate* redundant additions by zero by 180 | providing *manual gradients* to functions using your data type. 181 | 182 | ```haskell top 183 | myFunc' :: Reifies s W => BVar s MyType -> BVar s Double 184 | myFunc' = liftOp1 . op1 $ \(MT x y z) -> 185 | ( (x * y) + z 186 | , \d -> MT (d * y) (x * d) d 187 | ) 188 | ``` 189 | 190 | ```haskell eval 191 | gradBP myFunc' (MT 5 7 2) 192 | ``` 193 | 194 | See the [writing manual gradients][manual-gradients] page for more information 195 | on exactly how to specify your operations with manual gradients. 196 | 197 | [manual-gradients]: https://backprop.jle.im/06-manual-gradients.html 198 | 199 | Once you do this, you can use `myFunc'` as a part of any larger computation; 200 | backpropagation will still work the same, and you avoid any redundant additions 201 | of zero: 202 | 203 | ```haskell eval 204 | gradBP (negate . sqrt . myFunc) (MT 5 7 2) 205 | ``` 206 | 207 | ```haskell eval 208 | gradBP (negate . sqrt . myFunc') (MT 5 7 2) 209 | ``` 210 | 211 | When you *use* `myFunc'` in a function, it will be efficiently backpropagated 212 | by the *backprop* library. 213 | 214 | This is useful for situations like optimizing artificial neural networks that 215 | are a composition of multiple "layers": you can manually specify the derivative 216 | of each layer, but let the *backprop* library take care of finding the 217 | derivative of *their composition*. This is exactly the "hybrid" mode mentioned 218 | in the benchmarks above. As can be seen by benchmark results, this brings the 219 | manual and automatic backprop results to almost within range of random variance 220 | of each other. 221 | 222 | However, I don't recommend doing this, unless as a last resort for 223 | optimization. This is because: 224 | 225 | 1. The whole point of the *backprop* library is to allow you to never have to 226 | specify manual gradients 227 | 2. It is *very very easy* to make a mistake in your gradient computation and 228 | introduce subtle bugs 229 | 3. It is difficult to *modify* your function if you want to tweak what it 230 | returns. Compare changing the multiplication to division in the original 231 | `myFunc` vs. the manual `myFunc'` 232 | 4. It makes it harder to read and understand (and subsequently refactor) your 233 | code. 234 | 235 | However, this option is available as a low-level performance hack. 236 | 237 | ### Dealing with Overhead from Naive Differentiation 238 | 239 | [Automatic differentiation][ad] is a mechanical process that is nothing more 240 | than glorified book-keeping and accumulation. It essentially "hitches a ride" 241 | on your normal computation in order to automatically accumulate its gradient. 242 | It isn't aware of the analytical nature of computations, and cannot do any 243 | symbolic or analytical simplifications like re-associating additions or 244 | canceling out factors that humans might perform if manually differentiating. 245 | 246 | [ad]: https://en.wikipedia.org/wiki/Automatic_differentiation 247 | 248 | In most cases, this is "good enough" and will not be any significant source of 249 | inefficiency in the larger picture. At least, it won't be worth the cognitive 250 | overhead in squeezing out a one or two percent increase in performance. 251 | However, there are some edge cases where this might become a concern worth 252 | looking at. 253 | 254 | A common example is the composition of the [softmax][] activation function and 255 | the [cross-entropy][] error function often used in deep learning. Together, 256 | their derivatives are somewhat complex, computationally. However, the 257 | derivative of their *composition*, `crossEntropy x . softMax` actually has an 258 | extremely "simple" form, because of how some factors cancel out. To get around 259 | this, libraries like *tensorflow* offer an [optimized version of the 260 | composition with manually computed gradients][smce]. 261 | 262 | [softmax]: https://en.wikipedia.org/wiki/Softmax_function 263 | [cross-entropy]: https://en.wikipedia.org/wiki/Cross_entropy 264 | [smce]: https://www.tensorflow.org/api_docs/python/tf/losses/softmax_cross_entropy 265 | 266 | ```haskell top hide 267 | instance Backprop (R n) where 268 | zero = zeroNum 269 | add = addNum 270 | one = oneNum 271 | 272 | instance (KnownNat n, KnownNat m) => Backprop (L n m) where 273 | zero = zeroNum 274 | add = addNum 275 | one = oneNum 276 | 277 | instance KnownNat n => AskInliterate (R n) where 278 | askInliterate = answerWith (show . H.extract) 279 | 280 | konst 281 | :: (KnownNat n, Reifies s W) 282 | => BVar s Double 283 | -> BVar s (R n) 284 | konst = liftOp1 . op1 $ \x -> 285 | ( H.konst x 286 | , HU.sumElements . H.extract 287 | ) 288 | 289 | sumElements 290 | :: (KnownNat n, Reifies s W) 291 | => BVar s (R n) 292 | -> BVar s Double 293 | sumElements = liftOp1 . op1 $ \x -> 294 | ( HU.sumElements . H.extract $ x 295 | , H.konst 296 | ) 297 | 298 | dot 299 | :: (KnownNat n, Reifies s W) 300 | => BVar s (R n) 301 | -> BVar s (R n) 302 | -> BVar s Double 303 | dot = liftOp2 . op2 $ \x y -> 304 | ( x `H.dot` y 305 | , \d -> let d' = H.konst d 306 | in (d' * y, x * d') 307 | ) 308 | ``` 309 | 310 | ```haskell top 311 | -- import Numeric.LinearAlgebra.Static.Backprop 312 | 313 | softMax 314 | :: (KnownNat n, Reifies s W) 315 | => BVar s (R n) 316 | -> BVar s (R n) 317 | softMax x = konst (1 / totx) * expx 318 | where 319 | expx = exp x 320 | totx = sumElements expx 321 | 322 | crossEntropy 323 | :: (KnownNat n, Reifies s W) 324 | => R n 325 | -> BVar s (R n) 326 | -> BVar s Double 327 | crossEntropy x y = -(log y `dot` auto x) 328 | ``` 329 | 330 | (Note the usage of `auto :: a -> BVar s a` to lift a normal value into a `BVar`) 331 | 332 | Now, you can use `crossEntropy x . softMax` as a `BVar s (R n) -> BVar s Double` 333 | function, and the result and gradient would be correct. It would backpropagate 334 | the gradient of `crossEntropy` into `softMax`. However, you can take advantage 335 | of the fact that some factors in the result "cancel out", and you can 336 | drastically simplify the computation. 337 | 338 | Their normal composition would naively be: 339 | 340 | ```haskell top 341 | softMaxCrossEntropy 342 | :: (KnownNat n, Reifies s W) 343 | => R n 344 | -> BVar s (R n) 345 | -> BVar s Double 346 | softMaxCrossEntropy x y = -(log softMaxY `dot` auto x) 347 | where 348 | expy = exp y 349 | toty = sumElements expy 350 | softMaxY = konst (1 / toty) * expy 351 | ``` 352 | 353 | Which you can probably guess has a decently complex gradient, just from all of 354 | the chained operations we have going on. 355 | 356 | However, if you work things out on pencil and paper, you'll find a nice form 357 | for the gradient of the cross entropy composed with softmax, \\(f(x,y)\\): 358 | 359 | \\[ 360 | \nabla_y f(\mathbf{x}, \mathbf{y}) = \mathrm{softmax}(\mathbf{y}) - \mathbf{x} 361 | \\] 362 | 363 | Basically, the gradient is just the result of `softMax` vector-subtracted 364 | from the target. 365 | 366 | After computing the gradient by hand, we can write `softMaxCrossEntropy` 367 | with our manual gradient: 368 | 369 | ```haskell top 370 | -- using the non-lifted interfaces 371 | -- import qualified Numeric.LinearAlgebra as HU 372 | -- import qualified Numeric.LinearAlgebra.Statuc as H 373 | 374 | softMaxCrossEntropy' 375 | :: (KnownNat n, Reifies s W) 376 | => R n 377 | -> BVar s (R n) 378 | -> BVar s Double 379 | softMaxCrossEntropy' x = liftOp1 . op1 $ \y -> 380 | let expy = exp y 381 | toty = HU.sumElements (H.extract expy) 382 | softMaxY = H.konst (1 / toty) * expy 383 | smce = -(log softMaxY `H.dot` x) 384 | in ( smce 385 | , \d -> H.konst d * (softMaxY - x) 386 | ) 387 | ``` 388 | 389 | Our gradient is now just `softMaxY - x`, which I can assure you is much, much 390 | simpler than the automatic differentiation-derived gradient. This is because a 391 | lot of factors show up on the top and bottom of functions and cancel out, and 392 | a lot of positive and negative additions also end up canceling out. 393 | 394 | Again, refer to the [writing manual gradients][manual-gradients] page for more 395 | information on exactly how to specify your operations with manual gradients. 396 | 397 | Once you do this, `softMaxCrossEntropy'` is now a function you can use normally 398 | and compose with other backpropagatable functions. You won't be able to 399 | functionally tell apart `crossEntropy x . softMax` from `softMaxCrossEntropy'`, 400 | and the two will behave identically, propagating gradients with other `BVar` 401 | functions: 402 | 403 | ```haskell eval 404 | gradBP ((**2) . crossEntropy (H.vec3 1 0 0) . softMax) (H.vec3 0.9 0.2 0.3) 405 | ``` 406 | 407 | ```haskell eval 408 | gradBP ((**2) . softMaxCrossEntropy (H.vec3 1 0 0)) (H.vec3 0.9 0.2 0.3) 409 | ``` 410 | 411 | ```haskell eval 412 | gradBP ((**2) . softMaxCrossEntropy' (H.vec3 1 0 0)) (H.vec3 0.9 0.2 0.3) 413 | ``` 414 | 415 | `softMaxCrossEntropy'` will be more efficient in computing gradients. 416 | 417 | Again, I don't recommend doing this in most cases, and this should always be a 418 | last resort. To me, this is even less warranted than the situation above 419 | (mentioning redundant additions) because any losses due to naive AD should be 420 | negligible. Only doing this *after profiling and benchmarking*, when you are 421 | *sure* that a particular function composition is causing your bottleneck. 422 | Don't do this for any ol' composition you write, because: 423 | 424 | 1. Again, the *whole point* of this library is to allow you to *avoid* 425 | computing gradients by hand. 426 | 2. Computing gradients by hand is very tricky and there are many places where 427 | you could introduce a bug in a subtle way that might not be apparent even 428 | through initial testings. 429 | 3. This is very fragile, and any future changes to your function will require 430 | you to completely re-compute and re-write your giant lifted function. 431 | 4. It is again much harder to read and understand your code. 432 | 433 | But, if you profile and benchmark and conclude that a bad composition is 434 | bottleneck, know that this path is available. 435 | 436 | -------------------------------------------------------------------------------- /doc/08-equipping-your-library.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Equipping your Library 3 | --- 4 | 5 | Equipping your Library for Backprop 6 | =================================== 7 | 8 | ```haskell top hide 9 | {-# LANGUAGE DataKinds #-} 10 | {-# LANGUAGE DeriveGeneric #-} 11 | {-# LANGUAGE FlexibleContexts #-} 12 | {-# LANGUAGE FlexibleInstances #-} 13 | {-# LANGUAGE StandaloneDeriving #-} 14 | {-# LANGUAGE TemplateHaskell #-} 15 | {-# LANGUAGE TypeFamilies #-} 16 | {-# LANGUAGE ViewPatterns #-} 17 | 18 | 19 | import Data.Functor.Identity 20 | import GHC.Generics (Generic) 21 | import Lens.Micro 22 | import Lens.Micro.TH 23 | import Numeric.Backprop 24 | import Numeric.Backprop.Class 25 | import System.Random 26 | import qualified Data.List 27 | import qualified Data.Vector as V 28 | ``` 29 | 30 | So you want your users to be able to use your numerical library with 31 | *backprop*, huh? 32 | 33 | This page is specifically for library authors who want to allow their users to 34 | use their library operations and API with *backprop*. End-users of the 35 | *backprop* library should not have to worry about the contents of this page. 36 | 37 | Equipping your library with backprop involves providing "backprop-aware" 38 | versions of your library functions. *In fact*, it is possible to make a 39 | library fully by providing *only* backprop versions of your functions, since 40 | you can use a backprop-aware function as a normal function with `evalBP`. 41 | Alternatively, you can re-export all of your functions in a separate module with 42 | "backprop-aware" versions. 43 | 44 | Know Thy Types 45 | -------------- 46 | 47 | The most significant effort will be in lifting your library's functions. If 48 | you have a function: 49 | 50 | ```haskell 51 | myFunc :: a -> b 52 | ``` 53 | 54 | Then its lifted version would have type: 55 | 56 | ```haskell 57 | myFunc :: Reifies s W => BVar s a -> BVar s b 58 | ``` 59 | 60 | That is, instead of a function directly taking an `a` and returning a `b`, it's 61 | a function taking a `BVar` containing an `a`, and returning a `BVar` containing 62 | a `b`. 63 | 64 | Functions taking multiple arguments can be translated pretty straightforwardly: 65 | 66 | ```haskell 67 | func1 :: a -> b -> c 68 | func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c 69 | ``` 70 | 71 | And also functions returning multiple arguments: 72 | 73 | ```haskell 74 | func2 :: a -> ( b, c) 75 | func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c) 76 | ``` 77 | 78 | It is recommended (for ease of use with `-XTypeApplications`) that `Reifies s 79 | W` be the *final* constraint in all code you write. 80 | 81 | Note that almost all operations involving `BVar`'d items require that the 82 | contents have a `Backprop` instance. Alternative API's to backprop that 83 | require `Num` instances instead (or explicitly specified addition functions) 84 | are available in *Numeric.Backprop.Num* and *Numeric.Backprop.Explicit*. 85 | 86 | The Easy Way 87 | ------------ 88 | 89 | `BVar` based functions are just normal functions, so they can be applied 90 | normally and passed as first-class values. If possible, if you can *utilize* 91 | functions that are already `BVar`'d/lifted, then you can just define your API 92 | in terms of those lifted functions. This is also how *users* are expected to 93 | be able to use your library: just use the lifted functions you provide, in 94 | order to make their own lifted functions using normal function application and 95 | composition. 96 | 97 | Lifting operations manually 98 | --------------------------- 99 | 100 | However, if no lifted primitive functions are available, then you do have to do 101 | some legwork to provide information on gradient computation for your types. 102 | Ideally, you would only need to do this for some minimal set of your 103 | operations, and then define the rest of them in terms of the functions you have 104 | already lifted. 105 | 106 | A full tutorial on lifting your library functions [can be found 107 | here][manual-gradients]. It describes the usage of the `liftOp` and `op` 108 | family of functions to fully lift your single-argument single-result and 109 | multiple-argument single-result functions to be backpropagatable. 110 | 111 | [manual-gradients]: https://backprop.jle.im/06-manual-gradients.html 112 | 113 | ### Returning multiple items 114 | 115 | As an extension of the [manual gradient tutorial][manual-gradients], we can 116 | consider functions that return multiple items. 117 | 118 | You can always return tuples inside `BVar`s: 119 | 120 | ```haskell top 121 | splitAt 122 | :: (Backprop a, Reifies s W) 123 | => Int 124 | -> BVar s [a] 125 | -> BVar s ([a], [a]) 126 | splitAt n = liftOp1 . op1 $ \xs -> 127 | let (ys, zs) = Data.List.splitAt n xs 128 | in ((ys, zs), \(dys,dzs) -> dys ++ dzs) 129 | -- assumes dys and dzs have the same lengths as ys and zs 130 | ``` 131 | 132 | This works as expected. However, it is recommended, for the benefit of your 133 | users, that you return a tuple of `BVar`s instead of a `BVar` of tuples: 134 | 135 | ```haskell top 136 | splitAt' 137 | :: (Backprop a, Reifies s W) 138 | => Int 139 | -> BVar s [a] 140 | -> (BVar s [a], BVar s [a]) 141 | splitAt' n xs = (yszs ^^. _1, yszs ^^. _2) 142 | where 143 | yszs = liftOp1 (op1 $ \xs' -> 144 | let (ys, zs) = Data.List.splitAt n xs' 145 | in ((ys, zs), \(dys,dzs) -> dys ++ dzs) 146 | ) xs 147 | ``` 148 | 149 | using `_1` and `_2` from the *[microlens][]* or *[lens][]* packages. This 150 | might also be cleaner if you take advantage of the `T2` or `T3` pattern 151 | synonyms: 152 | 153 | [microlens]: http://hackage.haskell.org/package/microlens 154 | [lens]: http://hackage.haskell.org/package/lens 155 | 156 | ```haskell top 157 | splitAt'' 158 | :: (Backprop a, Reifies s W) 159 | => Int 160 | -> BVar s [a] 161 | -> (BVar s [a], BVar s [a]) 162 | splitAt'' n xs = (ys, zs) 163 | where 164 | T2 ys zs = liftOp1 (op1 $ \xs' -> 165 | let (ys, zs) = Data.List.splitAt n xs' 166 | in ((ys, zs), \(dys,dzs) -> dys ++ dzs) 167 | ) xs 168 | ``` 169 | 170 | ### Isomorphisms 171 | 172 | If your function witnesses an isomorphism, there are handy combinators for 173 | making this easy to write. This is especially useful in the case of data 174 | constructors: 175 | 176 | ```haskell top 177 | newtype Foo = MkFoo { getFoo :: Double } 178 | deriving Generic 179 | 180 | instance Backprop Foo 181 | 182 | mkFoo 183 | :: Reifies s W 184 | => BVar s Double 185 | -> BVar s Foo 186 | mkFoo = isoVar MkFoo getFoo 187 | 188 | data Bar = MkBar { bar1 :: Double, bar2 :: Float } 189 | deriving Generic 190 | 191 | instance Backprop Bar 192 | 193 | mkBar 194 | :: Reifies s W 195 | => BVar s Double 196 | -> BVar s Float 197 | -> BVar s Bar 198 | mkBar = isoVar2 MkBar (\b -> (bar1 b, bar2 b)) 199 | ``` 200 | 201 | Note also that if you have a newtype with one constructor (or any other two 202 | `Coercible` types), you can simply use `coerceVar`: 203 | 204 | ```haskell top 205 | mkFoo' 206 | :: BVar s Double 207 | -> BVar s Foo 208 | mkFoo' = coerceVar -- requires no `Reifies s W` constraint 209 | ``` 210 | 211 | ### NoGrad 212 | 213 | If you do decide to go to the extreme, and provide *only* a BVar-based 214 | interface to your library (and no non-BVar based one), then you might have a 215 | situation where you have a function where you cannot define the gradient -- 216 | maybe no gradient exists, or you haven't put in the time to write one. In this 217 | case, you can use `noGrad` and `noGrad1`: 218 | 219 | ```haskell top 220 | negateNoGrad 221 | :: (Num a, Backprop a, Reifies s W) 222 | => BVar s a 223 | -> BVar s a 224 | negateNoGrad = liftOp1 (noGrad1 negate) 225 | ``` 226 | 227 | This function can still be used with `evalBP` to get the correct answer. It 228 | can even be used with `gradBP` if the result is never used in the final answer. 229 | 230 | However, if it *is* used in the final answer, then computing the gradient will 231 | throw a runtime exception. 232 | 233 | Be sure to warn your users! Like any partial function, this is not recommended 234 | unless in extreme circumstances. 235 | 236 | Monadic Operations 237 | ------------------ 238 | 239 | This should all work if your operations are all "pure". However, what about 240 | the cases where your operations have to be performed in some Applicative or 241 | Monadic context? 242 | 243 | For example, what if `add :: X -> X -> IO X` ? 244 | 245 | One option you can do is to newtype-wrap your operations, and then give those a 246 | backprop instance: 247 | 248 | ```haskell top hide 249 | data X 250 | 251 | zeroForX :: X -> X 252 | zeroForX = undefined 253 | addForX :: X -> X -> IO X 254 | addForX = undefined 255 | oneForX :: X -> X 256 | oneForX = undefined 257 | ``` 258 | 259 | ```haskell top 260 | newtype IOX = IOX (IO X) 261 | 262 | instance Backprop IOX where 263 | zero (IOX x) = IOX (fmap zeroForX x) 264 | -- or, depending on the type of `zeroForX`: 265 | -- zero (IOX x) = IOX (zeroForX =<< x) 266 | 267 | add (IOX x) (IOX y) = IOX $ do 268 | x' <- x 269 | y' <- y 270 | addForX x' y' 271 | 272 | one (IOX x) = IOX (fmap oneForX x) 273 | ``` 274 | 275 | And you can define your functions in terms of this: 276 | 277 | ```haskell top 278 | addX 279 | :: Reifies s W 280 | => BVar s IOX 281 | -> BVar s IOX 282 | -> BVar s IOX 283 | addX = liftOp2 . op2 $ \(IOX x) (IOX y) -> 284 | ( IOX (do x' <- x; y' <- y; addForX x' y') 285 | , \dzdy -> (dzdy, dzdy) 286 | ) 287 | ``` 288 | 289 | This should work fine as long as you never "branch" on any *results* of your 290 | actions. You must not ever need to peek inside the *results* of the action in 291 | order to decide *what* operations to do next. In other words, this works if 292 | the operations you need to perform are all known and fixed before-hand, before 293 | any actions are performed. So, this means no access to the `Eq` or `Ord` 294 | instances of BVars (unless your monad has `Eq` or `Ord` instances defined). 295 | 296 | A newtype wrapper is provided to give you this behavior automatically -- it's 297 | `ABP`, from *Numeric.Backprop* and *Numeric.Backprop.Class*. 298 | 299 | ```haskell 300 | type IOX = ABP IO X 301 | ``` 302 | 303 | However, this will not work if you need to do things like compare contents, 304 | etc. to decide what operations to use. 305 | 306 | At the moment, this is not supported. Please open an issue if this becomes an 307 | issue! 308 | 309 | Supporting Data Types 310 | --------------------- 311 | 312 | Your library will probably have data types that you expect your users to use. 313 | To equip your data types for backpropagation, you can take a few steps. 314 | 315 | ### Backprop Class 316 | 317 | First of all, all of your library's types should have instances of the 318 | [`Backprop` typeclass][class]. This allows values of your type to be used in 319 | backpropagatable functions. See the [Backprop typeclass section][tcdocs] of 320 | this documentation for more information on writing a `Backprop` instance for 321 | your types. 322 | 323 | [class]: https://hackage.haskell.org/package/backprop/docs/Numeric-Backprop-Class.html 324 | [tcdocs]: https://backprop.jle.im/04-the-backprop-typeclass.html 325 | 326 | In short: 327 | 328 | 1. If your type is a type with a single constructor whose fields are all 329 | instances of `Backprop`, you can just write `instance Backprop MyType`, and 330 | the instance is generated automatically (as long as your type has a 331 | `Generic` instance) 332 | 333 | ```haskell top 334 | data MyType = MkMyType Double [Float] (R 10) (L 20 10) (V.Vector Double) 335 | deriving Generic 336 | 337 | instance Backprop MyType 338 | ``` 339 | 340 | 2. If your type is an instance of `Num`, you can use `zeroNum`, `addNum`, and 341 | `oneNum` to get free definitions of the typeclass methods. 342 | 343 | ```haskell 344 | instance Backprop Double where 345 | zero = zeroNum 346 | add = addNum 347 | one = oneNum 348 | ``` 349 | 350 | 3. If your type is made using a `Functor` instance, you can use `zeroFunctor` 351 | and `oneFunctor`: 352 | 353 | ```haskell 354 | instance Backprop a => Backprop (V.Vector a) where 355 | zero = zeroFunctor 356 | add = undefined -- ?? 357 | one = oneFunctor 358 | ``` 359 | 360 | 4. If your type has an `IsList` instance, you can use `addIsList`: 361 | 362 | ```haskell 363 | instance Backprop a => Backprop (V.Vector a) where 364 | zero = zeroFunctor 365 | add = addIsList 366 | one = oneFunctor 367 | ``` 368 | 369 | For more details, see the [aforementioned documentation][tcdocs] or the [actual 370 | typeclass haddock documentation][class]. 371 | 372 | ### Accessors 373 | 374 | If you have product types, users should be able to access values inside `BVar`s 375 | of your data type. There are two main ways to provide access: the lens-based 376 | interface and the higher-kinded-data-based interface. 377 | 378 | The lens-based interface gives your users "getter" and "setter" functions for 379 | fields, and the higher-kinded-data-based interface lets your users pattern 380 | match on your data type's original constructor to get fields and construct 381 | values. 382 | 383 | #### Lens-Based Interface 384 | 385 | If you are defining a product type, like 386 | 387 | ```haskell top 388 | data MyType = MT { _mtDouble :: Double 389 | , _mtInt :: Int 390 | , _mtDoubles :: [Double] 391 | } 392 | ``` 393 | 394 | Users who have a `BVar s MyType` can't normally access the fields inside, 395 | because you can't directly pattern match normally, and the record accessors 396 | are `MyType -> Int` (unlifted). As a library maintainer, you can provide them 397 | *lenses* to the fields, either generated automatically using the *[lens][]* or 398 | *[microlens-th][]* packages: 399 | 400 | [lens]: http://hackage.haskell.org/package/lens 401 | [microlens-th]: http://hackage.haskell.org/package/microlens-th 402 | 403 | ```haskell top 404 | -- requires -XTemplateHaskell 405 | makeLenses ''MyType 406 | ``` 407 | 408 | or manually by hand: 409 | 410 | ```haskell top 411 | mtInt' :: Functor f => (Int -> f Int) -> MyType -> f MyType 412 | mtInt' f mt = (\i -> mt { _mtInt = i }) <$> f (_mtInt mt) 413 | ``` 414 | 415 | Now, users can use `^.` or `view` from the *lens* or *[microlens][]* packages 416 | to retrieve your fields: 417 | 418 | [microlens]: http://hackage.haskell.org/package/microlens 419 | 420 | ```haskell 421 | (^. mtDouble) :: MyType -> Double 422 | ``` 423 | 424 | And `(^^.)` and `viewVar` from *backprop* to retrieve fields from a `BVar`: 425 | 426 | ```haskell 427 | (^^. mtDouble) :: BVar s MyType -> BVar s Double 428 | ``` 429 | 430 | They can also use `set` or `.~` to modify fields, and `setVar` and `.~~` to 431 | modify and "set" fields in a `BVar`: 432 | 433 | ```haskell 434 | set mtDouble :: Double -> MyType -> MyType 435 | setVar mtDouble :: BVar s Double -> BVar s MyType -> BVar s MyType 436 | ``` 437 | 438 | Likewise, `over` and `%~` can be used to apply a function to the contents of a 439 | field, and `overVar` and `%~~` can be used to apply backpropagatable functions 440 | to over fields of a value in a `BVar`. 441 | 442 | #### Higher-Kinded Data Interface 443 | 444 | The alternative "Higher-Kinded Data" technique, inspired by [this 445 | article][hkd], allows your users to directly pattern match on `BVar`s of your 446 | types to get their contents. 447 | 448 | [hkd]: http://reasonablypolymorphic.com/blog/higher-kinded-data/ 449 | 450 | Doing this requires modifying the definition of your data types slightly. 451 | Instead of `MyType` above, we can make a type family that can be re-used for 452 | all of your data types: 453 | 454 | ```haskell top 455 | type family HKD f a where 456 | HKD Identity a = a 457 | HKD f a = f a 458 | ``` 459 | 460 | and define your data types in terms of this type family (remembering to derive 461 | `Generic`): 462 | 463 | ```haskell top 464 | data MyType2' f = MT2 { mt2Double :: HKD f Double 465 | , mt2Int :: HKD f Int 466 | , mt2Doubles :: HKD f [Double] 467 | } 468 | deriving Generic 469 | ``` 470 | 471 | Now your original data type can be recovered with `MyType2' Identity`, and can 472 | be pattern matched directly in the same way as the original type (the 473 | `Identity` disappears): 474 | 475 | ```haskell top 476 | type MyType2 = MyType2' Identity 477 | 478 | deriving instance Show MyType2 479 | instance Backprop MyType2 480 | 481 | getMT2Double :: MyType2 -> Double 482 | getMT2Double (MT2 d _ _) = d 483 | ``` 484 | 485 | But now, users can *pattern match* on a `BVar s MyType2` to get `BVar`s of the 486 | contents, with `splitBV` or the `BV` pattern synonym: 487 | 488 | ```haskell top 489 | getMT2DoubleBVar 490 | :: Reifies s W 491 | => BVar s MyType2 492 | -> BVar s Double 493 | getMT2DoubleBVar (splitBV -> MT2 d _ _) = d 494 | ``` 495 | 496 | Under `splitBV`, your users can pattern match on the `MT2` constructor and get 497 | the contents as `BVar`s. 498 | 499 | Note that HKD access through pattern matching is potentially less performant 500 | than access using lens (by about 10-20%). 501 | 502 | Users can also use `joinBV` (or the `BV` pattern synonym in constructor mode) 503 | to re-construct a `BVar` of `MyType2` in terms of `BVar`s of its contents using 504 | the `MT2` constructor: 505 | 506 | ```haskell top 507 | makeMyType2 508 | :: Reifies s W 509 | => BVar s Double 510 | -> BVar s Int 511 | -> BVar s [Double] 512 | -> BVar s MyType2 513 | makeMyType2 d i ds = joinBV $ MT2 d i ds 514 | ``` 515 | -------------------------------------------------------------------------------- /doc/09-comparisons.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Comparisons 3 | --- 4 | 5 | Comparisons 6 | =========== 7 | 8 | *backprop* can be compared and contrasted to many other similar libraries with 9 | some overlap: 10 | 11 | 1. The *[ad][]* library (and variants like *[diffhask][]*) support automatic 12 | differentiation, but only for *homogeneous*/*monomorphic* situations. All 13 | values in a computation must be of the same type --- so, your computation 14 | might be the manipulation of `Double`s through a `Double -> Double` 15 | function. 16 | 17 | *backprop* allows you to mix matrices, vectors, doubles, integers, and even 18 | key-value maps as a part of your computation, and they will all be 19 | backpropagated properly with the help of the `Backprop` typeclass. 20 | 21 | 2. The *[autograd][]* library is a very close equivalent to *backprop*, 22 | implemented in Python for Python applications. The difference between 23 | *backprop* and *autograd* is mostly the difference between Haskell and 24 | Python --- static types with type inference, purity, etc. 25 | 26 | 3. There is a link between *backprop* and deep learning/neural network 27 | libraries like *[tensorflow][]*, *[caffe][]*, and *[theano][]*, which all 28 | all support some form of heterogeneous automatic differentiation. Haskell 29 | libraries doing similar things include *[grenade][]*. 30 | 31 | These are all frameworks for working with neural networks or other 32 | gradient-based optimizations --- they include things like built-in 33 | optimizers, methods to automate training data, built-in models to use out 34 | of the box. *backprop* could be used as a *part* of such a framework, like 35 | I described in my [A Purely Functional Typed Approach to Trainable 36 | Models][models] blog series; however, the *backprop* library itself does 37 | not provide any built in models or optimizers or automated data processing 38 | pipelines. 39 | 40 | [ad]: https://hackage.haskell.org/package/ad 41 | [diffhask]: https://hackage.haskell.org/package/diffhask 42 | [autograd]: https://github.com/HIPS/autograd 43 | [tensorflow]: https://www.tensorflow.org/ 44 | [caffe]: http://caffe.berkeleyvision.org/ 45 | [theano]: http://www.deeplearning.net/software/theano/ 46 | [grenade]: http://hackage.haskell.org/package/grenade 47 | [models]: https://blog.jle.im/entry/purely-functional-typed-models-1.html 48 | -------------------------------------------------------------------------------- /doc/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Home 3 | --- 4 | 5 | Welcome to Backprop 6 | =================== 7 | 8 | Automatic *heterogeneous* back-propagation. 9 | 10 | *Write your functions normally* to compute your result, and the library will 11 | *automatically compute your gradient*! 12 | 13 | ```haskell top hide 14 | import Numeric.Backprop 15 | ``` 16 | 17 | ```haskell eval 18 | gradBP (\x -> x^2 + 3) (9 :: Double) 19 | ``` 20 | 21 | Differs from [ad][] by offering full heterogeneity -- each intermediate step 22 | and the resulting value can have different types (matrices, vectors, scalars, 23 | lists, etc.) 24 | 25 | [ad]: http://hackage.haskell.org/package/ad 26 | 27 | ```haskell eval 28 | gradBP2 (\x xs -> sum (map (**2) (sequenceVar xs)) / x) 29 | (9 :: Double ) 30 | ([1,6,2] :: [Double]) 31 | ``` 32 | 33 | Useful for applications in [differentiable programming][dp] and deep learning 34 | for creating and training numerical models, especially as described in this 35 | blog post on [a purely functional typed approach to trainable models][models]. 36 | Overall, intended for the implementation of gradient descent and other numeric 37 | optimization techniques. Comparable to the python library [autograd][]. 38 | 39 | [dp]: https://www.facebook.com/yann.lecun/posts/10155003011462143 40 | [models]: https://blog.jle.im/entry/purely-functional-typed-models-1.html 41 | [autograd]: https://github.com/HIPS/autograd 42 | 43 | **[Get started][getting started]** with the introduction and walkthrough! Full 44 | technical documentation is also **[available on hackage][hackage]** if you want 45 | to skip the introduction and get right into using the library. Support is 46 | available on the **[gitter channel][gitter]**! 47 | 48 | [getting started]: https://backprop.jle.im/01-getting-started.html 49 | 50 | [hackage]: http://hackage.haskell.org/package/backprop 51 | [gitter]: https://gitter.im/haskell-backprop/Lobby 52 | 53 | [![Join the chat at https://gitter.im/haskell-backprop/Lobby](https://badges.gitter.im/haskell-backprop/Lobby.svg)](https://gitter.im/haskell-backprop/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 54 | 55 | [![backprop on Hackage](https://img.shields.io/hackage/v/backprop.svg?maxAge=86400)](https://hackage.haskell.org/package/backprop) 56 | [![backprop on Stackage LTS 11](http://stackage.org/package/backprop/badge/lts-11)](http://stackage.org/lts-11/package/backprop) 57 | [![backprop on Stackage Nightly](http://stackage.org/package/backprop/badge/nightly)](http://stackage.org/nightly/package/backprop) 58 | [![Build Status](https://travis-ci.org/mstksg/backprop.svg?branch=master)](https://travis-ci.org/mstksg/backprop) 59 | 60 | -------------------------------------------------------------------------------- /doctest/doctest.hs: -------------------------------------------------------------------------------- 1 | import System.FilePath.Glob (glob) 2 | import Test.DocTest (doctest) 3 | 4 | main :: IO () 5 | main = glob "src/**/*.hs" >>= doctest 6 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Basic Haskell Project Flake"; 3 | inputs = { 4 | haskellProjectFlake.url = "github:mstksg/haskell-project-flake"; 5 | nixpkgs.follows = "haskellProjectFlake/nixpkgs"; 6 | }; 7 | outputs = 8 | { self 9 | , nixpkgs 10 | , flake-utils 11 | , haskellProjectFlake 12 | }: 13 | flake-utils.lib.eachDefaultSystem (system: 14 | let 15 | name = "backprop"; 16 | pkgs = import nixpkgs { 17 | inherit system; 18 | overlays = [ haskellProjectFlake.overlays."${system}".default ]; 19 | }; 20 | project-flake = pkgs.haskell-project-flake 21 | { 22 | inherit name; 23 | src = ./.; 24 | # lapack seems to link badly on almost all versions so CI is 25 | # screwed until we figure out what's going on 26 | excludeCompilerMajors = [ "ghc94" "ghc913" ]; 27 | defaultCompiler = "ghc982"; 28 | }; 29 | in 30 | { 31 | packages = project-flake.packages; 32 | apps = project-flake.apps; 33 | checks = project-flake.checks; 34 | devShells = project-flake.devShells; 35 | legacyPackages."${name}" = project-flake; 36 | } 37 | ); 38 | } 39 | 40 | -------------------------------------------------------------------------------- /fourmolu.yaml: -------------------------------------------------------------------------------- 1 | column-limit: 100 2 | comma-style: leading 3 | fixities: [] 4 | function-arrows: trailing 5 | haddock-style: single-line 6 | haddock-style-module: null 7 | import-export-style: diff-friendly 8 | in-style: right-align 9 | indent-wheres: true 10 | indentation: 2 11 | let-style: inline 12 | newlines-between-decls: 1 13 | record-break-space: true 14 | reexports: [] 15 | respectful: true 16 | single-constraint-parens: never 17 | unicode: detect 18 | -------------------------------------------------------------------------------- /renders/backprop-mnist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mstksg/backprop/4826ca2f95e1706f9ff8f452141a76cc0f652fbf/renders/backprop-mnist.pdf -------------------------------------------------------------------------------- /renders/extensible-neural.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mstksg/backprop/4826ca2f95e1706f9ff8f452141a76cc0f652fbf/renders/extensible-neural.pdf -------------------------------------------------------------------------------- /src/Data/Type/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE GADTs #-} 3 | {-# LANGUAGE LambdaCase #-} 4 | {-# LANGUAGE PatternSynonyms #-} 5 | {-# LANGUAGE RankNTypes #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TupleSections #-} 8 | {-# LANGUAGE TypeFamilyDependencies #-} 9 | {-# LANGUAGE TypeInType #-} 10 | {-# LANGUAGE TypeOperators #-} 11 | {-# LANGUAGE UndecidableInstances #-} 12 | 13 | module Data.Type.Util ( 14 | runzipWith, 15 | rzipWithM_, 16 | Replicate, 17 | VecT (.., (:+)), 18 | vmap, 19 | withVec, 20 | vecToRec, 21 | fillRec, 22 | zipVecList, 23 | splitRec, 24 | p1, 25 | p2, 26 | s1, 27 | s2, 28 | ) where 29 | 30 | import Data.Bifunctor 31 | import Data.Functor.Identity 32 | import Data.Kind 33 | import Data.Proxy 34 | import Data.Vinyl.Core 35 | import Data.Vinyl.TypeLevel 36 | import GHC.Generics 37 | import Lens.Micro 38 | 39 | runzipWith :: 40 | forall f g h. 41 | () => 42 | (forall x. f x -> (g x, h x)) -> 43 | (forall xs. Rec f xs -> (Rec g xs, Rec h xs)) 44 | runzipWith f = go 45 | where 46 | go :: forall ys. Rec f ys -> (Rec g ys, Rec h ys) 47 | go = \case 48 | RNil -> (RNil, RNil) 49 | x :& xs -> 50 | let (y, z) = f x 51 | (ys, zs) = go xs 52 | in (y :& ys, z :& zs) 53 | {-# INLINE runzipWith #-} 54 | 55 | data VecT :: Nat -> (k -> Type) -> k -> Type where 56 | VNil :: VecT 'Z f a 57 | (:*) :: !(f a) -> VecT n f a -> VecT ('S n) f a 58 | 59 | pattern (:+) :: a -> VecT n Identity a -> VecT ('S n) Identity a 60 | pattern x :+ xs = Identity x :* xs 61 | 62 | vmap :: 63 | forall n f g a. 64 | () => 65 | (f a -> g a) -> VecT n f a -> VecT n g a 66 | vmap f = go 67 | where 68 | go :: VecT m f a -> VecT m g a 69 | go = \case 70 | VNil -> VNil 71 | x :* xs -> f x :* go xs 72 | {-# INLINE vmap #-} 73 | 74 | withVec :: 75 | [f a] -> 76 | (forall n. VecT n f a -> r) -> 77 | r 78 | withVec = \case 79 | [] -> \f -> f VNil 80 | x : xs -> \f -> withVec xs (f . (x :*)) 81 | {-# INLINE withVec #-} 82 | 83 | type family Replicate (n :: Nat) (a :: k) = (as :: [k]) | as -> n where 84 | Replicate 'Z a = '[] 85 | Replicate ('S n) a = a ': Replicate n a 86 | 87 | vecToRec :: 88 | VecT n f a -> 89 | Rec f (Replicate n a) 90 | vecToRec = \case 91 | VNil -> RNil 92 | x :* xs -> x :& vecToRec xs 93 | {-# INLINE vecToRec #-} 94 | 95 | fillRec :: 96 | forall f g as c. 97 | () => 98 | (forall a. f a -> c -> g a) -> 99 | Rec f as -> 100 | [c] -> 101 | Maybe (Rec g as) 102 | fillRec f = go 103 | where 104 | go :: Rec f bs -> [c] -> Maybe (Rec g bs) 105 | go = \case 106 | RNil -> \_ -> Just RNil 107 | x :& xs -> \case 108 | [] -> Nothing 109 | y : ys -> (f x y :&) <$> go xs ys 110 | {-# INLINE fillRec #-} 111 | 112 | rzipWithM_ :: 113 | forall h f g as. 114 | Applicative h => 115 | (forall a. f a -> g a -> h ()) -> 116 | Rec f as -> 117 | Rec g as -> 118 | h () 119 | rzipWithM_ f = go 120 | where 121 | go :: forall bs. Rec f bs -> Rec g bs -> h () 122 | go = \case 123 | RNil -> \case 124 | RNil -> pure () 125 | x :& xs -> \case 126 | y :& ys -> f x y *> go xs ys 127 | {-# INLINE rzipWithM_ #-} 128 | 129 | zipVecList :: 130 | forall a b c f g n. 131 | () => 132 | (f a -> Maybe b -> g c) -> 133 | VecT n f a -> 134 | [b] -> 135 | VecT n g c 136 | zipVecList f = go 137 | where 138 | go :: VecT m f a -> [b] -> VecT m g c 139 | go = \case 140 | VNil -> const VNil 141 | x :* xs -> \case 142 | [] -> f x Nothing :* go xs [] 143 | y : ys -> f x (Just y) :* go xs ys 144 | {-# INLINE zipVecList #-} 145 | 146 | splitRec :: 147 | forall f as bs. 148 | RecApplicative as => 149 | Rec f (as ++ bs) -> 150 | (Rec f as, Rec f bs) 151 | splitRec = go (rpure Proxy) 152 | where 153 | go :: Rec Proxy as' -> Rec f (as' ++ bs) -> (Rec f as', Rec f bs) 154 | go = \case 155 | RNil -> (RNil,) 156 | _ :& ps -> \case 157 | x :& xs -> first (x :&) $ go ps xs 158 | {-# INLINE splitRec #-} 159 | 160 | p1 :: Lens' ((f :*: g) a) (f a) 161 | p1 f (x :*: y) = (:*: y) <$> f x 162 | {-# INLINE p1 #-} 163 | 164 | p2 :: Lens' ((f :*: g) a) (g a) 165 | p2 f (x :*: y) = (x :*:) <$> f y 166 | {-# INLINE p2 #-} 167 | 168 | s1 :: Traversal' ((f :+: g) a) (f a) 169 | s1 f (L1 x) = L1 <$> f x 170 | s1 _ (R1 y) = pure (R1 y) 171 | {-# INLINE s1 #-} 172 | 173 | s2 :: Traversal' ((f :+: g) a) (g a) 174 | s2 _ (L1 x) = pure (L1 x) 175 | s2 f (R1 y) = R1 <$> f y 176 | {-# INLINE s2 #-} 177 | -------------------------------------------------------------------------------- /src/Numeric/Backprop/Explicit.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE EmptyCase #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE FunctionalDependencies #-} 6 | {-# LANGUAGE GADTs #-} 7 | {-# LANGUAGE LambdaCase #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE ScopedTypeVariables #-} 10 | {-# LANGUAGE TypeApplications #-} 11 | {-# LANGUAGE TypeOperators #-} 12 | {-# LANGUAGE UndecidableInstances #-} 13 | {-# OPTIONS_HADDOCK not-home #-} 14 | 15 | -- | 16 | -- Module : Numeric.Backprop.Explicit 17 | -- Copyright : (c) Justin Le 2023 18 | -- License : BSD3 19 | -- 20 | -- Maintainer : justin@jle.im 21 | -- Stability : experimental 22 | -- Portability : non-portable 23 | -- 24 | -- Provides "explicit" versions of all of the functions in 25 | -- "Numeric.Backprop". Instead of relying on a 'Backprop' instance, allows 26 | -- you to manually provide 'zero', 'add', and 'one' on a per-value basis. 27 | -- 28 | -- It is recommended you use "Numeric.Backprop" or "Numeric.Backprop.Num" 29 | -- instead, unless your type has no 'Num' instance, or you else you want to 30 | -- avoid defining orphan 'Backprop' instances for external types. Can also 31 | -- be useful if mixing and matching styles. 32 | -- 33 | -- See "Numeric.Backprop" for fuller documentation on using these 34 | -- functions. 35 | -- 36 | -- WARNING: API of this module can be considered only "semi-stable"; while 37 | -- the API of "Numeric.Backprop" and "Numeric.Backprop.Num" are kept 38 | -- consistent, some argument order changes might happen in this module to 39 | -- reflect changes in underlying implementation. 40 | -- 41 | -- @since 0.2.0.0 42 | module Numeric.Backprop.Explicit ( 43 | -- * Types 44 | BVar, 45 | W, 46 | Backprop (..), 47 | ABP (..), 48 | NumBP (..), 49 | 50 | -- * Explicit 'zero', 'add', and 'one' 51 | ZeroFunc (..), 52 | zfNum, 53 | zfNums, 54 | zeroFunc, 55 | zeroFuncs, 56 | zfFunctor, 57 | AddFunc (..), 58 | afNum, 59 | afNums, 60 | addFunc, 61 | addFuncs, 62 | OneFunc (..), 63 | ofNum, 64 | ofNums, 65 | oneFunc, 66 | oneFuncs, 67 | ofFunctor, 68 | 69 | -- * Running 70 | backprop, 71 | evalBP, 72 | gradBP, 73 | backpropWith, 74 | 75 | -- ** Multiple inputs 76 | evalBP0, 77 | backprop2, 78 | evalBP2, 79 | gradBP2, 80 | backpropWith2, 81 | backpropN, 82 | evalBPN, 83 | gradBPN, 84 | backpropWithN, 85 | RPureConstrained, 86 | 87 | -- * Manipulating 'BVar' 88 | constVar, 89 | auto, 90 | coerceVar, 91 | viewVar, 92 | setVar, 93 | overVar, 94 | sequenceVar, 95 | collectVar, 96 | previewVar, 97 | toListOfVar, 98 | 99 | -- ** With Isomorphisms 100 | isoVar, 101 | isoVar2, 102 | isoVar3, 103 | isoVarN, 104 | 105 | -- ** With 'Op's 106 | liftOp, 107 | liftOp1, 108 | liftOp2, 109 | liftOp3, 110 | 111 | -- ** Generics 112 | splitBV, 113 | joinBV, 114 | BVGroup, 115 | 116 | -- * 'Op' 117 | Op (..), 118 | 119 | -- ** Creation 120 | op0, 121 | opConst, 122 | idOp, 123 | bpOp, 124 | 125 | -- *** Giving gradients directly 126 | op1, 127 | op2, 128 | op3, 129 | 130 | -- *** From Isomorphisms 131 | opCoerce, 132 | opTup, 133 | opIso, 134 | opIsoN, 135 | opLens, 136 | 137 | -- *** No gradients 138 | noGrad1, 139 | noGrad, 140 | 141 | -- * Utility 142 | Reifies, 143 | ) where 144 | 145 | import Data.Bifunctor 146 | import Data.Functor.Identity 147 | import Data.Reflection 148 | import Data.Type.Util 149 | import Data.Vinyl.Core 150 | import Data.Vinyl.TypeLevel 151 | import GHC.Generics as G 152 | import Lens.Micro 153 | import Numeric.Backprop.Class 154 | import Numeric.Backprop.Internal 155 | import Numeric.Backprop.Op 156 | import Unsafe.Coerce 157 | 158 | -- | 'ZeroFunc's for every item in a type level list based on their 159 | -- 'Num' instances 160 | -- 161 | -- @since 0.2.0.0 162 | zfNums :: RPureConstrained Num as => Rec ZeroFunc as 163 | zfNums = rpureConstrained @Num zfNum 164 | 165 | -- | 'zeroFunc' for instances of 'Functor' 166 | -- 167 | -- @since 0.2.1.0 168 | zfFunctor :: (Backprop a, Functor f) => ZeroFunc (f a) 169 | zfFunctor = ZF zeroFunctor 170 | {-# INLINE zfFunctor #-} 171 | 172 | -- | 'ZeroFunc's for every item in a type level list based on their 173 | -- 'Num' instances 174 | -- 175 | -- @since 0.2.0.0 176 | afNums :: RPureConstrained Num as => Rec AddFunc as 177 | afNums = rpureConstrained @Num afNum 178 | 179 | -- | 'ZeroFunc's for every item in a type level list based on their 180 | -- 'Num' instances 181 | -- 182 | -- @since 0.2.0.0 183 | ofNums :: RPureConstrained Num as => Rec OneFunc as 184 | ofNums = rpureConstrained @Num ofNum 185 | 186 | -- | 'OneFunc' for instances of 'Functor' 187 | -- 188 | -- @since 0.2.1.0 189 | ofFunctor :: (Backprop a, Functor f) => OneFunc (f a) 190 | ofFunctor = OF oneFunctor 191 | {-# INLINE ofFunctor #-} 192 | 193 | -- | Generate an 'ZeroFunc' for every type in a type-level list, if every 194 | -- type has an instance of 'Backprop'. 195 | -- 196 | -- @since 0.2.0.0 197 | zeroFuncs :: RPureConstrained Backprop as => Rec ZeroFunc as 198 | zeroFuncs = rpureConstrained @Backprop zeroFunc 199 | 200 | -- | Generate an 'AddFunc' for every type in a type-level list, if every 201 | -- type has an instance of 'Backprop'. 202 | -- 203 | -- @since 0.2.0.0 204 | addFuncs :: RPureConstrained Backprop as => Rec AddFunc as 205 | addFuncs = rpureConstrained @Backprop addFunc 206 | 207 | -- | Generate an 'OneFunc' for every type in a type-level list, if every 208 | -- type has an instance of 'Backprop'. 209 | -- 210 | -- @since 0.2.0.0 211 | oneFuncs :: RPureConstrained Backprop as => Rec OneFunc as 212 | oneFuncs = rpureConstrained @Backprop oneFunc 213 | 214 | -- | Shorter alias for 'constVar', inspired by the /ad/ library. 215 | -- 216 | -- @since 0.2.0.0 217 | auto :: a -> BVar s a 218 | auto = constVar 219 | {-# INLINE auto #-} 220 | 221 | -- | 'Numeric.Backprop.backpropN', but with explicit 'zero' and 'one'. 222 | backpropN :: 223 | forall as b. 224 | () => 225 | Rec ZeroFunc as -> 226 | OneFunc b -> 227 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 228 | Rec Identity as -> 229 | (b, Rec Identity as) 230 | backpropN zfs ob f xs = case backpropWithN zfs f xs of 231 | (y, g) -> (y, g (runOF ob y)) 232 | {-# INLINE backpropN #-} 233 | 234 | -- | 'Numeric.Backprop.backprop', but with explicit 'zero' and 'one'. 235 | backprop :: 236 | ZeroFunc a -> 237 | OneFunc b -> 238 | (forall s. Reifies s W => BVar s a -> BVar s b) -> 239 | a -> 240 | (b, a) 241 | backprop zfa ofb f = 242 | second (\case Identity x :& RNil -> x) 243 | . backpropN (zfa :& RNil) ofb (f . (\case x :& RNil -> x)) 244 | . (:& RNil) 245 | . Identity 246 | {-# INLINE backprop #-} 247 | 248 | -- | 'Numeric.Backprop.backpropWith', but with explicit 'zero'. 249 | -- 250 | -- Note that argument order changed in v0.2.4. 251 | backpropWith :: 252 | ZeroFunc a -> 253 | (forall s. Reifies s W => BVar s a -> BVar s b) -> 254 | a -> 255 | (b, b -> a) 256 | backpropWith zfa f = 257 | second ((\case Identity x :& RNil -> x) .) 258 | . backpropWithN (zfa :& RNil) (f . (\case x :& RNil -> x)) 259 | . (:& RNil) 260 | . Identity 261 | {-# INLINE backpropWith #-} 262 | 263 | -- | 'evalBP' but with no arguments. Useful when everything is just given 264 | -- through 'constVar'. 265 | evalBP0 :: (forall s. Reifies s W => BVar s a) -> a 266 | evalBP0 x = evalBPN (const x) RNil 267 | {-# INLINE evalBP0 #-} 268 | 269 | -- | Turn a function @'BVar' s a -> 'BVar' s b@ into the function @a -> b@ 270 | -- that it represents. 271 | -- 272 | -- Benchmarks show that this should have virtually no overhead over 273 | -- directly writing a @a -> b@. 'BVar' is, in this situation, a zero-cost 274 | -- abstraction, performance-wise. 275 | -- 276 | -- See documentation of 'Numeric.Backprop.backprop' for more information. 277 | evalBP :: (forall s. Reifies s W => BVar s a -> BVar s b) -> a -> b 278 | evalBP f = evalBPN (f . (\case x :& RNil -> x)) . (:& RNil) . Identity 279 | {-# INLINE evalBP #-} 280 | 281 | -- | 'Numeric.Backprop.gradBP', but with explicit 'zero' and 'one'. 282 | gradBP :: 283 | ZeroFunc a -> 284 | OneFunc b -> 285 | (forall s. Reifies s W => BVar s a -> BVar s b) -> 286 | a -> 287 | a 288 | gradBP zfa ofb f = snd . backprop zfa ofb f 289 | {-# INLINE gradBP #-} 290 | 291 | -- | 'Numeric.Backprop.gradBP', Nbut with explicit 'zero' and 'one'. 292 | gradBPN :: 293 | Rec ZeroFunc as -> 294 | OneFunc b -> 295 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 296 | Rec Identity as -> 297 | Rec Identity as 298 | gradBPN zfas ofb f = snd . backpropN zfas ofb f 299 | {-# INLINE gradBPN #-} 300 | 301 | -- | 'Numeric.Backprop.backprop2', but with explicit 'zero' and 'one'. 302 | backprop2 :: 303 | ZeroFunc a -> 304 | ZeroFunc b -> 305 | OneFunc c -> 306 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 307 | a -> 308 | b -> 309 | (c, (a, b)) 310 | backprop2 zfa zfb ofc f x y = 311 | second (\(Identity dx :& Identity dy :& RNil) -> (dx, dy)) $ 312 | backpropN 313 | (zfa :& zfb :& RNil) 314 | ofc 315 | (\(x' :& y' :& RNil) -> f x' y') 316 | (Identity x :& Identity y :& RNil) 317 | {-# INLINE backprop2 #-} 318 | 319 | -- | 'Numeric.Backprop.backpropWith2', but with explicit 'zero'. 320 | -- 321 | -- Note that argument order changed in v0.2.4. 322 | -- 323 | -- @since 0.2.0.0 324 | backpropWith2 :: 325 | ZeroFunc a -> 326 | ZeroFunc b -> 327 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 328 | a -> 329 | b -> 330 | (c, c -> (a, b)) 331 | backpropWith2 zfa zfb f x y = 332 | second ((\(Identity dx :& Identity dy :& RNil) -> (dx, dy)) .) $ 333 | backpropWithN 334 | (zfa :& zfb :& RNil) 335 | (\(x' :& y' :& RNil) -> f x' y') 336 | (Identity x :& Identity y :& RNil) 337 | {-# INLINE backpropWith2 #-} 338 | 339 | -- | 'evalBP' for a two-argument function. See 340 | -- 'Numeric.Backprop.backprop2' for notes. 341 | evalBP2 :: 342 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 343 | a -> 344 | b -> 345 | c 346 | evalBP2 f x y = 347 | evalBPN (\(x' :& y' :& RNil) -> f x' y') $ 348 | Identity x 349 | :& Identity y 350 | :& RNil 351 | {-# INLINE evalBP2 #-} 352 | 353 | -- | 'Numeric.Backprop.gradBP2' with explicit 'zero' and 'one'. 354 | gradBP2 :: 355 | ZeroFunc a -> 356 | ZeroFunc b -> 357 | OneFunc c -> 358 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 359 | a -> 360 | b -> 361 | (a, b) 362 | gradBP2 zfa zfb ofc f x = snd . backprop2 zfa zfb ofc f x 363 | {-# INLINE gradBP2 #-} 364 | 365 | -- | 'Numeric.Backprop.bpOp' with explicit 'zero'. 366 | bpOp :: 367 | Rec ZeroFunc as -> 368 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 369 | Op as b 370 | bpOp zfs f = Op (backpropWithN zfs f) 371 | {-# INLINE bpOp #-} 372 | 373 | -- | 'Numeric.Backprop.overVar' with explicit 'add' and 'zero'. 374 | -- 375 | -- @since 0.2.4.0 376 | overVar :: 377 | Reifies s W => 378 | AddFunc a -> 379 | AddFunc b -> 380 | ZeroFunc a -> 381 | ZeroFunc b -> 382 | Lens' b a -> 383 | (BVar s a -> BVar s a) -> 384 | BVar s b -> 385 | BVar s b 386 | overVar afa afb zfa zfb l f x = setVar afa afb zfa l (f (viewVar afa zfb l x)) x 387 | {-# INLINE overVar #-} 388 | 389 | -- | 'Numeric.Backprop.isoVar' with explicit 'add' and 'zero'. 390 | isoVar :: 391 | Reifies s W => 392 | AddFunc a -> 393 | (a -> b) -> 394 | (b -> a) -> 395 | BVar s a -> 396 | BVar s b 397 | isoVar af f g = liftOp1 af (opIso f g) 398 | {-# INLINE isoVar #-} 399 | 400 | -- | 'Numeric.Backprop.isoVar2' with explicit 'add' and 'zero'. 401 | isoVar2 :: 402 | Reifies s W => 403 | AddFunc a -> 404 | AddFunc b -> 405 | (a -> b -> c) -> 406 | (c -> (a, b)) -> 407 | BVar s a -> 408 | BVar s b -> 409 | BVar s c 410 | isoVar2 afa afb f g = liftOp2 afa afb (opIso2 f g) 411 | {-# INLINE isoVar2 #-} 412 | 413 | -- | 'Numeric.Backprop.isoVar3' with explicit 'add' and 'zero'. 414 | isoVar3 :: 415 | Reifies s W => 416 | AddFunc a -> 417 | AddFunc b -> 418 | AddFunc c -> 419 | (a -> b -> c -> d) -> 420 | (d -> (a, b, c)) -> 421 | BVar s a -> 422 | BVar s b -> 423 | BVar s c -> 424 | BVar s d 425 | isoVar3 afa afb afc f g = liftOp3 afa afb afc (opIso3 f g) 426 | {-# INLINE isoVar3 #-} 427 | 428 | -- | 'Numeric.Backprop.isoVarN' with explicit 'add' and 'zero'. 429 | isoVarN :: 430 | Reifies s W => 431 | Rec AddFunc as -> 432 | (Rec Identity as -> b) -> 433 | (b -> Rec Identity as) -> 434 | Rec (BVar s) as -> 435 | BVar s b 436 | isoVarN afs f g = liftOp afs (opIsoN f g) 437 | {-# INLINE isoVarN #-} 438 | 439 | -- | Helper class for generically "splitting" and "joining" 'BVar's into 440 | -- constructors. See 'Numeric.Backprop.splitBV' and 441 | -- 'Numeric.Backprop.joinBV'. 442 | -- 443 | -- See "Numeric.Backprop#hkd" for a tutorial on how to use this. 444 | -- 445 | -- Instances should be available for types made with one constructor whose 446 | -- fields are all instances of 'Backprop', with a 'Generic' instance. 447 | -- 448 | -- @since 0.2.2.0 449 | class BVGroup s as i o | o -> i, i -> as where 450 | -- | Helper method for generically "splitting" 'BVar's out of 451 | -- constructors inside a 'BVar'. See 'splitBV'. 452 | gsplitBV :: Rec AddFunc as -> Rec ZeroFunc as -> BVar s (i ()) -> o () 453 | 454 | -- | Helper method for generically "joining" 'BVar's inside 455 | -- a constructor into a 'BVar'. See 'joinBV'. 456 | gjoinBV :: Rec AddFunc as -> Rec ZeroFunc as -> o () -> BVar s (i ()) 457 | 458 | instance BVGroup s '[] (K1 i a) (K1 i (BVar s a)) where 459 | gsplitBV _ _ = K1 . coerceVar 460 | {-# INLINE gsplitBV #-} 461 | gjoinBV _ _ = coerceVar . unK1 462 | {-# INLINE gjoinBV #-} 463 | 464 | instance 465 | BVGroup s as i o => 466 | BVGroup s as (M1 p c i) (M1 p c o) 467 | where 468 | gsplitBV afs zfs = M1 . gsplitBV afs zfs . coerceVar @_ @(i ()) 469 | {-# INLINE gsplitBV #-} 470 | gjoinBV afs zfs = coerceVar @(i ()) . gjoinBV afs zfs . unM1 471 | {-# INLINE gjoinBV #-} 472 | 473 | instance BVGroup s '[] V1 V1 where 474 | gsplitBV _ _ = unsafeCoerce 475 | {-# INLINE gsplitBV #-} 476 | gjoinBV _ _ = \case {} 477 | {-# INLINE gjoinBV #-} 478 | 479 | instance BVGroup s '[] U1 U1 where 480 | gsplitBV _ _ _ = U1 481 | {-# INLINE gsplitBV #-} 482 | gjoinBV _ _ _ = constVar U1 483 | {-# INLINE gjoinBV #-} 484 | 485 | instance 486 | ( Reifies s W 487 | , BVGroup s as i1 o1 488 | , BVGroup s bs i2 o2 489 | , cs ~ (as ++ bs) 490 | , RecApplicative as 491 | ) => 492 | BVGroup s (i1 () ': i2 () ': cs) (i1 :*: i2) (o1 :*: o2) 493 | where 494 | gsplitBV (afa :& afb :& afs) (zfa :& zfb :& zfs) xy = x :*: y 495 | where 496 | (afas, afbs) = splitRec afs 497 | (zfas, zfbs) = splitRec zfs 498 | zfab = ZF $ \(xx :*: yy) -> runZF zfa xx :*: runZF zfb yy 499 | x = gsplitBV afas zfas . viewVar afa zfab p1 $ xy 500 | y = gsplitBV afbs zfbs . viewVar afb zfab p2 $ xy 501 | {-# INLINE gsplitBV #-} 502 | gjoinBV (afa :& afb :& afs) (_ :& _ :& zfs) (x :*: y) = 503 | isoVar2 504 | afa 505 | afb 506 | (:*:) 507 | unP 508 | (gjoinBV afas zfas x) 509 | (gjoinBV afbs zfbs y) 510 | where 511 | (afas, afbs) = splitRec afs 512 | (zfas, zfbs) = splitRec zfs 513 | unP (xx :*: yy) = (xx, yy) 514 | {-# INLINE gjoinBV #-} 515 | 516 | -- | This instance is possible but it is not clear when it would be useful 517 | instance 518 | ( Reifies s W 519 | , BVGroup s as i1 o1 520 | , BVGroup s bs i2 o2 521 | , cs ~ (as ++ bs) 522 | , RecApplicative as 523 | ) => 524 | BVGroup s (i1 () ': i2 () ': cs) (i1 :+: i2) (o1 :+: o2) 525 | where 526 | gsplitBV (afa :& afb :& afs) (zfa :& zfb :& zfs) xy = 527 | case previewVar afa zf s1 xy of 528 | Just x -> L1 $ gsplitBV afas zfas x 529 | Nothing -> case previewVar afb zf s2 xy of 530 | Just y -> R1 $ gsplitBV afbs zfbs y 531 | Nothing -> error "Numeric.Backprop.gsplitBV: Internal error occurred" 532 | where 533 | zf = ZF $ \case 534 | L1 xx -> L1 $ runZF zfa xx 535 | R1 yy -> R1 $ runZF zfb yy 536 | (afas, afbs) = splitRec afs 537 | (zfas, zfbs) = splitRec zfs 538 | {-# INLINE gsplitBV #-} 539 | gjoinBV (afa :& afb :& afs) (zfa :& zfb :& zfs) = \case 540 | L1 x -> 541 | liftOp1 542 | afa 543 | (op1 (\xx -> (L1 xx, \case L1 d -> d; R1 _ -> runZF zfa xx))) 544 | (gjoinBV afas zfas x) 545 | R1 y -> 546 | liftOp1 547 | afb 548 | (op1 (\yy -> (R1 yy, \case L1 _ -> runZF zfb yy; R1 d -> d))) 549 | (gjoinBV afbs zfbs y) 550 | where 551 | (afas, afbs) = splitRec afs 552 | (zfas, zfbs) = splitRec zfs 553 | {-# INLINE gjoinBV #-} 554 | 555 | -- | 'Numeric.Backprop.splitBV' with explicit 'add' and 'zero'. 556 | -- 557 | -- @since 0.2.2.0 558 | splitBV :: 559 | forall z f s as. 560 | ( Generic (z f) 561 | , Generic (z (BVar s)) 562 | , BVGroup s as (Rep (z f)) (Rep (z (BVar s))) 563 | , Reifies s W 564 | ) => 565 | AddFunc (Rep (z f) ()) -> 566 | Rec AddFunc as -> 567 | ZeroFunc (z f) -> 568 | Rec ZeroFunc as -> 569 | -- | 'BVar' of value 570 | BVar s (z f) -> 571 | -- | 'BVar's of fields 572 | z (BVar s) 573 | splitBV af afs zf zfs = 574 | G.to 575 | . gsplitBV afs zfs 576 | . viewVar af zf (lens (from @(z f) @()) (const G.to)) 577 | {-# INLINE splitBV #-} 578 | 579 | -- | 'Numeric.Backprop.joinBV' with explicit 'add' and 'zero'. 580 | -- 581 | -- @since 0.2.2.0 582 | joinBV :: 583 | forall z f s as. 584 | ( Generic (z f) 585 | , Generic (z (BVar s)) 586 | , BVGroup s as (Rep (z f)) (Rep (z (BVar s))) 587 | , Reifies s W 588 | ) => 589 | AddFunc (z f) -> 590 | Rec AddFunc as -> 591 | ZeroFunc (Rep (z f) ()) -> 592 | Rec ZeroFunc as -> 593 | -- | 'BVar's of fields 594 | z (BVar s) -> 595 | -- | 'BVar' of combined value 596 | BVar s (z f) 597 | joinBV af afs zf zfs = 598 | viewVar af zf (lens G.to (const from)) 599 | . gjoinBV afs zfs 600 | . from @(z (BVar s)) @() 601 | {-# INLINE joinBV #-} 602 | -------------------------------------------------------------------------------- /src/Numeric/Backprop/Num.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE FlexibleContexts #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# OPTIONS_HADDOCK not-home #-} 6 | 7 | -- | 8 | -- Module : Numeric.Backprop.Num 9 | -- Copyright : (c) Justin Le 2023 10 | -- License : BSD3 11 | -- 12 | -- Maintainer : justin@jle.im 13 | -- Stability : experimental 14 | -- Portability : non-portable 15 | -- 16 | -- Provides the exact same API as "Numeric.Backprop", except requiring 17 | -- 'Num' instances for all types involved instead of 'Backprop' instances. 18 | -- 19 | -- This was the original API of the library (for version 0.1). 20 | -- 21 | -- 'Num' is strictly more powerful than 'Backprop', and is a stronger 22 | -- constraint on types than is necessary for proper backpropagating. In 23 | -- particular, 'fromInteger' is a problem for many types, preventing useful 24 | -- backpropagation for lists, variable-length vectors (like "Data.Vector") 25 | -- and variable-size matrices from linear algebra libraries like /hmatrix/ 26 | -- and /accelerate/. 27 | -- 28 | -- However, this module might be useful in situations where you are working 29 | -- with external types with 'Num' instances, and you want to avoid writing 30 | -- orphan instances for external types. 31 | -- 32 | -- If you have external types that are not 'Num' instances, consider 33 | -- instead "Numeric.Backprop.External". 34 | -- 35 | -- If you need a 'Num' instance for tuples, you can use the orphan 36 | -- instances in the package (in particular, "Data.NumInstances.Tuple") if you 38 | -- are writing an application and do not have to worry about orphan 39 | -- instances. 40 | -- 41 | -- See "Numeric.Backprop" for fuller documentation on using these 42 | -- functions. 43 | -- 44 | -- @since 0.2.0.0 45 | module Numeric.Backprop.Num ( 46 | -- * Types 47 | BVar, 48 | W, 49 | 50 | -- * Running 51 | backprop, 52 | E.evalBP, 53 | gradBP, 54 | backpropWith, 55 | 56 | -- ** Multiple inputs 57 | E.evalBP0, 58 | backprop2, 59 | E.evalBP2, 60 | gradBP2, 61 | backpropWith2, 62 | backpropN, 63 | E.evalBPN, 64 | gradBPN, 65 | backpropWithN, 66 | 67 | -- * Manipulating 'BVar' 68 | E.constVar, 69 | E.auto, 70 | E.coerceVar, 71 | (^^.), 72 | (.~~), 73 | (%~~), 74 | (^^?), 75 | (^^..), 76 | (^^?!), 77 | viewVar, 78 | setVar, 79 | overVar, 80 | sequenceVar, 81 | collectVar, 82 | previewVar, 83 | toListOfVar, 84 | 85 | -- ** With Isomorphisms 86 | isoVar, 87 | isoVar2, 88 | isoVar3, 89 | isoVarN, 90 | 91 | -- ** With 'Op's 92 | liftOp, 93 | liftOp1, 94 | liftOp2, 95 | liftOp3, 96 | 97 | -- * 'Op' 98 | Op (..), 99 | 100 | -- ** Creation 101 | op0, 102 | opConst, 103 | idOp, 104 | bpOp, 105 | 106 | -- *** Giving gradients directly 107 | op1, 108 | op2, 109 | op3, 110 | 111 | -- *** From Isomorphisms 112 | opCoerce, 113 | opTup, 114 | opIso, 115 | opIsoN, 116 | opLens, 117 | 118 | -- *** No gradients 119 | noGrad1, 120 | noGrad, 121 | 122 | -- * Utility 123 | Reifies, 124 | ) where 125 | 126 | import Data.Functor.Identity 127 | import Data.Maybe 128 | import Data.Reflection 129 | import Data.Vinyl 130 | import Lens.Micro 131 | import Numeric.Backprop.Explicit (BVar, W) 132 | import qualified Numeric.Backprop.Explicit as E 133 | import Numeric.Backprop.Op 134 | 135 | -- | 'Numeric.Backprop.backpropN', but with 'Num' constraints instead of 136 | -- 'Backprop' constraints. 137 | -- 138 | -- The @'RPureConstrained' 'Num' as@ in the constraint says that every 139 | -- value in the type-level list @as@ must have a 'Num' instance. This 140 | -- means you can use, say, @'[Double, Float, Int]@, but not @'[Double, 141 | -- Bool, String]@. 142 | -- 143 | -- If you stick to /concerete/, monomorphic usage of this (with specific 144 | -- types, typed into source code, known at compile-time), then 145 | -- @'AllPureConstrained' 'Num' as@ should be fulfilled automatically. 146 | backpropN :: 147 | (RPureConstrained Num as, Num b) => 148 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 149 | Rec Identity as -> 150 | (b, Rec Identity as) 151 | backpropN = E.backpropN E.zfNums E.ofNum 152 | {-# INLINE backpropN #-} 153 | 154 | -- | 'Numeric.Backprop.backpropWithN', but with 'Num' constraints instead 155 | -- of 'Backprop' constraints. 156 | -- 157 | -- See 'backpropN' for information on the 'AllConstrained' constraint. 158 | -- 159 | -- Note that argument order changed in v0.2.4. 160 | -- 161 | -- @since 0.2.0.0 162 | backpropWithN :: 163 | RPureConstrained Num as => 164 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 165 | Rec Identity as -> 166 | (b, b -> Rec Identity as) 167 | backpropWithN = E.backpropWithN E.zfNums 168 | {-# INLINE backpropWithN #-} 169 | 170 | -- | 'Numeric.Backprop.backprop', but with 'Num' constraints instead of 171 | -- 'Backprop' constraints. 172 | -- 173 | -- See module documentation for "Numeric.Backprop.Num" for information on 174 | -- using this with tuples. 175 | backprop :: 176 | (Num a, Num b) => 177 | (forall s. Reifies s W => BVar s a -> BVar s b) -> 178 | a -> 179 | (b, a) 180 | backprop = E.backprop E.zfNum E.ofNum 181 | {-# INLINE backprop #-} 182 | 183 | -- | 'Numeric.Backprop.backpropWith', but with 'Num' constraints instead of 184 | -- 'Backprop' constraints. 185 | -- 186 | -- See module documentation for "Numeric.Backprop.Num" for information on 187 | -- using this with tuples. 188 | -- 189 | -- Note that argument order changed in v0.2.4. 190 | -- 191 | -- @since 0.2.0.0 192 | backpropWith :: 193 | Num a => 194 | (forall s. Reifies s W => BVar s a -> BVar s b) -> 195 | a -> 196 | (b, b -> a) 197 | backpropWith = E.backpropWith E.zfNum 198 | {-# INLINE backpropWith #-} 199 | 200 | -- | 'Numeric.Backprop.gradBP', but with 'Num' constraints instead of 201 | -- 'Backprop' constraints. 202 | gradBP :: 203 | (Num a, Num b) => 204 | (forall s. Reifies s W => BVar s a -> BVar s b) -> 205 | a -> 206 | a 207 | gradBP = E.gradBP E.zfNum E.ofNum 208 | {-# INLINE gradBP #-} 209 | 210 | -- | 'Numeric.Backprop.gradBPN', but with 'Num' constraints instead of 211 | -- 'Backprop' constraints. 212 | gradBPN :: 213 | (RPureConstrained Num as, Num b) => 214 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 215 | Rec Identity as -> 216 | Rec Identity as 217 | gradBPN = E.gradBPN E.zfNums E.ofNum 218 | {-# INLINE gradBPN #-} 219 | 220 | -- | 'Numeric.Backprop.backprop2', but with 'Num' constraints instead of 221 | -- 'Backprop' constraints. 222 | backprop2 :: 223 | (Num a, Num b, Num c) => 224 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 225 | a -> 226 | b -> 227 | (c, (a, b)) 228 | backprop2 = E.backprop2 E.zfNum E.zfNum E.ofNum 229 | {-# INLINE backprop2 #-} 230 | 231 | -- | 'Numeric.Backprop.backpropWith2', but with 'Num' constraints instead of 232 | -- 'Backprop' constraints. 233 | -- 234 | -- Note that argument order changed in v0.2.4. 235 | -- 236 | -- @since 0.2.0.0 237 | backpropWith2 :: 238 | (Num a, Num b) => 239 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 240 | a -> 241 | b -> 242 | -- | Takes function giving gradient of final result given the output of function 243 | (c, c -> (a, b)) 244 | backpropWith2 = E.backpropWith2 E.zfNum E.zfNum 245 | {-# INLINE backpropWith2 #-} 246 | 247 | -- | 'Numeric.Backprop.gradBP2', but with 'Num' constraints instead of 248 | -- 'Backprop' constraints. 249 | gradBP2 :: 250 | (Num a, Num b, Num c) => 251 | (forall s. Reifies s W => BVar s a -> BVar s b -> BVar s c) -> 252 | a -> 253 | b -> 254 | (a, b) 255 | gradBP2 = E.gradBP2 E.zfNum E.zfNum E.ofNum 256 | {-# INLINE gradBP2 #-} 257 | 258 | -- | 'Numeric.Backprop.bpOp', but with 'Num' constraints instead of 259 | -- 'Backprop' constraints. 260 | bpOp :: 261 | RPureConstrained Num as => 262 | (forall s. Reifies s W => Rec (BVar s) as -> BVar s b) -> 263 | Op as b 264 | bpOp = E.bpOp E.zfNums 265 | {-# INLINE bpOp #-} 266 | 267 | -- | 'Numeric.Backprop.^^.', but with 'Num' constraints instead of 268 | -- 'Backprop' constraints. 269 | (^^.) :: 270 | forall b a s. 271 | (Num a, Num b, Reifies s W) => 272 | BVar s b -> 273 | Lens' b a -> 274 | BVar s a 275 | x ^^. l = viewVar l x 276 | 277 | infixl 8 ^^. 278 | {-# INLINE (^^.) #-} 279 | 280 | -- | 'Numeric.Backprop.viewVar', but with 'Num' constraints instead of 281 | -- 'Backprop' constraints. 282 | viewVar :: 283 | forall b a s. 284 | (Num a, Num b, Reifies s W) => 285 | Lens' b a -> 286 | BVar s b -> 287 | BVar s a 288 | viewVar = E.viewVar E.afNum E.zfNum 289 | {-# INLINE viewVar #-} 290 | 291 | -- | 'Numeric.Backprop..~~', but with 'Num' constraints instead of 292 | -- 'Backprop' constraints. 293 | (.~~) :: 294 | (Num a, Num b, Reifies s W) => 295 | Lens' b a -> 296 | BVar s a -> 297 | BVar s b -> 298 | BVar s b 299 | l .~~ x = setVar l x 300 | 301 | infixl 8 .~~ 302 | {-# INLINE (.~~) #-} 303 | 304 | -- | 'Numeric.Backprop.setVar', but with 'Num' constraints instead of 305 | -- 'Backprop' constraints. 306 | setVar :: 307 | forall a b s. 308 | (Num a, Num b, Reifies s W) => 309 | Lens' b a -> 310 | BVar s a -> 311 | BVar s b -> 312 | BVar s b 313 | setVar = E.setVar E.afNum E.afNum E.zfNum 314 | {-# INLINE setVar #-} 315 | 316 | -- | 'Numeric.Backprop.%~~', but with 'Num' constraints instead of 317 | -- 'Backprop' constraints. 318 | -- 319 | -- @since 0.2.4.0 320 | (%~~) :: 321 | (Num a, Num b, Reifies s W) => 322 | Lens' b a -> 323 | (BVar s a -> BVar s a) -> 324 | BVar s b -> 325 | BVar s b 326 | l %~~ f = overVar l f 327 | 328 | infixr 4 %~~ 329 | {-# INLINE (%~~) #-} 330 | 331 | -- | 'Numeric.Backprop.overVar', but with 'Num' constraints instead of 332 | -- 'Backprop' constraints. 333 | -- 334 | -- @since 0.2.4.0 335 | overVar :: 336 | (Num a, Num b, Reifies s W) => 337 | Lens' b a -> 338 | (BVar s a -> BVar s a) -> 339 | BVar s b -> 340 | BVar s b 341 | overVar = E.overVar E.afNum E.afNum E.zfNum E.zfNum 342 | {-# INLINE overVar #-} 343 | 344 | -- | 'Numeric.Backprop.^^?', but with 'Num' constraints instead of 345 | -- 'Backprop' constraints. 346 | -- 347 | -- Note that many automatically-generated prisms by the /lens/ package use 348 | -- tuples, which cannot work this this by default (because tuples do not 349 | -- have a 'Num' instance). 350 | -- 351 | -- If you are writing an application or don't have to worry about orphan 352 | -- instances, you can pull in the orphan instances from 353 | -- . 354 | -- Alternatively, you can chain those prisms with conversions to the 355 | -- anonymous canonical strict tuple types in "Numeric.Backprop.Tuple", 356 | -- which do have 'Num' instances. 357 | -- 358 | -- @ 359 | -- myPrism :: 'Prism'' c (a, b) 360 | -- myPrism . 'iso' 'tupT2' 't2Tup' :: 'Prism'' c ('T2' a b) 361 | -- @ 362 | (^^?) :: 363 | forall b a s. 364 | (Num b, Num a, Reifies s W) => 365 | BVar s b -> 366 | Traversal' b a -> 367 | Maybe (BVar s a) 368 | v ^^? t = previewVar t v 369 | 370 | infixl 8 ^^? 371 | {-# INLINE (^^?) #-} 372 | 373 | -- | 'Numeric.Backprop.^^?!', but with 'Num' constraints instead of 374 | -- 'Backprop' constraints. 375 | -- 376 | -- Like 'Numeric.Backprop.^^?!', is *UNSAFE*. 377 | -- 378 | -- @since 0.2.1.0 379 | (^^?!) :: 380 | forall b a s. 381 | (Num b, Num a, Reifies s W) => 382 | BVar s b -> 383 | Traversal' b a -> 384 | BVar s a 385 | v ^^?! t = fromMaybe (error e) (previewVar t v) 386 | where 387 | e = "Numeric.Backprop.Num.^^?!: Empty traversal" 388 | 389 | infixl 8 ^^?! 390 | {-# INLINE (^^?!) #-} 391 | 392 | -- | 'Numeric.Backprop.previewVar', but with 'Num' constraints instead of 393 | -- 'Backprop' constraints. 394 | -- 395 | -- See documentation for '^^?' for more information and important notes. 396 | previewVar :: 397 | forall b a s. 398 | (Num b, Num a, Reifies s W) => 399 | Traversal' b a -> 400 | BVar s b -> 401 | Maybe (BVar s a) 402 | previewVar = E.previewVar E.afNum E.zfNum 403 | {-# INLINE previewVar #-} 404 | 405 | -- | 'Numeric.Backprop.^^..', but with 'Num' constraints instead of 406 | -- 'Backprop' constraints. 407 | (^^..) :: 408 | forall b a s. 409 | (Num b, Num a, Reifies s W) => 410 | BVar s b -> 411 | Traversal' b a -> 412 | [BVar s a] 413 | v ^^.. t = toListOfVar t v 414 | {-# INLINE (^^..) #-} 415 | 416 | -- | 'Numeric.Backprop.toListOfVar', but with 'Num' constraints instead of 417 | -- 'Backprop' constraints. 418 | toListOfVar :: 419 | forall b a s. 420 | (Num b, Num a, Reifies s W) => 421 | Traversal' b a -> 422 | BVar s b -> 423 | [BVar s a] 424 | toListOfVar = E.toListOfVar E.afNum E.zfNum 425 | {-# INLINE toListOfVar #-} 426 | 427 | -- | 'Numeric.Backprop.sequenceVar', but with 'Num' constraints instead of 428 | -- 'Backprop' constraints. 429 | -- 430 | -- Since v0.2.4, requires a 'Num' constraint on @t a@. 431 | sequenceVar :: 432 | (Traversable t, Num a, Reifies s W) => 433 | BVar s (t a) -> 434 | t (BVar s a) 435 | sequenceVar = E.sequenceVar E.afNum E.zfNum 436 | {-# INLINE sequenceVar #-} 437 | 438 | -- | 'Numeric.Backprop.collectVar', but with 'Num' constraints instead of 439 | -- 'Backprop' constraints. 440 | -- 441 | -- Prior to v0.2.3, required a 'Num' constraint on @t a@. 442 | collectVar :: 443 | (Foldable t, Functor t, Num a, Reifies s W) => 444 | t (BVar s a) -> 445 | BVar s (t a) 446 | collectVar = E.collectVar E.afNum E.zfNum 447 | {-# INLINE collectVar #-} 448 | 449 | -- | 'Numeric.Backprop.liftOp', but with 'Num' constraints instead of 450 | -- 'Backprop' constraints. 451 | liftOp :: 452 | (RPureConstrained Num as, Reifies s W) => 453 | Op as b -> 454 | Rec (BVar s) as -> 455 | BVar s b 456 | liftOp = E.liftOp E.afNums 457 | {-# INLINE liftOp #-} 458 | 459 | -- | 'Numeric.Backprop.liftOp1', but with 'Num' constraints instead of 460 | -- 'Backprop' constraints. 461 | liftOp1 :: 462 | (Num a, Reifies s W) => 463 | Op '[a] b -> 464 | BVar s a -> 465 | BVar s b 466 | liftOp1 = E.liftOp1 E.afNum 467 | {-# INLINE liftOp1 #-} 468 | 469 | -- | 'Numeric.Backprop.liftOp2', but with 'Num' constraints instead of 470 | -- 'Backprop' constraints. 471 | liftOp2 :: 472 | (Num a, Num b, Reifies s W) => 473 | Op '[a, b] c -> 474 | BVar s a -> 475 | BVar s b -> 476 | BVar s c 477 | liftOp2 = E.liftOp2 E.afNum E.afNum 478 | {-# INLINE liftOp2 #-} 479 | 480 | -- | 'Numeric.Backprop.liftOp3', but with 'Num' constraints instead of 481 | -- 'Backprop' constraints. 482 | liftOp3 :: 483 | (Num a, Num b, Num c, Reifies s W) => 484 | Op '[a, b, c] d -> 485 | BVar s a -> 486 | BVar s b -> 487 | BVar s c -> 488 | BVar s d 489 | liftOp3 = E.liftOp3 E.afNum E.afNum E.afNum 490 | {-# INLINE liftOp3 #-} 491 | 492 | -- | 'Numeric.Backprop.isoVar', but with 'Num' constraints instead of 493 | -- 'Backprop' constraints. 494 | isoVar :: 495 | (Num a, Reifies s W) => 496 | (a -> b) -> 497 | (b -> a) -> 498 | BVar s a -> 499 | BVar s b 500 | isoVar = E.isoVar E.afNum 501 | {-# INLINE isoVar #-} 502 | 503 | -- | 'Numeric.Backprop.isoVar', but with 'Num' constraints instead of 504 | -- 'Backprop' constraints. 505 | isoVar2 :: 506 | (Num a, Num b, Reifies s W) => 507 | (a -> b -> c) -> 508 | (c -> (a, b)) -> 509 | BVar s a -> 510 | BVar s b -> 511 | BVar s c 512 | isoVar2 = E.isoVar2 E.afNum E.afNum 513 | {-# INLINE isoVar2 #-} 514 | 515 | -- | 'Numeric.Backprop.isoVar3', but with 'Num' constraints instead of 516 | -- 'Backprop' constraints. 517 | isoVar3 :: 518 | (Num a, Num b, Num c, Reifies s W) => 519 | (a -> b -> c -> d) -> 520 | (d -> (a, b, c)) -> 521 | BVar s a -> 522 | BVar s b -> 523 | BVar s c -> 524 | BVar s d 525 | isoVar3 = E.isoVar3 E.afNum E.afNum E.afNum 526 | {-# INLINE isoVar3 #-} 527 | 528 | -- | 'Numeric.Backprop.isoVarN', but with 'Num' constraints instead of 529 | -- 'Backprop' constraints. 530 | isoVarN :: 531 | (RPureConstrained Num as, Reifies s W) => 532 | (Rec Identity as -> b) -> 533 | (b -> Rec Identity as) -> 534 | Rec (BVar s) as -> 535 | BVar s b 536 | isoVarN = E.isoVarN E.afNums 537 | {-# INLINE isoVarN #-} 538 | -------------------------------------------------------------------------------- /src/Prelude/Backprop.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | 3 | -- | 4 | -- Module : Prelude.Backprop 5 | -- Copyright : (c) Justin Le 2023 6 | -- License : BSD3 7 | -- 8 | -- Maintainer : justin@jle.im 9 | -- Stability : experimental 10 | -- Portability : non-portable 11 | -- 12 | -- Some lifted versions of common functions found in 'Prelude' (or /base/ 13 | -- in general). 14 | -- 15 | -- This module is intended to be a catch-all one, so feel free to suggest 16 | -- other functions or submit a PR if you think one would make sense. 17 | -- 18 | -- See "Prelude.Backprop.Num" for a version with 'Num' constraints instead 19 | -- of 'Backprop' constraints, and "Prelude.Backprop.Explicit" for a version 20 | -- allowing you to provide 'zero', 'add', and 'one' explicitly. 21 | -- 22 | -- @since 0.1.3.0 23 | module Prelude.Backprop ( 24 | -- * Foldable and Traversable 25 | sum, 26 | product, 27 | length, 28 | minimum, 29 | maximum, 30 | traverse, 31 | toList, 32 | mapAccumL, 33 | mapAccumR, 34 | foldr, 35 | foldl', 36 | 37 | -- * Functor and Applicative 38 | fmap, 39 | fmapConst, 40 | (<$>), 41 | (<$), 42 | ($>), 43 | pure, 44 | liftA2, 45 | liftA3, 46 | 47 | -- * Numeric 48 | fromIntegral, 49 | realToFrac, 50 | round, 51 | fromIntegral', 52 | 53 | -- * Misc 54 | E.coerce, 55 | ) where 56 | 57 | import Numeric.Backprop 58 | import qualified Numeric.Backprop.Explicit as E 59 | import qualified Prelude.Backprop.Explicit as E 60 | import Prelude ( 61 | Applicative, 62 | Foldable, 63 | Fractional (..), 64 | Functor, 65 | Num (..), 66 | Ord (..), 67 | Traversable, 68 | ) 69 | import qualified Prelude as P 70 | 71 | -- | Lifted 'P.sum'. More efficient than going through 'toList'. 72 | sum :: 73 | (Foldable t, Functor t, Backprop (t a), Num a, Reifies s W) => 74 | BVar s (t a) -> 75 | BVar s a 76 | sum = E.sum E.addFunc 77 | {-# INLINE sum #-} 78 | 79 | -- | Lifted 'P.pure'. 80 | pure :: 81 | (Foldable t, Applicative t, Backprop a, Reifies s W) => 82 | BVar s a -> 83 | BVar s (t a) 84 | pure = E.pure E.addFunc E.zeroFunc 85 | {-# INLINE pure #-} 86 | 87 | -- | Lifted 'P.product'. More efficient than going through 'toList'. 88 | product :: 89 | (Foldable t, Functor t, Backprop (t a), Fractional a, Reifies s W) => 90 | BVar s (t a) -> 91 | BVar s a 92 | product = E.product E.addFunc 93 | {-# INLINE product #-} 94 | 95 | -- | Lifted 'P.length'. More efficient than going through 'toList'. 96 | length :: 97 | (Foldable t, Backprop (t a), Num b, Reifies s W) => 98 | BVar s (t a) -> 99 | BVar s b 100 | length = E.length E.addFunc E.zeroFunc 101 | {-# INLINE length #-} 102 | 103 | -- | Lifted 'P.minimum'. Undefined for situations where 'P.minimum' would 104 | -- be undefined. More efficient than going through 'toList'. 105 | minimum :: 106 | (Foldable t, Functor t, Backprop a, Ord a, Backprop (t a), Reifies s W) => 107 | BVar s (t a) -> 108 | BVar s a 109 | minimum = E.minimum E.addFunc E.zeroFunc 110 | {-# INLINE minimum #-} 111 | 112 | -- | Lifted 'P.maximum'. Undefined for situations where 'P.maximum' would 113 | -- be undefined. More efficient than going through 'toList'. 114 | maximum :: 115 | (Foldable t, Functor t, Backprop a, Ord a, Backprop (t a), Reifies s W) => 116 | BVar s (t a) -> 117 | BVar s a 118 | maximum = E.maximum E.addFunc E.zeroFunc 119 | {-# INLINE maximum #-} 120 | 121 | -- | Lifed 'P.foldr'. Essentially just 'toList' composed with a normal 122 | -- list 'P.foldr', and is only here for convenience. 123 | -- 124 | -- @since 0.2.3.0 125 | foldr :: 126 | (Traversable t, Backprop a, Reifies s W) => 127 | (BVar s a -> BVar s b -> BVar s b) -> 128 | BVar s b -> 129 | BVar s (t a) -> 130 | BVar s b 131 | foldr = E.foldr E.addFunc E.zeroFunc 132 | {-# INLINE foldr #-} 133 | 134 | -- | Lifed 'P.foldl''. Essentially just 'toList' composed with a normal 135 | -- list 'P.foldl'', and is only here for convenience. 136 | -- 137 | -- @since 0.2.3.0 138 | foldl' :: 139 | (Traversable t, Backprop a, Reifies s W) => 140 | (BVar s b -> BVar s a -> BVar s b) -> 141 | BVar s b -> 142 | BVar s (t a) -> 143 | BVar s b 144 | foldl' = E.foldl' E.addFunc E.zeroFunc 145 | {-# INLINE foldl' #-} 146 | 147 | -- | Lifted 'P.fmap'. Lifts backpropagatable functions to be 148 | -- backpropagatable functions on 'Traversable' 'Functor's. 149 | fmap :: 150 | (Traversable f, Backprop a, Backprop b, Reifies s W) => 151 | (BVar s a -> BVar s b) -> 152 | BVar s (f a) -> 153 | BVar s (f b) 154 | fmap = E.fmap E.addFunc E.addFunc E.zeroFunc E.zeroFunc 155 | {-# INLINE fmap #-} 156 | 157 | -- | Efficient version of 'fmap' when used to "replace" all values in 158 | -- a 'Functor' value. 159 | -- 160 | -- @ 161 | -- 'fmapConst' x = 'fmap' ('P.const' x) 162 | -- @ 163 | -- 164 | -- but much more efficient. 165 | -- 166 | -- @since 0.2.4.0 167 | fmapConst :: 168 | (Functor f, Foldable f, Backprop b, Backprop (f a), Reifies s W) => 169 | BVar s b -> 170 | BVar s (f a) -> 171 | BVar s (f b) 172 | fmapConst = E.fmapConst E.addFunc E.addFunc E.zeroFunc E.zeroFunc 173 | {-# INLINE fmapConst #-} 174 | 175 | -- | Alias for 'fmap'. 176 | (<$>) :: 177 | (Traversable f, Backprop a, Backprop b, Reifies s W) => 178 | (BVar s a -> BVar s b) -> 179 | BVar s (f a) -> 180 | BVar s (f b) 181 | (<$>) = fmap 182 | 183 | infixl 4 <$> 184 | {-# INLINE (<$>) #-} 185 | 186 | -- | Alias for 'fmapConst'. 187 | -- 188 | -- @since 0.2.4.0 189 | (<$) :: 190 | (Traversable f, Backprop b, Backprop (f a), Reifies s W) => 191 | BVar s b -> 192 | BVar s (f a) -> 193 | BVar s (f b) 194 | (<$) = fmapConst 195 | 196 | infixl 4 <$ 197 | {-# INLINE (<$) #-} 198 | 199 | -- | Alias for @'flip' 'fmapConst'@. 200 | -- 201 | -- @since 0.2.4.0 202 | ($>) :: 203 | (Traversable f, Backprop b, Backprop (f a), Reifies s W) => 204 | BVar s (f a) -> 205 | BVar s b -> 206 | BVar s (f b) 207 | xs $> x = x <$ xs 208 | 209 | infixl 4 $> 210 | {-# INLINE ($>) #-} 211 | 212 | -- | Lifted 'P.traverse'. Lifts backpropagatable functions to be 213 | -- backpropagatable functions on 'Traversable' 'Functor's. 214 | traverse :: 215 | (Traversable t, Applicative f, Foldable f, Backprop a, Backprop b, Backprop (t b), Reifies s W) => 216 | (BVar s a -> f (BVar s b)) -> 217 | BVar s (t a) -> 218 | BVar s (f (t b)) 219 | traverse = 220 | E.traverse 221 | E.addFunc 222 | E.addFunc 223 | E.addFunc 224 | E.zeroFunc 225 | E.zeroFunc 226 | {-# INLINE traverse #-} 227 | 228 | -- | Lifted 'P.liftA2'. Lifts backpropagatable functions to be 229 | -- backpropagatable functions on 'Traversable' 'Applicative's. 230 | liftA2 :: 231 | ( Traversable f 232 | , Applicative f 233 | , Backprop a 234 | , Backprop b 235 | , Backprop c 236 | , Reifies s W 237 | ) => 238 | (BVar s a -> BVar s b -> BVar s c) -> 239 | BVar s (f a) -> 240 | BVar s (f b) -> 241 | BVar s (f c) 242 | liftA2 = 243 | E.liftA2 244 | E.addFunc 245 | E.addFunc 246 | E.addFunc 247 | E.zeroFunc 248 | E.zeroFunc 249 | E.zeroFunc 250 | {-# INLINE liftA2 #-} 251 | 252 | -- | Lifted 'P.liftA3'. Lifts backpropagatable functions to be 253 | -- backpropagatable functions on 'Traversable' 'Applicative's. 254 | liftA3 :: 255 | ( Traversable f 256 | , Applicative f 257 | , Backprop a 258 | , Backprop b 259 | , Backprop c 260 | , Backprop d 261 | , Reifies s W 262 | ) => 263 | (BVar s a -> BVar s b -> BVar s c -> BVar s d) -> 264 | BVar s (f a) -> 265 | BVar s (f b) -> 266 | BVar s (f c) -> 267 | BVar s (f d) 268 | liftA3 = 269 | E.liftA3 270 | E.addFunc 271 | E.addFunc 272 | E.addFunc 273 | E.addFunc 274 | E.zeroFunc 275 | E.zeroFunc 276 | E.zeroFunc 277 | E.zeroFunc 278 | {-# INLINE liftA3 #-} 279 | 280 | -- | Lifted conversion between two 'P.Integral' instances. 281 | -- 282 | -- @since 0.2.1.0 283 | fromIntegral :: 284 | (Backprop a, P.Integral a, P.Integral b, Reifies s W) => 285 | BVar s a -> 286 | BVar s b 287 | fromIntegral = E.fromIntegral E.addFunc 288 | {-# INLINE fromIntegral #-} 289 | 290 | -- | Lifted conversion between two 'Fractional' and 'P.Real' instances. 291 | -- 292 | -- @since 0.2.1.0 293 | realToFrac :: 294 | (Backprop a, Fractional a, P.Real a, Fractional b, P.Real b, Reifies s W) => 295 | BVar s a -> 296 | BVar s b 297 | realToFrac = E.realToFrac E.addFunc 298 | {-# INLINE realToFrac #-} 299 | 300 | -- | Lifted version of 'P.round'. 301 | -- 302 | -- Gradient should technically diverge whenever the fractional part is 0.5, 303 | -- but does not do this for convenience reasons. 304 | -- 305 | -- @since 0.2.3.0 306 | round :: 307 | (P.RealFrac a, P.Integral b, Reifies s W) => 308 | BVar s a -> 309 | BVar s b 310 | round = E.round E.afNum 311 | {-# INLINE round #-} 312 | 313 | -- | Lifted version of 'P.fromIntegral', defined to let you return 314 | -- 'P.RealFrac' instances as targets, instead of only other 'P.Integral's. 315 | -- Essentially the opposite of 'round'. 316 | -- 317 | -- The gradient should technically diverge whenever the fractional part of 318 | -- the downstream gradient is 0.5, but does not do this for convenience 319 | -- reasons. 320 | -- 321 | -- @since 0.2.3.0 322 | fromIntegral' :: 323 | (P.Integral a, P.RealFrac b, Reifies s W) => 324 | BVar s a -> 325 | BVar s b 326 | fromIntegral' = E.fromIntegral' E.afNum 327 | {-# INLINE fromIntegral' #-} 328 | 329 | -- | Lifted version of 'P.toList'. Takes a 'BVar' of a 'Traversable' of 330 | -- items and returns a list of 'BVar's for each item. 331 | -- 332 | -- You can use this to implement "lifted" versions of 'Foldable' methods 333 | -- like 'P.foldr', 'P.foldl'', etc.; however, 'sum', 'product', 'length', 334 | -- 'minimum', and 'maximum' have more efficient implementations than simply 335 | -- @'P.minimum' . 'toList'.@ 336 | -- 337 | -- @since 0.2.2.0 338 | toList :: 339 | (Traversable t, Backprop a, Reifies s W) => 340 | BVar s (t a) -> 341 | [BVar s a] 342 | toList = E.toList E.addFunc E.zeroFunc 343 | {-# INLINE toList #-} 344 | 345 | -- | Lifted version of 'P.mapAccumL'. 346 | -- 347 | -- Prior to v0.2.3, required a 'Backprop' constraint on @t b@. 348 | -- 349 | -- @since 0.2.2.0 350 | mapAccumL :: 351 | (Traversable t, Backprop b, Backprop c, Reifies s W) => 352 | (BVar s a -> BVar s b -> (BVar s a, BVar s c)) -> 353 | BVar s a -> 354 | BVar s (t b) -> 355 | (BVar s a, BVar s (t c)) 356 | mapAccumL = E.mapAccumL E.addFunc E.addFunc E.zeroFunc E.zeroFunc 357 | {-# INLINE mapAccumL #-} 358 | 359 | -- | Lifted version of 'P.mapAccumR'. 360 | -- 361 | -- Prior to v0.2.3, required a 'Backprop' constraint on @t b@. 362 | -- 363 | -- @since 0.2.2.0 364 | mapAccumR :: 365 | (Traversable t, Backprop b, Backprop c, Reifies s W) => 366 | (BVar s a -> BVar s b -> (BVar s a, BVar s c)) -> 367 | BVar s a -> 368 | BVar s (t b) -> 369 | (BVar s a, BVar s (t c)) 370 | mapAccumR = E.mapAccumR E.addFunc E.addFunc E.zeroFunc E.zeroFunc 371 | {-# INLINE mapAccumR #-} 372 | -------------------------------------------------------------------------------- /src/Prelude/Backprop/Explicit.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# OPTIONS_HADDOCK not-home #-} 3 | 4 | -- | 5 | -- Module : Prelude.Backprop.Explicit 6 | -- Copyright : (c) Justin Le 2023 7 | -- License : BSD3 8 | -- 9 | -- Maintainer : justin@jle.im 10 | -- Stability : experimental 11 | -- Portability : non-portable 12 | -- 13 | -- Provides "explicit" versions of all of the functions in 14 | -- "Prelude.Backprop". Instead of relying on a 'Backprop' instance, allows 15 | -- you to manually provide 'zero', 'add', and 'one' on a per-value basis. 16 | -- 17 | -- WARNING: API of this module can be considered only "semi-stable"; while 18 | -- the API of "Prelude.Backprop" and Prelude.Backprop.Num" are kept 19 | -- consistent, some argument order changes might happen in this module to 20 | -- reflect changes in underlying implementation. 21 | -- 22 | -- @since 0.2.0.0 23 | module Prelude.Backprop.Explicit ( 24 | -- * Foldable and Traversable 25 | sum, 26 | product, 27 | length, 28 | minimum, 29 | maximum, 30 | traverse, 31 | toList, 32 | mapAccumL, 33 | mapAccumR, 34 | foldr, 35 | foldl', 36 | 37 | -- * Functor and Applicative 38 | fmap, 39 | fmapConst, 40 | pure, 41 | liftA2, 42 | liftA3, 43 | 44 | -- * Numeric 45 | fromIntegral, 46 | realToFrac, 47 | round, 48 | fromIntegral', 49 | 50 | -- * Misc 51 | coerce, 52 | ) where 53 | 54 | import qualified Control.Applicative as P 55 | import Data.Bifunctor 56 | import qualified Data.Coerce as C 57 | import qualified Data.Foldable as P 58 | import qualified Data.Traversable as P 59 | import Numeric.Backprop.Explicit 60 | import Prelude ( 61 | Applicative, 62 | Eq (..), 63 | Foldable, 64 | Fractional (..), 65 | Functor, 66 | Num (..), 67 | Ord (..), 68 | Traversable, 69 | ($), 70 | (.), 71 | ) 72 | import qualified Prelude as P 73 | 74 | -- | 'Prelude.Backprop.sum', but taking explicit 'add' and 'zero'. 75 | sum :: 76 | (Foldable t, Functor t, Num a, Reifies s W) => 77 | AddFunc (t a) -> 78 | BVar s (t a) -> 79 | BVar s a 80 | sum af = liftOp1 af . op1 $ \xs -> 81 | ( P.sum xs 82 | , (P.<$ xs) 83 | ) 84 | {-# INLINE sum #-} 85 | 86 | -- | 'Prelude.Backprop.pure', but taking explicit 'add' and 'zero'. 87 | pure :: 88 | (Foldable t, Applicative t, Reifies s W) => 89 | AddFunc a -> 90 | ZeroFunc a -> 91 | BVar s a -> 92 | BVar s (t a) 93 | pure af zfa = liftOp1 af . op1 $ \x -> 94 | ( P.pure x 95 | , \d -> case P.toList d of 96 | [] -> runZF zfa x 97 | e : es -> P.foldl' (runAF af) e es 98 | ) 99 | {-# INLINE pure #-} 100 | 101 | -- | 'Prelude.Backprop.product', but taking explicit 'add' and 'zero'. 102 | product :: 103 | (Foldable t, Functor t, Fractional a, Reifies s W) => 104 | AddFunc (t a) -> 105 | BVar s (t a) -> 106 | BVar s a 107 | product af = liftOp1 af . op1 $ \xs -> 108 | let p = P.product xs 109 | in ( p 110 | , \d -> (\x -> p * d / x) P.<$> xs 111 | ) 112 | {-# INLINE product #-} 113 | 114 | -- | 'Prelude.Backprop.length', but taking explicit 'add' and 'zero'. 115 | length :: 116 | (Foldable t, Num b, Reifies s W) => 117 | AddFunc (t a) -> 118 | ZeroFunc (t a) -> 119 | BVar s (t a) -> 120 | BVar s b 121 | length af zfa = liftOp1 af . op1 $ \xs -> 122 | ( P.fromIntegral (P.length xs) 123 | , P.const (runZF zfa xs) 124 | ) 125 | {-# INLINE length #-} 126 | 127 | -- | 'Prelude.Backprop.minimum', but taking explicit 'add' and 'zero'. 128 | minimum :: 129 | (Foldable t, Functor t, Ord a, Reifies s W) => 130 | AddFunc (t a) -> 131 | ZeroFunc a -> 132 | BVar s (t a) -> 133 | BVar s a 134 | minimum af zf = liftOp1 af . op1 $ \xs -> 135 | let m = P.minimum xs 136 | in ( m 137 | , \d -> (\x -> if x == m then d else runZF zf x) P.<$> xs 138 | ) 139 | {-# INLINE minimum #-} 140 | 141 | -- | 'Prelude.Backprop.maximum', but taking explicit 'add' and 'zero'. 142 | maximum :: 143 | (Foldable t, Functor t, Ord a, Reifies s W) => 144 | AddFunc (t a) -> 145 | ZeroFunc a -> 146 | BVar s (t a) -> 147 | BVar s a 148 | maximum af zf = liftOp1 af . op1 $ \xs -> 149 | let m = P.maximum xs 150 | in ( m 151 | , \d -> (\x -> if x == m then d else runZF zf x) P.<$> xs 152 | ) 153 | {-# INLINE maximum #-} 154 | 155 | -- | 'Prelude.Backprop.foldr', but taking explicit 'add' and 'zero'. 156 | -- 157 | -- @since 0.2.3.0 158 | foldr :: 159 | (Traversable t, Reifies s W) => 160 | AddFunc a -> 161 | ZeroFunc a -> 162 | (BVar s a -> BVar s b -> BVar s b) -> 163 | BVar s b -> 164 | BVar s (t a) -> 165 | BVar s b 166 | foldr af z f x = P.foldr f x . toList af z 167 | {-# INLINE foldr #-} 168 | 169 | -- | 'Prelude.Backprop.foldl'', but taking explicit 'add' and 'zero'. 170 | -- 171 | -- @since 0.2.3.0 172 | foldl' :: 173 | (Traversable t, Reifies s W) => 174 | AddFunc a -> 175 | ZeroFunc a -> 176 | (BVar s b -> BVar s a -> BVar s b) -> 177 | BVar s b -> 178 | BVar s (t a) -> 179 | BVar s b 180 | foldl' af z f x = P.foldl' f x . toList af z 181 | {-# INLINE foldl' #-} 182 | 183 | -- | 'Prelude.Backprop.fmap', but taking explicit 'add' and 'zero'. 184 | fmap :: 185 | (Traversable f, Reifies s W) => 186 | AddFunc a -> 187 | AddFunc b -> 188 | ZeroFunc a -> 189 | ZeroFunc b -> 190 | (BVar s a -> BVar s b) -> 191 | BVar s (f a) -> 192 | BVar s (f b) 193 | fmap afa afb zfa zfb f = collectVar afb zfb . P.fmap f . sequenceVar afa zfa 194 | {-# INLINE fmap #-} 195 | 196 | -- | 'Prelude.Backprop.fmapConst', but taking explicit 'add' and 'zero'. 197 | -- 198 | -- @since 0.2.4.0 199 | fmapConst :: 200 | (Functor f, Foldable f, Reifies s W) => 201 | AddFunc (f a) -> 202 | AddFunc b -> 203 | ZeroFunc (f a) -> 204 | ZeroFunc b -> 205 | BVar s b -> 206 | BVar s (f a) -> 207 | BVar s (f b) 208 | fmapConst afa afb zfa zfb = liftOp2 afb afa . op2 $ \x xs -> 209 | ( x P.<$ xs 210 | , \d -> 211 | ( case P.toList d of 212 | [] -> runZF zfb x 213 | e : es -> P.foldl' (runAF afb) e es 214 | , runZF zfa xs 215 | ) 216 | ) 217 | {-# INLINE fmapConst #-} 218 | 219 | -- | 'Prelude.Backprop.traverse', but taking explicit 'add' and 'zero'. 220 | traverse :: 221 | (Traversable t, Applicative f, Foldable f, Reifies s W) => 222 | AddFunc a -> 223 | AddFunc b -> 224 | AddFunc (t b) -> 225 | ZeroFunc a -> 226 | ZeroFunc b -> 227 | (BVar s a -> f (BVar s b)) -> 228 | BVar s (t a) -> 229 | BVar s (f (t b)) 230 | traverse afa afb aftb zfa zfb f = 231 | collectVar aftb zftb 232 | . P.fmap (collectVar afb zfb) 233 | . P.traverse f 234 | . sequenceVar afa zfa 235 | where 236 | zftb = ZF $ P.fmap (runZF zfb) 237 | {-# INLINE zftb #-} 238 | {-# INLINE traverse #-} 239 | 240 | -- | 'Prelude.Backprop.liftA2', but taking explicit 'add' and 'zero'. 241 | liftA2 :: 242 | ( Traversable f 243 | , Applicative f 244 | , Reifies s W 245 | ) => 246 | AddFunc a -> 247 | AddFunc b -> 248 | AddFunc c -> 249 | ZeroFunc a -> 250 | ZeroFunc b -> 251 | ZeroFunc c -> 252 | (BVar s a -> BVar s b -> BVar s c) -> 253 | BVar s (f a) -> 254 | BVar s (f b) -> 255 | BVar s (f c) 256 | liftA2 afa afb afc zfa zfb zfc f x y = 257 | collectVar afc zfc $ 258 | f 259 | P.<$> sequenceVar afa zfa x 260 | P.<*> sequenceVar afb zfb y 261 | {-# INLINE liftA2 #-} 262 | 263 | -- | 'Prelude.Backprop.liftA3', but taking explicit 'add' and 'zero'. 264 | liftA3 :: 265 | ( Traversable f 266 | , Applicative f 267 | , Reifies s W 268 | ) => 269 | AddFunc a -> 270 | AddFunc b -> 271 | AddFunc c -> 272 | AddFunc d -> 273 | ZeroFunc a -> 274 | ZeroFunc b -> 275 | ZeroFunc c -> 276 | ZeroFunc d -> 277 | (BVar s a -> BVar s b -> BVar s c -> BVar s d) -> 278 | BVar s (f a) -> 279 | BVar s (f b) -> 280 | BVar s (f c) -> 281 | BVar s (f d) 282 | liftA3 afa afb afc afd zfa zfb zfc zfd f x y z = 283 | collectVar afd zfd $ 284 | f 285 | P.<$> sequenceVar afa zfa x 286 | P.<*> sequenceVar afb zfb y 287 | P.<*> sequenceVar afc zfc z 288 | {-# INLINE liftA3 #-} 289 | 290 | -- | Coerce items inside a 'BVar'. 291 | coerce :: C.Coercible a b => BVar s a -> BVar s b 292 | coerce = coerceVar 293 | {-# INLINE coerce #-} 294 | 295 | -- | 'Prelude.Backprop.fromIntegral', but taking explicit 'add' and 'zero'. 296 | -- 297 | -- @since 0.2.1.0 298 | fromIntegral :: 299 | (P.Integral a, P.Integral b, Reifies s W) => 300 | AddFunc a -> 301 | BVar s a -> 302 | BVar s b 303 | fromIntegral af = isoVar af P.fromIntegral P.fromIntegral 304 | {-# INLINE fromIntegral #-} 305 | 306 | -- | 'Prelude.Backprop.realToFrac', but taking explicit 'add' and 'zero'. 307 | -- 308 | -- @since 0.2.1.0 309 | realToFrac :: 310 | (Fractional a, P.Real a, Fractional b, P.Real b, Reifies s W) => 311 | AddFunc a -> 312 | BVar s a -> 313 | BVar s b 314 | realToFrac af = isoVar af P.realToFrac P.realToFrac 315 | {-# INLINE realToFrac #-} 316 | 317 | -- | 'Prelude.Backprop.round', but taking explicit 'add' and 'zero'. 318 | -- 319 | -- @since 0.2.3.0 320 | round :: 321 | (P.RealFrac a, P.Integral b, Reifies s W) => 322 | AddFunc a -> 323 | BVar s a -> 324 | BVar s b 325 | round af = isoVar af P.round P.fromIntegral 326 | {-# INLINE round #-} 327 | 328 | -- | 'Prelude.Backprop.fromIntegral'', but taking explicit 'add' and 329 | -- 'zero'. 330 | -- 331 | -- @since 0.2.3.0 332 | fromIntegral' :: 333 | (P.Integral a, P.RealFrac b, Reifies s W) => 334 | AddFunc a -> 335 | BVar s a -> 336 | BVar s b 337 | fromIntegral' af = isoVar af P.fromIntegral P.round 338 | {-# INLINE fromIntegral' #-} 339 | 340 | -- | 'Prelude.Backprop.length', but taking explicit 'add' and 'zero'. 341 | -- 342 | -- @since 0.2.2.0 343 | toList :: 344 | (Traversable t, Reifies s W) => 345 | AddFunc a -> 346 | ZeroFunc a -> 347 | BVar s (t a) -> 348 | [BVar s a] 349 | toList af z = toListOfVar af (ZF (P.fmap (runZF z))) P.traverse 350 | {-# INLINE toList #-} 351 | 352 | -- | 'Prelude.Backprop.mapAccumL', but taking explicit 'add' and 'zero'. 353 | -- 354 | -- @since 0.2.2.0 355 | mapAccumL :: 356 | (Traversable t, Reifies s W) => 357 | AddFunc b -> 358 | AddFunc c -> 359 | ZeroFunc b -> 360 | ZeroFunc c -> 361 | (BVar s a -> BVar s b -> (BVar s a, BVar s c)) -> 362 | BVar s a -> 363 | BVar s (t b) -> 364 | (BVar s a, BVar s (t c)) 365 | mapAccumL afb afc zfb zfc f s = 366 | second (collectVar afc zfc) 367 | . P.mapAccumL f s 368 | . sequenceVar afb zfb 369 | {-# INLINE mapAccumL #-} 370 | 371 | -- | 'Prelude.Backprop.mapAccumR', but taking explicit 'add' and 'zero'. 372 | -- 373 | -- @since 0.2.2.0 374 | mapAccumR :: 375 | (Traversable t, Reifies s W) => 376 | AddFunc b -> 377 | AddFunc c -> 378 | ZeroFunc b -> 379 | ZeroFunc c -> 380 | (BVar s a -> BVar s b -> (BVar s a, BVar s c)) -> 381 | BVar s a -> 382 | BVar s (t b) -> 383 | (BVar s a, BVar s (t c)) 384 | mapAccumR afb afc zfb zfc f s = 385 | second (collectVar afc zfc) 386 | . P.mapAccumR f s 387 | . sequenceVar afb zfb 388 | {-# INLINE mapAccumR #-} 389 | -------------------------------------------------------------------------------- /src/Prelude/Backprop/Num.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# OPTIONS_HADDOCK not-home #-} 3 | 4 | -- | 5 | -- Module : Prelude.Backprop.Num 6 | -- Copyright : (c) Justin Le 2023 7 | -- License : BSD3 8 | -- 9 | -- Maintainer : justin@jle.im 10 | -- Stability : experimental 11 | -- Portability : non-portable 12 | -- 13 | -- Provides the exact same API as "Prelude.Backprop", except requiring 14 | -- 'Num' instances for all types involved instead of 'Backprop' instances. 15 | -- 16 | -- @since 0.2.0.0 17 | module Prelude.Backprop.Num ( 18 | -- * Foldable and Traversable 19 | sum, 20 | product, 21 | length, 22 | minimum, 23 | maximum, 24 | traverse, 25 | toList, 26 | mapAccumL, 27 | mapAccumR, 28 | foldr, 29 | foldl', 30 | 31 | -- * Functor and Applicative 32 | fmap, 33 | fmapConst, 34 | (<$>), 35 | (<$), 36 | ($>), 37 | pure, 38 | liftA2, 39 | liftA3, 40 | 41 | -- * Numeric 42 | fromIntegral, 43 | realToFrac, 44 | round, 45 | fromIntegral', 46 | 47 | -- * Misc 48 | E.coerce, 49 | ) where 50 | 51 | import qualified Numeric.Backprop.Explicit as E 52 | import Numeric.Backprop.Num 53 | import qualified Prelude.Backprop.Explicit as E 54 | import Prelude ( 55 | Applicative, 56 | Foldable, 57 | Fractional (..), 58 | Functor, 59 | Num (..), 60 | Ord (..), 61 | Traversable, 62 | ) 63 | import qualified Prelude as P 64 | 65 | -- | 'Prelude.Backprop.sum', but with 'Num' constraints instead of 66 | -- 'Backprop' constraints. 67 | sum :: 68 | (Foldable t, Functor t, Num (t a), Num a, Reifies s W) => 69 | BVar s (t a) -> 70 | BVar s a 71 | sum = E.sum E.afNum 72 | {-# INLINE sum #-} 73 | 74 | -- | 'Prelude.Backprop.pure', but with 'Num' constraints instead of 75 | -- 'Backprop' constraints. 76 | pure :: 77 | (Foldable t, Applicative t, Num a, Reifies s W) => 78 | BVar s a -> 79 | BVar s (t a) 80 | pure = E.pure E.afNum E.zfNum 81 | {-# INLINE pure #-} 82 | 83 | -- | 'Prelude.Backprop.product', but with 'Num' constraints instead of 84 | -- 'Backprop' constraints. 85 | product :: 86 | (Foldable t, Functor t, Num (t a), Fractional a, Reifies s W) => 87 | BVar s (t a) -> 88 | BVar s a 89 | product = E.product E.afNum 90 | {-# INLINE product #-} 91 | 92 | -- | 'Prelude.Backprop.length', but with 'Num' constraints instead of 93 | -- 'Backprop' constraints. 94 | length :: 95 | (Foldable t, Num (t a), Num b, Reifies s W) => 96 | BVar s (t a) -> 97 | BVar s b 98 | length = E.length E.afNum E.zfNum 99 | {-# INLINE length #-} 100 | 101 | -- | 'Prelude.Backprop.minimum', but with 'Num' constraints instead of 102 | -- 'Backprop' constraints. 103 | minimum :: 104 | (Foldable t, Functor t, Num a, Ord a, Num (t a), Reifies s W) => 105 | BVar s (t a) -> 106 | BVar s a 107 | minimum = E.minimum E.afNum E.zfNum 108 | {-# INLINE minimum #-} 109 | 110 | -- | 'Prelude.Backprop.maximum', but with 'Num' constraints instead of 111 | -- 'Backprop' constraints. 112 | maximum :: 113 | (Foldable t, Functor t, Num a, Ord a, Num (t a), Reifies s W) => 114 | BVar s (t a) -> 115 | BVar s a 116 | maximum = E.maximum E.afNum E.zfNum 117 | {-# INLINE maximum #-} 118 | 119 | -- | 'Prelude.Backprop.foldr', but with 'Num' constraints instead of 120 | -- 'Backprop' constraints. 121 | -- 122 | -- @since 0.2.3.0 123 | foldr :: 124 | (Traversable t, Num a, Reifies s W) => 125 | (BVar s a -> BVar s b -> BVar s b) -> 126 | BVar s b -> 127 | BVar s (t a) -> 128 | BVar s b 129 | foldr = E.foldr E.afNum E.zfNum 130 | {-# INLINE foldr #-} 131 | 132 | -- | 'Prelude.Backprop.foldl'', but with 'Num' constraints instead of 133 | -- 'Backprop' constraints. 134 | -- 135 | -- @since 0.2.3.0 136 | foldl' :: 137 | (Traversable t, Num a, Reifies s W) => 138 | (BVar s b -> BVar s a -> BVar s b) -> 139 | BVar s b -> 140 | BVar s (t a) -> 141 | BVar s b 142 | foldl' = E.foldl' E.afNum E.zfNum 143 | {-# INLINE foldl' #-} 144 | 145 | -- | 'Prelude.Backprop.fmap', but with 'Num' constraints instead of 146 | -- 'Backprop' constraints. 147 | fmap :: 148 | (Traversable f, Num a, Num b, Reifies s W) => 149 | (BVar s a -> BVar s b) -> 150 | BVar s (f a) -> 151 | BVar s (f b) 152 | fmap = E.fmap E.afNum E.afNum E.zfNum E.zfNum 153 | {-# INLINE fmap #-} 154 | 155 | -- | 'Prelude.Backprop.fmapConst', but with 'Num' constraints instead of 156 | -- 'Backprop' constraints. 157 | -- 158 | -- @since 0.2.4.0 159 | fmapConst :: 160 | (Functor f, Foldable f, Num b, Num (f a), Reifies s W) => 161 | BVar s b -> 162 | BVar s (f a) -> 163 | BVar s (f b) 164 | fmapConst = E.fmapConst E.afNum E.afNum E.zfNum E.zfNum 165 | {-# INLINE fmapConst #-} 166 | 167 | -- | Alias for 'fmap'. 168 | (<$>) :: 169 | (Traversable f, Num a, Num b, Reifies s W) => 170 | (BVar s a -> BVar s b) -> 171 | BVar s (f a) -> 172 | BVar s (f b) 173 | (<$>) = fmap 174 | 175 | infixl 4 <$> 176 | {-# INLINE (<$>) #-} 177 | 178 | -- | Alias for 'fmapConst'. 179 | -- 180 | -- @since 0.2.4.0 181 | (<$) :: 182 | (Functor f, Foldable f, Num b, Num (f a), Reifies s W) => 183 | BVar s b -> 184 | BVar s (f a) -> 185 | BVar s (f b) 186 | (<$) = fmapConst 187 | 188 | infixl 4 <$ 189 | {-# INLINE (<$) #-} 190 | 191 | -- | Alias for @'flip' 'fmapConst'@. 192 | -- 193 | -- @since 0.2.4.0 194 | ($>) :: 195 | (Functor f, Foldable f, Num b, Num (f a), Reifies s W) => 196 | BVar s (f a) -> 197 | BVar s b -> 198 | BVar s (f b) 199 | xs $> x = x <$ xs 200 | 201 | infixl 4 $> 202 | {-# INLINE ($>) #-} 203 | 204 | -- | 'Prelude.Backprop.traverse', but with 'Num' constraints instead of 205 | -- 'Backprop' constraints. 206 | -- 207 | -- See for 208 | -- a fixed-length vector type with a very appropriate 'Num' instance! 209 | traverse :: 210 | (Traversable t, Applicative f, Foldable f, Num a, Num b, Num (t b), Reifies s W) => 211 | (BVar s a -> f (BVar s b)) -> 212 | BVar s (t a) -> 213 | BVar s (f (t b)) 214 | traverse = E.traverse E.afNum E.afNum E.afNum E.zfNum E.zfNum 215 | {-# INLINE traverse #-} 216 | 217 | -- | 'Prelude.Backprop.liftA2', but with 'Num' constraints instead of 218 | -- 'Backprop' constraints. 219 | liftA2 :: 220 | ( Traversable f 221 | , Applicative f 222 | , Num a 223 | , Num b 224 | , Num c 225 | , Reifies s W 226 | ) => 227 | (BVar s a -> BVar s b -> BVar s c) -> 228 | BVar s (f a) -> 229 | BVar s (f b) -> 230 | BVar s (f c) 231 | liftA2 = E.liftA2 E.afNum E.afNum E.afNum E.zfNum E.zfNum E.zfNum 232 | {-# INLINE liftA2 #-} 233 | 234 | -- | 'Prelude.Backprop.liftA3', but with 'Num' constraints instead of 235 | -- 'Backprop' constraints. 236 | liftA3 :: 237 | ( Traversable f 238 | , Applicative f 239 | , Num a 240 | , Num b 241 | , Num c 242 | , Num d 243 | , Reifies s W 244 | ) => 245 | (BVar s a -> BVar s b -> BVar s c -> BVar s d) -> 246 | BVar s (f a) -> 247 | BVar s (f b) -> 248 | BVar s (f c) -> 249 | BVar s (f d) 250 | liftA3 = 251 | E.liftA3 252 | E.afNum 253 | E.afNum 254 | E.afNum 255 | E.afNum 256 | E.zfNum 257 | E.zfNum 258 | E.zfNum 259 | E.zfNum 260 | {-# INLINE liftA3 #-} 261 | 262 | -- | 'Prelude.Backprop.fromIntegral', but with 'Num' constraints instead of 263 | -- 'Backprop' constraints. 264 | -- 265 | -- @since 0.2.1.0 266 | fromIntegral :: 267 | (P.Integral a, P.Integral b, Reifies s W) => 268 | BVar s a -> 269 | BVar s b 270 | fromIntegral = E.fromIntegral E.afNum 271 | {-# INLINE fromIntegral #-} 272 | 273 | -- | 'Prelude.Backprop.realToFrac', but with 'Num' constraints instead of 274 | -- 'Backprop' constraints. 275 | -- 276 | -- @since 0.2.1.0 277 | realToFrac :: 278 | (Fractional a, P.Real a, Fractional b, P.Real b, Reifies s W) => 279 | BVar s a -> 280 | BVar s b 281 | realToFrac = E.realToFrac E.afNum 282 | {-# INLINE realToFrac #-} 283 | 284 | -- | 'Prelude.Backprop.round', but with 'Num' constraints instead of 285 | -- 'Backprop' constraints. 286 | -- 287 | -- @since 0.2.3.0 288 | round :: 289 | (P.RealFrac a, P.Integral b, Reifies s W) => 290 | BVar s a -> 291 | BVar s b 292 | round = E.round E.afNum 293 | {-# INLINE round #-} 294 | 295 | -- | 'Prelude.Backprop.fromIntegral'', but with 'Num' constraints instead 296 | -- of 'Backprop' constraints. 297 | -- 298 | -- @since 0.2.3.0 299 | fromIntegral' :: 300 | (P.Integral a, P.RealFrac b, Reifies s W) => 301 | BVar s a -> 302 | BVar s b 303 | fromIntegral' = E.fromIntegral' E.afNum 304 | {-# INLINE fromIntegral' #-} 305 | 306 | -- | 'Prelude.Backprop.toList', but with 'Num' constraints instead of 307 | -- 'Backprop' constraints. 308 | -- 309 | -- @since 0.2.2.0 310 | toList :: 311 | (Traversable t, Num a, Reifies s W) => 312 | BVar s (t a) -> 313 | [BVar s a] 314 | toList = E.toList E.afNum E.zfNum 315 | {-# INLINE toList #-} 316 | 317 | -- | 'Prelude.Backprop.mapAccumL', but with 'Num' constraints instead of 318 | -- 'Backprop' constraints. 319 | -- 320 | -- Prior to v0.2.3, required a 'Num' constraint on @t b@. 321 | -- 322 | -- @since 0.2.2.0 323 | mapAccumL :: 324 | (Traversable t, Num b, Num c, Reifies s W) => 325 | (BVar s a -> BVar s b -> (BVar s a, BVar s c)) -> 326 | BVar s a -> 327 | BVar s (t b) -> 328 | (BVar s a, BVar s (t c)) 329 | mapAccumL = E.mapAccumL E.afNum E.afNum E.zfNum E.zfNum 330 | {-# INLINE mapAccumL #-} 331 | 332 | -- | 'Prelude.Backprop.mapAccumR', but with 'Num' constraints instead of 333 | -- 'Backprop' constraints. 334 | -- 335 | -- Prior to v0.2.3, required a 'Num' constraint on @t b@. 336 | -- 337 | -- @since 0.2.2.0 338 | mapAccumR :: 339 | (Traversable t, Num b, Num c, Reifies s W) => 340 | (BVar s a -> BVar s b -> (BVar s a, BVar s c)) -> 341 | BVar s a -> 342 | BVar s (t b) -> 343 | (BVar s a, BVar s (t c)) 344 | mapAccumR = E.mapAccumR E.afNum E.afNum E.zfNum E.zfNum 345 | {-# INLINE mapAccumR #-} 346 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | main :: IO () 2 | main = putStrLn "Test suite not yet implemented" 3 | --------------------------------------------------------------------------------