├── Setup.hs ├── .gitignore ├── README.md ├── app └── Main.hs ├── LICENSE ├── package.yaml ├── haskell-trace-types.cabal ├── stack.yaml └── src ├── BasicExamples.hs └── TraceTypes.hs /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | dist-* 3 | cabal-dev 4 | *.o 5 | *.hi 6 | *.hie 7 | *.chi 8 | *.chs.h 9 | *.dyn_o 10 | *.dyn_hi 11 | .hpc 12 | .hsenv 13 | .cabal-sandbox/ 14 | cabal.sandbox.config 15 | *.prof 16 | *.aux 17 | *.hp 18 | *.eventlog 19 | .stack-work/ 20 | cabal.project.local 21 | cabal.project.local~ 22 | .HTF/ 23 | .ghc.environment.* 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # haskell-trace-types 2 | 3 | Haskell prototype of the system described in [Trace Types and Denotational Semantics for Sound Programmable Inference in Probabilistic Languages](https://dl.acm.org/doi/10.1145/3371087). 4 | 5 | Run `stack ghci` to start an interactive prompt 6 | with definitions from `TraceTypes` and `BasicExamples` 7 | available. 8 | 9 | -------------------------------------------------------------------------------- /app/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds, GADTs, KindSignatures, TypeFamilies, TypeOperators, TypeApplications, ScopedTypeVariables, MultiParamTypeClasses, OverloadedLabels, InstanceSigs, FlexibleContexts, AllowAmbiguousTypes, FlexibleInstances, RankNTypes, UndecidableInstances, ConstraintKinds, TypeFamilyDependencies #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE PartialTypeSignatures #-} 5 | {-# LANGUAGE RecordWildCards, Rank2Types, ConstraintKinds #-} 6 | 7 | module Main where 8 | 9 | import TraceTypes 10 | import Control.Monad.Bayes.Sampler.Lazy (runSamplerTIO) 11 | import Control.Monad.Bayes.Weighted (runWeightedT) 12 | import Data.Row.Records 13 | import GHC.OverloadedLabels (fromLabel) 14 | 15 | run = runSamplerTIO . runWeightedT 16 | 17 | main :: IO () 18 | main = do (b, w) <- run . traceSampler $ sample #b (brn $ mkUnit 0.4) 19 | putStrLn (show b) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 MIT Probabilistic Computing Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: haskell-trace-types 2 | version: 0.1.0.0 3 | github: "probcomp/haskell-trace-types" 4 | license: MIT 5 | author: "Alex Lew" 6 | maintainer: "alexlew@mit.edu" 7 | copyright: "2019-2025 Alex Lew" 8 | 9 | extra-source-files: 10 | - README.md 11 | 12 | # To avoid duplicated efforts in documentation and dealing with the 13 | # complications of embedding Haddock markup inside cabal files, it is 14 | # common to point users to the README.md file. 15 | description: Please see the README on GitHub at 16 | 17 | dependencies: 18 | - base >= 4.7 && < 5 19 | - unconstrained 20 | - text 21 | - hashable 22 | - vector-sized 23 | - unordered-containers 24 | - constraints 25 | - row-types 26 | - deepseq 27 | - random 28 | - containers 29 | - monad-bayes 30 | - log-domain 31 | - finite-typelits 32 | - data-default 33 | - statistics 34 | 35 | library: 36 | source-dirs: src 37 | 38 | executables: 39 | trace-types-exe: 40 | main: Main.hs 41 | source-dirs: app 42 | ghc-options: 43 | - -threaded 44 | - -rtsopts 45 | - -with-rtsopts=-N 46 | dependencies: 47 | - haskell-trace-types 48 | 49 | -------------------------------------------------------------------------------- /haskell-trace-types.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.38.0. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: haskell-trace-types 8 | version: 0.1.0.0 9 | description: Please see the README on GitHub at 10 | homepage: https://github.com/probcomp/haskell-trace-types#readme 11 | bug-reports: https://github.com/probcomp/haskell-trace-types/issues 12 | author: Alex Lew 13 | maintainer: alexlew@mit.edu 14 | copyright: 2019-2025 Alex Lew 15 | license: MIT 16 | license-file: LICENSE 17 | build-type: Simple 18 | extra-source-files: 19 | README.md 20 | 21 | source-repository head 22 | type: git 23 | location: https://github.com/probcomp/haskell-trace-types 24 | 25 | library 26 | exposed-modules: 27 | BasicExamples 28 | TraceTypes 29 | other-modules: 30 | Paths_haskell_trace_types 31 | hs-source-dirs: 32 | src 33 | build-depends: 34 | base >=4.7 && <5 35 | , constraints 36 | , containers 37 | , data-default 38 | , deepseq 39 | , finite-typelits 40 | , hashable 41 | , log-domain 42 | , monad-bayes 43 | , random 44 | , row-types 45 | , statistics 46 | , text 47 | , unconstrained 48 | , unordered-containers 49 | , vector-sized 50 | default-language: Haskell2010 51 | 52 | executable trace-types-exe 53 | main-is: Main.hs 54 | other-modules: 55 | Paths_haskell_trace_types 56 | hs-source-dirs: 57 | app 58 | ghc-options: -threaded -rtsopts -with-rtsopts=-N 59 | build-depends: 60 | base >=4.7 && <5 61 | , constraints 62 | , containers 63 | , data-default 64 | , deepseq 65 | , finite-typelits 66 | , hashable 67 | , haskell-trace-types 68 | , log-domain 69 | , monad-bayes 70 | , random 71 | , row-types 72 | , statistics 73 | , text 74 | , unconstrained 75 | , unordered-containers 76 | , vector-sized 77 | default-language: Haskell2010 78 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # 15 | # The location of a snapshot can be provided as a file or url. Stack assumes 16 | # a snapshot provided as a file might change, whereas a url resource does not. 17 | # 18 | # resolver: ./custom-snapshot.yaml 19 | # resolver: https://example.com/snapshots/2018-01-01.yaml 20 | resolver: lts-23.27 21 | 22 | # User packages to be built. 23 | # Various formats can be used as shown in the example below. 24 | # 25 | # packages: 26 | # - some-directory 27 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 28 | # subdirs: 29 | # - auto-update 30 | # - wai 31 | packages: 32 | - . 33 | # Dependency packages to be pulled from upstream that are not in the resolver. 34 | # These entries can reference officially published versions as well as 35 | # forks / in-progress versions pinned to a git hash. For example: 36 | # 37 | # extra-deps: 38 | # - acme-missiles-0.3 39 | # - git: https://github.com/commercialhaskell/stack.git 40 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 41 | # 42 | # extra-deps: [] 43 | extra-deps: 44 | - monad-bayes-1.3.0.4@sha256:101a60697c2bf0fae60157284bb8127e6a84e37c82c562cd7a262b50a34f3b5d,6650 45 | - brick-2.5@sha256:f80444d8009883013a2dac934618d86c104f4d68c44171c7aab519418d28aa77,17632 46 | 47 | # Override default flag values for local packages and extra-deps 48 | # flags: {} 49 | 50 | # Extra package databases containing global packages 51 | # extra-package-dbs: [] 52 | 53 | # Control whether we use the GHC we find on the path 54 | # system-ghc: true 55 | # 56 | # Require a specific version of stack, using version ranges 57 | # require-stack-version: -any # Default 58 | # require-stack-version: ">=2.1" 59 | # 60 | # Override the architecture used by stack, especially useful on Windows 61 | # arch: i386 62 | # arch: x86_64 63 | # 64 | # Extra directories used by stack for building 65 | # extra-include-dirs: [/path/to/dir] 66 | # extra-lib-dirs: [/path/to/dir] 67 | # 68 | # Allow a newer minor version of GHC than the snapshot specifies 69 | # compiler-check: newer-minor 70 | -------------------------------------------------------------------------------- /src/BasicExamples.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds, GADTs, KindSignatures, TypeFamilies, TypeOperators, TypeApplications, ScopedTypeVariables, MultiParamTypeClasses, OverloadedLabels, InstanceSigs, FlexibleContexts, FlexibleInstances, RankNTypes, UndecidableInstances, ConstraintKinds, TypeFamilyDependencies #-} 2 | {-# LANGUAGE RebindableSyntax #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE PartialTypeSignatures, NoMonomorphismRestriction #-} 6 | {-# LANGUAGE RecordWildCards, Rank2Types, ConstraintKinds #-} 7 | 8 | module BasicExamples where 9 | 10 | import TraceTypes 11 | import Control.Monad.Bayes.Class 12 | import Control.Monad.Bayes.Weighted (runWeightedT) 13 | import Control.Monad.Bayes.Sampler.Lazy (runSamplerTIO) 14 | import Data.Row.Records (Rec, (.==), type (.==), (.!), (.+), type (.+), Empty) 15 | import Data.Row (Label) 16 | import GHC.OverloadedLabels (fromLabel) 17 | import Data.Row.Internal (LT(..), Row(..)) 18 | import qualified Data.Vector.Sized as V 19 | import Data.Vector.Sized (Vector) 20 | import Prelude hiding ((>>=), (>>), return) 21 | import Numeric.Log 22 | import Control.Monad.Bayes.Inference.SMC (smcPush, SMCConfig(..)) 23 | import Control.Monad.Bayes.Population (resampleMultinomial, evidence) 24 | 25 | 26 | f >>= g = pBind f g 27 | 28 | return :: Monad m => a -> P m Empty a 29 | return = pReturn 30 | 31 | f >> g = f >>= \_ -> g 32 | 33 | ------------------------------- 34 | ------------------------------- 35 | -- -- 36 | -- EXAMPLES FROM FIGURE 1 -- 37 | -- Valid & Invalid Inference -- 38 | -- -- 39 | ------------------------------- 40 | ------------------------------- 41 | 42 | 43 | -- weight_model :: MonadMeasure m => P m ("weight" .== Log Double .+ "measurement" .== Double) Double 44 | weight_model = do 45 | w <- sample #weight (gmma 2 1) 46 | sample #measurement (nrm (realToFrac w) 0.2) 47 | 48 | obs = #measurement .== 0.5 49 | 50 | 51 | -- Importance Sampling 52 | 53 | -- q1 :: MonadMeasure m => P m ("weight" .== Unit) Unit 54 | q1 = sample #weight unif 55 | 56 | -- appliedQ1 :: Type Error 57 | -- appliedQ1 = importance weight_model obs q1 58 | 59 | -- q1' :: MonadMeasure m => P m ("weight" .== Log Double) (Log Double) 60 | q1' = sample #weight (gmma 2 0.25) 61 | 62 | -- appliedQ1' :: MonadMeasure m => m (Rec ("weight" .== Log Double)) 63 | appliedQ1' = importance weight_model obs q1' 64 | 65 | -- MCMC 66 | 67 | -- k1 :: MonadMeasure m => K m ("weight" .== Log Double) ("weight" .== Log Double) 68 | k1 = mh (\t -> sample #weight (truncnrm (t .! #weight) 0.2)) 69 | 70 | -- k2 :: MonadMeasure m => K m ("weight" .== Log Double) ("weight" .== Log Double) 71 | k2 = mh (\t -> sample #weight (truncnrm (t .! #weight) 1.0)) 72 | 73 | -- shouldUseK1 :: Rec ("weight" .== Log Double) -> Bool 74 | shouldUseK1 t = t .! #weight <= 2 75 | 76 | -- k :: Type Error 77 | -- k = seqK (ifK shouldUseK1 k1) 78 | -- (ifK (not . shouldUseK1) k2) 79 | 80 | -- k' :: MonadMeasure m => K m ("weight" .== Log Double) ("weight" .== Log Double) 81 | k' = mh (\t -> let w = t .! #weight in 82 | sample #weight (truncnrm w (if w < 2 then 0.2 else 1))) 83 | 84 | -- appliedK' :: MonadMeasure m => Rec ("weight" .== Log Double) -> m (Rec ("weight" .== Log Double)) 85 | appliedK' = k' (observe weight_model obs) 86 | 87 | -- SVI 88 | 89 | -- q2 :: MonadMeasure m => (Double, Log Double) -> P m ("weight" .== Double) Double 90 | q2 (a, b) = sample #weight (nrm a b) 91 | 92 | -- appliedQ2 :: Type Error 93 | -- appliedQ2 = svi weight_model obs q2 94 | 95 | -- q2' :: MonadMeasure m => (Log Double, Log Double) -> P m ("weight" .== Log Double) (Log Double) 96 | q2' (a, b) = sample #weight (truncnrm a b) 97 | 98 | -- appliedQ2' :: MonadMeasure m => (Log Double, Log Double) -> m (Log Double, Log Double) 99 | appliedQ2' = svi weight_model obs q2' 100 | 101 | 102 | 103 | 104 | ------------------------------ 105 | ------------------------------ 106 | -- -- 107 | -- EXAMPLES FROM FIGURE 3 -- 108 | -- Control Flow Constructs -- 109 | -- -- 110 | ------------------------------ 111 | ------------------------------ 112 | 113 | 114 | -- simple1 :: MonadMeasure m => P m ("x" .== Double .+ "z" .== Double) Double 115 | simple1 = do 116 | x <- sample #x (nrm 0 1) 117 | z <- sample #z (nrm x 1) 118 | return $ x + z 119 | 120 | -- simple2 :: MonadMeasure m => P m ("x" .== Double .+ "z" .== Double) Double 121 | simple2 = do 122 | z <- sample #z (nrm 0 1) 123 | x <- sample #x (nrm z 1) 124 | return $ x + z 125 | 126 | -- branch1 :: MonadMeasure m => P m ("b" .== Bool .+ "p" .== Unit .+ "coin" .== Bool) Bool 127 | branch1 = do 128 | isBiased <- sample #b (brn $ mkUnit 0.1) 129 | p <- if isBiased then 130 | sample #p (bta 1 2) 131 | else 132 | sample #p (unif) 133 | sample #coin (brn p) 134 | 135 | -- branch2 136 | -- :: MonadMeasure m 137 | -- => P m ("p" .== Either (Rec ("isLow" .== Bool)) (Rec Empty) .+ "coin" .== Bool) Bool 138 | branch2 = do 139 | p <- withProbability #p (mkUnit 0.1) (do { 140 | isLow <- sample #isLow (brn (mkUnit 0.5)); 141 | return (if isLow then (mkUnit 0.01) else (mkUnit 0.99)) 142 | }) (return $ mkUnit 0.5) 143 | sample #coin (brn p) 144 | 145 | -- loop1 :: MonadMeasure m => P m ("pts" .== [Rec ("x" .== Double .+ "y" .== Double)]) [Double] 146 | loop1 = do 147 | pts <- forRandomRange #pts (pois 3) (\i -> do { 148 | x <- sample #x (nrm (fromIntegral i) 1); 149 | sample #y (nrm x 1) 150 | }) 151 | return $ map (\y -> 2 * y) pts 152 | 153 | -- loop2 :: MonadMeasure m => P m ("pts" .== Vector 3 (Rec ("y" .== Double))) [Double] 154 | loop2 = do 155 | pts <- forEach #pts (V.generate fromIntegral :: Vector 3 Int) (\x -> do { 156 | sample #y (nrm (fromIntegral x) 1) 157 | }) 158 | return $ map (\y -> 2 * y) (V.toList pts) 159 | 160 | 161 | 162 | ------------------------------ 163 | ------------------------------ 164 | -- -- 165 | -- EXAMPLES FROM FIGURE 4 -- 166 | -- End-to-End Example -- 167 | -- -- 168 | ------------------------------ 169 | ------------------------------ 170 | 171 | xs = V.generate (\i -> 172 | case i of 173 | 0 -> -3.0 174 | 1 -> -2.0 175 | 2 -> -1.0 176 | 3 -> 0.0 177 | 4 -> 1.0 178 | 5 -> 2.0 179 | 6 -> 3.0) :: Vector 7 Double 180 | 181 | 182 | ys = V.generate (\i -> 183 | case i of 184 | 0 -> #y .== 3.1 185 | 1 -> #y .== 1.8 186 | 2 -> #y .== 1.1 187 | 3 -> #y .== (-0.2) 188 | 4 -> #y .== (-1.2) 189 | 5 -> #y .== (-2.1) 190 | 6 -> #y .== (-3.0)) :: Vector 7 (Rec ("y" .== Double)) 191 | 192 | -- prior 193 | -- :: MonadMeasure m 194 | -- => P m ("noise" .== Log Double .+ "coeffs" .== [Rec ("c" .== Double)]) (Double -> Double, Log Double) 195 | prior = do z <- sample #noise (gmma 1 1) 196 | terms <- forRandomRange #coeffs (geom $ mkUnit 0.4) (\i -> do { 197 | coeff <- sample #c (nrm 0 1); 198 | return (\x -> coeff * (x ^ i)) 199 | }) 200 | let f = \x -> foldr (\t y -> y + (t x)) 0 terms 201 | return (f, z) 202 | 203 | -- p :: MonadMeasure m 204 | -- => P m ("noise" .== Log Double .+ "coeffs" .== [Rec ("c" .== Double)] .+ "data" .== Vector 7 (Rec ("y" .== Double))) (Vector 7 Double) 205 | p = do (f, z) <- BasicExamples.prior 206 | forEach #data xs (\x -> sample #y (nrm (f x) z)) 207 | 208 | 209 | data NNResult = NN { nnPi :: Unit, nnMu :: Double, nnSig :: Log Double, nnNoise :: Log Double } 210 | data QState = Q { qN :: Int, qResids :: Vector 7 Double, qNN:: NNResult } 211 | nn = undefined :: Vector 384 Double -> Vector 7 (Double, Double) -> NNResult 212 | 213 | -- q :: MonadMeasure m 214 | -- => (Vector 384 Double, Log Double) 215 | -- -> Rec ("data" .== Vector 7 (Rec ("y" .== Double))) 216 | -- -> P m ("noise" .== Log Double .+ "coeffs" .== [Rec ("c" .== Double)]) (Log Double) 217 | q (theta, sigma) obs = 218 | do let ys = V.map (\d -> d .! #y) (obs .! #data) 219 | let nnInputs n resids = V.zip (V.map (\x -> x^n) xs) resids 220 | let initialState = Q { qN = 0, qResids = ys, qNN = nn theta (nnInputs 0 ys) } 221 | finalState <- while #coeffs initialState (nnPi . qNN) (mkUnit 0.99) (\s -> do { 222 | c <- sample #c $ nrm (nnMu (qNN s)) (nnSig (qNN s)); 223 | let newResids = V.map (\(x, y) -> y - c * (x ^ (qN s))) (V.zip xs (qResids s)) in 224 | return (Q {qN = qN s + 1, qResids = newResids, qNN = nn theta (nnInputs (qN s + 1) newResids)}) 225 | }) 226 | sample #noise (lognormal (nnNoise (qNN finalState)) sigma) 227 | 228 | -- sampleCurve = sampleIO $ runWeightedT $ sampler $ traced p 229 | 230 | -- j = trainAmortized p q 231 | 232 | -- h = importance p (#data .== ys) prior 233 | 234 | 235 | 236 | 237 | 238 | 239 | ------------------------------ 240 | ------------------------------ 241 | -- -- 242 | -- HIDDEN MARKOV MODEL SMC -- 243 | -- -- 244 | ------------------------------ 245 | ------------------------------ 246 | 247 | 248 | 249 | -- simple_hmm_init :: MonadMeasure m => P m ("init" .== Double) Double 250 | simple_hmm_init = sample #init (nrm 0 10) 251 | 252 | -- simple_hmm_trans :: MonadMeasure m => Double -> P m ("x" .== Double .+ "y" .== Double) Double 253 | simple_hmm_trans state = do 254 | newState <- sample #x (nrm state 1) 255 | noisyObs <- sample #y (nrm newState 0.2) 256 | return newState 257 | 258 | hmm_steps = [#y .== 3.2, #y .== 5.1, #y .== 6.8, #y .== 8.9, #y .== 10.3] 259 | 260 | -- normalNormalPosterior :: MonadMeasure m => (Double, Double) -> Double -> Double -> D m Double 261 | normalNormalPosterior (mu1, sigma1) sigma2 obs = 262 | let var = 1 / ((1 / sigma1^2) + (1 / sigma2^2)) 263 | mu = var * (mu1 / sigma1^2 + obs / sigma2^2) 264 | in nrm mu (Exp (log var / 2)) 265 | 266 | -- hmm_proposal :: MonadMeasure m => Rec ("y" .== Double) -> Double -> P m ("x" .== Double) Double 267 | hmm_proposal o x = sample #x $ normalNormalPosterior (x, 1) 0.2 (o .! #y) 268 | 269 | -- hmm_proposals :: MonadMeasure m => [(Rec ("y" .== Double), Double -> P m ("x" .== Double) Double)] 270 | hmm_proposals = fmap (\o -> (o, hmm_proposal o)) hmm_steps 271 | 272 | -- hmm_init_proposal :: MonadMeasure m => P m ("init" .== Double) Double 273 | hmm_init_proposal = sample #init (nrm 0 10) 274 | 275 | -- customInference :: Int -> IO (Log Double) 276 | customInference n = runSamplerTIO $ evidence $ smcPush config $ 277 | particleFilter simple_hmm_init hmm_init_proposal simple_hmm_trans hmm_proposals 278 | where config = SMCConfig 279 | { resampler = resampleMultinomial 280 | , numSteps = n 281 | , numParticles = 5 282 | } 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /src/TraceTypes.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds, NamedFieldPuns, GADTs, KindSignatures, TypeFamilies, TypeOperators, TypeApplications, ScopedTypeVariables, MultiParamTypeClasses, OverloadedLabels, InstanceSigs, FlexibleContexts, AllowAmbiguousTypes, FlexibleInstances, RankNTypes, UndecidableInstances, ConstraintKinds, TypeFamilyDependencies #-} 2 | 3 | module TraceTypes where 4 | 5 | import Data.Kind (Type) 6 | import Data.Row.Records hiding ( Disjoint 7 | , sequence 8 | , map 9 | , zip 10 | ) 11 | import Data.Row.Internal ( Subset 12 | , type (.\\) 13 | , type (.\/) 14 | ) 15 | import Data.Default 16 | import Numeric.Log as Log 17 | import qualified Data.Vector.Sized as V 18 | import Data.Vector.Sized ( Vector ) 19 | import GHC.OverloadedLabels 20 | import Data.Finite 21 | import GHC.TypeLits 22 | import Control.Monad 23 | import Control.Monad.Bayes.Class as Bayes 24 | import Control.Monad.Bayes.Class ( MonadDistribution 25 | , MonadFactor 26 | , MonadMeasure 27 | ) 28 | import Statistics.Distribution (logDensity, complCumulative) 29 | import Statistics.Distribution.Gamma 30 | import Statistics.Distribution.Beta 31 | import Statistics.Distribution.Normal 32 | 33 | -- The "Constraint" type is intended only 34 | -- for internal use, and is not user-facing. 35 | -- A Constraint of a certain record type is 36 | -- a record where each field type `a` has been 37 | -- replaced with the type `Maybe a`; a value of 38 | -- `Nothing` indicates an unconstrained value, 39 | -- whereas `Just x` indicates an observed (i.e. 40 | -- constrained) value. 41 | type Constraint r = Rec (Map Maybe r) 42 | 43 | -- (ConcatTo l r s) means that l, r, and s are all 44 | -- valid record types, and that l .+ r = r .+ l = s. 45 | type ConcatTo l r s 46 | = ( WellBehaved l 47 | , WellBehaved r 48 | , WellBehaved s 49 | , Subset l s 50 | , Subset r s 51 | , (l .+ r) ≈ s 52 | , (r .+ l) ≈ s 53 | , s .\\ l ≈ r 54 | , s .\\ r ≈ l 55 | ) 56 | 57 | -- Are the label sets of l and r disjoint? 58 | type Disjoint l r 59 | = ( ConcatTo l r (l .+ r) 60 | , ConcatTo (Map Maybe l) (Map Maybe r) (Map Maybe (l .+ r)) 61 | , Forall (Map Maybe l) Default 62 | , Forall (Map Maybe r) Default 63 | ) 64 | 65 | -- Three main types: probability distributions (D), unnormalized density-carrying measures (U), 66 | -- and probabilistic programs (P). 67 | data U m a = U 68 | { usampler :: m a 69 | , udensity :: a -> Log Double 70 | } 71 | data D m a = D 72 | { dsampler :: m a 73 | , ddensity :: a -> Log Double 74 | } 75 | data P m t a = P 76 | { traced :: D m (Rec t) 77 | , retval :: Rec t -> a 78 | , constrain :: Constraint t -> U m (Rec t) 79 | } 80 | 81 | class DensityCarrying f m where 82 | sampler :: f m a -> m a 83 | density :: f m a -> a -> Log Double 84 | 85 | 86 | instance DensityCarrying U m where 87 | sampler = usampler 88 | density = udensity 89 | 90 | instance DensityCarrying D m where 91 | sampler = dsampler 92 | density = ddensity 93 | 94 | dist sampler density = D { dsampler = sampler, ddensity = density } 95 | udist sampler density = U { usampler = sampler, udensity = density } 96 | 97 | -- DEFAULT INFERENCE 98 | observe 99 | :: forall m a t s 100 | . (Disjoint t s, Monad m) 101 | => P m (t .+ s) a 102 | -> Rec t 103 | -> U m (Rec s) 104 | observe p obs = 105 | let 106 | -- Turn the observation record `obs` of type `t` into a Constraint of 107 | -- type `t .+ s`, by filling in the `s` piece of the record with `Nothing`. 108 | constraints :: Constraint (t .+ s) 109 | constraints = map' Just obs .+ (default' @Default def :: Constraint s) 110 | 111 | -- Obtain a constrained distribution over traces. 112 | u :: U m (Rec (t .+ s)) 113 | u = constrain p constraints 114 | in 115 | -- Modify the distribution u so that it is over only the unobserved values 116 | -- (Rec s), not full-but-constrained traces (Rec (t .+ s)). 117 | udist (fmap restrict $ sampler u) (density u . ((.+) obs)) 118 | 119 | traceSampler = sampler . traced 120 | traceDensity = density . traced 121 | 122 | -- UNIT INTERVAL TYPE 123 | 124 | newtype Unit = Sigmoid {logit :: Double} 125 | 126 | fromUnit :: Unit -> Double 127 | fromUnit (Sigmoid x) = 1.0 / (1.0 + exp (-x)) 128 | 129 | mkUnit :: Double -> Unit 130 | mkUnit u = if 0 < u && u < 1 131 | then Sigmoid (log (u / (1.0 - u))) 132 | else error "Number must be between 0 and 1." 133 | 134 | instance Show Unit where 135 | show x = show (fromUnit x) 136 | 137 | ------------------------------ 138 | -- STPL LANGUAGE CONSTRUCTS -- 139 | ------------------------------ 140 | 141 | -- RETURN 142 | 143 | pReturn :: Monad m => a -> P m Empty a 144 | pReturn x = 145 | let sampler = return Data.Row.Records.empty 146 | density _ = 1 147 | traced = dist sampler density 148 | constrain _ = udist sampler density 149 | retval _ = x 150 | in P { traced, retval, constrain } 151 | 152 | 153 | -- BIND 154 | 155 | pBind :: (Disjoint t s, Monad m) => P m t a -> (a -> P m s b) -> P m (t .+ s) b 156 | pBind f g = 157 | let 158 | bindSampler = do 159 | fTrace <- sampler $ traced f 160 | let x = retval f fTrace 161 | gTrace <- sampler $ traced (g x) 162 | return (fTrace .+ gTrace) 163 | 164 | bindDensity tr = 165 | let fDensity = density (traced f) (restrict tr) 166 | x = retval f (restrict tr) 167 | gDensity = density (traced $ g x) (restrict tr) 168 | in fDensity * gDensity 169 | 170 | bindTraced = dist bindSampler bindDensity 171 | 172 | bindRet tr = let x = retval f (restrict tr) in retval (g x) (restrict tr) 173 | 174 | constrainedSampler c = do 175 | fTrace <- sampler $ constrain f (restrict c) 176 | let x = retval f fTrace 177 | gTrace <- sampler $ constrain (g x) (restrict c) 178 | return (fTrace .+ gTrace) 179 | 180 | constrainedDensity c tr = 181 | let fDensity = density (constrain f (restrict c)) (restrict tr) 182 | x = retval f (restrict tr) 183 | gDensity = density (constrain (g x) (restrict c)) (restrict tr) 184 | in fDensity * gDensity 185 | 186 | bindConstrain c = udist (constrainedSampler c) (constrainedDensity c) 187 | in 188 | P { traced = bindTraced, retval = bindRet, constrain = bindConstrain } 189 | 190 | 191 | -- SAMPLE 192 | 193 | singleLabelConstrain 194 | :: (KnownSymbol l, MonadMeasure m) 195 | => Label l 196 | -> D m (Rec (l .== a)) 197 | -> Constraint (l .== a) 198 | -> U m (Rec (l .== a)) 199 | singleLabelConstrain lbl traced c = case c .! lbl of 200 | Just x -> 201 | let t = lbl .== x 202 | constrainedSampler = do 203 | score $ density traced t 204 | return t 205 | constrainedDensity _ = density traced t 206 | in udist constrainedSampler constrainedDensity 207 | 208 | Nothing -> udist (sampler traced) (density traced) 209 | 210 | 211 | sample :: (KnownSymbol l, MonadMeasure m) => Label l -> D m a -> P m (l .== a) a 212 | sample lbl d = 213 | let sampleTraced = dist 214 | (do 215 | x <- sampler d 216 | return (lbl .== x) 217 | ) 218 | (\t -> density d (t .! lbl)) 219 | 220 | sampleRet t = t .! lbl 221 | 222 | sampleConstrain = singleLabelConstrain lbl sampleTraced 223 | in P { traced = sampleTraced 224 | , retval = sampleRet 225 | , constrain = sampleConstrain 226 | } 227 | 228 | 229 | -- WITH PROBABILITY 230 | 231 | withProbability 232 | :: (KnownSymbol l, MonadMeasure m) 233 | => Label l 234 | -> Unit 235 | -> P m t a 236 | -> P m s a 237 | -> P m (l .== Either (Rec t) (Rec s)) a 238 | withProbability lbl logit f g = 239 | let p = fromUnit logit 240 | 241 | wpSampler = do 242 | r <- random 243 | if r < p 244 | then do 245 | t <- sampler $ traced f 246 | return (lbl .== Left t) 247 | else do 248 | t <- sampler $ traced g 249 | return (lbl .== Right t) 250 | 251 | wpDensity tr = case tr .! lbl of 252 | Left t -> Exp (log p) * density (traced f) t 253 | Right t -> Exp (log (1.0 - p)) * density (traced g) t 254 | 255 | wpTraced = dist wpSampler wpDensity 256 | 257 | wpRet t = case t .! lbl of 258 | Left t -> retval f t 259 | Right t -> retval g t 260 | 261 | wpConstrain = singleLabelConstrain lbl wpTraced 262 | in P { traced = wpTraced, retval = wpRet, constrain = wpConstrain } 263 | 264 | -- FOR EACH 265 | 266 | forEach 267 | :: (KnownSymbol l, KnownNat n, MonadMeasure m) 268 | => Label l 269 | -> Vector n a 270 | -> (a -> P m t b) 271 | -> P m (l .== Vector n (Rec t)) (Vector n b) 272 | forEach lbl xs body = 273 | let 274 | forEachSampler = do 275 | traces <- V.mapM (sampler . traced . body) xs 276 | return (lbl .== traces) 277 | 278 | forEachDensity t = V.product 279 | (V.map (\(x, tr) -> density (traced $ body x) tr) (V.zip xs (t .! lbl))) 280 | 281 | forEachTraced = dist forEachSampler forEachDensity 282 | 283 | forEachRet t = V.map (\(x, tr) -> retval (body x) tr) (V.zip xs (t .! lbl)) 284 | 285 | forEachConstrain = singleLabelConstrain lbl forEachTraced 286 | in 287 | P { traced = forEachTraced 288 | , retval = forEachRet 289 | , constrain = forEachConstrain 290 | } 291 | 292 | 293 | -- FOR ... IN RANDOM RANGE ... 294 | 295 | forRandomRange 296 | :: (KnownSymbol l, MonadMeasure m) 297 | => Label l 298 | -> D m Int 299 | -> (Int -> P m t a) 300 | -> P m (l .== [Rec t]) [a] 301 | forRandomRange lbl d body = 302 | let 303 | -- Sample a trace of the for loop. 304 | forSampler = do 305 | n <- sampler d 306 | traces <- sequence $ map (sampler . traced . body) [0 .. n - 1] 307 | return (lbl .== traces) 308 | 309 | -- Compute the density of a trace of the for loop. 310 | forDensity t = 311 | density d (length $ t .! lbl) 312 | * product 313 | (map (\(tr, i) -> density (traced $ body i) tr) 314 | (zip (t .! lbl) [0 ..]) 315 | ) 316 | 317 | -- The for loop's density-carrying distribution over traces 318 | forTraced = dist forSampler forDensity 319 | 320 | -- Compute the return value of the for loop, given a trace of an execution. 321 | -- The return value is just a list of the return values from each iteration. 322 | forRet t = map (\(tr, i) -> retval (body i) tr) (zip (t .! lbl) [0 ..]) 323 | 324 | -- Constrain the random choices made inside the for loop. 325 | forConstrain = singleLabelConstrain lbl forTraced 326 | in 327 | P { traced = forTraced, retval = forRet, constrain = forConstrain } 328 | 329 | {- WHILE: 330 | 331 | A stochastic while loop. 332 | Takes as input: 333 | - lbl : The label at which to trace the while loop's random choices. 334 | - init : The initial state of the while loop. 335 | - pi : A function computing a probability of continuing, based on the current state. 336 | - pi_max : An upper bound beyond which the probability of continuing is truncated. 337 | - body : The body of the while loop, which simulates a new state based on the current state. 338 | 339 | The resulting probabilistic program has type P m (l .== [Rec t]) a: 340 | - The trace contains, at the specified label, a list of subtraces, one for each iteration. 341 | - The return value is the final state at the finish of the while loop. 342 | 343 | -} 344 | 345 | while 346 | :: (KnownSymbol l, MonadMeasure m) 347 | => Label l 348 | -> a 349 | -> (a -> Unit) 350 | -> Unit 351 | -> (a -> P m t a) 352 | -> P m (l .== [Rec t]) a 353 | while lbl init pi pi_max body = 354 | let 355 | -- Sample a trace of the while loop. 356 | whileSampler state = do 357 | shouldContinue <- bernoulli $ min (fromUnit (pi state)) (fromUnit pi_max) 358 | if shouldContinue 359 | then do 360 | nextTrace <- sampler (traced $ body state) 361 | let nextState = retval (body state) nextTrace 362 | restOfTraces <- whileSampler nextState 363 | return (lbl .== (nextTrace : (restOfTraces .! lbl))) 364 | else return (lbl .== []) 365 | 366 | -- Given traces from each iteration, and an initial state, 367 | -- compute the final state. 368 | retvalFromTraces init ts = case ts of 369 | [] -> init 370 | t : ts -> retvalFromTraces (retval (body init) t) ts 371 | 372 | -- Given a trace of the entire while loop, compute the final 373 | -- state. 374 | whileRet t = retvalFromTraces init (t .! lbl) 375 | 376 | -- Given an initial state and traces from each iteration, 377 | -- compute the density. 378 | densityFromTraces init ts = case ts of 379 | [] -> Exp (log (1.0 - (fromUnit (pi init)))) 380 | t : ts -> 381 | Exp (log (fromUnit (pi init))) 382 | * (density (traced $ body init) t) 383 | * densityFromTraces (retval (body init) t) ts 384 | 385 | -- Given a trace of the entire while loop, compute the density. 386 | whileDensity t = densityFromTraces init (t .! lbl) 387 | 388 | -- The while loop's density-carrying distribution over traces. 389 | whileTraced = dist (whileSampler init) whileDensity 390 | 391 | -- Constrain the choices made inside the loop. 392 | whileConstrain = singleLabelConstrain lbl whileTraced 393 | 394 | in P { traced = whileTraced, retval = whileRet, constrain = whileConstrain } 395 | 396 | ------------------- 397 | -- DISTRIBUTIONS -- 398 | ------------------- 399 | 400 | nrm :: MonadMeasure m => Double -> Log Double -> D m Double 401 | nrm mu (Exp sigln) = dist (normal mu (exp sigln)) (normalPdf mu (exp sigln)) 402 | 403 | brn :: MonadDistribution m => Unit -> D m Bool 404 | brn logit = 405 | let p = fromUnit logit 406 | in dist (bernoulli p) (\b -> Exp . log $ if b then p else (1.0 - p)) 407 | 408 | gmma :: MonadDistribution m => Log Double -> Log Double -> D m (Log Double) 409 | gmma shape scale = 410 | let toDouble (Exp l) = exp l 411 | toLogDouble d = Exp (log d) 412 | in 413 | dist (fmap toLogDouble (gamma (toDouble shape) (toDouble scale))) 414 | (Exp . (logDensity $ Statistics.Distribution.Gamma.gammaDistr 415 | (toDouble shape) (toDouble scale)) 416 | . toDouble) 417 | 418 | geom :: MonadDistribution m => Unit -> D m Int 419 | geom p' = 420 | let p = fromUnit p' 421 | in dist (geometric p) 422 | (\n -> Exp (fromIntegral n * (log p) + (log (1.0 - p)))) 423 | 424 | pois :: MonadDistribution m => Int -> D m Int 425 | pois = undefined 426 | 427 | bta :: MonadDistribution m => Log Double -> Log Double -> D m Unit 428 | bta (Exp lna) (Exp lnb) = 429 | let 430 | a = exp lna 431 | b = exp lnb 432 | in 433 | dist (fmap mkUnit (beta a b)) 434 | (Exp . (logDensity $ Statistics.Distribution.Beta.betaDistr a b) 435 | . fromUnit) 436 | 437 | unif :: MonadDistribution m => D m Unit 438 | unif = dist (fmap mkUnit Bayes.random) (\x -> 1.0 :: Log Double) 439 | 440 | cat :: (KnownNat n, MonadDistribution m) => Vector n Unit -> D m (Finite n) 441 | cat probs = undefined 442 | 443 | lognormal :: forall m. MonadMeasure m => Log Double -> Log Double -> D m (Log Double) 444 | lognormal (Exp logmu) sigma = 445 | dist (fmap Exp $ sampler $ nrm logmu sigma) 446 | (\(Exp logx) -> density (nrm logmu sigma :: D m Double) logx / (Exp logx)) 447 | 448 | 449 | truncnrm :: MonadMeasure m => Log Double -> Log Double -> D m (Log Double) 450 | truncnrm (Exp logmu) (Exp logsigma) = 451 | let mu = exp logmu 452 | sigma = exp logsigma 453 | unbounded = nrm mu (Exp logsigma) 454 | 455 | truncNormSampler = do 456 | x <- sampler unbounded 457 | if x < 0 then truncNormSampler else return (Exp $ log x) 458 | 459 | truncNormPdf (Exp logy) = 460 | let y = exp logy 461 | z = complCumulative (normalDistr mu sigma) 0 462 | in (density unbounded y / (Exp (log z))) 463 | in dist truncNormSampler truncNormPdf 464 | 465 | 466 | 467 | ---------------------------- 468 | -- PROGRAMMABLE INFERENCE -- 469 | ---------------------------- 470 | 471 | -- IMPORTANCE SAMPLING 472 | 473 | importance 474 | :: (Disjoint t s, MonadMeasure m) 475 | => P m (t .+ s) a 476 | -> Rec t 477 | -> P m s b 478 | -> m (Rec s) 479 | importance p t q = 480 | let target = observe p t 481 | in do 482 | tr <- sampler $ traced q 483 | score ((density target tr) / (density (traced q) tr)) 484 | return tr 485 | 486 | -- SMC 487 | 488 | unroll 489 | :: (MonadMeasure m, Disjoint t s) 490 | => P m i a 491 | -> (a -> P m (t .+ s) a) 492 | -> [Rec t] 493 | -> m (Rec i, [Rec s]) 494 | unroll init next steps = do 495 | tInit <- sampler $ traced init 496 | let processStep soFar t = do { 497 | (s, results) <- soFar; 498 | tNext <- sampler $ observe (next s) t; 499 | return (retval (next s) (tNext .+ t), tNext : results) 500 | } 501 | (s, l) <- foldl processStep (return (retval init tInit, [])) steps 502 | return (tInit, reverse l) 503 | 504 | particleFilter 505 | :: (MonadMeasure m, Disjoint t s, 506 | Disjoint Empty i, (Empty .+ i) ≈ i) 507 | => P m i a 508 | -> P m i b 509 | -> (a -> P m (t .+ s) a) 510 | -> [(Rec t, a -> P m s c)] 511 | -> m (Rec i, [Rec s]) 512 | particleFilter init qInit next steps = do 513 | tInit <- importance init Data.Row.Records.empty qInit 514 | let processStep soFar (t, q) = do { 515 | (s, results) <- soFar; 516 | tNext <- importance (next s) t (q s); 517 | return (retval (next s) (tNext .+ t), tNext : results) 518 | } 519 | (s, l) <- foldl processStep (return (retval init tInit, [])) steps 520 | return (tInit, reverse l) 521 | 522 | 523 | -- MCMC 524 | 525 | type K m t (s :: Row Type) = U m (Rec t) -> Rec t -> m (Rec t) 526 | 527 | mh 528 | :: forall s t m a 529 | . (Disjoint t s, MonadMeasure m) 530 | => (Rec (t .+ s) -> P m t a) 531 | -> K m (t .+ s) t 532 | mh q p old = 533 | let proposal = do 534 | new <- sampler (traced $ q old) 535 | return $ new .+ (restrict old :: Rec s) 536 | rho (x, y) = 537 | ((density p y) * (density (traced $ q y) (restrict x))) 538 | / ((density p x) * (density (traced $ q x) (restrict y))) 539 | in do 540 | proposed <- proposal 541 | r <- Bayes.random 542 | if Exp (log r) < (min 1 (rho (old, proposed))) 543 | then return proposed 544 | else return old 545 | 546 | seqK :: Monad m => K m t s -> K m t r -> K m t (s .\/ r) 547 | seqK k1 k2 p old = do 548 | t' <- k1 p old 549 | k2 p t' 550 | 551 | mixK :: MonadDistribution m => Double -> K m t s -> K m t r -> K m t (s .\/ r) 552 | mixK p k1 k2 model old = do 553 | r <- Bayes.random 554 | if r < p then k1 model old else k2 model old 555 | 556 | repeatK :: Monad m => Int -> K m t s -> K m t s 557 | repeatK n k p old = if n == 0 558 | then return old 559 | else do 560 | t <- k p old 561 | repeatK (n - 1) k p t 562 | 563 | ifK 564 | :: (Disjoint t s, Monad m) 565 | => (Rec t -> Bool) 566 | -> K m (t .+ s) s 567 | -> K m (t .+ s) s 568 | ifK b k p old = if (b (restrict old)) then k p old else return old 569 | 570 | -- VARIATIONAL (STUBBED) 571 | 572 | svi 573 | :: (Disjoint t s, Monad m) 574 | => P m (t .+ s) a 575 | -> Rec t 576 | -> (theta -> P m s b) 577 | -> theta 578 | -> m theta 579 | svi p t q theta = undefined 580 | -- do tr <- traceSampler $ q theta 581 | -- let sco = density (observe p t) tr 582 | -- let gradient = grad (\th -> traceDensity (q th) tr) theta 583 | -- return theta + stepSize * sco * gradient 584 | 585 | trainAmortized 586 | :: (Disjoint t s, Monad m) 587 | => P m (t .+ s) a 588 | -> (theta -> Rec t -> P m s b) 589 | -> theta 590 | -> m theta 591 | trainAmortized p q theta = undefined 592 | -- do tr <- traceSampler p 593 | -- let gradient = grad (\th -> traceDensity (q th (restrict tr)) (restrict tr)) theta 594 | -- return theta + step * gradient 595 | 596 | 597 | ifThenElse c t f = if c then t else f 598 | --------------------------------------------------------------------------------