├── .gitignore ├── .travis.yml ├── Dockerfile ├── INSTALL.md ├── LICENSE ├── Makefile ├── README.md ├── bin ├── convert ├── convert-all ├── gnuplot-hist ├── gnuplot-hist.gp └── parse-all ├── deps.edn ├── doc ├── interaction.md ├── language.md └── metaprob-in-clojure.txt ├── docker-compose.yaml ├── project.clj ├── src └── metaprob │ ├── autotrace.cljc │ ├── code_handlers.cljc │ ├── distributions.cljc │ ├── examples │ ├── aide.clj │ ├── all.clj │ ├── curve_fitting.clj │ ├── earthquake.clj │ ├── flip_n_coins.clj │ ├── inference_on_gaussian.clj │ ├── long_test.clj │ ├── main.clj │ ├── multimixture_dsl.clj │ └── spelling_correction.clj │ ├── expander.cljc │ ├── generative_functions.cljc │ ├── inference.cljc │ ├── prelude.cljc │ └── trace.cljc ├── test └── metaprob │ ├── compositional_test.cljc │ ├── distributions_test.cljc │ ├── examples │ ├── all_test.clj │ ├── flip_n_coins_test.clj │ ├── inference_on_gaussian_test.clj │ └── main_test.clj │ ├── inference_test.cljc │ ├── prelude_test.cljc │ ├── syntax_test.cljc │ ├── test_runner.cljs │ └── trace_test.cljc └── tutorial ├── README.md ├── resources ├── plot-trace.js └── plotly-latest.min.js └── src └── metaprob └── tutorial └── jupyter.clj /.gitignore: -------------------------------------------------------------------------------- 1 | pom.xml 2 | pom.xml.asc 3 | /lib/ 4 | /classes/ 5 | /target/ 6 | /checkouts/ 7 | /results/ 8 | .nrepl-port 9 | .ipynb_checkpoints 10 | .idea 11 | TAGS 12 | parsings 13 | *.tmp 14 | *-tmp.png 15 | *~ 16 | .#* 17 | *# 18 | 19 | # Clojure 20 | *.jar 21 | *.class 22 | 23 | # ClojureScript 24 | /out/ 25 | /nashorn_code_cache/ 26 | 27 | # Leiningen 28 | .lein-deps-sum 29 | .lein-repl-history 30 | .lein-plugins/ 31 | .lein-failures 32 | .lein_classpath 33 | bin/lein 34 | /src/metaprob/tutorial/.ipynb_checkpoints/ 35 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | 3 | language: bash 4 | 5 | sudo: required 6 | 7 | services: 8 | - docker 9 | 10 | before_install: 11 | - make docker-build 12 | 13 | script: 14 | - make docker-test 15 | 16 | notifications: 17 | # Notification options are documented here: 18 | # https://docs.travis-ci.com/user/notifications/#configuring-slack-notifications 19 | slack: 20 | rooms: 21 | # Because this setting includes our Slack API token it is encrypted. For 22 | # more information on using encryption with Travis see the encryption 23 | # documentation here: https://docs.travis-ci.com/user/encryption-keys/ 24 | secure: B9/VRP+CoXvCE97atMXJ3eFKDbLTIxKTyIcJ7BId8pbCEYEBO+5gnmkAfrt+bP/A+MFDsXVOUw990ARSXxh4GI8ydA+KmqQZ39Zm4cB3UpaM6E24f9McO+5ZicFnHg5aSTipZm5dL8cubdv/f3S0G0kSrkAvODEQK7FHBQsf7bJh9TPXUFteWeGjmKtqIBn4R2oOKweBW9LyESwu/5Y7O2BqtD79Z7is9ejzgLqatq+7YqL1p67jRD9icw97ZNt4tT7NQwQzS94U9vgYbTp4qnRzFN5jrTiRfHuqMSsEAVwYZwkoLaHb6OXhU+deZ4FIggJhheuAHA/zHyle7IwkZY4lZFkj8hkwU8f972FqziuWrvPAtDE2xxJ9PW2ze8hpagu+Sm8LZSnKOEnerxrXLJhBTDOwDT8ARCGuO84G/60aE3ajQfiNAzSQvi5n94H0io1I9udlLX7UAG001eTilyzuJQ973b500KzyMbjgvQvzxXiKtoxP/JinHE/FNYmJm2V3csboQoOJ2+ur1I/vHzdmSrPQ6s14gI5Wf3N2JqkwKRbFjqylLgfihfCO5G09Wly8Vc1sZ8KpggFspimJvUZY+0OEZ79nCKD1iT81m43KCgO9tiDsG3xigVqfQRatkV29SdN3b4L+AkrQxaSKzgmoufU6N6nH68sTWIOkWFE= 25 | on_success: change 26 | on_failure: always 27 | on_pull_requests: false 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Docker Hub maintains an official Clojure library. Using it as our base saves 2 | # us from having to install Leiningen ourselves. 3 | 4 | FROM clojure:lein-2.8.1 5 | 6 | # Install curl so we can use it to download the Clojure command line tools, 7 | # install time so we can measure how long it takes to run the examples, install 8 | # rlwrap for use with clj, and install pip so we can install jupyter, and 9 | # install cmake and xxd so we can build Planck. 10 | RUN apt-get update -qq \ 11 | && apt-get upgrade -qq \ 12 | && apt-get install -qq -y \ 13 | cmake \ 14 | curl \ 15 | nodejs \ 16 | time \ 17 | rlwrap \ 18 | python3-pip \ 19 | xxd 20 | 21 | # Install Node so we can run our tests in JVM-hosted Clojurescript mode. 22 | 23 | RUN ln -s /usr/bin/nodejs /usr/bin/node 24 | 25 | # Install the Clojure command line tools. These instructions are taken directly 26 | # from the Clojure "Getting Started" guide: 27 | # https://clojure.org/guides/getting_started 28 | 29 | ENV CLOJURE_VERSION 1.9.0.394 30 | RUN curl -O https://download.clojure.org/install/linux-install-${CLOJURE_VERSION}.sh \ 31 | && chmod +x linux-install-${CLOJURE_VERSION}.sh \ 32 | && ./linux-install-${CLOJURE_VERSION}.sh 33 | 34 | # Work around a bug in the Ubuntu package `ca-certificates-java` 35 | # (https://stackoverflow.com/a/33440168) 36 | RUN dpkg --purge --force-depends ca-certificates-java \ 37 | && apt-get install ca-certificates-java 38 | 39 | # Install Planck so we can run our tests in self-hosted mode. 40 | 41 | RUN apt-get update \ 42 | && apt-get install -y --no-install-recommends apt-utils \ 43 | && apt-get install -qq -y \ 44 | libjavascriptcoregtk-4.0 \ 45 | libglib2.0-dev \ 46 | libzip-dev \ 47 | libcurl4-gnutls-dev \ 48 | libicu-dev 49 | 50 | RUN git clone https://github.com/planck-repl/planck.git \ 51 | && cd planck \ 52 | && git fetch --all --tags \ 53 | && git checkout tags/2.21.0 \ 54 | && script/build --fast \ 55 | && script/install \ 56 | && planck -h \ 57 | && cd .. 58 | 59 | # Install jupyter. 60 | 61 | RUN pip3 install jupyter 62 | 63 | # Create a new user to run commands as per the best practice. 64 | # https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#user 65 | # Use --no-log-init to work around the bug detailed there. 66 | 67 | RUN groupadd metaprob && \ 68 | useradd --no-log-init -m -g metaprob metaprob 69 | 70 | # Switch users early so files created by subsequent operations will be owned by the 71 | # runtime user. This also makes it so that commands will not be run as root. 72 | 73 | USER metaprob 74 | 75 | ENV METAPROB_DIR /home/metaprob/projects/metaprob-clojure 76 | RUN mkdir -p $METAPROB_DIR 77 | WORKDIR $METAPROB_DIR 78 | 79 | # Retrieve our dependencies now in order to reduce the time it takes for the 80 | # notebook to start when the image is run. 81 | 82 | COPY --chown=metaprob:metaprob ./deps.edn $METAPROB_DIR 83 | COPY --chown=metaprob:metaprob ./project.clj $METAPROB_DIR 84 | RUN clojure -e "(clojure-version)" 85 | 86 | # downgrade tornado. 87 | # see https://stackoverflow.com/questions/54963043/jupyter-notebook-no-connection-to-server-because-websocket-connection-fails 88 | 89 | USER root 90 | RUN pip3 uninstall -y tornado 91 | RUN pip3 install tornado==5.1.1 92 | 93 | USER metaprob 94 | 95 | RUN lein jupyter install-kernel 96 | 97 | # Copy in the rest of our source. 98 | 99 | COPY --chown=metaprob:metaprob . $METAPROB_DIR 100 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installing metaprob-in-clojure 2 | 3 | Setting up a development environment for Metaprob is still a somewhat 4 | manual process. Note that what follows are instructions for those 5 | interesting in working on Metaprob itself or users wanting to use 6 | Metaprob for specific projects. 7 | 8 | If you're interested in learning and experimenting with Metaprob, we 9 | strongly recommend starting with the [Tutorial](tutorial/README.md) 10 | rather than following these instructions. 11 | 12 | ## Install Java 13 | 14 | You will need Java 1.8 or later in order to run Leiningen, Clojure, 15 | and Metaprob. Personally I (JAR) use JRE version 1.8.0_05. You need 16 | the full Java development environment (JDK), not just the JVM. 17 | 18 | Check you local Java version with 19 | 20 | java -version 21 | 22 | ## Clone the metaprob-clojure repository 23 | 24 | Clone this repository from github, and set the working directory to 25 | the root of the clone: 26 | 27 | git clone git@github.com:probcomp/metaprob-clojure.git 28 | cd metaprob-clojure 29 | 30 | Use `https` if the above method (ssh) doesn't work: 31 | 32 | git clone https://github.com/probcomp/metaprob-clojure.git 33 | cd metaprob-clojure 34 | 35 | ## Install the Clojure command line tools and Leiningen 36 | 37 | To use Metaprob and run the tutorials you will need to install both the Clojure 38 | command line tools and Leiningen. It is not necessary to separately install 39 | Clojure, because it will be retrieved as-needed by these tools. Metaprob 40 | requires Clojure 1.9 or later. 41 | 42 | ### MacOS 43 | 44 | On MacOS we recommend you first [install Homewbrew](https://brew.sh/) and then 45 | use it to install both the Clojure command line tools and Leiningen like so: 46 | 47 | 1. `brew update` 48 | 2. `brew upgrade` 49 | 3. `brew install clojure` 50 | 4. `brew install leiningen` 51 | 52 | ### Linux 53 | 54 | To install the Clojure command line tools follow the instructions for your 55 | system in the [Clojure getting started 56 | guide](https://clojure.org/guides/getting_started#_clojure_installer_and_cli_tools). 57 | 58 | For a quick Leiningen installation, just do `make`: 59 | 60 | make 61 | 62 | This creates `bin/lein` which can be used as a shell command. If you 63 | want to be able to say just `lein`, you'll need to put it in your 64 | PATH. E.g. if you have a `~/bin` directory in your PATH, try this: 65 | 66 | ln -sf $PWD/bin/lein ~/bin/ 67 | 68 | Leiningen keeps some state in the `~/.lein` directory. 69 | 70 | Full instructions for installing Leiningen are 71 | [here](https://leiningen.org/#install). 72 | 73 | The `make` that you run for this purpose also creates a file 74 | `.lein_classpath`, which is used to speed up Java invocation in some 75 | cases. 76 | 77 | 78 | ## Emacs setup 79 | 80 | It is possible to use metaprob-in-clojure exclusively from the shell, 81 | but running a REPL is better in a supervised environment where you 82 | have a transcript, can search, can get to source code easily, and so 83 | on. This observation is independent of your choice of text editor. 84 | There may be Clojure support for `vim` and `eclipse` support; I 85 | don't know. Clojure support under emacs is pretty good, in case you 86 | know emacs or are willing to learn. 87 | 88 | Following is some information on setting up emacs for Clojure. What I 89 | suggest here is not necessarily right or best; it's just stuff I got 90 | off the Internet. 91 | 92 | Put the following in your `.emacs` file: 93 | 94 | ;; From http://clojure-doc.org/articles/tutorials/emacs.html. 95 | (require 'package) 96 | (add-to-list 'package-archives 97 | '("melpa-stable" . "http://stable.melpa.org/packages/") 98 | t) 99 | ;; "Run M-x package-refresh-contents to pull in the package listing." 100 | (package-initialize) 101 | (defvar clojure-packages '(projectile 102 | clojure-mode 103 | cider)) 104 | (dolist (p clojure-packages) 105 | (unless (package-installed-p p) 106 | (package-install p))) 107 | 108 | Here is my `~/.lein/profiles.clj`. I'm not sure why it is as it is, 109 | but it seems to be harmless: 110 | 111 | ; Sample profile: https://gist.github.com/jamesmacaulay/5603176 112 | {:repl {:dependencies [[org.clojure/tools.namespace "0.2.11"]] 113 | :injections [(require '(clojure.tools.namespace repl find))] 114 | :plugins [[cider/cider-nrepl "0.15.1"]]}} 115 | 116 | ## Gnuplot 117 | 118 | Gnuplot must be installed in order for plotting to work. If you want 119 | to make plots and do not already have gnuplot, you are on your own 120 | because I don't remember how I installed it. 121 | 122 | ----- 123 | 124 | NEXT: [Using metaprob-in-clojure](doc/interaction.md) 125 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: bin/lein .lein_classpath 2 | @echo "Good to go!" 3 | 4 | # Starts a ClojureScript REPL 5 | cljs: 6 | clojure -Acljs -m cljs.main --repl-env nashorn --repl 7 | .PHONY: cljsrepl 8 | 9 | cljstest: 10 | clojure -Acljs:cljstest 11 | .PHONY: cljstest 12 | 13 | cljsclean: 14 | rm -Rf out 15 | 16 | cljsselftest: 17 | plk -c`clojure -Acljs:test -Spath` -m metaprob.test-runner 18 | .PHONY: cljsselftest 19 | 20 | # This target is referenced in README.md 21 | bin/lein: 22 | wget "https://raw.githubusercontent.com/technomancy/leiningen/stable/bin/lein" 23 | mv lein bin/lein 24 | chmod +x bin/lein 25 | @echo "Please add the 'lein' command to you PATH, e.g." 26 | @echo ln -sf $$PWD/bin/lein ~/bin/lein 27 | 28 | # By setting java's classpath explicitly, instead of relying on 'lein' 29 | # to do it for us, we cut the number of Java VM startups in half. 30 | # This is a significant speedup. 31 | # I got this hack from stackoverflow. 32 | .lein_classpath: bin/lein 33 | bin/lein classpath > $@ 34 | 35 | # Incudes long-running tests 36 | test: cljtest cljtestlong cljstest cljsselftest 37 | .PHONY: test 38 | 39 | cljtest: 40 | clojure -Atest 41 | .PHONY: cljtest 42 | 43 | cljtestlong: 44 | clojure -Atest -d src -n metaprob.examples.long-test 45 | .PHONY: cljtestlong 46 | 47 | # Create directory of .trace files from .vnts files. 48 | # Requires python-metaprob. 49 | # NOTE: This must run with the metaprob virtualenv active! 50 | parse: ../metaprob/pythenv.sh python python/transcribe.py 51 | bin/parse-all 52 | 53 | # Create directory of .clj files from .trace files 54 | convert: src/metaprob/main.clj src/metaprob/to_clojure.clj .lein_classpath 55 | bin/lein compile :all 56 | bin/convert-all 57 | 58 | # General rule for converting a .vnts (metaprob) file to a .trace file. 59 | # The rule is for documentation purposes; I don't think it's used. 60 | # Requires python-metaprob. 61 | # NOTE: This must run with the metaprob virtualenv active! 62 | %.trace: %.vnts 63 | ../metaprob/pythenv.sh python python/transcribe.py -f $< $@.new 64 | mv -f $@.new $@ 65 | 66 | # General rule for converting a .trace file to a .clj file 67 | # The rule is for documentation purposes; I don't think it's used. 68 | %.clj: %.trace .lein_classpath 69 | java -cp `cat .lein_classpath` metaprob.main $< $@.new 70 | mv -f $@.new $@ 71 | 72 | # If you get errors with 'lein compile :all' try just 'lein 73 | # compile'. I don't understand the difference. 74 | # To change number of samples, pass the number on the command line. 75 | # 2000 is a good number, but it takes hours to run. 76 | # 5 is good for smoke tests. To get more, you can say e.g. 77 | # make view COUNT=100 78 | COUNT=10 79 | exa: results/samples_from_the_gaussian_demo_prior.samples 80 | 81 | SAMPLES=results/samples_from_the_gaussian_demo_prior.samples 82 | 83 | $(SAMPLES): 84 | mkdir -p results 85 | clojure -Aexamples -a --samples $(COUNT) 86 | 87 | $(SAMPLES).png: $(SAMPLES) 88 | for f in results/*.samples; do bin/gnuplot-hist $$f; done 89 | 90 | view: $(SAMPLES).png 91 | open results/*.png 92 | 93 | # suppress '.#foo.clj' somehow 94 | tags: 95 | etags --language=lisp `find src -name "[a-z]*.clj"` `find test -name "[a-z]*.clj"` 96 | 97 | # Targets for manipulating Docker below. 98 | docker-build: 99 | mkdir -p $(HOME)/.m2/repository 100 | docker build -t probcomp/metaprob-clojure:latest . 101 | .PHONY: docker-build 102 | 103 | docker-test: 104 | docker run --rm -t probcomp/metaprob-clojure:latest bash -c "make test" 105 | .PHONY: docker-test 106 | 107 | docker-bash: 108 | docker run \ 109 | -it \ 110 | --mount type=bind,source=${HOME}/.m2,destination=/home/metaprob/.m2 \ 111 | --mount type=bind,source=${CURDIR},destination=/home/metaprob/projects/metaprob-clojure \ 112 | probcomp/metaprob-clojure:latest \ 113 | bash 114 | .PHONY: docker-cider 115 | 116 | docker-repl: 117 | docker run \ 118 | -it \ 119 | --mount type=bind,source=${HOME}/.m2,destination=/home/metaprob/.m2 \ 120 | --mount type=bind,source=${CURDIR},destination=/home/metaprob/projects/metaprob-clojure \ 121 | probcomp/metaprob-clojure:latest \ 122 | bash -c "sleep 1;clj" 123 | # For more information on why this sleep is necessary see this pull request: 124 | # https://github.com/sflyr/docker-sqlplus/pull/2 125 | .PHONY: docker-repl 126 | 127 | docker-notebook: 128 | docker run \ 129 | -it \ 130 | --mount type=bind,source=${HOME}/.m2,destination=/home/metaprob/.m2 \ 131 | --mount type=bind,source=${CURDIR},destination=/home/metaprob/projects/metaprob-clojure \ 132 | --publish 8888:8888/tcp \ 133 | probcomp/metaprob-clojure:latest \ 134 | bash -c "lein jupyter notebook \ 135 | --ip=0.0.0.0 \ 136 | --port=8888 \ 137 | --no-browser \ 138 | --NotebookApp.token= \ 139 | --notebook-dir ./tutorial" 140 | .PHONY: docker-notebook 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Metaprob 2 | 3 | [![Build Status](https://travis-ci.org/probcomp/metaprob.svg?branch=master)](https://travis-ci.org/probcomp/metaprob) 4 | ![Stability: Experimental](https://img.shields.io/badge/stability-experimental-orange.svg) 5 | 6 | A language for probabilistic programming and metaprogramming, embedded in [Clojure](https://clojure.org/). 7 | 8 | **Note: Metaprob is currently an unstable research prototype, with little documentation and low test coverage. Also, future versions may not be backwards compatible with this version. We do not recommend using it for any purpose other than basic research, and are not yet able to support users outside of the MIT Probabilistic Computing Project.** 9 | 10 | ## Key features 11 | 12 | * Models can be represented via generative code, i.e. ordinary code that makes stochastic choices 13 | * Models can also be represented via approximations, e.g. importance samplers with nontrivial weights 14 | * Custom inference algorithms can be written in user-space code, via reflective language constructs for: 15 | * tracing program executions 16 | * using partial traces to specify interventions and constraints 17 | * Generic inference algorithms are provided via user-space code in a standard library; adding new algorithms does not require modifying the language implementation 18 | * All Inference algorithms are ordinary generative code and can be traced and treated as models 19 | * New probability distributions and inference algorithms are first-class citizens that can be created dynamically during program execution 20 | 21 | ## Motivations 22 | 23 | * Lightweight embeddings of probabilistic programming and inference metaprogramming 24 | * Interactive, browser-based data analysis tools (via [ClojureScript](https://clojurescript.org/)) 25 | * Smart data pipelines suitable for enterprise deployment (via Clojure on the JVM) 26 | * “Small core” language potentially suitable for formal specification and verification 27 | * Teaching 28 | * Undergraduates and graduate students interested in implementing their own minimal PPL 29 | * Software engineers and data engineers interested in probabilistic modeling and inference 30 | * Research in artificial intelligence and cognitive science 31 | * Combining symbolic and probabilistic reasoning, e.g. via integration with Clojure’s [core.logic](https://github.com/clojure/core.logic) 32 | * “Theory of mind” models, where an agent’s reasoning is modeled as an inference metaprogram acting on a generative model 33 | * Reinforcement learning and other “nested” applications of modeling and approximate inference 34 | * Causal reasoning, via a notion of interventions that extends Pearl's “do” operator 35 | * Research in probabilistic meta-programming, e.g. synthesis, reflection, runtime code generation 36 | 37 | ## Modeling and tracing 38 | 39 | Generative models are represented as ordinary functions that make stochastic choices. 40 | 41 | ```clojure 42 | ;; Flip a fair coin n times 43 | (def fair-coin-model 44 | (gen [n] 45 | (map (fn [i] (at i flip 0.5)) (range n)))) 46 | ;; Flip a possibly weighted coin n times 47 | (def biased-coin-model 48 | (gen [n] 49 | (let [p (at "p" uniform 0 1)] 50 | (map (fn [i] (at i flip p)) (range n))))) 51 | ``` 52 | 53 | Execution traces of models, which record the random choices they make, are first-class values that inference algorithms can manipulate. 54 | 55 | We obtain scored traces using `infer-and-score`, which invokes a “tracing interpreter” that is itself a Metaprob program. 56 | 57 | ```clojure 58 | (infer-and-score :procedure fair-coin-model, :inputs [3]) 59 | ``` 60 | 61 | ## Documentation 62 | 63 | * [Contributor installation instructions](INSTALL.md) 64 | * [Using Metaprob](doc/interaction.md) 65 | * [Language reference](doc/language.md) 66 | -------------------------------------------------------------------------------- /bin/convert: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Convert an original-metaprob source file (.trace) to clojure (.clj) 4 | 5 | set -e 6 | 7 | function err { 8 | echo "$@" 1>&2 9 | exit 1 10 | } 11 | 12 | [ $# = 3 ] || err "wna" 13 | 14 | source=$1 15 | namespace=$2 16 | dest=$3 17 | 18 | # Now do the convert command 19 | mkdir -p `dirname $dest` 20 | echo Converting $source '->' $dest 21 | if true ; then 22 | java -cp `cat .lein_classpath` metaprob.main $source $dest.new $namespace 23 | mv -f $dest.new $dest 24 | else 25 | echo "** namespace = " $namespace 26 | fi 27 | -------------------------------------------------------------------------------- /bin/convert-all: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sourcedir=parsings 4 | destdir=converted 5 | 6 | set -e 7 | 8 | dashtounderscore=yes 9 | 10 | for f in `cd parsings; find . -name "*.trace"`; do 11 | 12 | # Strip off leading ./ 13 | f=${f:2} 14 | 15 | # Separate directory path from file name 16 | dir=`dirname $f` 17 | base=`basename $f .trace` 18 | 19 | # https://stackoverflow.com/questions/24077667/bash-replace-slash 20 | namespace=metaprob.${dir//\//.}.$base 21 | destpath=$dir/$base 22 | 23 | # Clojure prefers underscores in file names; don't know why. 24 | if [ $dashtounderscore = yes ]; then 25 | # See bash manual under '${parameter/pattern/string}' 26 | destpath=${destpath//-/_} 27 | fi 28 | 29 | # Now do the convert command 30 | source=$sourcedir/$f 31 | dest=$destdir/$destpath.clj 32 | bin/convert "$source" "$namespace" "$dest" 33 | done 34 | -------------------------------------------------------------------------------- /bin/gnuplot-hist: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | data=$1 4 | commands=$1.commands 5 | dest=$data.png 6 | 7 | cp -p "$data" gnuplot-hist.tmp 8 | if [ -e $commands ]; then 9 | echo "commands" 10 | cat $commands bin/gnuplot-hist.gp | gnuplot - 11 | else 12 | gnuplot bin/gnuplot-hist.gp 13 | fi 14 | mv gnuplot-hist-tmp.png $dest 15 | 16 | echo Plot written to $dest 17 | 18 | -------------------------------------------------------------------------------- /bin/gnuplot-hist.gp: -------------------------------------------------------------------------------- 1 | # https://gnuplot-surprising.blogspot.com/2011/09/statistic-analysis-and-histogram.html 2 | width=(max-min)/n #interval width 3 | #function used to map a value to the intervals 4 | hist(x,width)=width*floor(x/width)+width/2.0 5 | set term png #output terminal and file 6 | set output "gnuplot-hist-tmp.png" 7 | set xrange [min:max] 8 | set yrange [0:] 9 | #to put an empty boundary around the 10 | #data inside an autoscaled graph. 11 | set offset graph 0.05,0.05,0.05,0.0 12 | set xtics min,(max-min)/5,max 13 | set boxwidth width*0.9 14 | set style fill solid 0.5 #fillstyle 15 | set tics out nomirror 16 | set xlabel "x" 17 | set ylabel "Frequency" 18 | #count and plot 19 | plot "gnuplot-hist.tmp" \ 20 | u (hist($1,width)):(1.0) smooth freq w boxes lc rgb"green" notitle 21 | -------------------------------------------------------------------------------- /bin/parse-all: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # NOTE: This must run with the metaprob virtualenv active! 4 | 5 | 6 | set -e 7 | 8 | mkdir -p parsings 9 | 10 | function doit { 11 | 12 | root=$1 13 | 14 | paths=`cd ../metaprob; find $root -name "*.vnts"` 15 | 16 | for path in $paths; do 17 | f=`basename $path .vnts` 18 | if [ $f != propose-subtrace ]; then 19 | d=`dirname $path` 20 | mkdir -p "parsings/$d" 21 | dest="parsings/$d/$f.trace" 22 | echo "writing $dest" 23 | ../metaprob/pythenv.sh python python/transcribe.py -f "../metaprob/$path" > "$dest".new 24 | mv "$dest".new "$dest" 25 | fi 26 | done 27 | 28 | } 29 | 30 | doit src 31 | doit metacirc 32 | doit examples 33 | -------------------------------------------------------------------------------- /deps.edn: -------------------------------------------------------------------------------- 1 | {:deps {kixi/stats {:mvn/version "0.4.0"} 2 | org.apache.commons/commons-math3 {:mvn/version "3.6.1"} 3 | org.clojure/clojure {:mvn/version "1.9.0"} 4 | org.clojure/tools.cli {:mvn/version "0.4.1"}} 5 | 6 | :aliases {:examples {:jvm-opts ["-Xss50M" "-Dhttps.protocols=TLSv1.2"] 7 | :main-opts ["-m" "metaprob.examples.main"]} 8 | :test {:extra-paths ["test"] 9 | :extra-deps {com.cognitect/test-runner {:git/url "https://github.com/cognitect-labs/test-runner.git" 10 | :sha "028a6d41ac9ac5d5c405dfc38e4da6b4cc1255d5"}} 11 | ;; Default stack size is 1MB or less, increase to 50. For more 12 | ;; information on `java` options: 13 | ;; https://docs.oracle.com/javase/8/docs/technotes/tools/unix/java.html 14 | :jvm-opts ["-Xss50M" "-Dhttps.protocols=TLSv1.2"] 15 | :main-opts ["-m" "cognitect.test-runner"]} 16 | 17 | :cljs {:extra-deps {org.clojure/clojurescript {:mvn/version "1.10.339"}}} 18 | :cljstest {:extra-paths ["test"] 19 | :main-opts ["-m" "cljs.main" 20 | "-re" "node" 21 | "-m" "metaprob.test-runner"]}}} 22 | -------------------------------------------------------------------------------- /doc/interaction.md: -------------------------------------------------------------------------------- 1 | # Using Metaprob 2 | 3 | This page is about the mechanics of using Metaprob. For 4 | information about the Metaprob 'language' see [language.md](language.md). 5 | 6 | There are many ways to work with Metaprob. They are the 7 | same as the ways one works with Clojure. Generally you alternate 8 | writing functions with exploration (including testing). 9 | 10 | * From a Clojure read-eval-print loop (REPL) 11 | * running Clojure direct from the shell (command line) 12 | * running Clojure under emacs 13 | * From files 14 | * putting code in files 15 | * writing tests and trying them out 16 | * running code via 'main' programs 17 | 18 | ## Using the Clojure REPL 19 | 20 | The REPL can be started from the shell, or from emacs. With emacs you 21 | get many desirable tools such as automatic indentation and namespace 22 | system integration, but the setup is more complex and the learning 23 | curve is steep if you haven't used emacs before. 24 | 25 | ### Using Clojure under the shell 26 | 27 | To get an interactive read-eval-print loop at the shell: 28 | 29 | $ clj 30 | 31 | This should be done with the working directory set to the directory that 32 | contains `deps.edn`, which in this setup would normally be the clone of the 33 | `metaprob` repository. 34 | 35 | ### Using Clojure under Emacs 36 | 37 | This section describes using a REPL in emacs. The Emacs interface to 38 | Clojure is called 'Cider'. (Or at least the interface that I know about.) 39 | 40 | Cider also knows to look at `deps.edn`. 41 | 42 | #### Starting a REPL under Emacs 43 | 44 | In emacs, do: 45 | 46 | M-x cider-jack-in RET 47 | 48 | (Alternatively, the fact that the Clojure connection goes over TCP/IP means you 49 | can put the server anywhere on the Internet that you want to and connect with 50 | `cider-connect`. The best way to connect is with an ssh tunnel, but that is a 51 | story for another day.) 52 | 53 | #### Emacs commands 54 | 55 | Evaluate an expression from a buffer with `C-c C-E`. Cider knows how to 56 | figure out which namespace to evaluate in (see below). 57 | 58 | You can load a clojure file by visiting it in a buffer and doing `C-c 59 | C-k`. I think doing so will also load any needed dependencies as 60 | inferred from the `ns` form at the top of the file. 61 | 62 | C-h m will show you commands available in REPL mode. 63 | There's more documentation of Cider somewhere. 64 | 65 | ### Simple scratch namespace for playing around 66 | 67 | Exploration from the REPL is a little bit annoying due to the Clojure 68 | namespace system. Unlike in Common Lisp, which has its `::` syntax, 69 | you can't even access a namespace unless you've previously 'required' 70 | it. 71 | 72 | To get started quickly, you can just switch to the examples namespace, 73 | after which you won't have to think about namespaces until you want to 74 | create new Clojure/Metaprob modules. Enter the following at the REPL 75 | to load Metaprob: 76 | 77 | (require 'metaprob.examples.all) 78 | (in-ns 'metaprob.examples.all) 79 | 80 | The `metaprob.examples.all` namespace imports all of the Metaprob 81 | namespaces, meaning that all Metaprob bindings are directly available 82 | and you don't have to worry about the Clojure namespace system after 83 | this point. (This is fine for experimentation, but for more durable 84 | programming I recommend convnetional use of the Clojure namespace 85 | system, requiring only those namespaces you use.) 86 | 87 | You can then evaluate metaprob expressions directly, run examples, and so on: 88 | 89 | (def x (trace-set-value {} "foo" 17)) 90 | (trace-value x "foo") 91 | (pprint x) 92 | 93 | ### Creating namespaces for use with metaprob 94 | 95 | Metaprob has some name conflicts with Clojure, so some care is 96 | necessary when preparing files containing Metaprob code. The method is 97 | described in [language.md](language.md). 98 | 99 | 100 | ### Interaction in other namespaces 101 | 102 | It is possible to use metaprob functions and macros from any Clojure 103 | namespace, if Metaprob namespaces are made accessible using `require`. 104 | 105 | For example, the default Clojure namespace, `user`, starts out 106 | knowing nothing about the metaprob namespaces. For any kind of access 107 | from `user`, you need to use `require`. E.g. 108 | 109 | > (require '[metaprob.trace]) 110 | 111 | after which 112 | 113 | > (metaprob.trace/trace-set-value {} "foo" 17) 114 | 115 | Namespace names are typically long, so it's useful to define namespace 116 | prefixes to abbreviate them. This is what `:as` is for: 117 | 118 | > (require '[metaprob.trace :as trace]) 119 | 120 | After which: 121 | 122 | > (trace/trace-value x "foo") 123 | > (trace/addresses-of x) 124 | 125 | and so on. Alternatively, you can access bindings without using a 126 | prefix at all by 'requiring' a namespace with `:refer :all`: 127 | 128 | > (require '[metaprob.trace :refer :all]) 129 | 130 | After which: 131 | 132 | > (trace-value x "foo") 133 | > (addresses-of x) 134 | 135 | ### Refreshing the state 136 | 137 | There is a `refresh` function that reloads your project, giving an 138 | alternative to manually visiting each changed buffer and doing C-c 139 | C-k. Make `refresh` available at the REPL's `user` namespace with 140 | 141 | > (require '[clojure.tools.namespace.repl :refer [refresh]]) 142 | 143 | and invoke it with 144 | 145 | > (refresh) 146 | 147 | (I guess this `require` could be put in project.clj or 148 | .lein/profiles.clj so that it happens every time you start Clojure? 149 | Need to look into this.) 150 | 151 | See [this stack overflow discussion](https://stackoverflow.com/questions/7658981/how-to-reload-a-clojure-file-in-repl). 152 | 153 | There are other circumstances that require a complete Clojure restart. 154 | 155 | This is a pain because it can take a minute or so to kill 156 | any running clojure under emacs, restart the REPL, connect to the new 157 | REPL, and reload the project. Therefore other alternative interaction 158 | modes (see below) may sometimes be preferable. 159 | 160 | 161 | ## Using files 162 | 163 | ### Files/modules in their own namespace 164 | 165 | Clojure talks about files and modules, but I'm not clear on the 166 | difference. I think they are in 1-1 correspondence. Each file has 167 | its own "namespace", thus another 1-1 correspondence. 168 | 169 | The top of a typical metaprob file, say `myproject/myfile.clj`, would 170 | look something like: 171 | 172 | (ns myproject.myfile 173 | (:refer-clojure :exclude [map replicate apply]) 174 | (:require [metaprob.generative-functions :refer :all] 175 | [metaprob.prelude :refer :all] 176 | [metaprob.distributions :refer :all] 177 | [metaprob.trace :refer :all] 178 | [metaprob.inference :refer :all])) 179 | 180 | If one of these imported modules isn't needed it can be left out of 181 | the list. 182 | 183 | At the REPL you can switch into any file's namespace using `in-ns`, e.g. 184 | 185 | > (in-ns 'myproject.myfile) 186 | 187 | This can be useful but I find I almost never do it. Instead I usually 188 | evaluate expressions from inside files using emacs, or I use the test 189 | system, or the user environment or `metaprob.examples.all`. Your 190 | mileage may vary. 191 | 192 | ### Unit tests 193 | 194 | There is documentation on the test system and it should be consulted, 195 | as it doesn't make sense to repeat all that information here. 196 | 197 | Basically `test` tree in a given project directory is parallel to the 198 | `src` tree. Each `.clj` file in `test/` contains unit tests for the 199 | corresponding file in `src/`. 200 | 201 | It is possible to do most or all development and testing using unit 202 | tests instead of the REPL. The disadvantages of tests compared to the REPL are 203 | 204 | * tests are not so good for exploration (what does this do? what 205 | is in this data structure?) 206 | * it takes 3-5 seconds to start tests going, whereas the REPL has no delay 207 | 208 | You can run tests either from the shell or from inside Clojure. From 209 | the Clojure REPL: Run tests for all modules in the project: 210 | 211 | (require 'metaprob.examples.all) 212 | (require '[clojure.test :refer :all]) 213 | (run-all-tests) 214 | 215 | Single module test: the two `require`s above, plus: 216 | 217 | (require 'metaprob.trace-test) 218 | (run-tests 'metaprob.trace-test) 219 | 220 | Or, from the shell: Run tests for all modules in the project: 221 | 222 | clojure -Atest 223 | 224 | Single module test: 225 | 226 | clojure -Atest -n metaprob.trace-test 227 | 228 | Don't forget the `-test` in the module names, and `_test` in the file 229 | names! I have spent a lot of time being confused, first because I 230 | hadn't realized the `-test` was needed, and later because I just 231 | forgot it. No error is reported when you get this wrong. 232 | 233 | I like for tests that reside in the test system to run quickly so that 234 | I can run `clojure -Atest` frequently and not have to wait a long time. 235 | It's good to run all the tests frequently, and if this were a slow 236 | operation I would be put off, and would run them less often. 237 | 238 | ### Running code noninteractively from the shell using the `-main` feature 239 | 240 | Rather than use the REPL or the test system I sometimes just put code 241 | in the `-main` function in some file e.g. `main.clj` (could be any 242 | file name) and invoke it directly from the shell: 243 | 244 | $ clojure -m metaprob.examples.main 245 | 246 | Any command line arguments become arguments to the `-main` function. 247 | -------------------------------------------------------------------------------- /doc/language.md: -------------------------------------------------------------------------------- 1 | # Metaprob language reference manual 2 | 3 | Well not really, but here are some notes. 4 | 5 | Metaprob is really just Clojure with some macros and functions. 6 | You'll need some familiarity with Clojure in 7 | order to use Metaprob. Introducing Clojure is beyond the current 8 | scope, so find some independent way to get started with Clojure. 9 | 10 | You may want to create a project (a set of files depending on metaprob 11 | and on one another). This is done with `lein` which has its own 12 | documentation. (There are other project tools besides `lein`. I talk 13 | about `lein` only because it's the only one I'm familiar with.) 14 | 15 | 16 | ## The Clojure namespaces that implement Metaprob 17 | 18 | Metaprob is provided as a set of Clojure namespaces that can 19 | be used in the usual clojure way, with `require` and so on. 20 | 21 | A typical Metaprob program `require`s the following namespaces: 22 | (all namespace names begin with `metaprob.` , which I'll omit for readability) 23 | 24 | * `generative-functions` - provides functions for creating and using generative functions, and also the `gen` macro. 25 | * `prelude` - exposes utility procedures like eager versions of `map` and `replicate` 26 | * `trace` - implements the trace datatype and its methods 27 | * `distributions` - implements nondeterministic primitives like `flip` and `uniform` as generative functions 28 | * `inference` - provides a standard library for inference programming 29 | 30 | In Clojure code, you can import these namespaces with `:as` or 31 | `:refer` with an explicit list of names, but these namespaces provide 32 | all the basic language primitives and it is a pain to write namespace 33 | prefixes for them all the time in Metaprob source code, where they are 34 | uniquitous. Therefore, in a Metaprob source file, you generally 35 | import the namespaces with `:refer :all`. Because some names in 36 | `prelude` conflict with the usual clojure bindings, you 37 | need to suppress `clojure.core, which you do with 38 | 39 | (ns (:refer-clojure :exclude [map replicate apply]) ...) 40 | 41 | So a typical Metaprob source file might start like this: 42 | 43 | (ns thisand.that 44 | (:refer-clojure :exclude [map replicate apply]) 45 | (:require [metaprob.generative-functions :refer :all]) 46 | (:require [metaprob.trace :refer :all]) 47 | (:require [metaprob.prelude :refer :all]) 48 | (:require [metaprob.inference :refer :all]) 49 | (:require [metaprob.distributions :refer :all])) 50 | 51 | ## `generative-functions` namespace 52 | 53 | ([Source](../src/metaprob/generative_functions.clj)) 54 | 55 | * `(make-generative-function run-in-clj make-constrained-generator)` returns a custom generative function. 56 | * `(gen {:annotation value, ...} [formal ...] body)` returns a generative function based on generative code. 57 | * `(make-primitive sampler scorer)` returns a generative function implementing a primitive distribution. 58 | * `(make-constrained-generator f obs)` returns a constrained version of `f` given an observation trace. 59 | 60 | ### Addresses 61 | 62 | A child of a trace is given by a key; each child of a trace has a 63 | different key. An 'address' is either a single key or a sequence of 64 | keys specifying, for a given trace t, the descendent of t obtained by 65 | following the path given by the key sequence. An address that is a 66 | key or a singleton specifies that child of a given trace, while longer 67 | addresses specify deeper descendants, and an empty address specifies 68 | the trace itself. 69 | 70 | ### Traces 71 | 72 | In the following, 'A' is an address. 73 | 74 | * `(trace-value t)` - get the value at t 75 | * `(trace-value t A)` - get the value of the A subtrace of t 76 | * `(trace-has-value? t)` - is there a value at t? 77 | * `(trace-has-value? t A)` - does t's A subtrace have a value? 78 | * `(trace-subtrace t A)` - get the subtrace of t at A 79 | * `(trace-has-subtrace? t A)` - does t have a subtrace at A? 80 | * `(trace? x)` - true if x is a trace 81 | 82 | * `(trace-set-value t A x)` - return a new trace that's the same as t except at address A, where one will find the value x 83 | * `(trace-set-subtrace t A u)` - a trace like t, in which u is at address A. 84 | 85 | * `(trace-keys t)` - list of keys for children 86 | * `(addresses-of t)` - the list of addresses, relative to t, of every descendant subtrace that has a value 87 | 88 | * `(trace-merge t1 t2)` - return a trace that is the merge of t1 and t2 (union of respective children, recursively) 89 | * `(trace-clear t A)` - remove the value at A 90 | * `(trace-clear-subtrace t A)` - remove the subtrace at A 91 | * `(merge-subtrace t A u)` - replace t's A subtrace with the result of merging it with u 92 | * `(partition-trace t As)` - if As is a list of addresses, return a pair of traces `[u v]` where `u` contains all of 93 | the addresses in As, and `v` contains the others, such that merging `u` and `v` recovers `t`. 94 | 95 | ### Output 96 | 97 | * `binned-histogram` - many options - writes out files that can be 98 | displayed using the `bin/gnuplot-hist` script. A set of samples and a 99 | control file are written to the `results` directory (which must be 100 | created manually). 101 | 102 | `bin/gnuplot-hist` is a shell script that uses gnuplot to create a 103 | .png file from a samples file written by `binned-histogram`. E.g. 104 | 105 | bin/gnuplot-hist results/samples_from_the_prior.samples 106 | open results/samples_from_the_prior.samples.png 107 | 108 | (where `open` is the MacOS command for opening a file using the 109 | appropriate Mac application). 110 | 111 | This is pretty much of a kludge right now. 112 | 113 | Gnuplot must be installed in order for this to work. 114 | 115 | ## `prelude` namespace 116 | 117 | ([Source](../src/metaprob/prelude.clj)) 118 | 119 | documentation TBD 120 | 121 | ## `distributions` namespace 122 | 123 | [Source](../src/metaprob/distributions.clj) 124 | 125 | * `(flip w)` - returns true with probability w, false otherwise 126 | * `(uniform a b)` - floating point number drawn from [a,b] 127 | * `(uniform-discrete items)` - one of the members of items (a list) 128 | * `(categorical probabilities)` - returns 0, 1, ... with probability of i proportional to probabilities[i] 129 | * `(log-categorical scores)` - returns 0, 1, ... with probability of i proportional to exp(score[i]) 130 | * `(gaussian mu sigma)` - samples from a normal distribution 131 | * `(beta a b)` - samples from a beta distribution 132 | * `(gamma shape scale)` - samples from a gamma distribution 133 | -------------------------------------------------------------------------------- /doc/metaprob-in-clojure.txt: -------------------------------------------------------------------------------- 1 | 2 | Need a syntax for expressing metaprob programs in clojure. 3 | 4 | Need a transformation from 'plain' metaprob-in-clojure to 'expanded' 5 | metaprob-in-clojure with explicit sequencing of trace and score 6 | propagation. 7 | 8 | 9 | 1. Syntax - need to handle everything that's used in the chapter. 10 | (The parser may implement constructs that aren't used.) 11 | 12 | 2. Primitive functions - similarly. Many of them come from venture, 13 | but not all of those are used. Bring them in on demand while working 14 | through all the examples. 15 | 16 | The following (form propose-and-trace-choices.vnts) seems to be the complete set: 17 | 18 | application 19 | variable 20 | literal 21 | program (i.e. lambda) 22 | if 23 | block (progn) 24 | [ ... ] - tuple - no splicing, static subkeys 25 | I don't think this is ever used. 26 | definition 27 | this 28 | with_address 29 | 30 | also: /a/b/c 31 | 32 | The grammar lists the following, but I haven't yet seen them in the 33 | chapter: 34 | 'letvalues' (vars) = body 35 | assignment var := expression - hairy left hand side syntax? 36 | handled by trace_set primitive. 37 | 38 | (body is any experession, I think, not just {...}) 39 | 40 | Also: implemented by functions: 41 | e := e including e[e] := e 42 | *e 43 | x has_value 44 | del x ? 45 | e[e] 46 | e[x:end] 47 | {{ ... }} - implemented by 6 functions 48 | 49 | Many of the main expression types have obvious clojure equivalents: 50 | 51 | application 52 | variable 53 | literal 54 | if 55 | block (progn) 56 | 57 | Need to examine more carefully: 58 | 59 | program (i.e. arrow, lambda) (vars) -> body 60 | - I think this is just (fn [vars] body). 61 | tuple - no splicing, static subkeys 62 | - Makes a tuple or 'array'. Can be done as an application, 63 | but be careful about addresses (/0/ must be beginning of array, 64 | but in an application /0/ is the function) 65 | definition 66 | - assigns a variable, as opposed to := which is for traces 67 | this 68 | with_address /a/b/c/ : body 69 | There are ten uses of with_address in the .vnts files. 70 | Many involve /$i/ setting the address inside a map or something. 71 | 72 | 73 | Builtins: 74 | src/builtin.py: 75 | 74 occurrences of register_builtin 76 | is_vpair 77 | vfirst 78 | vlist 79 | vrest 80 | print_py_data 81 | render_address 82 | add 83 | interpret_prim 84 | py_match_bind 85 | etc. etc. 86 | list_to_array 87 | etc. etc. 88 | discrete_histogram 89 | 90 | builtInSPs from venture: populated by 203 calls to 91 | registerBuiltinSP() over 15 source files. 92 | grep -r registerBuiltinSP ../Venturecxx >tmp.tmp 93 | 94 | basic_sps.py 95 | eq neq ft fte lt lte real atom atom_index integer probability 96 | not xor all_p any_p is_number is_probability is_atom 97 | ... dump_data dump_py_data load_py_data 98 | 99 | cmvn.py 100 | make_niw_normal 101 | 102 | conditionals.py 103 | biplex 104 | 105 | continuous.py 106 | multivariate_normal inv_wishart ... 107 | 108 | crp.py 109 | make_crp 110 | 111 | dirichlet.py 112 | dirichlet 113 | symmetric_dirichlet 114 | make_dir_cat 115 | ... 116 | 117 | discrete.py 118 | flip 119 | bernoulli 120 | log_flip 121 | log_bernoulli 122 | log_odds_flip 123 | log_odds_bernoulli 124 | binomial 125 | ... 126 | 127 | eval_sps.py 128 | get_current_environment 129 | get_empty_environment 130 | is_environment 131 | extend_environment 132 | eval 133 | address_of 134 | 135 | function.py 136 | apply_function 137 | 138 | functional.py 139 | apply 140 | mapv 141 | mapv2 142 | imapv 143 | fix 144 | assess 145 | 146 | gp.py 147 | make_gp 148 | gp_mean_const 149 | gp_cov_const 150 | ... 151 | 152 | hmm.py 153 | make_lazy_hmm 154 | 155 | msp.py 156 | mem 157 | 158 | records.py 159 | is_ 160 | 161 | scope.py 162 | tag 163 | tag_exclude 164 | 165 | vectors.py 166 | array 167 | vector 168 | is_array 169 | is_vector 170 | to_array 171 | to_vector 172 | matrix 173 | is_matrix 174 | simplex is_simplex 175 | normalize arrange fill 176 | 177 | Also 66 definitions in prelude.vnts. 178 | 179 | 180 | grep output, file builtin.py: 181 | 182 | -- venture names collide with metaprob 183 | register_builtin("is_vpair", SPFromLite(sp)) 184 | register_builtin("vfirst", SPFromLite(sp)) 185 | register_builtin("vlist", SPFromLite(sp)) 186 | register_builtin("vrest", SPFromLite(sp)) 187 | register_builtin("inverse_gamma", SPFromLite(sp)) 188 | register_builtin("dump_py_data", Function2SP(dump_py_data_func)) 189 | register_builtin("print_py_data", Function1SP(print_py_data_func)) 190 | register_builtin("render_address", Function1SP(render_address)) 191 | register_builtin("add", Function2SP(metaprob_plus)) 192 | register_builtin("interpret_prim", PythonFunctionSP(interpret_prim)) 193 | register_builtin("py_match_bind", PythonFunctionSP(match_bind_func)) 194 | register_builtin("is_compound_sp_name", Function1SP(is_compound_name)) 195 | register_builtin("is_metaprob_array", Function1SP(is_metaprob_array_func)) 196 | register_builtin("py_propose", PythonFunctionSP(propose_func)) 197 | register_builtin("propose_application", PythonFunctionSP(propose_application_func)) 198 | register_builtin("py_make_env", PythonFunctionSP(make_env)) 199 | register_builtin("is_trace", Function1SP(is_trace)) 200 | register_builtin("subtrace", Function2SP(subtrace)) 201 | register_builtin("trace_has", Function1SP(trace_has)) NOT USED 202 | register_builtin("trace_get", Function1SP(trace_get)) NOT USED 203 | register_builtin("trace_set", Function2SP(trace_set)) NOT USED 204 | register_builtin("trace_update", Function2SP(trace_update)) 205 | register_builtin("trace_clear", Function1SP(trace_clear)) 206 | register_builtin("trace_has_key", Function2SP(trace_has_key)) 207 | register_builtin("trace_subkeys", Function1SP(trace_subkeys)) 208 | register_builtin("trace_sites", Function1SP(trace_sites)) 209 | register_builtin("trace_empty", Function1SP(trace_empty)) 210 | register_builtin("trace_copy", Function1SP(copy_trace)) 211 | register_builtin("set_difference", Function2SP(set_difference)) 212 | register_builtin("contains", Function2SP(contains)) 213 | register_builtin("uniform_categorical", UniformCategoricalSP()) 214 | register_builtin("log_categorical", LogCategoricalSP()) 215 | register_builtin("categorical", CategoricalSP()) 216 | register_builtin("normalize", Function1SP(normalize)) 217 | register_builtin("logsumexp", Function1SP(logsumexp)) 218 | register_builtin("address_to_venture_list", 219 | Function1SP(metaprob_list_to_venture_list)) NOT USED 220 | register_builtin("array_to_address", Function1SP(metaprob_array_to_metaprob_list)) 221 | register_builtin("collection_to_address", Function1SP(metaprob_collection_to_metaprob_list)) 222 | register_builtin("collection_to_array", Function1SP(metaprob_collection_to_metaprob_array)) 223 | register_builtin("array_to_list", Function1SP(metaprob_array_to_metaprob_list)) 224 | register_builtin("list_to_array", Function1SP(metaprob_list_to_metaprob_array)) 225 | register_builtin("list_to_vlist", Function1SP(metaprob_list_to_venture_list)) 226 | register_builtin("array_to_vlist", Function1SP(metaprob_array_to_venture_list)) 227 | register_builtin("vlist_to_array", Function1SP(venture_list_to_metaprob_array)) 228 | register_builtin("array_to_varray", Function1SP(metaprob_array_to_venture_array)) 229 | register_builtin("prob_prog_name", Function1SP(prob_prog_name)) 230 | register_builtin("primitive_backpropagator", Function1SP(primitive_backpropagator)) 231 | register_builtin("destructure_compound_name", Function1SP(destructure_compound_name)) 232 | register_builtin("p_p_plot_2samp_to_file", PythonFunctionSP(p_p_plot_2samp_to_file)) 233 | register_builtin("random_output", Function1SP(random_output)) 234 | register_builtin("resolve_tag_address", Function1SP(resolve_tag_address)) 235 | register_builtin("pair", Function2SP(metaprob_pair)) 236 | register_builtin("list", PythonFunctionSP(metaprob_list)) 237 | register_builtin("mk_nil", PythonFunctionSP(metaprob_nil)) 238 | register_builtin("trace_set_subtrace_at", PythonFunctionSP(trace_set_subtrace_at)) 239 | register_builtin("mk_spl", Function1SP(mk_spl)) 240 | register_builtin("trace_to_graphviz", PythonFunctionSP(trace_to_graphviz_func)) 241 | register_builtin("is_builtin_env", Function1SP(is_builtin_env)) 242 | register_builtin("and", SPFromLite(lite_and_sp)) 243 | register_builtin("or", SPFromLite(lite_or_sp)) 244 | register_builtin("lookup", SPFromLite(updated_lookup_sp)) 245 | register_builtin("pprint", SPFromLite(deterministic_typed(pprint_fun, [t.Data], t.Nil))) 246 | register_builtin("assert", SPFromLite(assert_sp)) 247 | register_builtin("binned_histogram", PythonFunctionSP(binned_histogram_func)) 248 | register_builtin("log_gamma_function", PythonFunctionSP(log_gamma_function)) 249 | register_builtin("discrete_histogram", PythonFunctionSP(discrete_histogram_func)) 250 | register_builtin("discrete_weighted_histogram", PythonFunctionSP(discrete_weighted_histogram_func)) 251 | register_builtin("toplevel_lookup", Function1SP(toplevel_lookup)) 252 | register_builtin("write", Function1SP(write_func)) 253 | register_builtin("string", Function1SP(string_func)) 254 | register_builtin("dereify_tag", Function1SP(dereify_tag)) 255 | 256 | basic_sps.py:registerBuiltinSP("eq", binaryPred(lambda x,y: x.equal(y), 257 | basic_sps.py:registerBuiltinSP("neq", binaryPred(lambda x,y: not x.equal(y), 258 | basic_sps.py:registerBuiltinSP("gt", binaryPred(lambda x,y: x.compare(y) > 0, 259 | basic_sps.py:registerBuiltinSP("gte", binaryPred(lambda x,y: x.compare(y) >= 0, 260 | basic_sps.py:registerBuiltinSP("lt", binaryPred(lambda x,y: x.compare(y) < 0, 261 | basic_sps.py:registerBuiltinSP("lte", binaryPred(lambda x,y: x.compare(y) <= 0, 262 | basic_sps.py:registerBuiltinSP("real", deterministic_typed(lambda x:x, 263 | basic_sps.py:registerBuiltinSP("atom", deterministic_typed(lambda x:x, 264 | basic_sps.py:registerBuiltinSP("atom_index", deterministic_typed(lambda x:x, 265 | basic_sps.py:registerBuiltinSP("integer", deterministic_typed(int, 266 | basic_sps.py:registerBuiltinSP("probability", deterministic_typed(lambda x:x, 267 | basic_sps.py:registerBuiltinSP("not", deterministic_typed(lambda x: not x, 268 | basic_sps.py:registerBuiltinSP("xor", deterministic_typed(lambda x, y: x != y, 269 | basic_sps.py:registerBuiltinSP("all_p", deterministic_typed(all, 270 | basic_sps.py:registerBuiltinSP("any_p", deterministic_typed(any, 271 | basic_sps.py:registerBuiltinSP("is_number", type_test(t.NumberType())) 272 | basic_sps.py:registerBuiltinSP("is_integer", type_test(t.IntegerType())) 273 | basic_sps.py:registerBuiltinSP("is_probability", type_test(t.ProbabilityType())) 274 | basic_sps.py:registerBuiltinSP("is_atom", type_test(t.AtomType())) 275 | basic_sps.py:registerBuiltinSP("is_boolean", type_test(t.BoolType())) 276 | basic_sps.py:registerBuiltinSP("is_symbol", type_test(t.SymbolType())) 277 | basic_sps.py:registerBuiltinSP("is_procedure", type_test(SPType([t.AnyType()], t.AnyType(), 278 | basic_sps.py:registerBuiltinSP("list", deterministic_typed(lambda *args: args, 279 | basic_sps.py:registerBuiltinSP("pair", deterministic_typed(lambda a,d: (a,d), 280 | basic_sps.py:registerBuiltinSP("is_pair", type_test(t.PairType())) 281 | basic_sps.py:registerBuiltinSP("first", deterministic_typed(lambda p: p[0], 282 | basic_sps.py:registerBuiltinSP("rest", deterministic_typed(lambda p: p[1], 283 | basic_sps.py:registerBuiltinSP("second", deterministic_typed(lambda p: p[1][0], 284 | basic_sps.py:registerBuiltinSP("to_list", 285 | basic_sps.py:registerBuiltinSP("zip", deterministic_typed(zip, NOT USED 286 | basic_sps.py:registerBuiltinSP("reverse", deterministic_typed(lambda l: list(reversed(l)), 287 | basic_sps.py:registerBuiltinSP("set_difference", deterministic_typed(lambda l1, l2: [v for v in l1 if v not in l2], 288 | basic_sps.py:registerBuiltinSP("dict", 289 | basic_sps.py:registerBuiltinSP("is_dict", type_test(t.DictType())) 290 | basic_sps.py:registerBuiltinSP("to_dict", 291 | basic_sps.py:registerBuiltinSP("keys", 292 | basic_sps.py:registerBuiltinSP("values", 293 | basic_sps.py:registerBuiltinSP("lookup", deterministic_typed(lambda xs, x: xs.lookup(x), 294 | basic_sps.py:registerBuiltinSP("contains", deterministic_typed(lambda xs, x: xs.contains(x), 295 | basic_sps.py:registerBuiltinSP("size", deterministic_typed(lambda xs: xs.size(), 296 | basic_sps.py:registerBuiltinSP("is_empty", deterministic_typed(lambda xs: xs.size() == 0, 297 | basic_sps.py:registerBuiltinSP("take", deterministic_typed(lambda ind, xs: xs.take(ind), 298 | basic_sps.py:registerBuiltinSP("debug", deterministic_typed(debug_print, 299 | basic_sps.py:registerBuiltinSP("value_error", 300 | basic_sps.py:registerBuiltinSP("name", deterministic_typed(make_name, 301 | basic_sps.py:registerBuiltinSP("dump_data", deterministic_typed( 302 | basic_sps.py:registerBuiltinSP("load_data", deterministic_typed( 303 | basic_sps.py:registerBuiltinSP("dump_py_data", deterministic_typed( 304 | basic_sps.py:registerBuiltinSP("load_py_data", deterministic_typed( 305 | cmvn.py:registerBuiltinSP("make_niw_normal", 306 | conditionals.py:registerBuiltinSP("biplex", no_request(generic_biplex)) 307 | continuous.py:registerBuiltinSP("multivariate_normal", typed_nr(MVNormalOutputPSP(), 308 | continuous.py:registerBuiltinSP("inv_wishart", typed_nr(InverseWishartOutputPSP(), 309 | continuous.py:registerBuiltinSP("wishart", typed_nr(WishartOutputPSP(), 310 | continuous.py:registerBuiltinSP("normalss", typed_nr(NormalOutputPSP(), 311 | continuous.py:registerBuiltinSP("normalsv", typed_nr(NormalsvOutputPSP(), 312 | continuous.py:registerBuiltinSP("normalvs", typed_nr(NormalvsOutputPSP(), 313 | continuous.py:registerBuiltinSP("normalvv", typed_nr(NormalvvOutputPSP(), 314 | continuous.py:registerBuiltinSP("normal", no_request(generic_normal)) 315 | continuous.py:registerBuiltinSP('lognormal', typed_nr(LogNormalOutputPSP(), 316 | continuous.py:registerBuiltinSP("vonmises", typed_nr(VonMisesOutputPSP(), 317 | continuous.py:registerBuiltinSP("uniform_continuous",typed_nr(UniformOutputPSP(), 318 | continuous.py:registerBuiltinSP("log_odds_uniform", typed_nr(LogOddsUniformOutputPSP(), 319 | continuous.py:registerBuiltinSP("beta", typed_nr(BetaOutputPSP(), 320 | continuous.py:registerBuiltinSP("log_beta", typed_nr(LogBetaOutputPSP(), 321 | continuous.py:registerBuiltinSP("log_odds_beta", typed_nr(LogOddsBetaOutputPSP(), 322 | continuous.py:registerBuiltinSP("expon", typed_nr(ExponOutputPSP(), 323 | continuous.py:registerBuiltinSP("gamma", typed_nr(GammaOutputPSP(), 324 | continuous.py:registerBuiltinSP("student_t", typed_nr(StudentTOutputPSP(), 325 | continuous.py:registerBuiltinSP("inv_gamma", typed_nr(InvGammaOutputPSP(), 326 | continuous.py:registerBuiltinSP("laplace", typed_nr(LaplaceOutputPSP(), 327 | continuous.py:registerBuiltinSP("make_nig_normal", typed_nr(MakerCNigNormalOutputPSP(), 328 | continuous.py:registerBuiltinSP("make_uc_nig_normal", typed_nr(MakerUNigNormalOutputPSP(), 329 | continuous.py:registerBuiltinSP("make_suff_stat_normal", typed_nr(MakerSuffNormalOutputPSP(), 330 | crp.py:registerBuiltinSP('make_crp', typed_nr(MakeCRPOutputPSP(), 331 | csp.py:registerBuiltinSP("make_csp", typed_nr(MakeCSPOutputPSP(), 332 | dirichlet.py:registerBuiltinSP("dirichlet", \ 333 | dirichlet.py:registerBuiltinSP("symmetric_dirichlet", \ 334 | dirichlet.py:registerBuiltinSP("make_dir_cat", \ 335 | dirichlet.py:registerBuiltinSP("make_uc_dir_cat", \ 336 | dirichlet.py:registerBuiltinSP("make_sym_dir_cat", \ 337 | dirichlet.py:registerBuiltinSP("make_uc_sym_dir_cat", 338 | discrete.py:registerBuiltinSP("flip", typed_nr(BernoulliOutputPSP(), 339 | discrete.py:registerBuiltinSP("bernoulli", typed_nr(BernoulliOutputPSP(), 340 | discrete.py:registerBuiltinSP("log_flip", typed_nr(LogBernoulliOutputPSP(), 341 | discrete.py:registerBuiltinSP("log_bernoulli", typed_nr(LogBernoulliOutputPSP(), 342 | discrete.py:registerBuiltinSP("log_odds_flip", typed_nr(LogOddsBernoulliOutputPSP(), 343 | discrete.py:registerBuiltinSP("log_odds_bernoulli", typed_nr(LogOddsBernoulliOutputPSP(), 344 | discrete.py:registerBuiltinSP("binomial", typed_nr(BinomialOutputPSP(), 345 | discrete.py:registerBuiltinSP("categorical", typed_nr(CategoricalOutputPSP(), 346 | discrete.py:registerBuiltinSP("log_categorical", typed_nr(LogCategoricalOutputPSP(), 347 | discrete.py:registerBuiltinSP("uniform_categorical", typed_nr(UniformCategoricalOutputPSP(), 348 | discrete.py:registerBuiltinSP("uniform_discrete", typed_nr(UniformDiscreteOutputPSP(), 349 | discrete.py:registerBuiltinSP("poisson", typed_nr(PoissonOutputPSP(), 350 | discrete.py:registerBuiltinSP("make_beta_bernoulli", typed_nr(MakerCBetaBernoulliOutputPSP(), 351 | discrete.py:registerBuiltinSP("make_uc_beta_bernoulli", 352 | discrete.py:registerBuiltinSP("make_suff_stat_bernoulli", 353 | discrete.py:registerBuiltinSP("exactly", typed_nr(ExactlyOutputPSP(), 354 | discrete.py:registerBuiltinSP("make_gamma_poisson", typed_nr(MakerCGammaPoissonOutputPSP(), 355 | discrete.py:registerBuiltinSP("make_uc_gamma_poisson", 356 | discrete.py:registerBuiltinSP("make_suff_stat_poisson", 357 | eval_sps.py:registerBuiltinSP("get_current_environment", 358 | eval_sps.py:registerBuiltinSP("get_empty_environment", 359 | eval_sps.py:registerBuiltinSP("is_environment", type_test(env.EnvironmentType())) 360 | eval_sps.py:registerBuiltinSP("extend_environment", 361 | eval_sps.py:registerBuiltinSP("eval", 362 | eval_sps.py:registerBuiltinSP("address_of", 363 | function.py:registerBuiltinSP("apply_function", applyFunctionSP) 364 | functional.py:registerBuiltinSP( 365 | functional.py:registerBuiltinSP( 366 | functional.py:registerBuiltinSP( 367 | functional.py:registerBuiltinSP( 368 | functional.py:registerBuiltinSP( 369 | functional.py:registerBuiltinSP( 370 | gp.py:registerBuiltinSP('make_gp', makeGPSP) 371 | gp.py:registerBuiltinSP('gp_mean_const', 372 | gp.py:registerBuiltinSP('gp_cov_const', 373 | gp.py:registerBuiltinSP('gp_cov_delta', 374 | gp.py:registerBuiltinSP('gp_cov_deltoid', 375 | gp.py:registerBuiltinSP('gp_cov_bump', 376 | gp.py:registerBuiltinSP('gp_cov_se', 377 | gp.py:registerBuiltinSP('gp_cov_periodic', 378 | gp.py:registerBuiltinSP('gp_cov_rq', 379 | gp.py:registerBuiltinSP('gp_cov_matern', 380 | gp.py:registerBuiltinSP('gp_cov_matern_32', 381 | gp.py:registerBuiltinSP('gp_cov_matern_52', 382 | gp.py:registerBuiltinSP('gp_cov_linear', 383 | gp.py:registerBuiltinSP('gp_cov_bias', 384 | gp.py:registerBuiltinSP('gp_cov_scale', 385 | gp.py:registerBuiltinSP('gp_cov_sum', 386 | gp.py:registerBuiltinSP('gp_cov_product', 387 | hmm.py:registerBuiltinSP("make_lazy_hmm", typed_nr(MakeUncollapsedHMMOutputPSP(), 388 | msp.py:registerBuiltinSP("mem",typed_nr(MakeMSPOutputPSP(), 389 | records.py: registerBuiltinSP(name, constructor) 390 | records.py: registerBuiltinSP("is_" + name, tester) 391 | records.py: registerBuiltinSP(f, a) 392 | scope.py:registerBuiltinSP("tag", 393 | scope.py:registerBuiltinSP("tag_exclude", 394 | sp_registry.py:def registerBuiltinSP(name, sp): 395 | vectors.py:registerBuiltinSP("array", 396 | vectors.py:registerBuiltinSP("vector", 397 | vectors.py:registerBuiltinSP("is_array", type_test(t.ArrayType())) 398 | vectors.py:registerBuiltinSP("is_vector", type_test(t.ArrayUnboxedType(t.NumberType()))) 399 | vectors.py:registerBuiltinSP("to_array", 400 | vectors.py:registerBuiltinSP("to_vector", 401 | vectors.py:registerBuiltinSP("matrix", 402 | vectors.py:registerBuiltinSP("is_matrix", type_test(t.MatrixType())) 403 | vectors.py:registerBuiltinSP("simplex", 404 | vectors.py:registerBuiltinSP("is_simplex", type_test(t.SimplexType())) 405 | vectors.py:registerBuiltinSP("normalize", 406 | vectors.py:registerBuiltinSP("arange", 407 | vectors.py:registerBuiltinSP("fill", 408 | vectors.py:registerBuiltinSP("linspace", 409 | vectors.py:registerBuiltinSP("zero_matrix", 410 | vectors.py:registerBuiltinSP("id_matrix", 411 | vectors.py:registerBuiltinSP("diag_matrix", 412 | vectors.py:registerBuiltinSP("ravel", 413 | vectors.py:registerBuiltinSP("transpose", 414 | vectors.py:registerBuiltinSP("vector_add", 415 | vectors.py:registerBuiltinSP("hadamard", 416 | vectors.py:registerBuiltinSP("matrix_add", 417 | vectors.py:registerBuiltinSP("scale_vector", 418 | vectors.py:registerBuiltinSP("scale_matrix", 419 | vectors.py:registerBuiltinSP("vector_dot", 420 | vectors.py:registerBuiltinSP("matrix_mul", 421 | vectors.py:registerBuiltinSP("matrix_times_vector", 422 | vectors.py:registerBuiltinSP("vector_times_matrix", 423 | vectors.py:registerBuiltinSP("matrix_inverse", 424 | vectors.py:registerBuiltinSP("matrix_solve", 425 | vectors.py:registerBuiltinSP("matrix_trace", 426 | vectors.py:registerBuiltinSP("row", 427 | vectors.py:registerBuiltinSP("col", 428 | vectors.py:registerBuiltinSP("sum", 429 | vectors.py:registerBuiltinSP("append", 430 | venmath.py:registerBuiltinSP("add", no_request(generic_add)) 431 | venmath.py:registerBuiltinSP("sub", no_request(generic_sub)) 432 | venmath.py:registerBuiltinSP("mul", no_request(generic_times)) 433 | venmath.py:registerBuiltinSP("div", binaryNum(divide, 434 | venmath.py:registerBuiltinSP("int_div", binaryNumInt(integer_divide, 435 | venmath.py:registerBuiltinSP("int_mod", binaryNumInt(integer_mod, 436 | venmath.py:registerBuiltinSP("min", 437 | venmath.py:registerBuiltinSP("max", 438 | venmath.py:registerBuiltinSP("floor", unaryNum(math.floor, 439 | venmath.py:registerBuiltinSP("sin", unaryNum(math.sin, sim_grad=grad_sin, 440 | venmath.py:registerBuiltinSP("cos", unaryNum(math.cos, sim_grad=grad_cos, 441 | venmath.py:registerBuiltinSP("tan", unaryNum(math.tan, sim_grad=grad_tan, 442 | venmath.py:registerBuiltinSP("hypot", binaryNum(math.hypot, 443 | venmath.py:registerBuiltinSP("exp", unaryNum(exp, 444 | venmath.py:registerBuiltinSP("expm1", unaryNum(expm1, 445 | venmath.py:registerBuiltinSP("log", unaryNum(log, 446 | venmath.py:registerBuiltinSP("log1p", unaryNum(log1p, 447 | venmath.py:registerBuiltinSP("pow", binaryNum(math.pow, sim_grad=grad_pow, 448 | venmath.py:registerBuiltinSP("sqrt", unaryNum(math.sqrt, sim_grad=grad_sqrt, 449 | venmath.py:registerBuiltinSP("atan2", binaryNum(math.atan2, 450 | venmath.py:registerBuiltinSP("negate", unaryNum(lambda x: -x, sim_grad=grad_negate, 451 | venmath.py:registerBuiltinSP("abs", unaryNum(abs, sim_grad=grad_abs, 452 | venmath.py:registerBuiltinSP("signum", unaryNum(signum, 453 | venmath.py:registerBuiltinSP("logistic", unaryNum(logistic, sim_grad=grad_logistic, 454 | venmath.py:registerBuiltinSP("logisticv", deterministic_typed(logistic, 455 | venmath.py:registerBuiltinSP("logit", unaryNum(logit, 456 | venmath.py:registerBuiltinSP("log_logistic", unaryNum(log_logistic, 457 | venmath.py:registerBuiltinSP("logit_exp", unaryNum(logit_exp, 458 | venmath.py:registerBuiltinSP("logsumexp", deterministic_typed(logsumexp, 459 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | services: 4 | notebook: 5 | image: probcomp/metaprob-clojure:latest 6 | ports: 7 | - 8888:8888 8 | hostname: ${USER}-notebook 9 | environment: 10 | NB_UID: ${NB_UID} 11 | volumes: 12 | - .:/home/jovyan/metaprob-clojure:rw 13 | -------------------------------------------------------------------------------- /project.clj: -------------------------------------------------------------------------------- 1 | (defproject metaprob "0.1.0-SNAPSHOT" 2 | :jvm-opts [ 3 | "-Xss50M" ; See `deps.edn` for an explanation of this setting 4 | "-Dhttps.protocols=TLSv1.2" ; See https://stackoverflow.com/a/50956622 5 | ] 6 | :source-paths ["tutorial/src"] 7 | :resource-paths ["tutorial/resources"] 8 | :dependencies [[org.clojure/data.json "0.2.6"] 9 | [lein-jupyter "0.1.16"]] 10 | :plugins [[lein-tools-deps "0.4.3"] 11 | [lein-jupyter "0.1.16"]] 12 | :middleware [lein-tools-deps.plugin/resolve-dependencies-with-deps-edn] 13 | :lein-tools-deps/config {:config-files [:install :user :project]}) 14 | -------------------------------------------------------------------------------- /src/metaprob/autotrace.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.autotrace 2 | #?(:cljs (:require-macros [metaprob.autotrace])) 3 | (:require [metaprob.code-handlers :as code] 4 | [metaprob.expander :as expander] 5 | [metaprob.generative-functions :refer [gen]])) 6 | 7 | (declare autotrace-expression) 8 | 9 | (defn autotrace-expressions 10 | [expressions stack] 11 | (map-indexed #(autotrace-expression %2 (cons %1 stack)) expressions)) 12 | 13 | (defmacro autotrace [gen-expr] 14 | (let [expr (expander/mp-expand #?(:cljs &env) gen-expr) 15 | result `(gen ~@(if (code/gen-has-annotations? expr) 16 | [(code/gen-annotations expr)] 17 | []) 18 | ~(code/gen-pattern expr) 19 | ~@(autotrace-expressions (code/gen-body expr) '()))] 20 | result)) 21 | 22 | (defn autotrace-expression 23 | [expr stack] 24 | (cond (or (code/fn-expr? expr) 25 | (code/gen-expr? expr)) 26 | `(~(first expr) ~@(if (code/gen-has-annotations? expr) [(code/gen-annotations expr)] []) 27 | ~(code/gen-pattern expr) 28 | ~@(autotrace-expressions (code/gen-body expr) stack)) 29 | 30 | (code/quote-expr? expr) 31 | expr 32 | 33 | (code/if-expr? expr) 34 | `(if ~(autotrace-expression (code/if-predicate expr) (cons "predicate" stack)) 35 | ~(autotrace-expression (code/if-then-clause expr) (cons "then" stack)) 36 | ~(autotrace-expression (code/if-else-clause expr) (cons "else" stack))) 37 | 38 | (seq? expr) 39 | (if (or (= (first expr) 'apply-at) 40 | (= (first expr) 'at) 41 | (code/fn-expr? (first expr))) 42 | (autotrace-expressions expr stack) 43 | `(~'at 44 | '~(reverse (cons (str (first expr)) stack)) 45 | ~@(autotrace-expressions expr stack))) 46 | 47 | true expr)) 48 | -------------------------------------------------------------------------------- /src/metaprob/code_handlers.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.code-handlers) 2 | 3 | (defn name-checker [n] 4 | (fn [x] 5 | (and (seq? x) 6 | (symbol? (first x)) 7 | (= (name (first x)) n)))) 8 | 9 | (defn symbol-checker [n] 10 | (fn [x] 11 | (and (seq? x) 12 | (= (first x) n)))) 13 | 14 | (def fn-expr? (name-checker "fn")) 15 | (def let-expr? (name-checker "let")) 16 | (def do-expr? (name-checker "do")) 17 | (def let-traced-expr? (name-checker "let-traced")) 18 | (def gen-expr? (name-checker "gen")) 19 | 20 | (defn gen-name 21 | [expr] 22 | (cond 23 | (symbol? (second expr)) (second expr) 24 | (map? (second expr)) (get (second expr) :name) 25 | true nil)) 26 | 27 | (defn gen-annotations 28 | [expr] 29 | (if (map? (second expr)) 30 | (second expr) 31 | {})) 32 | 33 | (defn gen-has-annotations? 34 | [expr] 35 | (not (vector? (second expr)))) 36 | 37 | (defn gen-pattern 38 | [expr] 39 | (if (gen-has-annotations? expr) 40 | (nth expr 2) 41 | (second expr))) 42 | 43 | (defn gen-body 44 | [expr] 45 | (if (gen-has-annotations? expr) 46 | (rest (rest (rest expr))) 47 | (rest (rest expr)))) 48 | 49 | (defn map-gen 50 | [f gen-expr] 51 | (if (gen-has-annotations? gen-expr) 52 | (cons (first gen-expr) 53 | (cons (second gen-expr) 54 | (cons (gen-pattern gen-expr) (map f (gen-body gen-expr))))) 55 | (cons (first gen-expr) 56 | (cons (gen-pattern gen-expr) (map f (gen-body gen-expr)))))) 57 | 58 | (def if-expr? (symbol-checker 'if)) 59 | 60 | (def if-predicate second) 61 | 62 | (defn if-then-clause [expr] (nth expr 2)) 63 | 64 | (defn if-else-clause [expr] 65 | (if (< (count expr) 4) nil (nth expr 3))) 66 | 67 | (def variable? symbol?) 68 | 69 | (def quote-expr? (symbol-checker 'quote)) 70 | (def quote-quoted second) 71 | 72 | (defn literal? 73 | [expr] 74 | (or (not (or (seq? expr) (vector? expr) (map? expr))) 75 | (empty? expr))) 76 | 77 | (defn let-bindings 78 | [expr] 79 | (partition 2 (second expr))) 80 | 81 | (defn let-body 82 | [expr] 83 | (rest (rest expr))) 84 | -------------------------------------------------------------------------------- /src/metaprob/distributions.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.distributions 2 | (:refer-clojure :exclude [apply map replicate reduce]) 3 | (:require [metaprob.prelude :as mp :refer [map make-primitive]]) 4 | #?(:clj (:import [org.apache.commons.math3.distribution BetaDistribution GammaDistribution]))) 5 | 6 | (def exactly 7 | (make-primitive 8 | (fn [x] x) 9 | (fn [y [x]] (if (not= y x) mp/negative-infinity 0)))) 10 | 11 | (def uniform 12 | (make-primitive 13 | (fn [a b] (mp/sample-uniform a b)) 14 | (fn [x [a b]] (if (<= a x b) (- (mp/log (- b a))) mp/negative-infinity)))) 15 | 16 | (def uniform-discrete 17 | (make-primitive 18 | (fn [items] (nth items (Math/floor (* (mp/sample-uniform) (count items))))) 19 | (fn [item [items]] 20 | (- (mp/log (count (filter #(= % item) items))) 21 | (mp/log (count items)))))) 22 | 23 | (def flip 24 | (make-primitive 25 | (fn [weight] (< (mp/sample-uniform) weight)) 26 | (fn [value [weight]] 27 | (if value 28 | (mp/log weight) 29 | (mp/log1p (- weight)))))) 30 | 31 | (defn normalize-numbers [nums] 32 | (let [total (clojure.core/reduce + nums)] (map #(/ % total) nums))) 33 | 34 | (def categorical 35 | (make-primitive 36 | (fn [probs] 37 | (if (map? probs) 38 | (nth (keys probs) (categorical (normalize-numbers (vals probs)))) ;; TODO: normalization not needed here? 39 | (let [total (clojure.core/reduce + probs) 40 | r (* (mp/sample-uniform) total)] 41 | (->> probs 42 | (reductions +) 43 | (take-while #(< % r)) 44 | count)))) 45 | (fn [i [probs]] 46 | (if (map? probs) 47 | (if (not (contains? probs i)) mp/negative-infinity (- (mp/log (get probs i)) (mp/log (clojure.core/reduce + (vals probs))))) 48 | (mp/log (nth probs i)))))) 49 | 50 | (defn logsumexp [scores] 51 | (let [max-score (mp/apply max scores) 52 | weights (map #(Math/exp (- % max-score)) scores)] 53 | (+ (Math/log (clojure.core/reduce + weights)) max-score))) 54 | 55 | (defn logmeanexp [scores] 56 | (- (logsumexp scores) (mp/log (count scores)))) 57 | 58 | (defn log-scores-to-probabilities [scores] 59 | (let [log-normalizer (logsumexp scores)] 60 | (map #(Math/exp (- % log-normalizer)) scores))) 61 | 62 | 63 | (def log-categorical 64 | (make-primitive 65 | (fn [scores] 66 | (let [probs 67 | (if (map? scores) (into {} (clojure.core/map (fn [a b] [a b]) (keys scores) (log-scores-to-probabilities (vals scores)))) 68 | (log-scores-to-probabilities scores))] 69 | (categorical probs))) 70 | (fn [i [scores]] 71 | (let [probs 72 | (if (map? scores) (into {} (clojure.core/map (fn [a b] [a b]) (keys scores) (log-scores-to-probabilities (vals scores)))) 73 | (log-scores-to-probabilities scores))] 74 | (if (map? probs) 75 | (if (not (contains? probs i)) mp/negative-infinity (- (mp/log (get probs i)) (mp/log (clojure.core/reduce + (vals probs))))) 76 | (mp/log (nth probs i))))))) 77 | 78 | 79 | (defn generate-gaussian [mu sigma] 80 | (+ mu (* sigma (Math/sqrt (* -2 (Math/log (mp/sample-uniform)))) (Math/cos (* 2 Math/PI (mp/sample-uniform)))))) 81 | (defn standard-gaussian-log-density [x] (* -0.5 (+ (Math/log (* 2 Math/PI)) (* x x)))) 82 | (defn score-gaussian [x [mu sigma]] 83 | (- (standard-gaussian-log-density (/ (- x mu) sigma)) (Math/log sigma))) 84 | 85 | (def gaussian 86 | (make-primitive 87 | generate-gaussian 88 | score-gaussian)) 89 | 90 | (def geometric 91 | (make-primitive 92 | (fn [p] (loop [i 0] (if (flip p) (recur (+ i 1)) i))) 93 | (fn [v [p]] (+ (mp/log1p (- p)) (* (mp/log p) v))))) 94 | 95 | #?(:clj (def gamma 96 | (make-primitive 97 | (fn [shape scale] 98 | (.sample (GammaDistribution. shape scale))) 99 | (fn [x [shape scale]] 100 | (.logDensity (GammaDistribution. shape scale) 101 | x))))) 102 | 103 | #?(:clj (def beta 104 | (make-primitive 105 | (fn [alpha beta] 106 | (.sample (BetaDistribution. alpha beta))) 107 | (fn [x [alpha beta]] 108 | (.logDensity (BetaDistribution. alpha beta) 109 | x))))) 110 | -------------------------------------------------------------------------------- /src/metaprob/examples/aide.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.aide 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [metaprob.trace :refer :all] 4 | [metaprob.generative-functions :refer :all] 5 | [metaprob.prelude :refer [map replicate expt infer-and-score]] 6 | [metaprob.distributions :refer :all] 7 | [clojure.pprint :refer [pprint]] 8 | [metaprob.inference :refer :all])) 9 | 10 | ;; AIDE 11 | 12 | 13 | (def log-likelihood-weighting 14 | (gen [& {:keys [model inputs observations n-particles] 15 | :or {inputs []}}] 16 | (let [scores 17 | (map (fn [i] 18 | (let [[_ _ s] 19 | (at i infer-and-score :procedure model :inputs inputs :observation-trace observations)] 20 | s)) 21 | (range n-particles))] 22 | (logmeanexp scores)))) 23 | 24 | (defn intervene [f t] 25 | (gen [& args] (first (apply-at '() (make-constrained-generator f t) args)))) 26 | 27 | ;; Calculate the mean of some numbers 28 | (defn avg [xs] (/ (reduce + xs) (count xs))) 29 | 30 | ;; Compare two generative functions at some subset of addresses. 31 | ;; Returns an unbiased estimate of an upper bound on the symmetric KL divergence 32 | ;; between f and g. 33 | (defn compare-generative-functions [f g addresses Nf Mf Ng Mg] 34 | (let [f-traces 35 | (map #(partition-trace % addresses) (replicate Nf #(nth (infer-and-score :procedure f) 1))) 36 | 37 | g-traces 38 | (map #(partition-trace % addresses) (replicate Ng #(nth (infer-and-score :procedure g) 1))) 39 | 40 | ;; Estimate the expectation, when x ~ g, of log P_f(x), using likelihood weighting 41 | f-scores-on-g-samples 42 | (map (fn [[x _]] (log-likelihood-weighting :model f, :observations x, :n-particles Mf)) g-traces) 43 | 44 | ;; Estimate the expectation, when x ~ f, of log P_g(x), using likelihood weighting 45 | g-scores-on-f-samples 46 | (map (fn [[x _]] (log-likelihood-weighting :model g, :observations x, :n-particles Mg)) f-traces) 47 | 48 | ;; Estimate the expectation, when x ~ f, of log P_f(x) 49 | f-scores-on-f-samples 50 | (map (fn [[x u]] ((intervene log-likelihood-weighting {0 u}) :model f, :observations x, :n-particles Mf)) f-traces) 51 | 52 | g-scores-on-g-samples 53 | (map (fn [[x v]] ((intervene log-likelihood-weighting {0 v}) :model g, :observations x, :n-particles Mg)) g-traces)] 54 | 55 | ;; Use Clojure's version of `map`, which can take two lists l and m, 56 | ;; and apply a function (like -) to l[0],m[0], l[1],m[1], etc. 57 | (+ (avg (clojure.core/map - f-scores-on-f-samples g-scores-on-f-samples)) 58 | (avg (clojure.core/map - g-scores-on-g-samples f-scores-on-g-samples))))) 59 | 60 | (def importance-resampling-model 61 | (gen [model inputs observations N] 62 | (let [;; Generate N particles of the form [retval trace weight], 63 | ;; tracing the ith particle at '("particles" i) 64 | particles 65 | (map (fn [i] 66 | (at `("particles" ~i) 67 | infer-and-score 68 | :procedure model, 69 | :inputs inputs, 70 | :observation-trace observations)) 71 | (range N)) 72 | 73 | ;; Choose one of the particles, according to their weights 74 | chosen-index 75 | (at "chosen-index" log-categorical (map #(nth % 2) particles)) 76 | 77 | ;; Pull out the trace of the chosen particle 78 | chosen-particle-trace 79 | (let [[retval trace score] (nth particles chosen-index)] trace)] 80 | 81 | ;; We need the chosen trace -- the "inference answer" that we're giving 82 | ;; to exist in _our_ trace at a predictable address. If the model we're 83 | ;; doing inference about has latent variables at addresses "x" and "y", 84 | ;; for example, then our inferred latent variables should have addresses (say) 85 | ;; '("inferred-trace" "x") and '("inferred-trace" "y"). 86 | ;; We do this below by looping through every variable in our inferred trace, 87 | ;; and "sampling" it again, using the deterministic `exactly` distribution: 88 | (map (fn [model-addr] 89 | (at `("inferred-trace" ~@model-addr) 90 | exactly (trace-value chosen-particle-trace model-addr))) 91 | (addresses-of chosen-particle-trace)) 92 | 93 | ;; Return the chosen particle 94 | chosen-particle-trace))) 95 | 96 | 97 | (defn make-smart-importance-resampling-proposer 98 | [meta-observation-trace] 99 | (let [inferred-trace (trace-subtrace meta-observation-trace "inferred-trace")] 100 | (gen [model inputs observations N] 101 | (let [chosen-index 102 | (at "chosen-index" uniform-discrete (range N)) 103 | 104 | other-indices 105 | (filter (fn [i] (not= i chosen-index)) (range N))] 106 | 107 | ;; Randomly sample particles at the other indices 108 | (map (fn [i] 109 | (at `("particles" ~i) 110 | infer-and-score :procedure model :inputs inputs :observation-trace observations)) 111 | other-indices) 112 | 113 | ;; Force exact samples of the inferred trace's choices at the chosen index 114 | (map (fn [addr] 115 | (at `("particles" ~chosen-index ~@addr) exactly (trace-value inferred-trace addr))) 116 | (addresses-of inferred-trace)))))) 117 | 118 | 119 | (def importance-resampling-gf 120 | (with-custom-proposal-attached 121 | importance-resampling-model 122 | make-smart-importance-resampling-proposer 123 | (fn [tr] (trace-has-subtrace? tr "inferred-trace")))) 124 | 125 | ;; Coin-flipping model 126 | 127 | (def coin-model 128 | (gen [n] 129 | (let-traced [p (beta 1 1)] 130 | (map (fn [i] (at i flip p)) (range n))))) 131 | 132 | (defn make-approx-inference-algorithm 133 | [n observations n-particles] 134 | (gen [] 135 | (at '() importance-resampling-gf coin-model [n] observations n-particles))) 136 | 137 | (defn exact-inference [n observations] 138 | (let [all-flips (filter boolean? (map (fn [addr] (trace-value observations addr)) (addresses-of observations))) 139 | heads (count (filter true? all-flips)) 140 | tails (count (filter false? all-flips))] 141 | (gen [] 142 | (let [p (at '("inferred-trace" "p") beta (inc heads) (inc tails))] 143 | (doseq [i (range n)] 144 | (when (not (trace-has-value? observations i)) 145 | (at `("inferred-trace" ~i) flip p))))))) 146 | 147 | (defn aide-demo [n observations] 148 | (doseq [i [1 2 3 5 10 15 20]] 149 | (println (str i " particles:")) 150 | (pprint (compare-generative-functions 151 | (exact-inference n observations) 152 | (make-approx-inference-algorithm n observations i) 153 | '(("inferred-trace" "p")) 154 | 100 1, 100 20)))) 155 | 156 | 157 | ;; Create an observation trace specifying we saw 7 heads and 3 tails 158 | (def seven-heads (into {} (map-indexed (fn [i x] [i {:value x}]) (concat (repeat 7 true) (repeat 3 false))))) 159 | 160 | (defn -main [] 161 | (aide-demo 10 seven-heads)) -------------------------------------------------------------------------------- /src/metaprob/examples/all.clj: -------------------------------------------------------------------------------- 1 | ;; Kitchen sink namespace. 2 | ;; Intended for use with (in-ns 'metaprob.examples.all). 3 | 4 | ;; Import a gajillion things so that this namespace is fun to use at 5 | ;; the REPL. 6 | 7 | (ns metaprob.examples.all 8 | (:refer-clojure :exclude [map replicate apply]) 9 | (:require 10 | [metaprob.trace :refer :all] 11 | [metaprob.generative-functions :refer :all] 12 | [clojure.pprint :refer [pprint]] 13 | [metaprob.prelude :refer :all] 14 | [metaprob.distributions :refer :all] 15 | [metaprob.inference :refer :all] 16 | [metaprob.autotrace :refer :all] 17 | [metaprob.examples.flip-n-coins :refer :all] 18 | [metaprob.examples.earthquake :refer :all :exclude [flip]] 19 | [metaprob.examples.inference-on-gaussian :refer :all])) 20 | ;; The following would simplify startup a bit but I'm getting 21 | ;; an error in `clojure -Atest` 22 | ;; (:require [clojure.tools.namespace.repl :refer [refresh]]) 23 | 24 | 25 | ;; You may prefer to invoke particular demos in the REPL, rather than 26 | ;; run them all wholesale 27 | 28 | (defn demo 29 | [] 30 | ;; Coin flips 31 | ;; (demo-coin-flips) 32 | (pprint (coin-flips-demo-n-flips 2)) 33 | (pprint (coin-flips-demo-biased 10)) ;; with intervention 34 | 35 | ;; Bayes net (earthquake) 36 | ;; prior: 37 | ;; exact probabilities 38 | ;; random sample 39 | ;; with intervention: 40 | ;; exact probabilities 41 | ;; random sample 42 | ;; TBD: importance sampling 43 | ;; TBD: rejection resampling 44 | ;; earthquake_bayesian_network could return the query instead of 45 | ;; the trace (that would then be queried) 46 | 47 | (demo-earthquake) 48 | 49 | ;; 2D gaussian 50 | ;; Harness is in main.clj. 51 | ;; 4 calls / plots: prior, rejection, importance, MH 52 | ;; (demo-gaussian) 53 | 54 | (let [number-of-runs 100] 55 | (gaussian-prior-samples number-of-runs) 56 | (rejection-assay number-of-runs) 57 | (importance-assay 100 number-of-runs) 58 | ;; (MH-assay number-of-runs) 59 | )) 60 | -------------------------------------------------------------------------------- /src/metaprob/examples/curve_fitting.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.curve-fitting 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [metaprob.trace :refer :all] 4 | [metaprob.generative-functions :refer :all] 5 | [metaprob.prelude :refer [map expt replicate infer-and-score]] 6 | [metaprob.distributions :refer :all] 7 | [clojure.pprint :refer [pprint]] 8 | [metaprob.inference :refer :all])) 9 | 10 | ;; Generate a random polynomial of degree 0, 1, 2, or 3 11 | (def random-polynomial 12 | (gen [] 13 | (let [coeffs (map (fn [i] (at `("coeffs" ~i) gaussian 0 1)) 14 | (range (+ 1 (at "degree" uniform-discrete [0 1 2 3]))))] 15 | (fn [x] (reduce + (map-indexed (fn [n c] (* c (expt x n))) coeffs)))))) 16 | 17 | ;; Create a generative function that is a noisy version of 18 | ;; a deterministic input function 19 | (def add-noise 20 | (gen [f] 21 | (let-traced [noise (gamma 1 1) 22 | prob-outlier (beta 1 10)] 23 | (gen [x] 24 | (if (at "outlier?" flip prob-outlier) 25 | (at "y" gaussian 0 10) 26 | (at "y" gaussian (f x) noise)))))) 27 | 28 | ;; Given a list of xs, create a list of ys that are related 29 | ;; via a noisy polynomial relationship 30 | (def curve-model 31 | (gen [xs] 32 | (let-traced [underlying-curve (random-polynomial) 33 | noisy-curve (add-noise underlying-curve)] 34 | (doall (map-indexed (fn [i x] (at `("data" ~i) noisy-curve x)) xs))))) 35 | 36 | ;; Useful helpers for curve-fitting 37 | (defn make-observation-trace 38 | [ys] 39 | {"data" (into {} (map-indexed (fn [i y] [i {"y" {:value y}}]) ys))}) 40 | 41 | 42 | ;; Create datasets 43 | (def point-count 10) 44 | (def x-min -5) 45 | (def x-max 5) 46 | (def x-range (- x-max x-min)) 47 | (def x-interval (/ x-range (- point-count 1))) 48 | 49 | (def xs (map-indexed (fn [i interval] (+ -5. (* interval i))) (repeat point-count x-interval))) 50 | 51 | ;; Add a random outlier to a dataset 52 | (defn add-outlier [ys] 53 | (let [idx (uniform-discrete (range (count ys)))] 54 | (map-indexed (fn [i y] (if (= i idx) (gaussian 0 10) y)) ys))) 55 | 56 | ;; y = 2x + 1 57 | (def ys-linear (map #(+ (* 2 %) 1) xs)) 58 | 59 | ;; y = 2x + 1, with two outliers 60 | (def ys-linear-outlier (add-outlier (add-outlier ys-linear))) 61 | 62 | ;; y = 2x^2 - 2x - 1, with noisy observations and an outlier 63 | (def ys-quadratic (add-outlier (map #(gaussian (- (* 2 % %) (* 2 %) 1) 0.7) xs))) 64 | 65 | 66 | (defn inference-step [xs] 67 | (let [curve-step 68 | (custom-proposal-mh-step 69 | :model curve-model 70 | :inputs [xs] 71 | :proposal (make-resimulation-proposal 72 | :model curve-model, 73 | :inputs [xs], 74 | :address-predicate #(address-contains? % "underlying-curve"))) 75 | 76 | noise-step 77 | (custom-proposal-mh-step 78 | :model curve-model 79 | :inputs [xs] 80 | :proposal (make-resimulation-proposal 81 | :model curve-model 82 | :inputs [xs] 83 | :address-predicate #(address-contains? % "noisy-curve"))) 84 | 85 | outlier-proposal 86 | (fn [i] (fn [trace] 87 | (trace-set-value 88 | trace `("data" ~i "outlier?") 89 | (not (trace-value trace `("data" ~i "outlier?")))))) 90 | 91 | outlier-steps 92 | (map (fn [i] (symmetric-proposal-mh-step :model curve-model 93 | :inputs [xs] 94 | :proposal (outlier-proposal i))) 95 | (range (count xs))) 96 | 97 | coeffs-step 98 | (gaussian-drift-mh-step :model curve-model 99 | :inputs [xs] 100 | :address-predicate #(address-contains? % "coeffs") 101 | :width 0.1)] 102 | 103 | (reduce comp identity `[~curve-step ~coeffs-step ~noise-step ~coeffs-step ~@outlier-steps]))) 104 | 105 | 106 | (defn run-mh [xs ys n] 107 | (let [[_ initial-trace _] 108 | (infer-and-score :procedure curve-model :inputs [xs] :observation-trace (make-observation-trace ys))] 109 | (take n (iterate (inference-step xs) initial-trace)))) 110 | 111 | 112 | (defn -main [] 113 | (pprint (last (run-mh xs ys-linear 200))) 114 | (pprint (last (run-mh xs ys-linear-outlier 300))) 115 | (pprint (last (run-mh xs ys-quadratic 10000)))) -------------------------------------------------------------------------------- /src/metaprob/examples/earthquake.clj: -------------------------------------------------------------------------------- 1 | ;; Selected nuggets from python-metaprob's src/inference.vnts file 2 | ;; See also figure 29 of the 7/17 chapter mss 3 | 4 | (ns metaprob.examples.earthquake 5 | (:refer-clojure :exclude [map replicate apply]) 6 | (:require [metaprob.generative-functions :refer :all] 7 | [metaprob.prelude :refer :all] 8 | [metaprob.trace :refer :all] 9 | [metaprob.distributions :refer :all] 10 | [metaprob.inference :refer :all])) 11 | 12 | ;; Convert a tuple of booleans to an integer. 13 | ;; Tuple element 0 determines the highest order bit. 14 | 15 | (defn bools-to-binary [bools] 16 | (reduce (fn [n b] (+ (* 2 n) (if b 1 0))) 0 bools)) 17 | ; 18 | ;(defn bools-to-binary [bools] 19 | ; (let [len (count bools)] 20 | ; (loop [i 0, n 0] 21 | ; (if (>= i len) 22 | ; n 23 | ; (recur (inc i) (+ (* 2 n) 24 | ; (if (nth bools i) 1 0))))))) 25 | 26 | (def earthquake-bayesian-network 27 | (gen [] 28 | (let-traced [earthquake (flip 0.1) 29 | burglary (flip 0.1) 30 | alarm (cond (and burglary earthquake) (flip 0.9) 31 | burglary (flip 0.85) 32 | earthquake (flip 0.2) 33 | true (flip 0.05)) 34 | john-call (flip (if alarm 0.8 0.1)) 35 | mary-call (flip (if alarm 0.9 0.4))] 36 | (bools-to-binary [earthquake burglary alarm john-call mary-call])))) 37 | 38 | (defn trace-to-binary [tr] 39 | (bools-to-binary 40 | (map #(trace-value tr %) ["earthquake" "burglary" "alarm" "john-call" "mary-call"]))) 41 | 42 | 43 | (defn earthquake-histogram 44 | [name samples] 45 | (binned-histogram 46 | :name name 47 | :samples samples 48 | :sample-lower-bound 0 49 | :sample-upper-bound 32 50 | :number-of-intervals 32 51 | :overlay-densities '())) 52 | 53 | ;; ---------------------------------------------------------------------------- 54 | ;; Calculate exact probabilities 55 | 56 | 57 | ;; Returns a list of output traces 58 | 59 | (defn joint-enumerate 60 | [addresses] 61 | (if (empty? addresses) 62 | '({}) 63 | (let [others (joint-enumerate (rest addresses)) 64 | addr (first addresses) 65 | trace-lists 66 | (map (fn [value] (map (fn [t] (trace-set-value t addr value)) others)) [true false])] 67 | 68 | (apply concat trace-lists)))) 69 | 70 | (defn intervene 71 | [f intervention] 72 | (gen [& args] 73 | (first (apply-at '() (make-constrained-generator f intervention) args)))) 74 | 75 | ;; Returns list of [state score] where state is value returned by 76 | ;; earthquake-bayesian-network 77 | 78 | (defn enumerate-executions 79 | [proc inputs intervention-trace target-trace] 80 | (let [[_ one-run _] 81 | (infer-and-score :procedure (intervene proc intervention-trace) 82 | :inputs inputs) 83 | 84 | all-addrs (addresses-of one-run) 85 | 86 | candidates 87 | (joint-enumerate all-addrs)] 88 | 89 | (map (fn [candidate] 90 | (let [[state _ score] 91 | (infer-and-score :procedure (intervene proc intervention-trace) 92 | :inputs inputs 93 | :observation-trace (trace-merge candidate target-trace))] 94 | [state score])) 95 | candidates))) 96 | 97 | ;(define enumerate-executions 98 | ; (gen [proc inputs intervention-trace target-trace] 99 | ; (print [(count (addresses-of intervention-trace)) "interventions"]) 100 | ; (define [_ one-run _] 101 | ; (infer :procedure proc 102 | ; :inputs inputs 103 | ; :intervention-trace intervention-trace)) 104 | ; (define all-sites (addresses-of one-run)) 105 | ; (print [(count all-sites) "sites"]) 106 | ; (define free-sites 107 | ; (set-difference 108 | ; (set-difference all-sites (addresses-of intervention-trace)) 109 | ; (addresses-of target-trace))) 110 | ; (print [(count free-sites) "free-sites"]) 111 | ; (define candidates (joint-enumerate free-sites)) 112 | ; (map (gen [candidate] 113 | ; ;; Returns [state nil score] 114 | ; (define [state _ score] 115 | ; (infer :procedure proc 116 | ; :inputs inputs 117 | ; :intervention-trace intervention-trace 118 | ; :target-trace (trace-merge candidate target-trace) 119 | ; :output-trace? false)) 120 | ; [state score]) 121 | ; candidates))) 122 | 123 | ;; Takes a list of [state score] and returns a list of samples. 124 | ;; A good multiplier is 12240. 125 | ;; The purpose is just so that we can easily reuse the histogram 126 | ;; plotting logic. 127 | 128 | (defn fake-samples-for-enumerated-executions 129 | [state-and-score-list multiplier] 130 | (mapcat (fn [[state score]] 131 | (let [count (round (* (exp score) multiplier))] 132 | (repeat count state))) state-and-score-list)) 133 | 134 | ;; ---------------------------------------------------------------------------- 135 | ;; Sample from the prior 136 | 137 | ;; Each sample is an output trace. 138 | 139 | (defn prior-samples 140 | [n-samples] 141 | (replicate n-samples #((infer-and-score :procedure earthquake-bayesian-network) 1))) 142 | 143 | ;; Test intervention 144 | (def alarm-went-off {"alarm" {:value true}}) 145 | 146 | ;; TODO: Use predicate version of rejection sampling? 147 | (defn eq-rejection-assay 148 | [number-of-runs] 149 | (replicate 150 | number-of-runs 151 | (fn [] 152 | (print "rejection sample") 153 | (rejection-sampling 154 | :model earthquake-bayesian-network 155 | :observation-trace alarm-went-off 156 | :log-bound 0)))) 157 | 158 | (defn eq-importance-assay 159 | [n-particles number-of-runs] 160 | (replicate 161 | number-of-runs 162 | (fn [] 163 | (importance-resampling 164 | :model earthquake-bayesian-network 165 | :obesrvation-trace alarm-went-off 166 | :n-particles n-particles)))) 167 | 168 | 169 | ;; TBD: importance sampling 170 | ;; TBD: rejection sampling 171 | 172 | (defn demo-earthquake 173 | [] 174 | (clojure.pprint/pprint "Exact prior probabilities") 175 | (let [exact-probabilities 176 | (enumerate-executions earthquake-bayesian-network [] {} {}) 177 | 178 | fake-samples 179 | (fake-samples-for-enumerated-executions exact-probabilities 12240) 180 | 181 | exact-awo-probabilities 182 | (enumerate-executions earthquake-bayesian-network [] alarm-went-off {}) 183 | 184 | fake-awo-samples 185 | (fake-samples-for-enumerated-executions exact-awo-probabilities 12240) 186 | 187 | n-samples 100] 188 | 189 | (earthquake-histogram "exact bayesnet prior probabilities" 190 | fake-samples) 191 | 192 | (earthquake-histogram "exact bayesnet alarm-went-off probabilities" 193 | fake-awo-samples) 194 | 195 | (earthquake-histogram "bayesnet sampled prior probabilities" 196 | (prior-samples n-samples)) 197 | 198 | (earthquake-histogram "bayesnet samples from rejection sampling" 199 | (map trace-to-binary (eq-rejection-assay n-samples))) 200 | 201 | (earthquake-histogram "bayesnet samples from importance sampling with 20 particles" 202 | (map trace-to-binary (eq-importance-assay 20 n-samples))))) 203 | -------------------------------------------------------------------------------- /src/metaprob/examples/flip_n_coins.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.flip-n-coins 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [metaprob.generative-functions :refer :all] 4 | [metaprob.prelude :refer :all] 5 | [metaprob.distributions :refer :all])) 6 | 7 | ;; Define a probabilistic model for n flips of a coin 8 | ;; with a custom address name for each coin flip 9 | 10 | (def flip-n-coins 11 | (gen [n] 12 | (let-traced [tricky (flip 0.1) 13 | p (if tricky (uniform 0 1) 0.5)] 14 | (map (fn [i] (at i flip p)) (range n))))) 15 | 16 | 17 | (defn coin-flips-demo-n-flips 18 | [n] 19 | (let [[_ trace-with-n-flips _] 20 | (infer-and-score :procedure flip-n-coins 21 | :inputs [n])] 22 | (infer-and-score :procedure flip-n-coins 23 | :inputs [n] 24 | :observation-trace trace-with-n-flips))) 25 | 26 | ;; make a partial trace that intervenes on flip-coins 27 | ;; to ensure the coin is tricky and the weight is 0.99 28 | ;; but the fourth flip comes up false 29 | 30 | (def ensure-tricky-and-biased 31 | {"tricky" {:value true} 32 | "p" {:value 0.99} 33 | 3 {:value false}}) 34 | 35 | (defn coin-flips-demo-biased 36 | [n] 37 | (infer-and-score :procedure flip-n-coins 38 | :inputs [n] 39 | :observation-trace ensure-tricky-and-biased)) 40 | 41 | -------------------------------------------------------------------------------- /src/metaprob/examples/inference_on_gaussian.clj: -------------------------------------------------------------------------------- 1 | ;; 5. 2 | 3 | ;; Try this: time lein run -m metaprob.examples.main 10 4 | 5 | (ns metaprob.examples.inference-on-gaussian 6 | (:refer-clojure :exclude [replicate map apply]) 7 | (:require [metaprob.generative-functions :refer :all] 8 | [metaprob.trace :refer :all] 9 | [metaprob.prelude :refer :all] 10 | [metaprob.distributions :refer :all] 11 | [metaprob.inference :refer :all])) 12 | 13 | ;; Exact versions of prior and target density functions, for 14 | ;; graphical comparison with sampled approximations. 15 | 16 | (def normal-normal 17 | (gen [] 18 | (let-traced [x (gaussian 0 1) 19 | y (gaussian x 1)] 20 | y))) 21 | 22 | (defn prior-density 23 | [x] 24 | (exp (score-gaussian x [0 1]))) 25 | 26 | (defn target-density 27 | [x] 28 | (exp (score-gaussian x [1.5 (/ 1.0 (sqrt 2.0))]))) 29 | 30 | ;; Each sample is an output trace. 31 | 32 | ;; Find the location of the (assumed unique) peak of the histogram. 33 | ;; For debugging. 34 | 35 | (defn peak-location 36 | [samples] 37 | (let [so (sort samples) 38 | window (+ 1 (quot (count so) 10)) 39 | lead (drop window so)] 40 | (nth (first (sort (clojure.core/map (fn [x y] [(- y x) (/ (+ x y) 2)]) so lead))) 1))) 41 | 42 | ;; For debugging. 43 | 44 | (defn analyze 45 | [samples] 46 | (print (first samples)) 47 | (print ["average:" (/ (reduce + samples) (count samples)) 48 | "peak:" (peak-location samples)]) 49 | samples) 50 | 51 | (defn gaussian-histogram 52 | [name samples] 53 | (binned-histogram 54 | :name name 55 | :samples (analyze samples) 56 | :overlay-densities `(["prior" ~prior-density] 57 | ["target" ~target-density]))) 58 | 59 | ;; Sample from prior & plot 60 | 61 | (defn gaussian-prior-samples 62 | [number-of-runs] 63 | (replicate number-of-runs normal-normal)) 64 | 65 | (def obs {"y" {:value 3}}) 66 | 67 | (defn rejection-assay 68 | [number-of-runs] 69 | (replicate 70 | number-of-runs 71 | (fn [] 72 | (print "rejection sample") ;Progress meter 73 | (trace-value 74 | (rejection-sampling :model normal-normal ; :model-procedure 75 | ; :predicate (fn [tr] (< 2.99 (trace-value tr "y") 3.01))) 76 | :observation-trace obs 77 | :log-bound 0.5) 78 | "x")))) 79 | 80 | (defn importance-assay 81 | [n-particles number-of-runs] 82 | (replicate 83 | number-of-runs 84 | (fn [] 85 | (trace-value 86 | (importance-resampling :model normal-normal ; :model-procedure 87 | :observation-trace obs 88 | :n-particles n-particles) "x")))) 89 | 90 | -------------------------------------------------------------------------------- /src/metaprob/examples/long_test.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.long-test 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [clojure.test :refer :all] 4 | [metaprob.trace :refer :all] 5 | [metaprob.generative-functions :refer :all] 6 | [metaprob.distributions :refer :all] 7 | [metaprob.inference :refer :all] 8 | [metaprob.prelude :refer :all])) 9 | 10 | ;; These tests are smoke tests, not real tests of the methods - the 11 | ;; real tests take too long and make `clojure -Atest` take too long. 12 | ;; For method tests, we use a longer-running procedure. 13 | 14 | ;; VKM requested 1000 samples on 2018-07-06. 15 | ;; 1000 is not quite enough to get convergence to within 10%. 16 | (def nsamples 1500) 17 | 18 | ;; VKM requested 1000 IS particles (!) 19 | (def n-particles 1000) 20 | 21 | ;; VKM requested 50 MH steps per sample 22 | (def n-mh-steps 50) 23 | 24 | ;; JAR's choice (20 bins makes for 50 samples per bin; the more samples 25 | ;; per bin, the more accurate the estimate) 26 | (def nbins 15) 27 | 28 | ;; What VKM requested 2018-07-06 29 | (def threshold 0.1) 30 | 31 | ;; Travis kills the process if it's silent for 10 minutes 32 | (defn tell-travis [message] 33 | (if (< (uniform 0 1) 0.01) ; We could count, but using RNG is easier to program 34 | (binding [*out* *err*] 35 | (println message) 36 | (flush)))) 37 | 38 | 39 | ;; This is to see whether the test harness itself is basically working: 40 | 41 | (deftest check-check 42 | (testing "check check" 43 | (let [sampler (fn [i] (uniform 0 1)) 44 | pdf (fn [x] 1)] 45 | (is (assay "0" sampler nsamples pdf nbins threshold))))) 46 | 47 | ;; Compare sampling from Gaussian prior to exact PDF of prior: 48 | (def normal-normal 49 | (gen [] 50 | (let-traced [x (gaussian 0 1) 51 | y (gaussian x 1)] 52 | y))) 53 | 54 | (defn target-density 55 | [x] 56 | (exp (score-gaussian x [1.5 (/ 1.0 (sqrt 2.0))]))) 57 | 58 | (deftest check-prior 59 | (testing "check sampling from gaussian prior" 60 | (let [sampler (fn [i] 61 | (tell-travis "Prior") 62 | (gaussian 0 1)) 63 | pdf (fn [x] (exp (score-gaussian x [0 1])))] 64 | (is (assay "p" sampler nsamples pdf nbins threshold))))) 65 | 66 | (deftest check-prior-failure 67 | (testing "check sampling from 'wrong' gaussian prior" 68 | (let [sampler (fn [i] 69 | (tell-travis "Wrong prior") 70 | (gaussian 0.5 1.2)) ;wrong gaussian!! 71 | pdf (fn [x] (exp (score-gaussian x [0 1])))] 72 | (is (> (badness sampler nsamples pdf nbins) threshold))))) 73 | 74 | (deftest check-rejection 75 | (testing "check rejection sampling" 76 | (let [sampler (fn [i] 77 | (tell-travis "Rejection") 78 | (trace-value 79 | (rejection-sampling :model normal-normal 80 | :observation-trace {"y" {:value 3}} 81 | :log-bound 0.5) 82 | "x")) 83 | pdf target-density] 84 | (is (assay "r" sampler nsamples pdf nbins threshold))))) 85 | 86 | (deftest check-importance 87 | (testing "check importance sampling" 88 | (let [sampler (fn [i] 89 | (tell-travis "Importance") 90 | (trace-value 91 | (importance-resampling :model normal-normal 92 | :observation-trace {"y" {:value 3}} 93 | :n-particles n-particles) 94 | "x")) 95 | pdf target-density] 96 | (is (assay "i" sampler nsamples pdf nbins threshold))))) 97 | ; 98 | ;(deftest check-MH 99 | ; (testing "check M-H sampling" 100 | ; (let [sampler (fn [i] 101 | ; (tell-travis "M-H") 102 | ; (gaussian-sample-value 103 | ; (lightweight-single-site-MH-sampling two-variable-gaussian-model 104 | ; [] 105 | ; target-trace 106 | ; n-mh-steps))) 107 | ; pdf target-density] 108 | ; (is (assay "m" sampler nsamples pdf nbins threshold))))) 109 | -------------------------------------------------------------------------------- /src/metaprob/examples/main.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.main 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:import [java.io File]) 4 | (:require [clojure.test :as test] 5 | [clojure.tools.cli :as cli] 6 | [metaprob.prelude :refer [apply replicate map]] 7 | [metaprob.examples.inference-on-gaussian :as ginf] 8 | [metaprob.examples.earthquake :as quake] 9 | [metaprob.examples.long-test])) 10 | 11 | ;; The file from which this one was derived is ../main.clj 12 | 13 | (defn- s-to-ns 14 | "Takes a number of seconds `n` and returns the equivalent number of 15 | nanoseconds." 16 | [n] 17 | (* n 1000 1000 1000)) 18 | 19 | (defn- instrument [fun & args] 20 | (flush) 21 | (if true 22 | (apply fun args) 23 | (comment 24 | (crit/report-result 25 | (crit/benchmark* 26 | #(apply fun args) 27 | {:warmup-jit-period 0 28 | :samples 1 29 | :target-execution-time (s-to-ns 10) 30 | :overhead 0}))))) 31 | 32 | (defn- parse-int 33 | "Parse an integer from a string." 34 | [x] 35 | (Integer/parseInt x)) 36 | 37 | (defn- greater-than 38 | "Returns a function to be used with `clojure.tools.cli/parse-opts`'s `:validate` 39 | option. Returns a validation setting that enforces that the parsed value be 40 | greater than `n`." 41 | [n] 42 | [#(< 1 n) (format "Must be greater than %d" n)]) 43 | 44 | (def any-of 45 | "Returns a function to be used with `clojure.tools.cli/parse-opts`'s 46 | `:default-fn` option. If any of the options passed to `any-of` are true then 47 | the default will be true." 48 | some-fn) 49 | 50 | (def none-of 51 | "Returns a function to be used with `clojure.tools.cli/parse-opts`'s 52 | `:default-fn` option. If none of the options passed to `any-of` are true then 53 | the default will be true." 54 | (comp complement any-of)) 55 | 56 | (def cli-options 57 | [["-a" "--all" "Run all the examples" :default false] 58 | ["-r" "--rejection" "Run the rejection sampling example" :default-fn :all] 59 | ["-i" "--importance" "Run the importance sampling example" :default-fn :all] 60 | ["-m" "--mh" "Run the Metropolis Hastings example" :default-fn :all] 61 | ["-q" "--quake" "Run the earthquake bayes net example" :default false] 62 | ["-t" "--test" "Run the long test example" :default false] 63 | 64 | ["-p" "--prior" "Run the prior example" 65 | :default-fn (any-of :rejection :importance :mh)] 66 | 67 | ["-s" "--samples SAMPLES" "Number of samples for all examples" 68 | :parse-fn parse-int 69 | :validate (greater-than 1) 70 | :default 1000] 71 | [nil "--gaussian-samples SAMPLES" "Number of gaussian samples" 72 | ;; For a more serious test, try 100 (takes about an hour?) 73 | :default-fn :samples 74 | :parse-fn parse-int 75 | :validate (greater-than 1)] 76 | [nil "--quake-samples SAMPLES" "Number of quake samples" 77 | :default-fn :samples 78 | :parse-fn parse-int 79 | :validate (greater-than 1)] 80 | [nil "--particles PARTICLES" "Number of particles" 81 | :default 1000 82 | :parse-fn parse-int 83 | :validate (greater-than 0)] 84 | [nil "--mh-count COUNT" "Metropolis Hastings count" 85 | :default 20 86 | :parse-fn parse-int 87 | :validate (greater-than 0)] 88 | 89 | ["-H" "--help" "Display this help message" 90 | :default-fn (none-of :rejection :importance :mh :quake :test :prior)]]) 91 | 92 | (defn- print-help 93 | [summary] 94 | (println "\nUSAGE:\n") 95 | (println "clojure -m" (namespace `print-help) "\n") 96 | (println summary)) 97 | 98 | (defn- print-header 99 | [header] 100 | (println (format "---- %s ----" header))) 101 | 102 | (defn -main 103 | "Runs examples and outputs samples and commands to `results/`. For a list of 104 | available options and their defaults see `clojure.tools.cli/parse-opts` and 105 | `cli-options`." 106 | [& args] 107 | (.mkdir (File. "results")) 108 | (let [{:keys [options summary]} 109 | (cli/parse-opts args cli-options)] 110 | (if (:help options) 111 | (print-help summary) 112 | (let [{:keys [mh-count particles samples gaussian-samples quake-samples]} 113 | options] 114 | (when (:test options) 115 | (test/run-tests 'metaprob.examples.long-test)) 116 | 117 | (when (:prior options) 118 | (print-header "Prior") 119 | (ginf/gaussian-histogram 120 | "samples from the gaussian demo prior" 121 | (instrument ginf/gaussian-prior-samples gaussian-samples))) 122 | 123 | (when (:rejection options) 124 | ;; Rejection sampling is very slow - 20 seconds per 125 | (print-header "Rejection") 126 | (ginf/gaussian-histogram 127 | "samples from the gaussian demo target" 128 | (instrument ginf/rejection-assay gaussian-samples))) 129 | 130 | (when (:importance options) 131 | ;; Importance sampling is very fast 132 | (print-header "Importance") 133 | (ginf/gaussian-histogram 134 | (format "importance sampling gaussian demo with %s particles" particles) 135 | (instrument ginf/importance-assay particles gaussian-samples))) 136 | 137 | ;(when (:mh options) 138 | ; ;; MH is fast 139 | ; (print-header "Metropolis Hastings") 140 | ; (ginf/gaussian-histogram 141 | ; (format "samples from gaussian demo lightweight single-site MH with %s iterations" 142 | ; mh-count) 143 | ; (instrument ginf/MH-assay mh-count gaussian-samples))) 144 | 145 | (when (:quake options) 146 | (print-header "Earthquake Bayesnet") 147 | ;; (quake/demo-earthquake) - doesn't work yet 148 | (quake/earthquake-histogram 149 | "bayesnet samples from rejection sampling" 150 | (map quake/trace-to-binary (quake/eq-rejection-assay quake-samples)))))))) 151 | -------------------------------------------------------------------------------- /src/metaprob/examples/multimixture_dsl.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.multimixture-dsl 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [metaprob.trace :refer :all] 4 | [metaprob.generative-functions :refer :all] 5 | [metaprob.prelude :refer [map log apply infer-and-score]] 6 | [metaprob.distributions :refer :all] 7 | [clojure.pprint :refer [pprint]] 8 | [metaprob.inference :refer :all])) 9 | 10 | 11 | ; ------------------- 12 | ; MULTI-MIXTURE MODEL 13 | ; ------------------- 14 | 15 | (defn get-cluster-addr 16 | [v] 17 | (str "cluster-for-" v)) 18 | 19 | 20 | (defn make-view 21 | [[vars-and-dists [cluster-probs cluster-params]]] 22 | (let [view-name (str "view" (gensym)) 23 | var-names (keys vars-and-dists) 24 | cluster-addr (get-cluster-addr view-name) 25 | 26 | ;; Generative model 27 | sampler 28 | (gen [] 29 | (let [cluster-idx (at cluster-addr categorical cluster-probs) 30 | params (nth cluster-params cluster-idx)] 31 | (map (fn [v] (apply-at v (get vars-and-dists v) (get params v))) var-names)))] 32 | (with-custom-proposal-attached 33 | sampler 34 | (fn [observations] 35 | (gen [] 36 | (let [score-cluster 37 | (fn [idx] 38 | (let [new-obs (trace-set-value observations cluster-addr idx) 39 | ;; Score should not depend on any of the stochastic 40 | ;; choices made by infer-and-score, so we leave this 41 | ;; untraced. 42 | [_ _ s] 43 | (infer-and-score 44 | :procedure sampler 45 | :observation-trace new-obs)] 46 | s)) 47 | 48 | cluster-scores 49 | (map score-cluster (range (count cluster-probs))) 50 | 51 | chosen-cluster 52 | (at cluster-addr log-categorical cluster-scores)] 53 | 54 | ;; Fill in the rest of the choices 55 | (at '() infer-and-score 56 | :procedure sampler 57 | :observation-trace 58 | (trace-set-value observations cluster-addr chosen-cluster))))) 59 | 60 | ;; Only use the custom proposal when we don't already know the cluster ID 61 | (fn [tr] (not (trace-has-value? tr cluster-addr)))))) 62 | 63 | (defn make-multi-mixture 64 | [views] 65 | (gen [] 66 | (apply concat (map (fn [view] (at '() view)) views)))) 67 | 68 | ; ------------------------ 69 | ; DOMAIN SPECIFIC LANGUAGE 70 | ; ------------------------ 71 | 72 | (defn multi-mixture 73 | [& viewspecs] 74 | (make-multi-mixture (map make-view viewspecs))) 75 | 76 | ; View specification constructor. 77 | (defn view [vars [probs params]] [vars [probs params]]) 78 | 79 | ; Cluster specification constructor. 80 | (defn clusters 81 | [& args] 82 | [(take-nth 2 args) 83 | (take-nth 2 (rest args))]) 84 | 85 | 86 | ;; Create an example model using the DSL 87 | (def example-model 88 | (multi-mixture 89 | (view 90 | {"x" gaussian, "y" gaussian} 91 | (clusters 92 | 0.3 {"x" [0 1], "y" [1 5]} 93 | 0.2 {"x" [0 10], "y" [10 15]} 94 | 0.5 {"x" [4 2], "y" [1 6]})) 95 | (view 96 | {"a" categorical "b" categorical} 97 | (clusters 98 | 0.8 {"a" [[0.1 0.4 0.5]] "b" [[0.1 0.9]]} 99 | 0.2 {"a" [[0.9 0.1 0.0]] "b" [[0.5 0.5]]})))) 100 | 101 | (defn -main 102 | [] 103 | (pprint (example-model)) 104 | (pprint (infer-and-score :procedure example-model 105 | :observation-trace {"a" {:value 0} 106 | "b" {:value 0}}))) -------------------------------------------------------------------------------- /src/metaprob/examples/spelling_correction.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.spelling-correction 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [metaprob.trace :refer :all] 4 | [metaprob.generative-functions :refer :all] 5 | [metaprob.prelude :refer [map expt replicate]] 6 | [clojure.pprint :refer [pprint]] 7 | [metaprob.distributions :refer :all] 8 | [clojure.string :refer [index-of]] 9 | [metaprob.inference :refer :all])) 10 | 11 | ;; A list of all English letters 12 | (def alphabet '("a" "b" "c" "d" "e" "f" "g" "h" "i" "j" "k" "l" "m" 13 | "n" "o" "p" "q" "r" "s" "t" "u" "v" "w" "x" "y" "z")) 14 | 15 | ;; Washington State cities: 16 | (def washington-cities {"aberdeen" 16255, "bellingham" 83365, "bremerton" 38572, "everett" 106736, "richland" 53019, 17 | "seattle" 668342, "spokane" 212052, "tacoma" 205159, "vancouver" 169294, "yakima" 93357}) 18 | 19 | ;; Helper functions for string manipulation 20 | (defn delete-at-index [s idx] (str (subs s 0 idx) (subs s (+ 1 idx)))) 21 | (defn insert-at-index [s idx c] (str (subs s 0 idx) c (subs s idx))) 22 | (defn replace-at-index [s idx c] (str (subs s 0 idx) c (subs s (+ 1 idx)))) 23 | 24 | ;; Introduce a single typo to a string 25 | (def add-error 26 | (gen [true-string] 27 | (let-traced [error-type (uniform-discrete [:deletion :insertion :substitution]) 28 | highest-possible-index (+ (count true-string) (if (= error-type :insertion) 1 0)) 29 | error-index (uniform-discrete (range highest-possible-index)) 30 | new-letter (when (not= error-type :deletion) (uniform-discrete alphabet))] 31 | (cond 32 | (= error-type :deletion) (delete-at-index true-string error-index) 33 | (= error-type :insertion) (insert-at-index true-string error-index new-letter) 34 | (= error-type :substitution) (replace-at-index true-string error-index new-letter))))) 35 | 36 | ;; Introduce zero or more typos to a string 37 | (def add-errors-model 38 | (gen [true-string] 39 | (let [helper (fn [current-string error-number] 40 | (if (at `(~error-number "done-adding-errors?") flip 0.9) 41 | (at "final-string" exactly current-string) 42 | (recur (at error-number add-error current-string) (+ error-number 1))))] 43 | (helper true-string 0)))) 44 | 45 | ;; Heuristics 46 | (defn distance-from-index 47 | [s idx c] 48 | (min 49 | (+ 1 (or (index-of (clojure.string/reverse (subs s 0 (min (count s) idx))) c) (count s))) 50 | (or (index-of (subs s (min (count s) idx)) c) (+ (count s) 1)))) 51 | 52 | ;; Probability, in the guide program, of adding this letter at this index 53 | (defn probability-of-letter 54 | [l observed-string error-index] 55 | (expt 0.01 (distance-from-index observed-string error-index l))) 56 | 57 | ;; Probabilities with which we will choose each letter for insertion at this index 58 | (defn letter-probabilities 59 | [observed-string error-index] 60 | (zipmap 61 | alphabet 62 | (map #(probability-of-letter % observed-string error-index) alphabet))) 63 | 64 | (def should-probably-add-typo? not=) 65 | 66 | (defn preferred-error-type 67 | [current-string observed-string] 68 | (cond 69 | (< (count current-string) (count observed-string)) :insertion 70 | (> (count current-string) (count observed-string)) :deletion 71 | (= (count current-string) (count observed-string)) :substitution)) 72 | 73 | 74 | (defn good-index-for-error? 75 | [i current-string observed-string] 76 | (or (>= i (count observed-string)) 77 | (<= (count current-string) i) 78 | (not= (subs current-string i (+ i 1)) (subs observed-string i (+ i 1))))) 79 | 80 | (defn score-error-type 81 | [error-type current-string observed-string] 82 | (if (= error-type (preferred-error-type current-string observed-string)) 83 | 1 0.05)) 84 | 85 | (defn make-smart-add-error 86 | [observed-string] 87 | (gen [current-string] 88 | (let-traced [;; Error types: 89 | error-types 90 | [:insertion :deletion :substitution] 91 | 92 | error-type-probabilities 93 | (zipmap 94 | error-types 95 | (map #(score-error-type % current-string observed-string) error-types)) 96 | 97 | error-type 98 | (categorical error-type-probabilities) 99 | 100 | ;; Error index 101 | error-index-upper-bound 102 | (+ (count current-string) (if (= error-type :insertion) 1 0)) 103 | 104 | index-probs 105 | (map 106 | (fn [i] (if (good-index-for-error? i current-string observed-string) 1 0.02)) 107 | (range error-index-upper-bound)) 108 | 109 | error-index 110 | (categorical index-probs) 111 | 112 | ;; Letter to introduce 113 | letter-probs 114 | (letter-probabilities observed-string error-index) 115 | 116 | new-letter 117 | (if (= error-type :deletion) "" (categorical letter-probs))] 118 | 119 | ;; Apply the typo 120 | (cond 121 | (= error-type :deletion) (delete-at-index current-string error-index) 122 | (= error-type :insertion) (insert-at-index current-string error-index new-letter) 123 | (= error-type :substitution) (replace-at-index current-string error-index new-letter))))) 124 | 125 | (defn make-add-errors-proposer 126 | [observation-trace] 127 | (let [observed-string (trace-value observation-trace "final-string") 128 | smart-add-error (make-smart-add-error observed-string)] 129 | 130 | ;; This is the "smart add-errors" 131 | (gen [true-string] 132 | (let [helper 133 | (fn [current-string i] 134 | (let [stop-prob (if (should-probably-add-typo? current-string observed-string) 0.000 0.999)] 135 | (if (at `(~i "done-adding-errors?") flip stop-prob) 136 | ;; If we are done adding errors, just return the current string 137 | (at "final-string" exactly current-string) 138 | ;; Otherwise, add another typo and loop 139 | (recur (at i smart-add-error current-string) 140 | (+ i 1)))))] 141 | 142 | ;; Start the process going 143 | (helper true-string 0))))) 144 | 145 | 146 | (def add-errors 147 | (with-custom-proposal-attached 148 | add-errors-model 149 | make-add-errors-proposer 150 | (fn [tr] (trace-has-value? tr "final-string")))) 151 | 152 | 153 | (defn add-errors-demo [] 154 | (println "Importance resampling with custom proposal") 155 | (pprint (importance-resampling-custom-proposal 156 | :model add-errors-model 157 | :proposer make-add-errors-proposer 158 | :inputs ["computer"] 159 | :observation-trace {"final-string" {:value "camptr"}} 160 | :n-particles 30)) 161 | (println "Importance resampling with internal proposal") 162 | (pprint (importance-resampling 163 | :model add-errors 164 | :inputs ["computer"] 165 | :observation-trace {"final-string" {:value "camptr"}} 166 | :n-particles 30))) 167 | 168 | 169 | 170 | ;; Choose a city and maybe add typos to it 171 | (def spelling-model 172 | (gen [] 173 | (let-traced [true-city (categorical washington-cities) 174 | with-typos (add-errors true-city)] 175 | with-typos))) 176 | 177 | (defn spelling-model-demo [] 178 | ;; Inferring that "seattle" was the true string 179 | (println "Observing 'satl'") 180 | (pprint 181 | (importance-resampling 182 | :model spelling-model 183 | :observation-trace {"with-typos" {"final-string" {:value "satl"}}} 184 | :n-particles 100)) 185 | (println "Observing 'spatkne'") 186 | (pprint 187 | (let [results 188 | (replicate 100 189 | (fn [] (trace-value 190 | (importance-resampling 191 | :model spelling-model 192 | :observation-trace {"with-typos" {"final-string" {:value "spatkne"}}} 193 | :n-particles 30) 194 | "true-city")))] 195 | (str "Seattle: " (count (filter (fn [city] (= city "seattle")) results)) ", " 196 | "Spokane: " (count (filter (fn [city] (= city "spokane")) results)))))) 197 | 198 | (defn -main [] 199 | (add-errors-demo) 200 | (spelling-model-demo)) -------------------------------------------------------------------------------- /src/metaprob/expander.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.expander 2 | (:require #?(:cljs [cljs.analyzer :as ana]) 3 | [metaprob.code-handlers :as code])) 4 | 5 | (declare mp-expand) 6 | 7 | ;; Translate a Clojure fn* expression, which defines a potentially 8 | ;; anonymous function that dispatches to different function bodies 9 | ;; based on number of arguments. Our Metaprob implementation creates 10 | ;; a variadic (gen [& args] ...) that dispatches on (count args). 11 | (defn convert-fn*-exp 12 | #?(:clj [exp] :cljs [env exp]) 13 | (let [mp-expand #?(:clj mp-expand :cljs #(mp-expand env %)) 14 | name (if (symbol? (second exp)) (second exp) nil), 15 | 16 | ;; Rest of the fn* expression, without the word fn* or 17 | ;; the (optional) function name 18 | exp-wo-name (if name (rest (rest exp)) (rest exp)), 19 | 20 | ;; We are looking at either 21 | ;; ([params...] body) or 22 | ;; (([params1...] body1) ([params2...] body2)...) 23 | sigs (if (vector? (first exp-wo-name)) 24 | (list exp-wo-name) ;; One signature 25 | exp-wo-name), ;; Multiple signatures 26 | 27 | ;; Helpers for counting arguments in a signature. 28 | count-vars (fn [pat] (count (take-while (partial not= '&) pat))), 29 | variadic? (fn [pat] (contains? pat '&)), 30 | 31 | ;; Come up with a name for the single variadic argument of the new fn 32 | argname (gensym), 33 | 34 | ;; arities: a map from number-of-arguments to signature-to-execute 35 | ;; variadic-n: either false, if this function has no variadic signature, 36 | ;; or the number of non-variadic args in the variadic signature 37 | [arities variadic-n] 38 | (reduce 39 | (fn [[arities variadic-n] sig] 40 | (let [pat (first sig)] 41 | (assert (not (contains? arities (count-vars pat))) "Repeat arity in fn*") 42 | (assert (or (not (variadic? pat)) (not variadic-n)) "Only one variadic allowed in fn*") 43 | (assert (or (not (variadic? pat)) (> (count-vars pat) (apply max (keys pat)))) 44 | "Non-variadic cannot have more args than variadic implementation") 45 | [(assoc arities 46 | (count-vars pat) 47 | `(let 48 | [~pat ~argname] 49 | ~@(rest sig))), 50 | (if (variadic? pat) (count-vars pat) variadic-n)])) 51 | [{} false] sigs), 52 | 53 | ;; Generate cond clauses for the body of the synthesized 54 | ;; gen expression. 55 | clauses 56 | (apply concat 57 | (concat 58 | (map 59 | (fn [a] 60 | `((= (count ~argname) ~a) ~(mp-expand (get arities a)))) 61 | (sort (keys arities))) 62 | (list (if variadic-n 63 | `(true ~(mp-expand (get arities variadic-n))) 64 | `(true (~'assert false "Wrong arity")))))), 65 | 66 | fn-expr (if (= (count sigs) 1) 67 | (cons 'fn 68 | (cons (first (first sigs)) 69 | (map mp-expand 70 | (rest (first sigs))))) 71 | `(~'fn [& ~argname] (cond ~@clauses)))] 72 | 73 | (if name 74 | (cons (first fn-expr) (cons name (rest fn-expr))) 75 | fn-expr))) 76 | 77 | (defn map-from-pairs 78 | [& pairs] 79 | (into {} pairs)) 80 | 81 | (defn vector-from-elems 82 | [& elems] 83 | (vec elems)) 84 | 85 | (defn expand-let-expr 86 | #?(:clj [form] :cljs [env form]) 87 | (if (empty? (code/let-bindings form)) 88 | (mp-expand #?(:cljs env) 89 | `((~'fn [] ~@(code/let-body form)))) 90 | (let [[first-name first-value] (first (code/let-bindings form)) 91 | other-bindings (apply concat (rest (code/let-bindings form)))] 92 | (mp-expand #?(:cljs env) 93 | `((~'fn [~first-name] 94 | (let [~@other-bindings] 95 | ~@(code/let-body form))) ~first-value))))) 96 | 97 | 98 | (defn mp-expand 99 | #?(:clj [form] :cljs [env form]) 100 | ;; (pprint/pprint ["Expanding" form]) 101 | (let [mp-expand #?(:clj mp-expand :cljs #(mp-expand env %))] 102 | (cond 103 | ;; Vector literals 104 | (vector? form) 105 | `(vector-from-elems ~@(map mp-expand form)) 106 | 107 | ;; Map literals 108 | (map? form) 109 | `(map-from-pairs ~@(map (fn [[k v]] `(vector-from-elems ~(mp-expand k) ~(mp-expand v))) form)) 110 | 111 | ;; Other literals 112 | (or (nil? form) (not (seq? form))) 113 | form 114 | 115 | true 116 | (case (if (symbol? (first form)) (name (first form)) "") 117 | ;; Metaprob special forms 118 | "quote" 119 | form 120 | 121 | "gen" 122 | (code/map-gen mp-expand form) 123 | 124 | "do" 125 | (mp-expand `((fn* [] ~@(rest form)))) 126 | ;; (cons 'do (doall (map mp-expand (rest form)))) 127 | 128 | "let" 129 | #?(:clj (expand-let-expr form) 130 | :cljs (expand-let-expr env form)) 131 | 132 | "letfn" 133 | form 134 | 135 | "if" 136 | (map mp-expand form) 137 | 138 | ;; Clojure special forms that need translating 139 | "let*" 140 | (mp-expand (cons 'let (rest form))) 141 | 142 | "letfn*" 143 | ;; TODO: Replace with letgen when we have it. 144 | (mp-expand (cons 'let (rest form))) 145 | 146 | "fn*" 147 | ;; We need to handle cases where the `fn` has a name (and 148 | ;; therefore may be recursive) and also cases where it may have 149 | ;; more than one arity defined. 150 | ;; no recursive call, because convert-fn*-exp handles all expansion 151 | #?(:clj (convert-fn*-exp form) 152 | :cljs (convert-fn*-exp env form)) 153 | 154 | "loop" 155 | (throw (ex-info "Cannot use loop* in Metaprob." {})) 156 | 157 | "case*" 158 | (throw (ex-info "Cannot use case* in Metaprob." {})) 159 | 160 | "throw" 161 | `(~'assert false ~(str "Clojure throw statement encountered: " form)) 162 | 163 | ;; TOTAL HACK and should be removed or made to work: 164 | "." (cons '. (map mp-expand (rest form))) 165 | 166 | ("new" "monitor-exit" "monitor-enter" "try" "finally" 167 | "import*" "deftype*" "set!" "var" "catch" "def" "reify*") 168 | (throw (ex-info "mp-expand encountered unsupported Clojure special form" 169 | {:form form})) 170 | 171 | ;; It's a function or macro... 172 | (let [next #?(:clj (macroexpand-1 form) 173 | :cljs (ana/macroexpand-1 env form))] 174 | (if (= next form) 175 | (map mp-expand form) 176 | (mp-expand next))))))) 177 | -------------------------------------------------------------------------------- /src/metaprob/generative_functions.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.generative-functions 2 | #?(:cljs (:require-macros [metaprob.generative-functions :refer [gen]])) 3 | (:require #?(:cljs [cljs.analyzer :as ana]) 4 | [metaprob.code-handlers :as code] 5 | [metaprob.trace :as trace])) 6 | 7 | (defn at [& args] (assert false "Cannot invoke at outside of a (gen ...) form.")) 8 | 9 | (defn apply-at [& args] (assert false "Cannot invoke apply-at outside of a (gen ...) form.")) 10 | 11 | ;; Most general way of creating a generative function: provide implementations of its 12 | ;; methods. All other ways of creating generative functions boil down, ultimately, to 13 | ;; a call to this function. 14 | (defn make-generative-function 15 | ([run-in-clojure make-constrained-generator] 16 | (make-generative-function run-in-clojure make-constrained-generator {})) 17 | ([run-in-clojure make-constrained-generator others] 18 | (with-meta run-in-clojure (assoc others :make-constrained-generator make-constrained-generator)))) 19 | 20 | ;; make-constrained-generator : generative function, observation trace -> generative function 21 | (defn make-constrained-generator [procedure observations] 22 | ((or (get (meta procedure) :make-constrained-generator) 23 | (fn [observations] 24 | (fn [& args] 25 | [(apply procedure args) {} 0]))) observations)) 26 | 27 | (defn generative-function-from-traced-code 28 | [fn-accepting-tracers metadata] 29 | (make-generative-function 30 | ;; Implementation of "run in Clojure" 31 | (fn-accepting-tracers 32 | ;; Implementation of `at`: 33 | (fn [addr proc & args] (apply proc args)) 34 | ;; Implementation of `apply-at`: 35 | (fn [addr proc args] (apply proc args))) 36 | 37 | ;; Implementation of "make constrained generator" 38 | (fn [observations] 39 | (generative-function-from-traced-code 40 | (fn [at apply-at] 41 | (fn [& args] 42 | (let [score (atom 0) 43 | trace (atom {}) 44 | apply-at-impl 45 | (fn [addr gf args] 46 | (let [[v tr s] (apply-at addr (make-constrained-generator gf (trace/maybe-subtrace observations addr)) args)] 47 | (swap! score + s) 48 | (swap! trace trace/merge-subtrace addr tr) 49 | v)) 50 | at-impl 51 | (fn [addr gf & args] (apply-at-impl addr gf args)) 52 | result (apply (fn-accepting-tracers at-impl apply-at-impl) args)] 53 | [result (deref trace) (deref score)]))) 54 | {:name 'make-constrained-generator-impl})) 55 | 56 | ;; Additional metadata 57 | metadata)) 58 | 59 | ;; Create a generative function using executable code that samples from other generative functions. 60 | (defmacro gen [& _] 61 | {:style/indent 1} 62 | (let [expr 63 | &form 64 | 65 | body 66 | (code/gen-body expr) 67 | 68 | name 69 | (code/gen-name expr) 70 | 71 | tracer-name 72 | 'at 73 | 74 | apply-tracer-name 75 | 'apply-at 76 | 77 | params 78 | (code/gen-pattern expr) 79 | 80 | thunk-name 81 | (if name (gensym (str name "thunk")) nil) 82 | 83 | named-fn-body 84 | (if name 85 | `((let [~name (~thunk-name)] 86 | ~@body)) 87 | body) 88 | 89 | innermost-fn-expr 90 | `(fn ~params ~@named-fn-body) 91 | 92 | generative-function-expression 93 | `(generative-function-from-traced-code 94 | (fn [~tracer-name ~apply-tracer-name] ~innermost-fn-expr) 95 | {:name '~name, :generative-source '~expr})] 96 | 97 | (if name 98 | `((fn ~thunk-name [] ~generative-function-expression)) 99 | generative-function-expression))) 100 | 101 | (defmacro let-traced [bindings & body] 102 | (let [binding-pairs (partition 2 bindings) 103 | 104 | trace-with-name 105 | (fn trace-with-name [expr name] 106 | (cond 107 | (code/if-expr? expr) 108 | `(if ~(trace-with-name (code/if-predicate expr) name) 109 | ~(trace-with-name (code/if-then-clause expr) name) 110 | ~(trace-with-name (code/if-else-clause expr) name)) 111 | 112 | (code/do-expr? expr) 113 | (cons 'do (map #(trace-with-name % name) (rest expr))) 114 | 115 | (or (not (seq? expr)) 116 | (special-symbol? (first expr)) 117 | (code/let-expr? expr) 118 | (code/let-traced-expr? expr) 119 | (code/fn-expr? expr) 120 | (code/gen-expr? expr)) 121 | expr 122 | 123 | ;; If a macro, try expanding 124 | (not= #?(:clj (macroexpand-1 expr) 125 | :cljs (ana/macroexpand-1 &env expr)) 126 | expr) 127 | (recur #?(:clj (macroexpand-1 expr) 128 | :cljs (ana/macroexpand-1 &env expr)) 129 | name) 130 | 131 | true 132 | `(~'at ~name ~@expr))) 133 | 134 | convert-binding 135 | (fn [[lhs rhs]] 136 | (if (symbol? lhs) 137 | [lhs (trace-with-name rhs (str lhs))] 138 | [lhs rhs])) 139 | 140 | new-bindings (vec (apply concat (map convert-binding binding-pairs)))] 141 | `(let ~new-bindings ~@body))) 142 | -------------------------------------------------------------------------------- /src/metaprob/inference.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.inference 2 | (:refer-clojure :exclude [map replicate]) 3 | (:require [metaprob.distributions :as dist] 4 | [metaprob.generative-functions :as gen :refer [gen]] 5 | [metaprob.prelude :as mp :refer [map replicate]] 6 | [metaprob.trace :as trace])) 7 | 8 | ;; Probabilistic inference methods 9 | ;; TODO: Make these all generative functions with sensible tracing 10 | ;; ---------------------------------------------------------------------------- 11 | 12 | ;; Rejection Sampling with predicate or log bound+obs trace 13 | (defn rejection-sampling 14 | [& {:keys [model inputs observation-trace predicate log-bound] 15 | :or {inputs [] observation-trace {}}}] 16 | (let [[_ candidate-trace score] 17 | (mp/infer-and-score :procedure model 18 | :inputs inputs 19 | :observation-trace observation-trace)] 20 | (cond 21 | predicate (if (predicate candidate-trace) 22 | candidate-trace 23 | (recur (list :model model :inputs inputs :predicate predicate))) 24 | log-bound (if (< (mp/log (dist/uniform 0 1)) (- score log-bound)) 25 | candidate-trace 26 | (recur (list :model model :inputs inputs :observation-trace observation-trace :log-bound log-bound)))))) 27 | 28 | ;; ---------------------------------------------------------------------------- 29 | 30 | (defn importance-sampling 31 | [& {:keys [model inputs f observation-trace n-particles] 32 | :or {n-particles 1, inputs [], observation-trace {}}}] 33 | (let [particles (replicate 34 | n-particles 35 | (fn [] 36 | (let [[v t s] (mp/infer-and-score :procedure model 37 | :observation-trace observation-trace 38 | :inputs inputs)] 39 | [(* (mp/exp s) 40 | (if f (f t) v) 41 | s)]))) 42 | normalizer (mp/exp (dist/logsumexp (map second particles)))] 43 | 44 | (/ (reduce + (map first particles)) normalizer))) 45 | 46 | (defn importance-resampling 47 | [& {:keys [model inputs observation-trace n-particles] 48 | :or {inputs [], observation-trace {}, n-particles 1}}] 49 | (let [particles (replicate n-particles 50 | (fn [] 51 | (let [[_ t s] (mp/infer-and-score :procedure model 52 | :inputs inputs 53 | :observation-trace observation-trace)] 54 | [t s])))] 55 | (first (nth particles (dist/log-categorical (map second particles)))))) 56 | 57 | 58 | (defn likelihood-weighting 59 | [& {:keys [model inputs observation-trace n-particles] 60 | :or {inputs [] observation-trace {} n-particles 1}}] 61 | (let [weights (replicate n-particles 62 | (fn [] 63 | (let [[_ _ s] (mp/infer-and-score :procedure model 64 | :inputs inputs 65 | :observation-trace observation-trace)] 66 | s)))] 67 | (mp/exp (dist/logmeanexp weights)))) 68 | 69 | ;; TODO: Document requirements on proposer 70 | (defn importance-resampling-custom-proposal 71 | [& {:keys [model proposer inputs observation-trace n-particles] 72 | :or {inputs [], observation-trace {}, n-particles 1}}] 73 | (let [custom-proposal 74 | (proposer observation-trace) 75 | 76 | proposed-traces 77 | (replicate n-particles 78 | (fn [] 79 | (let [[_ t _] 80 | (mp/infer-and-score :procedure custom-proposal 81 | :inputs inputs)] 82 | (trace/trace-merge t observation-trace)))) 83 | 84 | scores 85 | (map (fn [tr] 86 | (- (nth (mp/infer-and-score :procedure model :inputs inputs :observation-trace tr) 2) 87 | (nth (mp/infer-and-score :procedure custom-proposal :inputs inputs :observation-trace tr) 2))) 88 | proposed-traces)] 89 | 90 | (nth proposed-traces (dist/log-categorical scores)))) 91 | 92 | ;; Custom proposal must not trace at any additional addresses. 93 | ;; Check with Marco about whether this can be relaxed; but I think 94 | ;; we need exact p/q estimates. 95 | (defn with-custom-proposal-attached 96 | [orig-generative-function make-custom-proposer condition-for-use] 97 | (gen/make-generative-function 98 | ;; To run in Clojure, use the same method as before: 99 | orig-generative-function 100 | 101 | ;; To create a constrained generator, first check if 102 | ;; the condition for using the custom proposal holds. 103 | ;; If so, use it and score it. 104 | ;; Otherwise, use the original make-constrained-generator 105 | ;; implementation. 106 | (fn [observations] 107 | (if (condition-for-use observations) 108 | (gen [& args] 109 | (let [custom-proposal 110 | (make-custom-proposer observations) 111 | 112 | ;; TODO: allow/require custom-proposal to specify which addresses it is proposing vs. sampling otherwise? 113 | [_ tr _] 114 | (at '() mp/infer-and-score 115 | :procedure custom-proposal 116 | :inputs args) 117 | 118 | proposed-trace 119 | (trace/trace-merge observations tr) 120 | 121 | [v tr2 p-score] 122 | (mp/infer-and-score :procedure orig-generative-function 123 | :inputs args 124 | :observation-trace proposed-trace) 125 | 126 | [_ _ q-score] 127 | (mp/infer-and-score :procedure custom-proposal 128 | :inputs args 129 | :observation-trace proposed-trace)] 130 | [v proposed-trace (- p-score q-score)])) 131 | (gen/make-constrained-generator orig-generative-function observations))))) 132 | 133 | ;;; ---------------------------------------------------------------------------- 134 | ;;; Metropolis-Hastings 135 | 136 | (defn symmetric-proposal-mh-step 137 | [& {:keys [model inputs proposal] 138 | :or {inputs []}}] 139 | (fn [current-trace] 140 | (let [[_ _ current-trace-score] 141 | (mp/infer-and-score :procedure model 142 | :inputs inputs 143 | :observation-trace current-trace) 144 | 145 | proposed-trace 146 | (proposal current-trace) 147 | 148 | [_ _ proposed-trace-score] 149 | (mp/infer-and-score :procedure model 150 | :inputs inputs 151 | :observation-trace proposed-trace) 152 | 153 | log-acceptance-ratio 154 | (min 0 (- proposed-trace-score current-trace-score))] 155 | 156 | (if (dist/flip (mp/exp log-acceptance-ratio)) 157 | proposed-trace 158 | current-trace)))) 159 | 160 | (defn make-gaussian-drift-proposal 161 | [addresses width] 162 | (fn [current-trace] 163 | (reduce 164 | (fn [tr addr] 165 | (trace/trace-set-value tr addr (dist/gaussian (trace/trace-value tr addr) width))) 166 | current-trace 167 | addresses))) 168 | 169 | (defn gaussian-drift-mh-step 170 | [& {:keys [model inputs addresses address-predicate width] 171 | :or {inputs [], width 0.1}}] 172 | (let [proposal (if addresses 173 | (fn [tr] (make-gaussian-drift-proposal addresses width)) 174 | (fn [tr] (make-gaussian-drift-proposal (filter address-predicate (trace/addresses-of tr)) width)))] 175 | (fn [tr] 176 | ((symmetric-proposal-mh-step :model model, :inputs inputs, :proposal (proposal tr)) tr)))) 177 | 178 | (defn custom-proposal-mh-step [& {:keys [model inputs proposal], :or {inputs []}}] 179 | (fn [current-trace] 180 | (let [[_ _ current-trace-score] ;; Evaluate log p(t) 181 | (mp/infer-and-score :procedure model 182 | :inputs inputs 183 | :observation-trace current-trace) 184 | 185 | [proposed-trace all-proposer-choices _] ;; Sample t' ~ q(• <- t) 186 | (mp/infer-and-score :procedure proposal 187 | :inputs [current-trace]) 188 | 189 | [_ _ new-trace-score] ;; Evaluate log p(t') 190 | (mp/infer-and-score :procedure model 191 | :inputs inputs 192 | :observation-trace proposed-trace) 193 | 194 | [_ _ forward-proposal-score] ;; Estimate log q(t' <- t) 195 | (mp/infer-and-score :procedure proposal 196 | :inputs [current-trace] 197 | :observation-trace all-proposer-choices) 198 | 199 | [_ _ backward-proposal-score] ;; Estimate log q(t <- t') 200 | (mp/infer-and-score :procedure proposal 201 | :inputs [proposed-trace] 202 | :observation-trace proposed-trace) 203 | 204 | log-acceptance-ratio ;; Compute estimate of log [p(t')q(t <- t') / p(t)q(t' <- t)] 205 | (- (+ new-trace-score backward-proposal-score) 206 | (+ current-trace-score forward-proposal-score))] 207 | 208 | (if (dist/flip (mp/exp log-acceptance-ratio)) ;; Decide whether to accept or reject 209 | proposed-trace 210 | current-trace)))) 211 | 212 | (defn make-gibbs-step 213 | [& {:keys [model address support inputs] 214 | :or {inputs []}}] 215 | (fn [current-trace] 216 | (let [log-scores 217 | (map (fn [value] 218 | (nth (mp/infer-and-score :procedure model 219 | :inputs inputs 220 | :observation-trace 221 | (trace/trace-set-value current-trace address value)) 2)) 222 | support)] 223 | (trace/trace-set-value 224 | current-trace 225 | address 226 | (nth support (dist/log-categorical log-scores)))))) 227 | 228 | (def make-resimulation-proposal 229 | (fn [& {:keys [model inputs addresses address-predicate] 230 | :or {inputs []}}] 231 | (let [get-addresses 232 | (if address-predicate 233 | (fn [tr] (filter address-predicate (trace/addresses-of tr))) 234 | (fn [tr] addresses))] 235 | (gen [old-trace] 236 | (let [addresses 237 | (get-addresses old-trace) 238 | 239 | [_ fixed-choices] 240 | (trace/partition-trace old-trace addresses) 241 | 242 | constrained-generator 243 | (gen/make-constrained-generator model fixed-choices) 244 | 245 | [_ new-trace _] 246 | (apply-at '() constrained-generator inputs)] 247 | new-trace))))) 248 | 249 | (defn resimulation-mh-move 250 | [model inputs tr addresses] 251 | (let [[current-choices fixed-choices] 252 | (trace/partition-trace tr addresses) 253 | 254 | ;; Get the log probability of the current trace 255 | [_ _ old-p] 256 | (mp/infer-and-score :procedure model, 257 | :inputs inputs, 258 | :observation-trace tr) 259 | 260 | ;; Propose a new trace, and get its score 261 | [_ proposed new-p-over-forward-q] 262 | (mp/infer-and-score :procedure model, 263 | :inputs inputs, 264 | :observation-trace fixed-choices) 265 | 266 | ;; Figure out what the reverse problem would look like 267 | [_ reverse-move-starting-point] 268 | (trace/partition-trace proposed addresses) 269 | 270 | ;; Compute a reverse score 271 | [_ _ reverse-q] 272 | (mp/infer-and-score :procedure mp/infer-and-score, 273 | :inputs [:procedure model, 274 | :inputs inputs, 275 | :observation-trace reverse-move-starting-point] 276 | :observation-trace current-choices) 277 | 278 | log-ratio 279 | (+ new-p-over-forward-q (- reverse-q old-p))] 280 | 281 | (if (dist/flip (mp/exp log-ratio)) 282 | proposed 283 | tr))) 284 | 285 | #_(define single-site-metropolis-hastings-step 286 | (gen [model-procedure inputs trace constraint-addresses] 287 | 288 | ;; choose an address to modify, uniformly at random 289 | 290 | (define choice-addresses (trace/addresses-of trace)) 291 | (define candidates (set-difference choice-addresses constraint-addresses)) 292 | (define target-address (uniform-sample candidates)) 293 | 294 | ;; generate a proposal trace 295 | 296 | (define initial-value (trace/trace-value trace target-address)) 297 | (define initial-num-choices (count candidates)) 298 | (define new-target (trace-clear-value trace target-address)) 299 | 300 | (define [_ new-trace forward-score] 301 | (comp/infer-apply model-procedure inputs (make-top-level-tracing-context {} new-target))) 302 | (define new-value (trace/trace-value new-trace target-address)) 303 | 304 | ;; the proposal is to move from trace to new-trace 305 | ;; now calculate the Metropolis-Hastings acceptance ratio 306 | 307 | (define new-choice-addresses (trace/addresses-of new-trace)) 308 | (define new-candidates (set-difference new-choice-addresses constraint-addresses)) 309 | (define new-num-choices (count new-candidates)) 310 | 311 | ;; make a trace that can be used to restore the original trace 312 | (define restoring-trace 313 | (trace/trace-set-value 314 | (clojure.core/reduce 315 | (gen [so-far next-adr] (trace/trace-set-value so-far next-adr (trace/trace-value trace next-adr))) 316 | {} 317 | (set-difference choice-addresses new-choice-addresses)) 318 | target-address initial-value)) 319 | 320 | ;; remove the new value 321 | (define new-target-rev (trace-clear-value new-trace target-address)) 322 | 323 | (define [_ _ reverse-score] 324 | (gen/infer :procedure model-procedure 325 | :inputs inputs 326 | :intervention-trace restoring-trace 327 | :target-trace new-target-rev)) 328 | 329 | (define log-acceptance-probability 330 | (- (+ forward-score (log new-num-choices)) 331 | (+ reverse-score (log initial-num-choices)))) 332 | 333 | (if (dist/flip (mp/exp log-acceptance-probability)) 334 | new-trace 335 | trace))) 336 | 337 | ;; Should return [output-trace value] ... 338 | 339 | #_(define lightweight-single-site-MH-sampling 340 | (gen [model-procedure inputs target-trace N] 341 | (clojure.core/reduce 342 | (gen [state _] 343 | ;; VKM had keywords :procedure :inputs :trace :constraint-addresses 344 | (single-site-metropolis-hastings-step 345 | model-procedure inputs state (trace/addresses-of target-trace))) 346 | (nth (gen/infer :procedure model-procedure :inputs inputs :target-trace target-trace) 1) 347 | (range N)))) 348 | 349 | ;; ----------------------------------------------------------------------------- 350 | ;; Utilities for checking that inference is giving acceptable sample sets. 351 | ;; These are used in the test suites. 352 | 353 | #_(declare sillyplot) 354 | 355 | (defn sillyplot 356 | [l] 357 | (let [nbins (count l) 358 | trimmed 359 | (if (> nbins 50) 360 | (take (drop l (/ (- nbins 50) 2)) 50) 361 | l)] 362 | (println (str (vec (map (fn [p] p) trimmed)))))) 363 | 364 | ;; A direct port of jar's old code 365 | (defn check-bins-against-pdf 366 | [bins pdf] 367 | (let [n-samples 368 | (reduce + (map count bins)) 369 | 370 | abs #(Math/abs %) 371 | 372 | bin-p 373 | (map (fn [bin] (/ (count bin) (float n-samples))) bins) 374 | 375 | bin-q 376 | (map (fn [bin] 377 | (let [bincount (count bin)] 378 | (* (/ (reduce + (map pdf bin)) bincount) 379 | (* (- (nth bin (- bincount 1)) 380 | (nth bin 0)) 381 | (/ (+ bincount 1) 382 | (* bincount 1.0)))))) 383 | bins) 384 | 385 | discrepancies 386 | (clojure.core/map #(abs (- %1 %2)) bin-p bin-q) 387 | 388 | trimmed (rest (reverse (rest discrepancies))) 389 | 390 | normalization (/ (count discrepancies) (* (count trimmed) 1.0))] 391 | [(* normalization (reduce + trimmed)) 392 | bin-p bin-q])) 393 | 394 | (defn check-samples-against-pdf 395 | [samples pdf nbins] 396 | (let [samples 397 | (vec (sort samples)) 398 | 399 | n-samples 400 | (count samples) 401 | 402 | bin-size 403 | (/ n-samples (float nbins)) 404 | 405 | bins 406 | (map (fn [i] 407 | (let [start (int (* i bin-size)) 408 | end (int (* (inc i) bin-size))] 409 | (subvec samples start end))) 410 | (range nbins))] 411 | (check-bins-against-pdf bins pdf))) 412 | 413 | (defn report-on-elapsed-time [tag thunk] 414 | (let [time #?(:clj #(System/nanoTime) 415 | :cljs #(.getTime (js/Date.))) 416 | start (time) 417 | ret (thunk) 418 | t (Math/round (/ (double (- (time) 419 | start)) 420 | #?(:clj 1000000000.0 ; nanoseconds 421 | :cljs 1000000.0)))] ; milliseconds 422 | (if (> t 1) 423 | (print (str tag ": elapsed time " t " sec\n"))) 424 | ret)) 425 | 426 | (defn assay 427 | [tag sampler nsamples pdf nbins threshold] 428 | (report-on-elapsed-time 429 | tag 430 | (fn [] 431 | (let [[badness bin-p bin-q] 432 | (check-samples-against-pdf (map sampler (range nsamples)) 433 | pdf nbins)] 434 | (if (or (> badness threshold) 435 | (< badness (/ threshold 2))) 436 | (do (println (str tag "." 437 | " n: " nsamples 438 | " bins: " nbins 439 | " badness: " badness 440 | " threshold: " threshold)) 441 | (sillyplot bin-p) 442 | (sillyplot bin-q))) 443 | (< badness (* threshold 1.5)))))) 444 | 445 | (defn badness 446 | [sampler nsamples pdf nbins] 447 | (let [[badness bin-p bin-q] 448 | (check-samples-against-pdf (map sampler (range nsamples)) 449 | pdf nbins)] 450 | badness)) 451 | -------------------------------------------------------------------------------- /src/metaprob/prelude.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.prelude 2 | "This module is intended for import by Metaprob code." 3 | (:refer-clojure :exclude [map reduce apply replicate]) 4 | (:require #?(:clj [clojure.java.io :as io]) 5 | [clojure.set :as set] 6 | [metaprob.trace :as trace] 7 | [metaprob.generative-functions :refer [gen make-generative-function make-constrained-generator]]) 8 | #?(:clj (:import [java.util Random]))) 9 | 10 | 11 | ;; Useful math 12 | (defn exp [x] (Math/exp x)) 13 | (defn expt [x y] (Math/pow x y)) 14 | (defn sqrt [x] (Math/sqrt x)) 15 | (defn log [x] (Math/log x)) 16 | (defn cos [x] (Math/cos x)) 17 | (defn sin [x] (Math/sin x)) 18 | (defn log1p [x] (Math/log1p x)) 19 | (defn floor [x] (Math/floor x)) 20 | (defn round [x] (Math/round x)) 21 | (def negative-infinity #?(:clj Double/NEGATIVE_INFINITY 22 | :cljs js/Number.NEGATIVE_INFINITY)) 23 | 24 | ;; Randomness 25 | #?(:clj (defonce ^:dynamic *rng* (Random. 42))) 26 | (defn sample-uniform 27 | ([] #?(:clj (.nextDouble *rng*) 28 | :cljs (js/Math.random))) 29 | ([a b] (+ a 30 | (* #?(:clj (.nextDouble *rng*) 31 | :cljs (js/Math.random)) 32 | (- b a))))) 33 | 34 | ;; Set difference 35 | (defn set-difference [s1 s2] 36 | (seq (set/difference (set s1) (set s2)))) 37 | 38 | ;; Apply 39 | (def apply 40 | (with-meta clojure.core/apply {:apply? true})) 41 | 42 | 43 | ;; Eager, generative-function versions of common list functions 44 | (def map 45 | (gen [f l] 46 | (doall (map-indexed (fn [i x] (at i f x)) l)))) 47 | 48 | (def replicate 49 | (gen [n f] 50 | (map (fn [i] (at i f)) (range n)))) 51 | 52 | (defn doall* 53 | [s] 54 | (dorun (tree-seq seq? seq s)) s) 55 | 56 | ;; ----------------------------------------------------------------------------- 57 | ;; Graphical output (via gnuplot or whatever) 58 | 59 | #?(:clj (defn binned-histogram [& {:keys [name samples overlay-densities 60 | sample-lower-bound sample-upper-bound 61 | number-of-intervals]}] 62 | (let [samples (seq samples) 63 | sample-lower-bound (or sample-lower-bound -5) 64 | sample-upper-bound (or sample-upper-bound 5) 65 | number-of-intervals (or number-of-intervals 20) 66 | fname (clojure.string/replace name " " "_") 67 | path (str "results/" fname ".samples") 68 | commands-path (str path ".commands")] 69 | (print (format "Writing commands to %s for histogram generation\n" commands-path)) 70 | ;;(print (format " overlay-densities = %s\n" (freeze overlay-densities))) 71 | (with-open [writor (io/writer commands-path)] 72 | (.write writor (format "reset\n")) 73 | (.write writor (format "min=%s.\n" sample-lower-bound)) 74 | (.write writor (format "max=%s.\n" sample-upper-bound)) 75 | (.write writor (format "n=%s\n" number-of-intervals)) 76 | (.close writor)) 77 | (print (format "Writing samples to %s\n" path)) 78 | (with-open [writor (io/writer path)] 79 | (doseq [sample samples] 80 | (.write writor (str sample)) 81 | (.write writor "\n")) 82 | (.close writor))))) 83 | 84 | ;; (defn print-source [f] (clojure.pprint/pprint (get (meta f) :generative-source))) 85 | 86 | ;; Create a "primitive" generative function out of a sampler and scorer 87 | (defn make-primitive [sampler scorer] 88 | (make-generative-function 89 | sampler 90 | (fn [observations] 91 | (if (trace/trace-has-value? observations) 92 | (gen [& args] 93 | [(trace/trace-value observations) 94 | {:value (trace/trace-value observations)} 95 | (scorer (trace/trace-value observations) args)]) 96 | (gen [& args] 97 | (let [result (apply-at '() (make-primitive sampler scorer) args)] 98 | [result {:value result} 0])))))) 99 | 100 | 101 | (def infer-and-score 102 | (gen [& {:keys [procedure inputs observation-trace] 103 | :or {inputs [], observation-trace {}}}] 104 | (apply-at '() (make-constrained-generator procedure observation-trace) inputs))) 105 | -------------------------------------------------------------------------------- /src/metaprob/trace.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.trace) 2 | 3 | (defn trace-subtrace [tr adr] 4 | ((if (seq? adr) get-in get) tr adr)) 5 | 6 | (defn trace-has-value? 7 | ([tr] (contains? tr :value)) 8 | ([tr adr] (contains? (trace-subtrace tr adr) :value))) 9 | 10 | (defn trace-value 11 | ([tr] (get tr :value)) 12 | ([tr adr] (get (trace-subtrace tr adr) :value))) 13 | 14 | (defn trace-has-subtrace? [tr adr] 15 | (if (seq? adr) 16 | (if (empty? adr) 17 | true 18 | (if (contains? tr (first adr)) 19 | (recur (get tr (first adr)) (rest adr)))) 20 | (contains? tr adr))) 21 | 22 | (defn trace-keys [tr] 23 | (filter (fn [x] (not= x :value)) (keys tr))) 24 | 25 | (defn subtrace-count [tr] 26 | (- (count tr) (if (trace-has-value? tr) 1 0))) 27 | 28 | (defn trace-set-subtrace [tr adr sub] 29 | (if (seq? adr) 30 | (if (empty? adr) 31 | sub 32 | (assoc tr (first adr) (trace-set-subtrace (get tr (first adr)) (rest adr) sub))) 33 | (assoc tr adr sub))) 34 | 35 | (defn trace-set-value 36 | ([tr val] (assoc tr :value val)) 37 | ([tr adr val] 38 | (trace-set-subtrace tr adr (trace-set-value (trace-subtrace tr adr) val)))) 39 | ; TODO: Only traverse once? 40 | 41 | (defn trace-clear-value 42 | ([tr] (dissoc tr :value)) 43 | ([tr adr] (trace-set-subtrace tr adr (trace-clear-value (trace-subtrace tr adr))))) 44 | 45 | (declare trace-clear-subtrace) 46 | (defn maybe-set-subtrace 47 | [output adr suboutput] 48 | (if (empty? suboutput) 49 | (trace-clear-subtrace output adr) 50 | (trace-set-subtrace output adr suboutput))) 51 | 52 | (defn trace-clear-subtrace [tr adr] 53 | (if (seq? adr) 54 | (if (empty? adr) 55 | {} 56 | (if (empty? (rest adr)) 57 | (dissoc tr (first adr)) 58 | (maybe-set-subtrace 59 | tr 60 | (first adr) 61 | (trace-clear-subtrace (trace-subtrace tr (first adr)) (rest adr))))) 62 | (dissoc tr adr))) 63 | 64 | (defn value-only-trace? [tr] 65 | (and (trace-has-value? tr) (= (count tr) 1))) 66 | 67 | ;; Recursively walks the entire state to check it's valid 68 | (defn trace? [s] 69 | (map? s)) 70 | 71 | (defn valid-trace? [s] 72 | (and 73 | (map? s) 74 | (every? 75 | (fn [k] (trace? (get s k))) 76 | (trace-keys s)))) 77 | 78 | ;; Marco's merge operator (+). Commutative and idempotent. 79 | ;; 80 | ;; (trace-merge small large) - when calling, try to make tr1 smaller than tr2, 81 | ;; because it will be tr1 that is traversed. 82 | 83 | ;; Compare states of two values that might or might not be traces. 84 | 85 | (defn trace-merge [tr1 tr2] 86 | (let 87 | [merged 88 | (into tr1 89 | (for [key (trace-keys tr2)] 90 | [key (if (trace-has-subtrace? tr1 key) 91 | (trace-merge (trace-subtrace tr1 key) 92 | (trace-subtrace tr2 key)) 93 | (trace-subtrace tr2 key))]))] 94 | (if (trace-has-value? merged) 95 | (do (if (trace-has-value? tr2) 96 | (assert (= (trace-value tr1) (trace-value tr2)) 97 | ["incompatible trace values" tr1 tr2])) 98 | merged) 99 | (if (trace-has-value? tr2) 100 | (trace-set-value merged (trace-value tr2)) 101 | merged)))) 102 | 103 | (defn maybe-subtrace 104 | [tr adr] 105 | (or (trace-subtrace tr adr) {})) 106 | 107 | (defn merge-subtrace 108 | [trace addr subtrace] 109 | (trace-merge trace (maybe-set-subtrace {} addr subtrace))) 110 | 111 | (defn addresses-of [tr] 112 | (letfn [(get-sites [tr] 113 | ;; returns a seq of traces 114 | (let [site-list 115 | (mapcat (fn [key] 116 | (map (fn [site] 117 | (cons key site)) 118 | (get-sites (trace-subtrace tr key)))) 119 | (trace-keys tr))] 120 | (if (trace-has-value? tr) 121 | (cons '() site-list) 122 | site-list)))] 123 | (let [s (get-sites tr)] 124 | (doseq [site s] 125 | (assert (trace-has-value? tr site) ["missing value at" site])) 126 | s))) 127 | 128 | 129 | (defn copy-addresses 130 | [src dst paths] 131 | "Copy values from a source trace to a destination trace, at the given paths." 132 | (reduce #(trace-set-value %1 %2 (trace-value src %2)) 133 | dst paths)) 134 | 135 | (defn partition-trace 136 | [trace paths] 137 | (let [path-set (into #{} (map #(if (not (seq? %)) (list %) %) paths)) 138 | addresses (into #{} (addresses-of trace)) 139 | all-addresses (group-by #(clojure.core/contains? path-set %) addresses)] 140 | [(copy-addresses trace {} (get all-addresses true)) 141 | (copy-addresses trace {} (get all-addresses false))])) 142 | 143 | (defn address-contains? [addr elem] 144 | (some #{elem} addr)) 145 | -------------------------------------------------------------------------------- /test/metaprob/compositional_test.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.compositional-test 2 | (:require [clojure.test :refer [deftest is testing]] 3 | [metaprob.trace :as trace] 4 | [metaprob.generative-functions :as gen :refer [gen]] 5 | [metaprob.distributions :as dist] 6 | [metaprob.prelude :as pre]) 7 | (:refer-clojure :exclude [assoc dissoc])) 8 | 9 | (defn ez-call [prob-prog & inputs] 10 | (let [inputs (if (= inputs nil) '() inputs) 11 | [value _ _] 12 | (pre/infer-and-score :procedure prob-prog :inputs inputs)] 13 | value)) 14 | 15 | (deftest apply-2 16 | (testing "Apply a procedure to one arg" 17 | (is (= (ez-call - 7) 18 | -7)))) 19 | 20 | (deftest thunk-1 21 | (testing "call a thunk" 22 | (is (= (ez-call (gen [] 7)) 23 | 7)))) 24 | 25 | ;; N.b. this will reify the procedure to get stuff to eval 26 | 27 | (deftest binding-1 28 | (testing "Bind a variable locally to a value (apply)" 29 | (is (= (ez-call (gen [x] x) 5) 30 | 5)))) 31 | 32 | (deftest binding-3 33 | (testing "Bind a variable locally to a value (apply)" 34 | (is (= (ez-call (gen [] (let [x 17] x))) 35 | 17)))) 36 | 37 | (deftest n-ary-1 38 | (testing "n-ary procedure, formal parameter list is [& y]" 39 | (is (= (first (ez-call (gen [& y] y) 40 | 8 9)) 41 | 8)))) 42 | 43 | (deftest n-ary-2 44 | (testing "n-ary procedure, & in formal parameter list" 45 | (let [result (ez-call (gen [x & y] y) 46 | 7 8 9 10)] 47 | (is (seq? result)) 48 | (is (= (count result) 3)) 49 | (is (= (first result) 8))))) 50 | 51 | 52 | (deftest traced-and-scored-execution 53 | (testing "traced and scored execution" 54 | (let [f (gen [p] (at "x" dist/flip p)) 55 | p 0.4 56 | [v1 t1 s1] (pre/infer-and-score :procedure f :inputs [p]) 57 | [v2 t2 s2] (pre/infer-and-score :procedure f :inputs [p] :observation-trace t1)] 58 | 59 | (is (boolean? (f 0.5))) 60 | (is (true? (f 1))) 61 | (is (false? (f 0))) 62 | (is (= '(("x")) (trace/addresses-of t1))) 63 | (is (not (trace/trace-has-value? t1))) 64 | (is (= v1 (trace/trace-value t1 "x"))) 65 | (is (= s1 0)) 66 | (is (= s2 (if v1 (pre/log p) (pre/log (- 1 p))))) 67 | (is (= t1 t2)) 68 | (is (= v1 v2))))) 69 | 70 | (deftest control-flow 71 | (testing "infer-and-score with weird control flow" 72 | (let [bar (gen [mu] (at "a" dist/gaussian mu 1)) 73 | baz (gen [mu] (at "b" dist/gaussian mu 1)) 74 | foo 75 | (gen [mu] 76 | (if (at "branch" dist/flip 0.4) 77 | (do (at "x" dist/gaussian mu 1) 78 | (at "u" bar mu)) 79 | (do (at "y" dist/gaussian mu 1) 80 | (at "v" baz mu)))) 81 | 82 | mu 83 | 0.123 84 | 85 | [_ first-branch-trace _] 86 | (pre/infer-and-score :procedure foo :inputs [mu] :observation-trace {"branch" {:value true}}) 87 | 88 | x 89 | (trace/trace-value first-branch-trace "x") 90 | 91 | a 92 | (trace/trace-value first-branch-trace '("u" "a")) 93 | 94 | [_ fixed-choices] 95 | (trace/partition-trace first-branch-trace '(("branch")))] 96 | 97 | (loop [i 0] 98 | (when (< i 10) 99 | (let [[v t s] (pre/infer-and-score :procedure foo :inputs [mu] :observation-trace fixed-choices)] 100 | (if (trace/trace-value t "branch") 101 | (do (is (= t first-branch-trace)) 102 | (is (not= s 0))) 103 | (do (is (trace/trace-has-value? t "y")) 104 | (is (trace/trace-has-value? t '("v" "b"))) 105 | (is (= s 0)) 106 | (is (= 3 (count (trace/addresses-of t)))))) 107 | (recur (inc i)))))))) 108 | 109 | (deftest self-execution 110 | (testing "running infer-and-score on infer-and-score" 111 | (let [f (gen [] (and (at 1 dist/flip 0.1) (at 2 dist/flip 0.4))) 112 | [[inner-v inner-t inner-s] t s] 113 | (pre/infer-and-score 114 | :procedure pre/infer-and-score 115 | :inputs [:procedure f, :observation-trace {2 {:value true}}] 116 | :observation-trace {1 {:value true}})] 117 | (is (= (count (trace/addresses-of inner-t)) 2)) 118 | (is (not (trace/trace-has-value? t 2))) 119 | (is (= s (pre/log 0.1))) 120 | (is (= inner-s (pre/log 0.4))) 121 | (is inner-v)))) 122 | 123 | 124 | ;;; `case` expands to use clojure-internal `case*`, which can't work in 125 | ;;; metaprob until we implement that (or manually expand `case`). Until 126 | ;;; we do that, I'm going to leave this commented-out and file an 127 | ;;; issue (jmt) 128 | #_ 129 | (deftest case-1 130 | (testing "case smoke test" 131 | (is (= (ez-eval (mp-expand '(case 1 2))) 2)) 132 | (is (= (ez-eval (mp-expand '(case 1 1 2))) 2)) 133 | (is (= (ez-eval (mp-expand '(case 1 1 2 3))) 2)) 134 | (is (= (ez-eval (mp-expand '(case 1 2 3 1 4))) 4)))) 135 | 136 | #_ 137 | (define tst1 138 | (gen [] 139 | (define x (if (distributions/flip 0.5) 0 1)) 140 | (+ x 3))) 141 | 142 | #_ 143 | (deftest intervene-4 144 | (testing "intervention value is recorded when it overwrites normally-traced execution" 145 | (let [intervene (trace-set-value {} '(0 "x") 5) 146 | [value out __prefix__] (comp/infer-apply 147 | tst1 [] 148 | {:interpretation-id (clojure.core/gensym) 149 | :intervene intervene 150 | :target no-trace 151 | :active? true})] 152 | (is (= value 8))))) 153 | 154 | ;;; in situations where an intervention targets a non-random site, 155 | ;;; `infer`'s _value_ should be affected, but the returned trace should 156 | ;;; still contain the (unused) random choices 157 | #_ 158 | (deftest intervene-target-disagree 159 | (testing "intervention and target traces disagree" 160 | (let [intervene (trace-set-value {} '(0 "x") 5) 161 | target (trace-set-value {} '(0 "x") 5) 162 | [value out s] (comp/infer-apply tst1 [] 163 | {:interpretation-id (clojure.core/gensym) 164 | :intervene intervene 165 | :target target 166 | :active? true})] 167 | (is (= value 8)) 168 | (is (= s 0)) 169 | (trace-has-value? out '(0 "x" "predicate" "distributions/flip")) 170 | (is (and (trace-has-value? out '(0 "x" "predicate" "distributions/flip")) 171 | (clojure.core/contains? 172 | #{true false} 173 | (trace-value out '(0 "x" "predicate" "distributions/flip")))))))) 174 | 175 | 176 | ;;; an assert keeps this from working. if that's expected, this test 177 | ;;; should change to catch that AssertionError (jmt) 178 | #_ 179 | (deftest intervene-target-disagree 180 | (testing "intervention and target traces disagree. this should throw." 181 | (let [intervene (trace-set-value {} '(0 "x") 6) 182 | target (trace-set-value {} '(0 "x") 5)] 183 | 184 | (is (thrown? AssertionError 185 | (comp/infer-apply tst1 [] 186 | {:interpretation-id (clojure.core/gensym) 187 | :intervene intervene 188 | :target target 189 | :active? true})))))) 190 | 191 | ;;; an assert keeps this from happening. if that's expected, this test 192 | ;;; should change to catch that AssertionError (jmt) 193 | #_ 194 | (deftest impossible-target 195 | (testing "target is impossible value" 196 | (let [target (trace-set-value {} '(0 "x") 5)] 197 | 198 | (is (thrown? AssertionError 199 | (comp/infer-apply tst1 [] 200 | {:interpretation-id (clojure.core/gensym) 201 | :intervene {} 202 | :target target 203 | :active? true})))))) 204 | 205 | #_ 206 | (deftest true-target 207 | (testing "target value is the true value" 208 | (let [form '(block (define x (+ 15 2)) (- x 2)) 209 | target (trace-set-value {} '(0 "x" "+") 17) 210 | [value out s] (comp/infer-eval form top 211 | {:interpretation-id (clojure.core/gensym) 212 | :intervene no-trace 213 | :target no-trace 214 | :active? true})] 215 | (is (= value 15)) 216 | (is (= s 0)) 217 | (is (empty? out))))) 218 | 219 | ;;; Self-application 220 | 221 | ;;; These have to be defined at top level because only top level 222 | ;;; defined gens can be interpreted (due to inability to understand 223 | ;;; environments). 224 | 225 | ;;; TODO: Explore what it would take to remove this limitation. 226 | #_ 227 | (define apply-test 228 | (gen [thunk] 229 | (define [val output score] 230 | (infer-apply thunk [] {:interpretation-id (clojure.core/gensym) 231 | :intervene no-trace 232 | :target no-trace 233 | :active? true})) 234 | output)) 235 | 236 | #_ 237 | (define tst2 (gen [] (distributions/flip 0.5))) 238 | #_ 239 | (define tst3 (gen [] (apply-test tst2))) 240 | 241 | #_ 242 | (deftest infer-apply-self-application 243 | (testing "apply infer-apply to program that calls infer-apply" 244 | (binding [*ambient-interpreter* infer-apply] 245 | ;; When we interpret tst1 directly, the value of flip is 246 | ;; recorded at the length-1 address '(distributions/flip). 247 | (is (= (count (first (addresses-of (apply-test tst2)))) 1)) 248 | 249 | ;; But when we trace the execution of the interpreter, the 250 | ;; address at which the random choice is recorded is 251 | ;; significantly longer, due to the complex chain of function 252 | ;; calls initiated by the interpreter. 253 | (is (> (count (first (addresses-of (apply-test tst3)))) 10))))) 254 | -------------------------------------------------------------------------------- /test/metaprob/distributions_test.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.distributions-test 2 | (:refer-clojure :exclude [apply map replicate]) 3 | (:require [clojure.test :refer [deftest is testing]] 4 | [metaprob.trace :as trace] 5 | [metaprob.prelude :as prelude :refer [apply map replicate]] 6 | [metaprob.generative-functions :as gen :refer [gen]] 7 | [metaprob.distributions :as dist])) 8 | 9 | (defn get-score 10 | [proc & inputs] 11 | (let [[_ tr _] 12 | (prelude/infer-and-score :procedure proc :inputs inputs) 13 | [_ _ score] 14 | (prelude/infer-and-score :procedure proc :inputs inputs :observation-trace tr)] 15 | score)) 16 | 17 | (deftest flip-1 18 | (testing "flip smoke tests" 19 | (let [r (range 100) 20 | flips (map (fn [i] (dist/flip 0.5)) r)] 21 | (is (not (every? not flips))) 22 | (is (not (every? (fn [x] x) flips)))))) 23 | 24 | (deftest flip-score-1 25 | (testing "flip score smoke test" 26 | (is (< (get-score dist/flip 0.5) 0)))) 27 | 28 | (deftest uniform-1 29 | (testing "uniform smoke tests" 30 | (is (> (dist/uniform 0 1) 0)) 31 | (is (< (dist/uniform 0 1) 1)))) 32 | 33 | (deftest uniform-score-1 34 | (testing "flip score smoke test" 35 | (let [score (get-score dist/uniform 0 1)] 36 | (is (number? score) score) 37 | (is (> score -0.1))))) 38 | 39 | (defn normalize [weights] 40 | (let [total (apply + weights)] 41 | (map (fn [x] (/ x total)) weights))) 42 | 43 | ;; target-distribution is a seq of [value probability] 44 | 45 | (defn test-generator [generator target-distribution reps] 46 | (let [values (map first target-distribution) 47 | probabilities (map second target-distribution) 48 | samples (map (fn [x] (generator)) (range reps)) 49 | measured (normalize 50 | (map (fn [value] 51 | (apply + 52 | (map (fn [sample] 53 | (if (= sample value) 54 | 1 55 | 0)) 56 | samples))) 57 | values)) 58 | abs (fn [x] (if (< x 0) (- 0 x) x)) 59 | close? (fn [x y] 60 | (if (if (= x 0) 61 | (= y 0) 62 | (< (abs (- (/ y x) 1)) 0.1)) 63 | true 64 | (do (print [x y]) false)))] 65 | (every? (fn [x] x) (clojure.core/map close? probabilities measured)))) 66 | 67 | (deftest categorical-1 68 | (testing "categorical with normalized probabilities" 69 | (let [weights (range 10) 70 | probabilities (normalize weights)] 71 | (is (test-generator (fn [] (dist/categorical probabilities)) 72 | (clojure.core/map (fn [i p] [i p]) 73 | weights 74 | probabilities) 75 | 100000))))) 76 | 77 | 78 | (deftest categorical-2 79 | (testing "categorical with unnormalized probabilities" 80 | (let [weights (range 10) 81 | probabilities (normalize weights)] 82 | (is (test-generator (fn [] (dist/categorical weights)) 83 | (clojure.core/map (fn [i p] [i p]) 84 | weights 85 | probabilities) 86 | 100000))))) 87 | 88 | 89 | (deftest categorical-3 90 | (testing "categorical" 91 | (let [weights (range 10) 92 | probabilities (normalize weights)] 93 | (is (test-generator (fn [] (dist/categorical (zipmap (range 10) 94 | probabilities))) 95 | (clojure.core/map (fn [i p] [i p]) 96 | weights 97 | probabilities) 98 | 100000))))) 99 | 100 | 101 | (deftest categorical-4 102 | (testing "categorical" 103 | (let [weights (range 10) 104 | probabilities (normalize weights)] 105 | (is (test-generator (fn [] (dist/categorical (zipmap (range 10) weights))) 106 | (clojure.core/map (fn [i p] [i p]) 107 | weights 108 | probabilities) 109 | 100000))))) 110 | 111 | 112 | (deftest log-categorical-1 113 | (testing "log-categorical" 114 | (let [weights (range 10) 115 | probabilities (normalize weights) 116 | scores (map (fn [p] 117 | (if (= p 0) 118 | prelude/negative-infinity 119 | (prelude/log p))) 120 | probabilities)] 121 | (is (test-generator (fn [] (dist/log-categorical scores)) 122 | (clojure.core/map (fn [i p] [i p]) 123 | weights 124 | probabilities) 125 | 100000))))) 126 | 127 | 128 | (deftest log-categorical-2 129 | (testing "log-categorical" 130 | (let [weights (range 10) 131 | probabilities (normalize weights) 132 | scores (map (fn [p] 133 | (if (= p 0) 134 | prelude/negative-infinity 135 | (prelude/log p))) 136 | probabilities)] 137 | (is (test-generator (fn [] (dist/log-categorical (zipmap (range 10) scores))) 138 | (clojure.core/map (fn [i p] [i p]) 139 | weights 140 | probabilities) 141 | 100000))))) 142 | -------------------------------------------------------------------------------- /test/metaprob/examples/all_test.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.all-test 2 | (:require [clojure.test :refer :all])) 3 | 4 | (deftest test-require 5 | (testing "can the namespace be required" 6 | (is (any? (require 'metaprob.examples.all))))) 7 | -------------------------------------------------------------------------------- /test/metaprob/examples/flip_n_coins_test.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.flip-n-coins-test 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [clojure.test :refer :all] 4 | [metaprob.trace :refer :all] 5 | [metaprob.generative-functions :refer :all] 6 | [metaprob.prelude :refer :all] 7 | [metaprob.examples.flip-n-coins :refer :all])) 8 | 9 | (defn datum-addr [n] n) 10 | 11 | (def number-of-flips 4) 12 | 13 | (deftest flip-n-coins-smoke-1 14 | (testing "testing flip-n-coins" 15 | (let [[answer trace-with-flips score] (infer-and-score :procedure flip-n-coins 16 | :inputs [number-of-flips])] 17 | (is (trace? trace-with-flips)) 18 | 19 | (let [a1 (first answer)] 20 | (is (or (= a1 true) 21 | (= a1 false)))) 22 | 23 | (is (trace-has-value? trace-with-flips (- number-of-flips 1))) 24 | (is (not (trace-has-value? trace-with-flips number-of-flips))) 25 | 26 | ;; Make sure that observed locations are present 27 | (is (trace-has-value? trace-with-flips "tricky")) 28 | (is (trace-has-value? trace-with-flips 1)) 29 | 30 | ;; Run subject to observations 31 | (let [[answer output score] 32 | (infer-and-score :procedure flip-n-coins 33 | :inputs [(+ number-of-flips 8)] 34 | :observation-trace ensure-tricky-and-biased)] 35 | 36 | ;; Check that the interventions actually got done (they're in 37 | ;; the output trace) 38 | (doseq [adr (addresses-of ensure-tricky-and-biased)] 39 | (is (trace-has-value? output adr)) 40 | (is (= (trace-value output adr) 41 | (trace-value ensure-tricky-and-biased adr)))) 42 | 43 | (is (trace-has-value? output "p")) 44 | (is (trace-has-value? output 1)) 45 | (is (trace-has-value? output 2)) 46 | (is (not (trace-has-value? output (+ number-of-flips 10)))) 47 | 48 | ;; Answer is expected to be 99% heads other than the intervened-on entry. 49 | (is (> (apply + (map (fn [x] (if x 1 0)) answer)) (+ number-of-flips 5))))))) 50 | 51 | (deftest flip-n-coins-score-1 52 | (testing "test score returned by flip-n-coins" 53 | ;; Compute score by re-running with target = previous output 54 | (let [[output-trace score] 55 | (loop [] 56 | (let [[_ output-trace _] 57 | (infer-and-score :procedure flip-n-coins :inputs [number-of-flips]) 58 | [_ _ score] 59 | (infer-and-score :procedure flip-n-coins :inputs [number-of-flips] :observation-trace output-trace)] 60 | (let [tricky (trace-value output-trace "tricky")] 61 | (if tricky 62 | (do (print "** tricky, skipping **\n") 63 | (recur)) 64 | [output-trace score])))) 65 | ;; -2.8779492378976075 66 | want-score (log (* 0.9 (expt 0.5 number-of-flips)))] 67 | (is (> score (- want-score 0.1)) [score want-score]) 68 | (is (< score (+ want-score 0.1)) [score want-score])))) 69 | -------------------------------------------------------------------------------- /test/metaprob/examples/inference_on_gaussian_test.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.inference-on-gaussian-test 2 | (:require [clojure.test :refer :all] 3 | [metaprob.generative-functions :refer :all] 4 | [metaprob.trace :refer :all] 5 | [metaprob.prelude :refer :all] 6 | [metaprob.distributions :refer :all] 7 | [metaprob.inference :refer :all] 8 | [metaprob.examples.inference-on-gaussian :refer :all] 9 | [metaprob.prelude :as prelude]) 10 | (:refer-clojure :exclude [assoc dissoc])) 11 | 12 | (deftest smoke-1 13 | (testing "testing check-sampler" 14 | ;; (let [variance 15 | ;; (check-sampler (fn [] (uniform 0 0.99)) 16 | ;; (fn [x] ;pdf 17 | ;; (if (and (> x 0) (< x 1)) 18 | ;; 1 19 | ;; 0)) 20 | ;; 10 21 | ;; 50)] 22 | ;; (print [variance]) 23 | ;; (is (< variance 0.2))) 24 | 0)) 25 | 26 | (deftest prior-density-1 27 | (testing "checking prior density" 28 | (is (> (prior-density 0) 0.01)))) 29 | 30 | (deftest target-density-1 31 | (testing "checking prior density" 32 | (is (> (target-density 1) 0.01)))) 33 | -------------------------------------------------------------------------------- /test/metaprob/examples/main_test.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.examples.main-test 2 | (:require [clojure.test :refer :all] 3 | [metaprob.examples.main :as main])) 4 | 5 | (defn foo [] 'hello) 6 | -------------------------------------------------------------------------------- /test/metaprob/inference_test.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.inference-test 2 | (:refer-clojure :exclude [map replicate apply]) 3 | (:require [clojure.test :refer [deftest is testing]] 4 | [metaprob.trace :as trace] 5 | [metaprob.generative-functions :as gen :refer [gen let-traced]] 6 | [metaprob.distributions :as dist] 7 | [metaprob.inference :as inf] 8 | #_[metaprob.examples.gaussian :refer :all] 9 | #_[metaprob.examples.inference-on-gaussian :refer :all] 10 | [metaprob.prelude :as prelude])) 11 | 12 | ;; These tests are smoke tests, not real tests of the methods - we don't expect 13 | ;; to get meaningful results with only 16 samples. The real tests take too long 14 | ;; for `clojure -Atest` which I would like to be fast (so it can be run 15 | ;; frequently). 16 | ;; 17 | ;; For actual method tests, we use a longer-running procedure (see long_test.clj). 18 | 19 | ;; IMPORTANCE SAMPLING TESTS: 20 | 21 | (def normal-normal 22 | (gen [] 23 | (let-traced [x (dist/gaussian 0 1) 24 | y (dist/gaussian x 1)] 25 | y))) 26 | 27 | (def small-nsamples 24) 28 | (def small-nbins 4) 29 | (def weak-threshold 0.5) 30 | 31 | ;; This is to see whether the test harness itself is basically working: 32 | 33 | (deftest check-check 34 | (testing "check check" 35 | (let [sampler (fn [i] (dist/uniform 0 1)) 36 | pdf (fn [x] 1)] 37 | (is (inf/assay "0" sampler small-nsamples pdf small-nbins weak-threshold))))) 38 | 39 | ;; Compare sampling from Gaussian prior to exact PDF of prior: 40 | 41 | (deftest check-prior 42 | (testing "check sampling from gaussian prior" 43 | (let [sampler (fn [i] (dist/gaussian 0 1)) 44 | pdf (fn [x] (prelude/exp (dist/score-gaussian x [0 1])))] 45 | (is (inf/assay "p" sampler small-nsamples pdf small-nbins weak-threshold))))) 46 | 47 | 48 | ;;; Inference methods 49 | 50 | (defn target-density 51 | [x] 52 | (prelude/exp (dist/score-gaussian x [1.5 (/ 1.0 (prelude/sqrt 2.0))]))) 53 | 54 | (deftest check-rejection 55 | (testing "check rejection sampling" 56 | (let [sampler (fn [i] 57 | (trace/trace-value 58 | (inf/rejection-sampling :model normal-normal 59 | :observation-trace {"y" {:value 3}} 60 | :log-bound 0.5) 61 | "x")) 62 | pdf target-density] 63 | (is (inf/assay "r" sampler small-nsamples pdf small-nbins weak-threshold))))) 64 | 65 | #_(deftest check-rejection 66 | (testing "check rejection sampling" 67 | (let [sampler (fn [i] 68 | (gaussian-sample-value 69 | (rejection-sampling two-variable-gaussian-model ; :model-procedure 70 | [] ; :inputs 71 | target-trace ; :target-trace 72 | 0.5))) 73 | pdf target-density] 74 | (is (assay "r" sampler small-nsamples pdf small-nbins weak-threshold))))) 75 | 76 | (deftest check-importance 77 | (testing "check importance sampling" 78 | (let [n-particles 50 79 | sampler (fn [i] 80 | (trace/trace-value 81 | (inf/importance-resampling :model normal-normal 82 | :observation-trace {"y" {:value 3}} 83 | :n-particles n-particles) 84 | "x")) 85 | pdf target-density] 86 | (is (inf/assay "i" sampler small-nsamples pdf small-nbins weak-threshold))))) 87 | 88 | #_(deftest check-MH 89 | (testing "check M-H sampling" 90 | (let [steps-per-sample 50 91 | sampler (fn [i] 92 | (gaussian-sample-value 93 | (lightweight-single-site-MH-sampling two-variable-gaussian-model 94 | [] 95 | target-trace 96 | steps-per-sample))) 97 | pdf target-density] 98 | (is (assay "m" sampler small-nsamples pdf small-nbins weak-threshold))))) 99 | -------------------------------------------------------------------------------- /test/metaprob/prelude_test.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.prelude-test 2 | (:require [clojure.test :refer [deftest is testing]] 3 | [metaprob.generative-functions :refer [gen]] 4 | [metaprob.trace :as trace] 5 | [metaprob.prelude :as prelude]) 6 | (:refer-clojure :exclude [assoc dissoc])) 7 | 8 | (deftest sample-1 9 | (testing "sample-uniform smoke tests" 10 | (let [x (prelude/sample-uniform) 11 | y (prelude/sample-uniform)] 12 | (is (> x 0)) 13 | (is (< x 1)) 14 | (is (> y 0)) 15 | (is (< y 1)) 16 | (is (not (= x y)))))) 17 | 18 | 19 | (deftest apply-1 20 | (testing "apply smoke test" 21 | (is (= (apply - [3 2]) 1)) 22 | (is (= (apply - (list 3 2)) 1)) 23 | (is (= (apply apply (list - (list 3 2))) 1)))) 24 | 25 | 26 | #?(:clj (deftest smoke-1 27 | (testing "Prelude smoke test" 28 | (is (= (ns-resolve 'metaprob.prelude 'v) nil) 29 | "namespacing sanity check 1") 30 | (is (not (contains? (ns-publics 'metaprob.prelude) 'v)) 31 | "namespacing sanity check 2")))) 32 | 33 | ;; ------------------------------------------------------------------ 34 | 35 | (deftest map-1 36 | (testing "map smoke test" 37 | (is (nth (prelude/map (gen [x] (+ x 1)) 38 | (list 4 5 6)) 39 | 1) 40 | 6) 41 | ;; These tests have to run after the call to map 42 | #?(:clj (is (= (ns-resolve 'metaprob.prelude 'value) nil) 43 | "namespacing sanity check 1")) 44 | #?(:clj (is (not (contains? (ns-publics 'metaprob.prelude) 'value)) 45 | "namespacing sanity check 2")))) 46 | 47 | ;; I'm sort of tired of this and don't anticipate problems, so 48 | ;; not putting more work into tests at this time. 49 | 50 | (deftest map-1a 51 | (testing "Map over a clojure list" 52 | (let [start (list 6 7 8) 53 | foo (prelude/map (fn [x] (+ x 1)) 54 | start)] 55 | (is (count foo) 3) 56 | (is (= (nth foo 0) 7)) 57 | (is (= (nth foo 1) 8)) 58 | (is (= (nth foo 2) 9))))) 59 | 60 | (deftest map-2 61 | (testing "Map over a different list" 62 | (is (= (first 63 | (rest 64 | (prelude/map (fn [x] (+ x 1)) 65 | (list 6 7 8)))) 66 | 8)))) 67 | -------------------------------------------------------------------------------- /test/metaprob/syntax_test.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.syntax-test 2 | (:require [clojure.test :refer [deftest is testing]] 3 | [metaprob.trace :as trace] 4 | [metaprob.generative-functions :refer [gen]])) 5 | 6 | (deftest gen-1 7 | (testing "Smoke test for gen macro" 8 | (is (= ((gen [x] x) 1) 1)))) 9 | 10 | (deftest gen-2 11 | (testing "Procedures are (no longer) traces" 12 | (is (not (trace/trace? (gen [x] x)))))) 13 | 14 | (deftest gen-3 15 | (testing "are procedures named?" 16 | (is (= (get (meta (gen {:name foo} [x] x)) :name) 'foo)))) 17 | -------------------------------------------------------------------------------- /test/metaprob/test_runner.cljs: -------------------------------------------------------------------------------- 1 | (ns metaprob.test-runner 2 | (:require [cljs.test :as test :include-macros true] 3 | [metaprob.autotrace] 4 | [metaprob.code-handlers] 5 | [metaprob.compositional-test] 6 | [metaprob.distributions-test] 7 | [metaprob.expander] 8 | [metaprob.generative-functions] 9 | [metaprob.inference-test] 10 | [metaprob.prelude-test] 11 | [metaprob.syntax-test] 12 | [metaprob.trace-test])) 13 | 14 | (defn -main 15 | [& args] 16 | (test/run-tests 'metaprob.compositional-test 17 | 'metaprob.distributions-test 18 | 'metaprob.inference-test 19 | 'metaprob.prelude-test 20 | 'metaprob.trace-test 21 | 'metaprob.syntax-test)) 22 | -------------------------------------------------------------------------------- /test/metaprob/trace_test.cljc: -------------------------------------------------------------------------------- 1 | (ns metaprob.trace-test 2 | (:require [clojure.test :refer [deftest is testing]] 3 | [metaprob.trace :as trace])) 4 | 5 | (deftest nil-not-a-trace 6 | (testing "nil is not a trace" 7 | (is (not (trace/trace? nil))))) 8 | 9 | (deftest basic-traces 10 | (testing "battery of tests applied to basic traces" 11 | (let [tr2 {"x" {:value 13} 12 | "y" {:value 19} 13 | :value 31} 14 | tr {"a" {:value 17} 15 | "b" {:value 39} 16 | "c" tr2 17 | :value 5}] 18 | (is (trace/trace? tr)) 19 | (is (= (trace/trace-value tr) 5)) 20 | 21 | (is (= (trace/trace-value (trace/trace-subtrace tr "a")) 17)) 22 | (is (= (trace/trace-value tr "b") 39)) 23 | 24 | (is (= (trace/trace-value tr2) 31)) 25 | (is (= (trace/trace-value tr2 "y") 19)) 26 | 27 | (let [c (trace/trace-subtrace tr "c")] 28 | (is (= (trace/trace-value c) 31)) 29 | (is (= (trace/trace-value c "x") 13))) 30 | 31 | (is (= (trace/trace-value (trace/trace-subtrace tr '("c" "x"))) 13)) 32 | (is (= (trace/trace-value tr '("c" "x")) 13))))) 33 | 34 | (deftest empty-as-trace 35 | (testing "see how well empty seq serves as a trace" 36 | (is (not (trace/trace-has-subtrace? {} "a"))) 37 | (is (not (trace/trace-has-value? {}))))) 38 | 39 | (deftest subtrace-1 40 | (testing "trace-has-subtrace?" 41 | (is (not (trace/trace-has-subtrace? {} "foo"))) 42 | (is (not (trace/trace-has-subtrace? {} '("foo")))) 43 | (is (not (trace/trace-has-subtrace? {} '("foo" "bar")))) 44 | (is (trace/trace-has-subtrace? {:a {}} '(:a))) 45 | (is (trace/trace-has-subtrace? {:a {:b {}}} '(:a :b))))) 46 | 47 | (deftest map-as-trace 48 | (testing "see how maps serve as traces" 49 | (let [new-trace (fn [x] {:value x}) 50 | trace-from-map (fn [x val] 51 | (assoc x :value val)) 52 | tr2 (trace-from-map {"x" (new-trace 13) 53 | "y" (new-trace 19)} 54 | 31) 55 | tr (trace-from-map {"a" (new-trace 17) 56 | "b" (new-trace 33) 57 | "c" tr2} 58 | 5)] 59 | (is (= (trace/trace-value tr) 5)) 60 | 61 | (is (= (trace/trace-value (trace/trace-subtrace tr "a")) 17)) 62 | (is (= (trace/trace-value tr "b") 33)) 63 | 64 | (is (= (trace/trace-value tr2) 31)) 65 | (is (= (trace/trace-value tr2 "y") 19)) 66 | 67 | (let [c (trace/trace-subtrace tr "c")] 68 | (is (= (trace/trace-value c) 31)) 69 | (is (= (trace/trace-value c "x") 13))) 70 | 71 | (is (= (trace/trace-value (trace/trace-subtrace tr '("c" "x"))) 13)) 72 | (is (= (trace/trace-value tr '("c" "x")) 13))))) 73 | 74 | (deftest merge-1 75 | (testing "trace-merge" 76 | (let [tr {} 77 | tr (trace/trace-merge tr {5 {:value 55}})] 78 | (is (= (count tr) 1) tr) 79 | (is (= (trace/trace-value tr 5) 55) tr) 80 | (let [tr (trace/trace-merge tr {6 {:value 66} 7 {:value 77}})] 81 | (is (= (trace/trace-value tr 7) 77)) 82 | (let [tr (trace/trace-merge tr {:value 8})] 83 | (is (= (trace/trace-value tr) 8)) 84 | (let [tr (trace/trace-merge tr {9 {3 {:value 33}}})] 85 | (is (= (trace/trace-value tr '(9 3)) 33)))))))) 86 | 87 | 88 | (deftest addresses-of-1 89 | (testing "Smoke test addresses-of" 90 | (let [tree {"x" {"a" {:value 1} 91 | "b" {:value 2} 92 | "c" {}} 93 | "y" {:value "d"}} 94 | sites (trace/addresses-of tree)] 95 | (is (= (count sites) 3))))) 96 | 97 | (deftest addresses-of-2 98 | (testing "addresses-of (addresses-of)" 99 | (let [tr {"a" {:value 17} 100 | "b" {:value 31} 101 | "c" {:value {"d" {:value 71}}}} 102 | ;; sites (sequence/sequence-to-seq (addresses-of tr)) 103 | sites (trace/addresses-of tr) 104 | vals (map (fn [site] (trace/trace-value tr site)) sites) 105 | has? (fn [val] (some (fn [x] (= x val)) vals))] 106 | (has? 17) 107 | (has? 31) 108 | (has? {"d" {:value 71}})))) -------------------------------------------------------------------------------- /tutorial/README.md: -------------------------------------------------------------------------------- 1 | # Metaprob tutorial 2 | 3 | Coming soon! 4 | -------------------------------------------------------------------------------- /tutorial/src/metaprob/tutorial/jupyter.clj: -------------------------------------------------------------------------------- 1 | (ns metaprob.tutorial.jupyter 2 | (:require 3 | [clojure.data.json :as json] 4 | [clojure.java.io :as io] 5 | [clojupyter.misc.display :as display] 6 | [metaprob.trace :as trace])) 7 | 8 | (defn enable-inline-viz [] 9 | (display/hiccup-html 10 | [:div [:em "Inline visualization functions have been enabled."] 11 | [:script (slurp (io/resource "plotly-latest.min.js"))] 12 | [:script (slurp (io/resource "plot-trace.js"))]])) 13 | 14 | (defn trace-as-json 15 | [tr] 16 | (let [base (if (trace/trace-has-value? tr) 17 | (let [v (trace/trace-value tr)] 18 | {:value (if (float? v) 19 | (format "%f" v) 20 | (pr-str v))}) 21 | {}) 22 | children (for [key (trace/trace-keys tr)] 23 | (into (trace-as-json (trace/trace-subtrace tr key)) [[:name (pr-str key)]]))] 24 | (into base [[:children (vec children)]]))) 25 | 26 | (defn plot-trace 27 | ([trace-json] (plot-trace trace-json 600 600)) 28 | ([trace-json s] (plot-trace trace-json s s)) 29 | ([trace-json w h] 30 | (let [id (str "svg" (java.util.UUID/randomUUID)) 31 | code (format "drawTrace(\"%s\", %s, %d, %d);" id, (json/write-str (trace-as-json trace-json)), h, w)] 32 | (display/hiccup-html 33 | [:div {:style (format "height: %d; width: %d" (+ h 50) (+ w 100))} [:svg {:id id :width (+ w 100) :height h}] 34 | [:script code]])))) 35 | ; 36 | ;(defn histogram 37 | ; ([data] (histogram data 20)) 38 | ; ([data bins] 39 | ; (histogram data (apply min data) (apply max data) bins)) 40 | ; ([data min max bins] 41 | ; (let [id (str "svg" (java.util.UUID/randomUUID)) 42 | ; code (format "drawHistogram(\"%s\", %s, %d, %d, %d);" id, (json/write-str data), (int (Math/floor min)), (int (Math/ceil max)), bins)] 43 | ; (display/hiccup-html 44 | ; [:div {:style (format "height: 500; width: 960")} [:svg {:id id :width 960 :height 500}] 45 | ; [:script code]])))) 46 | 47 | 48 | (defn plotly-chart 49 | [data layout options] 50 | (let [id (str "plotly" (java.util.UUID/randomUUID)) 51 | code (format "Plotly.newPlot(%s, %s, %s, %s)" (json/write-str id) (json/write-str data) (json/write-str layout) (json/write-str options))] 52 | (display/hiccup-html [:div [:div {:id id}] [:script code]]))) 53 | 54 | 55 | (defn plotly-chart-2 56 | [options] 57 | (let [id (str "plotly" (java.util.UUID/randomUUID)) 58 | code (format "Plotly.plot(%s, %s)" (json/write-str id) (json/write-str options))] 59 | (display/hiccup-html [:div [:div {:id id}] [:script code]]))) 60 | 61 | (defn plotly-chart-animated 62 | [initial-data datas layout options] 63 | (let [id (str "plotly" (java.util.UUID/randomUUID)) 64 | code (format "Plotly.newPlot(%s, %s, %s, %s); plotlyAnimate(%s, %s)" 65 | (json/write-str id) (json/write-str initial-data) 66 | (json/write-str layout) (json/write-str options) 67 | (json/write-str id) (json/write-str datas))] 68 | (display/hiccup-html [:div [:div {:id id}] [:script code]]))) 69 | 70 | 71 | 72 | (defn bar-chart 73 | [title labels data] 74 | (plotly-chart [{:x labels :y data :type "bar" :hoverinfo "none"}] {:title title} {:displayModeBar false})) 75 | 76 | (defn histogram 77 | ([title data [min-x max-x]] 78 | (plotly-chart [{:x data :type "histogram" :xbins {:start min-x :end max-x}}] 79 | {:title title :xaxis {:range [min-x max-x]}} 80 | {:displayModeBar false})) 81 | ([title data [min-x max-x] bin-size] 82 | (plotly-chart [{:x data :type "histogram" :xbins {:start min-x :end (+ max-x bin-size) :size bin-size} :autobinx false}] 83 | {:title title :xaxis {:range [min-x (+ max-x bin-size)]}} 84 | {:displayModeBar false})) 85 | ([title data] (plotly-chart [{:x data :type "histogram"}] 86 | {:title title} 87 | {:displayModeBar false}))) 88 | 89 | (defn overlaid-histograms 90 | [title name1 dist1 name2 dist2] 91 | (plotly-chart [{:x dist1 :name name1 :type "histogram" :opacity 0.5 :marker {:color "green"}} 92 | {:x dist2 :name name2 :type "histogram" :opacity 0.5 :marker {:color "red"}}] 93 | {:barmode "overlay" :title title} 94 | {:displayModeBar false})) 95 | 96 | (defn lin-range 97 | [low high n-intervals] 98 | (map (fn [i] (+ low (* i (/ (- high low) (- n-intervals 1))))) 99 | (range n-intervals))) 100 | 101 | (defn curve-trace 102 | [density [x-min x-max]] 103 | (let 104 | [x-coords (lin-range x-min x-max 100) 105 | curve-data (vec (map density x-coords))] 106 | {:x x-coords :y curve-data :type "line" :name "density" :yaxis "y2"})) 107 | 108 | (defn histogram-with-curve 109 | [title data density x-range] 110 | (plotly-chart [(curve-trace density x-range), 111 | {:x data :name "samples" :type "histogram"}] 112 | {:title title, 113 | :xaxis {:title "x"} 114 | :yaxis {:title "count"}, 115 | :yaxis2 {:overlaying "y" :rangemode "tozero" :zeroline false :showgrid false :anchor "y" :title "density" :side "right"}} 116 | {:displayModeBar false})) 117 | 118 | ; Produce a Plotly trace (not a Metaprob trace!) 119 | ; specifying options for a contour plot of the given density. 120 | (defn contour-trace 121 | [density [[x-min x-max] [y-min y-max]]] 122 | (let [x-coords (lin-range x-min x-max 40) 123 | y-coords (lin-range y-min y-max 40) 124 | contour-data 125 | (vec 126 | (map (fn [y] (clojure.core/vec (map (fn [x] (density x y)) x-coords))) 127 | y-coords))] 128 | {:x x-coords :y y-coords :z contour-data :type "contour", :name "density"})) 129 | 130 | ; Produce a Plotly trace (not a Metaprob trace) 131 | ; with options for a scatter plot with the given data 132 | ; in the form ([x1 y1], [x2 y2], ...) 133 | (defn scatter-trace 134 | [data-points marker-symbol color] 135 | {:x (map first data-points), 136 | :y (map second data-points), 137 | :mode "markers", :type "scatter", :hoverinfo "none", 138 | :id "samples", 139 | :marker {:color color :symbol marker-symbol}}) 140 | (defn little-scatter-trace [data-points] 141 | {:x (map first data-points) 142 | :id "samples" 143 | :y (map second data-points)}) 144 | 145 | (defn default-scatter-plot 146 | [title data] 147 | (plotly-chart [(scatter-trace data "cross" "black")] 148 | {:xaxis {:title "x"} 149 | :yaxis {:title "y"} 150 | :title title} 151 | {:displayModeBar false})) 152 | 153 | (defn custom-scatter-plot 154 | [title data marker-symbol color bgcolor [x-range y-range]] 155 | (plotly-chart [(scatter-trace data marker-symbol color)] 156 | {:xaxis {:range x-range :title "x"} 157 | :yaxis {:range y-range :title "y"} 158 | :plot_bgcolor bgcolor 159 | :title title 160 | :width 500 161 | :height 500} 162 | {:displayModeBar false})) 163 | 164 | (defn scatter-with-contours-animated 165 | [title data density [x-range y-range]] 166 | (let 167 | [ 168 | init-data [(contour-trace density [x-range y-range]) (scatter-trace (first data) "cross" "black")] 169 | slider-steps 170 | (vec (map (fn [i] {:method "animate" :label (str i) :args [[i] {:mode "immediate" :transition {:duration 50}}]}) (range (count data)))) 171 | layout {:xaxis {:range x-range :title "x" :zeroline false} 172 | :yaxis {:range y-range :title "y" :zeroline false} 173 | :width 500 174 | :height 500 175 | :sliders [{:currentvalue {:visible true :prefix "Iter:" :xanchor "left"} :steps slider-steps}] 176 | :title title} 177 | config {:displayModeBar false} 178 | frames (vec (map-indexed (fn [i d] {:name (str i) :data [{} (little-scatter-trace d)]}) data)) 179 | options {:data init-data :layout layout :config config :frames frames} 180 | ] 181 | (plotly-chart-2 options))) 182 | 183 | (defn mh-animation 184 | [title data [x-range y-range]] 185 | (let 186 | [ 187 | init-data [(scatter-trace [(first data)] "cross" "black")] 188 | frames (vec (reverse (map-indexed (fn [i x] {:name (str i) :data [x]}) (reduce 189 | (fn [l next] 190 | (if (and (= (((first l) :y) 0) (next 1)) 191 | (= (((first l) :x) 0) (next 0))) 192 | (cons 193 | {:x [(next 0)] :y [(next 1)] :marker {:color "red"}} 194 | (cons {:x [(next 0)] :y [(next 1)] :marker {:color "red"}} 195 | l)) 196 | (cons {:x [(next 0)] :y [(next 1)] :marker {:color "black"}} 197 | l))) 198 | (list {:x [((first data) 0)] :y [((first data) 1)] :marker {:color "black"}}) 199 | (rest data))))) 200 | layout {:xaxis {:range x-range :title "x"} 201 | :yaxis {:range y-range :title "y"} 202 | :width 500 :height 500 203 | :title title 204 | :updatemenus 205 | [{ :type "buttons" :xanchor "left" :yanchor "top" :pad {:t 50} 206 | :x 0 :y 0 207 | :buttons [{:method "animate" :args [nil {:mode "immediate" :transition {:duration 60} :frame {:duration 50 :redraw false}}] 208 | :label "Play"}] }]} 209 | config {:displayModeBar false} 210 | ] 211 | (plotly-chart-2 {:data init-data :frames frames :layout layout :config config}))) 212 | 213 | (defn scatter-with-contours 214 | [title data density [x-range y-range]] 215 | ; generate contour data 216 | (plotly-chart [(contour-trace density [x-range y-range]), 217 | (scatter-trace data "cross" "white")] 218 | {:xaxis {:range x-range :title "x"} 219 | :yaxis {:range y-range :title "y"} 220 | :width 500 221 | :height 500 222 | :title title} 223 | {:displayModeBar false})) --------------------------------------------------------------------------------