├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── Setup.hs ├── adev.cabal ├── examples ├── Figure2.hs └── ParticleFilterExample.hs ├── figures ├── adev-diagram.png └── example.png ├── package.yaml ├── src └── Numeric │ └── ADEV │ ├── Class.hs │ ├── Diff.hs │ ├── Distributions.hs │ └── Interp.hs ├── stack.yaml ├── stack.yaml.lock └── test └── Spec.hs /.gitignore: -------------------------------------------------------------------------------- 1 | .stack-work/ 2 | *~ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog for `adev` 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to the 7 | [Haskell Package Versioning Policy](https://pvp.haskell.org/). 8 | 9 | ## Unreleased 10 | 11 | ## 0.1.0.0 - 2022-12-08 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Alex Lew (c) 2022 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Alex Lew nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADEV 2 | 3 | This repository contains the Haskell prototype that accompanies the paper "[ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs](https://popl23.sigplan.org/details/POPL-2023-popl-research-papers/5/ADEV-Sound-Automatic-Differentiation-of-Expected-Values-of-Probabilistic-Programs)". **See [ADEV.jl](https://github.com/probcomp/ADEV.jl) for an experimental port to Julia.** 4 | 5 | ## Overview 6 | 7 | ![Overview of ADEV](https://github.com/probcomp/adev/blob/main/figures/adev-diagram.png) 8 | ADEV is a method of automatically differentiating loss functions defined as *expected values* of probabilistic processes. ADEV users define a _probabilistic program_ $t$, which, given a parameter of type $\mathbb{R}$ (or a subtype), outputs a value of type $\widetilde{\mathbb{R}}$, 9 | which represents probabilistic estimators of losses. We translate $t$ to a new probabilistic program $s$, 10 | whose expected return value is the derivative of $t$’s expected return value. Running $s$ yields provably unbiased 11 | estimates $x_i$ of the loss's derivative, which can be used in the inner loop of stochastic optimization algorithms like ADAM or stochastic gradient descent. 12 | 13 | ADEV goes beyond standard AD by explicitly supporting probabilistic primitives, like `flip`, for flipping a coin. If these probabilistic constructs are ignored, standard AD may produce incorrect results, as this figure from our paper illustrates: 14 | ![Optimizing an example loss function using ADEV](https://github.com/probcomp/adev/blob/main/figures/example.png) 15 | In this example, standard AD 16 | fails to account for the parameter $\theta$'s effect on the *probability* of entering each branch. ADEV, by contrast, correctly accounts 17 | for the probabilistic effects, generating similar code to what a practitioner might hand-derive. Correct 18 | gradients are often crucial for downstream applications, e.g. optimization via stochastic gradient descent. 19 | 20 | ADEV compositionally supports various gradient estimation strategies from the literature, including: 21 | - Reparameterization trick (Kingma & Welling 2014) 22 | - Score function estimator (Ranganath et al. 2014) 23 | - Baselines as control variates (Mnih and Gregor 2014) 24 | - Multi-sample estimators that Storchastic supports (e.g. leave-one-out baselines) (van Krieken et al. 2021) 25 | - Variance reduction via dependency tracking (Schulman et al. 2015) 26 | - Special estimators for differentiable particle filtering (Ścibior et al. 2021) 27 | - Implicit reparameterization (Figurnov et al. 2018) 28 | - Measure-valued derivatives (Heidergott and Vázquez-Abad 2000) 29 | - Reparameterized rejection sampling (Nasseth et al. 2017) 30 | 31 | 32 | ## Haskell Example 33 | 34 | ADEV extends forward-mode automatic differentiation to support *probabilistic programs*. Consider the following example: 35 | 36 | ```haskell 37 | import Numeric.ADEV.Class (ADEV(..)) 38 | import Numeric.ADEV.Interp () 39 | import Numeric.ADEV.Diff (diff) 40 | import Control.Monad (replicateM) 41 | import Control.Monad.Bayes.Sampler.Strict (sampleIO) 42 | 43 | -- Define a loss function l as the expected 44 | -- value of a probabilistic process. 45 | l theta = expect $ do 46 | b <- flip_reinforce theta 47 | if b then 48 | return 0 49 | else 50 | return (-theta / 2) 51 | 52 | -- Take its derivative. 53 | l' = diff l 54 | 55 | -- Helper function for computing averages 56 | mean xs = sum xs / (realToFrac $ length xs) 57 | 58 | -- Estimating the loss and its derivative 59 | -- by averaging many samples 60 | estimate_loss = fmap mean (replicateM 1000 (l 0.4)) 61 | estimate_deriv = fmap mean (replicateM 1000 (l' 0.4)) 62 | 63 | main = do 64 | loss <- sampleIO estimate_loss 65 | deriv <- sampleIO estimate_deriv 66 | print (loss, deriv) 67 | ``` 68 | 69 | **Defining a loss.** The function `l` is defined to be the *expected value* of a probabilistic process, using `expect`. The process in question involves flipping a coin, whose probability of heads is `theta`, and returning either `0` or `-theta / 2`, depending on the coin flip's result. 70 | 71 | **Differentiating.** ADEV's `diff` operator converts such a loss into a new function `l'` representing its derivative, with respect to the input parameter `theta`. 72 | 73 | **Running the estimators.** Operationally, neither `l` nor `l'` compute exact expectations (or derivatives of expecations): instead, they represent _unbiased estimators_ of the desired values, which can be run using `sampleIO`. 74 | On one run, the above code printed `(-0.122, -0.10)`, which are very close to the correct values of $-0.12$ and $-0.1$. 75 | 76 | **Composing `expect` with other operators.** Note that ADEV also provides primitives for manipulating expected values, e.g. `exp_` for taking their exponents. For example, the code `fmap mean (replicateM 1000 (exp_ (l 0.4)))` yielded `0.881` on a sample run, close to the true value of $e^{-0.12} = 0.886$. This is the exponent of the expected value, not the expected value of the exponent, which is slightly different, and would yield $0.6 \times e^{-0.2} + 0.4 \times e^0 = 0.891$. 77 | 78 | **Optimization.** We can use ADEV's estimated derivatives to implement a stochastic optimization algorithm: 79 | 80 | ```haskell 81 | sgd loss eta x0 steps = 82 | if steps == 0 then 83 | return [x0] 84 | else do 85 | v <- diff loss x0 86 | let x1 = x0 - eta * v 87 | xs <- sgd loss eta x1 (steps - 1) 88 | return (x0:xs) 89 | ``` 90 | 91 | Running `sampleIO $ sgd l 0.2 0.2 100` finds the value of $\theta$ that minimizes $l$, namely $\theta = 0.5$. 92 | 93 | ## Haskell Encoding of ADEV Programs 94 | 95 | In the ADEV paper, the program `l` above would have type $\mathbb{R} \to \widetilde{\mathbb{R}}$. 96 | In Haskell, its type is `ADEV p m r => r -> m r`. Why? 97 | 98 | In general, expressions in the ADEV source language are represented by Haskell expressions with polymorphic type `ADEV p m r => ...`, where the `...` is a Haskell type that uses the three type variables `p`, `m`, and `r` as follows: 99 | 100 | * `r` represents real numbers, $\mathbb{R}$ in the ADEV paper. (The type of positive reals reals, $\mathbb{R}_{>0}$, is represented as `Log r`.) 101 | * `m r` represents estimated real numbers, $\widetilde{\mathbb{R}}$ in the ADEV paper. 102 | * `p m a` represents probabilistic programs returning `a`, $P~a$ in the ADEV paper. 103 | 104 | Below, we show how `l`'s type relates to the types of its sub-expressions: 105 | ```haskell 106 | -- `flip_reinforce` takes a real parameter and 107 | -- probabilistically outputs a Boolean. 108 | flip_reinforce :: ADEV p m r => r -> p m Bool 109 | 110 | -- Using `do`, we can build a larger computation that 111 | -- uses the result of a flip to compute a real. 112 | -- Its type reflects that it still takes a real parameter 113 | -- as input, but now probabilistically outputs a real. 114 | prog :: ADEV p m r => r -> p m r 115 | prog theta = do 116 | b <- flip_reinforce theta 117 | if b then 118 | return 0 119 | else 120 | return (-theta/2) 121 | 122 | -- The `expect` operation turns a probabilistic computation 123 | -- over reals (type P R) into an estimator of its expected 124 | -- value (type R~). 125 | expect :: ADEV p m r => p m r -> m r 126 | 127 | -- By composing expect and prog, we get l from above. 128 | l :: ADEV p m r => r -> m r 129 | l = expect . prog 130 | ``` 131 | 132 | ## Implementation 133 | 134 | To understand ADEV's implementation, it is useful to first skim the ADEV paper, which explains how ADEV modularly extends standard forward-mode AD with support for probabilistic primitives. The Haskell code is a relatively direct encoding of the ideas described in the paper. Briefly: 135 | 136 | * All the primitives in the ADEV language, including those introduced by the extensions from Appendix B, are encoded as methods of the `ADEV` typeclass, in the [Numeric.ADEV.Class](src/Numeric/ADEV/Class.hs) module. This is like a 'specification' that a specific interpreter of the ADEV language can satisfy. It leaves open what concrete Haskell types will be used to represent the ADEV types of real numbers $\mathbb{R}$, estimated reals $\widetilde{\mathbb{R}}$, and monadic probabilistic programs $P~\tau$ — it uses the type variables `r`, `m r`, and `p m tau` for this purpose. 137 | 138 | * The [Numeric.ADEV.Interp](src/Numeric/ADEV/Interp.hs) module provides one instance of the `ADEV` typeclass, implementing the standard semantics of an ADEV term. The type variables `p`, `m`, and `r` are instantiated so that the type of reals `r` is interpreted as `Double`, the type `m r` of *estimated* reals is interpreted as the type `m Double` for some `MonadDistribution` `m` (where `MonadDistribution` is the [monad-bayes](https://github.com/tweag/monad-bayes) typeclass for probabilistic programs), and the type of probabilistic programs `p m a` is interpreted as `WriterT Sum m a` (the `Sum` maintains an accumulated loss, and is described in Appendix B.2 of the ADEV paper). 139 | 140 | * The [Numeric.ADEV.Diff](src/Numeric/ADEV/Diff.hs) module provides built-in derivatives for each primitive. These are organized into a second instance of the `ADEV` typeclass, where now the type `r` of reals is interpreted as `ForwardDouble`, representing forward-mode dual numbers $\mathcal{D}\\{\mathbb{R}\\}$ from the paper; the type `m r` of estimated reals is interpreted as the type `m ForwardDouble` for some `MonadDistribution` `m`, which implements the type $\widetilde{\mathbb{R}}\_\mathcal{D}$ of estimated dual numbers from the paper; and the type `p m tau` of probabilistic programs is interpreted as `ContT ForwardDouble m tau`, i.e., the type of *higher-order functions* that transform an input `loss_to_go : tau -> m ForwardDouble` into an estimated dual-number loss of type `m ForwardDouble` (this implements the type $P\_\mathcal{D}~\tau$ from the paper). 141 | 142 | 143 | ## Installing ADEV 144 | 1. Install `stack` (https://docs.haskellstack.org/en/stable/install_and_upgrade/). 145 | 2. Clone this repository. 146 | 3. Run the examples using `stack run ExampleName`, where `ExampleName.hs` is the name of a file from the `examples` directory. 147 | 4. Or: Run `stack ghci` to enter a REPL. 148 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /adev.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.35.0. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: adev 8 | version: 0.1.0.0 9 | description: Please see the README on GitHub at 10 | homepage: https://github.com/alex-lew/adev#readme 11 | bug-reports: https://github.com/alex-lew/adev/issues 12 | author: Alex Lew 13 | maintainer: alexlew@mit.edu 14 | copyright: Copyright (c) 2022 15 | license: BSD3 16 | license-file: LICENSE 17 | build-type: Simple 18 | extra-source-files: 19 | README.md 20 | CHANGELOG.md 21 | 22 | source-repository head 23 | type: git 24 | location: https://github.com/alex-lew/adev 25 | 26 | library 27 | exposed-modules: 28 | Numeric.ADEV.Class 29 | Numeric.ADEV.Diff 30 | Numeric.ADEV.Distributions 31 | Numeric.ADEV.Interp 32 | other-modules: 33 | Paths_adev 34 | hs-source-dirs: 35 | src 36 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints 37 | build-depends: 38 | ad ==4.5.2.* 39 | , base >=4.7 && <5 40 | , log-domain >=0.12 && <0.14 41 | , monad-bayes >=1.1.0 42 | , mtl ==2.2.2.* 43 | , transformers ==0.5.6.2.* 44 | , vector >=0.12.3.1 && <0.12.4 45 | default-language: Haskell2010 46 | 47 | executable Figure2 48 | main-is: Figure2.hs 49 | hs-source-dirs: 50 | examples 51 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 52 | build-depends: 53 | ad ==4.5.2.* 54 | , adev 55 | , base >=4.7 && <5 56 | , log-domain >=0.12 && <0.14 57 | , monad-bayes >=1.1.0 58 | , mtl ==2.2.2.* 59 | , transformers ==0.5.6.2.* 60 | , vector >=0.12.3.1 && <0.12.4 61 | default-language: Haskell2010 62 | 63 | executable ParticleFilterExample 64 | main-is: ParticleFilterExample.hs 65 | hs-source-dirs: 66 | examples 67 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 68 | build-depends: 69 | ad ==4.5.2.* 70 | , adev 71 | , base >=4.7 && <5 72 | , log-domain >=0.12 && <0.14 73 | , monad-bayes >=1.1.0 74 | , mtl ==2.2.2.* 75 | , transformers ==0.5.6.2.* 76 | , vector >=0.12.3.1 && <0.12.4 77 | default-language: Haskell2010 78 | 79 | test-suite adev-test 80 | type: exitcode-stdio-1.0 81 | main-is: Spec.hs 82 | other-modules: 83 | Paths_adev 84 | hs-source-dirs: 85 | test 86 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 87 | build-depends: 88 | ad ==4.5.2.* 89 | , adev 90 | , base >=4.7 && <5 91 | , log-domain >=0.12 && <0.14 92 | , monad-bayes >=1.1.0 93 | , mtl ==2.2.2.* 94 | , transformers ==0.5.6.2.* 95 | , vector >=0.12.3.1 && <0.12.4 96 | default-language: Haskell2010 97 | -------------------------------------------------------------------------------- /examples/Figure2.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Numeric.ADEV.Class 4 | import Numeric.ADEV.Diff (diff) 5 | import Numeric.ADEV.Interp () 6 | import Numeric.AD.Mode.Forward.Double (ForwardDouble) 7 | import Control.Monad.Bayes.Class (MonadDistribution) 8 | import Control.Monad.Bayes.Sampler.Strict (sampleIO) 9 | 10 | l :: ADEV p m r => r -> m r 11 | l theta = expect $ do 12 | b <- flip_reinforce theta 13 | if b then 14 | return 0 15 | else 16 | return (-theta / 2) 17 | 18 | sgd :: MonadDistribution m => (ForwardDouble -> m ForwardDouble) -> Double -> Double -> Int -> m [Double] 19 | sgd loss eta x0 steps = 20 | if steps == 0 then 21 | return [x0] 22 | else do 23 | v <- diff loss x0 24 | let x1 = x0 - eta * v 25 | xs <- sgd loss eta x1 (steps - 1) 26 | return (x0:xs) 27 | 28 | main :: IO () 29 | main = do 30 | vs <- sampleIO $ sgd l 0.2 0.2 100 31 | print vs -------------------------------------------------------------------------------- /examples/ParticleFilterExample.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Numeric.ADEV.Class 4 | import Numeric.ADEV.Diff (diff) 5 | import Numeric.ADEV.Interp () 6 | import Numeric.ADEV.Distributions (normalD) 7 | import Numeric.AD.Mode.Forward.Double (ForwardDouble) 8 | import Control.Monad.Bayes.Class (MonadDistribution) 9 | import Control.Monad.Bayes.Sampler.Strict (sampleIO) 10 | import Numeric.Log (Log(..)) 11 | 12 | -- smc :: ([a] -> Log r) -> D m r a -> (a -> D m r a) -> ([a] -> m r) -> Int -> Int -> m r 13 | 14 | dens :: (RealFrac r, Floating r) => D m r Double -> Double -> Log r 15 | dens (D _ f) x = f x 16 | 17 | normalDensity :: (RealFrac r, Floating r) => r -> r -> Double -> Log r 18 | normalDensity mu sig x = Exp $ -log(sig) - log(2*pi) / 2 - ((realToFrac x)-mu)^2/(2*sig^2) 19 | 20 | ys = [undefined, 1,2,3,4,5] 21 | 22 | l :: (MonadDistribution m, RealFloat r, Floating r, ADEV p m r) => r -> m r 23 | l theta = smc p q0 q f 2 1000 24 | where 25 | p xs = let xys = zip (map realToFrac xs) (reverse (take (length xs) ys)) in 26 | pxys xys 27 | pxys [] = undefined 28 | pxys [(x, y)] = normalDensity 0 (exp theta) (realToFrac x) 29 | pxys ((x,y):((xprev,yprev):xys)) = normalDensity (realToFrac xprev) (exp theta) (realToFrac x) * normalDensity (realToFrac x) 1 y * pxys ((xprev,yprev):xys) 30 | q0 = normalD 0 (exp theta) 31 | q x = normalD (realToFrac x) (exp theta) 32 | f xs = return 1 33 | 34 | sga :: MonadDistribution m => (ForwardDouble -> m ForwardDouble) -> Double -> Double -> Int -> m [Double] 35 | sga loss eta x0 steps = 36 | if steps == 0 then 37 | return [x0] 38 | else do 39 | v <- diff loss x0 40 | let x1 = x0 + eta * v 41 | xs <- sga loss eta x1 (steps - 1) 42 | return (x0:xs) 43 | 44 | main :: IO () 45 | main = do 46 | vs <- sampleIO $ sga l 10.0 0.0 500 47 | print (vs) -------------------------------------------------------------------------------- /figures/adev-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probcomp/adev/57452f319d8edf098e9d6e8235c6f6864d434f90/figures/adev-diagram.png -------------------------------------------------------------------------------- /figures/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probcomp/adev/57452f319d8edf098e9d6e8235c6f6864d434f90/figures/example.png -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: adev 2 | version: 0.1.0.0 3 | github: "alex-lew/adev" 4 | license: BSD3 5 | author: "Alex Lew" 6 | maintainer: "alexlew@mit.edu" 7 | copyright: "Copyright (c) 2022" 8 | 9 | extra-source-files: 10 | - README.md 11 | - CHANGELOG.md 12 | 13 | # Metadata used when publishing your package 14 | # synopsis: Short description of your package 15 | # category: probabilistic programming 16 | 17 | # To avoid duplicated efforts in documentation and dealing with the 18 | # complications of embedding Haddock markup inside cabal files, it is 19 | # common to point users to the README.md file. 20 | description: Please see the README on GitHub at 21 | 22 | dependencies: 23 | - base >= 4.7 && < 5 24 | - monad-bayes >=1.1.0 25 | - log-domain >=0.12 && <0.14 26 | - mtl >=2.2.2 && <2.2.3 27 | - ad >=4.5.2 && <4.5.3 28 | - transformers >=0.5.6.2 && <0.5.6.3 29 | - vector >=0.12.3.1 && <0.12.4 30 | 31 | ghc-options: 32 | - -Wall 33 | - -Wcompat 34 | - -Widentities 35 | - -Wincomplete-record-updates 36 | - -Wincomplete-uni-patterns 37 | - -Wmissing-export-lists 38 | - -Wmissing-home-modules 39 | - -Wpartial-fields 40 | - -Wredundant-constraints 41 | 42 | library: 43 | source-dirs: src 44 | 45 | executables: 46 | Figure2: 47 | main: Figure2.hs 48 | source-dirs: examples 49 | other-modules: [] 50 | ghc-options: 51 | - -threaded 52 | - -rtsopts 53 | - -with-rtsopts=-N 54 | dependencies: 55 | - adev 56 | ParticleFilterExample: 57 | main: ParticleFilterExample.hs 58 | source-dirs: examples 59 | other-modules: [] 60 | ghc-options: 61 | - -threaded 62 | - -rtsopts 63 | - -with-rtsopts=-N 64 | dependencies: 65 | - adev 66 | 67 | tests: 68 | adev-test: 69 | main: Spec.hs 70 | source-dirs: test 71 | ghc-options: 72 | - -threaded 73 | - -rtsopts 74 | - -with-rtsopts=-N 75 | dependencies: 76 | - adev 77 | -------------------------------------------------------------------------------- /src/Numeric/ADEV/Class.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-} 2 | 3 | module Numeric.ADEV.Class ( 4 | D(..), C(..), ADEV(..)) where 5 | 6 | import Numeric.Log 7 | 8 | -- | Type of density-carrying distributions. 9 | data D m r a = D (m a) (a -> Log r) 10 | 11 | -- | Type of CDF-carrying distributions over the reals, 12 | -- for implicit reparameterization. 13 | data C m r = C (m Double) (Double -> Log Double) (r -> Log r) 14 | 15 | -- --------------------------------------------------------------------------- 16 | -- | A typeclass for ADEV programs, parameterized by: 17 | -- 18 | -- * @r@ - the type used to represent real numbers 19 | -- * @m@ - the monad used to encode randomness (so that @m r@ is the type of 20 | -- unbiasedly estimated real numbers) 21 | -- * @p@ - the type used for monadic probabilistic programming (so 22 | -- that @p m a@ is a probabilistic program returning @a@) 23 | class (RealFrac r, Monad (p m), Monad m) => ADEV p m r | p -> r, r -> p where 24 | -- | Sample a random uniform value between 0 and 1. 25 | sample :: p m r 26 | -- | Add a real value into a running cost accumulator. 27 | -- When a @p m r@ is passed to @expect@, the result is 28 | -- an estimator of the expected cost *plus* the expected 29 | -- return value. 30 | add_cost :: r -> p m () 31 | -- | Flip a coin with a specified probability of heads. 32 | -- Uses enumeration (costly but low-variance) to estimate 33 | -- gradients. 34 | flip_enum :: r -> p m Bool 35 | -- | Flip a coin with a specified probability of heads. 36 | -- Uses the REINFORCE estimator (cheaper but higher-variance) 37 | -- for gradients. 38 | flip_reinforce :: r -> p m Bool 39 | -- | Generate from a normal distribution. Uses the REPARAM gradient estimator. 40 | normal_reparam :: r -> r -> p m r 41 | -- | Generate from a normal distribution. Uses the REINFORCE gradient estimator. 42 | normal_reinforce :: r -> r -> p m r 43 | -- | Estimate the expectation of a probabilistic computation. 44 | expect :: p m r -> m r 45 | -- | Combinator DSL for estimators 46 | plus_ :: m r -> m r -> m r 47 | times_ :: m r -> m r -> m r 48 | exp_ :: m r -> m r 49 | minibatch_ :: Int -> Int -> (Int -> m r) -> m r 50 | exact_ :: r -> m r 51 | -- | Baselines for controlling variance 52 | baseline :: p m r -> r -> m r 53 | -- | Automatic construction of new REINFORCE estimators 54 | reinforce :: D m r a -> p m a 55 | -- | Storchastic leave_one_out estimator 56 | leave_one_out :: Int -> D m r a -> p m a 57 | -- | Differentiable particle filter, accepting: 58 | -- * @p@: a density function for the target measure. 59 | -- * @q0@: an initial proposal for the particle filter. 60 | -- * @q@: a transition proposal for the particle filter. 61 | -- * @f@: an unbiased estimator of an integrand to estimate 62 | -- * @n@: the number of SMC steps to run 63 | -- * @k@: the number of particles to use 64 | -- Returns an SMC estimator of the integral 65 | smc :: ([a] -> Log r) -> D m r a -> (a -> D m r a) -> ([a] -> m r) -> Int -> Int -> m r 66 | -- | Importance sampling gradient estimator 67 | importance :: D m r a -> D m r a -> p m a 68 | -- | Implicit reparameterization for real-valued distributions 69 | -- differentiable with CDFs (e.g., mixtures of Gaussians) 70 | implicit_reparam :: C m r -> p m r 71 | -- | Sample from a Poisson distribution, using a measure-valued derivative. 72 | poisson_weak :: Log r -> p m Int 73 | -- | Gradients through rejection sampling for density-carrying distributions. 74 | reparam_reject :: D m r a -> (a -> b) -> (D m r b) -> (D m r b) -> Log r -> p m b 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/Numeric/ADEV/Diff.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, FunctionalDependencies #-} 2 | 3 | module Numeric.ADEV.Diff ( 4 | ADEV(..), diff 5 | ) where 6 | 7 | import Numeric.ADEV.Class 8 | import Numeric.ADEV.Interp() 9 | import Control.Monad.Bayes.Class ( 10 | MonadDistribution, 11 | uniform, 12 | uniformD, 13 | logCategorical, 14 | poisson, 15 | bernoulli, 16 | normal) 17 | import Control.Monad.Cont (ContT(..)) 18 | import Numeric.AD.Internal.Forward.Double ( 19 | ForwardDouble, 20 | bundle, 21 | primal, 22 | tangent) 23 | import Control.Monad (replicateM, mapM) 24 | import Numeric.Log (Log(..)) 25 | import qualified Numeric.Log as Log (sum) 26 | import Data.List (zipWith4) 27 | import qualified Data.Vector as V 28 | 29 | split :: ForwardDouble -> (Double, Double) 30 | split dx = (primal dx, tangent dx) 31 | 32 | -- | ADEV translation of an ADEV program. 33 | -- Implements the built-in derivative for every ADEV primitive. 34 | -- * Reals are interpreted as ForwardDoubles, pairs of Doubles. 35 | -- * Underlying randomness is provided by a monad @m@ satisfying the 36 | -- monad-bayes @MonadSample@ interface. 37 | -- * ADEV probabilistic programs are represented by the monad 38 | -- @ContT ForwardDouble m@: they know how to transform estimators of 39 | -- losses and loss derivatives into estimators of *expected* losses and 40 | -- loss derivatives, where the expectation is taken over the probabilistic 41 | -- program in question. 42 | instance MonadDistribution m => ADEV (ContT ForwardDouble) m ForwardDouble where 43 | sample = ContT $ \dloss -> do 44 | u <- uniform 0 1 45 | dloss (bundle u 0) 46 | 47 | flip_enum dp = ContT $ \dloss -> do 48 | dl1 <- dloss True 49 | dl2 <- dloss False 50 | return (dp * dl1 + (1 - dp) * dl2) 51 | 52 | flip_reinforce dp = ContT $ \dloss -> do 53 | b <- bernoulli (primal dp) 54 | (l, l') <- fmap split (dloss b) 55 | let logpdf' = tangent (log $ if b then dp else 1 - dp) 56 | return (bundle l (l' + l * logpdf')) 57 | 58 | normal_reparam dmu dsig = do 59 | deps <- stdnorm 60 | return $ (deps * dsig) + dmu 61 | where 62 | stdnorm = ContT $ \dloss -> do 63 | eps <- normal 0 1 64 | dloss (bundle eps 0) 65 | 66 | normal_reinforce dmu dsig = ContT $ \dloss -> do 67 | x <- normal (primal dmu) (primal dsig) 68 | let dx = bundle x 0 69 | (l, l') <- fmap split (dloss dx) 70 | let logpdf' = tangent $ (-1 * log dsig) - 0.5 * ((dx - dmu) / dsig)^2 71 | return (bundle l (l' + l * logpdf')) 72 | 73 | add_cost dcost = ContT $ \dloss -> do 74 | dl <- dloss () 75 | return (dl + dcost) 76 | 77 | expect prog = runContT prog return 78 | 79 | plus_ estimate_da estimate_db = do -- different from paper's estimator 80 | da <- estimate_da 81 | db <- estimate_db 82 | return (da + db) 83 | 84 | times_ estimate_da estimate_db = do 85 | da <- estimate_da 86 | db <- estimate_db 87 | return (da * db) 88 | 89 | exp_ estimate_dx = do 90 | (x, x') <- (fmap split estimate_dx) 91 | s <- exp_ (fmap primal estimate_dx) 92 | return (bundle x (s * x')) 93 | 94 | minibatch_ n m estimate_df = do 95 | indices <- replicateM m (uniformD [1..n]) 96 | dfs <- mapM (\i -> estimate_df i) indices 97 | return $ (sum dfs) * (fromIntegral n / fromIntegral m) 98 | 99 | exact_ = return 100 | 101 | baseline dp db = do 102 | dl <- runContT dp (\dx -> return (dx - db)) 103 | return (dl + db) 104 | 105 | reinforce (D dsamp dpdf) = ContT $ \dloss -> do 106 | x <- dsamp 107 | (l, l') <- fmap split (dloss x) 108 | let logpdf' = tangent $ ln (dpdf x) 109 | return (bundle l (l' + l * logpdf')) 110 | 111 | leave_one_out m (D dsamp dpdf) = ContT $ \dloss -> do 112 | xs <- replicateM m dsamp 113 | dlosses <- mapM dloss xs 114 | let (ls, l's) = unzip (map split dlosses) 115 | -- For each l, average the other ls to get a baseline 116 | let bs = map (\i -> (sum (take i ls) + sum (drop (i + 1) ls)) / (fromIntegral (m - 1))) [0..m-1] 117 | let logpdfs = map (tangent . ln . dpdf) xs 118 | return $ bundle (sum ls / fromIntegral m) (sum (zipWith4 (\l l' b lpdf -> l' + (l - b) * lpdf) ls l's bs logpdfs) / fromIntegral m) 119 | 120 | implicit_reparam (C samp pdf dcdf) = ContT $ \dloss -> do 121 | x <- samp 122 | let f' = tangent $ (exp . ln . dcdf) (bundle x 0) 123 | let p = (exp . ln . pdf) x 124 | dloss (bundle x (-f' / p)) 125 | 126 | poisson_weak drate = ContT $ \dloss -> do 127 | let (rate, rate') = split (exp (ln drate)) 128 | x_neg <- poisson rate 129 | let x_pos = x_neg + 1 130 | y_neg <- dloss x_neg 131 | y_pos <- dloss x_pos 132 | let grad = primal y_pos - primal y_neg 133 | return (bundle (primal y_neg) (grad * rate')) 134 | 135 | reparam_reject (D s spdf) h (D p ppdf) (D q qpdf) m = ContT $ \dloss -> 136 | runContT (reinforce dpi) (dloss . h) 137 | where 138 | pi = do 139 | eps <- s 140 | let x = h eps 141 | let w = exp ((primal (ln (ppdf x))) - (primal (ln (qpdf x)))) 142 | u <- uniform 0 1 143 | if u < w then return eps else pi 144 | dpi_density deps = spdf deps * ppdf (h deps) / qpdf (h deps) 145 | dpi = D pi dpi_density 146 | 147 | smc dp (D q0samp q0dens) dq df n k = do 148 | particles <- iterateM step init n 149 | values <- mapM (\(v, w) -> do 150 | (f, f') <- fmap split (df v) 151 | let logpdf' = tangent $ ln (dp v) 152 | return (bundle f (exp (ln w) * (f' + f * logpdf')))) particles 153 | return $ sum values / fromIntegral k 154 | where 155 | iterateM k m n = if n == 0 then m else do 156 | x <- m 157 | iterateM k (k x) (n - 1) 158 | pp = Exp . primal . ln . dp 159 | qq0 = Exp . primal . ln . q0dens 160 | init = replicateM k (do 161 | x <- q0samp 162 | return ([x], pp [x] / qq0 x)) 163 | resample particles = do 164 | let weights = map snd particles 165 | let total_weight = Log.sum weights 166 | let normed_weights = map (\w -> w / total_weight) weights 167 | indices <- replicateM k (logCategorical (V.fromList normed_weights)) 168 | let new_weights = replicate k (total_weight / fromIntegral k) 169 | return $ zip (map (\i -> fst (particles !! i)) indices) new_weights 170 | propagate particle = do 171 | let (v, w) = particle 172 | let (D qs qd) = dq (head v) 173 | let qqd = Exp . primal . ln . qd 174 | v' <- qs 175 | return (v':v, w * (pp (v':v) / pp v) / qqd v') 176 | step particles = do 177 | particles <- resample particles 178 | mapM propagate particles 179 | 180 | diff :: MonadDistribution m => (ForwardDouble -> m ForwardDouble) -> Double -> m Double 181 | diff f x = do 182 | df <- f (bundle x 1) 183 | return (tangent df) 184 | -------------------------------------------------------------------------------- /src/Numeric/ADEV/Distributions.hs: -------------------------------------------------------------------------------- 1 | module Numeric.ADEV.Distributions (normalD, geometricD) where 2 | 3 | import Numeric.ADEV.Class 4 | 5 | import Control.Monad.Bayes.Class ( 6 | MonadDistribution, 7 | geometric, 8 | normal) 9 | 10 | import Numeric.Log (Log(..)) 11 | 12 | normalD :: (MonadDistribution m, RealFrac r, Floating r) => r -> r -> D m r Double 13 | normalD mu sig = D (normal (realToFrac mu) (realToFrac sig)) (\x -> Exp $ -log(sig) - log(2*pi) / 2 - ((realToFrac x)-mu)^2/(2*sig^2)) 14 | 15 | geometricD :: (MonadDistribution m, RealFrac r, Floating r) => r -> D m r Int 16 | geometricD p = D (geometric (realToFrac p)) (\x -> Exp $ log(p) + (fromIntegral $ x-1) * log(1-p)) -------------------------------------------------------------------------------- /src/Numeric/ADEV/Interp.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, ImportQualifiedPost, FunctionalDependencies #-} 2 | 3 | module Numeric.ADEV.Interp where 4 | 5 | import Numeric.ADEV.Class 6 | import Control.Monad.Bayes.Class ( 7 | MonadDistribution, 8 | uniform, 9 | uniformD, 10 | logCategorical, 11 | score, 12 | bernoulli, 13 | poisson, 14 | normal) 15 | import Control.Monad.Trans.Class (lift) 16 | import Control.Monad.Trans.Writer.Lazy (WriterT(..), tell) 17 | import Data.Monoid (Sum(..)) 18 | import Control.Monad (replicateM, mapM) 19 | import Data.Vector qualified as V 20 | import Numeric.Log (Log(..)) 21 | import qualified Numeric.Log as Log (sum) 22 | 23 | -- | Standard, non-AD semantics of an ADEV program. 24 | -- * Reals are represented as Doubles. 25 | -- * Randomness comes from an underlying measure monad @m@ satisfying 26 | -- the monad-bayes @MonadInfer@ interface. 27 | -- * The ADEV probability monad is interpreted as @WriterT (Sum Double) m@, 28 | -- i.e. a probabilistic computation that accumulates an additive loss. 29 | instance MonadDistribution m => ADEV (WriterT (Sum Double)) m Double where 30 | sample = uniform 0 1 31 | flip_enum = bernoulli 32 | flip_reinforce = bernoulli 33 | normal_reparam = normal 34 | normal_reinforce = normal 35 | add_cost w = tell (Sum w) 36 | expect f = do {(x, w) <- runWriterT f; return (x + getSum w)} 37 | exact_ = return 38 | plus_ esta estb = do 39 | a <- esta 40 | b <- estb 41 | return (a + b) 42 | times_ esta estb = do 43 | a <- esta 44 | b <- estb 45 | return (a * b) 46 | exp_ estx = do 47 | n <- poisson rate 48 | xs <- replicateM n estx 49 | return $ exp rate * product (map (\x -> x / rate) xs) 50 | where rate = 2 51 | minibatch_ n m f = do 52 | indices <- replicateM m (uniformD [1..n]) 53 | vals <- mapM f indices 54 | return $ (fromIntegral n / fromIntegral m) * (sum vals) 55 | baseline p b = expect p 56 | reinforce (D sampler density) = lift sampler 57 | leave_one_out n (D sampler density) = lift sampler 58 | smc p (D q0samp q0dens) q f n k = do 59 | particles <- iterateM step init n 60 | values <- mapM (\(v, w) -> do 61 | x <- f v 62 | return (x * exp (ln w))) particles 63 | return $ sum values / fromIntegral k 64 | where 65 | iterateM k m n = if n == 0 then m else do 66 | x <- m 67 | iterateM k (k x) (n - 1) 68 | init = replicateM k (do 69 | x <- q0samp 70 | return ([x], p [x] / q0dens x)) 71 | resample particles = do 72 | let weights = map snd particles 73 | let total_weight = Log.sum weights 74 | let normed_weights = map (\w -> w / total_weight) weights 75 | indices <- replicateM k (logCategorical (V.fromList normed_weights)) 76 | let new_weights = replicate k (total_weight / fromIntegral k) 77 | return $ zip (map (\i -> fst (particles !! i)) indices) new_weights 78 | propagate particle = do 79 | let (v, w) = particle 80 | let (D qs qd) = q (head v) 81 | v' <- qs 82 | return (v':v, w * (p (v':v) / p v) / qd v') 83 | step particles = do 84 | particles <- resample particles 85 | mapM propagate particles 86 | importance (D samp _) _ = lift samp 87 | implicit_reparam (C samp pdf cdf) = lift samp 88 | poisson_weak (Exp rate) = poisson (exp rate) 89 | reparam_reject s h (D p ppdf) (D q qpdf) m = do 90 | x <- lift q 91 | let w = ppdf x / (m * qpdf x) 92 | u <- uniform 0 1 93 | if log u < ln w then do 94 | return x 95 | else 96 | reparam_reject s h (D p ppdf) (D q qpdf) m -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # 15 | # The location of a snapshot can be provided as a file or url. Stack assumes 16 | # a snapshot provided as a file might change, whereas a url resource does not. 17 | # 18 | # resolver: ./custom-snapshot.yaml 19 | # resolver: https://example.com/snapshots/2018-01-01.yaml 20 | resolver: 21 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/3.yaml 22 | 23 | # User packages to be built. 24 | # Various formats can be used as shown in the example below. 25 | # 26 | # packages: 27 | # - some-directory 28 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 29 | # subdirs: 30 | # - auto-update 31 | # - wai 32 | packages: 33 | - . 34 | # Dependency packages to be pulled from upstream that are not in the resolver. 35 | # These entries can reference officially published versions as well as 36 | # forks / in-progress versions pinned to a git hash. For example: 37 | # 38 | extra-deps: 39 | - monad-bayes-1.1.0 40 | # - acme-missiles-0.3 41 | # - git: https://github.com/commercialhaskell/stack.git 42 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 43 | # 44 | # extra-deps: [] 45 | 46 | # Override default flag values for local packages and extra-deps 47 | # flags: {} 48 | 49 | # Extra package databases containing global packages 50 | # extra-package-dbs: [] 51 | 52 | # Control whether we use the GHC we find on the path 53 | # system-ghc: true 54 | # 55 | # Require a specific version of stack, using version ranges 56 | # require-stack-version: -any # Default 57 | # require-stack-version: ">=2.9" 58 | # 59 | # Override the architecture used by stack, especially useful on Windows 60 | # arch: i386 61 | # arch: x86_64 62 | # 63 | # Extra directories used by stack for building 64 | # extra-include-dirs: [/path/to/dir] 65 | # extra-lib-dirs: [/path/to/dir] 66 | # 67 | # Allow a newer minor version of GHC than the snapshot specifies 68 | # compiler-check: newer-minor 69 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: 7 | - completed: 8 | hackage: monad-bayes-1.1.0@sha256:8929887b2883e553b928dcc9b1326171c87b6aa26f11800dc8c55b119a9e9649,6123 9 | pantry-tree: 10 | sha256: bf7f9b1351226a957c7ebd0c42316505be713690cd9d44425bd9cfd494a94161 11 | size: 3568 12 | original: 13 | hackage: monad-bayes-1.1.0 14 | snapshots: 15 | - completed: 16 | sha256: 03cec7d96ed78877b03b5c2bc5e31015b47f69af2fe6a62d994a42b7a43c5805 17 | size: 648659 18 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/3.yaml 19 | original: 20 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/3.yaml 21 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | main :: IO () 2 | main = putStrLn "Test suite not yet implemented" 3 | --------------------------------------------------------------------------------