├── .envrc ├── .github ├── actions │ └── nix-common-setup │ │ └── action.yml ├── dependabot.yml └── workflows │ ├── ci.yml │ ├── nix.yml │ └── update-flake-lock.yml ├── .gitignore ├── .hlint.yaml ├── CHANGELOG.md ├── LICENSE.md ├── MAINTAINERS.md ├── README.md ├── Setup.hs ├── benchmark ├── SSM.hs ├── Single.hs └── Speed.hs ├── default.nix ├── docs ├── docs │ ├── examples.md │ ├── images │ │ ├── code_example.png │ │ ├── haskell-logo.png │ │ ├── plot.png │ │ ├── priorpred.png │ │ ├── randomwalk.png │ │ └── regress.png │ ├── index.md │ ├── javascripts │ │ └── mathjax.js │ ├── notebooks │ │ ├── AdvancedSampling.html │ │ ├── Bayesian.html │ │ ├── ClassicalPhysics.html │ │ ├── Diagrams.html │ │ ├── Functional_PPLs.html │ │ ├── Histogram.html │ │ ├── Introduction.html │ │ ├── Ising.html │ │ ├── Lazy.html │ │ ├── Lenses.html │ │ ├── MCMC.html │ │ ├── Parsing.html │ │ ├── RealTimeInference.html │ │ ├── SMC.html │ │ ├── Sampling.html │ │ └── Streaming.html │ ├── probprog.md │ ├── tutorials.md │ └── usage.md ├── mkdocs.yml ├── netlify.toml ├── requirements.txt └── runtime.txt ├── flake.lock ├── flake.nix ├── kernels └── haskell.nix ├── models ├── BetaBin.hs ├── ConjugatePriors.hs ├── Dice.hs ├── HMM.hs ├── Helper.hs ├── LDA.hs ├── LogReg.hs ├── NestedInference.hs ├── NonlinearSSM.hs ├── NonlinearSSM │ └── Algorithms.hs ├── Sprinkler.hs └── StrictlySmallerSupport.hs ├── monad-bayes.cabal ├── notebooks ├── Implementation.ipynb ├── _build │ └── _page │ │ └── Introduction │ │ └── html │ │ └── _sphinx_design_static │ │ ├── design-style.b7bb847fb20b106c3d81b95245e65545.min.css │ │ └── design-tabs.js ├── examples │ ├── ClassicalPhysics.ipynb │ ├── Diagrams.ipynb │ ├── Histogram.ipynb │ ├── Ising.ipynb │ ├── Lenses.ipynb │ ├── Parsing.ipynb │ └── Streaming.ipynb ├── file.json ├── models │ ├── .ipynb_checkpoints │ │ └── LDA-checkpoint.ipynb │ └── LDA.ipynb ├── plotting.hs └── tutorials │ ├── AdvancedSampling.ipynb │ ├── Bayesian.ipynb │ ├── Introduction.ipynb │ ├── Lazy.ipynb │ ├── MCMC.ipynb │ ├── SMC.ipynb │ └── Sampling.ipynb ├── plots.py ├── profile.sh ├── regenerate_notebooks.sh ├── shell.nix ├── src ├── Control │ ├── Applicative │ │ └── List.hs │ └── Monad │ │ └── Bayes │ │ ├── Class.hs │ │ ├── Density │ │ ├── Free.hs │ │ └── State.hs │ │ ├── Enumerator.hs │ │ ├── Inference │ │ ├── Lazy │ │ │ ├── MH.hs │ │ │ └── WIS.hs │ │ ├── MCMC.hs │ │ ├── PMMH.hs │ │ ├── RMSMC.hs │ │ ├── SMC.hs │ │ ├── SMC2.hs │ │ └── TUI.hs │ │ ├── Integrator.hs │ │ ├── Population.hs │ │ ├── Sampler │ │ ├── Lazy.hs │ │ └── Strict.hs │ │ ├── Sequential │ │ └── Coroutine.hs │ │ ├── Traced.hs │ │ ├── Traced │ │ ├── Basic.hs │ │ ├── Common.hs │ │ ├── Dynamic.hs │ │ └── Static.hs │ │ └── Weighted.hs └── Math │ └── Integrators │ └── StormerVerlet.hs ├── stack.yaml └── test ├── Spec.hs ├── TestAdvanced.hs ├── TestBenchmarks.hs ├── TestDistribution.hs ├── TestEnumerator.hs ├── TestInference.hs ├── TestIntegrator.hs ├── TestPipes.hs ├── TestPopulation.hs ├── TestSSMFixtures.hs ├── TestSampler.hs ├── TestSequential.hs ├── TestStormerVerlet.hs ├── TestWeighted.hs └── fixtures ├── HMM10-MH.txt ├── HMM10-RMSMC.txt ├── HMM10-SMC.txt ├── LDA10-MH.txt ├── LDA10-RMSMC.txt ├── LDA10-SMC.txt ├── LR10-MH.txt ├── LR10-RMSMC.txt ├── LR10-SMC.txt ├── SSM-PMMH.txt ├── SSM-RMSMC.txt ├── SSM-RMSMCBasic.txt ├── SSM-RMSMCDynamic.txt ├── SSM-SMC.txt └── SSM-SMC2.txt /.envrc: -------------------------------------------------------------------------------- 1 | # Make sure you have direnv >= 2.30 2 | use flake --extra-experimental-features nix-command --extra-experimental-features flakes 3 | -------------------------------------------------------------------------------- /.github/actions/nix-common-setup/action.yml: -------------------------------------------------------------------------------- 1 | name: Setup Nix Environment 2 | inputs: 3 | CACHIX_AUTH_TOKEN: 4 | required: true 5 | description: 'Cachix Auth Token' 6 | runs: 7 | using: "composite" 8 | steps: 9 | 10 | - name: Installing Nix 11 | uses: cachix/install-nix-action@v22 12 | with: 13 | nix_path: nixpkgs=channel:nixos-unstable 14 | 15 | - uses: cachix/cachix-action@v12 16 | with: 17 | name: tweag-monad-bayes 18 | authToken: "${{ inputs.CACHIX_AUTH_TOKEN }}" 19 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | 4 | - package-ecosystem: github-actions 5 | directory: "/" 6 | schedule: 7 | interval: daily 8 | time: '00:00' 9 | timezone: UTC 10 | open-pull-requests-limit: 10 11 | commit-message: 12 | prefix: "chore" 13 | include: "scope" 14 | 15 | # By default, when `package-ecosystem: github-actions` is set, 16 | # dependabot only looks in the `.github/workflows` directory, 17 | # even when setting `directory: "/"`. 18 | # But we need to keep updating the common nix setup as well. 19 | # Hopefully the following works. 20 | - package-ecosystem: github-actions 21 | directory: "/.github/actions" 22 | schedule: 23 | interval: daily 24 | time: '00:00' 25 | timezone: UTC 26 | open-pull-requests-limit: 10 27 | commit-message: 28 | prefix: "chore-actions" 29 | include: "scope" 30 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - newdocs 6 | - master 7 | permissions: 8 | contents: write 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: 3.x 17 | - run: pip install mkdocs-material 18 | # - run: mkdocs gh-deploy --force 19 | -------------------------------------------------------------------------------- /.github/workflows/nix.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "nix-build" 3 | 4 | on: 5 | pull_request: 6 | push: 7 | branches: 8 | - master 9 | 10 | jobs: 11 | 12 | lint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Setup Nix Environment 17 | uses: ./.github/actions/nix-common-setup 18 | with: 19 | CACHIX_AUTH_TOKEN: ${{ secrets.CACHIX_AUTH_TOKEN }} 20 | - name: Lint 21 | run: nix --print-build-logs build .#pre-commit 22 | 23 | build: 24 | needs: lint 25 | strategy: 26 | matrix: 27 | include: 28 | - os: ubuntu-latest 29 | system: x86_64-linux 30 | - os: macos-latest 31 | system: x86_64-darwin 32 | runs-on: ${{ matrix.os }} 33 | steps: 34 | - uses: actions/checkout@v4 35 | - name: Setup Nix Environment 36 | uses: ./.github/actions/nix-common-setup 37 | with: 38 | CACHIX_AUTH_TOKEN: ${{ secrets.CACHIX_AUTH_TOKEN }} 39 | - name: Build 40 | run: nix --print-build-logs build .#packages.${{ matrix.system }}.monad-bayes 41 | - name: Development environment (package only) 42 | run: nix --print-build-logs develop .#packages.${{ matrix.system }}.monad-bayes --command echo Ready 43 | 44 | build-all-ghcs: 45 | needs: lint 46 | strategy: 47 | matrix: 48 | ghc: ["ghc902", "ghc927", "ghc945", "ghc964", "ghc982", "ghc9101"] 49 | include: 50 | - os: ubuntu-latest 51 | system: x86_64-linux 52 | - os: macos-latest 53 | system: x86_64-darwin 54 | runs-on: ${{ matrix.os }} 55 | steps: 56 | - uses: actions/checkout@v4 57 | - name: Setup Nix Environment 58 | uses: ./.github/actions/nix-common-setup 59 | with: 60 | CACHIX_AUTH_TOKEN: ${{ secrets.CACHIX_AUTH_TOKEN }} 61 | - name: Build 62 | run: nix --print-build-logs build .#packages.${{ matrix.system }}.monad-bayes-per-ghc.${{ matrix.ghc }} 63 | - name: Development environment (package only) 64 | run: nix --print-build-logs develop .#packages.${{ matrix.system }}.monad-bayes-per-ghc.${{ matrix.ghc }} --command echo Ready 65 | 66 | notebooks: 67 | needs: 68 | - build 69 | - build-all-ghcs 70 | strategy: 71 | matrix: 72 | include: 73 | - os: ubuntu-latest 74 | system: x86_64-linux 75 | # Jupyenv doesn't support Darwin yet, https://github.com/tweag/jupyenv/issues/388 76 | # - os: macos-latest 77 | # system: x86_64-darwin 78 | runs-on: ${{ matrix.os }} 79 | steps: 80 | - uses: actions/checkout@v4 81 | - name: Setup Nix Environment 82 | uses: ./.github/actions/nix-common-setup 83 | with: 84 | CACHIX_AUTH_TOKEN: ${{ secrets.CACHIX_AUTH_TOKEN }} 85 | - name: Development environment (complete) 86 | run: nix --print-build-logs develop --command echo Ready 87 | - name: Check whether notebook *.html files are up to date 88 | run: | 89 | ./regenerate_notebooks.sh 90 | git diff --exit-code || (echo "Please update notebooks by running regenerate_notebooks.sh, and inspecting and committing the result." && exit 1) 91 | -------------------------------------------------------------------------------- /.github/workflows/update-flake-lock.yml: -------------------------------------------------------------------------------- 1 | name: update-flake-lock 2 | on: 3 | workflow_dispatch: # allows manual triggering 4 | schedule: 5 | - cron: '0 0 * * 0' # runs weekly on Sunday at 00:00 6 | 7 | jobs: 8 | lockfile: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout repository 12 | uses: actions/checkout@v4 13 | - name: Install Nix 14 | uses: cachix/install-nix-action@v31 15 | with: 16 | extra_nix_config: | 17 | access-tokens = github.com=${{ secrets.GITHUB_TOKEN }} 18 | - name: Update flake.lock 19 | uses: DeterminateSystems/update-flake-lock@v25 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/data 2 | docs/site 3 | venv 4 | _cache 5 | docs/build 6 | *.csv 7 | *.pdf 8 | *.dat 9 | *.aux 10 | *.chi 11 | *.chs.h 12 | *.dyn_hi 13 | *.dyn_o 14 | *.eventlog 15 | *.hi 16 | *.hp 17 | stack.yaml.lock 18 | .vscode/* 19 | *.o 20 | *.prof 21 | .HTF/ 22 | .cabal-sandbox/ 23 | .ghc.environment.* 24 | .hpc 25 | .hsenv 26 | .vscode 27 | .envrc 28 | .stack-work/ 29 | cabal-dev 30 | cabal.project.local 31 | cabal.project.local~ 32 | cabal.sandbox.config 33 | dist 34 | dist-* 35 | /result 36 | /.pre-commit-config.yaml 37 | /.direnv 38 | venv/ 39 | .venv/ 40 | .ipynb_checkpoints/ 41 | .netlify/ 42 | .jupyter/ 43 | build/ 44 | cache/ 45 | -------------------------------------------------------------------------------- /.hlint.yaml: -------------------------------------------------------------------------------- 1 | - ignore: {} 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 1.3.0.4 2 | 3 | - Allowed GHC 9.10 4 | - Updated some version bounds 5 | 6 | # 1.3.0.3 7 | 8 | - Relaxed some version bounds 9 | 10 | # 1.3.0.2 11 | 12 | - Relaxed some version bounds 13 | 14 | # 1.3.0.1 15 | 16 | - Support for GHC 9.8 17 | 18 | # 1.3.0 19 | 20 | - Support for GHC 9.6 21 | - Replaced transformers' `ListT` by 22 | (https://github.com/tweag/monad-bayes/pull/295) 23 | - Naming fixes for `Sampler` and `SamplerT` 24 | 25 | # 1.2.0 26 | 27 | - Renamed monad transformers idiomatically 28 | (https://github.com/tweag/monad-bayes/pull/295) 29 | 30 | # 1.1.1 31 | 32 | - add fixture tests for benchmark models 33 | - extensive documentation improvements 34 | - add `poissonPdf` 35 | - Fix TUI inference 36 | - Fix flaky test 37 | - Support GHC 9.4 38 | 39 | # 1.1.0 40 | 41 | - extensive notebook improvements 42 | 43 | # 1.0.0 (2022-09-10) 44 | 45 | - host website from repo 46 | - host notebooks from repo 47 | - use histogram-fill 48 | 49 | # 0.2.0 (2022-07-26) 50 | 51 | - rename various functions to match the names of the corresponding types (e.g. `Enumerator` goes with `enumerator`) 52 | - add configs as arguments to inference methods `smc` and `mcmc` 53 | - add rudimentary tests for all inference methods 54 | - put `mcmc` as inference method in new module `Control.Monad.Bayes.Inference.MCMC` 55 | - update history of changelog in line with semantic versioning conventions 56 | - bumped to GHC 9.2.3 57 | 58 | # 0.1.5 (2022-07-26) 59 | 60 | - Refactor of sampler to be parametric in the choice of a pair of IO monad and RNG 61 | 62 | # 0.1.4 (2022-06-15) 63 | 64 | Addition of new helper functions, plotting tools, tests, and Integrator monad. 65 | 66 | - helpers include: `toEmpirical` (list of samples to empirical distribution) and `toBins` (simple histogramming) 67 | - `Integrator` is an instance of `MonadDistribution` for numerical integration 68 | - `notebooks` now contains working notebook-based tutorials and examples 69 | - new tests, including with conjugate distributions to compare analytic solution against inferred posterior 70 | - `models` directory is cleaned up. New sequential models using `pipes` package to represent monadic streams 71 | 72 | # 0.1.3 (2022-06-08) 73 | 74 | Clean up of unused functions and broken code 75 | 76 | - remove unused functions in `Weighted` and `Population` 77 | - remove broken models in `models` 78 | - explicit imports 79 | - added some global language extensions 80 | 81 | # 0.1.2 (2022-06-08) 82 | 83 | Add documentation 84 | 85 | - docs written in markdown 86 | - docs built by sphinx 87 | 88 | # 0.1.1 (2020-04-08) 89 | 90 | - New exported function: `Control.Monad.Bayes.Class` now exports `discrete`. 91 | 92 | # 0.1.0 (2020-02-17) 93 | 94 | Initial release. 95 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015-2020 Adam Scibior 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # GHC compatibility and Cabal dependency version bounds 2 | 3 | ## Overview 4 | 5 | `monad-bayes` supports the three most recent [major versions][ghc-major] of 6 | GHC: 7 | - CI builds and tests against **all supported versions**. The CI setup is the 8 | source of truth for which GHC versions `monad-bayes` supports. 9 | - The local environment (e.g., stack.yaml) sets up **a supported version** of 10 | GHC. 11 | - The Cabal dependency version bounds for each dependency are as follows: 12 | The **lower bound** is taken from `cabal gen-bounds` run against the oldest 13 | supported GHC version. The **upper bound** is taken from `cabal gen-bounds` 14 | run against the newest supported GHC version. 15 | 16 | ## What to do when a new major GHC version has been released 17 | 18 | A **new major GHC version** has been released. Here's what you need to do: 19 | - **Add the new major GHC version** to the CI build matrix and **drop the 20 | oldest version** that was previously supported. 21 | - Make sure the the **local environment** (e.g., stack.yaml) still sets up a 22 | supported version of GHC. If not, update it. 23 | - Update the Cabal **dependency bounds** as described above. 24 | 25 | ## How to release a new version 26 | 27 | - Open a separate branch `release` and a merge request. On this branch, do the following: 28 | - Update the file `CHANGELOG.md`, using the diff to the last release. 29 | - Increment the version in `monad-bayes.cabal`. See the [Hackage Package 30 | Versioning Policy][hackage-pvp]. 31 | - Upload the package candidate sources: 32 | ```console 33 | $ dir=$(mktemp -d monad-bayes-sdist.XXXXXX) 34 | $ cabal v2-sdist --builddir="$dir" 35 | $ cabal upload --user= --password= "$dir/sdist/*.tar.gz" 36 | $ rm -rf "$dir" 37 | ``` 38 | - Upload the package candidate documentation: 39 | ```console 40 | $ dir=$(mktemp -d monad-bayes-docs.XXXXXX) 41 | $ cabal v2-haddock --builddir="$dir" --haddock-for-hackage --enable-doc 42 | $ cabal upload --documentation --user= --password= "$dir/*-docs.tar.gz" 43 | $ rm -rf "$dir" 44 | ``` 45 | - Check the candidate's Hackage page, make sure everything looks as expected. 46 | - When you're ready, and the CI passes for your merge request, repeat the above `cabal upload` commands (for sources and 47 | documentation), adding `--publish` so the uploads are no longer marked as 48 | candidates but as proper releases. 49 | - Merge the `release` branch. 50 | - Add a `git` tag in the form `vmajor-major-minor`, e.g. `v1.1.0`, to the commit that was uploaded: 51 | ```console 52 | git tag v1.2.3 53 | git push --tags 54 | ``` 55 | 56 | [ghc-major]: https://gitlab.haskell.org/ghc/ghc/wikis/working-conventions/releases#major-releases 57 | [hackage-pvp]: https://pvp.haskell.org/ 58 | 59 | 60 | ## Documentation 61 | 62 | The docs are built with MkDocs. Serve locally with: `mkdocs serve`. Site is served online with Netlify. 63 | 64 | # Benchmarking 65 | 66 | ## Quick benchmark 67 | 68 | * Run `cabal run single -- -m MODEL -a ALG` 69 | * For `MODEL`, insert e.g. `LR100` 70 | * For `ALG`, insert e.g. `SMC` 71 | * See `benchmark/Single.hs` for details 72 | 73 | ## Extensive benchmark 74 | 75 | * Run `cabal bench speed-bench` 76 | * It will run several benchmarks of differing complexity, and try to plot them using Python Pandas 77 | * Look at `samples.pdf` 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Monad-Bayes](https://monad-bayes.netlify.app/) 2 | 3 | A library for probabilistic programming in Haskell. 4 | 5 | 9 | 10 | [See the docs](https://monad-bayes.netlify.app/) for a user guide, notebook-style tutorials, an example gallery, and a detailed account of the implementation. 11 | 12 | 13 | 14 | 15 | 16 | Created by [Adam Scibior][adam-web] ([@adscib][adam-github]), documentation, website and newer features by [Reuben][reuben-web], maintained by [Tweag][tweagio]. 17 | 18 | ## Project status 19 | 20 | Now that `monad-bayes` has been released on Hackage, and the documentation and the API has been updated, we will focus on adding new features. See the Github issues to get a sense of what is being prepared, and please feel free to make requests. 21 | 22 | ## Background 23 | 24 | The basis for the code in this repository is the ICFP 2018 paper [2]. For the 25 | code associated with the Haskell2015 paper [1], see the [`haskell2015` 26 | tag][haskell2015-tag]. 27 | 28 | [1] Adam M. Ścibior, Zoubin Ghahramani, and Andrew D. Gordon. 2015. [Practical 29 | probabilistic programming with monads][haskell2015-doi]. In _Proceedings of the 30 | 2015 ACM SIGPLAN Symposium on Haskell_ (Haskell ’15), Association for Computing 31 | Machinery, Vancouver, BC, Canada, 165–176. 32 | 33 | [2] Adam M. Ścibior, Ohad Kammar, and Zoubin Ghahramani. 2018. [Functional 34 | programming for modular Bayesian inference][icfp2018-doi]. In _Proceedings of 35 | the ACM on Programming Languages_ Volume 2, ICFP (July 2018), 83:1–83:29. 36 | 37 | [3] Adam M. Ścibior. 2019. [Formally justified and modular Bayesian inference 38 | for probabilistic programs][thesis-doi]. Thesis. University of Cambridge. 39 | 40 | ## Hacking 41 | 42 | 1. Install `stack` by following [these instructions][stack-install]. 43 | 44 | 2. Clone the repository using one of these URLs: 45 | ``` 46 | git clone git@github.com:tweag/monad-bayes.git 47 | git clone https://github.com/tweag/monad-bayes.git 48 | ``` 49 | 50 | Now you can use `stack build`, `stack test` and `stack ghci`. 51 | 52 | **To view the notebooks, go to the website**. To use the notebooks interactively: 53 | 54 | 1. Compile the source: `stack build` 55 | 2. If you do not have `nix` [install it](https://nixos.org/download.html). 56 | 3. Run `nix develop --system x86_64-darwin --extra-experimental-features nix-command --extra-experimental-features flakes` - this should open a nix shell. For Linux use `x86_64-linux` for `--system` option instead. 57 | 4. Run `jupyter-lab` from the nix shell to load the notebooks. 58 | 59 | Your mileage may vary. 60 | 61 | [adam-github]: https://github.com/adscib 62 | [adam-web]: https://www.cs.ubc.ca/~ascibior/ 63 | [reuben-web]: https://reubencohngordon.com/ 64 | [haskell2015-doi]: https://doi.org/10.1145/2804302.2804317 65 | [haskell2015-tag]: https://github.com/tweag/monad-bayes/tree/haskell2015 66 | [icfp2018-doi]: https://doi.org/10.1145/3236778 67 | [models]: https://github.com/tweag/monad-bayes/tree/master/models 68 | [stack-install]: https://docs.haskellstack.org/en/stable/install_and_upgrade/ 69 | [thesis-doi]: https://doi.org/10.17863/CAM.42233 70 | [tweagio]: https://tweag.io 71 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | 3 | main = defaultMain 4 | -------------------------------------------------------------------------------- /benchmark/SSM.hs: -------------------------------------------------------------------------------- 1 | module Main where 2 | 3 | import Control.Monad (forM_) 4 | import Control.Monad.Bayes.Inference.MCMC 5 | import Control.Monad.Bayes.Inference.PMMH as PMMH (pmmh) 6 | import Control.Monad.Bayes.Inference.RMSMC (rmsmcDynamic) 7 | import Control.Monad.Bayes.Inference.SMC 8 | import Control.Monad.Bayes.Inference.SMC2 as SMC2 (smc2) 9 | import Control.Monad.Bayes.Population 10 | import Control.Monad.Bayes.Population (resampleMultinomial, runPopulationT) 11 | import Control.Monad.Bayes.Sampler.Strict (sampleIO, sampleIOfixed, sampleWith) 12 | import Control.Monad.Bayes.Weighted (unweighted) 13 | import Control.Monad.IO.Class (MonadIO (liftIO)) 14 | import NonlinearSSM (generateData, model, param) 15 | import NonlinearSSM.Algorithms 16 | import System.Random.Stateful (mkStdGen, newIOGenM) 17 | 18 | main :: IO () 19 | main = sampleIOfixed $ do 20 | dat <- generateData t 21 | let ys = map snd dat 22 | forM_ [SMC, RMSMCDynamic, PMMH, SMC2] $ \alg -> do 23 | liftIO $ print alg 24 | result <- runAlgFixed ys alg 25 | liftIO $ putStrLn result 26 | -------------------------------------------------------------------------------- /benchmark/Single.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE ImportQualifiedPost #-} 3 | 4 | import Control.Applicative (Applicative (..)) 5 | import Control.Monad.Bayes.Sampler.Strict 6 | import Data.Time (diffUTCTime, getCurrentTime) 7 | import Helper 8 | import Options.Applicative 9 | ( ParserInfo, 10 | auto, 11 | execParser, 12 | fullDesc, 13 | help, 14 | info, 15 | long, 16 | maybeReader, 17 | option, 18 | short, 19 | ) 20 | -- Prelude exports liftA2 from GHC 9.6 on, see https://github.com/haskell/core-libraries-committee/blob/main/guides/export-lifta2-prelude.md 21 | -- import Control.Applicative further up can be removed once we don't support GHC <= 9.4 anymore 22 | 23 | import Prelude hiding (Applicative (..)) 24 | 25 | infer :: Model -> Alg -> IO () 26 | infer model alg = do 27 | x <- sampleIOfixed (runAlg model alg) 28 | print x 29 | 30 | opts :: ParserInfo (Model, Alg) 31 | opts = flip info fullDesc $ liftA2 (,) model alg 32 | where 33 | model = 34 | option 35 | (maybeReader parseModel) 36 | ( long "model" 37 | <> short 'm' 38 | <> help "Model" 39 | ) 40 | alg = 41 | option 42 | auto 43 | ( long "alg" 44 | <> short 'a' 45 | <> help "Inference algorithm" 46 | ) 47 | 48 | main :: IO () 49 | main = do 50 | (model, alg) <- execParser opts 51 | startTime <- getCurrentTime 52 | infer model alg 53 | endTime <- getCurrentTime 54 | print (diffUTCTime endTime startTime) 55 | -------------------------------------------------------------------------------- /benchmark/Speed.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE ImportQualifiedPost #-} 3 | {-# OPTIONS_GHC -Wall #-} 4 | 5 | module Main (main) where 6 | 7 | import Control.Monad.Bayes.Class (MonadMeasure) 8 | import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (MCMCConfig, numBurnIn, numMCMCSteps, proposal), Proposal (SingleSiteMH)) 9 | import Control.Monad.Bayes.Inference.RMSMC (rmsmcDynamic) 10 | import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smc) 11 | import Control.Monad.Bayes.Population (resampleSystematic, runPopulationT) 12 | import Control.Monad.Bayes.Sampler.Strict (SamplerIO, sampleIOfixed) 13 | import Control.Monad.Bayes.Traced (mh) 14 | import Control.Monad.Bayes.Weighted (unweighted) 15 | import Criterion.Main 16 | ( Benchmark, 17 | Benchmarkable, 18 | bench, 19 | defaultConfig, 20 | defaultMainWith, 21 | nfIO, 22 | ) 23 | import Criterion.Types (Config (csvFile, rawDataFile)) 24 | import Data.Functor (void) 25 | import Data.Maybe (listToMaybe) 26 | import Data.Text qualified as T 27 | import HMM qualified 28 | import LDA qualified 29 | import LogReg qualified 30 | import System.Directory (removeFile) 31 | import System.IO.Error (catchIOError, isDoesNotExistError) 32 | import System.Process.Typed (runProcess) 33 | 34 | data ProbProgSys = MonadBayes 35 | deriving stock (Show) 36 | 37 | data Model = LR [(Double, Bool)] | HMM [Double] | LDA [[T.Text]] 38 | 39 | instance Show Model where 40 | show (LR xs) = "LR" ++ show (length xs) 41 | show (HMM xs) = "HMM" ++ show (length xs) 42 | show (LDA xs) = "LDA" ++ show (maybe 0 length $ listToMaybe xs) 43 | 44 | buildModel :: (MonadMeasure m) => Model -> m String 45 | buildModel (LR dataset) = show <$> LogReg.logisticRegression dataset 46 | buildModel (HMM dataset) = show <$> HMM.hmm dataset 47 | buildModel (LDA dataset) = show <$> LDA.lda dataset 48 | 49 | modelLength :: Model -> Int 50 | modelLength (LR xs) = length xs 51 | modelLength (HMM xs) = length xs 52 | modelLength (LDA xs) = sum (map length xs) 53 | 54 | data Alg = MH Int | SMC Int | RMSMC Int Int 55 | 56 | instance Show Alg where 57 | show (MH n) = "MH" ++ show n 58 | show (SMC n) = "SMC" ++ show n 59 | show (RMSMC n t) = "RMSMC" ++ show n ++ "-" ++ show t 60 | 61 | runAlg :: Model -> Alg -> SamplerIO String 62 | runAlg model (MH n) = show <$> unweighted (mh n (buildModel model)) 63 | runAlg model (SMC n) = show <$> runPopulationT (smc SMCConfig {numSteps = (modelLength model), numParticles = n, resampler = resampleSystematic} (buildModel model)) 64 | runAlg model (RMSMC n t) = 65 | show 66 | <$> runPopulationT 67 | ( rmsmcDynamic 68 | MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} 69 | SMCConfig {numSteps = modelLength model, numParticles = n, resampler = resampleSystematic} 70 | (buildModel model) 71 | ) 72 | 73 | prepareBenchmarkable :: ProbProgSys -> Model -> Alg -> Benchmarkable 74 | prepareBenchmarkable MonadBayes model alg = nfIO $ sampleIOfixed (runAlg model alg) 75 | 76 | prepareBenchmark :: ProbProgSys -> Model -> Alg -> Benchmark 77 | prepareBenchmark MonadBayes model alg = 78 | bench (show MonadBayes ++ sep ++ show model ++ sep ++ show alg) $ 79 | prepareBenchmarkable MonadBayes model alg 80 | where 81 | sep = "_" :: String 82 | 83 | -- | Checks if the requested benchmark is implemented. 84 | supported :: (ProbProgSys, Model, Alg) -> Bool 85 | supported (_, _, RMSMC _ _) = True 86 | supported _ = True 87 | 88 | systems :: [ProbProgSys] 89 | systems = 90 | [ MonadBayes 91 | ] 92 | 93 | lengthBenchmarks :: [(Double, Bool)] -> [Double] -> [[T.Text]] -> [Benchmark] 94 | lengthBenchmarks lrData hmmData ldaData = benchmarks 95 | where 96 | lrLengths = 10 : map (* 100) [1 :: Int .. 10] 97 | hmmLengths = 10 : map (* 100) [1 :: Int .. 10] 98 | ldaLengths = 5 : map (* 50) [1 :: Int .. 10] 99 | models = 100 | map (LR . (`take` lrData)) lrLengths 101 | ++ map (HMM . (`take` hmmData)) hmmLengths 102 | ++ map (\n -> LDA $ map (take n) ldaData) ldaLengths 103 | algs = 104 | [ MH 100, 105 | SMC 100, 106 | RMSMC 10 1 107 | ] 108 | benchmarks = map (uncurry3 (prepareBenchmark)) $ filter supported xs 109 | where 110 | uncurry3 f (x, y, z) = f x y z 111 | xs = do 112 | m <- models 113 | s <- systems 114 | a <- algs 115 | return (s, m, a) 116 | 117 | samplesBenchmarks :: [(Double, Bool)] -> [Double] -> [[T.Text]] -> [Benchmark] 118 | samplesBenchmarks lrData hmmData ldaData = benchmarks 119 | where 120 | lrLengths = [50 :: Int] 121 | hmmLengths = [20 :: Int] 122 | ldaLengths = [10 :: Int] 123 | models = 124 | map (LR . (`take` lrData)) lrLengths 125 | ++ map (HMM . (`take` hmmData)) hmmLengths 126 | ++ map (\n -> LDA $ map (take n) ldaData) ldaLengths 127 | algs = 128 | map (\x -> MH (100 * x)) [1 .. 10] 129 | ++ map (\x -> SMC (100 * x)) [1 .. 10] 130 | ++ map (\x -> RMSMC 10 (10 * x)) [1 .. 10] 131 | benchmarks = map (uncurry3 (prepareBenchmark)) $ filter supported xs 132 | where 133 | uncurry3 f (x, y, z) = f x y z 134 | xs = do 135 | a <- algs 136 | s <- systems 137 | m <- models 138 | return (s, m, a) 139 | 140 | speedLengthCSV :: FilePath 141 | speedLengthCSV = "speed-length.csv" 142 | 143 | speedSamplesCSV :: FilePath 144 | speedSamplesCSV = "speed-samples.csv" 145 | 146 | rawDAT :: FilePath 147 | rawDAT = "raw.dat" 148 | 149 | cleanupLastRun :: IO () 150 | cleanupLastRun = mapM_ removeIfExists [speedLengthCSV, speedSamplesCSV, rawDAT] 151 | 152 | removeIfExists :: FilePath -> IO () 153 | removeIfExists file = do 154 | putStrLn $ "Removing: " ++ file 155 | catchIOError (removeFile file) $ \e -> 156 | if isDoesNotExistError e 157 | then putStrLn "Didn't find file, not removing" 158 | else ioError e 159 | 160 | main :: IO () 161 | main = do 162 | cleanupLastRun 163 | lrData <- sampleIOfixed (LogReg.syntheticData 1000) 164 | hmmData <- sampleIOfixed (HMM.syntheticData 1000) 165 | ldaData <- sampleIOfixed (LDA.syntheticData 5 1000) 166 | let configLength = defaultConfig {csvFile = Just speedLengthCSV, rawDataFile = Just rawDAT} 167 | defaultMainWith configLength (lengthBenchmarks lrData hmmData ldaData) 168 | let configSamples = defaultConfig {csvFile = Just speedSamplesCSV, rawDataFile = Just rawDAT} 169 | defaultMainWith configSamples (samplesBenchmarks lrData hmmData ldaData) 170 | void $ runProcess "python plots.py" 171 | -------------------------------------------------------------------------------- /default.nix: -------------------------------------------------------------------------------- 1 | ( 2 | import 3 | ( 4 | let 5 | lock = builtins.fromJSON (builtins.readFile ./flake.lock); 6 | in 7 | fetchTarball { 8 | url = "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz"; 9 | sha256 = lock.nodes.flake-compat.locked.narHash; 10 | } 11 | ) 12 | {src = ./.;} 13 | ) 14 | .defaultNix 15 | -------------------------------------------------------------------------------- /docs/docs/examples.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Example Gallery 3 | --- 4 | 5 | ## [Histograms](/notebooks/Histogram.html) 6 | 7 | ## [JSON (with `lens`)](/notebooks/Lenses.html) 8 | 9 | ## [Diagrams](/notebooks/Diagrams.html) 10 | 11 | ## [Probabilistic Parsing](/notebooks/Parsing.html) 12 | 13 | ## [Streams (with `pipes`)](/notebooks/Streaming.html) 14 | 15 | ## [Ising models](/notebooks/Ising.html) 16 | 17 | ## [Physics](/notebooks/ClassicalPhysics.html) 18 | 19 | -------------------------------------------------------------------------------- /docs/docs/images/code_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tweag/monad-bayes/c12f0e191ff6c8836273fd26b46f21b262480cf9/docs/docs/images/code_example.png -------------------------------------------------------------------------------- /docs/docs/images/haskell-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tweag/monad-bayes/c12f0e191ff6c8836273fd26b46f21b262480cf9/docs/docs/images/haskell-logo.png -------------------------------------------------------------------------------- /docs/docs/images/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tweag/monad-bayes/c12f0e191ff6c8836273fd26b46f21b262480cf9/docs/docs/images/plot.png -------------------------------------------------------------------------------- /docs/docs/images/priorpred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tweag/monad-bayes/c12f0e191ff6c8836273fd26b46f21b262480cf9/docs/docs/images/priorpred.png -------------------------------------------------------------------------------- /docs/docs/images/randomwalk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tweag/monad-bayes/c12f0e191ff6c8836273fd26b46f21b262480cf9/docs/docs/images/randomwalk.png -------------------------------------------------------------------------------- /docs/docs/images/regress.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tweag/monad-bayes/c12f0e191ff6c8836273fd26b46f21b262480cf9/docs/docs/images/regress.png -------------------------------------------------------------------------------- /docs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to Monad-Bayes 2 | 3 | Monad-Bayes is a library for **probabilistic programming** written in **Haskell**. 4 | 5 | **Define distributions** [as programs](/notebooks/Introduction.html) 6 | 7 | **Perform inference** [with a variety of standard methods](tutorials.md) [defined compositionally](http://approximateinference.org/accepted/ScibiorGhahramani2016.pdf) 8 | 9 | **Integrate with Haskell code** [like this](examples.md) because Monad-Bayes is just a library, not a separate language 10 | 11 | ## Example 12 | 13 | ```haskell 14 | model :: Distribution Double 15 | model = do 16 | x <- bernoulli 0.5 17 | normal (if x then (-3) else 3) 1 18 | 19 | image :: Distribution Plot 20 | image = fmap (plot . histogram 200) (replicateM 100000 model) 21 | 22 | sampler image 23 | ``` 24 | 25 | The program `model` is a mixture of Gaussians. Its type `Distribution Double` represents a distribution over reals. 26 | `image` is a program too: as its type shows, it is a distribution over plots. In particular, plots that arise from forming a 200 bin histogram out of 100000 independent identically distributed (iid) draws from `model`. 27 | To sample from `image`, we simply write `sampler image`, with the result shown below: 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) -------------------------------------------------------------------------------- /docs/docs/tutorials.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Tutorials 3 | --- 4 | 5 | ## [Introduction to Monad-Bayes](/notebooks/Introduction.html) 6 | 7 | ## [Sampling from a distribution](/notebooks/Sampling.html) 8 | 9 | ## [Bayesian models](/notebooks/Bayesian.html) 10 | 11 | ## [Markov Chain Monte Carlo](/notebooks/MCMC.html) 12 | 13 | ## [Sequential Monte Carlo](/notebooks/SMC.html) 14 | 15 | ## [Lazy Sampling](/notebooks/Lazy.html) 16 | 17 | ## [Advanced Inference Methods](/notebooks/AdvancedSampling.html) 18 | 19 | -------------------------------------------------------------------------------- /docs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Probabilistic Programming in Haskell 2 | theme: 3 | name: material 4 | features: 5 | - content.code.annotate 6 | - content.tooltips 7 | - navigation.sections 8 | - navigation.top 9 | - navigation.tracking 10 | - search.highlight 11 | - search.share 12 | - search.suggest 13 | - toc.follow 14 | palette: 15 | - scheme: default 16 | primary: indigo 17 | accent: indigo 18 | toggle: 19 | icon: material/brightness-7 20 | name: Switch to dark mode 21 | - scheme: slate 22 | primary: indigo 23 | accent: indigo 24 | toggle: 25 | icon: material/brightness-4 26 | name: Switch to light mode 27 | 28 | extra: 29 | social: 30 | - icon: fontawesome/brands/github 31 | link: https://github.com/tweag/monad-bayes 32 | 33 | markdown_extensions: 34 | - pymdownx.arithmatex: 35 | generic: true 36 | - admonition 37 | - attr_list 38 | - pymdownx.details 39 | - pymdownx.emoji: 40 | emoji_index: !!python/name:materialx.emoji.twemoji 41 | emoji_generator: !!python/name:materialx.emoji.to_svg 42 | - pymdownx.highlight: 43 | anchor_linenums: true 44 | - pymdownx.inlinehilite 45 | - pymdownx.snippets 46 | - pymdownx.superfences 47 | - pymdownx.tabbed: 48 | alternate_style: true 49 | - toc: 50 | permalink: true 51 | extra_javascript: 52 | - javascripts/mathjax.js 53 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 54 | 55 | 56 | nav: 57 | - 'index.md' 58 | - 'probprog.md' 59 | - 'tutorials.md' 60 | - 'examples.md' 61 | - 'usage.md' 62 | -------------------------------------------------------------------------------- /docs/netlify.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | command = "mkdocs build" 3 | publish = "site" -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | pymdown-extensions -------------------------------------------------------------------------------- /docs/runtime.txt: -------------------------------------------------------------------------------- 1 | 3.8 -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "A library for probabilistic programming in Haskell."; 3 | nixConfig = { 4 | extra-substituters = [ 5 | "https://tweag-monad-bayes.cachix.org" 6 | "https://tweag-jupyter.cachix.org" 7 | ]; 8 | extra-trusted-public-keys = [ 9 | "tweag-monad-bayes.cachix.org-1:tmmTZ+WvtUMpYWD4LAkfSuNKqSuJyL3N8ZVm/qYtqdc=" 10 | "tweag-jupyter.cachix.org-1:UtNH4Zs6hVUFpFBTLaA4ejYavPo5EFFqgd7G7FxGW9g=" 11 | ]; 12 | }; 13 | inputs = { 14 | nixpkgs.url = "nixpkgs/nixos-unstable"; 15 | flake-compat = { 16 | url = "github:edolstra/flake-compat"; 17 | flake = false; 18 | }; 19 | flake-utils.url = "github:numtide/flake-utils"; 20 | pre-commit-hooks = { 21 | url = "github:cachix/pre-commit-hooks.nix"; 22 | inputs = { 23 | nixpkgs.follows = "nixpkgs"; 24 | flake-utils.follows = "flake-utils"; 25 | }; 26 | }; 27 | jupyenv = { 28 | url = "github:tweag/jupyenv"; 29 | inputs = { 30 | flake-compat.follows = "flake-compat"; 31 | flake-utils.follows = "flake-utils"; 32 | }; 33 | }; 34 | }; 35 | outputs = { 36 | self, 37 | nixpkgs, 38 | jupyenv, 39 | flake-compat, 40 | flake-utils, 41 | pre-commit-hooks, 42 | } @ inputs: 43 | flake-utils.lib.eachSystem 44 | [ 45 | # Tier 1 - Tested in CI 46 | flake-utils.lib.system.x86_64-linux 47 | flake-utils.lib.system.x86_64-darwin 48 | # Tier 2 - Not tested in CI (at least for now) 49 | flake-utils.lib.system.aarch64-linux 50 | flake-utils.lib.system.aarch64-darwin 51 | ] 52 | ( 53 | system: let 54 | inherit (nixpkgs) lib; 55 | inherit (jupyenv.lib.${system}) mkJupyterlabNew; 56 | pkgs = import nixpkgs { 57 | inherit system; 58 | config.allowBroken = true; 59 | }; 60 | 61 | warnToUpdateNix = pkgs.lib.warn "Consider updating to Nix > 2.7 to remove this warning!"; 62 | src = lib.sourceByRegex self [ 63 | "^benchmark.*$" 64 | "^models.*$" 65 | "^monad-bayes\.cabal$" 66 | "^src.*$" 67 | "^test.*$" 68 | "^.*\.md" 69 | ]; 70 | 71 | monad-bayes-per-ghc = let 72 | opts = { 73 | name = "monad-bayes"; 74 | root = src; 75 | cabal2nixOptions = "--benchmark -fdev"; 76 | 77 | # https://github.com/tweag/monad-bayes/pull/256: Don't run tests on Mac because of machine precision issues 78 | modifier = drv: 79 | if system == "x86_64-linux" 80 | then drv 81 | else pkgs.haskell.lib.dontCheck drv; 82 | overrides = self: super: 83 | with pkgs.haskell.lib; 84 | { 85 | # Please check after flake.lock updates whether some of these overrides can be removed 86 | brick = super.brick_2_4; 87 | } 88 | // lib.optionalAttrs (lib.versionAtLeast super.ghc.version "9.10") { 89 | # Please check after flake.lock updates whether some of these overrides can be removed 90 | microstache = doJailbreak super.microstache; 91 | }; 92 | }; 93 | ghcs = [ 94 | # Always keep this up to date with the tested-with section in monad-bayes.cabal, 95 | # and the build-all-ghcs job in .github/workflows/nix.yml! 96 | "ghc902" 97 | "ghc927" 98 | "ghc945" 99 | "ghc964" 100 | "ghc982" 101 | "ghc9101" 102 | ]; 103 | buildForVersion = ghcVersion: (builtins.getAttr ghcVersion pkgs.haskell.packages).developPackage opts; 104 | in 105 | lib.attrsets.genAttrs ghcs buildForVersion; 106 | 107 | monad-bayes = monad-bayes-per-ghc.ghc902; 108 | 109 | monad-bayes-all-ghcs = pkgs.linkFarm "monad-bayes-all-ghcs" monad-bayes-per-ghc; 110 | 111 | jupyterEnvironment = mkJupyterlabNew { 112 | imports = [ 113 | (import ./kernels/haskell.nix {inherit monad-bayes;}) 114 | ]; 115 | }; 116 | 117 | monad-bayes-dev = pkgs.mkShell { 118 | inputsFrom = [monad-bayes.env]; 119 | packages = with pre-commit-hooks.packages.${system}; [ 120 | alejandra 121 | cabal-fmt 122 | hlint 123 | ormolu 124 | jupyterEnvironment 125 | ]; 126 | shellHook = pre-commit.shellHook; 127 | }; 128 | pre-commit = pre-commit-hooks.lib.${system}.run { 129 | inherit src; 130 | hooks = { 131 | alejandra.enable = true; 132 | cabal-fmt.enable = true; 133 | hlint.enable = false; 134 | ormolu.enable = true; 135 | }; 136 | }; 137 | in rec { 138 | packages = { 139 | inherit monad-bayes monad-bayes-per-ghc monad-bayes-all-ghcs pre-commit jupyterEnvironment; 140 | }; 141 | packages.default = packages.monad-bayes; 142 | checks = {inherit monad-bayes pre-commit;}; 143 | devShells.default = monad-bayes-dev; 144 | # Needed for backwards compatibility with Nix versions <2.8 145 | defaultPackage = warnToUpdateNix packages.default; 146 | devShell = warnToUpdateNix devShells.default; 147 | } 148 | ); 149 | } 150 | -------------------------------------------------------------------------------- /kernels/haskell.nix: -------------------------------------------------------------------------------- 1 | {monad-bayes}: 2 | 3 | {pkgs, ...}: { 4 | kernel.haskell.monad-bayes = { 5 | enable = true; 6 | name = "monad-bayes"; 7 | displayName = "monad-bayes"; 8 | extraHaskellPackages = p: [ 9 | p.hvega 10 | p.lens 11 | p.log-domain 12 | p.katip 13 | p.ihaskell-hvega 14 | p.ihaskell-diagrams 15 | p.diagrams 16 | p.diagrams-cairo 17 | p.aeson 18 | p.lens 19 | p.lens-aeson 20 | p.pretty-simple 21 | p.monad-loops 22 | # hamilton seems unmaintained, unclear whether we can keep it. See https://github.com/tweag/monad-bayes/issues/300 23 | # p.hamilton 24 | p.hmatrix 25 | p.vector-sized 26 | p.linear 27 | p.recursion-schemes 28 | p.data-fix 29 | p.free 30 | p.comonad 31 | p.adjunctions 32 | p.distributive 33 | p.vector 34 | p.megaparsec 35 | p.histogram-fill 36 | # Strange build error which I can't fix: 37 | # > Configuring gloss-rendering-1.13.1.2... 38 | # > 39 | # > Setup: Encountered missing or private dependencies: 40 | # > bytestring >=0.11 && <0.12 41 | # p.gloss 42 | monad-bayes 43 | ]; 44 | }; 45 | } 46 | -------------------------------------------------------------------------------- /models/BetaBin.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} 3 | {-# OPTIONS_GHC -Wno-missing-export-lists #-} 4 | 5 | module BetaBin where 6 | 7 | -- The beta-binomial model in latent variable and urn model representations. 8 | -- The two formulations should be exactly equivalent, but only urn works with Dist. 9 | import Control.Monad (replicateM) 10 | import Control.Monad.Bayes.Class 11 | ( MonadDistribution (bernoulli, uniform), 12 | MonadMeasure, 13 | condition, 14 | ) 15 | import Control.Monad.State.Lazy (evalStateT, get, put) 16 | import Pipes ((<-<)) 17 | import Pipes.Prelude qualified as P hiding (show) 18 | 19 | -- | Beta-binomial model as an i.i.d. sequence conditionally on weight. 20 | latent :: (MonadDistribution m) => Int -> m [Bool] 21 | latent n = do 22 | weight <- uniform 0 1 23 | replicateM n (bernoulli weight) 24 | 25 | -- | Beta-binomial as a random process. 26 | -- Equivalent to the above by De Finetti's theorem. 27 | urn :: (MonadDistribution m) => Int -> m [Bool] 28 | urn n = flip evalStateT (1, 1) $ do 29 | replicateM n do 30 | (a, b) <- get 31 | let weight = a / (a + b) 32 | outcome <- bernoulli weight 33 | let (a', b') = if outcome then (a + 1, b) else (a, b + 1) 34 | put (a', b') 35 | return outcome 36 | 37 | -- | Beta-binomial as a random process. 38 | -- This time using the Pipes library, for a more pure functional style 39 | urnP :: (MonadDistribution m) => Int -> m [Bool] 40 | urnP n = P.toListM $ P.take n <-< P.unfoldr toss (1, 1) 41 | where 42 | toss (a, b) = do 43 | let weight = a / (a + b) 44 | outcome <- bernoulli weight 45 | let (a', b') = if outcome then (a + 1, b) else (a, b + 1) 46 | return $ Right (outcome, (a', b')) 47 | 48 | -- | A beta-binomial model where the first three states are True,True,False. 49 | -- The resulting distribution is on the remaining outcomes. 50 | cond :: (MonadMeasure m) => m [Bool] -> m [Bool] 51 | cond d = do 52 | ~(first : second : third : rest) <- d 53 | condition first 54 | condition second 55 | condition (not third) 56 | return rest 57 | 58 | -- | The final conditional model, abstracting the representation. 59 | model :: (MonadMeasure m) => (Int -> m [Bool]) -> Int -> m Int 60 | model repr n = fmap count $ cond $ repr (n + 3) 61 | where 62 | -- Post-processing by counting the number of True values. 63 | count :: [Bool] -> Int 64 | count = length . filter id 65 | -------------------------------------------------------------------------------- /models/ConjugatePriors.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | 5 | module ConjugatePriors where 6 | 7 | import Control.Applicative (Applicative (..)) 8 | import Control.Foldl (fold) 9 | import Control.Foldl qualified as F 10 | import Control.Monad.Bayes.Class (Bayesian (..), MonadDistribution (bernoulli, beta, gamma, normal), MonadMeasure, normalPdf) 11 | import Numeric.Log (Log (Exp)) 12 | -- Prelude exports liftA2 from GHC 9.6 on, see https://github.com/haskell/core-libraries-committee/blob/main/guides/export-lifta2-prelude.md 13 | -- import Control.Applicative further up can be removed once we don't support GHC <= 9.4 anymore 14 | 15 | import Prelude hiding (Applicative (..)) 16 | 17 | type GammaParams = (Double, Double) 18 | 19 | type BetaParams = (Double, Double) 20 | 21 | type NormalParams = (Double, Double) 22 | 23 | -- | Posterior on the precision of the normal after the points are observed 24 | gammaNormalAnalytic :: 25 | (MonadMeasure m, Foldable t, Functor t) => 26 | GammaParams -> 27 | t Double -> 28 | m Double 29 | 30 | -- | Exact posterior for the model. 31 | -- For derivation see Kevin Murphy's 32 | -- "Conjugate Bayesian analysis of the Gaussian distribution" 33 | -- section 4. 34 | gammaNormalAnalytic (a, b) points = gamma a' (recip b') 35 | where 36 | a' = a + fromIntegral (length points) / 2 37 | b' = b + sum (fmap (** 2) points) / 2 38 | 39 | -- | Posterior on beta after the bernoulli sample 40 | betaBernoulliAnalytic :: (MonadMeasure m, Foldable t) => BetaParams -> t Bool -> m Double 41 | betaBernoulliAnalytic (a, b) points = beta a' b' 42 | where 43 | (n, s) = fold (liftA2 (,) F.length (F.premap (\case True -> 1; False -> 0) F.sum)) points 44 | a' = a + s 45 | b' = b + fromIntegral n - s 46 | 47 | bernoulliPdf :: (Floating a) => a -> Bool -> Log a 48 | bernoulliPdf p x = let numBool = if x then 1.0 else 0 in Exp $ log (p ** numBool * (1 - p) ** (1 - numBool)) 49 | 50 | betaBernoulli' :: (MonadMeasure m) => (Double, Double) -> Bayesian m Double Bool 51 | betaBernoulli' (a, b) = Bayesian (beta a b) bernoulli bernoulliPdf 52 | 53 | normalNormal' :: (MonadMeasure m) => Double -> (Double, Double) -> Bayesian m Double Double 54 | normalNormal' var (mu0, var0) = Bayesian (normal mu0 (sqrt var0)) (`normal` (sqrt var)) (`normalPdf` (sqrt var)) 55 | 56 | gammaNormal' :: (MonadMeasure m) => (Double, Double) -> Bayesian m Double Double 57 | gammaNormal' (a, b) = Bayesian (gamma a (recip b)) (normal 0 . sqrt . recip) (normalPdf 0 . sqrt . recip) 58 | 59 | normalNormalAnalytic :: 60 | (MonadMeasure m, Foldable t) => 61 | Double -> 62 | NormalParams -> 63 | t Double -> 64 | m Double 65 | normalNormalAnalytic sigma_2 (mu0, sigma0_2) points = normal mu' (sqrt sigma_2') 66 | where 67 | (n, s) = fold (liftA2 (,) F.length F.sum) points 68 | mu' = sigma_2' * (mu0 / sigma0_2 + s / sigma_2) 69 | sigma_2' = recip (recip sigma0_2 + fromIntegral n / sigma_2) 70 | -------------------------------------------------------------------------------- /models/Dice.hs: -------------------------------------------------------------------------------- 1 | module Dice (diceHard, diceSoft) where 2 | 3 | -- A toy model for dice rolling from http://dl.acm.org/citation.cfm?id=2804317 4 | -- Exact results can be obtained using Dist monad 5 | 6 | import Control.Applicative (Applicative (..)) 7 | import Control.Monad.Bayes.Class 8 | ( MonadDistribution (uniformD), 9 | MonadFactor (score), 10 | MonadMeasure, 11 | condition, 12 | ) 13 | -- Prelude exports liftA2 from GHC 9.6 on, see https://github.com/haskell/core-libraries-committee/blob/main/guides/export-lifta2-prelude.md 14 | -- import Control.Applicative further up can be removed once we don't support GHC <= 9.4 anymore 15 | 16 | import Prelude hiding (Applicative (..)) 17 | 18 | -- | A toss of a six-sided die. 19 | die :: (MonadDistribution m) => m Int 20 | die = uniformD [1 .. 6] 21 | 22 | -- | A sum of outcomes of n independent tosses of six-sided dice. 23 | dice :: (MonadDistribution m) => Int -> m Int 24 | dice 1 = die 25 | dice n = liftA2 (+) die (dice (n - 1)) 26 | 27 | -- | Toss of two dice where the output is greater than 4. 28 | diceHard :: (MonadMeasure m) => m Int 29 | diceHard = do 30 | result <- dice 2 31 | condition (result > 4) 32 | return result 33 | 34 | -- | Toss of two dice with an artificial soft constraint. 35 | diceSoft :: (MonadMeasure m) => m Int 36 | diceSoft = do 37 | result <- dice 2 38 | score (1 / fromIntegral result) 39 | return result 40 | -------------------------------------------------------------------------------- /models/HMM.hs: -------------------------------------------------------------------------------- 1 | -- HMM from Anglican (https://bitbucket.org/probprog/anglican-white-paper) 2 | 3 | module HMM where 4 | 5 | import Control.Monad (replicateM, when) 6 | import Control.Monad.Bayes.Class 7 | ( MonadDistribution (categorical, normal, uniformD), 8 | MonadFactor, 9 | MonadMeasure, 10 | factor, 11 | normalPdf, 12 | ) 13 | import Control.Monad.Bayes.Enumerator (enumerateToDistribution) 14 | import Data.Maybe (fromJust, isJust) 15 | import Data.Vector (fromList) 16 | import Pipes (MFunctor (hoist), MonadTrans (lift), each, yield, (>->)) 17 | import Pipes.Core (Producer) 18 | import Pipes.Prelude qualified as Pipes 19 | 20 | -- | Observed values 21 | values :: [Double] 22 | values = 23 | [ 0.9, 24 | 0.8, 25 | 0.7, 26 | 0, 27 | -0.025, 28 | -5, 29 | -2, 30 | -0.1, 31 | 0, 32 | 0.13, 33 | 0.45, 34 | 6, 35 | 0.2, 36 | 0.3, 37 | -1, 38 | -1 39 | ] 40 | 41 | -- | The transition model. 42 | trans :: (MonadDistribution m) => Int -> m Int 43 | trans 0 = categorical $ fromList [0.1, 0.4, 0.5] 44 | trans 1 = categorical $ fromList [0.2, 0.6, 0.2] 45 | trans 2 = categorical $ fromList [0.15, 0.7, 0.15] 46 | trans _ = error "unreachable" 47 | 48 | -- | The emission model. 49 | emissionMean :: Int -> Double 50 | emissionMean 0 = -1 51 | emissionMean 1 = 1 52 | emissionMean 2 = 0 53 | emissionMean _ = error "unreachable" 54 | 55 | -- | Initial state distribution 56 | start :: (MonadDistribution m) => m Int 57 | start = uniformD [0, 1, 2] 58 | 59 | -- | Example HMM from http://dl.acm.org/citation.cfm?id=2804317 60 | hmm :: (MonadMeasure m) => [Double] -> m [Int] 61 | hmm dataset = f dataset (const . return) 62 | where 63 | expand x y = do 64 | x' <- trans x 65 | factor $ normalPdf (emissionMean x') 1 y 66 | return x' 67 | f [] k = start >>= k [] 68 | f (y : ys) k = f ys (\xs x -> expand x y >>= k (x : xs)) 69 | 70 | syntheticData :: (MonadDistribution m) => Int -> m [Double] 71 | syntheticData n = replicateM n syntheticPoint 72 | where 73 | syntheticPoint = uniformD [0, 1, 2] 74 | 75 | -- | Equivalent model, but using pipes for simplicity 76 | 77 | -- | Prior expressed as a stream 78 | hmmPrior :: (MonadDistribution m) => Producer Int m b 79 | hmmPrior = do 80 | x <- lift start 81 | yield x 82 | Pipes.unfoldr (fmap (Right . (\k -> (k, k))) . trans) x 83 | 84 | -- | Observations expressed as a stream 85 | hmmObservations :: (Functor m) => [a] -> Producer (Maybe a) m () 86 | hmmObservations dataset = each (Nothing : (Just <$> reverse dataset)) 87 | 88 | -- | Posterior expressed as a stream 89 | hmmPosterior :: (MonadMeasure m) => [Double] -> Producer Int m () 90 | hmmPosterior dataset = 91 | zipWithM 92 | hmmLikelihood 93 | hmmPrior 94 | (hmmObservations dataset) 95 | where 96 | hmmLikelihood :: (MonadFactor f) => (Int, Maybe Double) -> f () 97 | hmmLikelihood (l, o) = when (isJust o) (factor $ normalPdf (emissionMean l) 1 (fromJust o)) 98 | 99 | zipWithM f p1 p2 = Pipes.zip p1 p2 >-> Pipes.chain f >-> Pipes.map fst 100 | 101 | hmmPosteriorPredictive :: (MonadDistribution m) => [Double] -> Producer Double m () 102 | hmmPosteriorPredictive dataset = 103 | Pipes.hoist enumerateToDistribution (hmmPosterior dataset) 104 | >-> Pipes.mapM (\x -> normal (emissionMean x) 1) 105 | -------------------------------------------------------------------------------- /models/Helper.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE ImportQualifiedPost #-} 3 | 4 | module Helper where 5 | 6 | import Control.Monad.Bayes.Class (MonadMeasure) 7 | import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..), Proposal (SingleSiteMH)) 8 | import Control.Monad.Bayes.Inference.RMSMC (rmsmcBasic) 9 | import Control.Monad.Bayes.Inference.SMC 10 | ( SMCConfig (SMCConfig, numParticles, numSteps, resampler), 11 | smc, 12 | ) 13 | import Control.Monad.Bayes.Population 14 | import Control.Monad.Bayes.Sampler.Strict 15 | import Control.Monad.Bayes.Traced hiding (model) 16 | import Control.Monad.Bayes.Weighted 17 | import Control.Monad.ST (runST) 18 | import HMM qualified 19 | import LDA qualified 20 | import LogReg qualified 21 | 22 | data Model = LR Int | HMM Int | LDA (Int, Int) 23 | deriving stock (Show, Read) 24 | 25 | parseModel :: String -> Maybe Model 26 | parseModel s = 27 | case s of 28 | 'L' : 'R' : n -> Just $ LR (read n) 29 | 'H' : 'M' : 'M' : n -> Just $ HMM (read n) 30 | 'L' : 'D' : 'A' : n -> Just $ LDA (5, read n) 31 | _ -> Nothing 32 | 33 | serializeModel :: Model -> Maybe String 34 | serializeModel (LR n) = Just $ "LR" ++ show n 35 | serializeModel (HMM n) = Just $ "HMM" ++ show n 36 | serializeModel (LDA (5, n)) = Just $ "LDA" ++ show n 37 | serializeModel (LDA _) = Nothing 38 | 39 | data Alg = SMC | MH | RMSMC 40 | deriving stock (Read, Show, Eq, Ord, Enum, Bounded) 41 | 42 | getModel :: (MonadMeasure m) => Model -> (Int, m String) 43 | getModel model = (size model, program model) 44 | where 45 | size (LR n) = n 46 | size (HMM n) = n 47 | size (LDA (d, w)) = d * w 48 | program (LR n) = show <$> (LogReg.logisticRegression (runST $ sampleSTfixed (LogReg.syntheticData n))) 49 | program (HMM n) = show <$> (HMM.hmm (runST $ sampleSTfixed (HMM.syntheticData n))) 50 | program (LDA (d, w)) = show <$> (LDA.lda (runST $ sampleSTfixed (LDA.syntheticData d w))) 51 | 52 | runAlg :: Model -> Alg -> SamplerIO String 53 | runAlg model alg = 54 | case alg of 55 | SMC -> 56 | let n = 100 57 | (k, m) = getModel model 58 | in show <$> runPopulationT (smc SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic} m) 59 | MH -> 60 | let t = 100 61 | (_, m) = getModel model 62 | in show <$> unweighted (mh t m) 63 | RMSMC -> 64 | let n = 10 65 | t = 1 66 | (k, m) = getModel model 67 | in show <$> runPopulationT (rmsmcBasic MCMCConfig {numMCMCSteps = t, numBurnIn = 0, proposal = SingleSiteMH} (SMCConfig {numSteps = k, numParticles = n, resampler = resampleSystematic}) m) 68 | 69 | runAlgFixed :: Model -> Alg -> IO String 70 | runAlgFixed model alg = sampleIOfixed $ runAlg model alg 71 | -------------------------------------------------------------------------------- /models/LDA.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | 3 | -- LDA model from Anglican 4 | -- (https://bitbucket.org/probprog/anglican-white-paper) 5 | 6 | -- This model is just a toy/reference implementation. 7 | -- A more serious one would not store documents as lists of words. 8 | -- The point is just to showcase the model 9 | 10 | module LDA where 11 | 12 | import Control.Monad qualified as List (replicateM) 13 | import Control.Monad.Bayes.Class 14 | ( MonadDistribution (categorical, dirichlet, uniformD), 15 | MonadMeasure, 16 | factor, 17 | ) 18 | import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed) 19 | import Control.Monad.Bayes.Traced (mh) 20 | import Control.Monad.Bayes.Weighted (unweighted) 21 | import Data.Map qualified as Map 22 | import Data.Text (Text, words) 23 | import Data.Vector as V (Vector, replicate, (!)) 24 | import Data.Vector qualified as V hiding (length, mapM, mapM_) 25 | import Numeric.Log (Log (Exp)) 26 | import Text.Pretty.Simple (pPrint) 27 | import Prelude hiding (words) 28 | 29 | vocabulary :: [Text] 30 | vocabulary = ["bear", "wolf", "python", "prolog"] 31 | 32 | topics :: [Text] 33 | topics = ["topic1", "topic2"] 34 | 35 | type Documents = [[Text]] 36 | 37 | documents :: Documents 38 | documents = 39 | [ words "bear wolf bear wolf bear wolf python wolf bear wolf", 40 | words "python prolog python prolog python prolog python prolog python prolog", 41 | words "bear wolf bear wolf bear wolf bear wolf bear wolf", 42 | words "python prolog python prolog python prolog python prolog python prolog", 43 | words "bear wolf bear python bear wolf bear wolf bear wolf" 44 | ] 45 | 46 | wordDistPrior :: (MonadDistribution m) => m (V.Vector Double) 47 | wordDistPrior = dirichlet $ V.replicate (length vocabulary) 1 48 | 49 | topicDistPrior :: (MonadDistribution m) => m (V.Vector Double) 50 | topicDistPrior = dirichlet $ V.replicate (length topics) 1 51 | 52 | wordIndex :: Map.Map Text Int 53 | wordIndex = Map.fromList $ zip vocabulary [0 ..] 54 | 55 | lda :: 56 | (MonadMeasure m) => 57 | Documents -> 58 | m (Map.Map Text (V.Vector (Text, Double)), [(Text, V.Vector (Text, Double))]) 59 | lda docs = do 60 | word_dist_for_topic <- do 61 | ts <- List.replicateM (length topics) wordDistPrior 62 | return $ Map.fromList $ zip topics ts 63 | let obs doc = do 64 | topic_dist <- topicDistPrior 65 | let f word = do 66 | topic <- (fmap (topics !!) . categorical) topic_dist 67 | factor $ (Exp . log) $ (word_dist_for_topic Map.! topic) V.! (wordIndex Map.! word) 68 | mapM_ f doc 69 | return topic_dist 70 | td <- mapM obs docs 71 | return 72 | ( fmap (V.zip (V.fromList vocabulary)) word_dist_for_topic, 73 | zip (fmap (foldr1 (\x y -> x <> " " <> y)) docs) (fmap (V.zip $ V.fromList ["topic1", "topic2"]) td) 74 | ) 75 | 76 | syntheticData :: (MonadDistribution m) => Int -> Int -> m [[Text]] 77 | syntheticData d w = List.replicateM d (List.replicateM w syntheticWord) 78 | where 79 | syntheticWord = uniformD vocabulary 80 | 81 | runLDA :: IO () 82 | runLDA = do 83 | s <- sampleIOfixed $ unweighted $ mh 1000 $ lda documents 84 | pPrint $ take 1 s 85 | -------------------------------------------------------------------------------- /models/LogReg.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | 3 | -- Logistic regression model from Anglican 4 | -- (https://bitbucket.org/probprog/anglican-white-paper) 5 | 6 | module LogReg (logisticRegression, syntheticData, xs, labels) where 7 | 8 | import Control.Monad (replicateM) 9 | import Control.Monad.Bayes.Class 10 | ( MonadDistribution (bernoulli, gamma, normal, uniform), 11 | MonadMeasure, 12 | factor, 13 | ) 14 | import Numeric.Log (Log (Exp)) 15 | 16 | logisticRegression :: (MonadMeasure m) => [(Double, Bool)] -> m Double 17 | logisticRegression dat = do 18 | m <- normal 0 1 19 | b <- normal 0 1 20 | sigma <- gamma 1 1 21 | let y x = normal (m * x + b) sigma 22 | sigmoid x = y x >>= \t -> return $ 1 / (1 + exp (-t)) 23 | obs x label = do 24 | p <- sigmoid x 25 | factor $ (Exp . log) $ if label then p else 1 - p 26 | mapM_ (uncurry obs) dat 27 | sigmoid 8 28 | 29 | -- make a synthetic dataset by randomly choosing input-label pairs 30 | syntheticData :: (MonadDistribution m) => Int -> m [(Double, Bool)] 31 | syntheticData n = replicateM n do 32 | x <- uniform (-1) 1 33 | label <- bernoulli 0.5 34 | return (x, label) 35 | 36 | -- a tiny test dataset, for sanity-checking 37 | xs :: [Double] 38 | xs = [-10, -5, 2, 6, 10] 39 | 40 | labels :: [Bool] 41 | labels = [False, False, True, True, True] 42 | -------------------------------------------------------------------------------- /models/NestedInference.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | 3 | module NestedInference where 4 | 5 | import Control.Monad.Bayes.Class (MonadDistribution (uniformD), MonadMeasure, factor) 6 | import Control.Monad.Bayes.Enumerator (mass) 7 | import Numeric.Log (Log (Exp)) 8 | 9 | data Utterance = ASquare | AShape deriving (Eq, Show, Ord) 10 | 11 | data State = Square | Circle deriving (Eq, Show, Ord) 12 | 13 | data Action = Speak Utterance | DoNothing deriving (Eq, Show, Ord) 14 | 15 | -- | uniformly likely to say any true utterance to convey the given state 16 | truthfulAgent :: (MonadDistribution m) => State -> m Action 17 | truthfulAgent state = uniformD case state of 18 | Square -> [Speak ASquare, Speak AShape, DoNothing] 19 | Circle -> [Speak AShape, DoNothing] 20 | 21 | -- | a listener which applies Bayes rule to infer the state 22 | -- given an observed action of the other agent 23 | listener :: (MonadMeasure m) => Action -> m State 24 | listener observedAction = do 25 | state <- uniformD [Square, Circle] 26 | factor $ log $ Exp $ mass (truthfulAgent state) observedAction 27 | return state 28 | 29 | -- | an agent which produces an action by reasoning about 30 | -- how the listener would interpret it 31 | informativeAgent :: (MonadMeasure m) => State -> m Action 32 | informativeAgent state = do 33 | utterance <- uniformD [Speak ASquare, Speak AShape, DoNothing] 34 | factor $ log $ Exp $ mass (listener utterance) state 35 | return utterance 36 | -------------------------------------------------------------------------------- /models/NonlinearSSM.hs: -------------------------------------------------------------------------------- 1 | module NonlinearSSM where 2 | 3 | import Control.Monad.Bayes.Class 4 | ( MonadDistribution (gamma, normal), 5 | MonadMeasure, 6 | factor, 7 | normalPdf, 8 | ) 9 | 10 | param :: (MonadDistribution m) => m (Double, Double) 11 | param = do 12 | let a = 0.01 13 | let b = 0.01 14 | precX <- gamma a b 15 | let sigmaX = 1 / sqrt precX 16 | precY <- gamma a b 17 | let sigmaY = 1 / sqrt precY 18 | return (sigmaX, sigmaY) 19 | 20 | mean :: Double -> Int -> Double 21 | mean x n = 0.5 * x + 25 * x / (1 + x * x) + 8 * cos (1.2 * fromIntegral n) 22 | 23 | -- | A nonlinear series model from Doucet et al. (2000) 24 | -- "On sequential Monte Carlo sampling methods" section VI.B 25 | model :: 26 | (MonadMeasure m) => 27 | -- | observed data 28 | [Double] -> 29 | -- | prior on the parameters 30 | (Double, Double) -> 31 | -- | list of latent states from t=1 32 | m [Double] 33 | model obs (sigmaX, sigmaY) = do 34 | let sq x = x * x 35 | simulate [] _ acc = return acc 36 | simulate (y : ys) x acc = do 37 | let n = length acc 38 | x' <- normal (mean x n) sigmaX 39 | factor $ normalPdf (sq x' / 20) sigmaY y 40 | simulate ys x' (x' : acc) 41 | x0 <- normal 0 (sqrt 5) 42 | xs <- simulate obs x0 [] 43 | return $ reverse xs 44 | 45 | generateData :: 46 | (MonadDistribution m) => 47 | -- | T 48 | Int -> 49 | -- | list of latent and observable states from t=1 50 | m [(Double, Double)] 51 | generateData t = do 52 | (sigmaX, sigmaY) <- param 53 | let sq x = x * x 54 | simulate 0 _ acc = return acc 55 | simulate k x acc = do 56 | let n = length acc 57 | x' <- normal (mean x n) sigmaX 58 | y' <- normal (sq x' / 20) sigmaY 59 | simulate (k - 1) x' ((x', y') : acc) 60 | x0 <- normal 0 (sqrt 5) 61 | xys <- simulate t x0 [] 62 | return $ reverse xys 63 | -------------------------------------------------------------------------------- /models/NonlinearSSM/Algorithms.hs: -------------------------------------------------------------------------------- 1 | module NonlinearSSM.Algorithms where 2 | 3 | import Control.Monad.Bayes.Class (MonadDistribution) 4 | import Control.Monad.Bayes.Inference.MCMC 5 | import Control.Monad.Bayes.Inference.PMMH as PMMH (pmmh) 6 | import Control.Monad.Bayes.Inference.RMSMC (rmsmc, rmsmcBasic, rmsmcDynamic) 7 | import Control.Monad.Bayes.Inference.SMC 8 | import Control.Monad.Bayes.Inference.SMC2 as SMC2 (smc2) 9 | import Control.Monad.Bayes.Population 10 | import Control.Monad.Bayes.Weighted (unweighted) 11 | import NonlinearSSM 12 | 13 | data Alg = SMC | RMSMC | RMSMCDynamic | RMSMCBasic | PMMH | SMC2 14 | deriving (Show, Read, Eq, Ord, Enum, Bounded) 15 | 16 | algs :: [Alg] 17 | algs = [minBound .. maxBound] 18 | 19 | type SSMData = [Double] 20 | 21 | t :: Int 22 | t = 5 23 | 24 | -- FIXME refactor such that it can be reused in ssm benchmark 25 | runAlgFixed :: (MonadDistribution m) => SSMData -> Alg -> m String 26 | runAlgFixed ys SMC = fmap show $ runPopulationT $ smc SMCConfig {numSteps = t, numParticles = 10, resampler = resampleMultinomial} (param >>= model ys) 27 | runAlgFixed ys RMSMC = 28 | fmap show $ 29 | runPopulationT $ 30 | rmsmc 31 | MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH} 32 | SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic} 33 | (param >>= model ys) 34 | runAlgFixed ys RMSMCBasic = 35 | fmap show $ 36 | runPopulationT $ 37 | rmsmcBasic 38 | MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH} 39 | SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic} 40 | (param >>= model ys) 41 | runAlgFixed ys RMSMCDynamic = 42 | fmap show $ 43 | runPopulationT $ 44 | rmsmcDynamic 45 | MCMCConfig {numMCMCSteps = 10, numBurnIn = 0, proposal = SingleSiteMH} 46 | SMCConfig {numSteps = t, numParticles = 10, resampler = resampleSystematic} 47 | (param >>= model ys) 48 | runAlgFixed ys PMMH = 49 | fmap show $ 50 | unweighted $ 51 | pmmh 52 | MCMCConfig {numMCMCSteps = 2, numBurnIn = 0, proposal = SingleSiteMH} 53 | SMCConfig {numSteps = t, numParticles = 3, resampler = resampleSystematic} 54 | param 55 | (model ys) 56 | runAlgFixed ys SMC2 = fmap show $ runPopulationT $ smc2 t 3 2 1 param (model ys) 57 | -------------------------------------------------------------------------------- /models/Sprinkler.hs: -------------------------------------------------------------------------------- 1 | module Sprinkler (hard, soft) where 2 | 3 | import Control.Monad (when) 4 | import Control.Monad.Bayes.Class 5 | 6 | hard :: (MonadMeasure m) => m Bool 7 | hard = do 8 | rain <- bernoulli 0.3 9 | sprinkler <- bernoulli $ if rain then 0.1 else 0.4 10 | wet <- bernoulli $ case (rain, sprinkler) of 11 | (True, True) -> 0.98 12 | (True, False) -> 0.8 13 | (False, True) -> 0.9 14 | (False, False) -> 0.0 15 | condition (not wet) 16 | return rain 17 | 18 | soft :: (MonadMeasure m) => m Bool 19 | soft = do 20 | rain <- bernoulli 0.3 21 | when rain (factor 0.2) 22 | sprinkler <- bernoulli $ if rain then 0.1 else 0.4 23 | when sprinkler (factor 0.1) 24 | return rain 25 | -------------------------------------------------------------------------------- /models/StrictlySmallerSupport.hs: -------------------------------------------------------------------------------- 1 | -- A model in which a random value switches between 2 | -- two distributions, one with a support strictly 3 | -- smaller than the other. 4 | module StrictlySmallerSupport (model) where 5 | 6 | import Control.Monad.Bayes.Class 7 | 8 | model :: (MonadDistribution m) => m Bool 9 | model = do 10 | x <- bernoulli 0.5 11 | _ <- uniformD (if x then [1, 2] else [1, 2, 3, 4] :: [Int]) 12 | return x 13 | -------------------------------------------------------------------------------- /notebooks/_build/_page/Introduction/html/_sphinx_design_static/design-tabs.js: -------------------------------------------------------------------------------- 1 | var sd_labels_by_text = {}; 2 | 3 | function ready() { 4 | const li = document.getElementsByClassName("sd-tab-label"); 5 | for (const label of li) { 6 | syncId = label.getAttribute("data-sync-id"); 7 | if (syncId) { 8 | label.onclick = onLabelClick; 9 | if (!sd_labels_by_text[syncId]) { 10 | sd_labels_by_text[syncId] = []; 11 | } 12 | sd_labels_by_text[syncId].push(label); 13 | } 14 | } 15 | } 16 | 17 | function onLabelClick() { 18 | // Activate other inputs with the same sync id. 19 | syncId = this.getAttribute("data-sync-id"); 20 | for (label of sd_labels_by_text[syncId]) { 21 | if (label === this) continue; 22 | label.previousElementSibling.checked = true; 23 | } 24 | window.localStorage.setItem("sphinx-design-last-tab", syncId); 25 | } 26 | 27 | document.addEventListener("DOMContentLoaded", ready, false); 28 | -------------------------------------------------------------------------------- /notebooks/examples/Parsing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2a798ffe-30c0-43ec-9614-c84c2a433da5", 6 | "metadata": {}, 7 | "source": [ 8 | "# Probabilistic Parser Combinators" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "b14be2d4-ec16-46d3-8cd9-810a25f70495", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import Text.Megaparsec\n", 19 | "import Text.Megaparsec.Char\n", 20 | "import Data.Char\n", 21 | "import qualified Data.Text as T\n", 22 | "import Control.Monad.Bayes.Class\n", 23 | "import Control.Monad.Bayes.Sampler.Strict\n", 24 | "import Control.Monad.Bayes.Weighted\n", 25 | "import Control.Monad.Bayes.Population\n", 26 | "import Control.Monad.Bayes.Enumerator\n", 27 | "import Control.Monad.Bayes.Inference.SMC\n", 28 | "import Control.Monad.Trans (lift)\n", 29 | "import Control.Monad (join, replicateM)\n", 30 | "import Data.Void\n", 31 | "import Control.Monad.Bayes.Enumerator\n", 32 | "import Text.Pretty.Simple\n", 33 | "\n", 34 | ":e OverloadedStrings\n", 35 | ":e FlexibleContexts\n", 36 | ":e GADTs\n", 37 | ":e LambdaCase" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "97f7ae67-5a7f-4650-8530-72385d84b019", 43 | "metadata": {}, 44 | "source": [ 45 | "Probability interfaces nicely with parser combinators from libraries like `megaparsec`. A parser in this setting is roughly a function `[Char] -> m (a, [Char])`, in other words a function which (monadically) strips off a prefix of the input string and returns a result. \n", 46 | "\n", 47 | "To make this probabilistic, we simply set `m` to a probability monad. The result of parsing is then a distribution over possible parses (and possible parse failures).\n" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "id": "5aec194f-b147-456b-8be0-46c12f4ed495", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "alphabet = map (: []) ['h', 'i', 'x']\n", 58 | "\n", 59 | "noise x = do\n", 60 | " perturb <- lift $ bernoulli 0.1\n", 61 | " if perturb then lift $ uniformD alphabet else return x\n", 62 | "\n", 63 | "letter = do\n", 64 | " true <- lift $ uniformD [\"h\", \"i\",\"x\"]\n", 65 | " predicted <- noise true\n", 66 | " observed <- lookAhead (foldr1 (<|>) [\"h\",\"i\", \"x\"])\n", 67 | " lift . condition $ predicted == observed\n", 68 | " string observed\n", 69 | " return $ head true \n", 70 | " \n", 71 | "word = (do \n", 72 | " wd <- some letter\n", 73 | " lift $ factor (if wd `elem` [\"hi\", \"goodbye\"] then 100 else 1)\n", 74 | " return wd\n", 75 | " ) <* eof\n", 76 | "\n", 77 | "errorBundlePretty' :: (TraversableStream s, VisualStream s) => ParseErrorBundle s Void -> String \n", 78 | "errorBundlePretty' = errorBundlePretty\n", 79 | "\n", 80 | "\n", 81 | "run parser input = either (T.pack . errorBundlePretty' ) (T.pack . show) <$> runParserT parser \"\" input" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "id": "e64782da-ddcc-41bd-aa29-b8f2dcd27184", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "[\n", 94 | " ( \"\"hi\"\", 0.7563333333333333 )\n", 95 | ",\n", 96 | " ( \"\"hx\"\", 0.20799999999999993 )\n", 97 | ",\n", 98 | " ( \"\"xx\"\", 1.5000000000000038 e- 2 )\n", 99 | ",\n", 100 | " ( \"\"hh\"\", 1.06666666666667 e- 2 )\n", 101 | ",\n", 102 | " ( \"\"ix\"\", 1.0000000000000014 e- 2 )\n", 103 | "]" 104 | ] 105 | }, 106 | "metadata": {}, 107 | "output_type": "display_data" 108 | } 109 | ], 110 | "source": [ 111 | "pPrintCustom = pPrintOpt CheckColorTty defaultOutputOptionsNoColor {outputOptionsCompact = True, outputOptionsIndentAmount = 2} \n", 112 | "\n", 113 | "runWordParser w = do\n", 114 | " x <- sampler \n", 115 | " . runPopulationT \n", 116 | " . smc SMCConfig {numSteps = 5, numParticles = 3000, resampler = resampleMultinomial} \n", 117 | " $ run word w\n", 118 | " pPrintCustom $ toEmpiricalWeighted x\n", 119 | " \n", 120 | "runWordParser \"hx\"" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 4, 126 | "id": "bff38aec-9249-4d0e-b173-20ccecda9ab3", 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "[\n", 133 | " ( \"\"hi\"\", 0.7813333333333331 )\n", 134 | ",\n", 135 | " ( \"\"ii\"\", 0.2046666666666667 )\n", 136 | ",\n", 137 | " ( \"\"xi\"\", 6.66666666666668 e- 3 )\n", 138 | ",\n", 139 | " ( \"\"ix\"\", 4.333333333333346 e- 3 )\n", 140 | ",\n", 141 | " ( \"\"ih\"\", 1.6666666666666711 e- 3 )\n", 142 | ",\n", 143 | " ( \"\"xh\"\", 1.0000000000000028 e- 3 )\n", 144 | ",\n", 145 | " ( \"\"xx\"\", 3.333333333333342 e- 4 )\n", 146 | "]" 147 | ] 148 | }, 149 | "metadata": {}, 150 | "output_type": "display_data" 151 | } 152 | ], 153 | "source": [ 154 | "runWordParser \"ii\"" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 5, 160 | "id": "7ad652ee-3e29-4cce-8d50-e1fd2e5c26ea", 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "data": { 165 | "text/plain": [ 166 | "[\n", 167 | " ( \"\"hii\"\", 0.8063333333333336 )\n", 168 | ",\n", 169 | " ( \"\"xii\"\", 3.9333333333333186 e- 2 )\n", 170 | ",\n", 171 | " ( \"\"hhi\"\", 3.533333333333321 e- 2 )\n", 172 | ",\n", 173 | " ( \"\"hix\"\", 2.966666666666659 e- 2 )\n", 174 | ",\n", 175 | " ( \"\"hih\"\", 2.8999999999999908 e- 2 )\n", 176 | ",\n", 177 | " ( \"\"hxi\"\", 2.633333333333325 e- 2 )\n", 178 | ",\n", 179 | " ( \"\"iii\"\", 1.666666666666663 e- 2 )\n", 180 | ",\n", 181 | " ( \"\"hxx\"\", 4.333333333333327 e- 3 )\n", 182 | ",\n", 183 | " ( \"\"xih\"\", 3.999999999999995 e- 3 )\n", 184 | ",\n", 185 | " ( \"\"ixi\"\", 3.6666666666666636 e- 3 )\n", 186 | ",\n", 187 | " ( \"\"hhx\"\", 2.666666666666665 e- 3 )\n", 188 | ",\n", 189 | " ( \"\"xxi\"\", 2.3333333333333314 e- 3 )\n", 190 | ",\n", 191 | " ( \"\"hhh\"\", 3.333333333333324 e- 4 )\n", 192 | "]" 193 | ] 194 | }, 195 | "metadata": {}, 196 | "output_type": "display_data" 197 | } 198 | ], 199 | "source": [ 200 | "runWordParser \"hii\"" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Haskell - nixpkgs", 207 | "language": "haskell", 208 | "name": "ihaskell_nixpkgs" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": "ihaskell", 212 | "file_extension": ".hs", 213 | "mimetype": "text/x-haskell", 214 | "name": "haskell", 215 | "pygments_lexer": "Haskell", 216 | "version": "9.0.2" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 5 221 | } 222 | -------------------------------------------------------------------------------- /notebooks/file.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "John", 3 | "isAlive": true, 4 | "age": 27, 5 | "height": 1.5, 6 | "address": { 7 | "streetAddress": "21 2nd Street", 8 | "id" : 5.4 9 | } 10 | 11 | } -------------------------------------------------------------------------------- /notebooks/models/.ipynb_checkpoints/LDA-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "id": "01d81119-f1fc-4bd1-9e31-c3477b8edebc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | ":e OverloadedStrings\n", 11 | ":l ../models/LDA.hs" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 7, 17 | "id": "635da03a-4d61-453f-bbd1-cc81b1315200", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "( fromList\n", 24 | " [\n", 25 | " ( \"topic1\"\n", 26 | " ,\n", 27 | " [\n", 28 | " ( \"bear\"\n", 29 | " , 0.34365781935415857\n", 30 | " )\n", 31 | " ,\n", 32 | " ( \"wolf\"\n", 33 | " , 0.2556245339526619\n", 34 | " )\n", 35 | " ,\n", 36 | " ( \"python\"\n", 37 | " , 0.24562951016031953\n", 38 | " )\n", 39 | " ,\n", 40 | " ( \"prolog\"\n", 41 | " , 0.15508813653285997\n", 42 | " )\n", 43 | " ]\n", 44 | " )\n", 45 | " ,\n", 46 | " ( \"topic2\"\n", 47 | " ,\n", 48 | " [\n", 49 | " ( \"bear\"\n", 50 | " , 3.910578485111217 e- 2\n", 51 | " )\n", 52 | " ,\n", 53 | " ( \"wolf\"\n", 54 | " , 0.4304465204440071\n", 55 | " )\n", 56 | " ,\n", 57 | " ( \"python\"\n", 58 | " , 0.1486640274645592\n", 59 | " )\n", 60 | " ,\n", 61 | " ( \"prolog\"\n", 62 | " , 0.3817836672403215\n", 63 | " )\n", 64 | " ]\n", 65 | " )\n", 66 | " ]\n", 67 | ",\n", 68 | " [\n", 69 | " ( \"bear wolf bear wolf bear wolf python wolf bear wolf\"\n", 70 | " ,\n", 71 | " [\n", 72 | " ( \"topic1\"\n", 73 | " , 0.9047101354542013\n", 74 | " )\n", 75 | " ,\n", 76 | " ( \"topic2\"\n", 77 | " , 9.528986454579876 e- 2\n", 78 | " )\n", 79 | " ]\n", 80 | " )\n", 81 | " ,\n", 82 | " ( \"python prolog python prolog python prolog python prolog python prolog\"\n", 83 | " ,\n", 84 | " [\n", 85 | " ( \"topic1\"\n", 86 | " , 9.413029740935884 e- 2\n", 87 | " )\n", 88 | " ,\n", 89 | " ( \"topic2\"\n", 90 | " , 0.9058697025906413\n", 91 | " )\n", 92 | " ]\n", 93 | " )\n", 94 | " ,\n", 95 | " ( \"bear wolf bear wolf bear wolf bear wolf bear wolf\"\n", 96 | " ,\n", 97 | " [\n", 98 | " ( \"topic1\"\n", 99 | " , 0.6835490762786085\n", 100 | " )\n", 101 | " ,\n", 102 | " ( \"topic2\"\n", 103 | " , 0.31645092372139155\n", 104 | " )\n", 105 | " ]\n", 106 | " )\n", 107 | " ,\n", 108 | " ( \"python prolog python prolog python prolog python prolog python prolog\"\n", 109 | " ,\n", 110 | " [\n", 111 | " ( \"topic1\"\n", 112 | " , 3.698500977423746 e- 2\n", 113 | " )\n", 114 | " ,\n", 115 | " ( \"topic2\"\n", 116 | " , 0.9630149902257626\n", 117 | " )\n", 118 | " ]\n", 119 | " )\n", 120 | " ,\n", 121 | " ( \"bear wolf bear python bear wolf bear wolf bear wolf\"\n", 122 | " ,\n", 123 | " [\n", 124 | " ( \"topic1\"\n", 125 | " , 0.9101537223185177\n", 126 | " )\n", 127 | " ,\n", 128 | " ( \"topic2\"\n", 129 | " , 8.98462776814825 e- 2\n", 130 | " )\n", 131 | " ]\n", 132 | " )\n", 133 | " ]\n", 134 | ")" 135 | ] 136 | }, 137 | "metadata": {}, 138 | "output_type": "display_data" 139 | } 140 | ], 141 | "source": [ 142 | "runLDA" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "ce857de6-2b0c-44ce-893f-a67ad39f3040", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Haskell - nixpkgs", 157 | "language": "haskell", 158 | "name": "ihaskell_nixpkgs" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": "ihaskell", 162 | "file_extension": ".hs", 163 | "mimetype": "text/x-haskell", 164 | "name": "haskell", 165 | "pygments_lexer": "Haskell", 166 | "version": "9.0.2" 167 | }, 168 | "toc-autonumbering": false 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 5 172 | } 173 | -------------------------------------------------------------------------------- /notebooks/models/LDA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "id": "01d81119-f1fc-4bd1-9e31-c3477b8edebc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | ":e OverloadedStrings\n", 11 | ":l ../models/LDA.hs" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 7, 17 | "id": "635da03a-4d61-453f-bbd1-cc81b1315200", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "( fromList\n", 24 | " [\n", 25 | " ( \"topic1\"\n", 26 | " ,\n", 27 | " [\n", 28 | " ( \"bear\"\n", 29 | " , 0.34365781935415857\n", 30 | " )\n", 31 | " ,\n", 32 | " ( \"wolf\"\n", 33 | " , 0.2556245339526619\n", 34 | " )\n", 35 | " ,\n", 36 | " ( \"python\"\n", 37 | " , 0.24562951016031953\n", 38 | " )\n", 39 | " ,\n", 40 | " ( \"prolog\"\n", 41 | " , 0.15508813653285997\n", 42 | " )\n", 43 | " ]\n", 44 | " )\n", 45 | " ,\n", 46 | " ( \"topic2\"\n", 47 | " ,\n", 48 | " [\n", 49 | " ( \"bear\"\n", 50 | " , 3.910578485111217 e- 2\n", 51 | " )\n", 52 | " ,\n", 53 | " ( \"wolf\"\n", 54 | " , 0.4304465204440071\n", 55 | " )\n", 56 | " ,\n", 57 | " ( \"python\"\n", 58 | " , 0.1486640274645592\n", 59 | " )\n", 60 | " ,\n", 61 | " ( \"prolog\"\n", 62 | " , 0.3817836672403215\n", 63 | " )\n", 64 | " ]\n", 65 | " )\n", 66 | " ]\n", 67 | ",\n", 68 | " [\n", 69 | " ( \"bear wolf bear wolf bear wolf python wolf bear wolf\"\n", 70 | " ,\n", 71 | " [\n", 72 | " ( \"topic1\"\n", 73 | " , 0.9047101354542013\n", 74 | " )\n", 75 | " ,\n", 76 | " ( \"topic2\"\n", 77 | " , 9.528986454579876 e- 2\n", 78 | " )\n", 79 | " ]\n", 80 | " )\n", 81 | " ,\n", 82 | " ( \"python prolog python prolog python prolog python prolog python prolog\"\n", 83 | " ,\n", 84 | " [\n", 85 | " ( \"topic1\"\n", 86 | " , 9.413029740935884 e- 2\n", 87 | " )\n", 88 | " ,\n", 89 | " ( \"topic2\"\n", 90 | " , 0.9058697025906413\n", 91 | " )\n", 92 | " ]\n", 93 | " )\n", 94 | " ,\n", 95 | " ( \"bear wolf bear wolf bear wolf bear wolf bear wolf\"\n", 96 | " ,\n", 97 | " [\n", 98 | " ( \"topic1\"\n", 99 | " , 0.6835490762786085\n", 100 | " )\n", 101 | " ,\n", 102 | " ( \"topic2\"\n", 103 | " , 0.31645092372139155\n", 104 | " )\n", 105 | " ]\n", 106 | " )\n", 107 | " ,\n", 108 | " ( \"python prolog python prolog python prolog python prolog python prolog\"\n", 109 | " ,\n", 110 | " [\n", 111 | " ( \"topic1\"\n", 112 | " , 3.698500977423746 e- 2\n", 113 | " )\n", 114 | " ,\n", 115 | " ( \"topic2\"\n", 116 | " , 0.9630149902257626\n", 117 | " )\n", 118 | " ]\n", 119 | " )\n", 120 | " ,\n", 121 | " ( \"bear wolf bear python bear wolf bear wolf bear wolf\"\n", 122 | " ,\n", 123 | " [\n", 124 | " ( \"topic1\"\n", 125 | " , 0.9101537223185177\n", 126 | " )\n", 127 | " ,\n", 128 | " ( \"topic2\"\n", 129 | " , 8.98462776814825 e- 2\n", 130 | " )\n", 131 | " ]\n", 132 | " )\n", 133 | " ]\n", 134 | ")" 135 | ] 136 | }, 137 | "metadata": {}, 138 | "output_type": "display_data" 139 | } 140 | ], 141 | "source": [ 142 | "runLDA" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "ce857de6-2b0c-44ce-893f-a67ad39f3040", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": "Haskell - nixpkgs", 157 | "language": "haskell", 158 | "name": "ihaskell_nixpkgs" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": "ihaskell", 162 | "file_extension": ".hs", 163 | "mimetype": "text/x-haskell", 164 | "name": "haskell", 165 | "pygments_lexer": "Haskell", 166 | "version": "9.0.2" 167 | }, 168 | "toc-autonumbering": false 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 5 172 | } 173 | -------------------------------------------------------------------------------- /notebooks/plotting.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | 5 | module Plotting where 6 | 7 | import Control.Arrow (first, second) 8 | import qualified Data.Text as T 9 | import Graphics.Vega.VegaLite hiding (filter, length) 10 | import IHaskell.Display.Hvega 11 | 12 | hist (xs, ys) = 13 | let enc = 14 | encoding 15 | . position X [PName "X", PmType Quantitative] 16 | . position Y [PName "Y", PmType Quantitative] 17 | 18 | dat = 19 | ( dataFromColumns [] 20 | . dataColumn "X" (Numbers xs) 21 | . dataColumn "Y" (Numbers ys) 22 | ) 23 | [] 24 | in toVegaLite 25 | [ dat, 26 | mark Bar [], 27 | enc [], 28 | width 400, 29 | height 400 30 | ] 31 | 32 | barplot (xs, ys) = 33 | let enc = 34 | encoding 35 | . position X [PName "X", PmType Nominal] 36 | . position Y [PName "Y", PmType Quantitative] 37 | 38 | dat = 39 | ( dataFromColumns [] 40 | . dataColumn "X" (Strings xs) 41 | . dataColumn "Y" (Numbers ys) 42 | ) 43 | [] 44 | in toVegaLite 45 | [ dat, 46 | mark Bar [], 47 | enc [], 48 | width 400, 49 | height 400 50 | ] 51 | 52 | scatterplot ((xs, ys), cs) cE f mode = 53 | let enc = 54 | encoding 55 | . position X [PName "X", PmType Quantitative] 56 | . position Y [PName "Y", PmType Quantitative] 57 | . cE 58 | 59 | dat = 60 | ( dataFromColumns [] 61 | . dataColumn "X" (Numbers xs) 62 | . dataColumn "Y" (Numbers ys) 63 | . dataColumn "Outlier" (f cs) 64 | ) 65 | [] 66 | in toVegaLite 67 | [ dat, 68 | mark mode [], 69 | enc [], 70 | width 400, 71 | height 400 72 | ] 73 | 74 | class Plottable a where 75 | plot :: a -> VegaLiteLab 76 | 77 | instance Plottable [((Double, Double), T.Text)] where 78 | plot ls = 79 | vlShow $ 80 | scatterplot 81 | (first unzip $ unzip ls) 82 | (color [MName "Outlier"]) 83 | (\cs -> (Strings (T.pack . show <$> cs))) 84 | Circle 85 | 86 | instance Plottable [((Double, Double), Double)] where 87 | plot ls = 88 | vlShow $ 89 | scatterplot 90 | (first unzip $ unzip (ls)) 91 | ( color 92 | [ MName "Outlier", 93 | MmType Quantitative, 94 | MScale 95 | [ SScheme "viridis" [] 96 | ] 97 | ] 98 | ) 99 | Numbers 100 | Circle 101 | 102 | instance Plottable ([Double], (Double, Double)) where 103 | plot ls = 104 | let cs = take (length $ fst ls) $ Prelude.repeat 1 105 | xs = fst ls 106 | (slope, intercept) = snd ls 107 | ys = (+ intercept) . (* slope) <$> xs 108 | in vlShow $ 109 | scatterplot 110 | ((xs, ys), cs) 111 | (color []) 112 | Numbers 113 | Line 114 | 115 | instance Plottable [(T.Text, Double)] where 116 | plot ls = vlShow $ barplot $ unzip ls 117 | 118 | instance Plottable [(Double, Double)] where 119 | plot ls = vlShow $ hist $ unzip ls 120 | 121 | instance Plottable ([Double], [Double]) where 122 | plot (xs, ys) = 123 | vlShow $ 124 | scatterplot ((xs, ys), replicate (length xs) 1) (color []) Numbers Line 125 | 126 | type Plot = VegaLiteLab 127 | -------------------------------------------------------------------------------- /profile.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # Generates profiling information for selected models and inference algorithms. 3 | 4 | stack build --profile 5 | for model in LR100 HMM100 LDA50 6 | do 7 | for alg in SMC MH RMSMC 8 | do 9 | echo "Profiling $alg on $model" 10 | stack exec --profile -- example -a $alg -m $model +RTS -p >/dev/null 11 | file=$model-$alg 12 | mv example.prof $file.prof 13 | sed 's/no location info/no_location_info/' $file.prof | awk '$8 !~ /0\.0/' >$file-small.prof 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /regenerate_notebooks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # Regenerates the notebook html files in `docs/docs/notebooks/` 3 | 4 | nix --print-build-logs develop --command jupyter-nbconvert --to html notebooks/examples/*.ipynb --output-dir docs/docs/notebooks/ 5 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | ( 2 | import 3 | ( 4 | let 5 | lock = builtins.fromJSON (builtins.readFile ./flake.lock); 6 | in 7 | fetchTarball { 8 | url = "https://github.com/edolstra/flake-compat/archive/${lock.nodes.flake-compat.locked.rev}.tar.gz"; 9 | sha256 = lock.nodes.flake-compat.locked.narHash; 10 | } 11 | ) 12 | {src = ./.;} 13 | ) 14 | .shellNix 15 | -------------------------------------------------------------------------------- /src/Control/Applicative/List.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE StandaloneDeriving #-} 2 | 3 | module Control.Applicative.List where 4 | 5 | -- base 6 | import Control.Applicative 7 | import Data.Functor.Compose 8 | 9 | -- * Applicative ListT 10 | 11 | -- | _Applicative_ transformer adding a list/nondeterminism/choice effect. 12 | -- It is not a valid monad transformer, but it is a valid 'Applicative'. 13 | newtype ListT m a = ListT {getListT :: Compose m [] a} 14 | deriving newtype (Functor, Applicative, Alternative) 15 | 16 | listT :: m [a] -> ListT m a 17 | listT = ListT . Compose 18 | 19 | lift :: (Functor m) => m a -> ListT m a 20 | lift = ListT . Compose . fmap pure 21 | 22 | runListT :: ListT m a -> m [a] 23 | runListT = getCompose . getListT 24 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Density/Free.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | {-# LANGUAGE MultiParamTypeClasses #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | 6 | -- | 7 | -- Module : Control.Monad.Bayes.Density.Free 8 | -- Description : Free monad transformer over random sampling 9 | -- Copyright : (c) Adam Scibior, 2015-2020 10 | -- License : MIT 11 | -- Maintainer : leonhard.markert@tweag.io 12 | -- Stability : experimental 13 | -- Portability : GHC 14 | -- 15 | -- 'DensityT' is a free monad transformer over random sampling. 16 | module Control.Monad.Bayes.Density.Free 17 | ( DensityT (..), 18 | hoist, 19 | interpret, 20 | withRandomness, 21 | runDensityT, 22 | traced, 23 | ) 24 | where 25 | 26 | import Control.Monad.Bayes.Class (MonadDistribution (random)) 27 | import Control.Monad.RWS 28 | import Control.Monad.State (evalStateT) 29 | import Control.Monad.Trans.Free.Church (FT, MonadFree (..), hoistFT, iterT, iterTM, liftF) 30 | import Control.Monad.Writer (WriterT (..)) 31 | import Data.Functor.Identity (Identity, runIdentity) 32 | 33 | -- | Random sampling functor. 34 | newtype SamF a = Random (Double -> a) deriving (Functor) 35 | 36 | -- | Free monad transformer over random sampling. 37 | -- 38 | -- Uses the Church-encoded version of the free monad for efficiency. 39 | newtype DensityT m a = DensityT {getDensityT :: FT SamF m a} 40 | deriving newtype (Functor, Applicative, Monad, MonadTrans) 41 | 42 | instance MonadFree SamF (DensityT m) where 43 | wrap = DensityT . wrap . fmap getDensityT 44 | 45 | instance (Monad m) => MonadDistribution (DensityT m) where 46 | random = DensityT $ liftF (Random id) 47 | 48 | -- | Hoist 'DensityT' through a monad transform. 49 | hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> DensityT m a -> DensityT n a 50 | hoist f (DensityT m) = DensityT (hoistFT f m) 51 | 52 | -- | Execute random sampling in the transformed monad. 53 | interpret :: (MonadDistribution m) => DensityT m a -> m a 54 | interpret (DensityT m) = iterT f m 55 | where 56 | f (Random k) = random >>= k 57 | 58 | -- | Execute computation with supplied values for random choices. 59 | withRandomness :: (Monad m) => [Double] -> DensityT m a -> m a 60 | withRandomness randomness (DensityT m) = evalStateT (iterTM f m) randomness 61 | where 62 | f (Random k) = do 63 | xs <- get 64 | case xs of 65 | [] -> error "DensityT: the list of randomness was too short" 66 | y : ys -> put ys >> k y 67 | 68 | -- | Execute computation with supplied values for a subset of random choices. 69 | -- Return the output value and a record of all random choices used, whether 70 | -- taken as input or drawn using the transformed monad. 71 | runDensityT :: (MonadDistribution m) => [Double] -> DensityT m a -> m (a, [Double]) 72 | runDensityT randomness (DensityT m) = 73 | runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness 74 | where 75 | f (Random k) = do 76 | -- This block runs in StateT [Double] (WriterT [Double]) m. 77 | -- StateT propagates consumed randomness while WriterT records 78 | -- randomness used, whether old or new. 79 | xs <- get 80 | x <- case xs of 81 | [] -> random 82 | y : ys -> put ys >> return y 83 | tell [x] 84 | k x 85 | 86 | -- | Like 'density', but use an arbitrary sampling monad. 87 | traced :: (MonadDistribution m) => [Double] -> DensityT Identity a -> m (a, [Double]) 88 | traced randomness m = runDensityT randomness $ hoist (return . runIdentity) m 89 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Density/State.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE FlexibleInstances #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | 6 | -- | 7 | -- Slower than Control.Monad.Bayes.Density.Free, so not used by default, 8 | -- but more elementary to understand. Just uses standard 9 | -- monad transformer techniques. 10 | module Control.Monad.Bayes.Density.State where 11 | 12 | import Control.Monad.Bayes.Class (MonadDistribution (random)) 13 | import Control.Monad.State (MonadState (get, put), StateT, evalStateT) 14 | import Control.Monad.Writer 15 | 16 | newtype DensityT m a = DensityT {getDensityT :: WriterT [Double] (StateT [Double] m) a} deriving newtype (Functor, Applicative, Monad) 17 | 18 | instance MonadTrans DensityT where 19 | lift = DensityT . lift . lift 20 | 21 | instance (Monad m) => MonadState [Double] (DensityT m) where 22 | get = DensityT $ lift $ get 23 | put = DensityT . lift . put 24 | 25 | instance (Monad m) => MonadWriter [Double] (DensityT m) where 26 | tell = DensityT . tell 27 | listen = DensityT . listen . getDensityT 28 | pass = DensityT . pass . getDensityT 29 | 30 | instance (MonadDistribution m) => MonadDistribution (DensityT m) where 31 | random = do 32 | trace <- get 33 | x <- case trace of 34 | [] -> random 35 | r : xs -> put xs >> pure r 36 | tell [x] 37 | pure x 38 | 39 | runDensityT :: (Monad m) => DensityT m b -> [Double] -> m (b, [Double]) 40 | runDensityT (DensityT m) = evalStateT (runWriterT m) 41 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Enumerator.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | {-# LANGUAGE ImportQualifiedPost #-} 4 | 5 | -- | 6 | -- Module : Control.Monad.Bayes.Enumerator 7 | -- Description : Exhaustive enumeration of discrete random variables 8 | -- Copyright : (c) Adam Scibior, 2015-2020 9 | -- License : MIT 10 | -- Maintainer : leonhard.markert@tweag.io 11 | -- Stability : experimental 12 | -- Portability : GHC 13 | module Control.Monad.Bayes.Enumerator 14 | ( Enumerator, 15 | logExplicit, 16 | explicit, 17 | evidence, 18 | mass, 19 | compact, 20 | enumerator, 21 | enumerate, 22 | expectation, 23 | normalForm, 24 | toEmpirical, 25 | toEmpiricalWeighted, 26 | normalizeWeights, 27 | enumerateToDistribution, 28 | removeZeros, 29 | fromList, 30 | ) 31 | where 32 | 33 | import Control.Applicative (Alternative) 34 | import Control.Arrow (second) 35 | import Control.Monad (MonadPlus) 36 | import Control.Monad.Bayes.Class 37 | ( MonadDistribution (bernoulli, categorical, logCategorical, random), 38 | MonadFactor (..), 39 | MonadMeasure, 40 | ) 41 | import Control.Monad.Writer (WriterT (..)) 42 | import Data.AEq (AEq, (===), (~==)) 43 | import Data.List (sortOn) 44 | import Data.Map qualified as Map 45 | import Data.Maybe (fromMaybe) 46 | import Data.Monoid (Product (..)) 47 | import Data.Ord (Down (Down)) 48 | import Data.Vector qualified as VV 49 | import Data.Vector.Generic qualified as V 50 | import Numeric.Log as Log (Log (..), sum) 51 | 52 | -- | An exact inference transformer that integrates 53 | -- discrete random variables by enumerating all execution paths. 54 | newtype Enumerator a = Enumerator (WriterT (Product (Log Double)) [] a) 55 | deriving newtype (Functor, Applicative, Monad, Alternative, MonadPlus) 56 | 57 | instance MonadDistribution Enumerator where 58 | random = error "Infinitely supported random variables not supported in Enumerator" 59 | bernoulli p = fromList [(True, (Exp . log) p), (False, (Exp . log) (1 - p))] 60 | categorical v = fromList $ zip [0 ..] $ map (Exp . log) (V.toList v) 61 | 62 | instance MonadFactor Enumerator where 63 | score w = fromList [((), w)] 64 | 65 | instance MonadMeasure Enumerator 66 | 67 | -- | Construct Enumerator from a list of values and associated weights. 68 | fromList :: [(a, Log Double)] -> Enumerator a 69 | fromList = Enumerator . WriterT . map (second Product) 70 | 71 | -- | Returns the posterior as a list of weight-value pairs without any post-processing, 72 | -- such as normalization or aggregation 73 | logExplicit :: Enumerator a -> [(a, Log Double)] 74 | logExplicit (Enumerator m) = map (second getProduct) $ runWriterT m 75 | 76 | -- | Same as `toList`, only weights are converted from log-domain. 77 | explicit :: Enumerator a -> [(a, Double)] 78 | explicit = map (second (exp . ln)) . logExplicit 79 | 80 | -- | Returns the model evidence, that is sum of all weights. 81 | evidence :: Enumerator a -> Log Double 82 | evidence = Log.sum . map snd . logExplicit 83 | 84 | -- | Normalized probability mass of a specific value. 85 | mass :: (Ord a) => Enumerator a -> a -> Double 86 | mass d = f 87 | where 88 | f a = fromMaybe 0 $ lookup a m 89 | m = enumerator d 90 | 91 | -- | Aggregate weights of equal values. 92 | -- The resulting list is sorted ascendingly according to values. 93 | compact :: (Num r, Ord a, Ord r) => [(a, r)] -> [(a, r)] 94 | compact = sortOn (Down . snd) . Map.toAscList . Map.fromListWith (+) 95 | 96 | -- | Aggregate and normalize of weights. 97 | -- The resulting list is sorted ascendingly according to values. 98 | -- 99 | -- > enumerator = compact . explicit 100 | enumerator, enumerate :: (Ord a) => Enumerator a -> [(a, Double)] 101 | enumerator d = filter ((/= 0) . snd) $ compact (zip xs ws) 102 | where 103 | (xs, ws) = second (map (exp . ln) . normalize) $ unzip (logExplicit d) 104 | 105 | -- | deprecated synonym 106 | enumerate = enumerator 107 | 108 | -- | Expectation of a given function computed using normalized weights. 109 | expectation :: (a -> Double) -> Enumerator a -> Double 110 | expectation f = Prelude.sum . map (\(x, w) -> f x * (exp . ln) w) . normalizeWeights . logExplicit 111 | 112 | normalize :: (Fractional b) => [b] -> [b] 113 | normalize xs = map (/ z) xs 114 | where 115 | z = Prelude.sum xs 116 | 117 | -- | Divide all weights by their sum. 118 | normalizeWeights :: (Fractional b) => [(a, b)] -> [(a, b)] 119 | normalizeWeights ls = zip xs ps 120 | where 121 | (xs, ws) = unzip ls 122 | ps = normalize ws 123 | 124 | -- | 'compact' followed by removing values with zero weight. 125 | normalForm :: (Ord a) => Enumerator a -> [(a, Double)] 126 | normalForm = filter ((/= 0) . snd) . compact . explicit 127 | 128 | toEmpirical :: (Fractional b, Ord a, Ord b) => [a] -> [(a, b)] 129 | toEmpirical ls = normalizeWeights $ compact (zip ls (repeat 1)) 130 | 131 | toEmpiricalWeighted :: (Fractional b, Ord a, Ord b) => [(a, b)] -> [(a, b)] 132 | toEmpiricalWeighted = normalizeWeights . compact 133 | 134 | enumerateToDistribution :: (MonadDistribution n) => Enumerator a -> n a 135 | enumerateToDistribution model = do 136 | let samples = logExplicit model 137 | let (support, logprobs) = unzip samples 138 | i <- logCategorical $ VV.fromList logprobs 139 | return $ support !! i 140 | 141 | removeZeros :: Enumerator a -> Enumerator a 142 | removeZeros (Enumerator (WriterT a)) = Enumerator $ WriterT $ filter ((\(Product x) -> x /= 0) . snd) a 143 | 144 | instance (Ord a) => Eq (Enumerator a) where 145 | p == q = normalForm p == normalForm q 146 | 147 | instance (Ord a) => AEq (Enumerator a) where 148 | p === q = xs == ys && ps === qs 149 | where 150 | (xs, ps) = unzip (normalForm p) 151 | (ys, qs) = unzip (normalForm q) 152 | p ~== q = xs == ys && ps ~== qs 153 | where 154 | (xs, ps) = unzip $ filter (not . (~== 0) . snd) $ normalForm p 155 | (ys, qs) = unzip $ filter (not . (~== 0) . snd) $ normalForm q 156 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/Lazy/MH.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | {-# LANGUAGE RankNTypes #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# OPTIONS_GHC -Wno-name-shadowing #-} 5 | 6 | module Control.Monad.Bayes.Inference.Lazy.MH where 7 | 8 | import Control.Monad.Bayes.Class (Log (ln)) 9 | import Control.Monad.Bayes.Sampler.Lazy 10 | ( Sampler, 11 | Tree (..), 12 | Trees (..), 13 | randomTree, 14 | runSampler, 15 | ) 16 | import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT) 17 | import Control.Monad.Extra (iterateM) 18 | import Control.Monad.State.Lazy (MonadState (get, put), runState) 19 | import System.Random (RandomGen (split), getStdGen, newStdGen) 20 | import System.Random qualified as R 21 | 22 | mh :: forall a. Double -> WeightedT Sampler a -> IO [(a, Log Double)] 23 | mh p m = do 24 | -- Top level: produce a stream of samples. 25 | -- Split the random number generator in two 26 | -- One part is used as the first seed for the simulation, 27 | -- and one part is used for the randomness in the MH algorithm. 28 | g <- newStdGen >> getStdGen 29 | let (g1, g2) = split g 30 | let t = randomTree g1 31 | let (x, w) = runSampler (runWeightedT m) t 32 | -- Now run step over and over to get a stream of (tree,result,weight)s. 33 | let (samples, _) = runState (iterateM step (t, x, w)) g2 34 | -- The stream of seeds is used to produce a stream of result/weight pairs. 35 | return $ map (\(_, x, w) -> (x, w)) samples 36 | where 37 | -- where 38 | {- NB There are three kinds of randomness in the step function. 39 | 1. The start tree 't', which is the source of randomness for simulating the 40 | program m to start with. This is sort-of the point in the "state space". 41 | 2. The randomness needed to propose a new tree ('g1') 42 | 3. The randomness needed to decide whether to accept or reject that ('g2') 43 | The tree t is an argument and result, 44 | but we use a state monad ('get'/'put') to deal with the other randomness '(g,g1,g2)' -} 45 | 46 | -- step :: RandomGen g => (Tree, a, Log Double) -> State g (Tree, a, Log Double) 47 | step (t, x, w) = do 48 | -- Randomly change some sites 49 | g <- get 50 | let (g1, g2) = split g 51 | let t' = mutateTree p g1 t 52 | -- Rerun the model with the new tree, to get a new 53 | -- weight w'. 54 | let (x', w') = runSampler (runWeightedT m) t' 55 | -- MH acceptance ratio. This is the probability of either 56 | -- returning the new seed or the old one. 57 | let ratio = w' / w 58 | let (r, g2') = R.random g2 59 | put g2' 60 | if r < min 1 (exp $ ln ratio) 61 | then return (t', x', w') 62 | else return (t, x, w) 63 | 64 | -- Replace the labels of a tree randomly, with probability p 65 | mutateTree :: forall g. (RandomGen g) => Double -> g -> Tree -> Tree 66 | mutateTree p g (Tree a ts) = 67 | let (a', g') = (R.random g :: (Double, g)) 68 | (a'', g'') = R.random g' 69 | in Tree 70 | { currentUniform = if a' < p then a'' else a, 71 | lazyUniforms = mutateTrees p g'' ts 72 | } 73 | 74 | mutateTrees :: (RandomGen g) => Double -> g -> Trees -> Trees 75 | mutateTrees p g (Trees t ts) = 76 | let (g1, g2) = split g 77 | in Trees 78 | { headTree = mutateTree p g1 t, 79 | tailTrees = mutateTrees p g2 ts 80 | } 81 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/Lazy/WIS.hs: -------------------------------------------------------------------------------- 1 | module Control.Monad.Bayes.Inference.Lazy.WIS where 2 | 3 | import Control.Monad (guard) 4 | import Control.Monad.Bayes.Sampler.Lazy (SamplerT, weightedSamples) 5 | import Control.Monad.Bayes.Weighted (WeightedT) 6 | import Data.Maybe (mapMaybe) 7 | import Numeric.Log (Log (Exp)) 8 | import System.Random (Random (randoms), getStdGen, newStdGen) 9 | 10 | -- | Weighted Importance Sampling 11 | 12 | -- | Likelihood weighted importance sampling first draws n weighted samples, 13 | -- and then samples a stream of results from that regarded as an empirical distribution 14 | lwis :: Int -> WeightedT (SamplerT IO) a -> IO [a] 15 | lwis n m = do 16 | xws <- weightedSamples m 17 | let xws' = take n $ accumulate xws 0 18 | let max' = snd $ last xws' 19 | _ <- newStdGen 20 | rs <- randoms <$> getStdGen 21 | return $ take 1 =<< fmap (\r -> mapMaybe (\(a, p) -> guard (p >= Exp (log r) * max') >> Just a) xws') rs 22 | where 23 | accumulate :: (Num t) => [(a, t)] -> t -> [(a, t)] 24 | accumulate ((x, w) : xws) a = (x, w + a) : (x, w + a) : accumulate xws (w + a) 25 | accumulate [] _ = [] 26 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/MCMC.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | 4 | -- | 5 | -- Module : Control.Monad.Bayes.Inference.MCMC 6 | -- Description : Markov Chain Monte Carlo (MCMC) 7 | -- Copyright : (c) Adam Scibior, 2015-2020 8 | -- License : MIT 9 | -- Maintainer : tweag.io 10 | -- Stability : experimental 11 | -- Portability : GHC 12 | module Control.Monad.Bayes.Inference.MCMC where 13 | 14 | import Control.Monad.Bayes.Class (MonadDistribution) 15 | import Control.Monad.Bayes.Traced.Basic qualified as Basic 16 | import Control.Monad.Bayes.Traced.Common 17 | ( MHResult (MHResult, trace), 18 | Trace (probDensity), 19 | burnIn, 20 | mhTransWithBool, 21 | ) 22 | import Control.Monad.Bayes.Traced.Dynamic qualified as Dynamic 23 | import Control.Monad.Bayes.Traced.Static qualified as Static 24 | import Control.Monad.Bayes.Weighted (WeightedT, unweighted) 25 | import Pipes ((>->)) 26 | import Pipes qualified as P 27 | import Pipes.Prelude qualified as P 28 | 29 | data Proposal = SingleSiteMH 30 | 31 | data MCMCConfig = MCMCConfig {proposal :: Proposal, numMCMCSteps :: Int, numBurnIn :: Int} 32 | 33 | defaultMCMCConfig :: MCMCConfig 34 | defaultMCMCConfig = MCMCConfig {proposal = SingleSiteMH, numMCMCSteps = 1, numBurnIn = 0} 35 | 36 | mcmc :: (MonadDistribution m) => MCMCConfig -> Static.TracedT (WeightedT m) a -> m [a] 37 | mcmc (MCMCConfig {..}) m = burnIn numBurnIn $ unweighted $ Static.mh numMCMCSteps m 38 | 39 | mcmcBasic :: (MonadDistribution m) => MCMCConfig -> Basic.TracedT (WeightedT m) a -> m [a] 40 | mcmcBasic (MCMCConfig {..}) m = burnIn numBurnIn $ unweighted $ Basic.mh numMCMCSteps m 41 | 42 | mcmcDynamic :: (MonadDistribution m) => MCMCConfig -> Dynamic.TracedT (WeightedT m) a -> m [a] 43 | mcmcDynamic (MCMCConfig {..}) m = burnIn numBurnIn $ unweighted $ Dynamic.mh numMCMCSteps m 44 | 45 | -- -- | draw iid samples until you get one that has non-zero likelihood 46 | independentSamples :: (Monad m) => Static.TracedT m a -> P.Producer (MHResult a) m (Trace a) 47 | independentSamples (Static.TracedT _w d) = 48 | P.repeatM d 49 | >-> P.takeWhile' ((== 0) . probDensity) 50 | >-> P.map (MHResult False) 51 | 52 | -- | convert a probabilistic program into a producer of samples 53 | mcmcP :: (MonadDistribution m) => MCMCConfig -> Static.TracedT m a -> P.Producer (MHResult a) m () 54 | mcmcP MCMCConfig {..} m@(Static.TracedT w _) = do 55 | initialValue <- independentSamples m >-> P.drain 56 | ( P.unfoldr (fmap (Right . (\k -> (k, trace k))) . mhTransWithBool w) initialValue 57 | >-> P.drop numBurnIn 58 | ) 59 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/PMMH.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | 3 | -- | 4 | -- Module : Control.Monad.Bayes.Inference.PMMH 5 | -- Description : Particle Marginal Metropolis-Hastings (PMMH) 6 | -- Copyright : (c) Adam Scibior, 2015-2020 7 | -- License : MIT 8 | -- Maintainer : leonhard.markert@tweag.io 9 | -- Stability : experimental 10 | -- Portability : GHC 11 | -- 12 | -- Particle Marginal Metropolis-Hastings (PMMH) sampling. 13 | -- 14 | -- Christophe Andrieu, Arnaud Doucet, and Roman Holenstein. 2010. Particle Markov chain Monte Carlo Methods. /Journal of the Royal Statistical Society/ 72 (2010), 269-342. 15 | module Control.Monad.Bayes.Inference.PMMH 16 | ( pmmh, 17 | pmmhBayesianModel, 18 | ) 19 | where 20 | 21 | import Control.Monad.Bayes.Class (Bayesian (generative), MonadDistribution, MonadMeasure, prior) 22 | import Control.Monad.Bayes.Inference.MCMC (MCMCConfig, mcmc) 23 | import Control.Monad.Bayes.Inference.SMC (SMCConfig (), smc) 24 | import Control.Monad.Bayes.Population as Pop 25 | ( PopulationT, 26 | hoist, 27 | pushEvidence, 28 | runPopulationT, 29 | ) 30 | import Control.Monad.Bayes.Sequential.Coroutine (SequentialT) 31 | import Control.Monad.Bayes.Traced.Static (TracedT) 32 | import Control.Monad.Bayes.Weighted 33 | import Control.Monad.Trans (lift) 34 | import Numeric.Log (Log) 35 | 36 | -- | Particle Marginal Metropolis-Hastings sampling. 37 | pmmh :: 38 | (MonadDistribution m) => 39 | MCMCConfig -> 40 | SMCConfig (WeightedT m) -> 41 | TracedT (WeightedT m) a1 -> 42 | (a1 -> SequentialT (PopulationT (WeightedT m)) a2) -> 43 | m [[(a2, Log Double)]] 44 | pmmh mcmcConf smcConf param model = 45 | mcmc 46 | mcmcConf 47 | ( param 48 | >>= runPopulationT 49 | . pushEvidence 50 | . Pop.hoist lift 51 | . smc smcConf 52 | . model 53 | ) 54 | 55 | -- | Particle Marginal Metropolis-Hastings sampling from a Bayesian model 56 | pmmhBayesianModel :: 57 | (MonadMeasure m) => 58 | MCMCConfig -> 59 | SMCConfig (WeightedT m) -> 60 | (forall m'. (MonadMeasure m') => Bayesian m' a1 a2) -> 61 | m [[(a2, Log Double)]] 62 | pmmhBayesianModel mcmcConf smcConf bm = pmmh mcmcConf smcConf (prior bm) (generative bm) 63 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/RMSMC.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | 4 | -- | 5 | -- Module : Control.Monad.Bayes.Inference.RMSMC 6 | -- Description : Resample-Move Sequential Monte Carlo (RM-SMC) 7 | -- Copyright : (c) Adam Scibior, 2015-2020 8 | -- License : MIT 9 | -- Maintainer : leonhard.markert@tweag.io 10 | -- Stability : experimental 11 | -- Portability : GHC 12 | -- 13 | -- Resample-move Sequential Monte Carlo (RM-SMC) sampling. 14 | -- 15 | -- Walter Gilks and Carlo Berzuini. 2001. Following a moving target - Monte Carlo inference for dynamic Bayesian models. /Journal of the Royal Statistical Society/ 63 (2001), 127-146. 16 | module Control.Monad.Bayes.Inference.RMSMC 17 | ( rmsmc, 18 | rmsmcDynamic, 19 | rmsmcBasic, 20 | ) 21 | where 22 | 23 | import Control.Monad.Bayes.Class (MonadDistribution) 24 | import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..)) 25 | import Control.Monad.Bayes.Inference.SMC 26 | import Control.Monad.Bayes.Population 27 | ( PopulationT, 28 | spawn, 29 | withParticles, 30 | ) 31 | import Control.Monad.Bayes.Sequential.Coroutine as Seq 32 | import Control.Monad.Bayes.Sequential.Coroutine qualified as S 33 | import Control.Monad.Bayes.Traced.Basic qualified as TrBas 34 | import Control.Monad.Bayes.Traced.Dynamic qualified as TrDyn 35 | import Control.Monad.Bayes.Traced.Static as Tr 36 | ( TracedT, 37 | marginal, 38 | mhStep, 39 | ) 40 | import Control.Monad.Bayes.Traced.Static qualified as TrStat 41 | import Data.Monoid (Endo (..)) 42 | 43 | -- | Resample-move Sequential Monte Carlo. 44 | rmsmc :: 45 | (MonadDistribution m) => 46 | MCMCConfig -> 47 | SMCConfig m -> 48 | -- | model 49 | SequentialT (TracedT (PopulationT m)) a -> 50 | PopulationT m a 51 | rmsmc (MCMCConfig {..}) (SMCConfig {..}) = 52 | marginal 53 | . S.sequentially (composeCopies numMCMCSteps mhStep . TrStat.hoist resampler) numSteps 54 | . S.hoistFirst (TrStat.hoist (spawn numParticles >>)) 55 | 56 | -- | Resample-move Sequential Monte Carlo with a more efficient 57 | -- tracing representation. 58 | rmsmcBasic :: 59 | (MonadDistribution m) => 60 | MCMCConfig -> 61 | SMCConfig m -> 62 | -- | model 63 | SequentialT (TrBas.TracedT (PopulationT m)) a -> 64 | PopulationT m a 65 | rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) = 66 | TrBas.marginal 67 | . S.sequentially (composeCopies numMCMCSteps TrBas.mhStep . TrBas.hoist resampler) numSteps 68 | . S.hoistFirst (TrBas.hoist (withParticles numParticles)) 69 | 70 | -- | A variant of resample-move Sequential Monte Carlo 71 | -- where only random variables since last resampling are considered 72 | -- for rejuvenation. 73 | rmsmcDynamic :: 74 | (MonadDistribution m) => 75 | MCMCConfig -> 76 | SMCConfig m -> 77 | -- | model 78 | SequentialT (TrDyn.TracedT (PopulationT m)) a -> 79 | PopulationT m a 80 | rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) = 81 | TrDyn.marginal 82 | . S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps TrDyn.mhStep . TrDyn.hoist resampler) numSteps 83 | . S.hoistFirst (TrDyn.hoist (withParticles numParticles)) 84 | 85 | -- | Apply a function a given number of times. 86 | composeCopies :: Int -> (a -> a) -> (a -> a) 87 | composeCopies k = withEndo (mconcat . replicate k) 88 | 89 | withEndo :: (Endo a -> Endo b) -> (a -> a) -> b -> b 90 | withEndo f = appEndo . f . Endo 91 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/SMC.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | 4 | -- | 5 | -- Module : Control.Monad.Bayes.Inference.SMC 6 | -- Description : Sequential Monte Carlo (SMC) 7 | -- Copyright : (c) Adam Scibior, 2015-2020 8 | -- License : MIT 9 | -- Maintainer : leonhard.markert@tweag.io 10 | -- Stability : experimental 11 | -- Portability : GHC 12 | -- 13 | -- Sequential Monte Carlo (SMC) sampling. 14 | -- 15 | -- Arnaud Doucet and Adam M. Johansen. 2011. A tutorial on particle filtering and smoothing: fifteen years later. In /The Oxford Handbook of Nonlinear Filtering/, Dan Crisan and Boris Rozovskii (Eds.). Oxford University Press, Chapter 8. 16 | module Control.Monad.Bayes.Inference.SMC 17 | ( smc, 18 | smcPush, 19 | SMCConfig (..), 20 | ) 21 | where 22 | 23 | import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure) 24 | import Control.Monad.Bayes.Population 25 | ( PopulationT, 26 | pushEvidence, 27 | withParticles, 28 | ) 29 | import Control.Monad.Bayes.Sequential.Coroutine as Coroutine 30 | 31 | data SMCConfig m = SMCConfig 32 | { resampler :: forall x. PopulationT m x -> PopulationT m x, 33 | numSteps :: Int, 34 | numParticles :: Int 35 | } 36 | 37 | -- | Sequential importance resampling. 38 | -- Basically an SMC template that takes a custom resampler. 39 | smc :: 40 | (MonadDistribution m) => 41 | SMCConfig m -> 42 | Coroutine.SequentialT (PopulationT m) a -> 43 | PopulationT m a 44 | smc SMCConfig {..} = 45 | Coroutine.sequentially resampler numSteps 46 | . Coroutine.hoistFirst (withParticles numParticles) 47 | 48 | -- | Sequential Monte Carlo with multinomial resampling at each timestep. 49 | -- Weights are normalized at each timestep and the total weight is pushed 50 | -- as a score into the transformed monad. 51 | smcPush :: 52 | (MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a 53 | smcPush config = smc config {resampler = (pushEvidence . resampler config)} 54 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/SMC2.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | 4 | -- | 5 | -- Module : Control.Monad.Bayes.Inference.SMC2 6 | -- Description : Sequential Monte Carlo squared (SMC²) 7 | -- Copyright : (c) Adam Scibior, 2015-2020 8 | -- License : MIT 9 | -- Maintainer : leonhard.markert@tweag.io 10 | -- Stability : experimental 11 | -- Portability : GHC 12 | -- 13 | -- Sequential Monte Carlo squared (SMC²) sampling. 14 | -- 15 | -- Nicolas Chopin, Pierre E. Jacob, and Omiros Papaspiliopoulos. 2013. SMC²: an efficient algorithm for sequential analysis of state space models. /Journal of the Royal Statistical Society Series B: Statistical Methodology/ 75 (2013), 397-426. Issue 3. 16 | module Control.Monad.Bayes.Inference.SMC2 17 | ( smc2, 18 | SMC2, 19 | ) 20 | where 21 | 22 | import Control.Monad.Bayes.Class 23 | ( MonadDistribution (random), 24 | MonadFactor (..), 25 | MonadMeasure, 26 | MonadUniformRange (uniformR), 27 | ) 28 | import Control.Monad.Bayes.Inference.MCMC 29 | import Control.Monad.Bayes.Inference.RMSMC (rmsmc) 30 | import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush) 31 | import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT) 32 | import Control.Monad.Bayes.Sequential.Coroutine (SequentialT) 33 | import Control.Monad.Bayes.Traced 34 | import Control.Monad.Trans (MonadTrans (..)) 35 | import Numeric.Log (Log) 36 | 37 | -- | Helper monad transformer for preprocessing the model for 'smc2'. 38 | newtype SMC2 m a = SMC2 (SequentialT (TracedT (PopulationT m)) a) 39 | deriving newtype (Functor, Applicative, Monad) 40 | 41 | setup :: SMC2 m a -> SequentialT (TracedT (PopulationT m)) a 42 | setup (SMC2 m) = m 43 | 44 | instance MonadTrans SMC2 where 45 | lift = SMC2 . lift . lift . lift 46 | 47 | instance (MonadDistribution m) => MonadDistribution (SMC2 m) where 48 | random = lift random 49 | 50 | instance (MonadUniformRange m) => MonadUniformRange (SMC2 m) where 51 | uniformR l u = lift $ uniformR l u 52 | 53 | instance (Monad m) => MonadFactor (SMC2 m) where 54 | score = SMC2 . score 55 | 56 | instance (MonadDistribution m) => MonadMeasure (SMC2 m) 57 | 58 | -- | Sequential Monte Carlo squared. 59 | smc2 :: 60 | (MonadDistribution m) => 61 | -- | number of time steps 62 | Int -> 63 | -- | number of inner particles 64 | Int -> 65 | -- | number of outer particles 66 | Int -> 67 | -- | number of MH transitions 68 | Int -> 69 | -- | model parameters 70 | SequentialT (TracedT (PopulationT m)) b -> 71 | -- | model 72 | (b -> SequentialT (PopulationT (SMC2 m)) a) -> 73 | PopulationT m [(a, Log Double)] 74 | smc2 k n p t param m = 75 | rmsmc 76 | MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0} 77 | SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial} 78 | (param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m) 79 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Inference/TUI.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE ImportQualifiedPost #-} 3 | {-# OPTIONS_GHC -Wno-type-defaults #-} 4 | 5 | module Control.Monad.Bayes.Inference.TUI where 6 | 7 | import Brick 8 | import Brick qualified as B 9 | import Brick.BChan qualified as B 10 | import Brick.Widgets.Border 11 | import Brick.Widgets.Border.Style 12 | import Brick.Widgets.Center 13 | import Brick.Widgets.ProgressBar qualified as B 14 | import Control.Arrow (Arrow (..)) 15 | import Control.Concurrent (forkIO) 16 | import Control.Foldl qualified as Fold 17 | import Control.Monad (void) 18 | import Control.Monad.Bayes.Enumerator (toEmpirical) 19 | import Control.Monad.Bayes.Inference.MCMC 20 | import Control.Monad.Bayes.Sampler.Strict (SamplerIO, sampleIO) 21 | import Control.Monad.Bayes.Traced (TracedT) 22 | import Control.Monad.Bayes.Traced.Common hiding (burnIn) 23 | import Control.Monad.Bayes.Weighted 24 | import Data.Maybe (listToMaybe) 25 | import Data.Scientific (FPFormat (Exponent), formatScientific, fromFloatDigits) 26 | import Data.Text qualified as T 27 | import Data.Text.Lazy qualified as TL 28 | import Data.Text.Lazy.IO qualified as TL 29 | import GHC.Float (double2Float) 30 | import Graphics.Vty 31 | import Graphics.Vty qualified as V 32 | import Graphics.Vty.Platform.Unix qualified as V 33 | import Numeric.Log (Log (ln)) 34 | import Pipes (runEffect, (>->)) 35 | import Pipes qualified as P 36 | import Pipes.Prelude qualified as P 37 | import Text.Pretty.Simple (pShow, pShowNoColor) 38 | 39 | data MCMCData a = MCMCData 40 | { numSteps :: Int, 41 | numSuccesses :: Int, 42 | samples :: [a], 43 | lk :: [Double], 44 | totalSteps :: Int 45 | } 46 | deriving stock (Show) 47 | 48 | -- | Brick is a terminal user interface (TUI) 49 | -- which we use to display inference algorithms in progress 50 | 51 | -- | draw the brick app 52 | drawUI :: ([a] -> Widget n) -> MCMCData a -> [Widget n] 53 | drawUI handleSamples state = [ui] 54 | where 55 | completionBar = 56 | updateAttrMap 57 | ( B.mapAttrNames 58 | [ (doneAttr, B.progressCompleteAttr), 59 | (toDoAttr, B.progressIncompleteAttr) 60 | ] 61 | ) 62 | $ toBar 63 | $ fromIntegral 64 | $ numSteps state 65 | 66 | likelihoodBar = 67 | updateAttrMap 68 | ( B.mapAttrNames 69 | [ (doneAttr, B.progressCompleteAttr), 70 | (toDoAttr, B.progressIncompleteAttr) 71 | ] 72 | ) 73 | $ B.progressBar 74 | (Just $ "Mean likelihood for last 1000 samples: " <> take 10 (maybe "(error)" show (listToMaybe $ lk state <> [0]))) 75 | (double2Float (Fold.fold Fold.mean $ take 1000 $ lk state) / double2Float (maximum $ 0 : lk state)) 76 | 77 | displayStep c = Just $ "Step " <> show c 78 | numFailures = numSteps state - numSuccesses state 79 | toBar v = B.progressBar (displayStep v) (v / fromIntegral (totalSteps state)) 80 | displaySuccessesAndFailures = 81 | withBorderStyle unicode $ 82 | borderWithLabel (str "Successes and failures") $ 83 | center (str (show $ numSuccesses state)) 84 | <+> vBorder 85 | <+> center (str (show numFailures)) 86 | warning = 87 | if numSteps state > 1000 && (fromIntegral (numSuccesses state) / fromIntegral (numSteps state)) < 0.1 88 | then withAttr (attrName "highlight") $ str "Warning: acceptance rate is rather low.\nThis probably means that your proposal isn't good." 89 | else str "" 90 | 91 | ui = 92 | (str "Progress: " <+> completionBar) 93 | <=> (str "Likelihood: " <+> likelihoodBar) 94 | <=> str "\n" 95 | <=> displaySuccessesAndFailures 96 | <=> warning 97 | <=> handleSamples (samples state) 98 | 99 | noVisual :: b -> Widget n 100 | noVisual = const emptyWidget 101 | 102 | showEmpirical :: (Show a, Ord a) => [a] -> Widget n 103 | showEmpirical = 104 | txt 105 | . T.pack 106 | . TL.unpack 107 | . pShow 108 | . (fmap (second (formatScientific Exponent (Just 3) . fromFloatDigits))) 109 | . toEmpirical 110 | 111 | showVal :: (Show a) => [a] -> Widget n 112 | showVal = txt . T.pack . (\case [] -> ""; a -> maybe "(error)" show $ listToMaybe a) 113 | 114 | -- | handler for events received by the TUI 115 | appEvent :: B.BrickEvent n s -> B.EventM n s () 116 | appEvent (B.VtyEvent (V.EvKey (V.KChar 'q') [])) = B.halt 117 | appEvent (B.VtyEvent _) = pure () 118 | appEvent (B.AppEvent d) = put d 119 | appEvent _ = error "unknown event" 120 | 121 | doneAttr, toDoAttr :: B.AttrName 122 | doneAttr = B.attrName "theBase" <> B.attrName "done" 123 | toDoAttr = B.attrName "theBase" <> B.attrName "remaining" 124 | 125 | theMap :: B.AttrMap 126 | theMap = 127 | B.attrMap 128 | V.defAttr 129 | [ (B.attrName "theBase", bg V.brightBlack), 130 | (doneAttr, V.black `on` V.white), 131 | (toDoAttr, V.white `on` V.black), 132 | (attrName "highlight", fg yellow) 133 | ] 134 | 135 | tui :: (Show a) => Int -> TracedT (WeightedT SamplerIO) a -> ([a] -> Widget ()) -> IO () 136 | tui burnIn distribution visualizer = void do 137 | eventChan <- B.newBChan 10 138 | initialVty <- buildVty 139 | _ <- forkIO $ run (mcmcP MCMCConfig {numBurnIn = burnIn, proposal = SingleSiteMH, numMCMCSteps = -1} distribution) eventChan n 140 | samples <- 141 | B.customMain 142 | initialVty 143 | buildVty 144 | (Just eventChan) 145 | ( ( B.App 146 | { B.appDraw = drawUI visualizer, 147 | B.appChooseCursor = B.showFirstCursor, 148 | B.appHandleEvent = appEvent, 149 | B.appStartEvent = return (), 150 | B.appAttrMap = const theMap 151 | } 152 | ) 153 | ) 154 | (initialState n) 155 | TL.writeFile "data/tui_output.txt" (pShowNoColor samples) 156 | return samples 157 | where 158 | buildVty = V.mkVty V.defaultConfig 159 | n = 100000 160 | initialState n = MCMCData {numSteps = 0, samples = [], lk = [], numSuccesses = 0, totalSteps = n} 161 | 162 | run prod chan i = 163 | runEffect $ 164 | P.hoist (sampleIO . unweighted) prod 165 | >-> P.scan 166 | ( \mcmcdata@(MCMCData ns nsc smples lk _) a -> 167 | mcmcdata 168 | { numSteps = ns + 1, 169 | numSuccesses = nsc + if success a then 1 else 0, 170 | samples = output (trace a) : smples, 171 | lk = exp (ln (probDensity (trace a))) : lk 172 | } 173 | ) 174 | (initialState i) 175 | id 176 | >-> P.take i 177 | >-> P.mapM_ (B.writeBChan chan) 178 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Integrator.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE DerivingStrategies #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 5 | {-# LANGUAGE ImportQualifiedPost #-} 6 | {-# OPTIONS_GHC -Wno-type-defaults #-} 7 | {-# OPTIONS_GHC -Wno-unused-top-binds #-} 8 | 9 | -- | 10 | -- This is adapted from https://jtobin.io/giry-monad-implementation 11 | -- but brought into the monad-bayes framework (i.e. Integrator is an instance of MonadMeasure) 12 | -- It's largely for debugging other inference methods and didactic use, 13 | -- because brute force integration of measures is 14 | -- only practical for small programs 15 | module Control.Monad.Bayes.Integrator 16 | ( probability, 17 | variance, 18 | expectation, 19 | cdf, 20 | empirical, 21 | enumeratorWith, 22 | histogram, 23 | plotCdf, 24 | volume, 25 | normalize, 26 | Integrator, 27 | momentGeneratingFunction, 28 | cumulantGeneratingFunction, 29 | integrator, 30 | runIntegrator, 31 | ) 32 | where 33 | 34 | import Control.Applicative (Applicative (..)) 35 | import Control.Foldl (Fold) 36 | import Control.Foldl qualified as Foldl 37 | import Control.Monad.Bayes.Class (MonadDistribution (bernoulli, random, uniformD)) 38 | import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT) 39 | import Control.Monad.Cont 40 | ( Cont, 41 | ContT (ContT), 42 | cont, 43 | runCont, 44 | ) 45 | import Data.Foldable (Foldable (..)) 46 | import Data.Set (Set, elems) 47 | import Numeric.Integration.TanhSinh (Result (result), trap) 48 | import Numeric.Log (Log (ln)) 49 | import Statistics.Distribution qualified as Statistics 50 | import Statistics.Distribution.Uniform qualified as Statistics 51 | -- Prelude exports liftA2 from GHC 9.6 on, see https://github.com/haskell/core-libraries-committee/blob/main/guides/export-lifta2-prelude.md 52 | -- import Control.Applicative further up can be removed once we don't support GHC <= 9.4 anymore 53 | 54 | import Prelude hiding (Applicative (..), Foldable (..)) 55 | 56 | newtype Integrator a = Integrator {getIntegrator :: Cont Double a} 57 | deriving newtype (Functor, Applicative, Monad) 58 | 59 | runIntegrator :: (a -> Double) -> Integrator a -> Double 60 | runIntegrator f (Integrator a) = runCont a f 61 | 62 | integrator :: ((a -> Double) -> Double) -> Integrator a 63 | integrator = Integrator . cont 64 | 65 | instance MonadDistribution Integrator where 66 | random = fromDensityFunction $ Statistics.density $ Statistics.uniformDistr 0 1 67 | bernoulli p = Integrator $ cont (\f -> p * f True + (1 - p) * f False) 68 | uniformD ls = fromMassFunction (const (1 / fromIntegral (length ls))) ls 69 | 70 | fromDensityFunction :: (Double -> Double) -> Integrator Double 71 | fromDensityFunction d = Integrator $ 72 | cont $ \f -> 73 | integralWithQuadrature (\x -> f x * d x) 74 | where 75 | integralWithQuadrature = result . last . (\z -> trap z 0 1) 76 | 77 | fromMassFunction :: (Foldable f) => (a -> Double) -> f a -> Integrator a 78 | fromMassFunction f support = Integrator $ cont \g -> 79 | foldl' (\acc x -> acc + f x * g x) 0 support 80 | 81 | empirical :: (Foldable f) => f a -> Integrator a 82 | empirical = Integrator . cont . flip weightedAverage 83 | where 84 | weightedAverage :: (Foldable f, Fractional r) => (a -> r) -> f a -> r 85 | weightedAverage f = Foldl.fold (weightedAverageFold f) 86 | 87 | weightedAverageFold :: (Fractional r) => (a -> r) -> Fold a r 88 | weightedAverageFold f = Foldl.premap f averageFold 89 | 90 | averageFold :: (Fractional a) => Fold a a 91 | averageFold = (/) <$> Foldl.sum <*> Foldl.genericLength 92 | 93 | expectation :: Integrator Double -> Double 94 | expectation = runIntegrator id 95 | 96 | variance :: Integrator Double -> Double 97 | variance nu = runIntegrator (^ 2) nu - expectation nu ^ 2 98 | 99 | momentGeneratingFunction :: Integrator Double -> Double -> Double 100 | momentGeneratingFunction nu t = runIntegrator (\x -> exp (t * x)) nu 101 | 102 | cumulantGeneratingFunction :: Integrator Double -> Double -> Double 103 | cumulantGeneratingFunction nu = log . momentGeneratingFunction nu 104 | 105 | normalize :: WeightedT Integrator a -> Integrator a 106 | normalize m = 107 | let m' = runWeightedT m 108 | z = runIntegrator (ln . exp . snd) m' 109 | in do 110 | (x, d) <- runWeightedT m 111 | Integrator $ cont $ \f -> (f () * (ln $ exp d)) / z 112 | return x 113 | 114 | cdf :: Integrator Double -> Double -> Double 115 | cdf nu x = runIntegrator (negativeInfinity `to` x) nu 116 | where 117 | negativeInfinity :: Double 118 | negativeInfinity = negate (1 / 0) 119 | 120 | to :: (Num a, Ord a) => a -> a -> a -> a 121 | to a b k 122 | | k >= a && k <= b = 1 123 | | otherwise = 0 124 | 125 | volume :: Integrator Double -> Double 126 | volume = runIntegrator (const 1) 127 | 128 | containing :: (Num a, Eq b) => [b] -> b -> a 129 | containing xs x 130 | | x `elem` xs = 1 131 | | otherwise = 0 132 | 133 | instance (Num a) => Num (Integrator a) where 134 | (+) = liftA2 (+) 135 | (-) = liftA2 (-) 136 | (*) = liftA2 (*) 137 | abs = fmap abs 138 | signum = fmap signum 139 | fromInteger = pure . fromInteger 140 | 141 | probability :: (Ord a) => (a, a) -> Integrator a -> Double 142 | probability (lower, upper) = runIntegrator (\x -> if x < upper && x >= lower then 1 else 0) 143 | 144 | enumeratorWith :: (Ord a) => Set a -> Integrator a -> [(a, Double)] 145 | enumeratorWith ls meas = 146 | [ ( val, 147 | runIntegrator 148 | (\x -> if x == val then 1 else 0) 149 | meas 150 | ) 151 | | val <- elems ls 152 | ] 153 | 154 | histogram :: 155 | (Enum a, Ord a, Fractional a) => 156 | Int -> 157 | a -> 158 | WeightedT Integrator a -> 159 | [(a, Double)] 160 | histogram nBins binSize model = do 161 | x <- take nBins [1 ..] 162 | let transform k = (k - (fromIntegral nBins / 2)) * binSize 163 | return 164 | ( (fst) 165 | (transform x, transform (x + 1)), 166 | probability (transform x, transform (x + 1)) $ normalize model 167 | ) 168 | 169 | plotCdf :: Int -> Double -> Double -> Integrator Double -> [(Double, Double)] 170 | plotCdf nBins binSize middlePoint model = do 171 | x <- take nBins [1 ..] 172 | let transform k = (k - (fromIntegral nBins / 2)) * binSize + middlePoint 173 | return (transform x, cdf model (transform x)) 174 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Sampler/Lazy.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE DeriveFunctor #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | 7 | -- | This is a port of the implementation of LazyPPL: https://lazyppl.bitbucket.io/ 8 | module Control.Monad.Bayes.Sampler.Lazy where 9 | 10 | import Control.Monad (ap) 11 | import Control.Monad.Bayes.Class (MonadDistribution (random)) 12 | import Control.Monad.Bayes.Weighted (WeightedT, runWeightedT) 13 | import Control.Monad.IO.Class 14 | import Control.Monad.Identity (Identity (runIdentity)) 15 | import Control.Monad.Trans 16 | import Numeric.Log (Log (..)) 17 | import System.Random 18 | ( RandomGen (split), 19 | getStdGen, 20 | newStdGen, 21 | ) 22 | import System.Random qualified as R 23 | 24 | -- | A 'Tree' is a lazy, infinitely wide and infinitely deep tree, labelled by Doubles. 25 | -- 26 | -- Our source of randomness will be a Tree, populated by uniform [0,1] choices for each label. 27 | -- Often people just use a list or stream instead of a tree. 28 | -- But a tree allows us to be lazy about how far we are going all the time. 29 | data Tree = Tree 30 | { currentUniform :: Double, 31 | lazyUniforms :: Trees 32 | } 33 | 34 | -- | An infinite stream of 'Tree's. 35 | data Trees = Trees 36 | { headTree :: Tree, 37 | tailTrees :: Trees 38 | } 39 | 40 | -- | A probability distribution over @a@ is a function 'Tree -> a'. 41 | -- The idea is that it uses up bits of the tree as it runs. 42 | type Sampler = SamplerT Identity 43 | 44 | runSampler :: Sampler a -> Tree -> a 45 | runSampler = (runIdentity .) . runSamplerT 46 | 47 | newtype SamplerT m a = SamplerT {runSamplerT :: Tree -> m a} 48 | deriving (Functor) 49 | 50 | -- | Split a tree in two (bijectively). 51 | splitTree :: Tree -> (Tree, Tree) 52 | splitTree (Tree r (Trees t ts)) = (t, Tree r ts) 53 | 54 | -- | Generate a tree with uniform random labels. 55 | -- 56 | -- Preliminary for the simulation methods. This uses 'split' to split a random seed. 57 | randomTree :: (RandomGen g) => g -> Tree 58 | randomTree g = let (a, g') = R.random g in Tree a (randomTrees g') 59 | 60 | randomTrees :: (RandomGen g) => g -> Trees 61 | randomTrees g = let (g1, g2) = split g in Trees (randomTree g1) (randomTrees g2) 62 | 63 | instance (Monad m) => Applicative (SamplerT m) where 64 | pure = lift . pure 65 | (<*>) = ap 66 | 67 | -- | Sequencing is done by splitting the tree 68 | -- and using different bits for different computations. 69 | instance (Monad m) => Monad (SamplerT m) where 70 | return = pure 71 | (SamplerT m) >>= f = SamplerT \g -> do 72 | let (g1, g2) = splitTree g 73 | a <- m g1 74 | let SamplerT m' = f a 75 | m' g2 76 | 77 | instance MonadTrans SamplerT where 78 | lift = SamplerT . const 79 | 80 | instance (MonadIO m) => MonadIO (SamplerT m) where 81 | liftIO = lift . liftIO 82 | 83 | -- | Sampling gets the label at the head of the tree and discards the rest. 84 | instance (Monad m) => MonadDistribution (SamplerT m) where 85 | random = SamplerT \(Tree r _) -> pure r 86 | 87 | -- | Runs a 'SamplerT' by creating a new 'StdGen'. 88 | runSamplerTIO :: (MonadIO m) => SamplerT m a -> m a 89 | runSamplerTIO m = liftIO newStdGen *> (runSamplerT m =<< randomTree <$> liftIO getStdGen) 90 | 91 | -- | Draw a stream of independent samples. 92 | independent :: (Monad m) => m a -> m [a] 93 | independent = sequence . repeat 94 | 95 | -- | Runs a probability measure and gets out a stream of @(result,weight)@ pairs 96 | weightedSamples :: (MonadIO m) => WeightedT (SamplerT m) a -> m [(a, Log Double)] 97 | weightedSamples = runSamplerTIO . sequence . repeat . runWeightedT 98 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Sampler/Strict.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE DerivingStrategies #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 4 | {-# LANGUAGE ImportQualifiedPost #-} 5 | {-# LANGUAGE UnboxedTuples #-} 6 | {-# LANGUAGE UndecidableInstances #-} 7 | 8 | -- | 9 | -- Module : Control.Monad.Bayes.Sampler 10 | -- Description : Pseudo-random sampling monads 11 | -- Copyright : (c) Adam Scibior, 2015-2020 12 | -- License : MIT 13 | -- Maintainer : leonhard.markert@tweag.io 14 | -- Stability : experimental 15 | -- Portability : GHC 16 | -- 17 | -- 'SamplerIO' and 'SamplerST' are instances of 'MonadDistribution'. Apply a 'MonadFactor' 18 | -- transformer to obtain a 'MonadMeasure' that can execute probabilistic models. 19 | module Control.Monad.Bayes.Sampler.Strict 20 | ( SamplerT (..), 21 | SamplerIO, 22 | SamplerST, 23 | sampleIO, 24 | sampleIOfixed, 25 | sampleWith, 26 | sampleSTfixed, 27 | sampleMean, 28 | sampler, 29 | ) 30 | where 31 | 32 | import Control.Foldl qualified as F hiding (random) 33 | import Control.Monad.Bayes.Class 34 | ( MonadDistribution 35 | ( bernoulli, 36 | beta, 37 | categorical, 38 | gamma, 39 | geometric, 40 | normal, 41 | random, 42 | uniform 43 | ), 44 | MonadUniformRange 45 | ( uniformR 46 | ), 47 | ) 48 | import Control.Monad.Primitive (PrimMonad) 49 | import Control.Monad.Reader (MonadIO, ReaderT (..)) 50 | import Control.Monad.ST (ST) 51 | import Control.Monad.Trans (MonadTrans) 52 | import Numeric.Log (Log (ln)) 53 | import System.Random.MWC.Distributions qualified as MWC 54 | import System.Random.Stateful (IOGenM (..), STGenM, StatefulGen, StdGen, initStdGen, mkStdGen, newIOGenM, newSTGenM, uniformDouble01M, uniformRM) 55 | 56 | -- | The sampling interpretation of a probabilistic program 57 | -- Here m is typically IO or ST 58 | newtype SamplerT g m a = SamplerT {runSamplerT :: ReaderT g m a} deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, PrimMonad) 59 | 60 | -- | convenient type synonym to show specializations of SamplerT 61 | -- to particular pairs of monad and RNG 62 | type SamplerIO = SamplerT (IOGenM StdGen) IO 63 | 64 | -- | convenient type synonym to show specializations of SamplerT 65 | -- to particular pairs of monad and RNG 66 | type SamplerST s = SamplerT (STGenM StdGen s) (ST s) 67 | 68 | instance (StatefulGen g m) => MonadDistribution (SamplerT g m) where 69 | random = SamplerT (ReaderT uniformDouble01M) 70 | 71 | uniform a b = SamplerT (ReaderT $ uniformRM (a, b)) 72 | normal m s = SamplerT (ReaderT (MWC.normal m s)) 73 | gamma shape scale = SamplerT (ReaderT $ MWC.gamma shape scale) 74 | beta a b = SamplerT (ReaderT $ MWC.beta a b) 75 | 76 | bernoulli p = SamplerT (ReaderT $ MWC.bernoulli p) 77 | categorical ps = SamplerT (ReaderT $ MWC.categorical ps) 78 | geometric p = SamplerT (ReaderT $ MWC.geometric0 p) 79 | 80 | instance (StatefulGen g m) => MonadUniformRange (SamplerT g m) where 81 | uniformR l u = SamplerT (ReaderT $ uniformRM (l, u)) 82 | 83 | -- | Sample with a random number generator of your choice e.g. the one 84 | -- from `System.Random`. 85 | -- 86 | -- >>> import Control.Monad.Bayes.Class 87 | -- >>> import System.Random.Stateful hiding (random) 88 | -- >>> newIOGenM (mkStdGen 1729) >>= sampleWith random 89 | -- 4.690861245089605e-2 90 | sampleWith :: SamplerT g m a -> g -> m a 91 | sampleWith (SamplerT m) = runReaderT m 92 | 93 | -- | initialize random seed using system entropy, and sample 94 | sampleIO, sampler :: SamplerIO a -> IO a 95 | sampleIO x = initStdGen >>= newIOGenM >>= sampleWith x 96 | sampler = sampleIO 97 | 98 | -- | Run the sampler with a fixed random seed 99 | sampleIOfixed :: SamplerIO a -> IO a 100 | sampleIOfixed x = newIOGenM (mkStdGen 1729) >>= sampleWith x 101 | 102 | -- | Run the sampler with a fixed random seed 103 | sampleSTfixed :: SamplerST s b -> ST s b 104 | sampleSTfixed x = newSTGenM (mkStdGen 1729) >>= sampleWith x 105 | 106 | sampleMean :: [(Double, Log Double)] -> Double 107 | sampleMean samples = 108 | let z = F.premap (ln . exp . snd) F.sum 109 | w = (F.premap (\(x, y) -> x * ln (exp y)) F.sum) 110 | s = (/) <$> w <*> z 111 | in F.fold s samples 112 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Sequential/Coroutine.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | 5 | -- | 6 | -- Module : Control.Monad.Bayes.Sequential 7 | -- Description : Suspendable probabilistic computation 8 | -- Copyright : (c) Adam Scibior, 2015-2020 9 | -- License : MIT 10 | -- Maintainer : leonhard.markert@tweag.io 11 | -- Stability : experimental 12 | -- Portability : GHC 13 | -- 14 | -- 'SequentialT' represents a computation that can be suspended. 15 | module Control.Monad.Bayes.Sequential.Coroutine 16 | ( SequentialT, 17 | suspend, 18 | finish, 19 | advance, 20 | finished, 21 | hoistFirst, 22 | hoist, 23 | sequentially, 24 | sis, 25 | ) 26 | where 27 | 28 | import Control.Monad.Bayes.Class 29 | ( MonadDistribution (bernoulli, categorical, random), 30 | MonadFactor (..), 31 | MonadMeasure, 32 | ) 33 | import Control.Monad.Coroutine 34 | ( Coroutine (..), 35 | bounce, 36 | mapMonad, 37 | pogoStick, 38 | ) 39 | import Control.Monad.Coroutine.SuspensionFunctors 40 | ( Await (..), 41 | await, 42 | ) 43 | import Control.Monad.Trans (MonadIO, MonadTrans (..)) 44 | import Data.Either (isRight) 45 | 46 | -- | Represents a computation that can be suspended at certain points. 47 | -- The intermediate monadic effects can be extracted, which is particularly 48 | -- useful for implementation of Sequential Monte Carlo related methods. 49 | -- All the probabilistic effects are lifted from the transformed monad, but 50 | -- also `suspend` is inserted after each `factor`. 51 | newtype SequentialT m a = SequentialT {runSequentialT :: Coroutine (Await ()) m a} 52 | deriving newtype (Functor, Applicative, Monad, MonadTrans, MonadIO) 53 | 54 | extract :: Await () a -> a 55 | extract (Await f) = f () 56 | 57 | instance (MonadDistribution m) => MonadDistribution (SequentialT m) where 58 | random = lift random 59 | bernoulli = lift . bernoulli 60 | categorical = lift . categorical 61 | 62 | -- | Execution is 'suspend'ed after each 'score'. 63 | instance (MonadFactor m) => MonadFactor (SequentialT m) where 64 | score w = lift (score w) >> suspend 65 | 66 | instance (MonadMeasure m) => MonadMeasure (SequentialT m) 67 | 68 | -- | A point where the computation is paused. 69 | suspend :: (Monad m) => SequentialT m () 70 | suspend = SequentialT await 71 | 72 | -- | Remove the remaining suspension points. 73 | finish :: (Monad m) => SequentialT m a -> m a 74 | finish = pogoStick extract . runSequentialT 75 | 76 | -- | Execute to the next suspension point. 77 | -- If the computation is finished, do nothing. 78 | -- 79 | -- > finish = finish . advance 80 | advance :: (Monad m) => SequentialT m a -> SequentialT m a 81 | advance = SequentialT . bounce extract . runSequentialT 82 | 83 | -- | Return True if no more suspension points remain. 84 | finished :: (Monad m) => SequentialT m a -> m Bool 85 | finished = fmap isRight . resume . runSequentialT 86 | 87 | -- | Transform the inner monad. 88 | -- This operation only applies to computation up to the first suspension. 89 | hoistFirst :: (forall x. m x -> m x) -> SequentialT m a -> SequentialT m a 90 | hoistFirst f = SequentialT . Coroutine . f . resume . runSequentialT 91 | 92 | -- | Transform the inner monad. 93 | -- The transformation is applied recursively through all the suspension points. 94 | hoist :: 95 | (Monad m, Monad n) => 96 | (forall x. m x -> n x) -> 97 | SequentialT m a -> 98 | SequentialT n a 99 | hoist f = SequentialT . mapMonad f . runSequentialT 100 | 101 | -- | Apply a function a given number of times. 102 | composeCopies :: Int -> (a -> a) -> (a -> a) 103 | composeCopies k f = foldr (.) id (replicate k f) 104 | 105 | -- | Sequential importance sampling. 106 | -- Applies a given transformation after each time step. 107 | sequentially, 108 | sis :: 109 | (Monad m) => 110 | -- | transformation 111 | (forall x. m x -> m x) -> 112 | -- | number of time steps 113 | Int -> 114 | SequentialT m a -> 115 | m a 116 | sequentially f k = finish . composeCopies k (advance . hoistFirst f) 117 | 118 | -- | synonym 119 | sis = sequentially 120 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Traced.hs: -------------------------------------------------------------------------------- 1 | -- | 2 | -- Module : Control.Monad.Bayes.Traced 3 | -- Description : Distributions on execution traces 4 | -- Copyright : (c) Adam Scibior, 2015-2020 5 | -- License : MIT 6 | -- Maintainer : leonhard.markert@tweag.io 7 | -- Stability : experimental 8 | -- Portability : GHC 9 | module Control.Monad.Bayes.Traced 10 | ( module Control.Monad.Bayes.Traced.Static, 11 | ) 12 | where 13 | 14 | import Control.Monad.Bayes.Traced.Static 15 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Traced/Basic.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | 3 | -- | 4 | -- Module : Control.Monad.Bayes.Traced.Basic 5 | -- Description : Distributions on full execution traces of full programs 6 | -- Copyright : (c) Adam Scibior, 2015-2020 7 | -- License : MIT 8 | -- Maintainer : leonhard.markert@tweag.io 9 | -- Stability : experimental 10 | -- Portability : GHC 11 | module Control.Monad.Bayes.Traced.Basic 12 | ( TracedT, 13 | hoist, 14 | marginal, 15 | mhStep, 16 | mh, 17 | ) 18 | where 19 | 20 | import Control.Applicative (Applicative (..)) 21 | import Control.Monad.Bayes.Class 22 | ( MonadDistribution (random), 23 | MonadFactor (..), 24 | MonadMeasure, 25 | ) 26 | import Control.Monad.Bayes.Density.Free (DensityT) 27 | import Control.Monad.Bayes.Traced.Common 28 | ( Trace (..), 29 | bind, 30 | mhTrans', 31 | scored, 32 | singleton, 33 | ) 34 | import Control.Monad.Bayes.Weighted (WeightedT) 35 | import Data.Functor.Identity (Identity) 36 | import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) 37 | -- Prelude exports liftA2 from GHC 9.6 on, see https://github.com/haskell/core-libraries-committee/blob/main/guides/export-lifta2-prelude.md 38 | -- import Control.Applicative further up can be removed once we don't support GHC <= 9.4 anymore 39 | 40 | import Prelude hiding (Applicative (..)) 41 | 42 | -- | Tracing monad that records random choices made in the program. 43 | data TracedT m a = TracedT 44 | { -- | Run the program with a modified trace. 45 | model :: WeightedT (DensityT Identity) a, 46 | -- | Record trace and output. 47 | traceDist :: m (Trace a) 48 | } 49 | 50 | instance (Monad m) => Functor (TracedT m) where 51 | fmap f (TracedT m d) = TracedT (fmap f m) (fmap (fmap f) d) 52 | 53 | instance (Monad m) => Applicative (TracedT m) where 54 | pure x = TracedT (pure x) (pure (pure x)) 55 | (TracedT mf df) <*> (TracedT mx dx) = TracedT (mf <*> mx) (liftA2 (<*>) df dx) 56 | 57 | instance (Monad m) => Monad (TracedT m) where 58 | (TracedT mx dx) >>= f = TracedT my dy 59 | where 60 | my = mx >>= model . f 61 | dy = dx `bind` (traceDist . f) 62 | 63 | instance (MonadDistribution m) => MonadDistribution (TracedT m) where 64 | random = TracedT random (fmap singleton random) 65 | 66 | instance (MonadFactor m) => MonadFactor (TracedT m) where 67 | score w = TracedT (score w) (score w >> pure (scored w)) 68 | 69 | instance (MonadMeasure m) => MonadMeasure (TracedT m) 70 | 71 | hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a 72 | hoist f (TracedT m d) = TracedT m (f d) 73 | 74 | -- | Discard the trace and supporting infrastructure. 75 | marginal :: (Monad m) => TracedT m a -> m a 76 | marginal (TracedT _ d) = fmap output d 77 | 78 | -- | A single step of the Trace Metropolis-Hastings algorithm. 79 | mhStep :: (MonadDistribution m) => TracedT m a -> TracedT m a 80 | mhStep (TracedT m d) = TracedT m d' 81 | where 82 | d' = d >>= mhTrans' m 83 | 84 | -- | Full run of the Trace Metropolis-Hastings algorithm with a specified 85 | -- number of steps. 86 | mh :: (MonadDistribution m) => Int -> TracedT m a -> m [a] 87 | mh n (TracedT m d) = fmap (map output . NE.toList) (f n) 88 | where 89 | f k 90 | | k <= 0 = fmap (:| []) d 91 | | otherwise = do 92 | (x :| xs) <- f (k - 1) 93 | y <- mhTrans' m x 94 | return (y :| x : xs) 95 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Traced/Common.hs: -------------------------------------------------------------------------------- 1 | -- | 2 | -- Module : Control.Monad.Bayes.Traced.Common 3 | -- Description : Numeric code for Trace MCMC 4 | -- Copyright : (c) Adam Scibior, 2015-2020 5 | -- License : MIT 6 | -- Maintainer : leonhard.markert@tweag.io 7 | -- Stability : experimental 8 | -- Portability : GHC 9 | module Control.Monad.Bayes.Traced.Common 10 | ( Trace (..), 11 | singleton, 12 | scored, 13 | bind, 14 | mhTrans, 15 | mhTransWithBool, 16 | mhTransFree, 17 | mhTrans', 18 | burnIn, 19 | MHResult (..), 20 | ) 21 | where 22 | 23 | import Control.Monad.Bayes.Class 24 | ( MonadDistribution (bernoulli, random), 25 | discrete, 26 | ) 27 | import Control.Monad.Bayes.Density.Free qualified as Free 28 | import Control.Monad.Bayes.Density.State qualified as State 29 | import Control.Monad.Bayes.Weighted as WeightedT 30 | ( WeightedT, 31 | hoist, 32 | runWeightedT, 33 | ) 34 | import Control.Monad.Writer (WriterT (WriterT, runWriterT)) 35 | import Data.Functor.Identity (Identity (runIdentity)) 36 | import Numeric.Log (Log, ln) 37 | import Statistics.Distribution.DiscreteUniform (discreteUniformAB) 38 | 39 | data MHResult a = MHResult 40 | { success :: Bool, 41 | trace :: Trace a 42 | } 43 | 44 | -- | Collection of random variables sampler during the program's execution. 45 | data Trace a = Trace 46 | { -- | Sequence of random variables sampler during the program's execution. 47 | variables :: [Double], 48 | -- 49 | output :: a, 50 | -- | The probability of observing this particular sequence. 51 | probDensity :: Log Double 52 | } 53 | 54 | instance Functor Trace where 55 | fmap f t = t {output = f (output t)} 56 | 57 | instance Applicative Trace where 58 | pure x = Trace {variables = [], output = x, probDensity = 1} 59 | tf <*> tx = 60 | Trace 61 | { variables = variables tf ++ variables tx, 62 | output = output tf (output tx), 63 | probDensity = probDensity tf * probDensity tx 64 | } 65 | 66 | instance Monad Trace where 67 | t >>= f = 68 | let t' = f (output t) 69 | in t' {variables = variables t ++ variables t', probDensity = probDensity t * probDensity t'} 70 | 71 | singleton :: Double -> Trace Double 72 | singleton u = Trace {variables = [u], output = u, probDensity = 1} 73 | 74 | scored :: Log Double -> Trace () 75 | scored w = Trace {variables = [], output = (), probDensity = w} 76 | 77 | bind :: (Monad m) => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b) 78 | bind dx f = do 79 | t1 <- dx 80 | t2 <- f (output t1) 81 | return $ t2 {variables = variables t1 ++ variables t2, probDensity = probDensity t1 * probDensity t2} 82 | 83 | -- | A single Metropolis-corrected transition of single-site Trace MCMC. 84 | mhTrans :: (MonadDistribution m) => (WeightedT (State.DensityT m)) a -> Trace a -> m (Trace a) 85 | mhTrans m t@Trace {variables = us, probDensity = p} = do 86 | let n = length us 87 | us' <- do 88 | i <- discrete $ discreteUniformAB 0 (n - 1) 89 | u' <- random 90 | case splitAt i us of 91 | (xs, _ : ys) -> return $ xs ++ (u' : ys) 92 | _ -> error "impossible" 93 | ((b, q), vs) <- State.runDensityT (runWeightedT m) us' 94 | let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs))) 95 | accept <- bernoulli ratio 96 | return $ if accept then Trace vs b q else t 97 | 98 | mhTransFree :: (MonadDistribution m) => WeightedT (Free.DensityT m) a -> Trace a -> m (Trace a) 99 | mhTransFree m t = trace <$> mhTransWithBool m t 100 | 101 | -- | A single Metropolis-corrected transition of single-site Trace MCMC. 102 | mhTransWithBool :: (MonadDistribution m) => WeightedT (Free.DensityT m) a -> Trace a -> m (MHResult a) 103 | mhTransWithBool m t@Trace {variables = us, probDensity = p} = do 104 | let n = length us 105 | us' <- do 106 | i <- discrete $ discreteUniformAB 0 (n - 1) 107 | u' <- random 108 | case splitAt i us of 109 | (xs, _ : ys) -> return $ xs ++ (u' : ys) 110 | _ -> error "impossible" 111 | ((b, q), vs) <- runWriterT $ runWeightedT $ WeightedT.hoist (WriterT . Free.runDensityT us') m 112 | let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs))) 113 | accept <- bernoulli ratio 114 | return if accept then MHResult True (Trace vs b q) else MHResult False t 115 | 116 | -- | A variant of 'mhTrans' with an external sampling monad. 117 | mhTrans' :: (MonadDistribution m) => WeightedT (Free.DensityT Identity) a -> Trace a -> m (Trace a) 118 | mhTrans' m = mhTransFree (WeightedT.hoist (Free.hoist (return . runIdentity)) m) 119 | 120 | -- | burn in an MCMC chain for n steps (which amounts to dropping samples of the end of the list) 121 | burnIn :: (Functor m) => Int -> m [a] -> m [a] 122 | burnIn n = fmap dropEnd 123 | where 124 | dropEnd ls = let len = length ls in take (len - n) ls 125 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Traced/Dynamic.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | 3 | -- | 4 | -- Module : Control.Monad.Bayes.Traced.Dynamic 5 | -- Description : Distributions on execution traces that can be dynamically frozen 6 | -- Copyright : (c) Adam Scibior, 2015-2020 7 | -- License : MIT 8 | -- Maintainer : leonhard.markert@tweag.io 9 | -- Stability : experimental 10 | -- Portability : GHC 11 | module Control.Monad.Bayes.Traced.Dynamic 12 | ( TracedT, 13 | hoist, 14 | marginal, 15 | freeze, 16 | mhStep, 17 | mh, 18 | ) 19 | where 20 | 21 | import Control.Monad (join) 22 | import Control.Monad.Bayes.Class 23 | ( MonadDistribution (random), 24 | MonadFactor (..), 25 | MonadMeasure, 26 | ) 27 | import Control.Monad.Bayes.Density.Free (DensityT) 28 | import Control.Monad.Bayes.Traced.Common 29 | ( Trace (..), 30 | bind, 31 | mhTransFree, 32 | scored, 33 | singleton, 34 | ) 35 | import Control.Monad.Bayes.Weighted (WeightedT) 36 | import Control.Monad.Trans (MonadTrans (..)) 37 | import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) 38 | 39 | -- | A tracing monad where only a subset of random choices are traced and this 40 | -- subset can be adjusted dynamically. 41 | newtype TracedT m a = TracedT {runTraced :: m (WeightedT (DensityT m) a, Trace a)} 42 | 43 | pushM :: (Monad m) => m (WeightedT (DensityT m) a) -> WeightedT (DensityT m) a 44 | pushM = join . lift . lift 45 | 46 | instance (Monad m) => Functor (TracedT m) where 47 | fmap f (TracedT c) = TracedT $ do 48 | (m, t) <- c 49 | let m' = fmap f m 50 | let t' = fmap f t 51 | return (m', t') 52 | 53 | instance (Monad m) => Applicative (TracedT m) where 54 | pure x = TracedT $ pure (pure x, pure x) 55 | (TracedT cf) <*> (TracedT cx) = TracedT $ do 56 | (mf, tf) <- cf 57 | (mx, tx) <- cx 58 | return (mf <*> mx, tf <*> tx) 59 | 60 | instance (Monad m) => Monad (TracedT m) where 61 | (TracedT cx) >>= f = TracedT $ do 62 | (mx, tx) <- cx 63 | let m = mx >>= pushM . fmap fst . runTraced . f 64 | t <- return tx `bind` (fmap snd . runTraced . f) 65 | return (m, t) 66 | 67 | instance MonadTrans TracedT where 68 | lift m = TracedT $ fmap ((,) (lift $ lift m) . pure) m 69 | 70 | instance (MonadDistribution m) => MonadDistribution (TracedT m) where 71 | random = TracedT $ fmap ((,) random . singleton) random 72 | 73 | instance (MonadFactor m) => MonadFactor (TracedT m) where 74 | score w = TracedT $ fmap (score w,) (score w >> pure (scored w)) 75 | 76 | instance (MonadMeasure m) => MonadMeasure (TracedT m) 77 | 78 | hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a 79 | hoist f (TracedT c) = TracedT (f c) 80 | 81 | -- | Discard the trace and supporting infrastructure. 82 | marginal :: (Monad m) => TracedT m a -> m a 83 | marginal (TracedT c) = fmap (output . snd) c 84 | 85 | -- | Freeze all traced random choices to their current values and stop tracing 86 | -- them. 87 | freeze :: (Monad m) => TracedT m a -> TracedT m a 88 | freeze (TracedT c) = TracedT $ do 89 | (_, t) <- c 90 | let x = output t 91 | return (return x, pure x) 92 | 93 | -- | A single step of the Trace Metropolis-Hastings algorithm. 94 | mhStep :: (MonadDistribution m) => TracedT m a -> TracedT m a 95 | mhStep (TracedT c) = TracedT $ do 96 | (m, t) <- c 97 | t' <- mhTransFree m t 98 | return (m, t') 99 | 100 | -- | Full run of the Trace Metropolis-Hastings algorithm with a specified 101 | -- number of steps. 102 | mh :: (MonadDistribution m) => Int -> TracedT m a -> m [a] 103 | mh n (TracedT c) = do 104 | (m, t) <- c 105 | let f k 106 | | k <= 0 = return (t :| []) 107 | | otherwise = do 108 | (x :| xs) <- f (k - 1) 109 | y <- mhTransFree m x 110 | return (y :| x : xs) 111 | fmap (map output . NE.toList) (f n) 112 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Traced/Static.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | 4 | -- | 5 | -- Module : Control.Monad.Bayes.Traced.Static 6 | -- Description : Distributions on execution traces of full programs 7 | -- Copyright : (c) Adam Scibior, 2015-2020 8 | -- License : MIT 9 | -- Maintainer : leonhard.markert@tweag.io 10 | -- Stability : experimental 11 | -- Portability : GHC 12 | module Control.Monad.Bayes.Traced.Static 13 | ( TracedT (..), 14 | hoist, 15 | marginal, 16 | mhStep, 17 | mh, 18 | ) 19 | where 20 | 21 | import Control.Applicative (Applicative (..)) 22 | import Control.Monad.Bayes.Class 23 | ( MonadDistribution (random), 24 | MonadFactor (..), 25 | MonadMeasure, 26 | ) 27 | import Control.Monad.Bayes.Density.Free (DensityT) 28 | import Control.Monad.Bayes.Traced.Common 29 | ( Trace (..), 30 | bind, 31 | mhTransFree, 32 | scored, 33 | singleton, 34 | ) 35 | import Control.Monad.Bayes.Weighted (WeightedT) 36 | import Control.Monad.Trans (MonadTrans (..)) 37 | import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList) 38 | -- Prelude exports liftA2 from GHC 9.6 on, see https://github.com/haskell/core-libraries-committee/blob/main/guides/export-lifta2-prelude.md 39 | -- import Control.Applicative further up can be removed once we don't support GHC <= 9.4 anymore 40 | 41 | import Prelude hiding (Applicative (..)) 42 | 43 | -- | A tracing monad where only a subset of random choices are traced. 44 | -- 45 | -- The random choices that are not to be traced should be lifted from the 46 | -- transformed monad. 47 | data TracedT m a = TracedT 48 | { model :: WeightedT (DensityT m) a, 49 | traceDist :: m (Trace a) 50 | } 51 | 52 | instance (Monad m) => Functor (TracedT m) where 53 | fmap f (TracedT m d) = TracedT (fmap f m) (fmap (fmap f) d) 54 | 55 | instance (Monad m) => Applicative (TracedT m) where 56 | pure x = TracedT (pure x) (pure (pure x)) 57 | (TracedT mf df) <*> (TracedT mx dx) = TracedT (mf <*> mx) (liftA2 (<*>) df dx) 58 | 59 | instance (Monad m) => Monad (TracedT m) where 60 | (TracedT mx dx) >>= f = TracedT my dy 61 | where 62 | my = mx >>= model . f 63 | dy = dx `bind` (traceDist . f) 64 | 65 | instance MonadTrans TracedT where 66 | lift m = TracedT (lift $ lift m) (fmap pure m) 67 | 68 | instance (MonadDistribution m) => MonadDistribution (TracedT m) where 69 | random = TracedT random (fmap singleton random) 70 | 71 | instance (MonadFactor m) => MonadFactor (TracedT m) where 72 | score w = TracedT (score w) (score w >> pure (scored w)) 73 | 74 | instance (MonadMeasure m) => MonadMeasure (TracedT m) 75 | 76 | hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a 77 | hoist f (TracedT m d) = TracedT m (f d) 78 | 79 | -- | Discard the trace and supporting infrastructure. 80 | marginal :: (Monad m) => TracedT m a -> m a 81 | marginal (TracedT _ d) = fmap output d 82 | 83 | -- | A single step of the Trace Metropolis-Hastings algorithm. 84 | mhStep :: (MonadDistribution m) => TracedT m a -> TracedT m a 85 | mhStep (TracedT m d) = TracedT m d' 86 | where 87 | d' = d >>= mhTransFree m 88 | 89 | -- $setup 90 | -- >>> import Control.Monad.Bayes.Class 91 | -- >>> import Control.Monad.Bayes.Sampler.Strict 92 | -- >>> import Control.Monad.Bayes.Weighted 93 | 94 | -- | Full run of the Trace Metropolis-Hastings algorithm with a specified 95 | -- number of steps. Newest samples are at the head of the list. 96 | -- 97 | -- For example: 98 | -- 99 | -- * I have forgotten what day it is. 100 | -- * There are ten buses per hour in the week and three buses per hour at the weekend. 101 | -- * I observe four buses in a given hour. 102 | -- * What is the probability that it is the weekend? 103 | -- 104 | -- >>> :{ 105 | -- let 106 | -- bus = do x <- bernoulli (2/7) 107 | -- let rate = if x then 3 else 10 108 | -- factor $ poissonPdf rate 4 109 | -- return x 110 | -- mhRunBusSingleObs = do 111 | -- let nSamples = 2 112 | -- sampleIOfixed $ unweighted $ mh nSamples bus 113 | -- in mhRunBusSingleObs 114 | -- :} 115 | -- [True,True,True] 116 | -- 117 | -- Of course, it will need to be run more than twice to get a reasonable estimate. 118 | mh :: (MonadDistribution m) => Int -> TracedT m a -> m [a] 119 | mh n (TracedT m d) = fmap (map output . NE.toList) (f n) 120 | where 121 | f k 122 | | k <= 0 = fmap (:| []) d 123 | | otherwise = do 124 | (x :| xs) <- f (k - 1) 125 | y <- mhTransFree m x 126 | return (y :| x : xs) 127 | -------------------------------------------------------------------------------- /src/Control/Monad/Bayes/Weighted.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | 5 | -- | 6 | -- Module : Control.Monad.Bayes.Weighted 7 | -- Description : Probability monad accumulating the likelihood 8 | -- Copyright : (c) Adam Scibior, 2015-2020 9 | -- License : MIT 10 | -- Maintainer : leonhard.markert@tweag.io 11 | -- Stability : experimental 12 | -- Portability : GHC 13 | -- 14 | -- 'WeightedT' is an instance of 'MonadFactor'. Apply a 'MonadDistribution' transformer to 15 | -- obtain a 'MonadMeasure' that can execute probabilistic models. 16 | module Control.Monad.Bayes.Weighted 17 | ( WeightedT, 18 | weightedT, 19 | extractWeight, 20 | unweighted, 21 | applyWeight, 22 | hoist, 23 | runWeightedT, 24 | ) 25 | where 26 | 27 | import Control.Monad.Bayes.Class 28 | ( MonadDistribution, 29 | MonadFactor (..), 30 | MonadMeasure, 31 | MonadUniformRange, 32 | factor, 33 | ) 34 | import Control.Monad.State (MonadIO, MonadTrans, StateT (..), lift, mapStateT, modify) 35 | import Numeric.Log (Log) 36 | 37 | -- | Execute the program using the prior distribution, while accumulating likelihood. 38 | newtype WeightedT m a = WeightedT (StateT (Log Double) m a) 39 | -- StateT is more efficient than WriterT 40 | deriving newtype (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadDistribution, MonadUniformRange) 41 | 42 | instance (Monad m) => MonadFactor (WeightedT m) where 43 | score w = WeightedT (modify (* w)) 44 | 45 | instance (MonadDistribution m) => MonadMeasure (WeightedT m) 46 | 47 | -- | Obtain an explicit value of the likelihood for a given value. 48 | runWeightedT :: WeightedT m a -> m (a, Log Double) 49 | runWeightedT (WeightedT m) = runStateT m 1 50 | 51 | -- | Compute the sample and discard the weight. 52 | -- 53 | -- This operation introduces bias. 54 | unweighted :: (Functor m) => WeightedT m a -> m a 55 | unweighted = fmap fst . runWeightedT 56 | 57 | -- | Compute the weight and discard the sample. 58 | extractWeight :: (Functor m) => WeightedT m a -> m (Log Double) 59 | extractWeight = fmap snd . runWeightedT 60 | 61 | -- | Embed a random variable with explicitly given likelihood. 62 | -- 63 | -- > runWeightedT . weightedT = id 64 | weightedT :: (Monad m) => m (a, Log Double) -> WeightedT m a 65 | weightedT m = WeightedT $ do 66 | (x, w) <- lift m 67 | modify (* w) 68 | return x 69 | 70 | -- | Use the weight as a factor in the transformed monad. 71 | applyWeight :: (MonadFactor m) => WeightedT m a -> m a 72 | applyWeight m = do 73 | (x, w) <- runWeightedT m 74 | factor w 75 | return x 76 | 77 | -- | Apply a transformation to the transformed monad. 78 | hoist :: (forall x. m x -> n x) -> WeightedT m a -> WeightedT n a 79 | hoist t (WeightedT m) = WeightedT $ mapStateT t m 80 | -------------------------------------------------------------------------------- /src/Math/Integrators/StormerVerlet.hs: -------------------------------------------------------------------------------- 1 | module Math.Integrators.StormerVerlet 2 | ( integrateV, 3 | stormerVerlet2H, 4 | Integrator, 5 | ) 6 | where 7 | 8 | import Control.Lens 9 | import Control.Monad.Primitive 10 | import Data.Vector (Vector, (!)) 11 | import Data.Vector qualified as V 12 | import Data.Vector.Mutable 13 | import Linear (V2 (..)) 14 | 15 | -- | Integrator function 16 | -- - \Phi [h] |-> y_0 -> y_1 17 | type Integrator a = 18 | -- | Step 19 | Double -> 20 | -- | Initial value 21 | a -> 22 | -- | Next value 23 | a 24 | 25 | -- | Störmer-Verlet integration scheme for systems of the form 26 | -- \(\mathbb{H}(p,q) = T(p) + V(q)\) 27 | stormerVerlet2H :: 28 | (Applicative f, Num (f a), Fractional a) => 29 | -- | Step size 30 | a -> 31 | -- | \(\frac{\partial H}{\partial q}\) 32 | (f a -> f a) -> 33 | -- | \(\frac{\partial H}{\partial p}\) 34 | (f a -> f a) -> 35 | -- | Current \((p, q)\) as a 2-dimensional vector 36 | V2 (f a) -> 37 | -- | New \((p, q)\) as a 2-dimensional vector 38 | V2 (f a) 39 | stormerVerlet2H hh nablaQ nablaP prev = 40 | V2 qNew pNew 41 | where 42 | h2 = hh / 2 43 | hhs = pure hh 44 | hh2s = pure h2 45 | qsPrev = prev ^. _1 46 | psPrev = prev ^. _2 47 | pp2 = psPrev - hh2s * nablaQ qsPrev 48 | qNew = qsPrev + hhs * nablaP pp2 49 | pNew = pp2 - hh2s * nablaQ qNew 50 | 51 | -- | 52 | -- Integrate ODE equation using fixed steps set by a vector, and returns a vector 53 | -- of solutions corrensdonded to times that was requested. 54 | -- It takes Vector of time points as a parameter and returns a vector of results 55 | integrateV :: 56 | (PrimMonad m) => 57 | -- | Internal integrator 58 | Integrator a -> 59 | -- | initial value 60 | a -> 61 | -- | vector of time points 62 | Vector Double -> 63 | -- | vector of solution 64 | m (Vector a) 65 | integrateV integrator initial times = do 66 | out <- new (V.length times) 67 | write out 0 initial 68 | compute initial 1 out 69 | V.unsafeFreeze out 70 | where 71 | compute y i out 72 | | i == V.length times = return () 73 | | otherwise = do 74 | let h = (times ! i) - (times ! (i - 1)) 75 | y' = integrator h y 76 | write out i y' 77 | compute y' (i + 1) out 78 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: nightly-2022-08-16 2 | packages: 3 | - "." 4 | flags: 5 | monad-bayes: 6 | dev: True 7 | 8 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | 3 | import Data.AEq (AEq ((~==))) 4 | import Test.Hspec (context, describe, hspec, it, shouldBe) 5 | import Test.Hspec.QuickCheck (prop) 6 | import Test.QuickCheck (ioProperty, property, (==>)) 7 | import TestAdvanced qualified 8 | import TestBenchmarks qualified 9 | import TestDistribution qualified 10 | import TestEnumerator qualified 11 | import TestInference qualified 12 | import TestIntegrator qualified 13 | import TestPipes (hmms) 14 | import TestPipes qualified 15 | import TestPopulation qualified 16 | import TestSSMFixtures qualified 17 | import TestSampler qualified 18 | import TestSequential qualified 19 | import TestStormerVerlet qualified 20 | import TestWeighted qualified 21 | 22 | main :: IO () 23 | main = hspec do 24 | describe "Stormer Verlet" $ 25 | it "conserves energy" $ 26 | do 27 | p1 <- TestStormerVerlet.passed1 28 | p1 `shouldBe` True 29 | describe "Distribution" $ 30 | it "gives correct mean, variance and covariance" $ 31 | do 32 | p1 <- TestDistribution.passed1 33 | p1 `shouldBe` True 34 | p2 <- TestDistribution.passed2 35 | p2 `shouldBe` True 36 | p3 <- TestDistribution.passed3 37 | p3 `shouldBe` True 38 | describe "Weighted" $ 39 | it "accumulates likelihood correctly" $ 40 | do 41 | passed <- TestWeighted.passed 42 | passed `shouldBe` True 43 | describe "Enumerator" do 44 | it "sorts samples and aggregates weights" $ 45 | TestEnumerator.passed2 `shouldBe` True 46 | it "gives correct answer for the sprinkler model" $ 47 | TestEnumerator.passed3 `shouldBe` True 48 | it "computes expectation correctly" $ 49 | TestEnumerator.passed4 `shouldBe` True 50 | describe "Integrator Expectation" do 51 | prop "expectation numerically" $ 52 | \mean var -> 53 | var > 0 ==> property $ TestIntegrator.normalExpectation mean (sqrt var) ~== mean 54 | describe "Integrator Variance" do 55 | prop "variance numerically" $ 56 | \mean var -> 57 | -- Because of rounding issues, require the variance to be a bit bigger than 0 58 | -- See https://github.com/tweag/monad-bayes/issues/275 59 | var > 0.1 ==> property $ TestIntegrator.normalVariance mean (sqrt var) ~== var 60 | describe "SamplerT mean and variance" do 61 | it "gets right mean and variance" $ 62 | TestSampler.testMeanAndVariance `shouldBe` True 63 | describe "Integrator Volume" do 64 | prop "volume sums to 1" $ 65 | property $ \case 66 | [] -> True 67 | ls -> (TestIntegrator.volumeIsOne ls) 68 | 69 | describe "Integrator" do 70 | it "" $ 71 | all 72 | (== True) 73 | [ TestIntegrator.passed1, 74 | TestIntegrator.passed2, 75 | TestIntegrator.passed3, 76 | TestIntegrator.passed4, 77 | TestIntegrator.passed5, 78 | TestIntegrator.passed6, 79 | TestIntegrator.passed7, 80 | TestIntegrator.passed8, 81 | TestIntegrator.passed9, 82 | TestIntegrator.passed10, 83 | TestIntegrator.passed11, 84 | TestIntegrator.passed12, 85 | TestIntegrator.passed13, 86 | TestIntegrator.passed14 87 | ] 88 | `shouldBe` True 89 | 90 | describe "Population" do 91 | context "controlling population" do 92 | it "preserves the population when not explicitly altered" do 93 | popSize <- TestPopulation.popSize 94 | popSize `shouldBe` 5 95 | it "multiplies the number of samples when spawn invoked twice" do 96 | manySize <- TestPopulation.manySize 97 | manySize `shouldBe` 15 98 | it "correctly computes population average" $ 99 | TestPopulation.popAvgCheck `shouldBe` True 100 | context "distribution-preserving transformations" do 101 | it "collapse preserves the distribution" do 102 | TestPopulation.transCheck1 `shouldBe` True 103 | TestPopulation.transCheck2 `shouldBe` True 104 | it "resample preserves the distribution" do 105 | TestPopulation.resampleCheck 1 `shouldBe` True 106 | TestPopulation.resampleCheck 2 `shouldBe` True 107 | describe "Sequential" do 108 | it "stops at every factor" do 109 | TestSequential.checkTwoSync 0 `shouldBe` True 110 | TestSequential.checkTwoSync 1 `shouldBe` True 111 | TestSequential.checkTwoSync 2 `shouldBe` True 112 | it "preserves the distribution" $ 113 | TestSequential.checkPreserve `shouldBe` True 114 | it "produces correct intermediate weights" do 115 | TestSequential.checkSync 0 `shouldBe` True 116 | TestSequential.checkSync 1 `shouldBe` True 117 | TestSequential.checkSync 2 `shouldBe` True 118 | describe "SMC" do 119 | it "terminates" $ 120 | seq TestInference.checkTerminateSMC () `shouldBe` () 121 | it "preserves the distribution on the sprinkler model" $ 122 | TestInference.checkPreserveSMC `shouldBe` True 123 | prop "number of particles is equal to its second parameter" $ 124 | \observations particles -> 125 | observations >= 0 && particles >= 1 ==> ioProperty do 126 | checkParticles <- TestInference.checkParticles observations particles 127 | return $ checkParticles == particles 128 | describe "SMC with systematic resampling" $ 129 | prop "number of particles is equal to its second parameter" $ 130 | \observations particles -> 131 | observations >= 0 && particles >= 1 ==> ioProperty do 132 | checkParticles <- TestInference.checkParticlesSystematic observations particles 133 | return $ checkParticles == particles 134 | describe "Equivalent Expectations" do 135 | prop "Gamma Normal" $ 136 | ioProperty . TestInference.testGammaNormal 137 | prop "Normal Normal" $ 138 | \n -> ioProperty (TestInference.testNormalNormal [max (-3) $ min 3 n]) 139 | prop "Beta Bernoulli" $ 140 | ioProperty . TestInference.testBetaBernoulli 141 | describe "Pipes: Urn" do 142 | it "Distributions are equivalent" do 143 | TestPipes.urns 10 `shouldBe` True 144 | describe "Pipes: HMM" do 145 | prop "pipe model is equivalent to standard model" $ 146 | \num -> property $ hmms $ take 5 num 147 | 148 | describe "SMC with stratified resampling" $ 149 | prop "number of particles is equal to its second parameter" $ 150 | \observations particles -> 151 | observations >= 0 && particles >= 1 ==> ioProperty do 152 | checkParticles <- TestInference.checkParticlesStratified observations particles 153 | return $ checkParticles == particles 154 | 155 | describe "Expectation from all inference methods" $ 156 | it "gives correct answer for the sprinkler model" do 157 | passed1 <- TestAdvanced.passed1 158 | passed1 `shouldBe` True 159 | passed2 <- TestAdvanced.passed2 160 | passed2 `shouldBe` True 161 | passed3 <- TestAdvanced.passed3 162 | passed3 `shouldBe` True 163 | passed4 <- TestAdvanced.passed4 164 | passed4 `shouldBe` True 165 | passed5 <- TestAdvanced.passed5 166 | passed5 `shouldBe` True 167 | passed6 <- TestAdvanced.passed6 168 | passed6 `shouldBe` True 169 | passed7 <- TestAdvanced.passed7 170 | passed7 `shouldBe` True 171 | 172 | TestBenchmarks.test 173 | TestSSMFixtures.test 174 | -------------------------------------------------------------------------------- /test/TestAdvanced.hs: -------------------------------------------------------------------------------- 1 | module TestAdvanced where 2 | 3 | import Control.Arrow 4 | import Control.Monad (join) 5 | import Control.Monad.Bayes.Class 6 | import Control.Monad.Bayes.Enumerator 7 | import Control.Monad.Bayes.Inference.MCMC 8 | import Control.Monad.Bayes.Inference.PMMH 9 | import Control.Monad.Bayes.Inference.RMSMC 10 | import Control.Monad.Bayes.Inference.SMC 11 | import Control.Monad.Bayes.Inference.SMC2 12 | import Control.Monad.Bayes.Population 13 | import Control.Monad.Bayes.Sampler.Strict 14 | 15 | mcmcConfig :: MCMCConfig 16 | mcmcConfig = MCMCConfig {numMCMCSteps = 0, numBurnIn = 0, proposal = SingleSiteMH} 17 | 18 | smcConfig :: (MonadDistribution m) => SMCConfig m 19 | smcConfig = SMCConfig {numSteps = 0, numParticles = 1000, resampler = resampleMultinomial} 20 | 21 | passed1, passed2, passed3, passed4, passed5, passed6, passed7 :: IO Bool 22 | passed1 = do 23 | sample <- sampleIOfixed $ mcmc MCMCConfig {numMCMCSteps = 10000, numBurnIn = 5000, proposal = SingleSiteMH} random 24 | return $ abs (0.5 - (expectation id $ fromList $ toEmpirical sample)) < 0.01 25 | passed2 = do 26 | sample <- sampleIOfixed $ runPopulationT $ smc (SMCConfig {numSteps = 0, numParticles = 10000, resampler = resampleMultinomial}) random 27 | return $ close 0.5 sample 28 | passed3 = do 29 | sample <- sampleIOfixed $ runPopulationT $ rmsmcDynamic mcmcConfig smcConfig random 30 | return $ close 0.5 sample 31 | passed4 = do 32 | sample <- sampleIOfixed $ runPopulationT $ rmsmcBasic mcmcConfig smcConfig random 33 | return $ close 0.5 sample 34 | passed5 = do 35 | sample <- sampleIOfixed $ runPopulationT $ rmsmc mcmcConfig smcConfig random 36 | return $ close 0.5 sample 37 | passed6 = do 38 | sample <- 39 | fmap join $ 40 | sampleIOfixed $ 41 | pmmh 42 | mcmcConfig {numMCMCSteps = 100} 43 | smcConfig {numSteps = 0, numParticles = 100} 44 | random 45 | (normal 0) 46 | return $ close 0.0 sample 47 | 48 | close :: Double -> [(Double, Log Double)] -> Bool 49 | 50 | passed7 = do 51 | sample <- fmap join $ sampleIOfixed $ fmap (fmap (\(x, y) -> fmap (second (* y)) x)) $ runPopulationT $ smc2 0 100 100 100 random (normal 0) 52 | return $ close 0.0 sample 53 | 54 | close n sample = abs (n - (expectation id $ fromList $ toEmpiricalWeighted sample)) < 0.01 55 | -------------------------------------------------------------------------------- /test/TestBenchmarks.hs: -------------------------------------------------------------------------------- 1 | module TestBenchmarks where 2 | 3 | import Control.Monad (forM_) 4 | import Data.Maybe (fromJust) 5 | import Helper 6 | import Paths_monad_bayes (getDataDir) 7 | import System.IO (readFile') 8 | import System.IO.Error (catchIOError, isDoesNotExistError) 9 | import Test.Hspec 10 | 11 | fixtureToFilename :: Model -> Alg -> String 12 | fixtureToFilename model alg = "/test/fixtures/" ++ fromJust (serializeModel model) ++ "-" ++ show alg ++ ".txt" 13 | 14 | models :: [Model] 15 | models = [LR 10, HMM 10, LDA (5, 10)] 16 | 17 | algs :: [Alg] 18 | algs = [minBound .. maxBound] 19 | 20 | test :: SpecWith () 21 | test = describe "Benchmarks" $ forM_ models $ \model -> forM_ algs $ testFixture model 22 | 23 | testFixture :: Model -> Alg -> SpecWith () 24 | testFixture model alg = do 25 | dataDir <- runIO getDataDir 26 | let filename = dataDir <> fixtureToFilename model alg 27 | it ("should agree with the fixture " ++ filename) $ do 28 | fixture <- catchIOError (readFile' filename) $ \e -> 29 | if isDoesNotExistError e 30 | then return "" 31 | else ioError e 32 | sampled <- runAlgFixed model alg 33 | -- Reset in case of fixture update or creation 34 | writeFile filename sampled 35 | fixture `shouldBe` sampled 36 | -------------------------------------------------------------------------------- /test/TestDistribution.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | {-# LANGUAGE Trustworthy #-} 3 | 4 | module TestDistribution 5 | ( passed1, 6 | passed2, 7 | passed3, 8 | ) 9 | where 10 | 11 | import Control.Monad (replicateM) 12 | import Control.Monad.Bayes.Class (mvNormal) 13 | import Control.Monad.Bayes.Sampler.Strict 14 | import Data.Matrix (fromList) 15 | import Data.Vector qualified as V 16 | 17 | -- Test the sampled covariance is approximately the same as the 18 | -- specified covariance. 19 | passed1 :: IO Bool 20 | passed1 = sampleIOfixed $ do 21 | let mu = (V.fromList [0.0, 0.0]) 22 | sigma11 = 2.0 23 | sigma12 = 1.0 24 | bigSigma = (fromList 2 2 [sigma11, sigma12, sigma12, sigma11]) 25 | nSamples = 200000 26 | nSamples' = fromIntegral nSamples 27 | ss <- replicateM nSamples $ (mvNormal mu bigSigma) 28 | let xbar = (/ nSamples') $ sum $ fmap (V.! 0) ss 29 | ybar = (/ nSamples') $ sum $ fmap (V.! 1) ss 30 | let term1 = (/ nSamples') $ sum $ zipWith (*) (fmap (V.! 0) ss) (fmap (V.! 1) ss) 31 | let term2 = xbar * ybar 32 | return $ abs (sigma12 - (term1 - term2)) < 2e-2 33 | 34 | -- Test the sampled means are approximately the same as the specified 35 | -- means. 36 | passed2 :: IO Bool 37 | passed2 = sampleIOfixed $ do 38 | let mu = (V.fromList [0.0, 0.0]) 39 | sigma11 = 2.0 40 | sigma12 = 1.0 41 | bigSigma = (fromList 2 2 [sigma11, sigma12, sigma12, sigma11]) 42 | nSamples = 100000 43 | nSamples' = fromIntegral nSamples 44 | ss <- replicateM nSamples $ (mvNormal mu bigSigma) 45 | let xbar = (/ nSamples') $ sum $ fmap (V.! 0) ss 46 | ybar = (/ nSamples') $ sum $ fmap (V.! 1) ss 47 | return $ abs xbar < 1e-2 && abs ybar < 1e-2 48 | 49 | -- Test the sampled variances are approximately the same as the 50 | -- specified variances. 51 | passed3 :: IO Bool 52 | passed3 = sampleIOfixed $ do 53 | let mu = (V.fromList [0.0, 0.0]) 54 | sigma11 = 2.0 55 | sigma12 = 1.0 56 | bigSigma = (fromList 2 2 [sigma11, sigma12, sigma12, sigma11]) 57 | nSamples = 200000 58 | nSamples' = fromIntegral nSamples 59 | ss <- replicateM nSamples $ (mvNormal mu bigSigma) 60 | let xbar = (/ nSamples') $ sum $ fmap (V.! 0) ss 61 | ybar = (/ nSamples') $ sum $ fmap (V.! 1) ss 62 | let xbar2 = (/ nSamples') $ sum $ fmap (\x -> x * x) $ fmap (V.! 0) ss 63 | ybar2 = (/ nSamples') $ sum $ fmap (\x -> x * x) $ fmap (V.! 1) ss 64 | let xvar = xbar2 - xbar * xbar 65 | let yvar = ybar2 - ybar * ybar 66 | return $ abs (xvar - sigma11) < 1e-2 && abs (yvar - sigma11) < 2e-2 67 | -------------------------------------------------------------------------------- /test/TestEnumerator.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | 3 | module TestEnumerator (passed1, passed2, passed3, passed4) where 4 | 5 | import Control.Monad.Bayes.Class 6 | ( MonadDistribution (categorical, uniformD), 7 | ) 8 | import Control.Monad.Bayes.Enumerator 9 | ( enumerator, 10 | evidence, 11 | expectation, 12 | ) 13 | import Data.AEq (AEq ((~==))) 14 | import Data.Vector qualified as V 15 | import Numeric.Log (Log (ln)) 16 | import Sprinkler (hard, soft) 17 | 18 | unnorm :: (MonadDistribution m) => m Int 19 | unnorm = categorical $ V.fromList [0.5, 0.8] 20 | 21 | passed1 :: Bool 22 | passed1 = (exp . ln) (evidence unnorm) ~== 1 23 | 24 | agg :: (MonadDistribution m) => m Int 25 | agg = do 26 | x <- uniformD [0, 1] 27 | y <- uniformD [2, 1] 28 | return (x + y) 29 | 30 | passed2 :: Bool 31 | passed2 = enumerator agg ~== [(2, 0.5), (1, 0.25), (3, 0.25)] 32 | 33 | passed3 :: Bool 34 | passed3 = enumerator Sprinkler.hard ~== enumerator Sprinkler.soft 35 | 36 | passed4 :: Bool 37 | passed4 = 38 | expectation (^ (2 :: Int)) (fmap (fromIntegral . (+ 1)) $ categorical $ V.fromList [0.5, 0.5]) ~== 2.5 39 | -------------------------------------------------------------------------------- /test/TestInference.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ImportQualifiedPost #-} 2 | {-# LANGUAGE Rank2Types #-} 3 | {-# LANGUAGE TypeFamilies #-} 4 | {-# OPTIONS_GHC -Wno-missing-export-lists #-} 5 | 6 | module TestInference where 7 | 8 | import ConjugatePriors 9 | ( betaBernoulli', 10 | betaBernoulliAnalytic, 11 | gammaNormal', 12 | gammaNormalAnalytic, 13 | normalNormal', 14 | normalNormalAnalytic, 15 | ) 16 | import Control.Monad (replicateM) 17 | import Control.Monad.Bayes.Class (MonadMeasure, posterior) 18 | import Control.Monad.Bayes.Enumerator (enumerator) 19 | import Control.Monad.Bayes.Inference.SMC 20 | import Control.Monad.Bayes.Integrator (normalize) 21 | import Control.Monad.Bayes.Integrator qualified as Integrator 22 | import Control.Monad.Bayes.Population 23 | import Control.Monad.Bayes.Sampler.Strict (SamplerT, sampleIOfixed) 24 | import Control.Monad.Bayes.Sampler.Strict qualified as Sampler 25 | import Control.Monad.Bayes.Weighted (WeightedT) 26 | import Control.Monad.Bayes.Weighted qualified as WeightedT 27 | import Data.AEq (AEq ((~==))) 28 | import Numeric.Log (Log) 29 | import Sprinkler (soft) 30 | import System.Random.Stateful (IOGenM, StdGen) 31 | 32 | sprinkler :: (MonadMeasure m) => m Bool 33 | sprinkler = Sprinkler.soft 34 | 35 | -- | Count the number of particles produced by SMC 36 | checkParticles :: Int -> Int -> IO Int 37 | checkParticles observations particles = 38 | sampleIOfixed (fmap length (runPopulationT $ smc SMCConfig {numSteps = observations, numParticles = particles, resampler = resampleMultinomial} Sprinkler.soft)) 39 | 40 | checkParticlesSystematic :: Int -> Int -> IO Int 41 | checkParticlesSystematic observations particles = 42 | sampleIOfixed (fmap length (runPopulationT $ smc SMCConfig {numSteps = observations, numParticles = particles, resampler = resampleSystematic} Sprinkler.soft)) 43 | 44 | checkParticlesStratified :: Int -> Int -> IO Int 45 | checkParticlesStratified observations particles = 46 | sampleIOfixed (fmap length (runPopulationT $ smc SMCConfig {numSteps = observations, numParticles = particles, resampler = resampleStratified} Sprinkler.soft)) 47 | 48 | checkTerminateSMC :: IO [(Bool, Log Double)] 49 | checkTerminateSMC = sampleIOfixed (runPopulationT $ smc SMCConfig {numSteps = 2, numParticles = 5, resampler = resampleMultinomial} sprinkler) 50 | 51 | checkPreserveSMC :: Bool 52 | checkPreserveSMC = 53 | (enumerator . collapse . smc SMCConfig {numSteps = 2, numParticles = 2, resampler = resampleMultinomial}) sprinkler 54 | ~== enumerator sprinkler 55 | 56 | expectationNearNumeric :: 57 | WeightedT Integrator.Integrator Double -> 58 | WeightedT Integrator.Integrator Double -> 59 | Double 60 | expectationNearNumeric x y = 61 | let e1 = Integrator.expectation $ normalize x 62 | e2 = Integrator.expectation $ normalize y 63 | in (abs (e1 - e2)) 64 | 65 | expectationNearSampling :: 66 | WeightedT (SamplerT (IOGenM StdGen) IO) Double -> 67 | WeightedT (SamplerT (IOGenM StdGen) IO) Double -> 68 | IO Double 69 | expectationNearSampling x y = do 70 | e1 <- sampleIOfixed $ fmap Sampler.sampleMean $ replicateM 10 $ WeightedT.runWeightedT x 71 | e2 <- sampleIOfixed $ fmap Sampler.sampleMean $ replicateM 10 $ WeightedT.runWeightedT y 72 | return (abs (e1 - e2)) 73 | 74 | testNormalNormal :: [Double] -> IO Bool 75 | testNormalNormal n = do 76 | let e = 77 | expectationNearNumeric 78 | (posterior (normalNormal' 1 (1, 10)) n) 79 | (normalNormalAnalytic 1 (1, 10) n) 80 | return (e < 1e-0) 81 | 82 | testGammaNormal :: [Double] -> IO Bool 83 | testGammaNormal n = do 84 | let e = 85 | expectationNearNumeric 86 | (posterior (gammaNormal' (1, 1)) n) 87 | (gammaNormalAnalytic (1, 1) n) 88 | return (e < 1e-1) 89 | 90 | testBetaBernoulli :: [Bool] -> IO Bool 91 | testBetaBernoulli bs = do 92 | let e = 93 | expectationNearNumeric 94 | (posterior (betaBernoulli' (1, 1)) bs) 95 | (betaBernoulliAnalytic (1, 1) bs) 96 | return (e < 1e-1) 97 | -------------------------------------------------------------------------------- /test/TestIntegrator.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | 3 | module TestIntegrator where 4 | 5 | import Control.Monad (replicateM) 6 | import Control.Monad.Bayes.Class 7 | ( MonadDistribution (bernoulli, gamma, normal, random, uniformD), 8 | MonadFactor (score), 9 | MonadMeasure, 10 | condition, 11 | factor, 12 | normalPdf, 13 | ) 14 | import Control.Monad.Bayes.Integrator 15 | import Control.Monad.Bayes.Sampler.Strict 16 | import Control.Monad.Bayes.Weighted (runWeightedT) 17 | import Control.Monad.ST (runST) 18 | import Data.AEq (AEq ((~==))) 19 | import Data.List (sortOn) 20 | import Data.Set (fromList) 21 | import Numeric.Log (Log (Exp, ln)) 22 | import Sprinkler (hard, soft) 23 | import Statistics.Distribution (Distribution (cumulative)) 24 | import Statistics.Distribution.Normal (normalDistr) 25 | 26 | normalExpectation :: Double -> Double -> Double 27 | normalExpectation mean std = expectation (normal mean std) 28 | 29 | normalVariance :: Double -> Double -> Double 30 | normalVariance mean std = variance (normal mean std) 31 | 32 | volumeIsOne :: [Double] -> Bool 33 | volumeIsOne = (~== 1.0) . volume . uniformD 34 | 35 | agg :: (MonadDistribution m) => m Int 36 | agg = do 37 | x <- uniformD [0, 1] 38 | y <- uniformD [2, 1] 39 | return (x + y) 40 | 41 | within :: (Ord a, Num a) => a -> a -> a -> Bool 42 | within n x y = abs (x - y) < n 43 | 44 | passed1, 45 | passed2, 46 | passed3, 47 | passed4, 48 | passed5, 49 | passed6, 50 | passed7, 51 | passed8, 52 | passed9, 53 | passed10, 54 | passed11, 55 | passed12, 56 | passed13, 57 | passed14 :: 58 | Bool 59 | -- enumerator from Integrator works 60 | passed1 = 61 | sortOn fst (enumeratorWith (fromList [3, 1, 2]) agg) 62 | ~== sortOn fst [(2, 0.5), (1, 0.25), (3, 0.25)] 63 | -- hard and soft sprinkers are equivalent under enumerator from Integrator 64 | passed2 = 65 | enumeratorWith (fromList [True, False]) (normalize (Sprinkler.hard)) 66 | ~== enumeratorWith (fromList [True, False]) (normalize (Sprinkler.soft)) 67 | -- expectation is as expected 68 | passed3 = 69 | expectation (fmap ((** 2) . (+ 1)) $ uniformD [0, 1]) == 2.5 70 | -- distribution is normalized 71 | passed4 = volume (uniformD [1, 2]) ~== 1.0 72 | -- enumerator is as expected 73 | passed5 = 74 | sortOn fst (enumeratorWith (fromList [0, 1 :: Int]) (empirical [0 :: Int, 1, 1, 1])) 75 | == sortOn fst [(1, 0.75), (0, 0.25)] 76 | -- normalization works right for enumerator, when there is conditioning 77 | passed6 = 78 | sortOn fst [(2, 0.5), (3, 0.5), (1, 0.0)] 79 | == sortOn 80 | fst 81 | ( enumeratorWith (fromList [1, 2, 3]) $ 82 | normalize $ do 83 | x <- uniformD [1 :: Int, 2, 3] 84 | condition (x > 1) 85 | return x 86 | ) 87 | -- soft factor statements work with enumerator and normalization 88 | passed7 = 89 | sortOn fst [(True, 0.75), (False, 0.25)] 90 | ~== sortOn 91 | fst 92 | ( enumeratorWith (fromList [True, False]) $ normalize do 93 | x <- bernoulli 0.5 94 | factor $ if x then 0.3 else 0.1 95 | return x 96 | ) 97 | -- volume of weight remains 1 98 | passed8 = 99 | 1 100 | == ( volume $ 101 | fmap (ln . exp . snd) $ runWeightedT do 102 | x <- bernoulli 0.5 103 | factor $ if x then 0.2 else 0.1 104 | return x 105 | ) 106 | -- normal probability in positive region is half 107 | passed9 = probability (1, 1000) (normal 1 10) - 0.5 < 0.05 108 | -- cdf as expected 109 | passed10 = cdf (normal 5 5) 5 - 0.5 < 0.05 110 | -- cdf as expected 111 | passed11 = 112 | (within 0.001) 113 | ( cdf 114 | ( do 115 | x <- normal 0 1 116 | return x 117 | ) 118 | 3 119 | ) 120 | (cumulative (normalDistr 0 1) 3) 121 | -- volume as expected 122 | passed12 = 123 | volume 124 | ( do 125 | x <- gamma 2 3 126 | return x 127 | ) 128 | ~== 1 129 | -- normalization preserves volume 130 | passed13 = 131 | (volume . normalize) 132 | ( do 133 | x <- gamma 2 3 134 | factor (normalPdf 0 1 x) 135 | return x 136 | ) 137 | ~== 1 138 | -- sampler and integrator agree on a non-trivial model 139 | passed14 = 140 | let sample = runST $ sampleSTfixed $ fmap sampleMean $ replicateM 10000 $ runWeightedT $ model1 141 | quadrature = expectation $ normalize $ model1 142 | in abs (sample - quadrature) < 0.01 143 | 144 | model1 :: (MonadMeasure m) => m Double 145 | model1 = do 146 | x <- random 147 | y <- random 148 | score (Exp $ log (f x + y)) 149 | return x 150 | where 151 | f x = cos (x ** 4) + x ** 3 152 | -------------------------------------------------------------------------------- /test/TestPipes.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wno-monomorphism-restriction #-} 2 | 3 | module TestPipes where 4 | 5 | import BetaBin (urn, urnP) 6 | import Control.Monad.Bayes.Class () 7 | import Control.Monad.Bayes.Enumerator (enumerator) 8 | import Data.AEq (AEq ((~==))) 9 | import Data.List (sort) 10 | import HMM (hmm, hmmPosterior) 11 | import Pipes.Prelude (toListM) 12 | 13 | urns :: Int -> Bool 14 | urns n = enumerator (urn n) ~== enumerator (urnP n) 15 | 16 | hmms :: [Double] -> Bool 17 | hmms observations = 18 | let hmmWithoutPipe = hmm observations 19 | hmmWithPipe = reverse . init <$> toListM (hmmPosterior observations) 20 | in -- Sort enumerator again although it is already sorted, see https://github.com/tweag/monad-bayes/issues/283 21 | sort (enumerator hmmWithPipe) ~== sort (enumerator hmmWithoutPipe) 22 | -------------------------------------------------------------------------------- /test/TestPopulation.hs: -------------------------------------------------------------------------------- 1 | module TestPopulation (weightedSampleSize, popSize, manySize, sprinkler, sprinklerExact, transCheck1, transCheck2, resampleCheck, popAvgCheck) where 2 | 3 | import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure) 4 | import Control.Monad.Bayes.Enumerator (enumerator, expectation) 5 | import Control.Monad.Bayes.Population as Population 6 | ( PopulationT, 7 | collapse, 8 | popAvg, 9 | pushEvidence, 10 | resampleMultinomial, 11 | runPopulationT, 12 | spawn, 13 | ) 14 | import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed) 15 | import Data.AEq (AEq ((~==))) 16 | import Sprinkler (soft) 17 | 18 | weightedSampleSize :: (MonadDistribution m) => PopulationT m a -> m Int 19 | weightedSampleSize = fmap length . runPopulationT 20 | 21 | popSize :: IO Int 22 | popSize = 23 | sampleIOfixed (weightedSampleSize $ spawn 5 >> sprinkler) 24 | 25 | manySize :: IO Int 26 | manySize = 27 | sampleIOfixed (weightedSampleSize $ spawn 5 >> sprinkler >> spawn 3) 28 | 29 | sprinkler :: (MonadMeasure m) => m Bool 30 | sprinkler = Sprinkler.soft 31 | 32 | sprinklerExact :: [(Bool, Double)] 33 | sprinklerExact = enumerator Sprinkler.soft 34 | 35 | transCheck1 :: Bool 36 | transCheck1 = 37 | enumerator (collapse sprinkler) 38 | ~== sprinklerExact 39 | 40 | transCheck2 :: Bool 41 | transCheck2 = 42 | enumerator (collapse (spawn 2 >> sprinkler)) 43 | ~== sprinklerExact 44 | 45 | resampleCheck :: Int -> Bool 46 | resampleCheck n = 47 | (enumerator . collapse . resampleMultinomial) (spawn n >> sprinkler) 48 | ~== sprinklerExact 49 | 50 | popAvgCheck :: Bool 51 | popAvgCheck = expectation f Sprinkler.soft ~== expectation id (popAvg f $ pushEvidence Sprinkler.soft) 52 | where 53 | f True = 10 54 | f False = 4 55 | -------------------------------------------------------------------------------- /test/TestSSMFixtures.hs: -------------------------------------------------------------------------------- 1 | module TestSSMFixtures where 2 | 3 | import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed) 4 | import NonlinearSSM 5 | import NonlinearSSM.Algorithms 6 | import Paths_monad_bayes (getDataDir) 7 | import System.IO (readFile') 8 | import System.IO.Error (catchIOError, isDoesNotExistError) 9 | import Test.Hspec 10 | 11 | fixtureToFilename :: Alg -> FilePath 12 | fixtureToFilename alg = "/test/fixtures/SSM-" ++ show alg ++ ".txt" 13 | 14 | testFixture :: Alg -> SpecWith () 15 | testFixture alg = do 16 | dataDir <- runIO getDataDir 17 | let filename = dataDir <> fixtureToFilename alg 18 | it ("should agree with the fixture " ++ filename) $ do 19 | ys <- sampleIOfixed $ generateData t 20 | fixture <- catchIOError (readFile' filename) $ \e -> 21 | if isDoesNotExistError e 22 | then return "" 23 | else ioError e 24 | sampled <- sampleIOfixed $ runAlgFixed (map fst ys) alg 25 | -- Reset in case of fixture update or creation 26 | writeFile filename sampled 27 | fixture `shouldBe` sampled 28 | 29 | test :: SpecWith () 30 | test = describe "TestSSMFixtures" $ mapM_ testFixture algs 31 | -------------------------------------------------------------------------------- /test/TestSampler.hs: -------------------------------------------------------------------------------- 1 | module TestSampler where 2 | 3 | import Control.Foldl qualified as Fold 4 | import Control.Monad (replicateM) 5 | import Control.Monad.Bayes.Class (MonadDistribution (normal)) 6 | import Control.Monad.Bayes.Sampler.Strict (sampleSTfixed) 7 | import Control.Monad.ST (runST) 8 | 9 | testMeanAndVariance :: Bool 10 | testMeanAndVariance = isDiff 11 | where 12 | m = runST (sampleSTfixed (foldWith Fold.mean (normal 2 4))) 13 | v = runST (sampleSTfixed (foldWith Fold.variance (normal 2 4))) 14 | foldWith f = fmap (Fold.fold f) . replicateM 100000 15 | isDiff = abs (2 - m) < 0.01 && abs (16 - v) < 0.1 16 | -------------------------------------------------------------------------------- /test/TestSequential.hs: -------------------------------------------------------------------------------- 1 | module TestSequential (twoSync, finishedTwoSync, checkTwoSync, checkPreserve, pFinished, isFinished, checkSync) where 2 | 3 | import Control.Monad.Bayes.Class 4 | ( MonadDistribution (uniformD), 5 | MonadMeasure, 6 | factor, 7 | ) 8 | import Control.Monad.Bayes.Enumerator as Dist (enumerator, mass) 9 | import Control.Monad.Bayes.Sequential.Coroutine (advance, finish, finished) 10 | import Data.AEq (AEq ((~==))) 11 | import Sprinkler (soft) 12 | 13 | twoSync :: (MonadMeasure m) => m Int 14 | twoSync = do 15 | x <- uniformD [0, 1] 16 | factor (fromIntegral x) 17 | y <- uniformD [0, 1] 18 | factor (fromIntegral y) 19 | return (x + y) 20 | 21 | finishedTwoSync :: (MonadMeasure m) => Int -> m Bool 22 | finishedTwoSync n = finished (run n twoSync) 23 | where 24 | run 0 d = d 25 | run k d = run (k - 1) (advance d) 26 | 27 | checkTwoSync :: Int -> Bool 28 | checkTwoSync 0 = mass (finishedTwoSync 0) False ~== 1 29 | checkTwoSync 1 = mass (finishedTwoSync 1) False ~== 1 30 | checkTwoSync 2 = mass (finishedTwoSync 2) True ~== 1 31 | checkTwoSync _ = error "Unexpected argument" 32 | 33 | sprinkler :: (MonadMeasure m) => m Bool 34 | sprinkler = Sprinkler.soft 35 | 36 | checkPreserve :: Bool 37 | checkPreserve = enumerator (finish sprinkler) ~== enumerator sprinkler 38 | 39 | pFinished :: Int -> Double 40 | pFinished 0 = 0.8267716535433071 41 | pFinished 1 = 0.9988062077198566 42 | pFinished 2 = 1 43 | pFinished _ = error "Unexpected argument" 44 | 45 | isFinished :: (MonadMeasure m) => Int -> m Bool 46 | isFinished n = finished (run n sprinkler) 47 | where 48 | run 0 d = d 49 | run k d = run (k - 1) (advance d) 50 | 51 | checkSync :: Int -> Bool 52 | checkSync n = mass (isFinished n) True ~== pFinished n 53 | -------------------------------------------------------------------------------- /test/TestStormerVerlet.hs: -------------------------------------------------------------------------------- 1 | module TestStormerVerlet 2 | ( passed1, 3 | ) 4 | where 5 | 6 | import Control.Lens 7 | import Control.Monad.ST 8 | import Data.Maybe (fromJust) 9 | import Data.Vector qualified as V 10 | import Linear qualified as L 11 | import Linear.V 12 | import Math.Integrators.StormerVerlet 13 | import Statistics.Function (square) 14 | 15 | gConst :: Double 16 | gConst = 6.67384e-11 17 | 18 | nStepsTwoPlanets :: Int 19 | nStepsTwoPlanets = 44 20 | 21 | stepTwoPlanets :: Double 22 | stepTwoPlanets = 24 * 60 * 60 * 100 23 | 24 | sunMass, jupiterMass :: Double 25 | sunMass = 1.9889e30 26 | jupiterMass = 1.8986e27 27 | 28 | jupiterPerihelion :: Double 29 | jupiterPerihelion = 7.405736e11 30 | 31 | jupiterV :: [Double] 32 | jupiterV = [-1.0965244901087316e02, -1.3710001990210707e04, 0.0] 33 | 34 | jupiterQ :: [Double] 35 | jupiterQ = [negate jupiterPerihelion, 0.0, 0.0] 36 | 37 | sunV :: [Double] 38 | sunV = [0.0, 0.0, 0.0] 39 | 40 | sunQ :: [Double] 41 | sunQ = [0.0, 0.0, 0.0] 42 | 43 | tm :: V.Vector Double 44 | tm = V.enumFromStepN 0 stepTwoPlanets nStepsTwoPlanets 45 | 46 | keplerP :: L.V2 (L.V3 Double) -> L.V2 (L.V3 Double) 47 | keplerP (L.V2 p1 p2) = L.V2 dHdP1 dHdP2 48 | where 49 | dHdP1 = p1 / pure jupiterMass 50 | dHdP2 = p2 / pure sunMass 51 | 52 | keplerQ :: L.V2 (L.V3 Double) -> L.V2 (L.V3 Double) 53 | keplerQ (L.V2 q1 q2) = L.V2 dHdQ1 dHdQ2 54 | where 55 | r = q2 L.^-^ q1 56 | ri = r `L.dot` r 57 | rr = ri * (sqrt ri) 58 | q1' = pure gConst * r / pure rr 59 | q2' = negate q1' 60 | dHdQ1 = q1' * pure sunMass * pure jupiterMass 61 | dHdQ2 = q2' * pure sunMass * pure jupiterMass 62 | 63 | listToV3 :: [a] -> L.V3 a 64 | listToV3 [x, y, z] = fromV . fromJust . fromVector . V.fromList $ [x, y, z] 65 | listToV3 xs = error $ "Only supply 3 elements not: " ++ show (length xs) 66 | 67 | initPQ2s :: L.V2 (L.V2 (L.V3 Double)) 68 | initPQ2s = 69 | L.V2 70 | (L.V2 (listToV3 jupiterQ) (listToV3 sunQ)) 71 | (L.V2 (pure jupiterMass * listToV3 jupiterV) (pure sunMass * listToV3 sunV)) 72 | 73 | result2 :: V.Vector (L.V2 (L.V2 (L.V3 Double))) 74 | result2 = runST $ integrateV (\h -> stormerVerlet2H (pure h) keplerQ keplerP) initPQ2s tm 75 | 76 | energy :: (L.V2 (L.V2 (L.V3 Double))) -> Double 77 | energy x = keJ + keS + peJ + peS 78 | where 79 | qs = x ^. _1 80 | ps = x ^. _2 81 | qJ = qs ^. _1 82 | qS = qs ^. _2 83 | pJ = ps ^. _1 84 | pS = ps ^. _2 85 | keJ = (* 0.5) $ (/ jupiterMass) $ sum $ fmap square pJ 86 | keS = (* 0.5) $ (/ sunMass) $ sum $ fmap square pS 87 | r = qJ L.^-^ qS 88 | ri = r `L.dot` r 89 | peJ = 0.5 * gConst * sunMass * jupiterMass / (sqrt ri) 90 | peS = 0.5 * gConst * sunMass * jupiterMass / (sqrt ri) 91 | 92 | energies :: V.Vector Double 93 | energies = fmap energy result2 94 | 95 | diffs :: V.Vector Double 96 | diffs = V.zipWith (\x y -> abs (x - y) / x) energies (V.tail energies) 97 | 98 | passed1 :: IO Bool 99 | passed1 = return $ V.all (< 1.0e-3) diffs 100 | -------------------------------------------------------------------------------- /test/TestWeighted.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE TypeFamilies #-} 2 | 3 | module TestWeighted (check, passed, result, model) where 4 | 5 | import Control.Monad (unless, when) 6 | import Control.Monad.Bayes.Class 7 | ( MonadDistribution (normal, uniformD), 8 | MonadMeasure, 9 | factor, 10 | ) 11 | import Control.Monad.Bayes.Sampler.Strict (sampleIOfixed) 12 | import Control.Monad.Bayes.Weighted (runWeightedT) 13 | import Data.AEq (AEq ((~==))) 14 | import Data.Bifunctor (second) 15 | import Numeric.Log (Log (Exp, ln)) 16 | 17 | model :: (MonadMeasure m) => m (Int, Double) 18 | model = do 19 | n <- uniformD [0, 1, 2] 20 | unless (n == 0) (factor 0.5) 21 | x <- if n == 0 then return 1 else normal 0 1 22 | when (n == 2) (factor $ (Exp . log) (x * x)) 23 | return (n, x) 24 | 25 | result :: (MonadDistribution m) => m ((Int, Double), Double) 26 | result = second (exp . ln) <$> runWeightedT model 27 | 28 | passed :: IO Bool 29 | passed = fmap check (sampleIOfixed result) 30 | 31 | check :: ((Int, Double), Double) -> Bool 32 | check ((0, 1), 1) = True 33 | check ((1, _), y) = y ~== 0.5 34 | check ((2, x), y) = y ~== 0.5 * x * x 35 | check _ = False 36 | -------------------------------------------------------------------------------- /test/fixtures/HMM10-MH.txt: -------------------------------------------------------------------------------- 1 | ["[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,2,1,1,1,1,1,1,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,1,1,2,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[1,2,1,1,1,1,1,2,2,0]","[2,2,1,1,1,1,1,2,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,2,1,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[2,1,2,1,1,2,1,1,2,0]","[1,1,2,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,2,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,2,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[1,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]","[0,1,1,1,1,0,1,1,2,0]"] -------------------------------------------------------------------------------- /test/fixtures/HMM10-RMSMC.txt: -------------------------------------------------------------------------------- 1 | [("[2,1,1,1,0,1,1,2,1,1]",2.4438034074800498e-8),("[1,1,1,1,1,2,1,1,1,0]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,1,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,2,0,2,1,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,1,2,1,2,1,2,1,1,2]",2.4438034074800498e-8),("[1,2,1,1,2,2,0,1,1,1]",2.4438034074800498e-8)] -------------------------------------------------------------------------------- /test/fixtures/HMM10-SMC.txt: -------------------------------------------------------------------------------- 1 | [("[1,1,0,1,1,2,1,1,1,2]",2.964810681340389e-9),("[1,1,0,1,1,2,1,1,1,2]",2.964810681340389e-9),("[1,1,0,1,1,2,1,1,1,2]",2.964810681340389e-9),("[1,2,0,2,1,2,1,1,1,0]",2.964810681340389e-9),("[1,2,0,2,1,2,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,0,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,0,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,0,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,2,0,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,2,0,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,0,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,0,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,0,1,1]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,1,1,2]",2.964810681340389e-9),("[1,2,0,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,2,0,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,2,1,1,0,2,2,1,2,0]",2.964810681340389e-9),("[1,2,1,1,0,2,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,0,2,2,1,2,0]",2.964810681340389e-9),("[1,2,2,1,1,1,1,1,2,0]",2.964810681340389e-9),("[1,2,2,1,1,1,1,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,1,0]",2.964810681340389e-9),("[1,1,1,2,1,0,0,1,1,0]",2.964810681340389e-9),("[1,1,1,2,1,0,0,1,1,0]",2.964810681340389e-9),("[1,1,1,2,1,0,0,1,1,0]",2.964810681340389e-9),("[1,1,1,1,2,0,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,2,0,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,2,2]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,2,2]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,1,1]",2.964810681340389e-9),("[2,1,1,1,1,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,2,0,1,2]",2.964810681340389e-9),("[1,1,1,1,2,1,2,0,1,2]",2.964810681340389e-9),("[2,1,1,1,2,1,2,0,1,2]",2.964810681340389e-9),("[1,1,1,2,0,2,1,2,1,0]",2.964810681340389e-9),("[1,1,1,2,0,1,1,1,1,2]",2.964810681340389e-9),("[1,1,1,2,0,1,1,1,1,2]",2.964810681340389e-9),("[1,1,1,1,0,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,0,1,1,1,1]",2.964810681340389e-9),("[1,1,1,1,1,1,1,2,1,0]",2.964810681340389e-9),("[1,2,0,1,1,1,1,2,1,0]",2.964810681340389e-9),("[1,2,0,1,1,1,1,2,1,0]",2.964810681340389e-9),("[1,1,1,1,2,1,2,1,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,1,1,2,2]",2.964810681340389e-9),("[1,1,2,1,1,1,1,1,2,2]",2.964810681340389e-9),("[1,1,2,1,1,1,1,2,1,1]",2.964810681340389e-9),("[1,1,2,1,1,1,1,2,1,1]",2.964810681340389e-9),("[1,2,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,2,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,2,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,1,1,1,2,1,1,1,2,0]",2.964810681340389e-9),("[1,1,0,1,0,1,0,1,1,2]",2.964810681340389e-9),("[1,1,0,1,0,1,0,1,1,2]",2.964810681340389e-9),("[1,1,0,1,0,1,0,1,1,2]",2.964810681340389e-9),("[0,1,1,1,1,0,1,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,0,1,1,2,0]",2.964810681340389e-9),("[1,1,2,1,2,2,0,1,1,1]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,0,1,1,1,1,0]",2.964810681340389e-9),("[1,1,1,1,2,1,0,0,1,1]",2.964810681340389e-9),("[1,1,1,1,2,1,0,0,1,1]",2.964810681340389e-9),("[1,1,2,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,2,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,1,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,2,1,2]",2.964810681340389e-9),("[1,1,1,1,1,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,2,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,2,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,2,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,1,2,2,0,1,1,2]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,1,2,1,2,0]",2.964810681340389e-9),("[1,1,1,1,1,2,1,2,1,0]",2.964810681340389e-9),("[1,1,1,1,1,2,1,2,1,0]",2.964810681340389e-9),("[1,2,1,1,1,2,1,2,1,0]",2.964810681340389e-9),("[2,1,1,1,2,1,1,2,1,0]",2.964810681340389e-9),("[1,1,1,1,1,1,0,2,1,0]",2.964810681340389e-9),("[1,1,1,1,1,1,0,2,1,0]",2.964810681340389e-9),("[2,2,2,1,1,1,1,1,2,1]",2.964810681340389e-9)] -------------------------------------------------------------------------------- /test/fixtures/LR10-MH.txt: -------------------------------------------------------------------------------- 1 | ["2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","2.950544941215132e-6","4.417248310074391e-5","4.417248310074391e-5","4.417248310074391e-5","4.417248310074391e-5","4.417248310074391e-5","4.417248310074391e-5","4.417248310074391e-5","4.417248310074391e-5","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","3.542845305464785e-4","1.283996726025695e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","5.8758436338617424e-6","0.21036341524185248","0.21036341524185248","0.21036341524185248","0.11759542598406439","0.11759542598406439","0.11759542598406439","0.11759542598406439","0.11759542598406439","0.11759542598406439","4.7312512236572256e-2","4.7312512236572256e-2","4.7312512236572256e-2","4.7312512236572256e-2","0.4947233726255189","0.4947233726255189","0.38932500660802716","0.38932500660802716","0.38932500660802716","0.38932500660802716","0.7207054705150296","0.6553804082004575","0.6553804082004575","0.6553804082004575","0.6553804082004575","0.8884089778243283","0.8884089778243283","0.8884089778243283","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","7.972233473120037e-3","0.9255022758580312","0.6388734801405944","2.2127051202286738e-5","2.2127051202286738e-5","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.340807350374084e-6","5.181558495552509e-6","5.181558495552509e-6","5.1274144757719514e-6","5.1274144757719514e-6","5.1274144757719514e-6","5.1274144757719514e-6","5.1274144757719514e-6","5.1274144757719514e-6"] -------------------------------------------------------------------------------- /test/fixtures/LR10-RMSMC.txt: -------------------------------------------------------------------------------- 1 | [("5.333997036871341e-6",6.253483531321334e-5),("4.7394126422869915e-6",6.253483531321334e-5),("7.798412010176118e-6",6.253483531321334e-5),("3.6537314189667804e-3",6.253483531321334e-5),("2.8719951155583437e-3",6.253483531321334e-5),("8.106430574751586e-3",6.253483531321334e-5),("0.8083297132912659",6.253483531321334e-5),("0.3178833271596258",6.253483531321334e-5),("0.3487591395231385",6.253483531321334e-5),("2.027192522287899e-2",6.253483531321334e-5)] -------------------------------------------------------------------------------- /test/fixtures/LR10-SMC.txt: -------------------------------------------------------------------------------- 1 | [("2.324831917751426e-5",8.615785653338886e-6),("4.0753397495054896e-8",8.615785653338886e-6),("7.274125904873279e-9",8.615785653338886e-6),("1.3743372716534289e-5",8.615785653338886e-6),("4.254983717992536e-6",8.615785653338886e-6),("1.2272127506266697e-6",8.615785653338886e-6),("1.214022726807993e-7",8.615785653338886e-6),("8.847449654106116e-8",8.615785653338886e-6),("1.6000818494452112e-3",8.615785653338886e-6),("2.0279043319035277e-3",8.615785653338886e-6),("3.5018249880037834e-5",8.615785653338886e-6),("0.16426780827561477",8.615785653338886e-6),("0.10773403426050324",8.615785653338886e-6),("0.3304703458482792",8.615785653338886e-6),("0.10913515995295375",8.615785653338886e-6),("3.8113707417415554e-7",8.615785653338886e-6),("1.1707029254682367e-4",8.615785653338886e-6),("2.047893831024706e-3",8.615785653338886e-6),("1.8725448138433734e-3",8.615785653338886e-6),("6.053601828693627e-2",8.615785653338886e-6),("0.1757594676547173",8.615785653338886e-6),("1.3511280746237259e-3",8.615785653338886e-6),("0.23723482733424164",8.615785653338886e-6),("0.16151469556914314",8.615785653338886e-6),("7.191080175022858e-3",8.615785653338886e-6),("1.6108325498387603e-7",8.615785653338886e-6),("1.6659054333722745e-7",8.615785653338886e-6),("4.28535218704211e-4",8.615785653338886e-6),("4.305836021545625e-4",8.615785653338886e-6),("4.7576965500979653e-4",8.615785653338886e-6),("4.4539756971784895e-4",8.615785653338886e-6),("3.96355147105004e-4",8.615785653338886e-6),("2.91792488057767e-3",8.615785653338886e-6),("0.1387157529569227",8.615785653338886e-6),("0.12679393576798442",8.615785653338886e-6),("0.23010548225976238",8.615785653338886e-6),("0.25223594655800285",8.615785653338886e-6),("6.97387528630698e-5",8.615785653338886e-6),("9.333831289852731e-5",8.615785653338886e-6),("8.375626668711192e-5",8.615785653338886e-6),("1.5213683176568195e-2",8.615785653338886e-6),("3.5712314627571923e-4",8.615785653338886e-6),("7.636217014509781e-4",8.615785653338886e-6),("2.0872090258618442e-4",8.615785653338886e-6),("5.196909883379355e-4",8.615785653338886e-6),("5.701798151154428e-4",8.615785653338886e-6),("5.093049208576314e-4",8.615785653338886e-6),("1.4278233878183234e-2",8.615785653338886e-6),("0.409063870280768",8.615785653338886e-6),("2.413899012911671e-2",8.615785653338886e-6),("4.9820455672153804e-2",8.615785653338886e-6),("2.3987373289798696e-3",8.615785653338886e-6),("2.3730997104438635e-3",8.615785653338886e-6),("1.5421674679073712e-2",8.615785653338886e-6),("6.127942201533583e-4",8.615785653338886e-6),("0.29100637723657274",8.615785653338886e-6),("6.466851911961999e-4",8.615785653338886e-6),("6.899429519546635e-4",8.615785653338886e-6),("6.796536628553891e-4",8.615785653338886e-6),("1.5687561944396283e-2",8.615785653338886e-6),("1.0114194783259966e-2",8.615785653338886e-6),("8.650936462496627e-4",8.615785653338886e-6),("8.659503706938266e-4",8.615785653338886e-6),("0.9866480964309671",8.615785653338886e-6),("0.9844641052458803",8.615785653338886e-6),("0.9877580766923669",8.615785653338886e-6),("0.3570814146622049",8.615785653338886e-6),("0.3333232242467457",8.615785653338886e-6),("0.28788395991944565",8.615785653338886e-6),("0.6235083846866266",8.615785653338886e-6),("2.9830083243998794e-2",8.615785653338886e-6),("0.6715098168079819",8.615785653338886e-6),("0.5092770566793119",8.615785653338886e-6),("0.14931145218944078",8.615785653338886e-6),("0.9476390701984468",8.615785653338886e-6),("0.19844654340805706",8.615785653338886e-6),("0.8692525556246279",8.615785653338886e-6),("0.2857288027412788",8.615785653338886e-6),("0.9128866108374706",8.615785653338886e-6),("0.5972931052418653",8.615785653338886e-6),("3.185136062093685e-2",8.615785653338886e-6),("0.7815143883956358",8.615785653338886e-6),("9.045285261453962e-4",8.615785653338886e-6),("1.0570963598552236e-4",8.615785653338886e-6),("0.9988704519081596",8.615785653338886e-6),("0.9983306238100842",8.615785653338886e-6),("0.999817180959221",8.615785653338886e-6),("0.999632323087958",8.615785653338886e-6),("0.9978593357863148",8.615785653338886e-6),("0.9932458970319142",8.615785653338886e-6),("0.9997633989820984",8.615785653338886e-6),("0.25474879877870016",8.615785653338886e-6),("0.12254684366442331",8.615785653338886e-6),("6.952221863096741e-2",8.615785653338886e-6),("0.38421927032809844",8.615785653338886e-6),("0.13255536658985864",8.615785653338886e-6),("0.1506084845751688",8.615785653338886e-6),("0.3502412842110343",8.615785653338886e-6),("1.28565002377979e-2",8.615785653338886e-6),("0.3782237084074611",8.615785653338886e-6)] -------------------------------------------------------------------------------- /test/fixtures/SSM-PMMH.txt: -------------------------------------------------------------------------------- 1 | [[([74405.69500410178,143777.3515026691,195290.64675632896,483878.28639985673,600603.4104497777],1.0),([74405.69500410178,143777.3515026691,195290.64675632896,483878.28639985673,600603.4104497777],1.0),([74405.69500410178,143777.3515026691,195290.64675632896,483878.28639985673,600603.4104497777],1.0)],[([157620.04097610444,26661.523636563594,321204.4219216401,421274.0528523404,487363.3134055787],1.0),([157620.04097610444,26661.523636563594,321204.4219216401,421274.0528523404,487363.3134055787],1.0),([157620.04097610444,26661.523636563594,321204.4219216401,421274.0528523404,487363.3134055787],1.0)],[([-1.2600621067470811e67,-1.3171618074660135e67,3.55155213532486e66,-9.486041679240111e66,-1.4476178361450074e67],1.0),([-1.2600621067470811e67,-1.3171618074660135e67,3.55155213532486e66,-9.486041679240111e66,-1.4476178361450074e67],1.0),([-1.2600621067470811e67,-1.3171618074660135e67,3.55155213532486e66,-9.486041679240111e66,-1.4476178361450074e67],1.0)]] -------------------------------------------------------------------------------- /test/fixtures/SSM-RMSMC.txt: -------------------------------------------------------------------------------- 1 | [([-2.660097878548362e12,4.2406657192899927e11,5.124397279021509e11,-1.5388049692223555e12,5.100591467413694e11],1.7934178371940385e-141),([85360.89249769927,82937.71043376798,236528.18865034508,184344.32920611685,257818.8194711436],1.7934178371940385e-141),([111456.98861817457,175718.5311026478,-123025.40752936345,-216315.66615773254,-216313.3788230968],1.7934178371940385e-141),([19634.90868499715,3982.228157420627,71064.25154723959,-61863.47069790381,-124170.21806752698],1.7934178371940385e-141),([-1.5475119433556266e13,7.505784271326075e13,1.8087989772377637e13,2.3121907140178836e13,6.957846697896459e13],1.7934178371940385e-141),([3.3114425601587117e13,-7.174858488587559e12,-5.3895766067097984e13,-5.574198391310134e13,-7.085806871739197e13],1.7934178371940385e-141),([1.7797637411011798e9,2.257078068208236e9,1.7822875201019692e9,-3.1059378475534713e8,-2.022387835329388e8],1.7934178371940385e-141),([-362.4859209624216,597.209862141266,354.2417236032639,699.0287190356412,-719.9079388224256],1.7934178371940385e-141),([-6962602.171312882,4072501.6094401204,-1935864.3227236385,6378383.05623946,2107530.9972696295],1.7934178371940385e-141),([535.0932822165735,229.4681249747279,-318.4034398226934,-425.6731795563692,-736.158075205745],1.7934178371940385e-141)] -------------------------------------------------------------------------------- /test/fixtures/SSM-RMSMCBasic.txt: -------------------------------------------------------------------------------- 1 | [([-2.660097878548362e12,4.2406657192899927e11,5.124397279021509e11,-1.5388049692223555e12,5.100591467413694e11],1.7934178371940385e-141),([85360.89249769927,82937.71043376798,236528.18865034508,184344.32920611685,257818.8194711436],1.7934178371940385e-141),([111456.98861817457,175718.5311026478,-123025.40752936345,-216315.66615773254,-216313.3788230968],1.7934178371940385e-141),([19634.90868499715,3982.228157420627,71064.25154723959,-61863.47069790381,-124170.21806752698],1.7934178371940385e-141),([-1.5475119433556266e13,7.505784271326075e13,1.8087989772377637e13,2.3121907140178836e13,6.957846697896459e13],1.7934178371940385e-141),([3.3114425601587117e13,-7.174858488587559e12,-5.3895766067097984e13,-5.574198391310134e13,-7.085806871739197e13],1.7934178371940385e-141),([1.7797637411011798e9,2.257078068208236e9,1.7822875201019692e9,-3.1059378475534713e8,-2.022387835329388e8],1.7934178371940385e-141),([-362.4859209624216,597.209862141266,354.2417236032639,699.0287190356412,-719.9079388224256],1.7934178371940385e-141),([-6962602.171312882,4072501.6094401204,-1935864.3227236385,6378383.05623946,2107530.9972696295],1.7934178371940385e-141),([535.0932822165735,229.4681249747279,-318.4034398226934,-425.6731795563692,-736.158075205745],1.7934178371940385e-141)] -------------------------------------------------------------------------------- /test/fixtures/SSM-RMSMCDynamic.txt: -------------------------------------------------------------------------------- 1 | [([61234.923743603955,79039.83817954235,354024.81636628765,225755.73057039993,-78843.37322818518],0.0),([61234.923743603955,-205024.66324964678,-438520.7645656072,-526045.6062936985,17959.08713788638],0.0),([61234.923743603955,130707.33129683959,260276.7204227042,538891.1815485102,432537.1717560617],0.0),([61234.923743603955,425968.16738967673,72802.89417099475,97062.29318414515,-90904.59187690681],0.0),([61234.923743603955,-80888.00179367141,122235.67304475381,48742.27626015559,-149682.32933231423],0.0),([61234.923743603955,43.833902800088254,-417728.0201965655,49565.634594935604,-303943.3354524304],0.0),([61234.923743603955,350501.69936972257,118986.06426751378,99950.78931739656,-60488.53431816819],0.0),([61234.923743603955,-117376.5868812376,116017.94360094423,378976.39475725644,74865.6296219704],0.0),([61234.923743603955,156368.9791422615,-586653.2615030725,-238480.82081038723,51581.15175237715],0.0),([61234.923743603955,-150776.59937461224,-30862.03908288705,200382.13919586508,-107135.36343350058],0.0)] -------------------------------------------------------------------------------- /test/fixtures/SSM-SMC.txt: -------------------------------------------------------------------------------- 1 | [([-1.6946443595984358e8,-2.0398900541476977e8,5.988104418627801e8,5.186441087015647e7,-1.1107580460544899e9],3.747925572660412e-147),([-1.6946443595984358e8,1.762322765772586e8,1.3143034131110222e9,2.917359439754021e7,-4.678360689283452e8],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,-1.0039010467494124e9,-4.5194462919398534e8],3.747925572660412e-147),([-1.6946443595984358e8,-2.0398900541476977e8,6.65407483202134e8,-1.3610874802534976e9,1.7804869696064534e9],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,-1.0039010467494124e9,-4.5194462919398534e8],3.747925572660412e-147),([-1.6946443595984358e8,-7.848111477226721e8,-1.536250656089418e9,-1.2593852525318892e9,9.33478070563457e8],3.747925572660412e-147),([-1.6946443595984358e8,1.762322765772586e8,1.3143034131110222e9,2.917359439754021e7,-4.678360689283452e8],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,7.201669253635451e8,-6.528627637915363e8],3.747925572660412e-147),([-1.6946443595984358e8,-4.978125179866476e8,-6.568379081060445e8,7.201669253635451e8,-6.528627637915363e8],3.747925572660412e-147),([-1.6946443595984358e8,-7.848111477226721e8,-1.536250656089418e9,-1.2593852525318892e9,9.33478070563457e8],3.747925572660412e-147)] -------------------------------------------------------------------------------- /test/fixtures/SSM-SMC2.txt: -------------------------------------------------------------------------------- 1 | [([([-9090.483553160731,-18364.240866577857,38447.317849110055,-3829.950678281628,-18689.938602553048],0.3333333333333341),([-9090.483553160731,-18364.240866577857,38447.317849110055,-3829.950678281628,-18689.938602553048],0.3333333333333341),([-9090.483553160731,-18364.240866577857,38447.317849110055,25131.836867847727,47603.03068211828],0.3333333333333341)],9.474658864966518e-180),([([-2.2418721864335723e9,2.4219211687208967e9,5.4463793547824545e9,7.074651672385337e8,2.595090695345872e8],0.3333333333333341),([-2.2418721864335723e9,2.4219211687208967e9,5.4463793547824545e9,7.074651672385337e8,2.595090695345872e8],0.3333333333333341),([-2.2418721864335723e9,2.4219211687208967e9,5.4463793547824545e9,7.074651672385337e8,2.595090695345872e8],0.3333333333333341)],9.474658864966518e-180)] --------------------------------------------------------------------------------