├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── GUIDE.md ├── LICENSE.txt ├── README.md ├── check.sh ├── examples ├── .gitignore ├── clickgraph.ipynb ├── clinical-trial.ipynb ├── disease-infection.ipynb ├── election.ipynb ├── fairness-hiring-model-1.ipynb ├── fairness-hiring-model-2.ipynb ├── fairness-income-model.ipynb ├── generate.sh ├── hierarchcial-markov-switching.ipynb ├── indian-gpa.ipynb ├── piecewise-transformation.ipynb ├── poisson-mean-inference.ipynb ├── random-sequence.ipynb ├── robot-localization.ipynb ├── simple-mixture-model.ipynb ├── student-interviews.ipynb ├── trueskill-poisson-binomial.ipynb └── two-dimensional-mixture-model.ipynb ├── magics ├── __init__.py ├── magics.py └── render.py ├── overview.png ├── pythenv.sh ├── requirements.sh ├── setup.py ├── sppl.png ├── src ├── __init__.py ├── compilers │ ├── __init__.py │ ├── ast_to_spe.py │ ├── spe_to_dict.py │ ├── spe_to_sppl.py │ └── sppl_to_python.py ├── distributions.py ├── dnf.py ├── math_util.py ├── poly.py ├── render.py ├── sets.py ├── spe.py ├── sym_util.py ├── timeout.py └── transforms.py └── tests ├── .gitignore ├── __init__.py ├── test_ast_condition.py ├── test_ast_for.py ├── test_ast_ifelse.py ├── test_ast_switch.py ├── test_ast_transform.py ├── test_burglary.py ├── test_cache_duplicate_subtrees.py ├── test_clickgraph.py ├── test_dnf.py ├── test_event_evaluate.py ├── test_indian_gpa.py ├── test_logpdf.py ├── test_mutual_information.py ├── test_nominal_distribution.py ├── test_parse_distributions.py ├── test_parse_spe.py ├── test_parse_transforms.py ├── test_poly.py ├── test_product.py ├── test_real_continuous.py ├── test_real_discrete.py ├── test_render.py ├── test_sets.py ├── test_solve_transforms.py ├── test_spe_to_dict.py ├── test_spe_to_spml.py ├── test_spe_transform.py ├── test_sppl_to_python.py ├── test_substitute.py ├── test_sum.py ├── test_sum_simplify.py └── test_sym_util.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | workflow_dispatch: 12 | 13 | jobs: 14 | build: 15 | 16 | runs-on: ubuntu-20.04 17 | strategy: 18 | matrix: 19 | python-version: [3.8, 3.9] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Test src 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install '.[tests]' 31 | ./check.sh ci 32 | - name: Test magics 33 | run: | 34 | sudo sh requirements.sh 35 | pip install '.[all]' 36 | ./check.sh -k '__magics_' 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .dockerignore 2 | .cache/ 3 | .coverage 4 | *.egg 5 | *.egg-info 6 | __pycache__/ 7 | build/ 8 | clean.sh 9 | dist/ 10 | egg-info 11 | htmlcov/ 12 | MANIFEST 13 | -------------------------------------------------------------------------------- /GUIDE.md: -------------------------------------------------------------------------------- 1 | System Overview 2 | =============== 3 | 4 | An overview of the SPPL architecture is shown below. 5 | For further details, please refer to the [system description](https://doi.org/10.1145/3453483.3454078) 6 | 7 | 8 | 9 | Probabilistic programs written in SPPL are translated into symbolic sum-product expressions 10 | that represent the joint distribution over all program variables and are used to deliver 11 | exact solutions to probabilistic inference queries. 12 | 13 | Guide to Source Files 14 | ===================== 15 | 16 | The table below describes the main source files that make up SPPL. 17 | 18 | | Filename | Description | 19 | | -------- | ----------- | 20 | | [`src/distributions.py`](src/distributions.py) | Wrappers for discrete and continuous probability distributions from [scipy.stats](https://docs.scipy.org/doc/scipy/reference/stats.html), making them available as modeling primitives in SPPL. | 21 | | [`src/dnf.py`](`src/dnf.py`) | Event preprocessing algorithms, which include converting events to disjunctive normal form, factoring variables in events, and writing an event as a disjoint union of conjunctions. | 22 | | [`src/math_util.py`](src/math_util.py) | Various utilities for mathematical routines. | 23 | | [`src/poly.py`](src/poly.py) | Semi-symbolic solvers for equalities and inequalities involving univariate polynomials with real coefficients. | 24 | | [`src/render.py`](src/render.py) | Renders a sum-product expression as a nested Python list, ideal for use with pprint. | 25 | | [`src/sets.py`](src/sets.py) | Type system and utilities for set theoretic operations including finite nominals, finite reals, and real intervals. | 26 | | [`src/spe.py`](src/spe.py) | Main module implementing the sum-product expressions, including the sum and product combinators and various leaf primitives. | 27 | | [`src/sym_util.py`](src/sym_util.py) | Various utilities for operating on sets and symbolic variables. | 28 | | [`src/timeout.py`](src/timeout.py) | Python context for enforcing a time limit on a block of code. | 29 | | [`src/transforms.py`](src/transforms.py) | Main module implementing (i) numerical transformations on symbolic variables, such as absolute values, logarithms, exponentials, polynomials, piecewise transformations, and (ii) logical transformations, which include conjunctions, disjunctions, and negations and of primitive events (predicates). | 30 | | [`src/compilers/ast_to_spe.py`](ast_to_spe.py) | Translates an SPPL abstract syntax tree to a sum-product expression. | 31 | | [`src/compilers/spe_to_dict.py`](spe_to_dict.py) | Converts a sum-product expression to a Python dictionary. | 32 | | [`src/compilers/spe_to_sppl.py`](spe_to_sppl.py) | Translates a sum-product expression to an SPPL program. | 33 | | [`src/compilers/sppl_to_python.py`](sppl_to_python.py) | Translates SPPL source code to Python source code that contains the original program abstract syntax tree. | 34 | | [`magics/magics.py`](magics/magics.py) | Provides magics for using SPPL through IPython notebooks (see [examples/](./examples)). | 35 | | [`magics/render.py`](magics/render.py) | Renders an SPE as networkx and graphviz. | 36 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Actions Status](https://github.com/probcomp/sppl/workflows/Python%20package/badge.svg)](https://github.com/probcomp/sppl/actions) 2 | [![pypi](https://img.shields.io/pypi/v/sppl.svg)](https://pypi.org/project/sppl/) 3 | 4 | Sum-Product Probabilistic Language 5 | ================================== 6 | 7 | 8 | 9 | SPPL is a probabilistic programming language that delivers exact solutions 10 | to a broad range of probabilistic inference queries. The language handles 11 | continuous, discrete, and mixed-type probability distributions; many-to-one 12 | numerical transformations; and a query language that includes general 13 | predicates on random variables. 14 | 15 | Users express generative models as probabilistic programs with standard 16 | imperative constructs, such as arrays, if/else branches, for loops, etc. 17 | The program is then translated to a sum-product expression (a 18 | generalization of [sum-product networks](https://arxiv.org/pdf/2004.01167.pdf)) 19 | that statically represents the probability distribution of all random 20 | variables in the program. This expression is used to deliver answers to 21 | probabilistic inference queries. 22 | 23 | A system description of SPPL is given in the following paper: 24 | 25 | SPPL: Probabilistic Programming with Fast Exact Symbolic Inference. 26 | Saad, F. A.; Rinard, M. C.; and Mansinghka, V. K. 27 | In PLDI 2021: Proceedings of the 42nd ACM SIGPLAN International Conference 28 | on Programming Language Design and Implementation, 29 | June 20-25, Virtual, Canada. ACM, New York, NY, USA. 2021. 30 | https://doi.org/10.1145/3453483.3454078. 31 | 32 | ### Installation 33 | 34 | This software is tested on Ubuntu 20.04 and Python 3.8. 35 | SPPL is available on the PyPI repository 36 | 37 | $ python -m pip install sppl 38 | 39 | To install the Jupyter interface, first obtain the system-wide dependencies in 40 | [requirements.sh](https://github.com/probcomp/sppl/blob/master/requirements.sh) 41 | and then run 42 | 43 | $ python -m pip install 'sppl[magics]' 44 | 45 | ### Examples 46 | 47 | The easiest way to use SPPL is via the browser-based Jupyter interface, which 48 | allows for interactive modeling, querying, and plotting. 49 | Refer to the `.ipynb` notebooks under the 50 | [examples](https://github.com/probcomp/sppl/tree/master/examples) directory. 51 | 52 | ### Benchmarks 53 | 54 | Please refer to the artifact at the ACM Digital Library: 55 | https://doi.org/10.1145/3453483.3454078 56 | 57 | ### Guide to Source Code 58 | 59 | Please refer to [GUIDE.md](./GUIDE.md) for a description of the 60 | main source files in this repository. 61 | 62 | ### Tests 63 | 64 | To run the test suite as a user, first install the test dependencies: 65 | 66 | $ python -m pip install 'sppl[tests]' 67 | 68 | Then run the test suite: 69 | 70 | $ python -m pytest --pyargs sppl 71 | 72 | To run the test suite as a developer: 73 | 74 | - To run crash tests: `$ ./check.sh` 75 | - To run integration tests: `$ ./check.sh ci` 76 | - To run a specific test: `$ ./check.sh [] /path/to/test.py` 77 | - To run the examples: `$ ./check.sh examples` 78 | - To build a docker image: `$ ./check.sh docker` 79 | - To generate a coverage report: `$ ./check.sh coverage` 80 | 81 | To view the coverage report, open `htmlcov/index.html` in the browser. 82 | 83 | ### Language Reference 84 | 85 | Coming Soon! 86 | 87 | ### Citation 88 | 89 | To cite this work, please use the following BibTeX. 90 | 91 | ```bibtex 92 | @inproceedings{saad2021sppl, 93 | title = {{SPPL:} Probabilistic Programming with Fast Exact Symbolic Inference}, 94 | author = {Saad, Feras A. and Rinard, Martin C. and Mansinghka, Vikash K.}, 95 | booktitle = {PLDI 2021: Proceedings of the 42nd ACM SIGPLAN International Conference on Programming Design and Implementation}, 96 | pages = {804--819}, 97 | year = 2021, 98 | location = {Virtual, Canada}, 99 | publisher = {ACM}, 100 | address = {New York, NY, USA}, 101 | doi = {10.1145/3453483.3454078}, 102 | address = {New York, NY, USA}, 103 | keywords = {probabilistic programming, symbolic execution, static analysis}, 104 | } 105 | ``` 106 | 107 | ### License 108 | 109 | Apache 2.0; see [LICENSE.txt](./LICENSE.txt) 110 | 111 | ### Acknowledgments 112 | 113 | The [logo](https://github.com/probcomp/sppl/blob/master/sppl.png) was 114 | designed by [McCoy R. Becker](https://femtomc.github.io/). 115 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -Ceux 4 | 5 | : ${PYTHON:=python} 6 | 7 | root=`cd -- "$(dirname -- "$0")" && pwd` 8 | 9 | ( 10 | set -Ceu 11 | cd -- "${root}" 12 | rm -rf build 13 | "$PYTHON" setup.py build 14 | 15 | # (Default) Run tests not marked __ci__ 16 | if [ $# -eq 0 ] || [ ${1} = 'crash' ]; then 17 | ./pythenv.sh "$PYTHON" -m pytest -k 'not __ci_' --pyargs sppl 18 | 19 | # Run all tests under tests/ 20 | elif [ ${1} = 'ci' ]; then 21 | ./pythenv.sh "$PYTHON" -m pytest --pyargs sppl 22 | 23 | # Generate coverage report. 24 | elif [ ${1} = 'coverage' ]; then 25 | ./pythenv.sh coverage run --source=build/ -m pytest --pyargs sppl 26 | coverage html 27 | coverage report 28 | 29 | # Run the .ipynb notebooks under examples/ 30 | elif [ ${1} = 'examples' ]; then 31 | cd -- examples/ 32 | ./generate.sh 33 | cd -- "${root}" 34 | 35 | # Make a tagged release. 36 | elif [ ${1} = 'tag' ]; then 37 | tag="${2}" 38 | (git diff --quiet --stat && git diff --quiet --staged) \ 39 | || (echo 'fatal: workspace dirty' && exit 1) 40 | git show-ref --quiet --tags v"${tag}" \ 41 | && (echo 'fatal: tag exists' && exit 1) 42 | sed -i "s/__version__ = .*/__version__ = '${tag}'/g" -- src/__init__.py 43 | git add -- src/__init__.py 44 | git commit -m "Pin version ${tag}." 45 | git tag -a -m v"${tag}" v"${tag}" 46 | 47 | # Send release to PyPI. 48 | elif [ ${1} = 'release' ]; then 49 | rm -rf dist 50 | "$PYTHON" setup.py sdist bdist_wheel 51 | twine upload --repository pypi dist/* 52 | 53 | # If args are specified delegate control to user. 54 | else 55 | ./pythenv.sh "$PYTHON" -m pytest "$@" 56 | 57 | fi 58 | ) 59 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | -------------------------------------------------------------------------------- /examples/clinical-trial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext sppl.magics" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%%sppl model\n", 19 | "\n", 20 | "n = 20\n", 21 | "k = 20\n", 22 | "isEffective ~= bernoulli(p=.5)\n", 23 | "probControl ~= randint(low=0, high=k)\n", 24 | "probTreated ~= randint(low=0, high=k)\n", 25 | "probAll ~= randint(low=0, high=k)\n", 26 | "\n", 27 | "controlGroup = array(n)\n", 28 | "treatedGroup = array(n)\n", 29 | "\n", 30 | "if (isEffective == 1):\n", 31 | " for i in range(n):\n", 32 | " switch (probControl) cases (p in range(k)):\n", 33 | " controlGroup[i] ~= bernoulli(p=p/k)\n", 34 | " switch (probTreated) cases (p in range(k)):\n", 35 | " treatedGroup[i] ~= bernoulli(p=p/k)\n", 36 | "else:\n", 37 | " for i in range(n):\n", 38 | " switch (probAll) cases (p in range(k)):\n", 39 | " controlGroup[i] ~= bernoulli(p=p/k)\n", 40 | " treatedGroup[i] ~= bernoulli(p=p/k)" 41 | ] 42 | } 43 | ], 44 | "metadata": { 45 | "kernelspec": { 46 | "display_name": "Python 3", 47 | "language": "python", 48 | "name": "python3" 49 | }, 50 | "language_info": { 51 | "codemirror_mode": { 52 | "name": "ipython", 53 | "version": 3 54 | }, 55 | "file_extension": ".py", 56 | "mimetype": "text/x-python", 57 | "name": "python", 58 | "nbconvert_exporter": "python", 59 | "pygments_lexer": "ipython3", 60 | "version": "3.6.9" 61 | } 62 | }, 63 | "nbformat": 4, 64 | "nbformat_minor": 4 65 | } 66 | -------------------------------------------------------------------------------- /examples/fairness-hiring-model-1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext sppl.magics\n", 10 | "%matplotlib inline\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": { 19 | "scrolled": false 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "%%sppl model\n", 24 | "\n", 25 | "# Prior on ethnicity, experience, and college rank.\n", 26 | "ethnicity ~= choice({'minority': .15, 'majority': .85})\n", 27 | "years_experience ~= binom(n=15, p=.5)\n", 28 | "if (ethnicity == 'minority'):\n", 29 | " college_rank ~= dlaplace(loc=25, a=1/5)\n", 30 | "else:\n", 31 | " college_rank ~= dlaplace(loc=20, a=1/5)\n", 32 | "\n", 33 | "# Top 50 colleges and at most 20 years of experience.\n", 34 | "condition((0 <= college_rank) <= 50)\n", 35 | "condition((0 <= years_experience) <= 20)\n", 36 | "\n", 37 | "# Hiring decision.\n", 38 | "if college_rank <= 5:\n", 39 | " hire ~= atomic(loc=1)\n", 40 | "else:\n", 41 | " switch (years_experience) cases (years in range(0, 20)):\n", 42 | " if ((years - 5) > college_rank):\n", 43 | " hire ~= atomic(loc=1)\n", 44 | " else:\n", 45 | " hire ~= atomic(loc=0)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "n = %sppl_get_namespace model\n", 55 | "model = n.model" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "minority = n.ethnicity << {'minority'}\n", 65 | "model_c0 = model.condition(minority)\n", 66 | "model_c1 = model.condition(~minority)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 5, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "p_hire_given_minority = model_c0.prob(n.hire << {1})\n", 76 | "p_hire_given_majority = model_c1.prob(n.hire << {1})" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 6, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "0.3666600508668959\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "print(p_hire_given_minority / p_hire_given_majority)" 94 | ] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 3", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.6.9" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 4 118 | } 119 | -------------------------------------------------------------------------------- /examples/fairness-hiring-model-2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext sppl.magics\n", 10 | "%matplotlib inline\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": { 19 | "scrolled": false 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "%%sppl model\n", 24 | "\n", 25 | "# Population model\n", 26 | "is_male ~= bernoulli(p=.5)\n", 27 | "college_rank ~= norm(loc=25, scale=10)\n", 28 | "if is_male == 1:\n", 29 | " years_exp ~= norm(loc=15, scale=5)\n", 30 | "else:\n", 31 | " years_exp ~= norm(loc=10, scale=5)\n", 32 | "\n", 33 | "# Hiring decision.\n", 34 | "if ((college_rank <= 5) | (years_exp > 5)):\n", 35 | " hire ~= atomic(loc=1)\n", 36 | "else:\n", 37 | " hire ~= atomic(loc=0)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "n = %sppl_get_namespace model" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "model_m = n.model.condition(n.is_male << {1})\n", 56 | "model_f = n.model.condition(n.is_male << {0})" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 5, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "p_m = model_m.prob(n.hire<<{1})\n", 66 | "p_f = model_f.prob(n.hire<<{1})" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 6, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "0.8641668176293485" 78 | ] 79 | }, 80 | "execution_count": 6, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "p_f/p_m" 87 | ] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.6.9" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 4 111 | } 112 | -------------------------------------------------------------------------------- /examples/fairness-income-model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext sppl.magics" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%%sppl model\n", 19 | "\n", 20 | "# Population model.\n", 21 | "sex ~= choice({'female': .3307, 'male': .6693})\n", 22 | "if (sex == 'female'):\n", 23 | " capital_gain ~= norm(loc=568.4105, scale=24248365.5428)\n", 24 | " if capital_gain < 7298.0000:\n", 25 | " age ~= norm(loc=38.4208, scale=184.9151)\n", 26 | " relationship ~= choice({\n", 27 | " '0': .0491, '1': .1556, '2': .4012,\n", 28 | " '3': .2589, '4': .0294, '5': .1058\n", 29 | " })\n", 30 | " else:\n", 31 | " age ~= norm(loc=38.8125, scale=193.4918)\n", 32 | " relationship ~= choice({\n", 33 | " '0': .0416, '1': .1667, '2': .4583,\n", 34 | " '3': .2292, '4': .0166, '5': .0876\n", 35 | " })\n", 36 | "else:\n", 37 | " capital_gain ~= norm(loc=1329.3700, scale=69327473.1006)\n", 38 | " if capital_gain < 5178.0000:\n", 39 | " age ~= norm(loc=38.6361, scale=187.2435)\n", 40 | " relationship ~= choice({\n", 41 | " '0': .0497, '1': .1545, '2': .4021,\n", 42 | " '3': .2590, '4': .0294, '5': .1053\n", 43 | " })\n", 44 | " else:\n", 45 | " age ~= norm(loc=38.2668, scale=187.2747)\n", 46 | " relationship ~= choice({\n", 47 | " '0': .0417, '1': .1624, '2': .3976,\n", 48 | " '3': .2606, '4': .0356, '5': .1021\n", 49 | " })\n", 50 | "\n", 51 | "condition(age > 18)\n", 52 | "\n", 53 | "# Decision model.\n", 54 | "if relationship == '1':\n", 55 | " if capital_gain < 5095.5:\n", 56 | " t ~= atomic(loc=1)\n", 57 | " else:\n", 58 | " t ~= atomic(loc=0)\n", 59 | "elif relationship == '2':\n", 60 | " if capital_gain < 4718.5:\n", 61 | " t ~= atomic(loc=1)\n", 62 | " else:\n", 63 | " t ~= atomic(loc=0)\n", 64 | "elif relationship == '3':\n", 65 | " if capital_gain < 5095.5:\n", 66 | " t ~= atomic(loc=1)\n", 67 | " else:\n", 68 | " t ~= atomic(loc=0)\n", 69 | "elif relationship == '4':\n", 70 | " if capital_gain < 8296:\n", 71 | " t ~= atomic(loc=1)\n", 72 | " else:\n", 73 | " t ~= atomic(loc=0)\n", 74 | "elif relationship == '5':\n", 75 | " t ~= atomic(loc=1)\n", 76 | "else:\n", 77 | " if capital_gain < 4668.5:\n", 78 | " t ~= atomic(loc=1)\n", 79 | " else:\n", 80 | " t ~= atomic(loc=0)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "n = %sppl_get_namespace model\n", 90 | "model = n.model" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "model_c1 = model.condition(n.t << {0})" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 5, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "1.046470962495416\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "p_female_prior = model.prob(n.sex << {'female'})\n", 117 | "p_female_given_no_hire = model_c1.prob(n.sex << {'female'})\n", 118 | "print(100 * (p_female_given_no_hire / p_female_prior - 1))" 119 | ] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "Python 3", 125 | "language": "python", 126 | "name": "python3" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 3 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython3", 138 | "version": "3.6.9" 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 4 143 | } 144 | -------------------------------------------------------------------------------- /examples/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for x in $(ls *.ipynb); do 4 | rm -rf ${x%%.ipynb}.html 5 | jupyter nbconvert --execute --to html ${x}; 6 | done 7 | -------------------------------------------------------------------------------- /examples/robot-localization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext sppl.magics\n", 10 | "%matplotlib inline" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%%sppl model\n", 20 | "param ~= randint(low=0, high=101)\n", 21 | "which ~= uniform(loc=0, scale=1)\n", 22 | "if (which < 0.9):\n", 23 | " switch (param) cases (p in range(0, 101)):\n", 24 | " x ~= norm(loc=p/10, scale=1)\n", 25 | "elif (which < 0.95):\n", 26 | " x ~= uniform(loc=0, scale=10)\n", 27 | "else:\n", 28 | " x ~= atomic(loc=10)" 29 | ] 30 | } 31 | ], 32 | "metadata": { 33 | "kernelspec": { 34 | "display_name": "Python 3", 35 | "language": "python", 36 | "name": "python3" 37 | }, 38 | "language_info": { 39 | "codemirror_mode": { 40 | "name": "ipython", 41 | "version": 3 42 | }, 43 | "file_extension": ".py", 44 | "mimetype": "text/x-python", 45 | "name": "python", 46 | "nbconvert_exporter": "python", 47 | "pygments_lexer": "ipython3", 48 | "version": "3.6.9" 49 | } 50 | }, 51 | "nbformat": 4, 52 | "nbformat_minor": 4 53 | } 54 | -------------------------------------------------------------------------------- /magics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from .magics import SPPL_Magics 5 | 6 | def load_ipython_extension(ipython): 7 | magics = SPPL_Magics(ipython) 8 | ipython.register_magics(magics) 9 | -------------------------------------------------------------------------------- /magics/magics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import sys 5 | from collections import namedtuple 6 | 7 | from IPython.core.magic import Magics 8 | from IPython.core.magic import cell_magic 9 | from IPython.core.magic import line_magic 10 | from IPython.core.magic import magics_class 11 | from IPython.core.magic import needs_local_scope 12 | 13 | from sppl.compilers.sppl_to_python import SPPL_Compiler 14 | 15 | from .render import render_graphviz 16 | 17 | Model = namedtuple('Model', ['source', 'compiler', 'namespace']) 18 | 19 | @magics_class 20 | class SPPL_Magics(Magics): 21 | 22 | def __init__(self, shell): 23 | super().__init__(shell) 24 | self.programs = {} 25 | 26 | @line_magic 27 | def sppl_get_spe(self, line): 28 | assert line in self.programs, 'unknown program %s' % (line,) 29 | return getattr(self.programs[line].namespace, line) 30 | 31 | @cell_magic 32 | def sppl(self, line, cell): 33 | if not line: 34 | sys.stderr.write('specify model name after %%sppl') 35 | return 36 | if line in self.programs: 37 | del self.programs[line] 38 | compiler = SPPL_Compiler(cell, line) 39 | namespace = compiler.execute_module() 40 | self.programs[line] = Model(cell, compiler, namespace) 41 | 42 | @line_magic 43 | def sppl_to_python(self, line): 44 | assert line in self.programs, 'unknown program %s' % (line,) 45 | print(self.programs[line].compiler.render_module()) 46 | 47 | @needs_local_scope 48 | @line_magic 49 | def sppl_to_graph(self, line, local_ns): 50 | tokens = line.strip().split(' ') 51 | line = tokens[0] 52 | filename = tokens[1] if len(tokens) == 2 else None 53 | if line in self.programs: 54 | spe = self.sppl_get_spe(line) 55 | elif line in local_ns: 56 | spe = local_ns[line] 57 | else: 58 | assert False, 'unknown program %s' % (line,) 59 | return render_graphviz(spe, filename=filename) 60 | 61 | @line_magic 62 | def sppl_get_namespace(self, line): 63 | assert line in self.programs, 'unknown program %s' % (line,) 64 | return self.programs[line].namespace 65 | -------------------------------------------------------------------------------- /magics/render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import os 5 | import tempfile 6 | import time 7 | 8 | from math import exp 9 | 10 | import graphviz 11 | import networkx as nx 12 | 13 | from sppl.spe import AtomicLeaf 14 | from sppl.spe import NominalLeaf 15 | from sppl.spe import ProductSPE 16 | from sppl.spe import RealLeaf 17 | from sppl.spe import SumSPE 18 | 19 | gensym = lambda: 'r%s' % (str(time.time()).replace('.', ''),) 20 | 21 | def render_networkx_graph(spe): 22 | if isinstance(spe, NominalLeaf): 23 | G = nx.DiGraph() 24 | root = gensym() 25 | G.add_node(root, label='%s\n%s' % (spe.symbol.token, 'Nominal')) 26 | return G 27 | if isinstance(spe, AtomicLeaf): 28 | G = nx.DiGraph() 29 | root = gensym() 30 | G.add_node(root, label='%s\n%s(%s)' 31 | % (spe.symbol.token, 'Atomic', str(spe.value))) 32 | return G 33 | if isinstance(spe, RealLeaf): 34 | G = nx.DiGraph() 35 | root = gensym() 36 | kwds = '\n%s' % (tuple(spe.dist.kwds.values()),) if spe.dist.kwds else '' 37 | G.add_node(root, label='%s\n%s%s' % (spe.symbol.token, spe.dist.dist.name, kwds)) 38 | if len(spe.env) > 1: 39 | for k, v in spe.env.items(): 40 | if k != spe.symbol: 41 | roott = gensym() 42 | G.add_node(roott, label=str(v), style='filled') 43 | G.add_edge(root, roott, label=' %s' % (str(k),), style='dashed') 44 | return G 45 | if isinstance(spe, SumSPE): 46 | G = nx.DiGraph() 47 | root = gensym() 48 | G.add_node(root, label='\N{PLUS SIGN}') 49 | # Add nodes and edges from children. 50 | G_children = [render_networkx_graph(c) for c in spe.children] 51 | for i, x in enumerate(G_children): 52 | G.add_nodes_from(x.nodes.data()) 53 | G.add_edges_from(x.edges.data()) 54 | subroot = list(nx.topological_sort(x))[0] 55 | G.add_edge(root, subroot, label='%1.3f' % (exp(spe.weights[i]),)) 56 | return G 57 | if isinstance(spe, ProductSPE): 58 | G = nx.DiGraph() 59 | root = gensym() 60 | G.add_node(root, label='\N{MULTIPLICATION SIGN}') 61 | # Add nodes and edges from children. 62 | G_children = [render_networkx_graph(c) for c in spe.children] 63 | for x in G_children: 64 | G.add_nodes_from(x.nodes.data()) 65 | G.add_edges_from(x.edges.data()) 66 | subroot = list(nx.topological_sort(x))[0] 67 | G.add_edge(root, subroot) 68 | return G 69 | assert False, 'Unknown SPE type: %s' % (spe,) 70 | 71 | def render_graphviz(spe, filename=None, ext=None, show=None): 72 | fname = filename 73 | if filename is None: 74 | f = tempfile.NamedTemporaryFile(delete=False) 75 | fname = f.name 76 | G = render_networkx_graph(spe) 77 | ext = ext or 'png' 78 | assert ext in ['png', 'pdf'], 'Extension must be .pdf or .png' 79 | fname_dot = '%s.dot' % (fname,) 80 | # nx.set_edge_attributes(G, 'serif', 'fontname') 81 | # nx.set_node_attributes(G, 'serif', 'fontname') 82 | nx.nx_agraph.write_dot(G, fname_dot) 83 | source = graphviz.Source.from_file(fname_dot, format=ext) 84 | source.render(filename=fname, view=show) 85 | os.unlink(fname) 86 | return source 87 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/sppl/ab6435648e56df1603c4d8d27029605c247cb9f5/overview.png -------------------------------------------------------------------------------- /pythenv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -Ceu 4 | 5 | : ${PYTHON:=python} 6 | root=`cd -- "$(dirname -- "$0")" && pwd` 7 | platform=$("${PYTHON}" -c 'import distutils.util as u; print(u.get_platform())') 8 | version=$("${PYTHON}" -c 'import sys; print(sys.version[0:3])') 9 | 10 | # The lib directory varies depending on 11 | # 12 | # (a) whether there are extension modules (here, no); and 13 | # (b) whether some Debian maintainer decided to patch the local Python 14 | # to behave as though there were. 15 | # 16 | # But there's no obvious way to just ask distutils what the name will 17 | # be. There's no harm in naming a pathname that doesn't exist, other 18 | # than a handful of microseconds of runtime, so we'll add both. 19 | libdir="${root}/build/lib" 20 | plat_libdir="${libdir}.${platform}-${version}" 21 | export PYTHONPATH="${libdir}:${plat_libdir}${PYTHONPATH:+:${PYTHONPATH}}" 22 | 23 | bindir="${root}/build/scripts-${version}" 24 | export PATH="${bindir}${PATH:+:${PATH}}" 25 | 26 | exec "$@" 27 | -------------------------------------------------------------------------------- /requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Tested on Ubuntu 18.04+, other platforms may vary. 4 | # 5 | # apt package pypi package 6 | # =========== ============ 7 | # graphviz pygraphviz 8 | # libgraphviz-dev pygraphviz 9 | # gfortran scipy 10 | 11 | apt-get -y install python3-dev graphviz libgraphviz-dev gfortran 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import os 5 | import re 6 | 7 | from setuptools import setup 8 | 9 | # Specify the requirements. 10 | requirements = { 11 | 'src' : [ 12 | 'astunparse==1.6.3', 13 | 'numpy==1.23.1', 14 | 'scipy==1.8.1', 15 | 'sympy==1.10.1', 16 | ], 17 | 'magics' : [ 18 | 'graphviz==0.13.2', 19 | 'ipython==7.23.1', 20 | 'jupyter-core==4.6.3', 21 | 'networkx==2.4', 22 | 'notebook==6.0.3', 23 | 'matplotlib==3.3.2', 24 | 'pygraphviz==1.5', 25 | ], 26 | 'tests' : [ 27 | 'pytest-timeout==2.1.0', 28 | 'pytest==7.1.2', 29 | 'coverage==6.4.2', 30 | ] 31 | } 32 | requirements['all'] = [r for v in requirements.values() for r in v] 33 | 34 | # Determine the version (hardcoded). 35 | dirname = os.path.dirname(os.path.realpath(__file__)) 36 | vre = re.compile('__version__ = \'(.*?)\'') 37 | m = open(os.path.join(dirname, 'src', '__init__.py')).read() 38 | __version__ = vre.findall(m)[0] 39 | 40 | setup( 41 | name='sppl', 42 | version=__version__, 43 | description='The Sum-Product Probabilistic Language', 44 | long_description=open('README.md').read(), 45 | long_description_content_type='text/markdown', 46 | url='https://github.com/probcomp/sppl', 47 | license='Apache-2.0', 48 | maintainer='Feras A. Saad', 49 | maintainer_email='fsaad@mit.edu', 50 | classifiers=[ 51 | 'Intended Audience :: Science/Research', 52 | 'License :: OSI Approved :: Apache Software License', 53 | 'Topic :: Scientific/Engineering :: Mathematics', 54 | ], 55 | packages=[ 56 | 'sppl', 57 | 'sppl.compilers', 58 | 'sppl.magics', 59 | 'sppl.tests', 60 | ], 61 | package_dir={ 62 | 'sppl' : 'src', 63 | 'sppl.compilers' : 'src/compilers', 64 | 'sppl.magics' : 'magics', 65 | 'sppl.tests' : 'tests', 66 | }, 67 | install_requires=requirements['src'], 68 | extras_require=requirements, 69 | python_requires='>=3.8', 70 | ) 71 | -------------------------------------------------------------------------------- /sppl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/sppl/ab6435648e56df1603c4d8d27029605c247cb9f5/sppl.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | __version__ = '2.0.4' 5 | -------------------------------------------------------------------------------- /src/compilers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | -------------------------------------------------------------------------------- /src/compilers/ast_to_spe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | """Convert AST to SPE.""" 5 | 6 | import os 7 | 8 | from functools import reduce 9 | 10 | from ..dnf import dnf_normalize 11 | from ..math_util import allclose 12 | from ..math_util import isinf_neg 13 | from ..math_util import logsumexp 14 | from ..sets import FiniteNominal 15 | from ..sets import FiniteReal 16 | from ..sets import Set 17 | from ..spe import Memo 18 | from ..spe import SumSPE 19 | from ..spe import spe_simplify_sum 20 | 21 | from .. import transforms 22 | 23 | inf = float('inf') 24 | 25 | Id = transforms.Id 26 | def IdArray(token, n): 27 | return [Id('%s[%d]' % (token, i,)) for i in range(n)] 28 | 29 | class Command(): 30 | def interpret(self, spe): 31 | raise NotImplementedError() 32 | 33 | class Skip(Command): 34 | def __init__(self): 35 | pass 36 | def interpret(self, spe=None): 37 | return spe 38 | 39 | class Sample(Command): 40 | def __init__(self, symbol, distribution): 41 | self.symbol = symbol 42 | self.distribution = distribution 43 | def interpret(self, spe=None): 44 | leaf = self.symbol >> self.distribution 45 | return leaf if (spe is None) else spe & leaf 46 | 47 | class Transform(Command): 48 | def __init__(self, symbol, expr): 49 | self.symbol = symbol 50 | self.expr = expr 51 | def interpret(self, spe=None): 52 | assert spe is not None 53 | return spe.transform(self.symbol, self.expr) 54 | 55 | class Condition(Command): 56 | def __init__(self, event): 57 | self.event = event 58 | def interpret(self, spe=None): 59 | assert spe is not None 60 | return spe.condition(self.event) 61 | 62 | class Constrain(Command): 63 | def __init__(self, assignment): 64 | self.assignment = assignment 65 | def interpret(self, spe=None): 66 | assert spe is not None 67 | return spe.constrain(self.assignment) 68 | 69 | class IfElse(Command): 70 | def __init__(self, *branches): 71 | assert len(branches) % 2 == 0 72 | self.branches = branches 73 | def interpret(self, spe=None): 74 | assert spe is not None 75 | conditions = self.branches[::2] 76 | subcommands = self.branches[1::2] 77 | # Make events for each condition. 78 | if conditions[-1] is True: 79 | events_if = [ 80 | reduce(lambda x, e: x & ~e, conditions[:i], conditions[i]) 81 | for i in range(len(conditions)-1) 82 | ] 83 | event_else = ~reduce(lambda x, e: x|e, conditions[:-1]) 84 | events_unorm = events_if + [event_else] 85 | else: 86 | events_unorm = [ 87 | reduce(lambda x, e: x & ~e, conditions[:i], conditions[i]) 88 | for i in range(len(conditions)) 89 | ] 90 | # Rewrite events in normalized form. 91 | events = [dnf_normalize(event) for event in events_unorm] 92 | # Rewrite events in normalized form. 93 | return interpret_if_block(spe, events, subcommands) 94 | 95 | class For(Command): 96 | def __init__(self, n0, n1, f): 97 | self.n0 = n0 98 | self.n1 = n1 99 | self.f = f 100 | def interpret(self, spe=None): 101 | commands = [self.f(i) for i in range(self.n0, self.n1)] 102 | sequence = Sequence(*commands) 103 | return sequence.interpret(spe) 104 | 105 | class Switch(Command): 106 | def __init__(self, symbol, values, f): 107 | self.symbol = symbol 108 | self.f = f 109 | self.values = values 110 | def interpret(self, spe=None): 111 | if isinstance(self.values, enumerate): 112 | values = list(self.values) 113 | sets = [self.value_to_set(v[1]) for v in values] 114 | subcommands = [self.f(*v) for v in values] 115 | else: 116 | sets = [self.value_to_set(v) for v in self.values] 117 | subcommands = [self.f(v) for v in self.values] 118 | sets_disjoint = [ 119 | reduce(lambda x, s: x & ~s, sets[:i], sets[i]) 120 | for i in range(len(sets))] 121 | events = [self.symbol << s for s in sets_disjoint] 122 | return interpret_if_block(spe, events, subcommands) 123 | def value_to_set(self, v): 124 | if isinstance(v, Set): 125 | return v 126 | if isinstance(v, str): 127 | return FiniteNominal(v) 128 | return FiniteReal(v) 129 | 130 | class Sequence(Command): 131 | def __init__(self, *commands): 132 | self.commands = commands 133 | def interpret(self, spe=None): 134 | return reduce(lambda S, c: c.interpret(S), self.commands, spe) 135 | 136 | Otherwise = True 137 | 138 | def interpret_if_block(spe, events, subcommands): 139 | assert len(events) == len(subcommands) 140 | # Prepare memo table. 141 | memo = Memo() 142 | # Obtain mixture probabilities. 143 | weights = [spe.logprob(event, memo) 144 | if event is not None else -inf for event in events] 145 | # Filter the irrelevant ones. 146 | indexes = [i for i, w in enumerate(weights) if not isinf_neg(w)] 147 | assert indexes, 'All conditions probability zero.' 148 | # Obtain conditioned SPEs. 149 | weights_conditioned = [weights[i] for i in indexes] 150 | spes_conditioned = [spe.condition(events[i], memo) for i in indexes] 151 | subcommands_conditioned = [subcommands[i] for i in indexes] 152 | assert allclose(logsumexp(weights_conditioned), 0) 153 | # Make the children. 154 | children = [ 155 | subcommand.interpret(S) 156 | for S, subcommand in zip(spes_conditioned, subcommands_conditioned) 157 | ] 158 | # Maybe Simplify. 159 | if len(children) == 1: 160 | spe = children[0] 161 | else: 162 | spe = SumSPE(children, weights_conditioned) 163 | if not os.environ.get('SPPL_NO_SIMPLIFY'): 164 | spe = spe_simplify_sum(spe) 165 | # Return the SPE. 166 | return spe 167 | -------------------------------------------------------------------------------- /src/compilers/spe_to_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | """Convert SPE to JSON friendly dictionary.""" 5 | 6 | from fractions import Fraction 7 | 8 | import scipy.stats 9 | from sympy import E 10 | from sympy import sqrt 11 | 12 | from ..spe import AtomicLeaf 13 | from ..spe import ContinuousLeaf 14 | from ..spe import DiscreteLeaf 15 | from ..spe import NominalLeaf 16 | from ..spe import ProductSPE 17 | from ..spe import SumSPE 18 | 19 | # Needed for "eval" 20 | from ..sets import * 21 | from ..transforms import Id 22 | from ..transforms import Identity 23 | from ..transforms import Radical 24 | from ..transforms import Exponential 25 | from ..transforms import Logarithm 26 | from ..transforms import Abs 27 | from ..transforms import Reciprocal 28 | from ..transforms import Poly 29 | from ..transforms import Piecewise 30 | from ..transforms import EventInterval 31 | from ..transforms import EventFiniteReal 32 | from ..transforms import EventFiniteNominal 33 | from ..transforms import EventOr 34 | from ..transforms import EventAnd 35 | 36 | def env_from_dict(env): 37 | if env is None: 38 | return None 39 | # Used in eval. 40 | return {eval(k): eval(v) for k, v in env.items()} 41 | 42 | def env_to_dict(env): 43 | if len(env) == 1: 44 | return None 45 | return {repr(k): repr(v) for k, v in env.items()} 46 | 47 | def scipy_dist_from_dict(dist): 48 | constructor = getattr(scipy.stats, dist['name']) 49 | return constructor(*dist['args'], **dist['kwds']) 50 | 51 | def scipy_dist_to_dict(dist): 52 | return { 53 | 'name': dist.dist.name, 54 | 'args': dist.args, 55 | 'kwds': dist.kwds 56 | } 57 | 58 | def spe_from_dict(metadata): 59 | if metadata['class'] == 'NominalLeaf': 60 | symbol = Id(metadata['symbol']) 61 | dist = {x: Fraction(w[0], w[1]) for x, w in metadata['dist']} 62 | return NominalLeaf(symbol, dist) 63 | if metadata['class'] == 'AtomicLeaf': 64 | symbol = Id(metadata['symbol']) 65 | value = float(metadata['value']) 66 | env = env_from_dict(metadata['env']) 67 | return AtomicLeaf(symbol, value, env=env) 68 | if metadata['class'] == 'ContinuousLeaf': 69 | symbol = Id(metadata['symbol']) 70 | dist = scipy_dist_from_dict(metadata['dist']) 71 | support = eval(metadata['support']) 72 | conditioned = metadata['conditioned'] 73 | env = env_from_dict(metadata['env']) 74 | return ContinuousLeaf(symbol, dist, support, conditioned, env=env) 75 | if metadata['class'] == 'DiscreteLeaf': 76 | symbol = Id(metadata['symbol']) 77 | dist = scipy_dist_from_dict(metadata['dist']) 78 | support = eval(metadata['support']) 79 | conditioned = metadata['conditioned'] 80 | env = env_from_dict(metadata['env']) 81 | return DiscreteLeaf(symbol, dist, support, conditioned, env=env) 82 | if metadata['class'] == 'SumSPE': 83 | children = [spe_from_dict(c) for c in metadata['children']] 84 | weights = metadata['weights'] 85 | return SumSPE(children, weights) 86 | if metadata['class'] == 'ProductSPE': 87 | children = [spe_from_dict(c) for c in metadata['children']] 88 | return ProductSPE(children) 89 | 90 | assert False, 'Cannot convert %s to SPE' % (metadata,) 91 | 92 | def spe_to_dict(spe): 93 | if isinstance(spe, NominalLeaf): 94 | return { 95 | 'class' : 'NominalLeaf', 96 | 'symbol' : spe.symbol.token, 97 | 'dist' : [ 98 | (str(x), (w.numerator, w.denominator)) 99 | for x, w in spe.dist.items() 100 | ], 101 | 'env' : env_to_dict(spe.env), 102 | } 103 | if isinstance(spe, AtomicLeaf): 104 | return { 105 | 'class' : 'AtomicLeaf', 106 | 'symbol' : spe.symbol.token, 107 | 'value' : spe.value, 108 | 'env' : env_to_dict(spe.env), 109 | } 110 | if isinstance(spe, ContinuousLeaf): 111 | return { 112 | 'class' : 'ContinuousLeaf', 113 | 'symbol' : spe.symbol.token, 114 | 'dist' : scipy_dist_to_dict(spe.dist), 115 | 'support' : repr(spe.support), 116 | 'conditioned' : spe.conditioned, 117 | 'env' : env_to_dict(spe.env), 118 | } 119 | if isinstance(spe, DiscreteLeaf): 120 | return { 121 | 'class' : 'DiscreteLeaf', 122 | 'symbol' : spe.symbol.token, 123 | 'dist' : scipy_dist_to_dict(spe.dist), 124 | 'support' : repr(spe.support), 125 | 'conditioned' : spe.conditioned, 126 | 'env' : env_to_dict(spe.env), 127 | } 128 | if isinstance(spe, SumSPE): 129 | return { 130 | 'class' : 'SumSPE', 131 | 'children' : [spe_to_dict(c) for c in spe.children], 132 | 'weights' : spe.weights, 133 | } 134 | if isinstance(spe, ProductSPE): 135 | return { 136 | 'class' : 'ProductSPE', 137 | 'children' : [spe_to_dict(c) for c in spe.children], 138 | } 139 | assert False, 'Cannot convert %s to JSON' % (spe,) 140 | -------------------------------------------------------------------------------- /src/compilers/spe_to_sppl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | """Convert SPE to SPPL.""" 5 | 6 | from io import StringIO 7 | from math import exp 8 | 9 | from ..spe import RealLeaf 10 | from ..spe import NominalLeaf 11 | from ..spe import ProductSPE 12 | from ..spe import SumSPE 13 | 14 | get_indentation = lambda i: ' ' * i 15 | float_to_str = lambda x,fw: '%1.*f' % (fw, float(x)) if fw else str(x) 16 | class _SPPL_Render_State: 17 | def __init__(self, stream=None, branches=None, indentation=None, 18 | fwidth=None): 19 | self.stream = stream or StringIO() 20 | self.branches = branches or [] 21 | self.indentation = indentation or 0 22 | self.fwidth = fwidth 23 | def render_sppl_choice(symbol, dist, stream, indentation, fwidth): 24 | idt = get_indentation(indentation) 25 | dist_pairs = [ 26 | '\'%s\' : %s' % (k, float_to_str(v, fwidth)) 27 | for (k, v) in dist.items() 28 | ] 29 | # Write each outcome on new line. 30 | if sum(len(x) for x in dist_pairs) > 30: 31 | idt4 = get_indentation(indentation + 4) 32 | prefix = ',\n%s' % (idt4,) 33 | dist_op = '\n%s' % (idt4,) 34 | dist_cl = '\n%s' % (idt,) 35 | # Write all outcomes on same line. 36 | else: 37 | prefix = ', ' 38 | dist_op = '' 39 | dist_cl = '' 40 | dist_str = prefix.join(dist_pairs) 41 | stream.write('%s%s ~= choice({%s%s%s})' % 42 | (idt, symbol, dist_op, dist_str, dist_cl)) 43 | stream.write('\n') 44 | def render_sppl_helper(spe, state): 45 | if isinstance(spe, NominalLeaf): 46 | assert len(spe.env) == 1 47 | render_sppl_choice( 48 | spe.symbol, 49 | spe.dist, 50 | state.stream, 51 | state.indentation, 52 | state.fwidth) 53 | return state 54 | if isinstance(spe, RealLeaf): 55 | kwds = ', '.join([ 56 | '%s=%s' % (k, float_to_str(v, state.fwidth)) 57 | for k, v in spe.dist.kwds.items() 58 | ]) 59 | dist = '%s(%s)' % (spe.dist.dist.name, kwds) 60 | idt = get_indentation(state.indentation) 61 | state.stream.write('%s%s ~= %s' % (idt, spe.symbol, dist)) 62 | state.stream.write('\n') 63 | if spe.conditioned: 64 | event = spe.symbol << spe.support 65 | # TODO: Consider using repr(event) 66 | state.stream.write('%scondition(%s)' % (idt, event)) 67 | state.stream.write('\n') 68 | for i, (var, expr) in enumerate(spe.env.items()): 69 | if 1 <= i: 70 | state.stream.write('%s%s ~= %s' % (idt, var, expr)) 71 | state.stream.write('\n') 72 | return state 73 | if isinstance(spe, ProductSPE): 74 | for child in spe.children: 75 | state = render_sppl_helper(child, state) 76 | return state 77 | if isinstance(spe, SumSPE): 78 | if len(spe.children) == 0: 79 | return state 80 | if len(spe.children) == 1: 81 | return render_sppl_helper(spe.children[0], state) 82 | branch_var = 'branch_var_%s' % (len(state.branches)) 83 | branch_idxs = [str(i) for i in range(len(spe.children))] 84 | branch_dist = {k: exp(w) for k, w in zip(branch_idxs, spe.weights)} 85 | state.branches.append((branch_var, branch_dist)) 86 | # Write the branches. 87 | for i, child in zip(branch_idxs, spe.children): 88 | ifstmt = 'if' if i == '0' else 'elif' 89 | idt = get_indentation(state.indentation) 90 | state.stream.write('%s%s (%s == \'%s\'):' 91 | % (idt, ifstmt, branch_var, i)) 92 | state.stream.write('\n') 93 | state.indentation += 4 94 | state = render_sppl_helper(child, state) 95 | state.stream.write('\n') 96 | state.indentation -= 4 97 | return state 98 | assert False, 'Unknown spe %s' % (spe,) 99 | 100 | def render_sppl(spe, stream=None, fwidth=None): 101 | if stream is None: 102 | stream = StringIO() 103 | state = _SPPL_Render_State(fwidth=fwidth) 104 | state = render_sppl_helper(spe, state) 105 | assert state.indentation == 0 106 | # Write the import. 107 | stream.write('from sppl.distributions import *') 108 | stream.write('\n') 109 | stream.write('\n') 110 | # Write the branch variables (if any). 111 | for branch_var, branch_dist in state.branches: 112 | render_sppl_choice(branch_var, branch_dist, stream, 0, fwidth) 113 | stream.write('\n') 114 | # Write the SPPL. 115 | stream.write(state.stream.getvalue()) 116 | return stream 117 | -------------------------------------------------------------------------------- /src/dnf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from functools import reduce 5 | from itertools import chain 6 | from itertools import combinations 7 | 8 | from .sets import EmptySet 9 | from .transforms import EventAnd 10 | from .transforms import EventBasic 11 | from .transforms import EventOr 12 | from .transforms import Id 13 | 14 | def dnf_factor(event, lookup=None): 15 | # Given an event (in DNF) and a dictionary lookup mapping symbols 16 | # to integers, this function returns a list R of dictionaries 17 | # R[i][j] is a conjunction events in the i-th DNF clause whose symbols 18 | # are assigned to integer j in the lookup dictionary. 19 | # 20 | # For example, if e is any predicate 21 | # event = (e(X0) & e(X1) & ~e(X2)) | (~e(X1) & e(X2) & e(X3) & e(X4))) 22 | # lookup = {X0: 0, X1: 1, X2: 0, X3: 1, X4: 2} 23 | # The output is 24 | # R = [ 25 | # { // First clause 26 | # 0: e(X0) & ~e(X2), 27 | # 1: e(X1)}, 28 | # { // Second clause 29 | # 0: e(X2), 30 | # 1: ~e(X1) & e(X3)}, 31 | # 2: e(X4)}, 32 | # ] 33 | if lookup is None: 34 | lookup = {s:s for s in event.get_symbols()} 35 | 36 | if isinstance(event, EventBasic): 37 | # Literal. 38 | symbols = event.get_symbols() 39 | assert len(symbols) == 1 40 | symbol = list(symbols)[0] 41 | key = lookup[symbol] 42 | return ({key: event},) 43 | 44 | if isinstance(event, EventAnd): 45 | # Conjunction. 46 | assert all(isinstance(e, EventBasic) for e in event.subexprs) 47 | mappings = (dnf_factor(e, lookup) for e in event.subexprs) 48 | events = {} 49 | for mapping in mappings: 50 | for m in mapping: 51 | for key, ev in m.items(): 52 | if key not in events: 53 | events[key] = ev 54 | else: 55 | events[key] &= ev 56 | return (events,) 57 | 58 | if isinstance(event, EventOr): 59 | # Disjunction. 60 | assert all(isinstance(e, (EventAnd, EventBasic)) for e in event.subexprs) 61 | mappings = (dnf_factor(e, lookup) for e in event.subexprs) 62 | return tuple(chain.from_iterable(mappings)) 63 | 64 | assert False, 'Invalid DNF event: %s' % (event,) 65 | 66 | def dnf_normalize(event): 67 | if isinstance(event, EventBasic): 68 | if isinstance(event.subexpr, Id): 69 | return event 70 | # Given an arbitrary event, rewrite in terms of only Id by 71 | # solving the subexpressions and return the resulting DNF formula, 72 | # or None if all solutions evaluate to EmptySet. 73 | event_dnf = event.to_dnf() 74 | event_factor = dnf_factor(event_dnf) 75 | solutions = list(filter(lambda x: all(y[1] is not EmptySet for y in x), [ 76 | [(symbol, ev.solve()) for symbol, ev in clause.items()] 77 | for clause in event_factor 78 | ])) 79 | if not solutions: 80 | return None 81 | conjunctions = [ 82 | reduce(lambda x, e: x & e, [(symbol << S) for symbol, S in clause]) 83 | for i, clause in enumerate(solutions) if clause not in solutions[:i] 84 | ] 85 | disjunctions = reduce(lambda x, e: x|e, conjunctions) 86 | return disjunctions.to_dnf() 87 | 88 | def dnf_non_disjoint_clauses(event): 89 | # Given an event in DNF, returns a dictionary R 90 | # such that R[j] = [i | i < j and event[i] intersects event[j]] 91 | event_factor = dnf_factor(event) 92 | solutions = [ 93 | {symbol: ev.solve() for symbol, ev in clause.items()} 94 | for clause in event_factor 95 | ] 96 | 97 | n_clauses = len(event_factor) 98 | overlap_dict = {} 99 | for i, j in combinations(range(n_clauses), 2): 100 | # Exit if any symbol in i does not intersect a symbol in j. 101 | intersections = ( 102 | solutions[i][symbol] & solutions[j][symbol] 103 | if (symbol in solutions[j]) else 104 | solutions[i][symbol] 105 | for symbol in solutions[i] 106 | ) 107 | if any(x is EmptySet for x in intersections): 108 | continue 109 | # Exit if any symbol in j is EmptySet. 110 | if any(solutions[j] is EmptySet for symbol in solutions[j]): 111 | continue 112 | # All symbols intersect, so clauses overlap. 113 | if j not in overlap_dict: 114 | overlap_dict[j] = [] 115 | overlap_dict[j].append(i) 116 | 117 | return overlap_dict 118 | 119 | def dnf_to_disjoint_union(event): 120 | # Given an event in DNF, returns an event in DNF where all the 121 | # clauses are disjoint from one another, by recursively solving the 122 | # identity E = (A or B or C) = (A) or (B and ~A) or (C and ~A and ~B). 123 | # Base case. 124 | if isinstance(event, (EventBasic, EventAnd)): 125 | return event 126 | # Find indexes of pairs of clauses that overlap. 127 | overlap_dict = dnf_non_disjoint_clauses(event) 128 | if not overlap_dict: 129 | return event 130 | # Create the cascading negated clauses. 131 | n_clauses = len(event.subexprs) 132 | clauses_disjoint = [ 133 | reduce( 134 | lambda state, event: state & ~event, 135 | (event.subexprs[j] for j in overlap_dict.get(i, [])), 136 | event.subexprs[i]) 137 | for i in range(n_clauses) 138 | ] 139 | # Recursively find the solutions for each clause. 140 | clauses_normalized = [dnf_normalize(clause) for clause in clauses_disjoint] 141 | solutions = [dnf_to_disjoint_union(c) for c in clauses_normalized if c] 142 | # Return the merged solution. 143 | return reduce(lambda a, b: a|b, solutions) 144 | -------------------------------------------------------------------------------- /src/math_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import isinf 5 | 6 | import numpy 7 | 8 | from scipy.special import logsumexp 9 | 10 | # Implementation of log1mexp and logdiffexp from PyMC3 math module. 11 | # https://github.com/pymc-devs/pymc3/blob/master/pymc3/math.py 12 | def log1mexp(x): 13 | if x < 0.683: 14 | return numpy.log(-numpy.expm1(-x)) 15 | else: 16 | return numpy.log1p(-numpy.exp(-x)) 17 | 18 | def logdiffexp(a, b): 19 | if b < a: 20 | return a + log1mexp(a - b) 21 | if allclose(b, a): 22 | return -float('inf') 23 | raise ValueError('Negative term in logdiffexp.') 24 | 25 | def lognorm(array): 26 | M = logsumexp(array) 27 | return [a - M for a in array] 28 | 29 | def logflip(logp, array, size, rng): 30 | p = numpy.exp(lognorm(logp)) 31 | return flip(p, array, size, rng) 32 | 33 | def flip(p, array, size, rng): 34 | p = normalize(p) 35 | return random(rng).choice(array, size=size, p=p) 36 | 37 | def normalize(p): 38 | s = float(sum(p)) 39 | return numpy.asarray(p, dtype=float) / s 40 | 41 | def allclose(values, x): 42 | return numpy.allclose(values, x) 43 | 44 | def isinf_pos(x): 45 | return isinf(x) and x > 0 46 | 47 | def isinf_neg(x): 48 | return isinf(x) and x < 0 49 | 50 | def random(x): 51 | return x or numpy.random 52 | 53 | int_or_isinf_neg = lambda a: isinf_neg(a) or float(a) == int(a) 54 | int_or_isinf_pos = lambda a: isinf_pos(a) or float(a) == int(a) 55 | float_to_int = lambda a: a if isinf(a) else int(a) 56 | -------------------------------------------------------------------------------- /src/poly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import os 5 | 6 | from itertools import chain 7 | from math import isinf 8 | 9 | import sympy 10 | 11 | from sympy import limit 12 | 13 | from .sets import EmptySet 14 | from .sets import ExtReals 15 | from .sets import FiniteReal 16 | from .sets import Interval 17 | from .sets import Reals 18 | from .sets import convert_sympy 19 | from .sets import oo 20 | from .sets import make_union 21 | from .sym_util import get_symbols 22 | from .timeout import timeout 23 | 24 | TIMEOUT_SYMBOLIC = 5 25 | 26 | def get_poly_symbol(expr): 27 | symbols = tuple(get_symbols(expr)) 28 | assert len(symbols) == 1 29 | return symbols[0] 30 | 31 | # ============================================================================== 32 | # Solving inequalities. 33 | 34 | def solve_poly_inequality(expr, b, strict, extended=None): 35 | # Handle infinite case. 36 | if isinf(b): 37 | return solve_poly_inequality_inf(expr, b, strict, extended=extended) 38 | # Bypass symbolic inference. 39 | if os.environ.get('SPPL_NO_SYMBOLIC'): 40 | return solve_poly_inequality_numerically(expr, b, strict) 41 | # Solve symbolically, if possible. 42 | try: 43 | with timeout(seconds=TIMEOUT_SYMBOLIC): 44 | result_symbolic = solve_poly_inequality_symbolically(expr, b, strict) 45 | except TimeoutError: 46 | result_symbolic = None 47 | if result_symbolic is not None: 48 | if not isinstance(result_symbolic, (sympy.ConditionSet, sympy.Intersection)): 49 | return convert_sympy(result_symbolic) 50 | # Solve numerically. 51 | return solve_poly_inequality_numerically(expr, b, strict) 52 | 53 | def solve_poly_inequality_symbolically(expr, b, strict): 54 | expr = (expr < b) if strict else (expr <= b) 55 | return sympy.solveset(expr, domain=sympy.Reals) 56 | 57 | def solve_poly_inequality_numerically(expr, b, strict): 58 | poly = expr - b 59 | symX = get_poly_symbol(expr) 60 | # Obtain numerical roots. 61 | roots = sympy.nroots(poly) 62 | zeros = sorted([r for r in roots if r.is_real]) 63 | if not zeros: 64 | return sympy.EmptySet 65 | # Construct intervals around roots. 66 | mk_intvl = lambda a, b: \ 67 | Interval(a, b, left_open=strict, right_open=strict) 68 | intervals = list(chain( 69 | [mk_intvl(-oo, zeros[0])], 70 | [mk_intvl(x, y) for x, y in zip(zeros, zeros[1:])], 71 | [mk_intvl(zeros[-1], oo)])) 72 | # Define probe points. 73 | xs_probe = list(chain( 74 | [zeros[0] - 1/2], 75 | [(i.left + i.right)/2 for i in intervals[1:-1] 76 | if isinstance(i, Interval)], 77 | [zeros[-1] + 1/2])) 78 | # Evaluate poly at the probe points. 79 | f_xs_probe = [poly.subs(symX, x) for x in xs_probe] 80 | # Return intervals where poly is less than zero. 81 | idxs = [i for i, fx in enumerate(f_xs_probe) if fx < 0] 82 | return make_union(*[intervals[i] for i in idxs]) 83 | 84 | def solve_poly_inequality_inf(expr, b, strict, extended=None): 85 | # Minimum value of polynomial is negative infinity. 86 | assert isinf(b) 87 | ext = True if extended is None else extended 88 | if b < 0: 89 | if strict or not ext: 90 | return EmptySet 91 | else: 92 | return solve_poly_equality_inf(expr, b) 93 | # Maximum value of polynomial is positive infinity. 94 | else: 95 | if strict: 96 | xinf = solve_poly_equality_inf(expr, -oo) if ext else EmptySet 97 | return Reals | xinf 98 | else: 99 | return ExtReals if ext else Reals 100 | 101 | # ============================================================================== 102 | # Solving equalities. 103 | 104 | def solve_poly_equality(expr, b): 105 | # Handle infinite case. 106 | if isinf(b): 107 | return solve_poly_equality_inf(expr, b) 108 | # Bypass symbolic inference. 109 | if os.environ.get('SPPL_NO_SYMBOLIC'): 110 | return solve_poly_equality_numerically(expr, b) 111 | # Solve symbolically, if possible. 112 | try: 113 | with timeout(seconds=TIMEOUT_SYMBOLIC): 114 | result_symbolic = solve_poly_equality_symbolically(expr, b) 115 | except TimeoutError: 116 | result_symbolic = None 117 | if result_symbolic is not None: 118 | if not isinstance(result_symbolic, (sympy.ConditionSet, sympy.Intersection)): 119 | return convert_sympy(result_symbolic) 120 | # Solve numerically. 121 | return solve_poly_equality_numerically(expr, b) 122 | 123 | def solve_poly_equality_symbolically(expr, b): 124 | soln_lt = solve_poly_inequality_symbolically(expr, b, False) 125 | soln_gt = solve_poly_inequality_symbolically(-expr, -b, False) 126 | return sympy.Intersection(soln_lt, soln_gt) 127 | 128 | def solve_poly_equality_numerically(expr, b): 129 | roots = sympy.nroots(expr-b) 130 | zeros = [r for r in roots if r.is_real] 131 | return FiniteReal(*zeros) 132 | 133 | def solve_poly_equality_inf(expr, b): 134 | assert isinf(b) 135 | symX = get_poly_symbol(expr) 136 | val_pos_inf = limit(expr, symX, oo) 137 | val_neg_inf = limit(expr, symX, -oo) 138 | check_equal = lambda x: isinf(x) and ((x > 0) if (b > 0) else (x < 0)) 139 | if check_equal(val_pos_inf) and check_equal(val_neg_inf): 140 | return FiniteReal(oo, -oo) 141 | if check_equal(val_pos_inf): 142 | return FiniteReal(oo) 143 | if check_equal(val_neg_inf): 144 | return FiniteReal(-oo) 145 | return EmptySet 146 | -------------------------------------------------------------------------------- /src/render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import exp 5 | 6 | from .spe import AtomicLeaf 7 | from .spe import DiscreteLeaf 8 | from .spe import LeafSPE 9 | from .spe import NominalLeaf 10 | from .spe import ProductSPE 11 | from .spe import RealLeaf 12 | from .spe import SumSPE 13 | 14 | def render_nested_lists_concise(spe): 15 | if isinstance(spe, LeafSPE): 16 | return [(str(k), str(v)) for k, v in spe.env.items()] 17 | if isinstance(spe, SumSPE): 18 | return ['+(%d)' % (len(spe.children),), 19 | # [exp(w) for w in spe.weights], 20 | [render_nested_lists_concise(c) for c in spe.children] 21 | ] 22 | if isinstance(spe, ProductSPE): 23 | return ['*(%d)' % (len(spe.children),), 24 | [render_nested_lists_concise(c) for c in spe.children] 25 | ] 26 | 27 | def render_nested_lists(spe): 28 | if isinstance(spe, NominalLeaf): 29 | return ['NominalLeaf', [ 30 | ['symbol', spe.symbol], 31 | ['env', dict(spe.env)], 32 | ['dist', {str(x): float(w) for x, w in spe.dist.items()}]] 33 | ] 34 | if isinstance(spe, AtomicLeaf): 35 | return ['AtomicLeaf', [ 36 | ['symbol', spe.symbol], 37 | ['value', spe.value], 38 | ['env', dict(spe.env)]] 39 | ] 40 | if isinstance(spe, RealLeaf): 41 | return ['RealLeaf', [ 42 | ['symbol', spe.symbol], 43 | ['env', dict(spe.env)], 44 | ['dist', (spe.dist.dist.name, spe.dist.args, spe.dist.kwds)], 45 | ['support', spe.support], 46 | ['conditioned', spe.conditioned]] 47 | ] 48 | if isinstance(spe, DiscreteLeaf): 49 | return ['DiscreteLeaf', [ 50 | ['symbol', spe.symbol], 51 | ['dist', (spe.dist.dist.name, spe.dist.args, spe.dist.kwds)], 52 | ['support', spe.support], 53 | ['conditioned', spe.conditioned]] 54 | ] 55 | if isinstance(spe, SumSPE): 56 | return ['SumSPE', [ 57 | ['symbols', list(spe.symbols)], 58 | ['weights', [exp(w) for w in spe.weights]], 59 | ['n_children', len(spe.children)], 60 | ['children', [render_nested_lists(c) for c in spe.children]]] 61 | ] 62 | if isinstance(spe, ProductSPE): 63 | return ['ProductSPE', [ 64 | ['symbols', list(spe.symbols)], 65 | ['n_children', len(spe.children)], 66 | ['children', [render_nested_lists(c) for c in spe.children]]] 67 | ] 68 | assert False, 'Unknown SPE type: %s' % (spe,) 69 | -------------------------------------------------------------------------------- /src/sym_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from collections import OrderedDict 5 | from itertools import chain 6 | from itertools import combinations 7 | from math import isinf 8 | 9 | import numpy 10 | import sympy 11 | 12 | from sympy.core.relational import Relational 13 | 14 | from .sets import FiniteReal 15 | from .sets import Interval 16 | 17 | def get_symbols(expr): 18 | atoms = expr.atoms() 19 | return [a for a in atoms if isinstance(a, sympy.Symbol)] 20 | 21 | def get_union(sets): 22 | return sets[0].union(*sets[1:]) 23 | 24 | def get_intersection(sets): 25 | return sets[0].intersection(*sets[1:]) 26 | 27 | def are_disjoint(sets): 28 | union = get_union(sets) 29 | return len(union) == sum(len(s) for s in sets) 30 | 31 | def are_identical(sets): 32 | intersection = get_intersection(sets) 33 | return all(len(s) == len(intersection) for s in sets) 34 | 35 | def binspace(start, stop, num=10): 36 | values = numpy.linspace(start, stop, num) 37 | bins = list(zip(values[:-1], values[1:])) 38 | return [Interval(*b) for b in bins] 39 | 40 | def powerset(values, start=0): 41 | s = list(values) 42 | subsets = (combinations(s, k) for k in range(start, len(s) + 1)) 43 | return chain.from_iterable(subsets) 44 | 45 | def partition_list_blocks(values): 46 | partition = OrderedDict([]) 47 | for i, v in enumerate(values): 48 | x = hash(v) 49 | if x not in partition: 50 | partition[x] = [] 51 | partition[x].append(i) 52 | return list(partition.values()) 53 | 54 | def partition_finite_real_contiguous(x): 55 | # Convert FiniteReal to list of FiniteReal, each with contiguous values. 56 | assert isinstance(x, FiniteReal) 57 | values = sorted(x.values) 58 | blocks = [[values[0]]] 59 | for y in values[1:]: 60 | expected = blocks[-1][-1] + 1 61 | if y == expected: 62 | blocks[-1].append(y) 63 | else: 64 | blocks.append([y]) 65 | return [FiniteReal(*v) for v in blocks] 66 | 67 | def sympify_number(x): 68 | if isinstance(x, (int, float)): 69 | return x 70 | msg = 'Expected a numeric term, not %s' % (x,) 71 | try: 72 | # String fallback in sympify has been deprecated since SymPy 1.6. Use 73 | # sympify(str(obj)) or sympy.core.sympify.converter or obj._sympy_ 74 | # instead. See https://github.com/sympy/sympy/issues/18066 for more 75 | # info. 76 | sym = sympy.sympify(str(x)) 77 | if not sym.is_number: 78 | raise TypeError(msg) 79 | return sym 80 | except (sympy.SympifyError, AttributeError, TypeError): 81 | raise TypeError(msg) 82 | 83 | def sym_log(x): 84 | assert 0 <= x 85 | if x == 0: 86 | return -float('inf') 87 | if isinf(x): 88 | return float('inf') 89 | return sympy.log(x) 90 | 91 | def sympy_solver(expr): 92 | # Sympy is buggy and slow. Use Transforms. 93 | symbols = get_symbols(expr) 94 | if len(symbols) != 1: 95 | raise ValueError('Expression "%s" needs exactly one symbol.' % (expr,)) 96 | 97 | if isinstance(expr, Relational): 98 | result = sympy.solveset(expr, domain=sympy.Reals) 99 | elif isinstance(expr, sympy.Or): 100 | subexprs = expr.args 101 | intervals = [sympy_solver(e) for e in subexprs] 102 | result = sympy.Union(*intervals) 103 | elif isinstance(expr, sympy.And): 104 | subexprs = expr.args 105 | intervals = [sympy_solver(e) for e in subexprs] 106 | result = sympy.Intersection(*intervals) 107 | elif isinstance(expr, sympy.Not): 108 | (notexpr,) = expr.args 109 | interval = sympy_solver(notexpr) 110 | result = interval.complement(sympy.Reals) 111 | else: 112 | raise ValueError('Expression "%s" has unknown type.' % (expr,)) 113 | 114 | if isinstance(result, sympy.ConditionSet): 115 | raise ValueError('Expression "%s" is not invertible.' % (expr,)) 116 | 117 | return result 118 | -------------------------------------------------------------------------------- /src/timeout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import signal 5 | 6 | class timeout: 7 | # Adapted from https://stackoverflow.com/a/22348885 8 | def __init__(self, seconds=1, error_message='Timeout'): 9 | self.seconds = seconds 10 | self.error_message = error_message 11 | def handle_timeout(self, signum, frame): 12 | raise TimeoutError(self.error_message) 13 | def __enter__(self): 14 | signal.signal(signal.SIGALRM, self.handle_timeout) 15 | signal.alarm(self.seconds) 16 | def __exit__(self, type, value, traceback): 17 | signal.alarm(0) 18 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | disabled_* 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | -------------------------------------------------------------------------------- /tests/test_ast_condition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.compilers.ast_to_spe import Condition 7 | from sppl.compilers.ast_to_spe import Id 8 | from sppl.compilers.ast_to_spe import Sample 9 | from sppl.compilers.ast_to_spe import Sequence 10 | from sppl.distributions import beta 11 | from sppl.distributions import choice 12 | from sppl.distributions import randint 13 | from sppl.math_util import allclose 14 | 15 | Y = Id('Y') 16 | X = Id('X') 17 | 18 | def test_condition_nominal(): 19 | command = Sequence( 20 | Sample(Y, choice({'a':.1, 'b':.1, 'c':.8})), 21 | Condition(Y << {'a', 'b'})) 22 | model = command.interpret() 23 | assert allclose(model.prob(Y << {'a'}), .5) 24 | assert allclose(model.prob(Y << {'b'}), .5) 25 | assert allclose(model.prob(Y << {'c'}), 0) 26 | 27 | def test_condition_real_discrete_range(): 28 | command = Sequence( 29 | Sample(Y, randint(low=0, high=4)), 30 | Condition(Y << {0, 1})) 31 | model = command.interpret() 32 | assert allclose(model.prob(Y << {0}), .5) 33 | assert allclose(model.prob(Y << {1}), .5) 34 | 35 | @pytest.mark.xfail(strict=True, reason='https://github.com/probcomp/sum-product-dsl/issues/77') 36 | def test_condition_real_discrete_no_range(): 37 | command = Sequence( 38 | Sample(Y, randint(low=0, high=4)), 39 | Condition(Y << {0, 2})) 40 | model = command.interpret() 41 | assert allclose(model.prob(Y << {0}), .5) 42 | assert allclose(model.prob(Y << {1}), .5) 43 | 44 | def test_condition_real_continuous(): 45 | command = Sequence( 46 | Sample(Y, beta(a=1, b=1)), 47 | Condition(Y < .5)) 48 | model = command.interpret() 49 | assert allclose(model.prob(Y < .5), 1) 50 | assert allclose(model.prob(Y > .5), 0) 51 | 52 | def test_condition_prob_zero(): 53 | with pytest.raises(Exception): 54 | Sequence( 55 | Sample(Y, {'a':.1, 'b':.1, 'c':.8}), 56 | Condition(Y << {'d'})).interpret() 57 | with pytest.raises(Exception): 58 | Sequence( 59 | Sample(Y, beta(a=1, b=1)), 60 | Condition(Y > 1)).interpret() 61 | -------------------------------------------------------------------------------- /tests/test_ast_for.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | from sppl.compilers.ast_to_spe import For 7 | from sppl.compilers.ast_to_spe import Id 8 | from sppl.compilers.ast_to_spe import IdArray 9 | from sppl.compilers.ast_to_spe import IfElse 10 | from sppl.compilers.ast_to_spe import Otherwise 11 | from sppl.compilers.ast_to_spe import Sample 12 | from sppl.compilers.ast_to_spe import Sequence 13 | from sppl.distributions import bernoulli 14 | from sppl.distributions import choice 15 | from sppl.math_util import allclose 16 | 17 | Y = Id('Y') 18 | X = IdArray('X', 5) 19 | Z = IdArray('Z', 5) 20 | 21 | def test_simple_model(): 22 | command = Sequence( 23 | Sample(Y, bernoulli(p=0.5)), 24 | For(0, 5, lambda i: 25 | Sample(X[i], bernoulli(p=1/(i+1))))) 26 | model = command.interpret() 27 | 28 | symbols = model.get_symbols() 29 | assert len(symbols) == 6 30 | assert Y in symbols 31 | assert X[0] in symbols 32 | assert X[1] in symbols 33 | assert X[2] in symbols 34 | assert X[3] in symbols 35 | assert X[4] in symbols 36 | assert model.logprob(X[0] << {1}) == log(1/1) 37 | assert model.logprob(X[1] << {1}) == log(1/2) 38 | assert model.logprob(X[2] << {1}) == log(1/3) 39 | assert model.logprob(X[3] << {1}) == log(1/4) 40 | assert model.logprob(X[4] << {1}) == log(1/5) 41 | 42 | def test_complex_model(): 43 | # Slow for larger number of repetitions 44 | # https://github.com/probcomp/sum-product-dsl/issues/43 45 | command = Sequence( 46 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), 47 | For(0, 3, lambda i: Sequence( 48 | Sample(Z[i], bernoulli(p=0.1)), 49 | IfElse( 50 | Y << {str(i)} | Z[i] << {0}, Sample(X[i], bernoulli(p=1/(i+1))), 51 | Otherwise, Sample(X[i], bernoulli(p=0.1)))))) 52 | model = command.interpret() 53 | assert allclose(model.prob(Y << {'0'}), 0.2) 54 | 55 | def test_complex_model_reorder(): 56 | command = Sequence( 57 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), 58 | For(0, 3, lambda i: 59 | Sample(Z[i], bernoulli(p=0.1))), 60 | For(0, 3, lambda i: 61 | IfElse( 62 | Y << {str(i)}, Sample(X[i], bernoulli(p=1/(i+1))), 63 | Z[i] << {0}, Sample(X[i], bernoulli(p=1/(i+1))), 64 | Otherwise, Sample(X[i], bernoulli(p=0.1))))) 65 | model = command.interpret() 66 | assert(allclose(model.prob(Y << {'0'}), 0.2)) 67 | 68 | def test_repeat_handcode_equivalence(): 69 | model_repeat = make_model_for() 70 | model_hand = make_model_handcode() 71 | 72 | assert allclose(model_repeat.prob(Y << {'0', '1'}), 0.4) 73 | assert allclose(model_repeat.prob(Z[0] << {0}), 0.5) 74 | assert allclose(model_repeat.prob(Z[0] << {1}), 0.5) 75 | 76 | event_condition = (X[0] << {1}) | (Y << {'1'}) 77 | model_repeat_condition = model_repeat.condition(event_condition) 78 | model_hand_condition = model_hand.condition(event_condition) 79 | 80 | for event in [ 81 | Y << {'0','1'}, 82 | Z[0] << {0}, 83 | Z[1] << {0}, 84 | X[0] << {0}, 85 | X[1] << {0}, 86 | ]: 87 | lp_repeat = model_repeat.logprob(event) 88 | lp_hand = model_hand.logprob(event) 89 | assert allclose(lp_hand, lp_repeat) 90 | 91 | lp_repeat_condition = model_repeat_condition.logprob(event) 92 | lp_hand_condition = model_hand_condition.logprob(event) 93 | assert allclose(lp_hand_condition, lp_repeat_condition) 94 | 95 | # This test case ensures that the duplicate subtrees in the mixture 96 | # components are pointers to the same object, and is obtained by 97 | # manually inspecting the rendering of the network generated by the 98 | # following code: 99 | # 100 | # from sppl.magics.render import render_graphviz 101 | # render_graphviz(model_repeat, show=True) 102 | # 103 | # See also test_cache_duplicate_subtrees.test_cache_complex_sum_of_product 104 | a = model_repeat.children[0].children[0].children[1].children[0] 105 | b = model_repeat.children[1].children[0] 106 | assert a is b 107 | 108 | # ============================================================================== 109 | # Helper functions. 110 | 111 | def make_model_for(n=2): 112 | command = Sequence( 113 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), 114 | For(0, n, lambda i: Sequence( 115 | Sample(Z[i], bernoulli(p=.5)), 116 | IfElse( 117 | (Y << {str(i)}) | (Z[i] << {0}), Sample(X[i], bernoulli(p=.1)), 118 | Otherwise, Sample(X[i], bernoulli(p=.5)))))) 119 | return command.interpret() 120 | 121 | def make_model_handcode(): 122 | command = Sequence( 123 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), 124 | Sample(Z[0], bernoulli(p=.5)), 125 | Sample(Z[1], bernoulli(p=.5)), 126 | IfElse( 127 | Y << {str(0)}, Sequence( 128 | Sample(X[0], bernoulli(p=.1)), 129 | IfElse( 130 | Z[1] << {0}, Sample(X[1], bernoulli(p=.1)), 131 | Otherwise, Sample(X[1], bernoulli(p=.5)))), 132 | Y << {str(1)}, Sequence( 133 | Sample(X[1], bernoulli(p=.1)), 134 | IfElse( 135 | Z[0] << {0}, Sample(X[0], bernoulli(p=.1)), 136 | Otherwise, Sample(X[0], bernoulli(p=.5)))), 137 | Otherwise, Sequence( 138 | IfElse( 139 | Z[0] << {0}, Sample(X[0], bernoulli(p=.1)), 140 | Otherwise, Sample(X[0], bernoulli(p=.5))), 141 | IfElse( 142 | Z[1] << {0}, Sample(X[1], bernoulli(p=.1)), 143 | Otherwise, Sample(X[1], bernoulli(p=.5)))))) 144 | return command.interpret() 145 | -------------------------------------------------------------------------------- /tests/test_ast_ifelse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from sppl.compilers.ast_to_spe import Id 5 | from sppl.compilers.ast_to_spe import IfElse 6 | from sppl.compilers.ast_to_spe import Sample 7 | from sppl.compilers.ast_to_spe import Sequence 8 | from sppl.compilers.ast_to_spe import Transform 9 | from sppl.distributions import bernoulli 10 | from sppl.distributions import randint 11 | from sppl.math_util import allclose 12 | 13 | Y = Id('Y') 14 | X = Id('X') 15 | 16 | def test_ifelse_zero_conditions(): 17 | command = Sequence( 18 | Sample(Y, randint(low=0, high=3)), 19 | IfElse( 20 | Y << {-1}, Transform(X, Y**(-1)), 21 | Y << {0}, Sample(X, bernoulli(p=1)), 22 | Y << {1}, Transform(X, Y), 23 | Y << {2}, Transform(X, Y**2), 24 | Y << {3}, Transform(X, Y**3), 25 | )) 26 | model = command.interpret() 27 | assert len(model.children) == 3 28 | assert len(model.weights) == 3 29 | assert allclose(model.weights[0], model.logprob(Y << {0})) 30 | assert allclose(model.weights[1], model.logprob(Y << {1})) 31 | assert allclose(model.weights[2], model.logprob(Y << {2})) 32 | -------------------------------------------------------------------------------- /tests/test_ast_switch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | import pytest 7 | 8 | from numpy import linspace 9 | 10 | from sppl.distributions import bernoulli 11 | from sppl.distributions import beta 12 | from sppl.distributions import randint 13 | from sppl.compilers.ast_to_spe import IfElse 14 | from sppl.compilers.ast_to_spe import Sample 15 | from sppl.compilers.ast_to_spe import Sequence 16 | from sppl.compilers.ast_to_spe import Switch 17 | from sppl.compilers.ast_to_spe import Id 18 | from sppl.math_util import allclose 19 | from sppl.math_util import logsumexp 20 | from sppl.sym_util import binspace 21 | 22 | Y = Id('Y') 23 | X = Id('X') 24 | 25 | def test_simple_model_eq(): 26 | command_switch = Sequence( 27 | Sample(Y, randint(low=0, high=4)), 28 | Switch(Y, range(0, 4), lambda i: 29 | Sample(X, bernoulli(p=1/(i+1))))) 30 | model_switch = command_switch.interpret() 31 | 32 | command_ifelse = Sequence( 33 | Sample(Y, randint(low=0, high=4)), 34 | IfElse( 35 | Y << {0}, Sample(X, bernoulli(p=1/(0+1))), 36 | Y << {1}, Sample(X, bernoulli(p=1/(1+1))), 37 | Y << {2}, Sample(X, bernoulli(p=1/(2+1))), 38 | Y << {3}, Sample(X, bernoulli(p=1/(3+1))), 39 | )) 40 | model_ifelse = command_ifelse.interpret() 41 | 42 | for model in [model_switch, model_ifelse]: 43 | symbols = model.get_symbols() 44 | assert symbols == {X, Y} 45 | assert allclose( 46 | model.logprob(X << {1}), 47 | logsumexp([-log(4) - log(i+1) for i in range(4)])) 48 | 49 | def test_simple_model_lte(): 50 | command_switch = Sequence( 51 | Sample(Y, beta(a=2, b=3)), 52 | Switch(Y, binspace(0, 1, 5), lambda i: 53 | Sample(X, bernoulli(p=i.right)))) 54 | model_switch = command_switch.interpret() 55 | 56 | command_ifelse = Sequence( 57 | Sample(Y, beta(a=2, b=3)), 58 | IfElse( 59 | Y <= 0, Sample(X, bernoulli(p=0)), 60 | Y <= 0.25, Sample(X, bernoulli(p=.25)), 61 | Y <= 0.50, Sample(X, bernoulli(p=.50)), 62 | Y <= 0.75, Sample(X, bernoulli(p=.75)), 63 | Y <= 1, Sample(X, bernoulli(p=1)), 64 | )) 65 | model_ifelse = command_ifelse.interpret() 66 | 67 | grid = [float(x) for x in linspace(0, 1, 5)] 68 | for model in [model_switch, model_ifelse]: 69 | symbols = model.get_symbols() 70 | assert symbols == {X, Y} 71 | assert allclose( 72 | model.logprob(X << {1}), 73 | logsumexp([ 74 | model.logprob((il < Y) <= ih) + log(ih) 75 | for il, ih in zip(grid[:-1], grid[1:]) 76 | ])) 77 | 78 | def test_simple_model_enumerate(): 79 | command_switch = Sequence( 80 | Sample(Y, randint(low=0, high=4)), 81 | Switch(Y, enumerate(range(0, 4)), lambda i,j: 82 | Sample(X, bernoulli(p=1/(i+j+1))))) 83 | model = command_switch.interpret() 84 | assert allclose(model.prob(Y<<{0} & (X << {1})), .25 * 1/(0+0+1)) 85 | assert allclose(model.prob(Y<<{1} & (X << {1})), .25 * 1/(1+1+1)) 86 | assert allclose(model.prob(Y<<{2} & (X << {1})), .25 * 1/(2+2+1)) 87 | assert allclose(model.prob(Y<<{3} & (X << {1})), .25 * 1/(3+3+1)) 88 | 89 | def test_error_range(): 90 | with pytest.raises(AssertionError): 91 | # Switch cases do not sum to one. 92 | command = Sequence( 93 | Sample(Y, randint(low=0, high=4)), 94 | Switch(Y, range(0, 3), lambda i: 95 | Sample(X, bernoulli(p=1/(i+1))))) 96 | command.interpret() 97 | 98 | def test_error_linspace(): 99 | with pytest.raises(AssertionError): 100 | # Switch cases do not sum to one. 101 | command = Sequence( 102 | Sample(Y, beta(a=2, b=3)), 103 | Switch(Y, linspace(0, .5, 5), lambda i: 104 | Sample(X, bernoulli(p=i)))) 105 | command.interpret() 106 | 107 | def test_error_binspace(): 108 | with pytest.raises(AssertionError): 109 | # Switch cases do not sum to one. 110 | command = Sequence( 111 | Sample(Y, beta(a=2, b=3)), 112 | Switch(Y, binspace(0, .5, 5), lambda i: 113 | Sample(X, bernoulli(p=i.right)))) 114 | command.interpret() 115 | -------------------------------------------------------------------------------- /tests/test_ast_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | from sppl.compilers.ast_to_spe import Id 7 | from sppl.compilers.ast_to_spe import IfElse 8 | from sppl.compilers.ast_to_spe import Otherwise 9 | from sppl.compilers.ast_to_spe import Sample 10 | from sppl.compilers.ast_to_spe import Sequence 11 | from sppl.compilers.ast_to_spe import Transform 12 | from sppl.distributions import bernoulli 13 | from sppl.distributions import norm 14 | from sppl.math_util import allclose 15 | 16 | X = Id('X') 17 | Y = Id('Y') 18 | Z = Id('Z') 19 | 20 | def test_simple_transform(): 21 | command = Sequence( 22 | Sample(X, norm(loc=0, scale=1)), 23 | Transform(Z, X**2)) 24 | model = command.interpret() 25 | assert model.get_symbols() == {Z, X} 26 | assert model.env == {Z:X**2, X:X} 27 | assert (model.logprob(Z > 0)) == 0 28 | 29 | def test_if_else_transform(): 30 | model = Sequence( 31 | Sample(X, norm(loc=0, scale=1)), 32 | IfElse( 33 | X > 0, Transform(Z, X**2), 34 | Otherwise, Transform(Z, X))).interpret() 35 | assert model.children[0].env == {X:X, Z:X**2} 36 | assert model.children[1].env == {X:X, Z:X} 37 | assert allclose(model.children[0].logprob(Z > 0), 0) 38 | assert allclose(model.children[1].logprob(Z > 0), -float('inf')) 39 | assert allclose(model.logprob(Z > 0), -log(2)) 40 | 41 | def test_if_else_transform_reverse(): 42 | command = Sequence( 43 | Sample(X, norm(loc=0, scale=1)), 44 | Sample(Y, bernoulli(p=0.5)), 45 | IfElse( 46 | Y << {0}, Transform(Z, X**2), 47 | Otherwise, Transform(Z, X))) 48 | model = command.interpret() 49 | assert allclose(model.logprob(Z > 0), log(3) - log(4)) 50 | -------------------------------------------------------------------------------- /tests/test_burglary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | ''' 5 | Burglary network example from: 6 | 7 | Artificial Intelligence: A Modern Approach (3rd Edition). 8 | Russel and Norvig, Fig 14.2 pp 512. 9 | ''' 10 | 11 | from sppl.compilers.ast_to_spe import Id 12 | from sppl.compilers.ast_to_spe import IfElse 13 | from sppl.compilers.ast_to_spe import Otherwise 14 | from sppl.compilers.ast_to_spe import Sample 15 | from sppl.compilers.ast_to_spe import Sequence 16 | from sppl.distributions import bernoulli 17 | 18 | Burglary = Id('Burglary') 19 | Earthquake = Id('Earthquake') 20 | Alarm = Id('Alarm') 21 | JohnCalls = Id('JohnCalls') 22 | MaryCalls = Id('MaryCalls') 23 | 24 | program = Sequence( 25 | Sample(Burglary, bernoulli(p=0.001)), 26 | Sample(Earthquake, bernoulli(p=0.002)), 27 | IfElse( 28 | Burglary << {1}, 29 | IfElse( 30 | Earthquake << {1}, Sample(Alarm, bernoulli(p=0.95)), 31 | Otherwise, Sample(Alarm, bernoulli(p=0.94))), 32 | Otherwise, 33 | IfElse( 34 | Earthquake << {1}, Sample(Alarm, bernoulli(p=0.29)), 35 | Otherwise, Sample(Alarm, bernoulli(p=0.001)))), 36 | IfElse( 37 | Alarm << {1}, Sequence( 38 | Sample(JohnCalls, bernoulli(p=0.90)), 39 | Sample(MaryCalls, bernoulli(p=0.70))), 40 | Otherwise, Sequence( 41 | Sample(JohnCalls, bernoulli(p=0.05)), 42 | Sample(MaryCalls, bernoulli(p=0.01))), 43 | )) 44 | model = program.interpret() 45 | 46 | def test_marginal_probability(): 47 | # Query on pp. 514. 48 | event = ((JohnCalls << {1}) & (MaryCalls << {1}) & (Alarm << {1}) 49 | & (Burglary << {0}) & (Earthquake << {0})) 50 | x = model.prob(event) 51 | assert str(x)[:8] == '0.000628' 52 | 53 | def test_conditional_probability(): 54 | # Query on pp. 523 55 | event = (JohnCalls << {1}) & (MaryCalls << {1}) 56 | model_condition = model.condition(event) 57 | x = model_condition.prob(Burglary << {1}) 58 | assert str(x)[:5] == '0.284' 59 | 60 | def test_mutual_information(): 61 | event_a = (JohnCalls << {1}) | (MaryCalls << {1}) 62 | event_b = (Burglary << {1}) & (Earthquake << {0}) 63 | print(model.mutual_information(event_a, event_b)) 64 | -------------------------------------------------------------------------------- /tests/test_cache_duplicate_subtrees.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | import numpy 7 | 8 | from sppl.distributions import bernoulli 9 | from sppl.distributions import choice 10 | from sppl.distributions import norm 11 | from sppl.spe import ProductSPE 12 | from sppl.spe import SumSPE 13 | from sppl.spe import spe_cache_duplicate_subtrees 14 | from sppl.transforms import Id 15 | 16 | rng = numpy.random.RandomState(1) 17 | 18 | W = Id('W') 19 | Y = Id('Y') 20 | X = [Id('X[0]'), Id('X[1]')] 21 | Z = [Id('Z[0]'), Id('Z[1]')] 22 | 23 | def test_cache_simple_leaf(): 24 | spe = .5 * (W >> norm(loc=0, scale=1)) | .5 * (W >> norm(loc=0, scale=1)) 25 | assert spe.children[0] is not spe.children[1] 26 | spe_cached = spe_cache_duplicate_subtrees(spe, {}) 27 | assert spe_cached.children[0] is spe_cached.children[1] 28 | 29 | def test_cache_simple_sum_of_product(): 30 | spe \ 31 | = 0.3 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=1))) \ 32 | | 0.7 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=2))) 33 | spe_cached = spe_cache_duplicate_subtrees(spe, {}) 34 | assert spe_cached.children[0].children[0] is spe_cached.children[1].children[0] 35 | 36 | def test_cache_complex_sum_of_product(): 37 | # Test case adapted from the SPE generated by 38 | # test_repeat.make_model_repeat(n=2) 39 | duplicate_subtrees = [None, None] 40 | for i in range(2): 41 | duplicate_subtrees[i] = SumSPE([ 42 | ProductSPE([ 43 | (X[0] >> bernoulli(p=.1)), 44 | SumSPE([ 45 | (Z[0] >> bernoulli(p=.5)) 46 | & (Y >> choice({'0':.1, '1': .9})), 47 | (Z[0] >> bernoulli(p=.1)) 48 | & (Y >> choice({'0':.9, '1': .1})) 49 | ], weights=[log(.730), log(.270)]) 50 | ]), 51 | ProductSPE([ 52 | Z[0] >> bernoulli(p=.1), 53 | Y >> choice({'0':.9, '1':.1}), 54 | X[0] >> bernoulli(p=.5), 55 | ]), 56 | ], weights=[log(.925), log(.075)]) 57 | 58 | assert duplicate_subtrees[0] == duplicate_subtrees[1] 59 | assert duplicate_subtrees[0] is not duplicate_subtrees[1] 60 | 61 | left_subtree = ProductSPE([ 62 | X[1] >> bernoulli(p=.5), 63 | SumSPE([ 64 | ProductSPE([ 65 | duplicate_subtrees[0], 66 | Z[1] >> bernoulli(p=.5), 67 | ]), 68 | ProductSPE([ 69 | Z[1] >> bernoulli(p=.7), 70 | SumSPE([ 71 | Y >> choice({'0':.3, '1':.7}) 72 | & X[0] >> bernoulli(p=.1) 73 | & Z[0] >> bernoulli(p=.1), 74 | Y >> choice({'0':.7, '1':.3}) 75 | & X[0] >> bernoulli(p=.5) 76 | & Z[0] >> bernoulli(p=.5), 77 | ], weights=[log(.9), log(.1)]) 78 | ]) 79 | ], weights=[log(.783), log(.217)]) 80 | ]) 81 | 82 | right_subtree = ProductSPE([ 83 | Z[1] >> bernoulli(p=.8), 84 | X[1] >> bernoulli(p=.1), 85 | duplicate_subtrees[1] 86 | ]) 87 | 88 | spe = .92 * left_subtree | .08 * right_subtree 89 | 90 | spe_cached = spe_cache_duplicate_subtrees(spe, {}) 91 | assert spe_cached.children[0].children[1].children[0].children[0] is duplicate_subtrees[0] 92 | assert spe_cached.children[1].children[2] is duplicate_subtrees[0] 93 | -------------------------------------------------------------------------------- /tests/test_clickgraph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.distributions import bernoulli 7 | from sppl.distributions import beta 8 | from sppl.distributions import randint 9 | from sppl.distributions import uniform 10 | 11 | from sppl.compilers.ast_to_spe import For 12 | from sppl.compilers.ast_to_spe import Id 13 | from sppl.compilers.ast_to_spe import IdArray 14 | from sppl.compilers.ast_to_spe import IfElse 15 | from sppl.compilers.ast_to_spe import Sample 16 | from sppl.compilers.ast_to_spe import Sequence 17 | from sppl.compilers.ast_to_spe import Switch 18 | from sppl.compilers.ast_to_spe import Transform 19 | 20 | from sppl.sym_util import binspace 21 | 22 | simAll = Id('simAll') 23 | sim = IdArray('sim', 5) 24 | p1 = IdArray('p1', 5) 25 | p2 = IdArray('p2', 5) 26 | clickA = IdArray('clickA', 5) 27 | clickB = IdArray('clickB', 5) 28 | 29 | ns = 5 30 | nd = ns - 1 31 | 32 | def get_command_randint(): 33 | return Sequence( 34 | Sample(simAll, randint(low=0, high=ns)), 35 | For(0, 5, lambda k: 36 | Switch (simAll, range(0, ns), lambda i: 37 | Sequence ( 38 | Sample (sim[k], bernoulli(p=i/nd)), 39 | Sample (p1[k], randint(low=0, high=ns)), 40 | IfElse( 41 | sim[k] << {1}, 42 | Sequence( 43 | Transform (p2[k], p1[k]), 44 | Switch (p1[k], range(ns), lambda j: 45 | Sequence( 46 | Sample(clickA[k], bernoulli(p=i/nd)), 47 | Sample(clickB[k], bernoulli(p=i/nd))))), 48 | True, 49 | Sequence( 50 | Sample (p2[k], randint(low=0, high=ns)), 51 | Switch (p1[k], range(ns), lambda j: 52 | Sample(clickA[k], bernoulli(p=j/nd))), 53 | Switch (p2[k], range(ns), lambda j: 54 | Sample (clickB[k], bernoulli(p=j/nd))))))))) 55 | 56 | def get_command_beta(): 57 | return Sequence( 58 | Sample(simAll, beta(a=2, b=3)), 59 | For(0, 5, lambda k: 60 | Switch (simAll, binspace(0, 1, ns), lambda i: 61 | Sequence ( 62 | Sample (sim[k], bernoulli(p=i.right)), 63 | Sample (p1[k], uniform()), 64 | IfElse( 65 | sim[k] << {1}, 66 | Sequence( 67 | Transform (p2[k], p1[k]), 68 | Switch (p1[k], binspace(0, 1, ns), lambda j: 69 | Sequence( 70 | Sample(clickA[k], bernoulli(p=i.right)), 71 | Sample(clickB[k], bernoulli(p=i.right))))), 72 | True, 73 | Sequence( 74 | Sample (p2[k], uniform()), 75 | Switch (p1[k], binspace(0, 1, ns), lambda j: 76 | Sample(clickA[k], bernoulli(p=j.right))), 77 | Switch (p2[k], binspace(0, 1, ns), lambda j: 78 | Sample (clickB[k], bernoulli(p=j.right))))))))) 79 | 80 | @pytest.mark.parametrize('get_command', [get_command_randint, get_command_beta]) 81 | def test_clickgraph_crash__ci_(get_command): 82 | command = get_command() 83 | model = command.interpret() 84 | model_condition = model.condition( 85 | (clickA[0] << {1}) & (clickB[0] << {1}) 86 | & (clickA[1] << {1}) & (clickB[1] << {1}) 87 | & (clickA[2] << {1}) & (clickB[2] << {1}) 88 | & (clickA[3] << {0}) & (clickB[3] << {0}) 89 | & (clickA[4] << {0}) & (clickB[4] << {0})) 90 | probabilities = [model_condition.prob(simAll << {i}) for i in range(ns)] 91 | assert all(p <= probabilities[nd-1] for p in probabilities) 92 | -------------------------------------------------------------------------------- /tests/test_dnf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.dnf import dnf_factor 7 | from sppl.dnf import dnf_non_disjoint_clauses 8 | from sppl.dnf import dnf_to_disjoint_union 9 | 10 | from sppl.transforms import EventOr 11 | from sppl.transforms import Exp 12 | from sppl.transforms import Id 13 | from sppl.transforms import Log 14 | from sppl.transforms import Sqrt 15 | 16 | (X0, X1, X2, X3, X4, X5) = [Id("X%d" % (i,)) for i in range(6)] 17 | 18 | events = [ 19 | X0 < 0, 20 | (X1<<(0,1)) & (X2<0), 21 | (X1<0) | (X2<0), 22 | (X0<0) | ((X1<0) & ~(X2<0)), 23 | ((X0<0) & ~(0 ~(A | B | C) | ~D 76 | assert event.to_dnf() == (~A & ~B & ~C) | ~D 77 | 78 | event = ~((A | ~B | C) & ~D) 79 | # => ~(A | B | C) | ~D 80 | assert event.to_dnf() == (~A & B & ~C) | D 81 | 82 | def test_dnf_factor(): 83 | E00 = Exp(X0) > 0 84 | E01 = X0 < 10 85 | E10 = X1 < 10 86 | E20 = (X2**2 - X2*3) < 0 87 | E30 = X3 > 10 88 | E31 = (Sqrt(2*X3)) < 0 89 | E40 = X4 > 0 90 | E41 = X4 << [1, 5] 91 | E50 = 10*Log(X5) + 9 > 5 92 | 93 | event = (E00) 94 | event_dnf = event.to_dnf() 95 | dnf = dnf_factor(event_dnf) 96 | assert len(dnf) == 1 97 | assert dnf[0][X0] == E00 98 | 99 | event = E00 & E01 100 | event_dnf = event.to_dnf() 101 | dnf = dnf_factor(event_dnf) 102 | assert len(dnf) == 1 103 | assert dnf[0][X0] == E00 & E01 104 | 105 | event = E00 | E01 106 | event_dnf = event.to_dnf() 107 | dnf = dnf_factor(event_dnf) 108 | assert len(dnf) == 2 109 | assert dnf[0][X0] == E00 110 | assert dnf[1][X0] == E01 111 | 112 | event = E00 | (E01 & E10) 113 | event_dnf = event.to_dnf() 114 | dnf = dnf_factor(event_dnf, {X0: 0, X1: 0}) 115 | assert len(dnf) == 2 116 | assert dnf[0][0] == E00 117 | assert dnf[1][0] == E01 & E10 118 | 119 | event = (E00 & E01 & E10 & E30 & E40) | (E20 & E50 & E31 & ~E41) 120 | # For the second clause we have: 121 | # ~E41 = (-oo, 1) U (1, 5) U (5, oo) 122 | # so the second clause becomes 123 | # = (E20 & E50 & E31 & ((-oo, 1) U (1, 5) U (5, oo))) 124 | # = (E20 & E50 & E31 & (-oo, 1)) 125 | # or (E20 & E50 & E31 & (1, 5)) 126 | # or (E20 & E50 & E31 & (5, oo)) 127 | event_dnf = event.to_dnf() 128 | event_factor = dnf_factor(event_dnf) 129 | assert len(event_factor) == 4 130 | # clause 0 131 | assert len(event_factor[0]) == 4 132 | assert event_factor[0][X0] == E00 & E01 133 | assert event_factor[0][X1] == E10 134 | assert event_factor[0][X3] == E30 135 | assert event_factor[0][X4] == E40 136 | # clause 1 137 | assert len(event_factor[1]) == 4 138 | assert event_factor[1][X3] == E31 139 | assert event_factor[1][X2] == E20 140 | assert event_factor[1][X4] == (X4 < 1) 141 | assert event_factor[1][X5] == E50 142 | # clause 2 143 | assert len(event_factor[2]) == 4 144 | assert event_factor[2][X3] == E31 145 | assert event_factor[2][X2] == E20 146 | assert event_factor[2][X4] == (1 < (X4 < 5)) 147 | assert event_factor[2][X5] == E50 148 | # clause 3 149 | assert len(event_factor[3]) == 4 150 | assert event_factor[3][X3] == E31 151 | assert event_factor[3][X2] == E20 152 | assert event_factor[3][X4] == (5 < X4) 153 | assert event_factor[3][X5] == E50 154 | 155 | def test_dnf_factor_1(): 156 | A = Exp(X0) > 0 157 | B = X0 < 10 158 | C = X1 < 10 159 | D = X2 < 0 160 | 161 | event = A & B & C & ~D 162 | event_dnf = event.to_dnf() 163 | event_factor = dnf_factor(event_dnf, {X0:0, X1:0, X2:0, X3:1, X4:1, X5:2}) 164 | assert len(event_factor) == 1 165 | assert event_factor[0][0] == event 166 | 167 | def test_dnf_factor_2(): 168 | A = X0 < 1 169 | B = X4 < 1 170 | C = X5 < 1 171 | event = A & B & C 172 | event_dnf = event.to_dnf() 173 | event_factor = dnf_factor(event_dnf, {X0:0, X1:0, X2:0, X3:1, X4:1, X5:2}) 174 | assert len(event_factor) == 1 175 | assert event_factor[0][0] == A 176 | assert event_factor[0][1] == B 177 | assert event_factor[0][2] == C 178 | 179 | def test_dnf_factor_3(): 180 | A = (Exp(X0) > 0) 181 | B = X0 < 10 182 | C = X1 < 10 183 | D = X4 > 0 184 | E = (X2**2 - 3*X2) << (0, 10, 100) 185 | F = (10*Log(X5) + 9) > 5 186 | G = X4 < 4 187 | 188 | event = (A & B & C & ~D) | (E & F & G) 189 | event_dnf = event.to_dnf() 190 | event_factor = dnf_factor(event_dnf, {X0:0, X1:0, X2:0, X3:1, X4:1, X5:2}) 191 | assert len(event_factor) == 2 192 | assert event_factor[0][0] == A & B & C 193 | assert event_factor[0][1] == ~D 194 | assert event_factor[1][0] == E 195 | assert event_factor[1][1] == G 196 | assert event_factor[1][2] == F 197 | 198 | def test_dnf_non_disjoint_clauses(): 199 | X = Id('X') 200 | Y = Id('Y') 201 | Z = Id('Z') 202 | 203 | event = (X > 0) | (Y < 0) 204 | overlaps = dnf_non_disjoint_clauses(event) 205 | assert overlaps == {1: [0]} 206 | 207 | event = (X > 0) | ((X < 0) & (Y < 0)) 208 | overlaps = dnf_non_disjoint_clauses(event) 209 | assert not overlaps 210 | 211 | event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1)) 212 | overlaps = dnf_non_disjoint_clauses(event) 213 | assert overlaps == {2: [0]} 214 | 215 | event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1) & (Z > 1)) 216 | overlaps = dnf_non_disjoint_clauses(event) 217 | assert not overlaps 218 | 219 | event = ((X**2 < 9)) | (1 < X) 220 | overlaps = dnf_non_disjoint_clauses(event) 221 | assert overlaps == {1: [0]} 222 | 223 | event = ((X**2 < 9) & (0 < X < 1)) | (1 < X) 224 | overlaps = dnf_non_disjoint_clauses(event) 225 | assert not overlaps 226 | 227 | def test_event_to_disjiont_union_numerical(): 228 | X = Id('X') 229 | Y = Id('Y') 230 | Z = Id('Z') 231 | 232 | for event in [ 233 | (X > 0) | (X < 3), 234 | (X > 0) | (Y < 3), 235 | ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0), 236 | ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X <<{1, 3}), 237 | ]: 238 | event_dnf = dnf_to_disjoint_union(event) 239 | assert not dnf_non_disjoint_clauses(event_dnf) 240 | 241 | def test_event_to_disjoint_union_nominal(): 242 | X = Id('X') 243 | Y = Id('Y') 244 | event = (X << {'1'}) | (X << {'1', '2'}) 245 | assert dnf_to_disjoint_union(event) == X << {'1', '2'} 246 | 247 | event = (X << {'1'}) | ~(Y << {'1'}) 248 | assert dnf_to_disjoint_union(event) == EventOr([ 249 | (X << {'1'}), ~(Y << {'1'}) & ~(X << {'1'}) 250 | ]) 251 | 252 | def test_event_to_disjoint_union_five(): 253 | # This test case can be visualized as follows: 254 | # fig, ax = plt.subplots() 255 | # ax.add_patch(Rectangle((2, 5), 6, 3, fill=False)) 256 | # ax.add_patch(Rectangle((3, 7), 1, 4, fill=False)) 257 | # ax.add_patch(Rectangle((3.5, 2), 2, 7, fill=False)) 258 | # ax.add_patch(Rectangle((4.5, 1), 2, 5, fill=False)) 259 | # ax.add_patch(Rectangle((5, 7), 2, 3, fill=False)) 260 | # 261 | # // ((2 < X < 8) & (5 < Y < 8)) 262 | # // ((3 < X < 4) & (8 <= Y < 11)) 263 | # // ((3.5 < X < 5.5) & (2 < Y <= 5)) 264 | # // ((5.5 <= X < 6.5) & (1 < Y <= 5)) 265 | # // ((4.5 < X < 5.5) & (1 < Y <= 2)) 266 | # // ((5 < X < 7) & (8 <= Y < 10)) 267 | # 268 | # fig, ax = plt.subplots() 269 | # ax.add_patch(Rectangle((2, 5), 6, 3, fill=False)) 270 | # ax.add_patch(Rectangle((3, 8), 1, 3, fill=False)) 271 | # ax.add_patch(Rectangle((3.5, 2), 2, 3, fill=False)) 272 | # ax.add_patch(Rectangle((5.5, 1), 1, 4, fill=False)) 273 | # ax.add_patch(Rectangle((4.5, 1), 1, 1, fill=False)) 274 | # ax.add_patch(Rectangle((5, 8), 2, 2, fill=False)) 275 | X = Id('X') 276 | Y = Id('Y') 277 | E1 = ((2 < X) < 8) & ((5 < Y) < 8) 278 | E2 = ((3 < X) < 4) & ((7 < Y) < 11) 279 | E3 = ((3.5 < X) < 5.5) & ((2 < Y) < 7) 280 | E4 = ((4.5 < X) < 6.5) & ((1 < Y) < 6) 281 | E5 = ((5 < X) < 7) & ((7 < Y) < 10) 282 | event = E1 | E2 | E3 | E4 | E5 283 | dnf_to_disjoint_union(event) 284 | -------------------------------------------------------------------------------- /tests/test_event_evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.transforms import Exp 7 | from sppl.transforms import Id 8 | from sppl.transforms import Log 9 | 10 | X = Id('X') 11 | Y = Id('Y') 12 | W = Id('W') 13 | Z = Id('Z') 14 | 15 | def test_event_basic_invertible(): 16 | expr = X**2 + 10*X 17 | 18 | event = (0 < expr) 19 | assert event.evaluate({X: 10}) 20 | assert not (~event).evaluate({X: 10}) 21 | 22 | event = (0 > expr) 23 | assert not event.evaluate({X: 10}) 24 | assert (~event).evaluate({X: 10}) 25 | 26 | event = (0 < expr) < 100 27 | for val in [0, 10]: 28 | assert not event.evaluate({X: val}) 29 | assert (~event).evaluate({X: val}) 30 | 31 | event = (0 <= expr) <= 100 32 | x_eq_100 = (expr << {100}).solve() 33 | for val in [0, list(x_eq_100)[0]]: 34 | assert event.evaluate({X: val}) 35 | assert not (~event).evaluate({X: val}) 36 | 37 | event = expr << {11} 38 | assert event.evaluate({X: 1}) 39 | assert not event.evaluate({X: 3}) 40 | assert not (~event).evaluate({X: 1}) 41 | assert (~event).evaluate({X :3}) 42 | 43 | with pytest.raises(ValueError): 44 | event.evaluate({Y: 10}) 45 | 46 | def test_event_compound(): 47 | expr0 = abs(X)**2 + 10*abs(X) 48 | expr1 = Y-1 49 | 50 | event = (0 < expr0) | (0 < expr1) 51 | assert not event.evaluate({X: 0, Y: 0}) 52 | for i in range(1, 100): 53 | assert event.evaluate({X: i, Y: i}) 54 | with pytest.raises(ValueError): 55 | event.evaluate({X: 0}) 56 | with pytest.raises(ValueError): 57 | event.evaluate({Y: 0}) 58 | 59 | event = (100 <= expr0) & ((expr0 < 200) | (3 < expr1)) 60 | assert not event.evaluate({X: 0, Y: 0}) 61 | 62 | x_eq_150 = (expr0 << {150}).solve() 63 | assert event.evaluate({X: list(x_eq_150)[0], Y: 0}) 64 | assert event.evaluate({X: list(x_eq_150)[1], Y: 0}) 65 | 66 | x_eq_500 = (expr0 << {500}).solve() 67 | assert not event.evaluate({X: list(x_eq_500)[1], Y: 4}) 68 | assert event.evaluate({X: list(x_eq_500)[1], Y: 5}) 69 | 70 | def test_event_solve_multi(): 71 | event = (Exp(abs(3*X**2)) > 1) | (Log(Y) < 0.5) 72 | with pytest.raises(ValueError): 73 | event.solve() 74 | event = (Exp(abs(3*X**2)) > 1) & (Log(Y) < 0.5) 75 | with pytest.raises(ValueError): 76 | event.solve() 77 | -------------------------------------------------------------------------------- /tests/test_indian_gpa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | ''' 5 | Indian GPA example from: 6 | 7 | Discrete-Continuous Mixtures in Probabilistic Programming: Generalized 8 | Semantics and Inference Algorithms, Wu et. al., ICML 2018. 9 | https://arxiv.org/pdf/1806.02027.pdf 10 | ''' 11 | 12 | import pytest 13 | 14 | from sppl.compilers.ast_to_spe import Id 15 | from sppl.compilers.ast_to_spe import IfElse 16 | from sppl.compilers.ast_to_spe import Sample 17 | from sppl.compilers.ast_to_spe import Sequence 18 | from sppl.compilers.sppl_to_python import SPPL_Compiler 19 | from sppl.distributions import atomic 20 | from sppl.distributions import choice 21 | from sppl.distributions import uniform 22 | from sppl.math_util import allclose 23 | from sppl.sets import Interval 24 | from sppl.spe import ExposedSumSPE 25 | 26 | Nationality = Id('Nationality') 27 | Perfect = Id('Perfect') 28 | GPA = Id('GPA') 29 | 30 | def model_no_latents(): 31 | return \ 32 | 0.5 * ( # American student 33 | 0.99 * (GPA >> uniform(loc=0, scale=4)) | \ 34 | 0.01 * (GPA >> atomic(loc=4))) | \ 35 | 0.5 * ( # Indian student 36 | 0.99 * (GPA >> uniform(loc=0, scale=10)) | \ 37 | 0.01 * (GPA >> atomic(loc=10))) 38 | 39 | def model_exposed(): 40 | return ExposedSumSPE( 41 | spe_weights=(Nationality >> choice({'India': 0.5, 'USA': 0.5})), 42 | children={ 43 | # American student. 44 | 'USA': ExposedSumSPE( 45 | spe_weights=(Perfect >> choice({'True': 0.01, 'False': 0.99})), 46 | children={ 47 | 'False' : GPA >> uniform(loc=0, scale=4), 48 | 'True' : GPA >> atomic(loc=4), 49 | }), 50 | # Indian student. 51 | 'India': ExposedSumSPE( 52 | spe_weights=(Perfect >> choice({'True': 0.01, 'False': 0.99})), 53 | children={ 54 | 'False' : GPA >> uniform(loc=0, scale=10), 55 | 'True' : GPA >> atomic(loc=10), 56 | })}, 57 | ) 58 | 59 | def model_ifelse_exhuastive(): 60 | command = Sequence( 61 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})), 62 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})), 63 | IfElse( 64 | (Nationality << {'India'}) & (Perfect << {'False'}), 65 | Sample(GPA, uniform(loc=0, scale=10)) 66 | , 67 | (Nationality << {'India'}) & (Perfect << {'True'}), 68 | Sample(GPA, atomic(loc=10)) 69 | , 70 | (Nationality << {'USA'}) & (Perfect << {'False'}), 71 | Sample(GPA, uniform(loc=0, scale=4)) 72 | , 73 | (Nationality << {'USA'}) & (Perfect << {'True'}), 74 | Sample(GPA, atomic(loc=4)))) 75 | return command.interpret() 76 | 77 | def model_ifelse_non_exhuastive(): 78 | Nationality = Id('Nationality') 79 | Perfect = Id('Perfect') 80 | GPA = Id('GPA') 81 | command = Sequence( 82 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})), 83 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})), 84 | IfElse( 85 | (Nationality << {'India'}) & (Perfect << {'False'}), 86 | Sample(GPA, uniform(loc=0, scale=10)) 87 | , 88 | (Nationality << {'India'}) & (Perfect << {'True'}), 89 | Sample(GPA, atomic(loc=10)) 90 | , 91 | (Nationality << {'USA'}) & (Perfect << {'False'}), 92 | Sample(GPA, uniform(loc=0, scale=4)) 93 | , 94 | True, 95 | Sample(GPA, atomic(loc=4)))) 96 | return command.interpret() 97 | 98 | def model_ifelse_nested(): 99 | Nationality = Id('Nationality') 100 | Perfect = Id('Perfect') 101 | GPA = Id('GPA') 102 | command = Sequence( 103 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})), 104 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})), 105 | IfElse( 106 | Nationality << {'India'}, 107 | IfElse( 108 | Perfect << {'True'}, Sample(GPA, atomic(loc=10)), 109 | Perfect << {'False'}, Sample(GPA, uniform(scale=10)), 110 | ), 111 | Nationality << {'USA'}, 112 | IfElse( 113 | Perfect << {'True'}, Sample(GPA, atomic(loc=4)), 114 | Perfect << {'False'}, Sample(GPA, uniform(scale=4)), 115 | ))) 116 | return command.interpret() 117 | 118 | def model_perfect_nested(): 119 | Nationality = Id('Nationality') 120 | Perfect = Id('Perfect') 121 | GPA = Id('GPA') 122 | command = Sequence( 123 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})), 124 | IfElse( 125 | Nationality << {'India'}, Sequence( 126 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})), 127 | IfElse( 128 | Perfect << {'True'}, Sample(GPA, atomic(loc=10)), 129 | True, Sample(GPA, uniform(scale=10)) 130 | )), 131 | Nationality << {'USA'}, Sequence( 132 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})), 133 | IfElse( 134 | Perfect << {'True'}, Sample(GPA, atomic(loc=4)), 135 | True, Sample(GPA, uniform(scale=4)), 136 | )))) 137 | return command.interpret() 138 | 139 | def model_ifelse_exhuastive_compiled(): 140 | compiler = SPPL_Compiler(''' 141 | Nationality ~= choice({'India': 0.5, 'USA': 0.5}) 142 | Perfect ~= choice({'True': 0.01, 'False': 0.99}) 143 | if (Nationality == 'India') & (Perfect == 'False'): 144 | GPA ~= uniform(loc=0, scale=10) 145 | elif (Nationality == 'India') & (Perfect == 'True'): 146 | GPA ~= atomic(loc=10) 147 | elif (Nationality == 'USA') & (Perfect == 'False'): 148 | GPA ~= uniform(loc=0, scale=4) 149 | elif (Nationality == 'USA') & (Perfect == 'True'): 150 | GPA ~= atomic(loc=4) 151 | ''') 152 | namespace = compiler.execute_module() 153 | return namespace.model 154 | 155 | def model_ifelse_non_exhuastive_compiled(): 156 | compiler = SPPL_Compiler(''' 157 | Nationality ~= choice({'India': 0.5, 'USA': 0.5}) 158 | Perfect ~= choice({'True': 0.01, 'False': 0.99}) 159 | if (Nationality == 'India') & (Perfect == 'False'): 160 | GPA ~= uniform(loc=0, scale=10) 161 | elif (Nationality == 'India') & (Perfect == 'True'): 162 | GPA ~= atomic(loc=10) 163 | elif (Nationality == 'USA') & (Perfect == 'False'): 164 | GPA ~= uniform(loc=0, scale=4) 165 | else: 166 | GPA ~= atomic(loc=4) 167 | ''') 168 | namespace = compiler.execute_module() 169 | return namespace.model 170 | 171 | def model_ifelse_nested_compiled(): 172 | compiler = SPPL_Compiler(''' 173 | Nationality ~= choice({'India': 0.5, 'USA': 0.5}) 174 | Perfect ~= choice({'True': 0.01, 'False': 0.99}) 175 | if (Nationality == 'India'): 176 | if (Perfect == 'False'): 177 | GPA ~= uniform(loc=0, scale=10) 178 | else: 179 | GPA ~= atomic(loc=10) 180 | elif (Nationality == 'USA'): 181 | if (Perfect == 'False'): 182 | GPA ~= uniform(loc=0, scale=4) 183 | elif (Perfect == 'True'): 184 | GPA ~= atomic(loc=4) 185 | ''') 186 | namespace = compiler.execute_module() 187 | return namespace.model 188 | 189 | def model_perfect_nested_compiled(): 190 | compiler = SPPL_Compiler(''' 191 | Nationality ~= choice({'India': 0.5, 'USA': 0.5}) 192 | if (Nationality == 'India'): 193 | Perfect ~= choice({'True': 0.01, 'False': 0.99}) 194 | if (Perfect == 'False'): 195 | GPA ~= uniform(loc=0, scale=10) 196 | else: 197 | GPA ~= atomic(loc=10) 198 | elif (Nationality == 'USA'): 199 | Perfect ~= choice({'True': 0.01, 'False': 0.99}) 200 | if (Perfect == 'False'): 201 | GPA ~= uniform(loc=0, scale=4) 202 | else: 203 | GPA ~= atomic(loc=4) 204 | ''') 205 | namespace = compiler.execute_module() 206 | return namespace.model 207 | 208 | @pytest.mark.parametrize('get_model', [ 209 | # Manual 210 | model_no_latents, 211 | model_exposed, 212 | # Interpreter 213 | model_ifelse_exhuastive, 214 | model_ifelse_non_exhuastive, 215 | model_ifelse_nested, 216 | model_perfect_nested, 217 | # Compiler 218 | model_ifelse_exhuastive_compiled, 219 | model_ifelse_non_exhuastive_compiled, 220 | model_ifelse_nested_compiled, 221 | model_perfect_nested_compiled, 222 | ]) 223 | def test_prior(get_model): 224 | model = get_model() 225 | GPA = Id('GPA') 226 | assert allclose(model.prob(GPA << {10}), 0.5*0.01) 227 | assert allclose(model.prob(GPA << {4}), 0.5*0.01) 228 | assert allclose(model.prob(GPA << {5}), 0) 229 | assert allclose(model.prob(GPA << {1}), 0) 230 | 231 | assert allclose(model.prob((2 < GPA) < 4), 232 | 0.5*0.99*0.5 + 0.5*0.99*0.2) 233 | assert allclose(model.prob((2 <= GPA) < 4), 234 | 0.5*0.99*0.5 + 0.5*0.99*0.2) 235 | assert allclose(model.prob((2 < GPA) <= 4), 236 | 0.5*(0.99*0.5 + 0.01) + 0.5*0.99*0.2) 237 | assert allclose(model.prob((2 < GPA) <= 8), 238 | 0.5*(0.99*0.5 + 0.01) + 0.5*0.99*0.6) 239 | assert allclose(model.prob((2 < GPA) < 10), 240 | 0.5*(0.99*0.5 + 0.01) + 0.5*0.99*0.8) 241 | assert allclose(model.prob((2 < GPA) <= 10), 242 | 0.5*(0.99*0.5 + 0.01) + 0.5*(0.99*0.8 + 0.01)) 243 | 244 | assert allclose(model.prob(((2 <= GPA) < 4) | (7 < GPA)), 245 | (0.5*0.99*0.5 + 0.5*0.99*0.2) + (0.5*(0.99*0.3 + 0.01))) 246 | 247 | assert allclose(model.prob(((2 <= GPA) < 4) & (7 < GPA)), 0) 248 | 249 | def test_condition(): 250 | model = model_no_latents() 251 | GPA = Id('GPA') 252 | model_condition = model.condition(GPA << {4} | GPA << {10}) 253 | assert len(model_condition.children) == 2 254 | assert model_condition.children[0].support == Interval.Ropen(4, 5) 255 | assert model_condition.children[1].support == Interval.Ropen(10, 11) 256 | 257 | model_condition = model.condition((0 < GPA < 4)) 258 | assert len(model_condition.children) == 2 259 | assert model_condition.children[0].support \ 260 | == model_condition.children[1].support 261 | assert allclose( 262 | model_condition.children[0].logprob(GPA < 1), 263 | model_condition.children[1].logprob(GPA < 1)) 264 | 265 | def test_logpdf(): 266 | model = model_no_latents() 267 | assert allclose(0.005, model.pdf({GPA: 4})) 268 | assert allclose(0.005, model.pdf({GPA: 10})) 269 | -------------------------------------------------------------------------------- /tests/test_logpdf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | import pytest 7 | 8 | from sppl.distributions import atomic 9 | from sppl.distributions import choice 10 | from sppl.distributions import discrete 11 | from sppl.distributions import gamma 12 | from sppl.distributions import norm 13 | from sppl.distributions import poisson 14 | from sppl.math_util import allclose 15 | from sppl.math_util import isinf_neg 16 | from sppl.math_util import logsumexp 17 | from sppl.spe import SumSPE 18 | from sppl.spe import ProductSPE 19 | from sppl.transforms import Id 20 | 21 | X = Id('X') 22 | Y = Id('Y') 23 | Z = Id('Z') 24 | 25 | def test_logpdf_real_continuous(): 26 | spe = (X >> norm()) 27 | assert allclose(spe.logpdf({X: 0}), norm().dist.logpdf(0)) 28 | 29 | def test_logpdf_real_discrete(): 30 | spe = (X >> poisson(mu=2)) 31 | assert isinf_neg(spe.logpdf({X: 1.5})) 32 | assert isinf_neg(spe.logpdf({X: '1'})) 33 | assert not isinf_neg(spe.logpdf({X: 0})) 34 | 35 | def test_logpdf_nominal(): 36 | spe = (X >> choice({'a' : .6, 'b': .4})) 37 | assert isinf_neg(spe.logpdf({X: 1.5})) 38 | allclose(spe.logpdf({X: 'a'}), log(.6)) 39 | 40 | def test_logpdf_mixture_real_continuous_continuous(): 41 | spe = X >> (.3*norm() | .7*gamma(a=1)) 42 | assert allclose( 43 | spe.logpdf({X: .5}), 44 | logsumexp([ 45 | log(.3) + spe.children[0].logpdf({X: 0.5}), 46 | log(.7) + spe.children[1].logpdf({X: 0.5}), 47 | ])) 48 | 49 | @pytest.mark.xfail 50 | def test_logpdf_mixture_real_continuous_discrete(): 51 | spe = X >> (.3*norm() | .7*poisson(mu=1)) 52 | assert allclose( 53 | spe.logpdf(X << {.5}), 54 | logsumexp([ 55 | log(.3) + spe.children[0].logpdf({X: 0.5}), 56 | log(.7) + spe.children[1].logpdf({X: 0.5}), 57 | ])) 58 | assert False, 'Invalid base measure addition' 59 | 60 | def test_logpdf_mixture_nominal(): 61 | spe = SumSPE([X >> norm(), X >> choice({'a':.1, 'b':.9})], [log(.4), log(.6)]) 62 | assert allclose( 63 | spe.logpdf({X: .5}), 64 | log(.4) + spe.children[0].logpdf({X: .5})) 65 | assert allclose( 66 | spe.logpdf({X: 'a'}), 67 | log(.6) + spe.children[1].logpdf({X: 'a'})) 68 | 69 | def test_logpdf_error_event(): 70 | spe = (X >> norm()) 71 | with pytest.raises(Exception): 72 | spe.logpdf(X < 1) 73 | 74 | def test_logpdf_error_transform_base(): 75 | spe = (X >> norm()) 76 | with pytest.raises(Exception): 77 | spe.logpdf({X**2: 0}) 78 | 79 | def test_logpdf_error_transform_env(): 80 | spe = (X >> norm()).transform(Z, X**2) 81 | with pytest.raises(Exception): 82 | spe.logpdf({Z: 0}) 83 | 84 | def test_logpdf_bivariate(): 85 | spe = (X >> norm()) & (Y >> choice({'a': .5, 'b': .5})) 86 | assert allclose( 87 | spe.logpdf({X: 0, Y: 'a'}), 88 | norm().dist.logpdf(0) + log(.5)) 89 | 90 | def test_logpdf_lexicographic_either(): 91 | spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.1, 2:.9})) \ 92 | | .25*(X >> atomic(loc=0) & Y >> norm() & Z >> norm()) 93 | # Lexicographic, Branch 1 94 | assignment = {X:0, Y:0, Z:2} 95 | assert allclose( 96 | spe.logpdf(assignment), 97 | log(.75) + norm().dist.logpdf(0) + log(1) + log(.9)) 98 | assert isinstance(spe.constrain(assignment), ProductSPE) 99 | # Lexicographic, Branch 2 100 | assignment = {X:0, Y:0, Z:0} 101 | assert allclose( 102 | spe.logpdf(assignment), 103 | log(.25) + log(1) + norm().dist.logpdf(0) + norm().dist.logpdf(0)) 104 | assert isinstance(spe.constrain(assignment), ProductSPE) 105 | 106 | def test_logpdf_lexicographic_both(): 107 | spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.2, 2:.8})) \ 108 | | .25*(X >> discrete({1:.5, 2:.5}) & Y >> norm() & Z >> atomic(loc=2)) 109 | # Lexicographic, Mix 110 | assignment = {X:1, Y:0, Z:2} 111 | assert allclose( 112 | spe.logpdf(assignment), 113 | logsumexp([ 114 | log(.75) + norm().dist.logpdf(1) + log(1) + log(.8), 115 | log(.25) + log(.5) + norm().dist.logpdf(0) + log(1)])) 116 | assert isinstance(spe.constrain(assignment), SumSPE) 117 | -------------------------------------------------------------------------------- /tests/test_mutual_information.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import exp 5 | from math import log 6 | 7 | import numpy 8 | import pytest 9 | 10 | from sppl.distributions import norm 11 | from sppl.math_util import allclose 12 | from sppl.math_util import isinf_neg 13 | from sppl.math_util import logdiffexp 14 | from sppl.math_util import logsumexp 15 | from sppl.spe import Memo 16 | from sppl.transforms import Id 17 | 18 | prng = numpy.random.RandomState(1) 19 | 20 | def entropy(spe, A, memo): 21 | lpA1 = spe.logprob(A, memo=memo) 22 | lpA0 = logdiffexp(0, lpA1) 23 | e1 = -exp(lpA1) * lpA1 if not isinf_neg(lpA1) else 0 24 | e0 = -exp(lpA0) * lpA0 if not isinf_neg(lpA0) else 0 25 | return e1 + e0 26 | def entropyc(spe, A, B, memo): 27 | lpB1 = spe.logprob(B) 28 | lpB0 = logdiffexp(0, lpB1) 29 | lp11 = spe.logprob(B & A, memo=memo) 30 | lp10 = spe.logprob(B & ~A, memo=memo) 31 | lp01 = spe.logprob(~B & A, memo=memo) 32 | # lp00 = self.logprob(~B & ~A, memo) 33 | lp00 = logdiffexp(0, logsumexp([lp11, lp10, lp01])) 34 | m11 = exp(lp11) * (lpB1 - lp11) if not isinf_neg(lp11) else 0 35 | m10 = exp(lp10) * (lpB1 - lp10) if not isinf_neg(lp10) else 0 36 | m01 = exp(lp01) * (lpB0 - lp01) if not isinf_neg(lp01) else 0 37 | m00 = exp(lp00) * (lpB0 - lp00) if not isinf_neg(lp00) else 0 38 | return m11 + m10 + m01 + m00 39 | 40 | def check_mi_properties(spe, A, B, memo): 41 | miAB = spe.mutual_information(A, B, memo=memo) 42 | miAA = spe.mutual_information(A, A, memo=memo) 43 | miBB = spe.mutual_information(B, B, memo=memo) 44 | eA = entropy(spe, A, memo=memo) 45 | eB = entropy(spe, B, memo=memo) 46 | eAB = entropyc(spe, A, B, memo=memo) 47 | eBA = entropyc(spe, B, A, memo=memo) 48 | assert allclose(miAA, eA) 49 | assert allclose(miBB, eB) 50 | assert allclose(miAB, eA - eAB) 51 | assert allclose(miAB, eB - eBA) 52 | 53 | @pytest.mark.parametrize('memo', [Memo(), None]) 54 | def test_mutual_information_four_clusters(memo): 55 | X = Id('X') 56 | Y = Id('Y') 57 | spe \ 58 | = 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ 59 | | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ 60 | | 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ 61 | | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ 62 | 63 | A = X > 2 64 | B = Y > 2 65 | samples = spe.sample(100, prng) 66 | mi = spe.mutual_information(A, B, memo=memo) 67 | assert allclose(mi, 0) 68 | check_mi_properties(spe, A, B, memo) 69 | 70 | event = ((X>2) & (Y<2) | ((X<2) & (Y>2))) 71 | spe_condition = spe.condition(event) 72 | samples = spe_condition.sample(100, prng) 73 | assert all(event.evaluate(sample) for sample in samples) 74 | mi = spe_condition.mutual_information(X > 2, Y > 2) 75 | assert allclose(mi, log(2)) 76 | 77 | check_mi_properties(spe, (X>1) | (Y<1), (Y>2), memo) 78 | check_mi_properties(spe, (X>1) | (Y<1), (X>1.5) & (Y>2), memo) 79 | -------------------------------------------------------------------------------- /tests/test_nominal_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from fractions import Fraction 5 | from math import log 6 | 7 | import numpy 8 | import pytest 9 | 10 | from sppl.distributions import choice 11 | from sppl.math_util import allclose 12 | from sppl.math_util import isinf_neg 13 | from sppl.sets import FiniteNominal 14 | from sppl.transforms import Id 15 | 16 | def test_nominal_distribution(): 17 | X = Id('X') 18 | spe = X >> choice({ 19 | 'a': Fraction(1, 5), 20 | 'b': Fraction(1, 5), 21 | 'c': Fraction(3, 5), 22 | }) 23 | assert allclose(spe.logprob(X << {'a'}), log(Fraction(1, 5))) 24 | assert allclose(spe.logprob(X << {'b'}), log(Fraction(1, 5))) 25 | assert allclose(spe.logprob(X << {'a', 'c'}), log(Fraction(4, 5))) 26 | assert allclose( 27 | spe.logprob((X << {'a'}) & ~(X << {'b'})), 28 | log(Fraction(1, 5))) 29 | assert allclose( 30 | spe.logprob((X << {'a', 'b'}) & ~(X << {'b'})), 31 | log(Fraction(1, 5))) 32 | assert spe.logprob((X << {'d'})) == -float('inf') 33 | assert spe.logprob((X << ())) == -float('inf') 34 | 35 | samples = spe.sample(100) 36 | assert all(s[X] in spe.support for s in samples) 37 | 38 | samples = spe.sample_subset([X], 100) 39 | assert all(len(s)==1 and s[X] in spe.support for s in samples) 40 | 41 | with pytest.raises(Exception): 42 | spe.sample_subset(['f'], 100) 43 | 44 | predicate = lambda X: (X in {'a', 'b'}) or X in {'c'} 45 | samples = spe.sample_func(predicate, 100) 46 | assert all(samples) 47 | 48 | predicate = lambda X: (not (X in {'a', 'b'})) and (not (X in {'c'})) 49 | samples = spe.sample_func(predicate, 100) 50 | assert not any(samples) 51 | 52 | func = lambda X: 1 if X in {'a'} else None 53 | samples = spe.sample_func(func, 100, prng=numpy.random.RandomState(1)) 54 | assert sum(1 for s in samples if s == 1) > 12 55 | assert sum(1 for s in samples if s is None) > 70 56 | 57 | with pytest.raises(ValueError): 58 | spe.sample_func(lambda Y: Y, 100) 59 | 60 | spe_condition = spe.condition(X<<{'a', 'b'}) 61 | assert spe_condition.support == FiniteNominal('a', 'b', 'c') 62 | assert allclose(spe_condition.logprob(X << {'a'}), -log(2)) 63 | assert allclose(spe_condition.logprob(X << {'b'}), -log(2)) 64 | assert spe_condition.logprob(X << {'c'}) == -float('inf') 65 | 66 | assert isinf_neg(spe_condition.logprob(X**2 << {1})) 67 | 68 | with pytest.raises(ValueError): 69 | spe.condition(X << {'python'}) 70 | assert spe.condition(~(X << {'python'})) == spe 71 | -------------------------------------------------------------------------------- /tests/test_parse_distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | import pytest 7 | 8 | from sppl.distributions import DistributionMix 9 | from sppl.distributions import bernoulli 10 | from sppl.distributions import choice 11 | from sppl.distributions import discrete 12 | from sppl.distributions import norm 13 | from sppl.distributions import poisson 14 | from sppl.distributions import rv_discrete 15 | from sppl.distributions import uniformd 16 | from sppl.math_util import allclose 17 | from sppl.sets import FiniteNominal 18 | from sppl.sets import Interval 19 | from sppl.sets import inf as oo 20 | from sppl.spe import ContinuousLeaf 21 | from sppl.spe import DiscreteLeaf 22 | from sppl.spe import NominalLeaf 23 | from sppl.spe import SumSPE 24 | from sppl.transforms import Id 25 | 26 | X = Id('X') 27 | 28 | def test_simple_parse_real(): 29 | assert isinstance(.3*bernoulli(p=.1), DistributionMix) 30 | a = .3*bernoulli(p=.1) | .5 * norm() | .2*poisson(mu=7) 31 | spe = a(X) 32 | assert isinstance(spe, SumSPE) 33 | assert allclose(spe.weights, [log(.3), log(.5), log(.2)]) 34 | assert isinstance(spe.children[0], DiscreteLeaf) 35 | assert isinstance(spe.children[1], ContinuousLeaf) 36 | assert isinstance(spe.children[2], DiscreteLeaf) 37 | assert spe.children[0].support == Interval(0, 1) 38 | assert spe.children[1].support == Interval(-oo, oo) 39 | assert spe.children[2].support == Interval(0, oo) 40 | 41 | def test_simple_parse_nominal(): 42 | assert isinstance(.7 * choice({'a': .1, 'b': .9}), DistributionMix) 43 | a = .3*bernoulli(p=.1) | .7*choice({'a': .1, 'b': .9}) 44 | spe = a(X) 45 | assert isinstance(spe, SumSPE) 46 | assert allclose(spe.weights, [log(.3), log(.7)]) 47 | assert isinstance(spe.children[0], DiscreteLeaf) 48 | assert isinstance(spe.children[1], NominalLeaf) 49 | assert spe.children[0].support == Interval(0, 1) 50 | assert spe.children[1].support == FiniteNominal('a', 'b') 51 | 52 | def test_error(): 53 | with pytest.raises(TypeError): 54 | 'a'*bernoulli(p=.1) 55 | a = .1 *bernoulli(p=.1) | .7*poisson(mu=8) 56 | with pytest.raises(Exception): 57 | a(X) 58 | 59 | def test_parse_rv_discrete(): 60 | for dist in [ 61 | rv_discrete(values=((1, 2, 10), (.3, .5, .2))), 62 | discrete({1: .3, 2: .5, 10: .2}) 63 | ]: 64 | spe = dist(X) 65 | assert spe.support == Interval(1, 10) 66 | assert allclose(spe.prob(X<<{1}), .3) 67 | assert allclose(spe.prob(X<<{2}), .5) 68 | assert allclose(spe.prob(X<<{10}), .2) 69 | assert allclose(spe.prob(X<=10), 1) 70 | 71 | dist = uniformd(values=((1, 2, 10, 0))) 72 | spe = dist(X) 73 | assert spe.support == Interval(0, 10) 74 | assert allclose(spe.prob(X<<{1}), .25) 75 | assert allclose(spe.prob(X<<{2}), .25) 76 | assert allclose(spe.prob(X<<{10}), .25) 77 | assert allclose(spe.prob(X<<{0}), .25) 78 | -------------------------------------------------------------------------------- /tests/test_parse_spe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | import pytest 7 | 8 | from sppl.spe import ContinuousLeaf 9 | from sppl.spe import PartialSumSPE 10 | from sppl.spe import ProductSPE 11 | from sppl.spe import SumSPE 12 | 13 | from sppl.distributions import gamma 14 | from sppl.distributions import norm 15 | from sppl.transforms import Id 16 | 17 | from sppl.math_util import allclose 18 | 19 | X = Id('X') 20 | Y = Id('Y') 21 | Z = Id('Z') 22 | 23 | def test_mul_leaf(): 24 | for y in [0.3 * (X >> norm()), (X >> norm()) * 0.3]: 25 | assert isinstance(y, PartialSumSPE) 26 | assert len(y.weights) == 1 27 | assert allclose(float(sum(y.weights)), 0.3) 28 | 29 | def test_sum_leaf(): 30 | # Cannot sum leaves without weights. 31 | with pytest.raises(TypeError): 32 | (X >> norm()) | (X >> gamma(a=1)) 33 | # Cannot sum a leaf with a partial sum. 34 | with pytest.raises(TypeError): 35 | 0.3*(X >> norm()) | (X >> gamma(a=1)) 36 | # Cannot sum a leaf with a partial sum. 37 | with pytest.raises(TypeError): 38 | (X >> norm()) | 0.3*(X >> gamma(a=1)) 39 | # Wrong symbol. 40 | with pytest.raises(ValueError): 41 | 0.4*(X >> norm()) | 0.6*(Y >> gamma(a=1)) 42 | # Sum exceeds one. 43 | with pytest.raises(ValueError): 44 | 0.4*(X >> norm()) | 0.7*(Y >> gamma(a=1)) 45 | 46 | y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) 47 | assert isinstance(y, PartialSumSPE) 48 | assert len(y.weights) == 2 49 | assert allclose(float(y.weights[0]), 0.4) 50 | assert allclose(float(y.weights[1]), 0.3) 51 | 52 | y = 0.4*(X >> norm()) | 0.6*(X >> gamma(a=1)) 53 | assert isinstance(y, SumSPE) 54 | assert len(y.weights) == 2 55 | assert allclose(float(y.weights[0]), log(0.4)) 56 | assert allclose(float(y.weights[1]), log(0.6)) 57 | # Sum exceeds one. 58 | with pytest.raises(TypeError): 59 | y | 0.7 * (X >> norm()) 60 | 61 | y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) | 0.1*(X >> norm()) 62 | assert isinstance(y, PartialSumSPE) 63 | assert len(y.weights) == 3 64 | assert allclose(float(y.weights[0]), 0.4) 65 | assert allclose(float(y.weights[1]), 0.3) 66 | assert allclose(float(y.weights[2]), 0.1) 67 | 68 | y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) | 0.3*(X >> norm()) 69 | assert isinstance(y, SumSPE) 70 | assert len(y.weights) == 3 71 | assert allclose(float(y.weights[0]), log(0.4)) 72 | assert allclose(float(y.weights[1]), log(0.3)) 73 | assert allclose(float(y.weights[2]), log(0.3)) 74 | 75 | with pytest.raises(TypeError): 76 | (0.3)*(0.3*(X >> norm())) 77 | with pytest.raises(TypeError): 78 | (0.3*(X >> norm())) * (0.3) 79 | with pytest.raises(TypeError): 80 | 0.3*(0.3*(X >> norm()) | 0.5*(X >> norm())) 81 | 82 | w = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) 83 | assert isinstance(w, PartialSumSPE) 84 | 85 | def test_product_leaf(): 86 | with pytest.raises(TypeError): 87 | 0.3*(X >> gamma(a=1)) & (X >> norm()) 88 | with pytest.raises(TypeError): 89 | (X >> norm()) & 0.3*(X >> gamma(a=1)) 90 | with pytest.raises(ValueError): 91 | (X >> norm()) & (X >> gamma(a=1)) 92 | 93 | y = (X >> norm()) & (Y >> gamma(a=1)) & (Z >> norm()) 94 | assert isinstance(y, ProductSPE) 95 | assert len(y.children) == 3 96 | assert y.get_symbols() == frozenset([X, Y, Z]) 97 | 98 | def test_sum_of_sums(): 99 | w \ 100 | = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) \ 101 | | 0.7*(0.1*(X >> norm()) | 0.9*(X >> norm())) 102 | assert isinstance(w, SumSPE) 103 | assert len(w.children) == 4 104 | assert allclose(float(w.weights[0]), log(0.3) + log(0.4)) 105 | assert allclose(float(w.weights[1]), log(0.3) + log(0.6)) 106 | assert allclose(float(w.weights[2]), log(0.7) + log(0.1)) 107 | assert allclose(float(w.weights[3]), log(0.7) + log(0.9)) 108 | 109 | w \ 110 | = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) \ 111 | | 0.2*(0.1*(X >> norm()) | 0.9*(X >> norm())) 112 | assert isinstance(w, PartialSumSPE) 113 | assert allclose(float(w.weights[0]), 0.3) 114 | assert allclose(float(w.weights[1]), 0.2) 115 | 116 | a = w | 0.5*(X >> gamma(a=1)) 117 | assert isinstance(a, SumSPE) 118 | assert len(a.children) == 5 119 | assert allclose(float(a.weights[0]), log(0.3) + log(0.4)) 120 | assert allclose(float(a.weights[1]), log(0.3) + log(0.6)) 121 | assert allclose(float(a.weights[2]), log(0.2) + log(0.1)) 122 | assert allclose(float(a.weights[3]), log(0.2) + log(0.9)) 123 | assert allclose(float(a.weights[4]), log(0.5)) 124 | 125 | # Wrong symbol. 126 | with pytest.raises(ValueError): 127 | z = w | 0.4*(Y >> gamma(a=1)) 128 | 129 | def test_or_and(): 130 | with pytest.raises(ValueError): 131 | (0.3*(X >> norm()) | 0.7*(Y >> gamma(a=1))) & (Z >> norm()) 132 | a = (0.3*(X >> norm()) | 0.7*(X >> gamma(a=1))) & (Z >> norm()) 133 | assert isinstance(a, ProductSPE) 134 | assert isinstance(a.children[0], SumSPE) 135 | assert isinstance(a.children[1], ContinuousLeaf) 136 | -------------------------------------------------------------------------------- /tests/test_poly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from sympy import Poly as SymPoly 5 | from sympy import Rational 6 | from sympy import sqrt as SymSqrt 7 | from sympy.abc import x 8 | 9 | from sppl.poly import solve_poly_equality 10 | from sppl.poly import solve_poly_inequality 11 | 12 | from sppl.math_util import allclose 13 | 14 | from sppl.sets import EmptySet 15 | from sppl.sets import ExtReals 16 | from sppl.sets import FiniteReal 17 | from sppl.sets import Interval 18 | from sppl.sets import Reals 19 | from sppl.sets import Union 20 | from sppl.sets import inf as oo 21 | 22 | def test_solve_poly_inequaltiy_pos_inf(): 23 | assert solve_poly_inequality(x**2-10*x+100, oo, True) == Reals 24 | assert solve_poly_inequality(x**2-10*x+100, oo, False) == ExtReals 25 | 26 | assert solve_poly_inequality(-x**3+10*x, oo, False) == ExtReals 27 | assert solve_poly_inequality(-x**3+10*x, oo, True) == Reals | FiniteReal(oo) 28 | 29 | assert solve_poly_inequality(x**3-10*x, oo, False) == ExtReals 30 | assert solve_poly_inequality(x**3-10*x, oo, True) == Reals | FiniteReal(-oo) 31 | 32 | 33 | def test_solve_poly_inequaltiy_neg_inf(): 34 | assert solve_poly_inequality(x**2-10*x+100, -oo, True) is EmptySet 35 | assert solve_poly_inequality(x**2-10*x+100, -oo, False) is EmptySet 36 | 37 | assert solve_poly_inequality(x**3-10*x, -oo, True) is EmptySet 38 | assert solve_poly_inequality(x**3-10*x, -oo, False) == FiniteReal(-oo) 39 | 40 | assert solve_poly_inequality(-x**2+10*x+100, -oo, True) is EmptySet 41 | assert solve_poly_inequality(-x**2+10*x+100, -oo, False) == FiniteReal(-oo, oo) 42 | 43 | 44 | p_quadratic = SymPoly((x-SymSqrt(2)/10)*(x+Rational(10, 7)), x) 45 | expr_quadratic = p_quadratic.args[0] 46 | def test_solve_poly_equality_quadratic_zero(): 47 | roots = solve_poly_equality(expr_quadratic, 0) 48 | assert roots == FiniteReal(SymSqrt(2)/10, -Rational(10,7)) 49 | def test_solve_poly_inequality_quadratic_zero(): 50 | interval = solve_poly_inequality(expr_quadratic, 0, False) 51 | assert interval == Interval(-Rational(10,7), SymSqrt(2)/10) 52 | interval = solve_poly_inequality(expr_quadratic, 0, True) 53 | assert interval == Interval.open(-Rational(10,7), SymSqrt(2)/10) 54 | 55 | xe1_quad0 = -5/7 + SymSqrt(2)/20 + SymSqrt(2)*SymSqrt(700*SymSqrt(2) + 14849)/140 56 | xe1_quad1 = -SymSqrt(2)*SymSqrt(700*SymSqrt(2) + 14849)/140 - 5/7 + SymSqrt(2)/20 57 | def test_solve_poly_equality_quadratic_one(): 58 | roots = solve_poly_equality(expr_quadratic, 1) 59 | # SymPy is not smart enough to simplify irrational roots symbolically 60 | # so check numerical equality of the symbolic roots. 61 | assert len(roots) == 2 62 | assert any(allclose(float(x), float(xe1_quad0)) for x in roots) 63 | assert any(allclose(float(x), float(xe1_quad1)) for x in roots) 64 | 65 | 66 | p_cubic_int = SymPoly((x-1)*(x+2)*(x-11), x) 67 | expr_cubic_int = p_cubic_int.args[0] 68 | def test_solve_poly_equality_cubic_int_zero(): 69 | roots = solve_poly_equality(expr_cubic_int, 0) 70 | assert roots == FiniteReal(-2, 1, 11) 71 | def test_solve_poly_inequality_cubic_int_zero(): 72 | interval = solve_poly_inequality(expr_cubic_int, 0, False) 73 | assert interval == Interval(-oo, -2) | Interval(1, 11) 74 | interval = solve_poly_inequality(expr_cubic_int, 0, True) 75 | assert interval == Interval.open(-oo, -2) | Interval.open(1, 11) 76 | 77 | xe1_cubic_int0 = -1.97408387376586 78 | xe1_cubic_int1 = 0.966402009973818 79 | xe1_cubic_int2 = 11.007681863792 80 | def test_solve_poly_equality_cubic_int_one(): 81 | roots = solve_poly_equality(expr_cubic_int, 1) 82 | assert len(roots) == 3 83 | assert any(allclose(float(x), xe1_cubic_int0) for x in roots) 84 | assert any(allclose(float(x), xe1_cubic_int1) for x in roots) 85 | assert any(allclose(float(x), xe1_cubic_int2) for x in roots) 86 | def test_solve_poly_inequality_cubic_int_one(): 87 | interval = solve_poly_inequality(expr_cubic_int, 1, True) 88 | assert isinstance(interval, Union) 89 | assert len(interval.args)==2 90 | # First interval. 91 | assert interval.args[0].left == -oo 92 | assert allclose(float(interval.args[0].right), xe1_cubic_int0) 93 | assert interval.args[0].right_open 94 | # Second interval. 95 | assert allclose(float(interval.args[1].left), xe1_cubic_int1) 96 | assert allclose(float(interval.args[1].right), xe1_cubic_int2) 97 | assert interval.args[1].left_open 98 | assert interval.args[1].right_open 99 | 100 | interval = solve_poly_inequality(-1*expr_cubic_int, -1, True) 101 | assert isinstance(interval, Union) 102 | assert len(interval.args) == 2 103 | # First interval. 104 | assert allclose(float(interval.args[0].left), xe1_cubic_int0) 105 | assert allclose(float(interval.args[0].right), xe1_cubic_int1) 106 | # Second interval. 107 | assert allclose(float(interval.args[1].left), xe1_cubic_int2) 108 | assert interval.args[1].right == oo 109 | 110 | 111 | p_cubic_irrat = SymPoly((x-SymSqrt(2)/10)*(x+Rational(10, 7))*(x-SymSqrt(5)), x) 112 | expr_cubic_irrat = p_cubic_irrat.args[0] 113 | def test_solve_poly_equality_cubic_irrat_zero(): 114 | roots = solve_poly_equality(expr_cubic_irrat, 0) 115 | # Confirm that roots contains symbolic elements (no timeout). 116 | assert -Rational(10,7) in roots 117 | # SymPy is not smart enough to simplify irrational roots symbolically 118 | # so check numerical equality of the symbolic roots. 119 | assert any(allclose(float(x), float(SymSqrt(2)/10)) for x in roots) 120 | assert any(allclose(float(x), float(SymSqrt(5))) for x in roots) 121 | 122 | xe1_cubic_irrat0 = -1.21493150246058 123 | xe1_cubic_irrat1 = -0.19158140952462 124 | xe1_cubic_irrat2 = 2.35543081715088 125 | def test_solve_poly_equality_cubic_irrat_one(): 126 | # This expression is too slow to solve symbolically. 127 | # The 5s timeout will trigger a numerical approximation. 128 | roots = solve_poly_equality(expr_cubic_irrat, 1) 129 | assert len(roots) == 3 130 | assert any(allclose(float(x), xe1_cubic_irrat0) for x in roots) 131 | assert any(allclose(float(x), xe1_cubic_irrat1) for x in roots) 132 | assert any(allclose(float(x), xe1_cubic_irrat2) for x in roots) 133 | -------------------------------------------------------------------------------- /tests/test_real_continuous.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from math import log 5 | 6 | import numpy 7 | import pytest 8 | import scipy.stats 9 | 10 | from sppl.distributions import gamma 11 | from sppl.distributions import norm 12 | from sppl.math_util import allclose 13 | from sppl.math_util import isinf_neg 14 | from sppl.math_util import logdiffexp 15 | from sppl.sets import Interval 16 | from sppl.sets import Reals 17 | from sppl.sets import inf as oo 18 | from sppl.spe import ContinuousLeaf 19 | from sppl.spe import SumSPE 20 | from sppl.transforms import Id 21 | 22 | def test_numeric_distribution_normal(): 23 | X = Id('X') 24 | spe = (X >> norm(loc=0, scale=1)) 25 | 26 | assert spe.size() == 1 27 | assert allclose(spe.logprob(X > 0), -log(2)) 28 | assert allclose(spe.logprob(abs(X) < 2), log(spe.dist.cdf(2) - spe.dist.cdf(-2))) 29 | 30 | assert allclose(spe.logprob(X**2 > 0), 0) 31 | assert allclose(spe.logprob(abs(X) > 0), 0) 32 | assert allclose(spe.logprob(~(X << {1})), 0) 33 | 34 | assert isinf_neg(spe.logprob(X**2 - X + 10 < 0)) 35 | assert isinf_neg(spe.logprob(abs(X) < 0)) 36 | assert isinf_neg(spe.logprob(X << {1})) 37 | 38 | spe.sample(100) 39 | spe.sample_subset([X], 100) 40 | assert spe.sample_subset([], 100) == [{}]*100 41 | spe.sample_func(lambda X: X**2, 1) 42 | spe.sample_func(lambda X: abs(X)+X**2, 1) 43 | spe.sample_func(lambda X: X**2 if X > 0 else X**3, 100) 44 | 45 | spe_condition_a = spe.condition((X < 2) | (X > 10)) 46 | samples = spe_condition_a.sample(100) 47 | assert all(s[X] < 2 for s in samples) 48 | 49 | spe_condition_b = spe.condition((X < -10) | (X > 10)) 50 | assert isinstance(spe_condition_b, SumSPE) 51 | assert allclose(spe_condition_b.weights[0], -log(2)) 52 | assert allclose(spe_condition_b.weights[0], spe_condition_b.weights[1]) 53 | 54 | for event in [(X<-10), (X>3)]: 55 | spe_condition_c = spe.condition(event) 56 | assert isinstance(spe_condition_c, ContinuousLeaf) 57 | assert isinf_neg(spe_condition_c.logprob((-1 < X) < 1)) 58 | samples = spe_condition_c.sample(100, prng=numpy.random.RandomState(1)) 59 | assert all(s[X] in event.values for s in samples) 60 | 61 | with pytest.raises(ValueError): 62 | spe.condition((X > 1) & (X < 1)) 63 | 64 | with pytest.raises(ValueError): 65 | spe.condition(X << {1}) 66 | 67 | with pytest.raises(ValueError): 68 | spe.sample_func(lambda Z: Z**2, 1) 69 | 70 | x = spe.logprob((X << {1, 2}) | (X < -1)) 71 | assert allclose(x, spe.logprob(X < -1)) 72 | 73 | with pytest.raises(AssertionError): 74 | spe.logprob(Id('Y') << {1, 2}) 75 | 76 | def test_numeric_distribution_gamma(): 77 | X = Id('X') 78 | 79 | spe = (X >> gamma(a=1, scale=1)) 80 | with pytest.raises(ValueError): 81 | spe.condition((X << {1, 2}) | (X < 0)) 82 | 83 | # Intentionally set Reals as the domain to exercise an important 84 | # code path in dist.condition (Union case with zero weights). 85 | spe = ContinuousLeaf(X, scipy.stats.gamma(a=1, scale=1), Reals) 86 | assert isinf_neg(spe.logprob((X << {1, 2}) | (X < 0))) 87 | with pytest.raises(ValueError): 88 | spe.condition((X << {1, 2}) | (X < 0)) 89 | 90 | spe_condition = spe.condition((X << {1,2} | (X <= 3))) 91 | assert isinstance(spe_condition, ContinuousLeaf) 92 | assert spe_condition.conditioned 93 | assert spe_condition.support == Interval(-oo, 3) 94 | assert allclose( 95 | spe_condition.logprob(X <= 2), 96 | logdiffexp(spe.logprob(X<=2), spe.logprob(X<=0)) 97 | - spe_condition.logZ) 98 | 99 | # Support on (-3, oo) 100 | spe = (X >> gamma(loc=-3, a=1)) 101 | assert spe.prob((-3 < X) < 0) > 0.95 102 | 103 | # Constrain. 104 | with pytest.raises(Exception): 105 | spe.constrain({X: -4}) 106 | spe_constrain = spe.constrain({X: .5}) 107 | samples = spe_constrain.sample(100, prng=None) 108 | assert all(s == {X: .5} for s in samples) 109 | -------------------------------------------------------------------------------- /tests/test_real_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | import numpy 7 | 8 | from sppl.distributions import poisson 9 | from sppl.distributions import randint 10 | from sppl.math_util import allclose 11 | from sppl.math_util import logdiffexp 12 | from sppl.math_util import logsumexp 13 | from sppl.sets import Interval 14 | from sppl.sets import Range 15 | from sppl.sets import inf as oo 16 | from sppl.spe import DiscreteLeaf 17 | from sppl.spe import SumSPE 18 | from sppl.transforms import Id 19 | 20 | def test_poisson(): 21 | X = Id('X') 22 | spe = X >> poisson(mu=5) 23 | 24 | a = spe.logprob((1 <= X) <= 7) 25 | b = spe.logprob(X << {1,2,3,4,5,6,7}) 26 | c = logsumexp([spe.logprob(X << {i}) for i in range(1, 8)]) 27 | assert allclose(a, b) 28 | assert allclose(a, c) 29 | assert allclose(b, c) 30 | 31 | spe_condition = spe.condition(10 <= X) 32 | assert spe_condition.conditioned 33 | assert spe_condition.support == Range(10, oo) 34 | assert spe_condition.logZ == logdiffexp(0, spe.logprob(X<=9)) 35 | 36 | assert allclose( 37 | spe_condition.logprob(X <= 10), 38 | spe_condition.logprob(X << {10})) 39 | assert allclose( 40 | spe_condition.logprob(X <= 10), 41 | spe_condition.logpdf({X: 10})) 42 | 43 | samples = spe_condition.sample(100) 44 | assert all(10 <= s[X] for s in samples) 45 | 46 | # Unify X = 5 with left interval to make one distribution. 47 | event = ((1 <= X) < 5) | ((3*X + 1) << {16}) 48 | spe_condition = spe.condition(event) 49 | assert isinstance(spe_condition, DiscreteLeaf) 50 | assert spe_condition.conditioned 51 | assert spe_condition.xl == 1 52 | assert spe_condition.xu == 5 53 | assert spe_condition.support == Range(1, 5) 54 | samples = spe_condition.sample(100, prng=numpy.random.RandomState(1)) 55 | assert all(event.evaluate(s) for s in samples) 56 | 57 | # Ignore X = 14/3 as a probability zero condition. 58 | spe_condition = spe.condition(((1 <= X) < 5) | (3*X + 1) << {15}) 59 | assert isinstance(spe_condition, DiscreteLeaf) 60 | assert spe_condition.conditioned 61 | assert spe_condition.xl == 1 62 | assert spe_condition.xu == 4 63 | assert spe_condition.support == Interval.Ropen(1,5) 64 | 65 | # Make a mixture of two components. 66 | spe_condition = spe.condition(((1 <= X) < 5) | (3*X + 1) << {22}) 67 | assert isinstance(spe_condition, SumSPE) 68 | xl = spe_condition.children[0].xl 69 | idx0 = 0 if xl == 7 else 1 70 | idx1 = 1 if xl == 7 else 0 71 | assert spe_condition.children[idx1].conditioned 72 | assert spe_condition.children[idx1].xl == 1 73 | assert spe_condition.children[idx1].xu == 4 74 | assert spe_condition.children[idx0].conditioned 75 | assert spe_condition.children[idx0].xl == 7 76 | assert spe_condition.children[idx0].xu == 7 77 | assert spe_condition.children[idx0].support == Range(7, 7) 78 | 79 | # Condition on probability zero event. 80 | with pytest.raises(ValueError): 81 | spe.condition(((-3 <= X) < 0) | (3*X + 1) << {20}) 82 | 83 | # Condition on FiniteReal contiguous. 84 | spe_condition = spe.condition(X << {1,2,3}) 85 | assert spe_condition.xl == 1 86 | assert spe_condition.xu == 3 87 | assert allclose(spe_condition.logprob((1 <= X) <=3), 0) 88 | 89 | # Condition on single point. 90 | assert allclose(0, spe.condition(X << {2}).logprob(X<<{2})) 91 | 92 | # Constrain. 93 | with pytest.raises(Exception): 94 | spe.constrain({X: -1}) 95 | with pytest.raises(Exception): 96 | spe.constrain({X: .5}) 97 | spe_constrain = spe.constrain({X: 10}) 98 | assert allclose(spe_constrain.prob(X << {0, 10}), 1) 99 | 100 | def test_condition_non_contiguous(): 101 | X = Id('X') 102 | spe = X >> poisson(mu=5) 103 | # FiniteSet. 104 | for c in [{0,2,3}, {-1,0,2,3}, {-1,0,2,3,'z'}]: 105 | spe_condition = spe.condition((X << c)) 106 | assert isinstance(spe_condition, SumSPE) 107 | assert allclose(0, spe_condition.children[0].logprob(X<<{0})) 108 | assert allclose(0, spe_condition.children[1].logprob(X<<{2,3})) 109 | # FiniteSet or Interval. 110 | spe_condition = spe.condition((X << {-1,'x',0,2,3}) | (X > 7)) 111 | assert isinstance(spe_condition, SumSPE) 112 | assert len(spe_condition.children) == 3 113 | assert allclose(0, spe_condition.children[0].logprob(X<<{0})) 114 | assert allclose(0, spe_condition.children[1].logprob(X<<{2,3})) 115 | assert allclose(0, spe_condition.children[2].logprob(X>7)) 116 | 117 | def test_randint(): 118 | X = Id('X') 119 | spe = X >> randint(low=0, high=5) 120 | assert spe.xl == 0 121 | assert spe.xu == 4 122 | assert spe.logprob(X < 5) == spe.logprob(X <= 4) == 0 123 | # i.e., X is not in [0, 3] 124 | spe_condition = spe.condition(~((X+1) << {1, 4})) 125 | assert isinstance(spe_condition, SumSPE) 126 | xl = spe_condition.children[0].xl 127 | idx0 = 0 if xl == 1 else 1 128 | idx1 = 1 if xl == 1 else 0 129 | assert spe_condition.children[idx0].xl == 1 130 | assert spe_condition.children[idx0].xu == 2 131 | assert spe_condition.children[idx1].xl == 4 132 | assert spe_condition.children[idx1].xu == 4 133 | assert allclose(spe_condition.children[idx0].logprob(X<<{1,2}), 0) 134 | assert allclose(spe_condition.children[idx1].logprob(X<<{4}), 0) 135 | -------------------------------------------------------------------------------- /tests/test_render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import os 5 | 6 | import pytest 7 | 8 | from sppl.compilers.ast_to_spe import Id 9 | from sppl.compilers.ast_to_spe import IfElse 10 | from sppl.compilers.ast_to_spe import Otherwise 11 | from sppl.compilers.ast_to_spe import Sample 12 | from sppl.compilers.ast_to_spe import Sequence 13 | from sppl.compilers.ast_to_spe import Transform 14 | from sppl.distributions import bernoulli 15 | from sppl.distributions import choice 16 | from sppl.render import render_nested_lists 17 | from sppl.render import render_nested_lists_concise 18 | 19 | def get_model(): 20 | Y = Id('Y') 21 | X = Id('X') 22 | Z = Id('Z') 23 | command = Sequence( 24 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), 25 | Sample(Z, bernoulli(p=0.1)), 26 | IfElse( 27 | Y << {str(0)} | Z << {0}, Sample(X, bernoulli(p=1/(0+1))), 28 | Otherwise, Transform(X, Z**2 + Z))) 29 | return command.interpret() 30 | 31 | def test_render_lists_crash(): 32 | model = get_model() 33 | render_nested_lists_concise(model) 34 | render_nested_lists(model) 35 | 36 | def test_render_graphviz_crash__magics_(): 37 | pytest.importorskip('graphviz') 38 | pytest.importorskip('pygraphviz') 39 | pytest.importorskip('networkx') 40 | 41 | from sppl.magics.render import render_networkx_graph 42 | from sppl.magics.render import render_graphviz 43 | 44 | model = get_model() 45 | render_networkx_graph(model) 46 | for fname in [None, '/tmp/spe.test.render']: 47 | for e in ['pdf', 'png', None]: 48 | render_graphviz(model, fname, ext=e) 49 | if fname is not None: 50 | assert not os.path.exists(fname) 51 | for ext in ['dot', e]: 52 | f = '%s.%s' % (fname, ext,) 53 | if e is not None: 54 | os.path.exists(f) 55 | os.unlink(f) 56 | -------------------------------------------------------------------------------- /tests/test_sets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.sets import EmptySet 7 | from sppl.sets import FiniteNominal as FN 8 | from sppl.sets import FiniteReal as FR 9 | from sppl.sets import Interval 10 | from sppl.sets import Union 11 | from sppl.sets import inf 12 | from sppl.sets import union_intervals 13 | from sppl.sets import union_intervals_finite 14 | 15 | def test_FiniteNominal_in(): 16 | with pytest.raises(Exception): 17 | FN() 18 | assert 'a' in FN('a') 19 | assert 'b' not in FN('a') 20 | assert 'b' in FN('a', b=True) 21 | assert 'a' not in FN('a', b=True) 22 | assert 'a' in FN(b=True) 23 | 24 | def test_FiniteNominal_invert(): 25 | assert ~(FN('a')) == FN('a', b=True) 26 | assert ~(FN('a', b=True)) == FN('a') 27 | assert ~FN(b=True) == EmptySet 28 | 29 | def test_FiniteNominal_and(): 30 | assert FN('a','b') & EmptySet is EmptySet 31 | assert FN('a','b') & FN('c') is EmptySet 32 | assert FN('a','b','c') & FN('a') == FN('a') 33 | assert FN('a','b','c') & FN(b=True) == FN('a','b','c') 34 | assert FN('a','b','c') & FN('a') == FN('a') 35 | assert FN('a','b','c', b=True) & FN('a') is EmptySet 36 | assert FN('a','b','c', b=True) & FN('d','a','b') == FN('d') 37 | assert FN('a','b','c', b=True) & FN('d') == FN('d') 38 | assert FN('a','b','c') & FN('a', b=True) == FN('b','c') 39 | assert FN('a','b','c') & FN('d','a','b', b=True) == FN('c') 40 | assert FN('a','b','c') & FN('d', b=True) == FN('a','b','c') 41 | assert FN('a','b','c', b=True) & FN('d', b=True) == FN('a','b','c','d', b=True) 42 | assert FN('a','b','c', b=True) & FN('a', b=True) == FN('a','b','c', b=True) 43 | assert FN(b=True) & FN(b=True) == FN(b=True) 44 | # FiniteReal 45 | assert FN('a') & FR(1) is EmptySet 46 | assert FN('a', b=True) & FR(1) is EmptySet 47 | # Interval 48 | assert FN('a') & Interval(0,1) is EmptySet 49 | 50 | def test_FiniteNominal_or(): 51 | # EmptySet 52 | assert FN('a','b') | EmptySet == FN('a','b') 53 | # Nominal 54 | assert FN('a','b') | FN('c') == FN('a','b','c') 55 | assert FN('a','b','c') | FN('a') == FN('a','b','c') 56 | assert FN('a','b','c') | FN(b=True) == FN(b=True) 57 | assert FN('a','b','c', b=True) | FN('a') == FN('b','c', b=True) 58 | assert FN('a','b','c', b=True) | FN('d','a','b') == FN('c', b=True) 59 | assert FN('a','b','c', b=True) | FN('d') == FN('a','b','c', b=True) 60 | assert FN('a','b','c') | FN('a', b=True) == FN(b=True) 61 | assert FN('a','b','c') | FN('d','a','b', b=True) == FN('d', b=True) 62 | assert FN('a','b','c') | FN('d', b=True) == FN('d', b=True) 63 | assert FN('a','b','c', b=True) | FN('d', b=True) == FN(b=True) 64 | assert FN('a','b','c', b=True) | FN('a', b=True) == FN('a', b=True) 65 | assert FN(b=True) | FN(b=True) == FN(b=True) 66 | # FiniteReal 67 | assert FN('a') | FR(1) == Union(FN('a'), FR(1)) 68 | assert FN('a', b=True) | FR(1) == Union(FR(1), FN('a', b=True)) 69 | # Interval 70 | assert FN('a') | Interval(0,1) == Union(FN('a'), Interval(0,1)) 71 | 72 | def test_FiniteReal_in(): 73 | with pytest.raises(Exception): 74 | FR() 75 | assert 1 in FN(1,2) 76 | assert 2 not in FN(1) 77 | 78 | def test_FiniteReal_invert(): 79 | assert ~FR(1) == Union( 80 | Interval.Ropen(-inf,1), 81 | Interval.Lopen(1, inf)) 82 | assert ~(FR(0,-1)) == Union( 83 | Interval.Ropen(-inf,-1), 84 | Interval.open(-1, 0), 85 | Interval.Lopen(0, inf)) 86 | 87 | def test_FiniteReal_and(): 88 | assert FR(1) & FR(2) is EmptySet 89 | assert FR(1,2) & FR(2) == FR(2) 90 | assert FR(1,2) & FN('2') is EmptySet 91 | assert FR(1,2) & FN('2', b=True) is EmptySet 92 | assert FR(1,2) & Interval(0,1) == FR(1) 93 | assert FR(1,2) & Interval.Ropen(0,1) is EmptySet 94 | assert FR(0,2) & Interval(0,1) == FR(0) 95 | assert FR(0,2) & Interval.Lopen(0,1) is EmptySet 96 | assert FR(-1,1) & Interval(-10,10) == FR(-1,1) 97 | assert FR(-1,11) & Interval(-10,10) == FR(-1) 98 | 99 | def test_FiniteReal_or(): 100 | assert FR(1) | FR(2) == FR(1,2) 101 | assert FR(1,2) | FR(2) == FR(1,2) 102 | assert FR(1,2) | FN('2') == Union(FR(1,2), FN('2')) 103 | assert FR(1,2) | Interval(0,1) == Union(Interval(0,1), FR(2)) 104 | assert FR(1,2) | Interval.Ropen(0,1) == Union(Interval(0,1), FR(2)) 105 | assert FR(0,1,2) | Interval.open(0,1) == Union(Interval(0,1), FR(2)) 106 | assert FR(0,2) | Interval.Lopen(0,1) == Union(Interval(0,1), FR(2)) 107 | assert FR(0,2) | Interval.Lopen(2.5,10) == Union(Interval.Lopen(2.5,10), FR(0,2)) 108 | assert FR(-1,1) | Interval(-10,10) == Interval(-10,10) 109 | assert FR(-1,11) | Interval(-10,10) == Union(Interval(-10, 10), FR(11)) 110 | 111 | def test_Interval_in(): 112 | with pytest.raises(Exception): 113 | Interval(3, 1) 114 | assert 1 in Interval(0,1) 115 | assert 1 not in Interval.Ropen(0,1) 116 | assert 1 in Interval.Lopen(0,1) 117 | assert 0 in Interval(0,1) 118 | assert 0 not in Interval.Lopen(0,1) 119 | assert 0 in Interval.Ropen(0,1) 120 | assert inf not in Interval(-inf, inf) 121 | assert -inf not in Interval(-inf, 0) 122 | assert 10 in Interval(-inf, inf) 123 | 124 | def test_Interval_invert(): 125 | assert ~(Interval(0,1)) == Union(Interval.Ropen(-inf, 0), Interval.Lopen(1, inf)) 126 | assert ~(Interval.open(0,1)) == Union(Interval(-inf, 0), Interval(1, inf)) 127 | assert ~(Interval.Lopen(0,1)) == Union(Interval(-inf, 0), Interval.Lopen(1, inf)) 128 | assert ~(Interval.Ropen(0,1)) == Union(Interval.Ropen(-inf, 0), Interval(1, inf)) 129 | assert ~(Interval(-inf, inf)) is EmptySet 130 | assert ~(Interval(3, inf)) == Interval.Ropen(-inf, 3) 131 | assert ~(Interval.open(3, inf)) == Interval(-inf, 3) 132 | assert ~(Interval.Lopen(3, inf)) == Interval(-inf, 3) 133 | assert ~(Interval.Ropen(3, inf)) == Interval.Ropen(-inf, 3) 134 | assert ~(Interval(-inf, 3)) == Interval.Lopen(3, inf) 135 | assert ~(Interval.open(-inf, 3)) == Interval(3, inf) 136 | assert ~(Interval.Lopen(-inf, 3)) == Interval.Lopen(3, inf) 137 | assert ~(Interval.Ropen(-inf, 3)) == Interval(3, inf) 138 | assert ~(Interval.open(-inf, inf)) is EmptySet 139 | 140 | def test_Interval_and(): 141 | assert Interval(0,1) & Interval(-1,1) == Interval(0,1) 142 | assert Interval(0,2) & Interval.open(0,1) == Interval.open(0,1) 143 | assert Interval.Lopen(0,1) & Interval(-1,1) == Interval.Lopen(0,1) 144 | assert Interval.Ropen(0,1) & Interval(-1,1) == Interval.Ropen(0,1) 145 | assert Interval(0,1) & Interval(1,2) == FR(1) 146 | assert Interval.Lopen(0,1) & Interval(1,2) == FR(1) 147 | assert Interval.Ropen(0,1) & Interval(1,2) is EmptySet 148 | assert Interval.Lopen(0,1) & Interval.Lopen(1,2) is EmptySet 149 | assert Interval(1,2) & Interval.Lopen(0,1) == FR(1) 150 | assert Interval(1,2) & Interval.open(0,1) is EmptySet 151 | assert Interval(0,2) & Interval.Lopen(0.5,2.5) == Interval.Lopen(0.5,2) 152 | assert Interval.Ropen(0,2) & Interval.Lopen(0.5,2.5) == Interval.open(0.5,2) 153 | assert Interval.open(0,2) & Interval(0.5,2.5) == Interval.Ropen(0.5,2) 154 | assert Interval.Lopen(0,2) & Interval.Ropen(0,2) == Interval.open(0,2) 155 | assert Interval(0,1) & Interval(2,3) is EmptySet 156 | assert Interval(2,3) & Interval(0,1) is EmptySet 157 | assert Interval.open(0,1) & Interval.open(0,1) == Interval.open(0,1) 158 | assert Interval.Ropen(-inf, -3) & Interval(-inf, inf) == Interval.Ropen(-inf, -3) 159 | assert Interval(-inf, inf) & Interval.Ropen(-inf, -3) == Interval.Ropen(-inf, -3) 160 | assert Interval(0, inf) & (Interval.Lopen(-5, inf)) == Interval(0, inf) 161 | assert Interval.Lopen(0, 1) & Interval.Ropen(0, 1) == Interval.open(0, 1) 162 | assert Interval.Ropen(0, 1) & Interval.Lopen(0, 1) == Interval.open(0, 1) 163 | assert Interval.Ropen(0, 5) & Interval.Ropen(-inf, 5) == Interval.Ropen(0, 5) 164 | 165 | def test_Interval_or(): 166 | assert Interval(0,1) | Interval(-1,1) == Interval(-1,1) 167 | assert Interval(0, 2) | Interval.open(0, 1) == Interval(0, 2) 168 | assert Interval.Lopen(0,1) | Interval(-1,1) == Interval(-1,1) 169 | assert Interval.Ropen(0,1) | Interval(-1,1) == Interval(-1,1) 170 | assert Interval(0,1) | Interval(1, 2) == Interval(0, 2) 171 | assert Interval.open(0, 1) | Interval(0,1) == Interval(0, 1) 172 | assert Interval(0, 1) | Interval(0,.5) == Interval(0, 1) 173 | assert Interval.Lopen(0, 1) | Interval(1,2) == Interval.Lopen(0, 2) 174 | assert Interval.Ropen(-1, 0) | Interval.Ropen(0, 1) == Interval.Ropen(-1,1) 175 | assert Interval.Ropen(0, 1) | Interval.Ropen(-1, 0) == Interval.Ropen(-1,1) 176 | assert Interval.Lopen(0, 1) | Interval.Lopen(1, 2) == Interval.Lopen(0,2) 177 | assert Interval.Lopen(0, 1) | Interval.Ropen(1, 2) == Interval.open(0,2) 178 | assert Interval.open(0, 2) | Interval(0, 1) == Interval.Ropen(0, 2) 179 | assert Interval.open(0, 1) | Interval.Ropen(-1, 0) == Union(Interval.open(0, 1), Interval.Ropen(-1,0)) 180 | assert Interval(1, 2) | Interval.Ropen(0, 1) == Interval(0,2) 181 | assert Interval.Ropen(0, 1) | Interval(1, 2) == Interval(0,2) 182 | assert Interval.open(0, 1) | Interval(1, 2) == Interval.Lopen(0,2) 183 | assert Interval.Ropen(0, 1) | Interval.Ropen(1, 2) == Interval.Ropen(0,2) 184 | assert Interval(1, 2) | Interval.open(0, 1) == Interval.Lopen(0,2) 185 | assert Interval(1, 2) | Interval.Lopen(0, 1) == Interval.Lopen(0,2) 186 | assert Interval(1, 2) | Interval.Ropen(0, 1) == Interval(0,2) 187 | assert Interval.open(0,1) | Interval(1,2) == Interval.Lopen(0,2) 188 | assert Interval.Lopen(0,1) | Interval(1,2) == Interval.Lopen(0,2) 189 | assert Interval.Ropen(0,1) | Interval(1,2) == Interval(0,2) 190 | assert Interval(0,2) | Interval.open(0.5, 2.5) == Interval.Ropen(0, 2.5) 191 | assert Interval.open(0,2) | Interval.open(0, 2.5) == Interval.open(0, 2.5) 192 | assert Interval.open(0,2.5) | Interval.open(0, 2) == Interval.open(0, 2.5) 193 | assert Interval.open(0,1) | Interval.open(1, 2) == Union(Interval.open(0, 1), Interval.open(1,2)) 194 | assert Interval.Ropen(0,2) | Interval.Lopen(0.5, 2.5) == Interval(0, 2.5) 195 | assert Interval.open(0,2) | Interval(0.5,2.5) == Interval.Lopen(0, 2.5) 196 | assert Interval.Lopen(0,2) | Interval.Ropen(0,2) == Interval(0,2) 197 | assert Interval(0,1) | Interval(2,3) == Union(Interval(0,1), Interval(2,3)) 198 | assert Interval(2,3) | Interval.Ropen(0,1) == Union(Interval(2,3), Interval.Ropen(0,1)) 199 | assert Interval.Lopen(0,1) | Interval(1,2) == Interval.Lopen(0,2) 200 | assert Interval(-10,10) | FR(-1,1) == Interval(-10,10) 201 | assert Interval(-10,10) | FR(-1,11) == Union(Interval(-10, 10), FR(11)) 202 | assert Interval(-inf, -3, right_open=True) | Interval(-inf, inf) == Interval(-inf, inf) 203 | 204 | def test_union_intervals(): 205 | assert union_intervals([ 206 | Interval(0,1), 207 | Interval(2,3), 208 | Interval(1,2) 209 | ]) == [Interval(0,3)] 210 | assert union_intervals([ 211 | Interval.open(0,1), 212 | Interval(2,3), 213 | Interval(1,2) 214 | ]) == [Interval.Lopen(0,3)] 215 | assert union_intervals([ 216 | Interval.open(0,1), 217 | Interval(2,3), 218 | Interval.Lopen(1,2) 219 | ]) == [Interval.open(0,1), Interval.Lopen(1,3)] 220 | assert union_intervals([ 221 | Interval.open(0,1), 222 | Interval.Ropen(0,3), 223 | Interval.Lopen(1,2) 224 | ]) == [Interval.Ropen(0,3)] 225 | assert union_intervals([ 226 | Interval.open(-2,-1), 227 | Interval.Ropen(0,3), 228 | Interval.Lopen(1,2) 229 | ]) == [Interval.open(-2,-1), Interval.Ropen(0,3)] 230 | 231 | def test_union_intervals_finite(): 232 | assert union_intervals_finite([ 233 | Interval.open(0,1), 234 | Interval(2,3), 235 | Interval.Lopen(1,2) 236 | ], FR(1)) \ 237 | == [Interval.Lopen(0, 3)] 238 | assert union_intervals_finite([ 239 | Interval.open(0,1), 240 | Interval.open(2, 3), 241 | Interval.open(1,2) 242 | ], FR(1, 3)) \ 243 | == [Interval.open(0, 2), Interval.Lopen(2, 3)] 244 | assert union_intervals_finite([ 245 | Interval.open(0,1), 246 | Interval.open(1, 3), 247 | Interval.open(11,15) 248 | ], FR(1, -11, -19, 3)) \ 249 | == [Interval.Lopen(0, 3), Interval.open(11,15), FR(-11, -19)] 250 | 251 | def test_Union_or(): 252 | x = Interval(0,1) | Interval(5,6) | Interval(10,11) 253 | assert x == Union(Interval(0,1), Interval(5,6), Interval(10,11)) 254 | x = Interval.Ropen(0,1) | Interval.Lopen(1,2) | Interval(10,11) 255 | assert x == Union(Interval.Ropen(0,1), Interval.Lopen(1,2), Interval(10,11)) 256 | x = Interval.Ropen(0,1) | Interval.Lopen(1,2) | Interval(10,11) | FR(1) 257 | assert x == Union(Interval(0,2), Interval(10,11)) 258 | x = (Interval.Ropen(0,1) | Interval.Lopen(1,2)) | (Interval(10,11) | FR(1)) 259 | assert x == Union(Interval(0,2), Interval(10,11)) 260 | x = FR(1) | ((Interval.Ropen(0,1) | Interval.Lopen(1,2) | FR(10,13)) \ 261 | | (Interval.Lopen(10,11) | FR(7))) 262 | assert x == Union(Interval(0,2), Interval(10,11), FR(13, 7)) 263 | assert 2 in x 264 | assert 13 in x 265 | x = FN('f') | (FR(1) | FN('g', b=True)) 266 | assert x == Union(FR(1), FN('g', b=True)) 267 | assert 'w' in x 268 | assert 'g' not in x 269 | 270 | def test_Union_and(): 271 | x = (Interval(0,1) | FR(1)) & (FN('a')) 272 | assert x is EmptySet 273 | x = (FN('x', b=True)| Interval(0,1) | FR(1)) & (FN('a')) 274 | assert x == FN('a') 275 | x = (FN('x')| Interval(0,1) | FR(1)) & (FN('a')) 276 | assert x is EmptySet 277 | x = (FN('x')| Interval.open(0,1) | FR(7)) & ((FN('x')) | FR(.5) | Interval(.75, 1.2)) 278 | assert x == Union(FR(.5), FN('x'), Interval.Ropen(.75, 1)) 279 | x = (FN('x')| Interval.open(0,1) | FR(7)) & (FR(3)) 280 | assert x is EmptySet 281 | x = (Interval.Lopen(-5, inf)) & (Interval(0, inf) | FR(inf)) 282 | assert x == Interval(0, inf) 283 | x = (FR(1,2) | Interval.Ropen(-inf, 0)) & Interval(0, inf) 284 | assert x == FR(1,2) 285 | x = (FR(1,12) | Interval(0, 5) | Interval(7,10)) & Interval(4, 12) 286 | assert x == Union(Interval(4,5), Interval(7,10), FR(12)) 287 | -------------------------------------------------------------------------------- /tests/test_spe_to_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import json 5 | import pytest 6 | 7 | from sympy import sqrt 8 | 9 | from sppl.compilers.spe_to_dict import spe_from_dict 10 | from sppl.compilers.spe_to_dict import spe_to_dict 11 | from sppl.distributions import choice 12 | from sppl.distributions import gamma 13 | from sppl.distributions import norm 14 | from sppl.distributions import poisson 15 | from sppl.sets import EmptySet 16 | from sppl.transforms import EventFiniteNominal 17 | from sppl.transforms import Exp 18 | from sppl.transforms import Exponential 19 | from sppl.transforms import Id 20 | from sppl.transforms import Log 21 | from sppl.transforms import Logarithm 22 | 23 | X = Id('X') 24 | Y = Id('Y') 25 | 26 | spes = [ 27 | X >> norm(loc=0, scale=1), 28 | X >> poisson(mu=7), 29 | Y >> choice({'a': 0.5, 'b': 0.5}), 30 | (X >> norm(loc=0, scale=1)) & (Y >> gamma(a=1)), 31 | 0.2*(X >> norm(loc=0, scale=1)) | 0.8*(X >> gamma(a=1)), 32 | ((X >> norm(loc=0, scale=1)) & (Y >> gamma(a=1))).constrain({Y:1}), 33 | ] 34 | @pytest.mark.parametrize('spe', spes) 35 | def test_serialize_equal(spe): 36 | metadata = spe_to_dict(spe) 37 | spe_json_encoded = json.dumps(metadata) 38 | spe_json_decoded = json.loads(spe_json_encoded) 39 | spe2 = spe_from_dict(spe_json_decoded) 40 | assert spe2 == spe 41 | 42 | transforms = [ 43 | X, 44 | X**(1,3), 45 | Exponential(X, base=3), 46 | Logarithm(X, base=2), 47 | 2**Log(X), 48 | 1/Exp(X), 49 | abs(X), 50 | 1/X, 51 | 2*X + X**3, 52 | (X/2)*(X<0) + (X**(1,2))*(0<=X), 53 | X < sqrt(3), 54 | X << [], 55 | ~(X << []), 56 | EventFiniteNominal(1/X**(1,10), EmptySet), 57 | X << {1, 2}, 58 | X << {'a', 'x'}, 59 | ~(X << {'a', '1'}), 60 | (X < 3) | (X << {1,2}), 61 | (X < 3) & (X << {1,2}), 62 | ] 63 | @pytest.mark.parametrize('transform', transforms) 64 | def test_serialize_env(transform): 65 | spe = (X >> norm()).transform(Y, transform) 66 | metadata = spe_to_dict(spe) 67 | spe_json_encoded = json.dumps(metadata) 68 | spe_json_decoded = json.loads(spe_json_encoded) 69 | spe2 = spe_from_dict(spe_json_decoded) 70 | assert spe2 == spe 71 | -------------------------------------------------------------------------------- /tests/test_spe_to_spml.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from sppl.compilers.spe_to_sppl import render_sppl 5 | from sppl.compilers.sppl_to_python import SPPL_Compiler 6 | from sppl.math_util import allclose 7 | from sppl.tests.test_render import get_model 8 | 9 | def test_render_sppl(): 10 | model = get_model() 11 | sppl_code = render_sppl(model) 12 | compiler = SPPL_Compiler(sppl_code.getvalue()) 13 | namespace = compiler.execute_module() 14 | (X, Y) = (namespace.X, namespace.Y) 15 | for i in range(5): 16 | assert allclose(model.logprob(Y << {'0'}), [ 17 | model.logprob(Y << {str(i)}), 18 | namespace.model.logprob(Y << {str(i)}) 19 | ]) 20 | for i in range(4): 21 | assert allclose( 22 | model.logprob(X << {i}), 23 | namespace.model.logprob(X << {i})) 24 | -------------------------------------------------------------------------------- /tests/test_spe_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.distributions import choice 7 | from sppl.distributions import norm 8 | from sppl.distributions import poisson 9 | from sppl.math_util import allclose 10 | from sppl.transforms import Id 11 | 12 | def test_transform_real_leaf_logprob(): 13 | X = Id('X') 14 | Y = Id('Y') 15 | Z = Id('Z') 16 | spe = (X >> norm(loc=0, scale=1)) 17 | 18 | with pytest.raises(AssertionError): 19 | spe.transform(Z, Y**2) 20 | with pytest.raises(AssertionError): 21 | spe.transform(X, X**2) 22 | 23 | spe = spe.transform(Z, X**2) 24 | assert spe.env == {X:X, Z:X**2} 25 | assert spe.get_symbols() == {X, Z} 26 | assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) 27 | assert spe.logprob((Z < 1) | ((X + 1) < 3)) \ 28 | == spe.logprob((X**2 < 1) | ((X+1) < 3)) 29 | 30 | spe = spe.transform(Y, 2*Z) 31 | assert spe.env == {X:X, Z:X**2, Y:2*Z} 32 | assert spe.logprob(Y**(1,3) < 10) \ 33 | == spe.logprob((2*Z)**(1,3) < 10) \ 34 | == spe.logprob((2*(X**2))**(1,3) < 10) \ 35 | 36 | W = Id('W') 37 | spe = spe.transform(W, X > 1) 38 | assert allclose(spe.logprob(W), spe.logprob(X > 1)) 39 | 40 | def test_transform_real_leaf_sample(): 41 | X = Id('X') 42 | Z = Id('Z') 43 | Y = Id('Y') 44 | spe = (X >> poisson(loc=-1, mu=1)) 45 | spe = spe.transform(Z, X+1) 46 | spe = spe.transform(Y, Z-1) 47 | samples = spe.sample(100) 48 | assert any(s[X] == -1 for s in samples) 49 | assert all(0 <= s[Z] for s in samples) 50 | assert all(s[Y] == s[X] for s in samples) 51 | assert all(spe.sample_func(lambda X,Y,Z: X-Y+Z==Z, 100)) 52 | assert all(set(s) == {X,Y} for s in spe.sample_subset([X, Y], 100)) 53 | 54 | def test_transform_sum(): 55 | X = Id('X') 56 | Z = Id('Z') 57 | Y = Id('Y') 58 | spe \ 59 | = 0.3*(X >> norm(loc=0, scale=1)) \ 60 | | 0.7*(X >> choice({'0': 0.4, '1': 0.6})) 61 | with pytest.raises(Exception): 62 | # Cannot transform Nominal variate. 63 | spe.transform(Z, X**2) 64 | spe \ 65 | = 0.3*(X >> norm(loc=0, scale=1)) \ 66 | | 0.7*(X >> poisson(mu=2)) 67 | spe = spe.transform(Z, X**2) 68 | assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) 69 | assert spe.children[0].env == spe.children[1].env 70 | spe = spe.transform(Y, Z/2) 71 | assert spe.children[0].env \ 72 | == spe.children[1].env \ 73 | == {X:X, Z:X**2, Y:Z/2} 74 | 75 | def test_transform_product(): 76 | X = Id('X') 77 | Y = Id('Y') 78 | W = Id('W') 79 | Z = Id('Z') 80 | V = Id('V') 81 | spe \ 82 | = (X >> norm(loc=0, scale=1)) \ 83 | & (Y >> poisson(mu=10)) 84 | with pytest.raises(Exception): 85 | # Cannot use symbols from different transforms. 86 | spe.transform(W, (X > 0) | (Y << {'0'})) 87 | spe = spe.transform(W, (X**2 - 3*X)**(1,10)) 88 | spe = spe.transform(Z, (W > 0) | (X**3 < 1)) 89 | spe = spe.transform(V, Y/10) 90 | assert allclose( 91 | spe.logprob(W>1), 92 | spe.logprob((X**2 - 3*X)**(1,10) > 1)) 93 | with pytest.raises(Exception): 94 | spe.tarnsform(Id('R'), (V>1) | (W < 0)) 95 | -------------------------------------------------------------------------------- /tests/test_sppl_to_python.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.compilers.sppl_to_python import SPPL_Compiler 7 | 8 | isclose = lambda a, b : abs(a-b) < 1e-10 9 | 10 | # Test errors in visit_Assign: 11 | # - [x] overwrite variable / array 12 | # - [x] unknown sample target array 13 | # - [x] assigning fresh variables in else 14 | # - [x] non-array variable in for 15 | # - [x] array target must not be subscript 16 | # - [x] assign invalid Python after sampling 17 | # - [x] assign invalid Python non-global 18 | # - [x] invalid left-hand side (tuple, unpacking, etc) 19 | 20 | overwrite_cases = [ 21 | 'X ~= norm(); X ~= bernoulli(p=.5)', 22 | 'X = array(5); X ~= bernoulli(p=.5)', 23 | 'X ~= bernoulli(p=.5); X = array(5)', 24 | ''' 25 | X ~= norm() 26 | if (X > 0): 27 | X ~= norm() 28 | else: 29 | X ~= norm() 30 | ''', 31 | 'X = array(5); W ~= norm(); X = array(10)', 32 | ] 33 | @pytest.mark.parametrize('case', overwrite_cases) 34 | def test_error_assign_overwrite_variable(case): 35 | with pytest.raises(AssertionError): 36 | SPPL_Compiler(case) 37 | 38 | def test_error_assign_unknown_array(): 39 | source = ''' 40 | X ~= uniform(loc=0, scale=1); 41 | Y[0] ~= bernoulli(p=.1) 42 | ''' 43 | with pytest.raises(AssertionError): 44 | SPPL_Compiler(source) 45 | SPPL_Compiler('Y = array(10);\n%s' % (source,)) 46 | 47 | def test_error_assign_fresh_variable_else(): 48 | source = ''' 49 | X ~= norm() 50 | if (X > 0): 51 | Y ~= 2*X 52 | else: 53 | W ~= 2*X 54 | ''' 55 | with pytest.raises(AssertionError): 56 | SPPL_Compiler(source) 57 | SPPL_Compiler(source.replace('W', 'Y')) 58 | 59 | def test_error_assign_non_array_in_for(): 60 | source = ''' 61 | Y ~= norm(); 62 | for i in range(10): 63 | X ~= norm() 64 | ''' 65 | with pytest.raises(AssertionError): 66 | SPPL_Compiler(source) 67 | source_prime = 'X = array(5);\n%s' % (source.replace('X', 'X[i]'),) 68 | SPPL_Compiler(source_prime) 69 | 70 | def test_error_assign_array_subscript(): 71 | source = ''' 72 | Y = array(5) 73 | Y[0] = array(2) 74 | ''' 75 | with pytest.raises(AssertionError): 76 | SPPL_Compiler(source) 77 | SPPL_Compiler(source.replace('array(2)', 'norm()')) 78 | 79 | def test_error_assign_py_constant_after_sampling(): 80 | source = ''' 81 | X ~= norm() 82 | Y = "foo" 83 | ''' 84 | with pytest.raises(AssertionError): 85 | SPPL_Compiler(source) 86 | 87 | def test_error_assign_py_constant_non_global(): 88 | source = ''' 89 | X ~= norm() 90 | if (X > 0): 91 | Y = "ali" 92 | else: 93 | Y = norm() 94 | ''' 95 | with pytest.raises(AssertionError): 96 | SPPL_Compiler(source) 97 | 98 | def test_error_assign_assign_invalid_lhs(): 99 | with pytest.raises(AssertionError): 100 | SPPL_Compiler('[Y] ~= norm()') 101 | with pytest.raises(AssertionError): 102 | SPPL_Compiler('(X, Y) ~= norm()') 103 | 104 | # Test errors in visit_For: 105 | # - [x] func is not range. 106 | # - [x] func range has > 2 arguments 107 | # - [x] more than one iteration variable 108 | 109 | def test_error_for_func_not_range(): 110 | source = ''' 111 | X = array(10) 112 | for i in xrange(10): 113 | X[i] ~= norm() 114 | ''' 115 | with pytest.raises(AssertionError): 116 | SPPL_Compiler(source) 117 | SPPL_Compiler(source.replace('xrange', 'range')) 118 | 119 | def test_error_for_range_step(): 120 | source = ''' 121 | X = array(10) 122 | X[0] ~= norm() 123 | for i in range(1, 10, 2): 124 | X[i] ~= 2*X[i-1] 125 | ''' 126 | with pytest.raises(AssertionError): 127 | SPPL_Compiler(source) 128 | SPPL_Compiler(source.replace(', 2)' , ')')) 129 | 130 | def test_error_for_range_multiple_vars(): 131 | source = ''' 132 | X = array(10) 133 | X[0] ~= norm() 134 | for (i, j) in range(1, 9): 135 | X[i] ~= 2*X[i-1] 136 | ''' 137 | with pytest.raises(AssertionError): 138 | SPPL_Compiler(source) 139 | SPPL_Compiler(source.replace('(i, j)' , 'i')) 140 | 141 | # Test errors in visit_If: 142 | # - [x] if without matching else/elif 143 | def test_error_if_no_else(): 144 | source = ''' 145 | X ~= norm() 146 | if (X > 0): 147 | Y ~= 2*X + 1 148 | ''' 149 | with pytest.raises(AssertionError): 150 | SPPL_Compiler(source) 151 | SPPL_Compiler('%s\nelse:\n Y~= norm()' % (source,)) 152 | SPPL_Compiler('%s\nelif (X<0):\n Y~= norm()' % (source,)) 153 | 154 | # Test SPPL_Transformer 155 | # Q = Z in {0, 1} E = Z << {0, 1} 156 | # Q = Z not in {0, 1} Q = ~(Z << {0, 1}) 157 | # Q = Z == 'foo' Q = Z << {'foo'} 158 | # B = Z != 'foo' Q = ~(Z << {'foo'}) 159 | 160 | def test_transform_in(): 161 | source = ''' 162 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4}) 163 | Y = X in {'foo', 'baz'} 164 | ''' 165 | compiler = SPPL_Compiler(source) 166 | py_source = compiler.render_module() 167 | assert 'X << {\'foo\', \'baz\'}' in py_source 168 | assert 'X in' not in py_source 169 | 170 | def test_transform_in_not(): 171 | source = ''' 172 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4}) 173 | Y = X not in {'foo', 'baz'} 174 | ''' 175 | compiler = SPPL_Compiler(source) 176 | py_source = compiler.render_module() 177 | assert '~ (X << {\'foo\', \'baz\'})' in py_source 178 | assert 'not in' not in py_source 179 | 180 | def test_transform_eq(): 181 | source = ''' 182 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4}) 183 | Y = X == 'foo' 184 | ''' 185 | compiler = SPPL_Compiler(source) 186 | py_source = compiler.render_module() 187 | assert 'X << {\'foo\'}' in py_source 188 | assert '==' not in py_source 189 | 190 | def test_transform_eq_not(): 191 | source = ''' 192 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4}) 193 | Y = X != 'foo' 194 | ''' 195 | compiler = SPPL_Compiler(source) 196 | py_source = compiler.render_module() 197 | assert '~ (X << {\'foo\'})' in py_source 198 | assert '!=' not in py_source 199 | 200 | def test_compile_all_constructs(): 201 | source = ''' 202 | X = array(10) 203 | W = array(10) 204 | Y = randint(low=1, high=2) 205 | Z = bernoulli(p=0.1) 206 | 207 | E = choice({'1': 0.3, '2': 0.7}) 208 | 209 | 210 | for i in range(1,5): 211 | W[i] = uniform(loc=0, scale=2) 212 | X[i] = bernoulli(p=0.5) 213 | 214 | X[0] ~= gamma(a=1) 215 | H = (X[0]**2 + 2*X[0] + 3)**(1, 10) 216 | 217 | # Here is a comment, with indentation on next line 218 | 219 | 220 | 221 | X[5] ~= 0.3*atomic(loc=0) | 0.4*atomic(loc=-1) | 0.3*atomic(loc=3) 222 | if X[5] == 0: 223 | X[7] = bernoulli(p=0.1) 224 | X[8] = 1 + X[3] 225 | elif (X[5]**2 == 1): 226 | X[7] = bernoulli(p=0.2) 227 | X[8] = 1 + X[3] 228 | else: 229 | if (X[3] in {0, 1}): 230 | X[7] = bernoulli(p=0.2) 231 | X[8] = 1 + X[3] 232 | else: 233 | X[7] = bernoulli(p=0.2) 234 | X[8] = 1 + X[3] 235 | ''' 236 | compiler = SPPL_Compiler(source) 237 | namespace = compiler.execute_module() 238 | model = namespace.model 239 | assert isclose(model.prob(namespace.X[5] << {0}), .3) 240 | assert isclose(model.prob(namespace.X[5] << {-1}), .4) 241 | assert isclose(model.prob(namespace.X[5] << {3}), .3) 242 | assert isclose(model.prob(namespace.E << {'1'}), .3) 243 | assert isclose(model.prob(namespace.E << {'2'}), .7) 244 | 245 | def test_imports(): 246 | source = ''' 247 | Y ~= bernoulli(p=.5) 248 | Z ~= choice({str(i): Fraction(1, 5) for i in range(5)}) 249 | X = array(5) 250 | for i in range(5): 251 | X[i] ~= Fraction(1,2) * Y 252 | ''' 253 | compiler = SPPL_Compiler(source) 254 | with pytest.raises(NameError): 255 | compiler.execute_module() 256 | compiler = SPPL_Compiler('from fractions import Fraction\n%s' % (source,)) 257 | namespace = compiler.execute_module() 258 | for i in range(5): 259 | assert isclose(namespace.model.prob(namespace.Z << {str(i)}), .2) 260 | 261 | def test_ifexp(): 262 | source = ''' 263 | from fractions import Fraction 264 | Y ~= choice({str(i): Fraction(1, 4) for i in range(4)}) 265 | Z ~= ( 266 | atomic(loc=0) if (Y in {'0', '1'}) else 267 | atomic(loc=4) if (Y == '2') else 268 | atomic(loc=6)) 269 | ''' 270 | compiler = SPPL_Compiler(source) 271 | assert 'IfElse' in compiler.render_module() 272 | namespace = compiler.execute_module() 273 | assert isclose(namespace.model.prob(namespace.Z << {0}), .5) 274 | assert isclose(namespace.model.prob(namespace.Z << {4}), .25) 275 | assert isclose(namespace.model.prob(namespace.Z << {6}), .25) 276 | 277 | def test_switch_shallow(): 278 | source = ''' 279 | Y ~= choice({'0': .25, '1': .5, '2': .25}) 280 | 281 | switch (Y) cases (i in ['0', '1', '2']): 282 | Z ~= atomic(loc=int(i)) 283 | ''' 284 | compiler = SPPL_Compiler(source) 285 | namespace = compiler.execute_module() 286 | assert isclose(namespace.model.prob(namespace.Z << {0}), .25) 287 | assert isclose(namespace.model.prob(namespace.Z << {1}), .5) 288 | assert isclose(namespace.model.prob(namespace.Z << {2}), .25) 289 | 290 | def test_switch_enumerate(): 291 | source = ''' 292 | Y ~= choice({'0': .25, '1': .5, '2': .25}) 293 | 294 | switch (Y) cases (i,j in enumerate(['0', '1', '2'])): 295 | Z ~= atomic(loc=i+int(j)) 296 | ''' 297 | compiler = SPPL_Compiler(source) 298 | namespace = compiler.execute_module() 299 | assert isclose(namespace.model.prob(namespace.Z << {0}), .25) 300 | assert isclose(namespace.model.prob(namespace.Z << {2}), .5) 301 | assert isclose(namespace.model.prob(namespace.Z << {4}), .25) 302 | 303 | def test_switch_nested(): 304 | source = ''' 305 | Y ~= randint(low=0, high=4) 306 | W ~= randint(low=0, high=2) 307 | 308 | switch (Y) cases (i in range(0, 5)): 309 | Z ~= choice({str(i): 1}) 310 | switch (W) cases (i in range(0, 2)): V ~= atomic(loc=i) 311 | ''' 312 | compiler = SPPL_Compiler(source) 313 | namespace = compiler.execute_module() 314 | assert isclose(namespace.model.prob(namespace.Z << {'0'}), .25) 315 | assert isclose(namespace.model.prob(namespace.Z << {'1'}), .25) 316 | assert isclose(namespace.model.prob(namespace.Z << {'2'}), .25) 317 | assert isclose(namespace.model.prob(namespace.Z << {'3'}), .25) 318 | assert isclose(namespace.model.prob(namespace.V << {0}), .5) 319 | assert isclose(namespace.model.prob(namespace.V << {1}), .5) 320 | 321 | def test_condition_simple(): 322 | source = ''' 323 | Y ~= norm(loc=0, scale=2) 324 | condition((0 < Y) < 2) 325 | Z ~= binom(n=10, p=.2) 326 | condition(Z == 0) 327 | ''' 328 | compiler = SPPL_Compiler(source) 329 | namespace = compiler.execute_module() 330 | assert isclose(namespace.model.prob((0 < namespace.Y) < 2), 1) 331 | 332 | def test_condition_if(): 333 | source = ''' 334 | Y ~= norm(loc=0, scale=2) 335 | if (Y > 1): 336 | condition(Y < 1) 337 | else: 338 | condition(Y > -1) 339 | ''' 340 | with pytest.raises(Exception): 341 | compiler = SPPL_Compiler(source) 342 | compiler.execute_module() 343 | compiler = SPPL_Compiler(source.replace('Y > 1', 'Y > 0')) 344 | namespace = compiler.execute_module() 345 | assert isclose(namespace.model.prob((-1 < namespace.Y) < 1), 1) 346 | 347 | def test_constrain_simple(): 348 | source = ''' 349 | Y ~= norm(loc=0, scale=2) 350 | Z ~= binom(n=10, p=.2) 351 | constrain({Z: 10, Y: 0}) 352 | ''' 353 | compiler = SPPL_Compiler(source) 354 | namespace = compiler.execute_module() 355 | assert isclose(namespace.model.prob((0 < namespace.Y) < 2), 0) 356 | assert isclose(namespace.model.prob((0 <= namespace.Y) < 2), 1) 357 | assert isclose(namespace.model.prob(namespace.Z << {10}), 1) 358 | 359 | def test_constrain_if(): 360 | source = ''' 361 | Y ~= norm(loc=0, scale=2) 362 | if (Y > 1): 363 | constrain({Y: 2}) 364 | else: 365 | constrain({Y: 0}) 366 | ''' 367 | compiler = SPPL_Compiler(source) 368 | namespace = compiler.execute_module() 369 | assert isclose(namespace.model.prob(namespace.Y > 0), 0.3085375387259868) 370 | assert isclose(namespace.model.prob(namespace.Y < 0), 0) 371 | assert isclose(namespace.model.prob(namespace.Y << {0}), 1-0.3085375387259868) 372 | 373 | def test_constant_parameter(): 374 | source = ''' 375 | parameters = [1, 2, 3] 376 | n_array = 2 377 | Y = array(n_array + 1) 378 | for i in range(n_array + 1): 379 | Y[i] = randint(low=parameters[i], high=parameters[i]+1) 380 | ''' 381 | compiler = SPPL_Compiler(source) 382 | namespace = compiler.execute_module() 383 | assert isclose(namespace.model.prob(namespace.Y[0] << {1}), 1) 384 | assert isclose(namespace.model.prob(namespace.Y[1] << {2}), 1) 385 | assert isclose(namespace.model.prob(namespace.Y[2] << {3}), 1) 386 | with pytest.raises(AssertionError): 387 | SPPL_Compiler('%sZ = "foo"\n' % (source,)) 388 | 389 | def test_error_array_length(): 390 | with pytest.raises(TypeError): 391 | SPPL_Compiler('Y = array(1.3)').execute_module() 392 | with pytest.raises(TypeError): 393 | SPPL_Compiler('Y = array(\'foo\')').execute_module() 394 | # Length zero array 395 | namespace = SPPL_Compiler('Y = array(-1)').execute_module() 396 | assert len(namespace.Y) == 0 397 | -------------------------------------------------------------------------------- /tests/test_substitute.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sppl.transforms import Id 7 | from sppl.transforms import Logarithm 8 | 9 | X = Id('X') 10 | Z = Id('Z') 11 | Y = Id('Y') 12 | 13 | b = (1, 10) 14 | cases = [ 15 | # Basic cases. 16 | [X , {} , X], 17 | [1/X , {} , 1/X], 18 | [X**2+X , {} , X**2+X], 19 | [X , {} , X], 20 | [X , {Z:X} , X], 21 | [Z , {Z:2**X} , 2**X], 22 | [((Z+1)**b) , {Z:1/X} , (1/X+1)**b], 23 | [((2**Z+1)**b) , {Z:X+1} , (2**(X+1)+1)**b], 24 | [((2**Logarithm(Z, 2)+1)**b) , {Z:X+1} , (2**Logarithm((X+1), 2)+1)**b] , 25 | [((2**abs(Logarithm(Z, 2))+1)**b) , {Z:X+1} , (2**abs(Logarithm((X+1), 2))+1)**b], 26 | [((2**abs(Logarithm(1/Z, 2))+1)**b) , {Z:X+1} , (2**abs(Logarithm(1/(X+1), 2))+1)**b], 27 | # Compound cases. 28 | [(Z > 1)*(1/Z) + (Z < 1)*Z**b 29 | , {Z:X+1} 30 | , ((X+1) > 1)*(1/(X+1)) + ((X+1) < 1)*(X+1)**b], 31 | [(Z << {'a'}) & (Y < 3) 32 | , {Y:1/X} 33 | , (Z << {'a'}) & (1/X < 3)], 34 | [((Y > 1) | Z << {'a'}) & (Y < 3) 35 | , {Z : Y**2, Y:1/X} 36 | , (((1/X) > 1) | ((1/X)**2) << {'a'}) & ((1/X) < 3)], 37 | ] 38 | @pytest.mark.parametrize('case', cases) 39 | def test_substitute_basic(case): 40 | (expr, env, expr_prime) = case 41 | assert expr.substitute(env) == expr_prime 42 | 43 | @pytest.mark.parametrize('case', cases) 44 | def test_substitute_transitive(case): 45 | (expr, env, expr_prime) = case 46 | if len(env) == 1: 47 | [(s0, s1)] = env.items() 48 | s2 = Id('s2') 49 | env_prime = {s0: s2, s2: s1} 50 | assert expr.substitute(env_prime) == expr_prime 51 | -------------------------------------------------------------------------------- /tests/test_sum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | from fractions import Fraction 5 | from math import log 6 | 7 | import pytest 8 | 9 | import numpy 10 | 11 | from sppl.distributions import choice 12 | from sppl.distributions import gamma 13 | from sppl.distributions import norm 14 | from sppl.math_util import allclose 15 | from sppl.math_util import isinf_neg 16 | from sppl.math_util import logsumexp 17 | from sppl.sets import FiniteNominal 18 | from sppl.sets import Interval 19 | from sppl.sets import inf as oo 20 | from sppl.spe import ContinuousLeaf 21 | from sppl.spe import ExposedSumSPE 22 | from sppl.spe import NominalLeaf 23 | from sppl.spe import ProductSPE 24 | from sppl.spe import SumSPE 25 | from sppl.transforms import Id 26 | 27 | def test_sum_normal_gamma(): 28 | X = Id('X') 29 | weights = [ 30 | log(Fraction(2, 3)), 31 | log(Fraction(1, 3)) 32 | ] 33 | spe = SumSPE( 34 | [X >> norm(loc=0, scale=1), X >> gamma(loc=0, a=1),], weights) 35 | 36 | assert spe.logprob(X > 0) == logsumexp([ 37 | spe.weights[0] + spe.children[0].logprob(X > 0), 38 | spe.weights[1] + spe.children[1].logprob(X > 0), 39 | ]) 40 | assert spe.logprob(X < 0) == log(Fraction(2, 3)) + log(Fraction(1, 2)) 41 | samples = spe.sample(100, prng=numpy.random.RandomState(1)) 42 | assert all(s[X] for s in samples) 43 | spe.sample_func(lambda X: abs(X**3), 100) 44 | with pytest.raises(ValueError): 45 | spe.sample_func(lambda Y: abs(X**3), 100) 46 | 47 | spe_condition = spe.condition(X < 0) 48 | assert isinstance(spe_condition, ContinuousLeaf) 49 | assert spe_condition.conditioned 50 | assert spe_condition.logprob(X < 0) == 0 51 | samples = spe_condition.sample(100) 52 | assert all(s[X] < 0 for s in samples) 53 | 54 | assert spe.logprob(X < 0) == logsumexp([ 55 | spe.weights[0] + spe.children[0].logprob(X < 0), 56 | spe.weights[1] + spe.children[1].logprob(X < 0), 57 | ]) 58 | 59 | def test_sum_normal_gamma_exposed(): 60 | X = Id('X') 61 | W = Id('W') 62 | weights = W >> choice({ 63 | '0': Fraction(2,3), 64 | '1': Fraction(1,3), 65 | }) 66 | children = { 67 | '0': X >> norm(loc=0, scale=1), 68 | '1': X >> gamma(loc=0, a=1), 69 | } 70 | spe = ExposedSumSPE(children, weights) 71 | 72 | assert spe.logprob(W << {'0'}) == log(Fraction(2, 3)) 73 | assert spe.logprob(W << {'1'}) == log(Fraction(1, 3)) 74 | assert allclose(spe.logprob((W << {'0'}) | (W << {'1'})), 0) 75 | assert spe.logprob((W << {'0'}) & (W << {'1'})) == -float('inf') 76 | 77 | assert allclose( 78 | spe.logprob((W << {'0', '1'}) & (X < 1)), 79 | spe.logprob(X < 1)) 80 | 81 | assert allclose( 82 | spe.logprob((W << {'0'}) & (X < 1)), 83 | spe.weights[0] + spe.children[0].logprob(X < 1)) 84 | 85 | spe_condition = spe.condition((W << {'1'}) | (W << {'0'})) 86 | assert isinstance(spe_condition, SumSPE) 87 | assert len(spe_condition.weights) == 2 88 | assert \ 89 | allclose(spe_condition.weights[0], log(Fraction(2,3))) \ 90 | and allclose(spe_condition.weights[0], log(Fraction(2,3))) \ 91 | or \ 92 | allclose(spe_condition.weights[1], log(Fraction(2,3))) \ 93 | and allclose(spe_condition.weights[0], log(Fraction(2,3)) 94 | ) 95 | 96 | spe_condition = spe.condition((W << {'1'})) 97 | assert isinstance(spe_condition, ProductSPE) 98 | assert isinstance(spe_condition.children[0], NominalLeaf) 99 | assert isinstance(spe_condition.children[1], ContinuousLeaf) 100 | assert spe_condition.logprob(X < 5) == spe.children[1].logprob(X < 5) 101 | 102 | def test_sum_normal_nominal(): 103 | X = Id('X') 104 | children = [ 105 | X >> norm(loc=0, scale=1), 106 | X >> choice({'low': Fraction(3, 10), 'high': Fraction(7, 10)}), 107 | ] 108 | weights = [log(Fraction(4,7)), log(Fraction(3, 7))] 109 | 110 | spe = SumSPE(children, weights) 111 | 112 | assert allclose( 113 | spe.logprob(X < 0), 114 | log(Fraction(4,7)) + log(Fraction(1,2))) 115 | 116 | assert allclose( 117 | spe.logprob(X << {'low'}), 118 | log(Fraction(3,7)) + log(Fraction(3, 10))) 119 | 120 | # The semantics of ~(X<<{'low'}) are (X << String and X != 'low') 121 | assert allclose( 122 | spe.logprob(~(X << {'low'})), 123 | spe.logprob((X << {'high'}))) 124 | assert allclose( 125 | spe.logprob((X<> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)], 19 | [log(0.4), log(0.6)]), 20 | X >> gamma(loc=0, a=1), 21 | ] 22 | spe = SumSPE(children, [log(0.7), log(0.3)]) 23 | assert spe.size() == 4 24 | assert spe.children == ( 25 | children[0].children[0], 26 | children[0].children[1], 27 | children[1] 28 | ) 29 | assert allclose(spe.weights[0], log(0.7) + log(0.4)) 30 | assert allclose(spe.weights[1], log(0.7) + log(0.6)) 31 | assert allclose(spe.weights[2], log(0.3)) 32 | 33 | def test_sum_simplify_nested_sum_2(): 34 | X = Id('X') 35 | W = Id('W') 36 | children = [ 37 | SumSPE([ 38 | (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)), 39 | (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1))], 40 | [log(0.9), log(0.1)]), 41 | (X >> norm(loc=0, scale=4)) & (W >> norm(loc=0, scale=10)), 42 | SumSPE([ 43 | (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)), 44 | (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1)), 45 | (X >> norm(loc=0, scale=8)) & (W >> norm(loc=0, scale=3)),], 46 | [log(0.4), log(0.3), log(0.3)]), 47 | ] 48 | spe = SumSPE(children, [log(0.4), log(0.4), log(0.2)]) 49 | assert spe.size() == 19 50 | assert spe.children == ( 51 | children[0].children[0], # 2 leaves 52 | children[0].children[1], # 2 leaves 53 | children[1], # 2 leaf 54 | children[2].children[0], # 2 leaves 55 | children[2].children[1], # 2 leaves 56 | children[2].children[2], # 2 leaves 57 | ) 58 | assert allclose(spe.weights[0], log(0.4) + log(0.9)) 59 | assert allclose(spe.weights[1], log(0.4) + log(0.1)) 60 | assert allclose(spe.weights[2], log(0.4)) 61 | assert allclose(spe.weights[3], log(0.2) + log(0.4)) 62 | assert allclose(spe.weights[4], log(0.2) + log(0.3)) 63 | assert allclose(spe.weights[5], log(0.2) + log(0.3)) 64 | 65 | def test_sum_simplify_leaf(): 66 | Xd0 = Id('X') >> norm(loc=0, scale=1) 67 | Xd1 = Id('X') >> norm(loc=0, scale=2) 68 | Xd2 = Id('X') >> norm(loc=0, scale=3) 69 | spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) 70 | assert spe.size() == 4 71 | assert spe_simplify_sum(spe) == spe 72 | 73 | Xd0 = Id('X') >> norm(loc=0, scale=1) 74 | Xd1 = Id('X') >> norm(loc=0, scale=1) 75 | Xd2 = Id('X') >> norm(loc=0, scale=1) 76 | spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) 77 | assert spe_simplify_sum(spe) == Xd0 78 | 79 | Xd3 = Id('X') >> norm(loc=0, scale=2) 80 | spe = SumSPE([Xd0, Xd3, Xd1, Xd3], [log(0.5), log(0.1), log(.3), log(.1)]) 81 | spe_simplified = spe_simplify_sum(spe) 82 | assert len(spe_simplified.children) == 2 83 | assert spe_simplified.children[0] == Xd0 84 | assert spe_simplified.children[1] == Xd3 85 | assert allclose(spe_simplified.weights[0], log(0.8)) 86 | assert allclose(spe_simplified.weights[1], log(0.2)) 87 | 88 | def test_sum_simplify_product_collapse(): 89 | A1 = Id('A') >> norm(loc=0, scale=1) 90 | A0 = Id('A') >> norm(loc=0, scale=1) 91 | B = Id('B') >> norm(loc=0, scale=1) 92 | B1 = Id('B') >> norm(loc=0, scale=1) 93 | B0 = Id('B') >> norm(loc=0, scale=1) 94 | C = Id('C') >> norm(loc=0, scale=1) 95 | C1 = Id('C') >> norm(loc=0, scale=1) 96 | D = Id('D') >> norm(loc=0, scale=1) 97 | spe = SumSPE([ 98 | ProductSPE([A1, B, C, D]), 99 | ProductSPE([A0, B1, C, D]), 100 | ProductSPE([A0, B0, C1, D]), 101 | ], [log(0.4), log(0.4), log(0.2)]) 102 | assert spe_simplify_sum(spe) == ProductSPE([A1, B, C, D]) 103 | 104 | def test_sum_simplify_product_complex(): 105 | A1 = Id('A') >> norm(loc=0, scale=1) 106 | A0 = Id('A') >> norm(loc=0, scale=2) 107 | B = Id('B') >> norm(loc=0, scale=1) 108 | B1 = Id('B') >> norm(loc=0, scale=2) 109 | B0 = Id('B') >> norm(loc=0, scale=3) 110 | C = Id('C') >> norm(loc=0, scale=1) 111 | C1 = Id('C') >> norm(loc=0, scale=2) 112 | D = Id('D') >> norm(loc=0, scale=1) 113 | spe = SumSPE([ 114 | ProductSPE([A1, B, C, D]), 115 | ProductSPE([A0, B1, C, D]), 116 | ProductSPE([A0, B0, C1, D]), 117 | ], [log(0.4), log(0.4), log(0.2)]) 118 | 119 | spe_simplified = spe_simplify_sum(spe) 120 | assert isinstance(spe_simplified, ProductSPE) 121 | assert isinstance(spe_simplified.children[0], SumSPE) 122 | assert spe_simplified.children[1] == D 123 | 124 | ssc0 = spe_simplified.children[0] 125 | assert isinstance(ssc0.children[1], ProductSPE) 126 | assert ssc0.children[1].children == (A0, B0, C1) 127 | 128 | assert isinstance(ssc0.children[0], ProductSPE) 129 | assert ssc0.children[0].children[1] == C 130 | 131 | ssc0c0 = ssc0.children[0].children[0] 132 | assert isinstance(ssc0c0, SumSPE) 133 | assert isinstance(ssc0c0.children[0], ProductSPE) 134 | assert isinstance(ssc0c0.children[1], ProductSPE) 135 | assert ssc0c0.children[0].children == (A1, B) 136 | assert ssc0c0.children[1].children == (A0, B1) 137 | -------------------------------------------------------------------------------- /tests/test_sym_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import pytest 5 | 6 | from sympy import exp as SymExp 7 | from sympy import log as SymLog 8 | from sympy import symbols 9 | 10 | from sppl.sets import FiniteReal 11 | from sppl.sym_util import get_symbols 12 | from sppl.sym_util import partition_finite_real_contiguous 13 | from sppl.sym_util import partition_list_blocks 14 | 15 | (X0, X1, X2, X3, X4, X5, X6, X7, X8, X9) = symbols('X:10') 16 | 17 | def test_get_symbols(): 18 | syms = get_symbols((X0 > 3) & (X1 < 4)) 19 | assert len(syms) == 2 20 | assert X0 in syms 21 | assert X1 in syms 22 | 23 | syms = get_symbols((SymExp(X0) > SymLog(X1)+10) & (X2 < 4)) 24 | assert len(syms) == 3 25 | assert X0 in syms 26 | assert X1 in syms 27 | assert X2 in syms 28 | 29 | @pytest.mark.parametrize('a, b', [ 30 | ([0, 1, 2, 3], [[0], [1], [2], [3]]), 31 | ([0, 1, 2, 1], [[0], [1, 3], [2]]), 32 | (['0', '0', 2, '0'], [[0, 1, 3], [2]]), 33 | ]) 34 | def test_partition_list_blocks(a, b): 35 | solution = partition_list_blocks(a) 36 | assert solution == b 37 | 38 | @pytest.mark.parametrize('a, b', [ 39 | (FiniteReal(0,1,2), [FiniteReal(0,1,2)]), 40 | (FiniteReal(0,3,1,2), [FiniteReal(0,1,2,3)]), 41 | (FiniteReal(-1,3,1,2), [FiniteReal(-1), FiniteReal(1,2,3)]), 42 | (FiniteReal(-1,3,1,2,-2,-7), [FiniteReal(-7), FiniteReal(-1,-2), FiniteReal(1,2,3)]), 43 | (FiniteReal(-1,3,1,2,-2,-7,0), [FiniteReal(-7), FiniteReal(-2,-1,0,1,2,3)]), 44 | ]) 45 | def test_parition_finite_real_contiguous(a, b): 46 | solution = partition_finite_real_contiguous(a) 47 | assert solution == b 48 | --------------------------------------------------------------------------------