├── .gitignore ├── LICENSE ├── README.md ├── Setup.hs ├── app └── Main.hs ├── package.yaml ├── src ├── BasicExamples.hs └── TraceTypes.hs ├── stack.yaml └── trace-types.cabal /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | dist-* 3 | cabal-dev 4 | *.o 5 | *.hi 6 | *.hie 7 | *.chi 8 | *.chs.h 9 | *.dyn_o 10 | *.dyn_hi 11 | .hpc 12 | .hsenv 13 | .cabal-sandbox/ 14 | cabal.sandbox.config 15 | *.prof 16 | *.aux 17 | *.hp 18 | *.eventlog 19 | .stack-work/ 20 | cabal.project.local 21 | cabal.project.local~ 22 | .HTF/ 23 | .ghc.environment.* 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 MIT Probabilistic Computing Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # haskell-trace-types 2 | 3 | Haskell prototype of the system described in [Trace Types and Denotational Semantics for Sound Programmable Inference in Probabilistic Languages](https://dl.acm.org/doi/10.1145/3371087). 4 | 5 | Run `stack ghci` to start an interactive prompt 6 | with definitions from `TraceTypes` and `BasicExamples` 7 | available. 8 | 9 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /app/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds, GADTs, KindSignatures, TypeFamilies, TypeOperators, TypeApplications, ScopedTypeVariables, MultiParamTypeClasses, OverloadedLabels, InstanceSigs, FlexibleContexts, AllowAmbiguousTypes, FlexibleInstances, RankNTypes, UndecidableInstances, ConstraintKinds, OverlappingInstances, TypeFamilyDependencies #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE PartialTypeSignatures #-} 5 | {-# LANGUAGE RecordWildCards, Rank2Types, ConstraintKinds #-} 6 | 7 | module Main where 8 | 9 | import TraceTypes 10 | import Control.Monad.Bayes.Sampler 11 | import Control.Monad.Bayes.Weighted 12 | import Data.Row.Records 13 | 14 | run = sampleIO . runWeighted 15 | 16 | main :: IO () 17 | main = do (b, w) <- run . traceSampler $ sample #b (brn $ mkUnit 0.4) 18 | putStrLn (show b) 19 | -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-trace-types 2 | version: 0.1.0.0 3 | github: "probcomp/haskell-trace-types" 4 | license: MIT 5 | author: "Alex Lew" 6 | maintainer: "alexlew@mit.edu" 7 | copyright: "2019 Alex Lew" 8 | 9 | extra-source-files: 10 | - README.md 11 | 12 | # To avoid duplicated efforts in documentation and dealing with the 13 | # complications of embedding Haddock markup inside cabal files, it is 14 | # common to point users to the README.md file. 15 | description: Please see the README on GitHub at 16 | 17 | dependencies: 18 | - base >= 4.7 && < 5 19 | - unconstrained 20 | - text 21 | - hashable 22 | - vector-sized 23 | - unordered-containers 24 | - constraints 25 | - row-types 26 | - deepseq 27 | - random 28 | - containers 29 | - monad-bayes 30 | - log-domain 31 | - finite-typelits 32 | - data-default 33 | - statistics 34 | 35 | library: 36 | source-dirs: src 37 | 38 | executables: 39 | trace-types-exe: 40 | main: Main.hs 41 | source-dirs: app 42 | ghc-options: 43 | - -threaded 44 | - -rtsopts 45 | - -with-rtsopts=-N 46 | dependencies: 47 | - trace-types 48 | 49 | -------------------------------------------------------------------------------- /src/BasicExamples.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds, GADTs, KindSignatures, TypeFamilies, TypeOperators, TypeApplications, ScopedTypeVariables, MultiParamTypeClasses, OverloadedLabels, InstanceSigs, FlexibleContexts, FlexibleInstances, RankNTypes, UndecidableInstances, ConstraintKinds, OverlappingInstances, TypeFamilyDependencies #-} 2 | {-# LANGUAGE RebindableSyntax #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE PartialTypeSignatures, NoMonomorphismRestriction #-} 6 | {-# LANGUAGE RecordWildCards, Rank2Types, ConstraintKinds #-} 7 | 8 | module BasicExamples where 9 | 10 | import TraceTypes 11 | import Control.Monad.Bayes.Class 12 | import Control.Monad.Bayes.Weighted (runWeighted) 13 | import Control.Monad.Bayes.Sampler 14 | import Data.Row.Records (Rec, (.==), type (.==), (.!), (.+), type (.+), Empty) 15 | import Data.Row.Internal (LT(..), Row(..)) 16 | import qualified Data.Vector.Sized as V 17 | import Data.Vector.Sized (Vector) 18 | import Prelude hiding ((>>=), (>>), return) 19 | import Numeric.Log 20 | import Control.Monad.Bayes.Inference.SMC 21 | import Control.Monad.Bayes.Population 22 | import Control.Monad.Bayes.Sequential 23 | 24 | 25 | f >>= g = pBind f g 26 | 27 | return :: Monad m => a -> P m Empty a 28 | return = pReturn 29 | 30 | f >> g = f >>= \_ -> g 31 | 32 | ------------------------------- 33 | ------------------------------- 34 | -- -- 35 | -- EXAMPLES FROM FIGURE 1 -- 36 | -- Valid & Invalid Inference -- 37 | -- -- 38 | ------------------------------- 39 | ------------------------------- 40 | 41 | 42 | -- weight_model :: MonadInfer m => P m ("weight" .== Log Double .+ "measurement" .== Double) Double 43 | weight_model = do 44 | w <- sample #weight (gmma 2 1) 45 | sample #measurement (nrm (realToFrac w) 0.2) 46 | 47 | obs = #measurement .== 0.5 48 | 49 | 50 | -- Importance Sampling 51 | 52 | -- q1 :: MonadInfer m => P m ("weight" .== Unit) Unit 53 | q1 = sample #weight unif 54 | 55 | -- appliedQ1 :: Type Error 56 | -- appliedQ1 = importance weight_model obs q1 57 | 58 | -- q1' :: MonadInfer m => P m ("weight" .== Log Double) (Log Double) 59 | q1' = sample #weight (gmma 2 0.25) 60 | 61 | -- appliedQ1' :: MonadInfer m => m (Rec ("weight" .== Log Double)) 62 | appliedQ1' = importance weight_model obs q1' 63 | 64 | -- MCMC 65 | 66 | -- k1 :: MonadInfer m => K m ("weight" .== Log Double) ("weight" .== Log Double) 67 | k1 = mh (\t -> sample #weight (truncnrm (t .! #weight) 0.2)) 68 | 69 | -- k2 :: MonadInfer m => K m ("weight" .== Log Double) ("weight" .== Log Double) 70 | k2 = mh (\t -> sample #weight (truncnrm (t .! #weight) 1.0)) 71 | 72 | -- shouldUseK1 :: Rec ("weight" .== Log Double) -> Bool 73 | shouldUseK1 t = t .! #weight <= 2 74 | 75 | -- k :: Type Error 76 | -- k = seqK (ifK shouldUseK1 k1) 77 | -- (ifK (not . shouldUseK1) k2) 78 | 79 | -- k' :: MonadInfer m => K m ("weight" .== Log Double) ("weight" .== Log Double) 80 | k' = mh (\t -> let w = t .! #weight in 81 | sample #weight (truncnrm w (if w < 2 then 0.2 else 1))) 82 | 83 | -- appliedK' :: MonadInfer m => Rec ("weight" .== Log Double) -> m (Rec ("weight" .== Log Double)) 84 | appliedK' = k' (observe weight_model obs) 85 | 86 | -- SVI 87 | 88 | -- q2 :: MonadInfer m => (Double, Log Double) -> P m ("weight" .== Double) Double 89 | q2 (a, b) = sample #weight (nrm a b) 90 | 91 | -- appliedQ2 :: Type Error 92 | -- appliedQ2 = svi weight_model obs q2 93 | 94 | -- q2' :: MonadInfer m => (Log Double, Log Double) -> P m ("weight" .== Log Double) (Log Double) 95 | q2' (a, b) = sample #weight (truncnrm a b) 96 | 97 | -- appliedQ2' :: MonadInfer m => (Log Double, Log Double) -> m (Log Double, Log Double) 98 | appliedQ2' = svi weight_model obs q2' 99 | 100 | 101 | 102 | 103 | ------------------------------ 104 | ------------------------------ 105 | -- -- 106 | -- EXAMPLES FROM FIGURE 3 -- 107 | -- Control Flow Constructs -- 108 | -- -- 109 | ------------------------------ 110 | ------------------------------ 111 | 112 | 113 | -- simple1 :: MonadInfer m => P m ("x" .== Double .+ "z" .== Double) Double 114 | simple1 = do 115 | x <- sample #x (nrm 0 1) 116 | z <- sample #z (nrm x 1) 117 | return $ x + z 118 | 119 | -- simple2 :: MonadInfer m => P m ("x" .== Double .+ "z" .== Double) Double 120 | simple2 = do 121 | z <- sample #z (nrm 0 1) 122 | x <- sample #x (nrm z 1) 123 | return $ x + z 124 | 125 | -- branch1 :: MonadInfer m => P m ("b" .== Bool .+ "p" .== Unit .+ "coin" .== Bool) Bool 126 | branch1 = do 127 | isBiased <- sample #b (brn $ mkUnit 0.1) 128 | p <- if isBiased then 129 | sample #p (bta 1 2) 130 | else 131 | sample #p (unif) 132 | sample #coin (brn p) 133 | 134 | -- branch2 135 | -- :: MonadInfer m 136 | -- => P m ("p" .== Either (Rec ("isLow" .== Bool)) (Rec Empty) .+ "coin" .== Bool) Bool 137 | branch2 = do 138 | p <- withProbability #p (mkUnit 0.1) (do { 139 | isLow <- sample #isLow (brn (mkUnit 0.5)); 140 | return (if isLow then (mkUnit 0.01) else (mkUnit 0.99)) 141 | }) (return $ mkUnit 0.5) 142 | sample #coin (brn p) 143 | 144 | -- loop1 :: MonadInfer m => P m ("pts" .== [Rec ("x" .== Double .+ "y" .== Double)]) [Double] 145 | loop1 = do 146 | pts <- forRandomRange #pts (pois 3) (\i -> do { 147 | x <- sample #x (nrm (fromIntegral i) 1); 148 | sample #y (nrm x 1) 149 | }) 150 | return $ map (\y -> 2 * y) pts 151 | 152 | -- loop2 :: MonadInfer m => P m ("pts" .== Vector 3 (Rec ("y" .== Double))) [Double] 153 | loop2 = do 154 | pts <- forEach #pts (V.generate id :: Vector 3 Int) (\x -> do { 155 | sample #y (nrm (fromIntegral x) 1) 156 | }) 157 | return $ map (\y -> 2 * y) (V.toList pts) 158 | 159 | 160 | 161 | ------------------------------ 162 | ------------------------------ 163 | -- -- 164 | -- EXAMPLES FROM FIGURE 4 -- 165 | -- End-to-End Example -- 166 | -- -- 167 | ------------------------------ 168 | ------------------------------ 169 | 170 | xs = V.generate (\i -> 171 | case i of 172 | 0 -> -3.0 173 | 1 -> -2.0 174 | 2 -> -1.0 175 | 3 -> 0.0 176 | 4 -> 1.0 177 | 5 -> 2.0 178 | 6 -> 3.0) :: Vector 7 Double 179 | 180 | 181 | ys = V.generate (\i -> 182 | case i of 183 | 0 -> #y .== 3.1 184 | 1 -> #y .== 1.8 185 | 2 -> #y .== 1.1 186 | 3 -> #y .== (-0.2) 187 | 4 -> #y .== (-1.2) 188 | 5 -> #y .== (-2.1) 189 | 6 -> #y .== (-3.0)) :: Vector 7 (Rec ("y" .== Double)) 190 | 191 | -- prior 192 | -- :: MonadInfer m 193 | -- => P m ("noise" .== Log Double .+ "coeffs" .== [Rec ("c" .== Double)]) (Double -> Double, Log Double) 194 | prior = do z <- sample #noise (gmma 1 1) 195 | terms <- forRandomRange #coeffs (geom $ mkUnit 0.4) (\i -> do { 196 | coeff <- sample #c (nrm 0 1); 197 | return (\x -> coeff * (x ^ i)) 198 | }) 199 | let f = \x -> foldr (\t y -> y + (t x)) 0 terms 200 | return (f, z) 201 | 202 | -- p :: MonadInfer m 203 | -- => P m ("noise" .== Log Double .+ "coeffs" .== [Rec ("c" .== Double)] .+ "data" .== Vector 7 (Rec ("y" .== Double))) (Vector 7 Double) 204 | p = do (f, z) <- prior 205 | forEach #data xs (\x -> sample #y (nrm (f x) z)) 206 | 207 | 208 | data NNResult = NN { nnPi :: Unit, nnMu :: Double, nnSig :: Log Double, nnNoise :: Log Double } 209 | data QState = Q { qN :: Int, qResids :: Vector 7 Double, qNN:: NNResult } 210 | nn = undefined :: Vector 384 Double -> Vector 7 (Double, Double) -> NNResult 211 | 212 | -- q :: MonadInfer m 213 | -- => (Vector 384 Double, Log Double) 214 | -- -> Rec ("data" .== Vector 7 (Rec ("y" .== Double))) 215 | -- -> P m ("noise" .== Log Double .+ "coeffs" .== [Rec ("c" .== Double)]) (Log Double) 216 | q (theta, sigma) obs = 217 | do let ys = V.map (\d -> d .! #y) (obs .! #data) 218 | let nnInputs n resids = V.zip (V.map (\x -> x^n) xs) resids 219 | let initialState = Q { qN = 0, qResids = ys, qNN = nn theta (nnInputs 0 ys) } 220 | finalState <- while #coeffs initialState (nnPi . qNN) (mkUnit 0.99) (\s -> do { 221 | c <- sample #c $ nrm (nnMu (qNN s)) (nnSig (qNN s)); 222 | let newResids = V.map (\(x, y) -> y - c * (x ^ (qN s))) (V.zip xs (qResids s)) in 223 | return (Q {qN = qN s + 1, qResids = newResids, qNN = nn theta (nnInputs (qN s + 1) newResids)}) 224 | }) 225 | sample #noise (lognormal (nnNoise (qNN finalState)) sigma) 226 | 227 | -- sampleCurve = sampleIO $ runWeighted $ sampler $ traced p 228 | 229 | -- j = trainAmortized p q 230 | 231 | -- h = importance p (#data .== ys) prior 232 | 233 | 234 | 235 | 236 | 237 | 238 | ------------------------------ 239 | ------------------------------ 240 | -- -- 241 | -- HIDDEN MARKOV MODEL SMC -- 242 | -- -- 243 | ------------------------------ 244 | ------------------------------ 245 | 246 | 247 | 248 | -- simple_hmm_init :: MonadInfer m => P m ("init" .== Double) Double 249 | simple_hmm_init = sample #init (nrm 0 10) 250 | 251 | -- simple_hmm_trans :: MonadInfer m => Double -> P m ("x" .== Double .+ "y" .== Double) Double 252 | simple_hmm_trans state = do 253 | newState <- sample #x (nrm state 1) 254 | noisyObs <- sample #y (nrm newState 0.2) 255 | return newState 256 | 257 | hmm_steps = [#y .== 3.2, #y .== 5.1, #y .== 6.8, #y .== 8.9, #y .== 10.3] 258 | 259 | -- normalNormalPosterior :: MonadInfer m => (Double, Double) -> Double -> Double -> D m Double 260 | normalNormalPosterior (mu1, sigma1) sigma2 obs = 261 | let var = 1 / ((1 / sigma1^2) + (1 / sigma2^2)) 262 | mu = var * (mu1 / sigma1^2 + obs / sigma2^2) 263 | in nrm mu (Exp (log var / 2)) 264 | 265 | -- hmm_proposal :: MonadInfer m => Rec ("y" .== Double) -> Double -> P m ("x" .== Double) Double 266 | hmm_proposal o x = sample #x $ normalNormalPosterior (x, 1) 0.2 (o .! #y) 267 | 268 | -- hmm_proposals :: MonadInfer m => [(Rec ("y" .== Double), Double -> P m ("x" .== Double) Double)] 269 | hmm_proposals = fmap (\o -> (o, hmm_proposal o)) hmm_steps 270 | 271 | -- hmm_init_proposal :: MonadInfer m => P m ("init" .== Double) Double 272 | hmm_init_proposal = sample #init (nrm 0 10) 273 | 274 | -- customInference :: Int -> IO (Log Double) 275 | customInference n = sampleIO $ evidence $ smcMultinomial 5 n $ 276 | particleFilter simple_hmm_init hmm_init_proposal simple_hmm_trans hmm_proposals 277 | 278 | 279 | 280 | -------------------------------------------------------------------------------- /src/TraceTypes.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds, NamedFieldPuns, GADTs, KindSignatures, TypeFamilies, TypeOperators, TypeApplications, ScopedTypeVariables, MultiParamTypeClasses, OverloadedLabels, InstanceSigs, FlexibleContexts, AllowAmbiguousTypes, FlexibleInstances, RankNTypes, UndecidableInstances, ConstraintKinds, OverlappingInstances, TypeFamilyDependencies #-} 2 | 3 | module TraceTypes where 4 | 5 | import Data.Row.Records hiding ( Disjoint 6 | , sequence 7 | , map 8 | , zip 9 | ) 10 | import Data.Row.Internal ( Subset 11 | , type (.\\) 12 | , type (.\/) 13 | ) 14 | import Data.Default 15 | import Numeric.Log as Log 16 | import qualified Data.Vector.Sized as V 17 | import Data.Vector.Sized ( Vector ) 18 | import GHC.OverloadedLabels 19 | import Data.Finite 20 | import GHC.TypeLits 21 | import Control.Monad 22 | import Control.Monad.Bayes.Class as Bayes 23 | import Statistics.Distribution (logDensity, complCumulative) 24 | import Statistics.Distribution.Gamma 25 | import Statistics.Distribution.Beta 26 | import Statistics.Distribution.Normal 27 | 28 | -- The "Constraint" type is intended only 29 | -- for internal use, and is not user-facing. 30 | -- A Constraint of a certain record type is 31 | -- a record where each field type `a` has been 32 | -- replaced with the type `Maybe a`; a value of 33 | -- `Nothing` indicates an unconstrained value, 34 | -- whereas `Just x` indicates an observed (i.e. 35 | -- constrained) value. 36 | type Constraint r = Rec (Map Maybe r) 37 | 38 | -- (ConcatTo l r s) means that l, r, and s are all 39 | -- valid record types, and that l .+ r = r .+ l = s. 40 | type ConcatTo l r s 41 | = ( WellBehaved l 42 | , WellBehaved r 43 | , WellBehaved s 44 | , Subset l s 45 | , Subset r s 46 | , (l .+ r) ≈ s 47 | , (r .+ l) ≈ s 48 | , s .\\ l ≈ r 49 | , s .\\ r ≈ l 50 | ) 51 | 52 | -- Are the label sets of l and r disjoint? 53 | type Disjoint l r 54 | = ( ConcatTo l r (l .+ r) 55 | , ConcatTo (Map Maybe l) (Map Maybe r) (Map Maybe (l .+ r)) 56 | , Forall (Map Maybe l) Default 57 | , Forall (Map Maybe r) Default 58 | ) 59 | 60 | -- Three main types: probability distributions (D), unnormalized density-carrying measures (U), 61 | -- and probabilistic programs (P). 62 | data U m a = U 63 | { usampler :: m a 64 | , udensity :: a -> Log Double 65 | } 66 | data D m a = D 67 | { dsampler :: m a 68 | , ddensity :: a -> Log Double 69 | } 70 | data P m t a = P 71 | { traced :: D m (Rec t) 72 | , retval :: Rec t -> a 73 | , constrain :: Constraint t -> U m (Rec t) 74 | } 75 | 76 | class DensityCarrying f m where 77 | sampler :: f m a -> m a 78 | density :: f m a -> a -> Log Double 79 | 80 | 81 | instance DensityCarrying U m where 82 | sampler = usampler 83 | density = udensity 84 | 85 | instance DensityCarrying D m where 86 | sampler = dsampler 87 | density = ddensity 88 | 89 | dist sampler density = D { dsampler = sampler, ddensity = density } 90 | udist sampler density = U { usampler = sampler, udensity = density } 91 | 92 | -- DEFAULT INFERENCE 93 | observe 94 | :: forall m a t s 95 | . (Disjoint t s, Monad m) 96 | => P m (t .+ s) a 97 | -> Rec t 98 | -> U m (Rec s) 99 | observe p obs = 100 | let 101 | -- Turn the observation record `obs` of type `t` into a Constraint of 102 | -- type `t .+ s`, by filling in the `s` piece of the record with `Nothing`. 103 | constraints :: Constraint (t .+ s) 104 | constraints = map' Just obs .+ (default' @Default def :: Constraint s) 105 | 106 | -- Obtain a constrained distribution over traces. 107 | u :: U m (Rec (t .+ s)) 108 | u = constrain p constraints 109 | in 110 | -- Modify the distribution u so that it is over only the unobserved values 111 | -- (Rec s), not full-but-constrained traces (Rec (t .+ s)). 112 | udist (fmap restrict $ sampler u) (density u . ((.+) obs)) 113 | 114 | traceSampler = sampler . traced 115 | traceDensity = density . traced 116 | 117 | -- UNIT INTERVAL TYPE 118 | 119 | newtype Unit = Sigmoid {logit :: Double} 120 | 121 | fromUnit :: Unit -> Double 122 | fromUnit (Sigmoid x) = 1.0 / (1.0 + exp (-x)) 123 | 124 | mkUnit :: Double -> Unit 125 | mkUnit u = if 0 < u && u < 1 126 | then Sigmoid (log (u / (1.0 - u))) 127 | else error "Number must be between 0 and 1." 128 | 129 | instance Show Unit where 130 | show x = show (fromUnit x) 131 | 132 | ------------------------------ 133 | -- STPL LANGUAGE CONSTRUCTS -- 134 | ------------------------------ 135 | 136 | -- RETURN 137 | 138 | pReturn :: Monad m => a -> P m Empty a 139 | pReturn x = 140 | let sampler = return Data.Row.Records.empty 141 | density _ = 1 142 | traced = dist sampler density 143 | constrain _ = udist sampler density 144 | retval _ = x 145 | in P { traced, retval, constrain } 146 | 147 | 148 | -- BIND 149 | 150 | pBind :: (Disjoint t s, Monad m) => P m t a -> (a -> P m s b) -> P m (t .+ s) b 151 | pBind f g = 152 | let 153 | bindSampler = do 154 | fTrace <- sampler $ traced f 155 | let x = retval f fTrace 156 | gTrace <- sampler $ traced (g x) 157 | return (fTrace .+ gTrace) 158 | 159 | bindDensity tr = 160 | let fDensity = density (traced f) (restrict tr) 161 | x = retval f (restrict tr) 162 | gDensity = density (traced $ g x) (restrict tr) 163 | in fDensity * gDensity 164 | 165 | bindTraced = dist bindSampler bindDensity 166 | 167 | bindRet tr = let x = retval f (restrict tr) in retval (g x) (restrict tr) 168 | 169 | constrainedSampler c = do 170 | fTrace <- sampler $ constrain f (restrict c) 171 | let x = retval f fTrace 172 | gTrace <- sampler $ constrain (g x) (restrict c) 173 | return (fTrace .+ gTrace) 174 | 175 | constrainedDensity c tr = 176 | let fDensity = density (constrain f (restrict c)) (restrict tr) 177 | x = retval f (restrict tr) 178 | gDensity = density (constrain (g x) (restrict c)) (restrict tr) 179 | in fDensity * gDensity 180 | 181 | bindConstrain c = udist (constrainedSampler c) (constrainedDensity c) 182 | in 183 | P { traced = bindTraced, retval = bindRet, constrain = bindConstrain } 184 | 185 | 186 | -- SAMPLE 187 | 188 | singleLabelConstrain 189 | :: (KnownSymbol l, MonadInfer m) 190 | => Label l 191 | -> D m (Rec (l .== a)) 192 | -> Constraint (l .== a) 193 | -> U m (Rec (l .== a)) 194 | singleLabelConstrain lbl traced c = case c .! lbl of 195 | Just x -> 196 | let t = lbl .== x 197 | constrainedSampler = do 198 | score $ density traced t 199 | return t 200 | constrainedDensity _ = density traced t 201 | in udist constrainedSampler constrainedDensity 202 | 203 | Nothing -> udist (sampler traced) (density traced) 204 | 205 | 206 | sample :: (KnownSymbol l, MonadInfer m) => Label l -> D m a -> P m (l .== a) a 207 | sample lbl d = 208 | let sampleTraced = dist 209 | (do 210 | x <- sampler d 211 | return (lbl .== x) 212 | ) 213 | (\t -> density d (t .! lbl)) 214 | 215 | sampleRet t = t .! lbl 216 | 217 | sampleConstrain = singleLabelConstrain lbl sampleTraced 218 | in P { traced = sampleTraced 219 | , retval = sampleRet 220 | , constrain = sampleConstrain 221 | } 222 | 223 | 224 | -- WITH PROBABILITY 225 | 226 | withProbability 227 | :: (KnownSymbol l, MonadInfer m) 228 | => Label l 229 | -> Unit 230 | -> P m t a 231 | -> P m s a 232 | -> P m (l .== Either (Rec t) (Rec s)) a 233 | withProbability lbl logit f g = 234 | let p = fromUnit logit 235 | 236 | wpSampler = do 237 | r <- random 238 | if r < p 239 | then do 240 | t <- sampler $ traced f 241 | return (lbl .== Left t) 242 | else do 243 | t <- sampler $ traced g 244 | return (lbl .== Right t) 245 | 246 | wpDensity tr = case tr .! lbl of 247 | Left t -> Exp (log p) * density (traced f) t 248 | Right t -> Exp (log (1.0 - p)) * density (traced g) t 249 | 250 | wpTraced = dist wpSampler wpDensity 251 | 252 | wpRet t = case t .! lbl of 253 | Left t -> retval f t 254 | Right t -> retval g t 255 | 256 | wpConstrain = singleLabelConstrain lbl wpTraced 257 | in P { traced = wpTraced, retval = wpRet, constrain = wpConstrain } 258 | 259 | -- FOR EACH 260 | 261 | forEach 262 | :: (KnownSymbol l, KnownNat n, MonadInfer m) 263 | => Label l 264 | -> Vector n a 265 | -> (a -> P m t b) 266 | -> P m (l .== Vector n (Rec t)) (Vector n b) 267 | forEach lbl xs body = 268 | let 269 | forEachSampler = do 270 | traces <- V.mapM (sampler . traced . body) xs 271 | return (lbl .== traces) 272 | 273 | forEachDensity t = V.product 274 | (V.map (\(x, tr) -> density (traced $ body x) tr) (V.zip xs (t .! lbl))) 275 | 276 | forEachTraced = dist forEachSampler forEachDensity 277 | 278 | forEachRet t = V.map (\(x, tr) -> retval (body x) tr) (V.zip xs (t .! lbl)) 279 | 280 | forEachConstrain = singleLabelConstrain lbl forEachTraced 281 | in 282 | P { traced = forEachTraced 283 | , retval = forEachRet 284 | , constrain = forEachConstrain 285 | } 286 | 287 | 288 | -- FOR ... IN RANDOM RANGE ... 289 | 290 | forRandomRange 291 | :: (KnownSymbol l, MonadInfer m) 292 | => Label l 293 | -> D m Int 294 | -> (Int -> P m t a) 295 | -> P m (l .== [Rec t]) [a] 296 | forRandomRange lbl d body = 297 | let 298 | -- Sample a trace of the for loop. 299 | forSampler = do 300 | n <- sampler d 301 | traces <- sequence $ map (sampler . traced . body) [0 .. n - 1] 302 | return (lbl .== traces) 303 | 304 | -- Compute the density of a trace of the for loop. 305 | forDensity t = 306 | density d (length $ t .! lbl) 307 | * product 308 | (map (\(tr, i) -> density (traced $ body i) tr) 309 | (zip (t .! lbl) [0 ..]) 310 | ) 311 | 312 | -- The for loop's density-carrying distribution over traces 313 | forTraced = dist forSampler forDensity 314 | 315 | -- Compute the return value of the for loop, given a trace of an execution. 316 | -- The return value is just a list of the return values from each iteration. 317 | forRet t = map (\(tr, i) -> retval (body i) tr) (zip (t .! lbl) [0 ..]) 318 | 319 | -- Constrain the random choices made inside the for loop. 320 | forConstrain = singleLabelConstrain lbl forTraced 321 | in 322 | P { traced = forTraced, retval = forRet, constrain = forConstrain } 323 | 324 | {- WHILE: 325 | 326 | A stochastic while loop. 327 | Takes as input: 328 | - lbl : The label at which to trace the while loop's random choices. 329 | - init : The initial state of the while loop. 330 | - pi : A function computing a probability of continuing, based on the current state. 331 | - pi_max : An upper bound beyond which the probability of continuing is truncated. 332 | - body : The body of the while loop, which simulates a new state based on the current state. 333 | 334 | The resulting probabilistic program has type P m (l .== [Rec t]) a: 335 | - The trace contains, at the specified label, a list of subtraces, one for each iteration. 336 | - The return value is the final state at the finish of the while loop. 337 | 338 | -} 339 | 340 | while 341 | :: (KnownSymbol l, MonadInfer m) 342 | => Label l 343 | -> a 344 | -> (a -> Unit) 345 | -> Unit 346 | -> (a -> P m t a) 347 | -> P m (l .== [Rec t]) a 348 | while lbl init pi pi_max body = 349 | let 350 | -- Sample a trace of the while loop. 351 | whileSampler state = do 352 | shouldContinue <- bernoulli $ min (fromUnit (pi state)) (fromUnit pi_max) 353 | if shouldContinue 354 | then do 355 | nextTrace <- sampler (traced $ body state) 356 | let nextState = retval (body state) nextTrace 357 | restOfTraces <- whileSampler nextState 358 | return (lbl .== (nextTrace : (restOfTraces .! lbl))) 359 | else return (lbl .== []) 360 | 361 | -- Given traces from each iteration, and an initial state, 362 | -- compute the final state. 363 | retvalFromTraces init ts = case ts of 364 | [] -> init 365 | t : ts -> retvalFromTraces (retval (body init) t) ts 366 | 367 | -- Given a trace of the entire while loop, compute the final 368 | -- state. 369 | whileRet t = retvalFromTraces init (t .! lbl) 370 | 371 | -- Given an initial state and traces from each iteration, 372 | -- compute the density. 373 | densityFromTraces init ts = case ts of 374 | [] -> Exp (log (1.0 - (fromUnit (pi init)))) 375 | t : ts -> 376 | Exp (log (fromUnit (pi init))) 377 | * (density (traced $ body init) t) 378 | * densityFromTraces (retval (body init) t) ts 379 | 380 | -- Given a trace of the entire while loop, compute the density. 381 | whileDensity t = densityFromTraces init (t .! lbl) 382 | 383 | -- The while loop's density-carrying distribution over traces. 384 | whileTraced = dist (whileSampler init) whileDensity 385 | 386 | -- Constrain the choices made inside the loop. 387 | whileConstrain = singleLabelConstrain lbl whileTraced 388 | 389 | in P { traced = whileTraced, retval = whileRet, constrain = whileConstrain } 390 | 391 | ------------------- 392 | -- DISTRIBUTIONS -- 393 | ------------------- 394 | 395 | nrm :: MonadInfer m => Double -> Log Double -> D m Double 396 | nrm mu (Exp sigln) = dist (normal mu (exp sigln)) (normalPdf mu (exp sigln)) 397 | 398 | brn :: MonadSample m => Unit -> D m Bool 399 | brn logit = 400 | let p = fromUnit logit 401 | in dist (bernoulli p) (\b -> Exp . log $ if b then p else (1.0 - p)) 402 | 403 | gmma :: MonadSample m => Log Double -> Log Double -> D m (Log Double) 404 | gmma shape scale = 405 | let toDouble (Exp l) = exp l 406 | toLogDouble d = Exp (log d) 407 | in 408 | dist (fmap toLogDouble (gamma (toDouble shape) (toDouble scale))) 409 | (Exp . (logDensity $ Statistics.Distribution.Gamma.gammaDistr 410 | (toDouble shape) (toDouble scale)) 411 | . toDouble) 412 | 413 | geom :: MonadSample m => Unit -> D m Int 414 | geom p' = 415 | let p = fromUnit p' 416 | in dist (geometric p) 417 | (\n -> Exp (fromIntegral n * (log p) + (log (1.0 - p)))) 418 | 419 | pois :: MonadSample m => Int -> D m Int 420 | pois = undefined 421 | 422 | bta :: MonadSample m => Log Double -> Log Double -> D m Unit 423 | bta (Exp lna) (Exp lnb) = 424 | let 425 | a = exp lna 426 | b = exp lnb 427 | in 428 | dist (fmap mkUnit (beta a b)) 429 | (Exp . (logDensity $ Statistics.Distribution.Beta.betaDistr a b) 430 | . fromUnit) 431 | 432 | unif :: MonadSample m => D m Unit 433 | unif = dist (fmap mkUnit Bayes.random) (\x -> 1.0 :: Log Double) 434 | 435 | cat :: (KnownNat n, MonadSample m) => Vector n Unit -> D m (Finite n) 436 | cat probs = undefined 437 | 438 | lognormal :: forall m. MonadInfer m => Log Double -> Log Double -> D m (Log Double) 439 | lognormal (Exp logmu) sigma = 440 | dist (fmap Exp $ sampler $ nrm logmu sigma) 441 | (\(Exp logx) -> density (nrm logmu sigma :: D m Double) logx / (Exp logx)) 442 | 443 | 444 | truncnrm :: MonadInfer m => Log Double -> Log Double -> D m (Log Double) 445 | truncnrm (Exp logmu) (Exp logsigma) = 446 | let mu = exp logmu 447 | sigma = exp logsigma 448 | unbounded = nrm mu (Exp logsigma) 449 | 450 | truncNormSampler = do 451 | x <- sampler unbounded 452 | if x < 0 then truncNormSampler else return (Exp $ log x) 453 | 454 | truncNormPdf (Exp logy) = 455 | let y = exp logy 456 | z = complCumulative (normalDistr mu sigma) 0 457 | in (density unbounded y / (Exp (log z))) 458 | in dist truncNormSampler truncNormPdf 459 | 460 | 461 | 462 | ---------------------------- 463 | -- PROGRAMMABLE INFERENCE -- 464 | ---------------------------- 465 | 466 | -- IMPORTANCE SAMPLING 467 | 468 | importance 469 | :: (Disjoint t s, MonadInfer m) 470 | => P m (t .+ s) a 471 | -> Rec t 472 | -> P m s b 473 | -> m (Rec s) 474 | importance p t q = 475 | let target = observe p t 476 | in do 477 | tr <- sampler $ traced q 478 | score ((density target tr) / (density (traced q) tr)) 479 | return tr 480 | 481 | -- SMC 482 | 483 | unroll 484 | :: (MonadInfer m, Disjoint t s) 485 | => P m i a 486 | -> (a -> P m (t .+ s) a) 487 | -> [Rec t] 488 | -> m (Rec i, [Rec s]) 489 | unroll init next steps = do 490 | tInit <- sampler $ traced init 491 | let processStep soFar t = do { 492 | (s, results) <- soFar; 493 | tNext <- sampler $ observe (next s) t; 494 | return (retval (next s) (tNext .+ t), tNext : results) 495 | } 496 | (s, l) <- foldl processStep (return (retval init tInit, [])) steps 497 | return (tInit, reverse l) 498 | 499 | particleFilter 500 | :: (MonadInfer m, Disjoint t s, 501 | Disjoint Empty i, (Empty .+ i) ≈ i) 502 | => P m i a 503 | -> P m i b 504 | -> (a -> P m (t .+ s) a) 505 | -> [(Rec t, a -> P m s c)] 506 | -> m (Rec i, [Rec s]) 507 | particleFilter init qInit next steps = do 508 | tInit <- importance init Data.Row.Records.empty qInit 509 | let processStep soFar (t, q) = do { 510 | (s, results) <- soFar; 511 | tNext <- importance (next s) t (q s); 512 | return (retval (next s) (tNext .+ t), tNext : results) 513 | } 514 | (s, l) <- foldl processStep (return (retval init tInit, [])) steps 515 | return (tInit, reverse l) 516 | 517 | 518 | -- MCMC 519 | 520 | type K m t (s :: Row *) = U m (Rec t) -> Rec t -> m (Rec t) 521 | 522 | mh 523 | :: forall s t m a 524 | . (Disjoint t s, MonadInfer m) 525 | => (Rec (t .+ s) -> P m t a) 526 | -> K m (t .+ s) t 527 | mh q p old = 528 | let proposal = do 529 | new <- sampler (traced $ q old) 530 | return $ new .+ (restrict old :: Rec s) 531 | rho (x, y) = 532 | ((density p y) * (density (traced $ q y) (restrict x))) 533 | / ((density p x) * (density (traced $ q x) (restrict y))) 534 | in do 535 | proposed <- proposal 536 | r <- Bayes.random 537 | if Exp (log r) < (min 1 (rho (old, proposed))) 538 | then return proposed 539 | else return old 540 | 541 | seqK :: Monad m => K m t s -> K m t r -> K m t (s .\/ r) 542 | seqK k1 k2 p old = do 543 | t' <- k1 p old 544 | k2 p t' 545 | 546 | mixK :: MonadSample m => Double -> K m t s -> K m t r -> K m t (s .\/ r) 547 | mixK p k1 k2 model old = do 548 | r <- Bayes.random 549 | if r < p then k1 model old else k2 model old 550 | 551 | repeatK :: Monad m => Int -> K m t s -> K m t s 552 | repeatK n k p old = if n == 0 553 | then return old 554 | else do 555 | t <- k p old 556 | repeatK (n - 1) k p t 557 | 558 | ifK 559 | :: (Disjoint t s, Monad m) 560 | => (Rec t -> Bool) 561 | -> K m (t .+ s) s 562 | -> K m (t .+ s) s 563 | ifK b k p old = if (b (restrict old)) then k p old else return old 564 | 565 | -- VARIATIONAL (STUBBED) 566 | 567 | svi 568 | :: (Disjoint t s, Monad m) 569 | => P m (t .+ s) a 570 | -> Rec t 571 | -> (theta -> P m s b) 572 | -> theta 573 | -> m theta 574 | svi p t q theta = undefined 575 | -- do tr <- traceSampler $ q theta 576 | -- let sco = density (observe p t) tr 577 | -- let gradient = grad (\th -> traceDensity (q th) tr) theta 578 | -- return theta + stepSize * sco * gradient 579 | 580 | trainAmortized 581 | :: (Disjoint t s, Monad m) 582 | => P m (t .+ s) a 583 | -> (theta -> Rec t -> P m s b) 584 | -> theta 585 | -> m theta 586 | trainAmortized p q theta = undefined 587 | -- do tr <- traceSampler p 588 | -- let gradient = grad (\th -> traceDensity (q th (restrict tr)) (restrict tr)) theta 589 | -- return theta + step * gradient 590 | 591 | 592 | ifThenElse c t f = if c then t else f 593 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # 15 | # The location of a snapshot can be provided as a file or url. Stack assumes 16 | # a snapshot provided as a file might change, whereas a url resource does not. 17 | # 18 | # resolver: ./custom-snapshot.yaml 19 | # resolver: https://example.com/snapshots/2018-01-01.yaml 20 | resolver: lts-9.0 21 | 22 | # User packages to be built. 23 | # Various formats can be used as shown in the example below. 24 | # 25 | # packages: 26 | # - some-directory 27 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 28 | # subdirs: 29 | # - auto-update 30 | # - wai 31 | packages: 32 | - . 33 | # Dependency packages to be pulled from upstream that are not in the resolver. 34 | # These entries can reference officially published versions as well as 35 | # forks / in-progress versions pinned to a git hash. For example: 36 | # 37 | # extra-deps: 38 | # - acme-missiles-0.3 39 | # - git: https://github.com/commercialhaskell/stack.git 40 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 41 | # 42 | # extra-deps: [] 43 | extra-deps: 44 | - github: adscib/monad-bayes 45 | commit: f55d9fa9d24d169d53bb03598306ee8c46b5fc11 46 | - row-types-0.3.0.0@sha256:a2d26ef268d38e76b82504e564ba771c0d03793e5452d05660e60803ae45d086,3283 47 | - generic-lens-1.2.0.1@sha256:faecffd513340cd570866aa8f3e0e139d3f9240ecd3eb766357cdf47f550004c,6550 48 | - unconstrained-0.1.0.2@sha256:e3be5ea44350c5ce9d5f54d0f42c9b5b0639eb0343e9bf7802fb579788a2507c,645 49 | 50 | # Override default flag values for local packages and extra-deps 51 | # flags: {} 52 | 53 | # Extra package databases containing global packages 54 | # extra-package-dbs: [] 55 | 56 | # Control whether we use the GHC we find on the path 57 | # system-ghc: true 58 | # 59 | # Require a specific version of stack, using version ranges 60 | # require-stack-version: -any # Default 61 | # require-stack-version: ">=2.1" 62 | # 63 | # Override the architecture used by stack, especially useful on Windows 64 | # arch: i386 65 | # arch: x86_64 66 | # 67 | # Extra directories used by stack for building 68 | # extra-include-dirs: [/path/to/dir] 69 | # extra-lib-dirs: [/path/to/dir] 70 | # 71 | # Allow a newer minor version of GHC than the snapshot specifies 72 | # compiler-check: newer-minor 73 | -------------------------------------------------------------------------------- /trace-types.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.31.2. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | -- 7 | -- hash: 673a08b8cd5a2cf3bd909493ef897f91a9f716b904300f884748452a463c34d7 8 | 9 | name: trace-types 10 | version: 0.1.0.0 11 | description: Please see the README on GitHub at 12 | homepage: https://github.com/githubuser/trace-types#readme 13 | bug-reports: https://github.com/githubuser/trace-types/issues 14 | author: Author name here 15 | maintainer: example@example.com 16 | copyright: 2019 Author name here 17 | license: BSD3 18 | license-file: LICENSE 19 | build-type: Simple 20 | extra-source-files: 21 | README.md 22 | 23 | source-repository head 24 | type: git 25 | location: https://github.com/githubuser/trace-types 26 | 27 | library 28 | exposed-modules: 29 | BasicExamples 30 | TraceTypes 31 | other-modules: 32 | Paths_trace_types 33 | hs-source-dirs: 34 | src 35 | build-depends: 36 | base >=4.7 && <5 37 | , constraints 38 | , containers 39 | , data-default 40 | , deepseq 41 | , finite-typelits 42 | , hashable 43 | , log-domain 44 | , monad-bayes 45 | , random 46 | , row-types 47 | , statistics 48 | , text 49 | , unconstrained 50 | , unordered-containers 51 | , vector-sized 52 | default-language: Haskell2010 53 | 54 | executable trace-types-exe 55 | main-is: Main.hs 56 | other-modules: 57 | Paths_trace_types 58 | hs-source-dirs: 59 | app 60 | ghc-options: -threaded -rtsopts -with-rtsopts=-N 61 | build-depends: 62 | base >=4.7 && <5 63 | , constraints 64 | , containers 65 | , data-default 66 | , deepseq 67 | , finite-typelits 68 | , hashable 69 | , log-domain 70 | , monad-bayes 71 | , random 72 | , row-types 73 | , statistics 74 | , text 75 | , trace-types 76 | , unconstrained 77 | , unordered-containers 78 | , vector-sized 79 | default-language: Haskell2010 80 | 81 | --------------------------------------------------------------------------------