├── .github
└── workflows
│ └── python-package.yml
├── .gitignore
├── GUIDE.md
├── LICENSE.txt
├── README.md
├── check.sh
├── examples
├── .gitignore
├── clickgraph.ipynb
├── clinical-trial.ipynb
├── disease-infection.ipynb
├── election.ipynb
├── fairness-hiring-model-1.ipynb
├── fairness-hiring-model-2.ipynb
├── fairness-income-model.ipynb
├── generate.sh
├── hierarchcial-markov-switching.ipynb
├── indian-gpa.ipynb
├── piecewise-transformation.ipynb
├── poisson-mean-inference.ipynb
├── random-sequence.ipynb
├── robot-localization.ipynb
├── simple-mixture-model.ipynb
├── student-interviews.ipynb
├── trueskill-poisson-binomial.ipynb
└── two-dimensional-mixture-model.ipynb
├── magics
├── __init__.py
├── magics.py
└── render.py
├── overview.png
├── pythenv.sh
├── requirements.sh
├── setup.py
├── sppl.png
├── src
├── __init__.py
├── compilers
│ ├── __init__.py
│ ├── ast_to_spe.py
│ ├── spe_to_dict.py
│ ├── spe_to_sppl.py
│ └── sppl_to_python.py
├── distributions.py
├── dnf.py
├── math_util.py
├── poly.py
├── render.py
├── sets.py
├── spe.py
├── sym_util.py
├── timeout.py
└── transforms.py
└── tests
├── .gitignore
├── __init__.py
├── test_ast_condition.py
├── test_ast_for.py
├── test_ast_ifelse.py
├── test_ast_switch.py
├── test_ast_transform.py
├── test_burglary.py
├── test_cache_duplicate_subtrees.py
├── test_clickgraph.py
├── test_dnf.py
├── test_event_evaluate.py
├── test_indian_gpa.py
├── test_logpdf.py
├── test_mutual_information.py
├── test_nominal_distribution.py
├── test_parse_distributions.py
├── test_parse_spe.py
├── test_parse_transforms.py
├── test_poly.py
├── test_product.py
├── test_real_continuous.py
├── test_real_discrete.py
├── test_render.py
├── test_sets.py
├── test_solve_transforms.py
├── test_spe_to_dict.py
├── test_spe_to_spml.py
├── test_spe_transform.py
├── test_sppl_to_python.py
├── test_substitute.py
├── test_sum.py
├── test_sum_simplify.py
└── test_sym_util.py
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 | workflow_dispatch:
12 |
13 | jobs:
14 | build:
15 |
16 | runs-on: ubuntu-20.04
17 | strategy:
18 | matrix:
19 | python-version: [3.8, 3.9]
20 |
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Test src
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install '.[tests]'
31 | ./check.sh ci
32 | - name: Test magics
33 | run: |
34 | sudo sh requirements.sh
35 | pip install '.[all]'
36 | ./check.sh -k '__magics_'
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .dockerignore
2 | .cache/
3 | .coverage
4 | *.egg
5 | *.egg-info
6 | __pycache__/
7 | build/
8 | clean.sh
9 | dist/
10 | egg-info
11 | htmlcov/
12 | MANIFEST
13 |
--------------------------------------------------------------------------------
/GUIDE.md:
--------------------------------------------------------------------------------
1 | System Overview
2 | ===============
3 |
4 | An overview of the SPPL architecture is shown below.
5 | For further details, please refer to the [system description](https://doi.org/10.1145/3453483.3454078)
6 |
7 |
8 |
9 | Probabilistic programs written in SPPL are translated into symbolic sum-product expressions
10 | that represent the joint distribution over all program variables and are used to deliver
11 | exact solutions to probabilistic inference queries.
12 |
13 | Guide to Source Files
14 | =====================
15 |
16 | The table below describes the main source files that make up SPPL.
17 |
18 | | Filename | Description |
19 | | -------- | ----------- |
20 | | [`src/distributions.py`](src/distributions.py) | Wrappers for discrete and continuous probability distributions from [scipy.stats](https://docs.scipy.org/doc/scipy/reference/stats.html), making them available as modeling primitives in SPPL. |
21 | | [`src/dnf.py`](`src/dnf.py`) | Event preprocessing algorithms, which include converting events to disjunctive normal form, factoring variables in events, and writing an event as a disjoint union of conjunctions. |
22 | | [`src/math_util.py`](src/math_util.py) | Various utilities for mathematical routines. |
23 | | [`src/poly.py`](src/poly.py) | Semi-symbolic solvers for equalities and inequalities involving univariate polynomials with real coefficients. |
24 | | [`src/render.py`](src/render.py) | Renders a sum-product expression as a nested Python list, ideal for use with pprint. |
25 | | [`src/sets.py`](src/sets.py) | Type system and utilities for set theoretic operations including finite nominals, finite reals, and real intervals. |
26 | | [`src/spe.py`](src/spe.py) | Main module implementing the sum-product expressions, including the sum and product combinators and various leaf primitives. |
27 | | [`src/sym_util.py`](src/sym_util.py) | Various utilities for operating on sets and symbolic variables. |
28 | | [`src/timeout.py`](src/timeout.py) | Python context for enforcing a time limit on a block of code. |
29 | | [`src/transforms.py`](src/transforms.py) | Main module implementing (i) numerical transformations on symbolic variables, such as absolute values, logarithms, exponentials, polynomials, piecewise transformations, and (ii) logical transformations, which include conjunctions, disjunctions, and negations and of primitive events (predicates). |
30 | | [`src/compilers/ast_to_spe.py`](ast_to_spe.py) | Translates an SPPL abstract syntax tree to a sum-product expression. |
31 | | [`src/compilers/spe_to_dict.py`](spe_to_dict.py) | Converts a sum-product expression to a Python dictionary. |
32 | | [`src/compilers/spe_to_sppl.py`](spe_to_sppl.py) | Translates a sum-product expression to an SPPL program. |
33 | | [`src/compilers/sppl_to_python.py`](sppl_to_python.py) | Translates SPPL source code to Python source code that contains the original program abstract syntax tree. |
34 | | [`magics/magics.py`](magics/magics.py) | Provides magics for using SPPL through IPython notebooks (see [examples/](./examples)). |
35 | | [`magics/render.py`](magics/render.py) | Renders an SPE as networkx and graphviz. |
36 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/probcomp/sppl/actions)
2 | [](https://pypi.org/project/sppl/)
3 |
4 | Sum-Product Probabilistic Language
5 | ==================================
6 |
7 |
8 |
9 | SPPL is a probabilistic programming language that delivers exact solutions
10 | to a broad range of probabilistic inference queries. The language handles
11 | continuous, discrete, and mixed-type probability distributions; many-to-one
12 | numerical transformations; and a query language that includes general
13 | predicates on random variables.
14 |
15 | Users express generative models as probabilistic programs with standard
16 | imperative constructs, such as arrays, if/else branches, for loops, etc.
17 | The program is then translated to a sum-product expression (a
18 | generalization of [sum-product networks](https://arxiv.org/pdf/2004.01167.pdf))
19 | that statically represents the probability distribution of all random
20 | variables in the program. This expression is used to deliver answers to
21 | probabilistic inference queries.
22 |
23 | A system description of SPPL is given in the following paper:
24 |
25 | SPPL: Probabilistic Programming with Fast Exact Symbolic Inference.
26 | Saad, F. A.; Rinard, M. C.; and Mansinghka, V. K.
27 | In PLDI 2021: Proceedings of the 42nd ACM SIGPLAN International Conference
28 | on Programming Language Design and Implementation,
29 | June 20-25, Virtual, Canada. ACM, New York, NY, USA. 2021.
30 | https://doi.org/10.1145/3453483.3454078.
31 |
32 | ### Installation
33 |
34 | This software is tested on Ubuntu 20.04 and Python 3.8.
35 | SPPL is available on the PyPI repository
36 |
37 | $ python -m pip install sppl
38 |
39 | To install the Jupyter interface, first obtain the system-wide dependencies in
40 | [requirements.sh](https://github.com/probcomp/sppl/blob/master/requirements.sh)
41 | and then run
42 |
43 | $ python -m pip install 'sppl[magics]'
44 |
45 | ### Examples
46 |
47 | The easiest way to use SPPL is via the browser-based Jupyter interface, which
48 | allows for interactive modeling, querying, and plotting.
49 | Refer to the `.ipynb` notebooks under the
50 | [examples](https://github.com/probcomp/sppl/tree/master/examples) directory.
51 |
52 | ### Benchmarks
53 |
54 | Please refer to the artifact at the ACM Digital Library:
55 | https://doi.org/10.1145/3453483.3454078
56 |
57 | ### Guide to Source Code
58 |
59 | Please refer to [GUIDE.md](./GUIDE.md) for a description of the
60 | main source files in this repository.
61 |
62 | ### Tests
63 |
64 | To run the test suite as a user, first install the test dependencies:
65 |
66 | $ python -m pip install 'sppl[tests]'
67 |
68 | Then run the test suite:
69 |
70 | $ python -m pytest --pyargs sppl
71 |
72 | To run the test suite as a developer:
73 |
74 | - To run crash tests: `$ ./check.sh`
75 | - To run integration tests: `$ ./check.sh ci`
76 | - To run a specific test: `$ ./check.sh [] /path/to/test.py`
77 | - To run the examples: `$ ./check.sh examples`
78 | - To build a docker image: `$ ./check.sh docker`
79 | - To generate a coverage report: `$ ./check.sh coverage`
80 |
81 | To view the coverage report, open `htmlcov/index.html` in the browser.
82 |
83 | ### Language Reference
84 |
85 | Coming Soon!
86 |
87 | ### Citation
88 |
89 | To cite this work, please use the following BibTeX.
90 |
91 | ```bibtex
92 | @inproceedings{saad2021sppl,
93 | title = {{SPPL:} Probabilistic Programming with Fast Exact Symbolic Inference},
94 | author = {Saad, Feras A. and Rinard, Martin C. and Mansinghka, Vikash K.},
95 | booktitle = {PLDI 2021: Proceedings of the 42nd ACM SIGPLAN International Conference on Programming Design and Implementation},
96 | pages = {804--819},
97 | year = 2021,
98 | location = {Virtual, Canada},
99 | publisher = {ACM},
100 | address = {New York, NY, USA},
101 | doi = {10.1145/3453483.3454078},
102 | address = {New York, NY, USA},
103 | keywords = {probabilistic programming, symbolic execution, static analysis},
104 | }
105 | ```
106 |
107 | ### License
108 |
109 | Apache 2.0; see [LICENSE.txt](./LICENSE.txt)
110 |
111 | ### Acknowledgments
112 |
113 | The [logo](https://github.com/probcomp/sppl/blob/master/sppl.png) was
114 | designed by [McCoy R. Becker](https://femtomc.github.io/).
115 |
--------------------------------------------------------------------------------
/check.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -Ceux
4 |
5 | : ${PYTHON:=python}
6 |
7 | root=`cd -- "$(dirname -- "$0")" && pwd`
8 |
9 | (
10 | set -Ceu
11 | cd -- "${root}"
12 | rm -rf build
13 | "$PYTHON" setup.py build
14 |
15 | # (Default) Run tests not marked __ci__
16 | if [ $# -eq 0 ] || [ ${1} = 'crash' ]; then
17 | ./pythenv.sh "$PYTHON" -m pytest -k 'not __ci_' --pyargs sppl
18 |
19 | # Run all tests under tests/
20 | elif [ ${1} = 'ci' ]; then
21 | ./pythenv.sh "$PYTHON" -m pytest --pyargs sppl
22 |
23 | # Generate coverage report.
24 | elif [ ${1} = 'coverage' ]; then
25 | ./pythenv.sh coverage run --source=build/ -m pytest --pyargs sppl
26 | coverage html
27 | coverage report
28 |
29 | # Run the .ipynb notebooks under examples/
30 | elif [ ${1} = 'examples' ]; then
31 | cd -- examples/
32 | ./generate.sh
33 | cd -- "${root}"
34 |
35 | # Make a tagged release.
36 | elif [ ${1} = 'tag' ]; then
37 | tag="${2}"
38 | (git diff --quiet --stat && git diff --quiet --staged) \
39 | || (echo 'fatal: workspace dirty' && exit 1)
40 | git show-ref --quiet --tags v"${tag}" \
41 | && (echo 'fatal: tag exists' && exit 1)
42 | sed -i "s/__version__ = .*/__version__ = '${tag}'/g" -- src/__init__.py
43 | git add -- src/__init__.py
44 | git commit -m "Pin version ${tag}."
45 | git tag -a -m v"${tag}" v"${tag}"
46 |
47 | # Send release to PyPI.
48 | elif [ ${1} = 'release' ]; then
49 | rm -rf dist
50 | "$PYTHON" setup.py sdist bdist_wheel
51 | twine upload --repository pypi dist/*
52 |
53 | # If args are specified delegate control to user.
54 | else
55 | ./pythenv.sh "$PYTHON" -m pytest "$@"
56 |
57 | fi
58 | )
59 |
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | *.html
2 |
--------------------------------------------------------------------------------
/examples/clinical-trial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext sppl.magics"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%%sppl model\n",
19 | "\n",
20 | "n = 20\n",
21 | "k = 20\n",
22 | "isEffective ~= bernoulli(p=.5)\n",
23 | "probControl ~= randint(low=0, high=k)\n",
24 | "probTreated ~= randint(low=0, high=k)\n",
25 | "probAll ~= randint(low=0, high=k)\n",
26 | "\n",
27 | "controlGroup = array(n)\n",
28 | "treatedGroup = array(n)\n",
29 | "\n",
30 | "if (isEffective == 1):\n",
31 | " for i in range(n):\n",
32 | " switch (probControl) cases (p in range(k)):\n",
33 | " controlGroup[i] ~= bernoulli(p=p/k)\n",
34 | " switch (probTreated) cases (p in range(k)):\n",
35 | " treatedGroup[i] ~= bernoulli(p=p/k)\n",
36 | "else:\n",
37 | " for i in range(n):\n",
38 | " switch (probAll) cases (p in range(k)):\n",
39 | " controlGroup[i] ~= bernoulli(p=p/k)\n",
40 | " treatedGroup[i] ~= bernoulli(p=p/k)"
41 | ]
42 | }
43 | ],
44 | "metadata": {
45 | "kernelspec": {
46 | "display_name": "Python 3",
47 | "language": "python",
48 | "name": "python3"
49 | },
50 | "language_info": {
51 | "codemirror_mode": {
52 | "name": "ipython",
53 | "version": 3
54 | },
55 | "file_extension": ".py",
56 | "mimetype": "text/x-python",
57 | "name": "python",
58 | "nbconvert_exporter": "python",
59 | "pygments_lexer": "ipython3",
60 | "version": "3.6.9"
61 | }
62 | },
63 | "nbformat": 4,
64 | "nbformat_minor": 4
65 | }
66 |
--------------------------------------------------------------------------------
/examples/fairness-hiring-model-1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext sppl.magics\n",
10 | "%matplotlib inline\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "import numpy as np"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "scrolled": false
20 | },
21 | "outputs": [],
22 | "source": [
23 | "%%sppl model\n",
24 | "\n",
25 | "# Prior on ethnicity, experience, and college rank.\n",
26 | "ethnicity ~= choice({'minority': .15, 'majority': .85})\n",
27 | "years_experience ~= binom(n=15, p=.5)\n",
28 | "if (ethnicity == 'minority'):\n",
29 | " college_rank ~= dlaplace(loc=25, a=1/5)\n",
30 | "else:\n",
31 | " college_rank ~= dlaplace(loc=20, a=1/5)\n",
32 | "\n",
33 | "# Top 50 colleges and at most 20 years of experience.\n",
34 | "condition((0 <= college_rank) <= 50)\n",
35 | "condition((0 <= years_experience) <= 20)\n",
36 | "\n",
37 | "# Hiring decision.\n",
38 | "if college_rank <= 5:\n",
39 | " hire ~= atomic(loc=1)\n",
40 | "else:\n",
41 | " switch (years_experience) cases (years in range(0, 20)):\n",
42 | " if ((years - 5) > college_rank):\n",
43 | " hire ~= atomic(loc=1)\n",
44 | " else:\n",
45 | " hire ~= atomic(loc=0)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 3,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "n = %sppl_get_namespace model\n",
55 | "model = n.model"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 4,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "minority = n.ethnicity << {'minority'}\n",
65 | "model_c0 = model.condition(minority)\n",
66 | "model_c1 = model.condition(~minority)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 5,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "p_hire_given_minority = model_c0.prob(n.hire << {1})\n",
76 | "p_hire_given_majority = model_c1.prob(n.hire << {1})"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 6,
82 | "metadata": {},
83 | "outputs": [
84 | {
85 | "name": "stdout",
86 | "output_type": "stream",
87 | "text": [
88 | "0.3666600508668959\n"
89 | ]
90 | }
91 | ],
92 | "source": [
93 | "print(p_hire_given_minority / p_hire_given_majority)"
94 | ]
95 | }
96 | ],
97 | "metadata": {
98 | "kernelspec": {
99 | "display_name": "Python 3",
100 | "language": "python",
101 | "name": "python3"
102 | },
103 | "language_info": {
104 | "codemirror_mode": {
105 | "name": "ipython",
106 | "version": 3
107 | },
108 | "file_extension": ".py",
109 | "mimetype": "text/x-python",
110 | "name": "python",
111 | "nbconvert_exporter": "python",
112 | "pygments_lexer": "ipython3",
113 | "version": "3.6.9"
114 | }
115 | },
116 | "nbformat": 4,
117 | "nbformat_minor": 4
118 | }
119 |
--------------------------------------------------------------------------------
/examples/fairness-hiring-model-2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext sppl.magics\n",
10 | "%matplotlib inline\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "import numpy as np"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "scrolled": false
20 | },
21 | "outputs": [],
22 | "source": [
23 | "%%sppl model\n",
24 | "\n",
25 | "# Population model\n",
26 | "is_male ~= bernoulli(p=.5)\n",
27 | "college_rank ~= norm(loc=25, scale=10)\n",
28 | "if is_male == 1:\n",
29 | " years_exp ~= norm(loc=15, scale=5)\n",
30 | "else:\n",
31 | " years_exp ~= norm(loc=10, scale=5)\n",
32 | "\n",
33 | "# Hiring decision.\n",
34 | "if ((college_rank <= 5) | (years_exp > 5)):\n",
35 | " hire ~= atomic(loc=1)\n",
36 | "else:\n",
37 | " hire ~= atomic(loc=0)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "n = %sppl_get_namespace model"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 4,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "model_m = n.model.condition(n.is_male << {1})\n",
56 | "model_f = n.model.condition(n.is_male << {0})"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 5,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "p_m = model_m.prob(n.hire<<{1})\n",
66 | "p_f = model_f.prob(n.hire<<{1})"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 6,
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "data": {
76 | "text/plain": [
77 | "0.8641668176293485"
78 | ]
79 | },
80 | "execution_count": 6,
81 | "metadata": {},
82 | "output_type": "execute_result"
83 | }
84 | ],
85 | "source": [
86 | "p_f/p_m"
87 | ]
88 | }
89 | ],
90 | "metadata": {
91 | "kernelspec": {
92 | "display_name": "Python 3",
93 | "language": "python",
94 | "name": "python3"
95 | },
96 | "language_info": {
97 | "codemirror_mode": {
98 | "name": "ipython",
99 | "version": 3
100 | },
101 | "file_extension": ".py",
102 | "mimetype": "text/x-python",
103 | "name": "python",
104 | "nbconvert_exporter": "python",
105 | "pygments_lexer": "ipython3",
106 | "version": "3.6.9"
107 | }
108 | },
109 | "nbformat": 4,
110 | "nbformat_minor": 4
111 | }
112 |
--------------------------------------------------------------------------------
/examples/fairness-income-model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext sppl.magics"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%%sppl model\n",
19 | "\n",
20 | "# Population model.\n",
21 | "sex ~= choice({'female': .3307, 'male': .6693})\n",
22 | "if (sex == 'female'):\n",
23 | " capital_gain ~= norm(loc=568.4105, scale=24248365.5428)\n",
24 | " if capital_gain < 7298.0000:\n",
25 | " age ~= norm(loc=38.4208, scale=184.9151)\n",
26 | " relationship ~= choice({\n",
27 | " '0': .0491, '1': .1556, '2': .4012,\n",
28 | " '3': .2589, '4': .0294, '5': .1058\n",
29 | " })\n",
30 | " else:\n",
31 | " age ~= norm(loc=38.8125, scale=193.4918)\n",
32 | " relationship ~= choice({\n",
33 | " '0': .0416, '1': .1667, '2': .4583,\n",
34 | " '3': .2292, '4': .0166, '5': .0876\n",
35 | " })\n",
36 | "else:\n",
37 | " capital_gain ~= norm(loc=1329.3700, scale=69327473.1006)\n",
38 | " if capital_gain < 5178.0000:\n",
39 | " age ~= norm(loc=38.6361, scale=187.2435)\n",
40 | " relationship ~= choice({\n",
41 | " '0': .0497, '1': .1545, '2': .4021,\n",
42 | " '3': .2590, '4': .0294, '5': .1053\n",
43 | " })\n",
44 | " else:\n",
45 | " age ~= norm(loc=38.2668, scale=187.2747)\n",
46 | " relationship ~= choice({\n",
47 | " '0': .0417, '1': .1624, '2': .3976,\n",
48 | " '3': .2606, '4': .0356, '5': .1021\n",
49 | " })\n",
50 | "\n",
51 | "condition(age > 18)\n",
52 | "\n",
53 | "# Decision model.\n",
54 | "if relationship == '1':\n",
55 | " if capital_gain < 5095.5:\n",
56 | " t ~= atomic(loc=1)\n",
57 | " else:\n",
58 | " t ~= atomic(loc=0)\n",
59 | "elif relationship == '2':\n",
60 | " if capital_gain < 4718.5:\n",
61 | " t ~= atomic(loc=1)\n",
62 | " else:\n",
63 | " t ~= atomic(loc=0)\n",
64 | "elif relationship == '3':\n",
65 | " if capital_gain < 5095.5:\n",
66 | " t ~= atomic(loc=1)\n",
67 | " else:\n",
68 | " t ~= atomic(loc=0)\n",
69 | "elif relationship == '4':\n",
70 | " if capital_gain < 8296:\n",
71 | " t ~= atomic(loc=1)\n",
72 | " else:\n",
73 | " t ~= atomic(loc=0)\n",
74 | "elif relationship == '5':\n",
75 | " t ~= atomic(loc=1)\n",
76 | "else:\n",
77 | " if capital_gain < 4668.5:\n",
78 | " t ~= atomic(loc=1)\n",
79 | " else:\n",
80 | " t ~= atomic(loc=0)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 3,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "n = %sppl_get_namespace model\n",
90 | "model = n.model"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 4,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "model_c1 = model.condition(n.t << {0})"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 5,
105 | "metadata": {},
106 | "outputs": [
107 | {
108 | "name": "stdout",
109 | "output_type": "stream",
110 | "text": [
111 | "1.046470962495416\n"
112 | ]
113 | }
114 | ],
115 | "source": [
116 | "p_female_prior = model.prob(n.sex << {'female'})\n",
117 | "p_female_given_no_hire = model_c1.prob(n.sex << {'female'})\n",
118 | "print(100 * (p_female_given_no_hire / p_female_prior - 1))"
119 | ]
120 | }
121 | ],
122 | "metadata": {
123 | "kernelspec": {
124 | "display_name": "Python 3",
125 | "language": "python",
126 | "name": "python3"
127 | },
128 | "language_info": {
129 | "codemirror_mode": {
130 | "name": "ipython",
131 | "version": 3
132 | },
133 | "file_extension": ".py",
134 | "mimetype": "text/x-python",
135 | "name": "python",
136 | "nbconvert_exporter": "python",
137 | "pygments_lexer": "ipython3",
138 | "version": "3.6.9"
139 | }
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 4
143 | }
144 |
--------------------------------------------------------------------------------
/examples/generate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | for x in $(ls *.ipynb); do
4 | rm -rf ${x%%.ipynb}.html
5 | jupyter nbconvert --execute --to html ${x};
6 | done
7 |
--------------------------------------------------------------------------------
/examples/robot-localization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext sppl.magics\n",
10 | "%matplotlib inline"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%%sppl model\n",
20 | "param ~= randint(low=0, high=101)\n",
21 | "which ~= uniform(loc=0, scale=1)\n",
22 | "if (which < 0.9):\n",
23 | " switch (param) cases (p in range(0, 101)):\n",
24 | " x ~= norm(loc=p/10, scale=1)\n",
25 | "elif (which < 0.95):\n",
26 | " x ~= uniform(loc=0, scale=10)\n",
27 | "else:\n",
28 | " x ~= atomic(loc=10)"
29 | ]
30 | }
31 | ],
32 | "metadata": {
33 | "kernelspec": {
34 | "display_name": "Python 3",
35 | "language": "python",
36 | "name": "python3"
37 | },
38 | "language_info": {
39 | "codemirror_mode": {
40 | "name": "ipython",
41 | "version": 3
42 | },
43 | "file_extension": ".py",
44 | "mimetype": "text/x-python",
45 | "name": "python",
46 | "nbconvert_exporter": "python",
47 | "pygments_lexer": "ipython3",
48 | "version": "3.6.9"
49 | }
50 | },
51 | "nbformat": 4,
52 | "nbformat_minor": 4
53 | }
54 |
--------------------------------------------------------------------------------
/magics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from .magics import SPPL_Magics
5 |
6 | def load_ipython_extension(ipython):
7 | magics = SPPL_Magics(ipython)
8 | ipython.register_magics(magics)
9 |
--------------------------------------------------------------------------------
/magics/magics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import sys
5 | from collections import namedtuple
6 |
7 | from IPython.core.magic import Magics
8 | from IPython.core.magic import cell_magic
9 | from IPython.core.magic import line_magic
10 | from IPython.core.magic import magics_class
11 | from IPython.core.magic import needs_local_scope
12 |
13 | from sppl.compilers.sppl_to_python import SPPL_Compiler
14 |
15 | from .render import render_graphviz
16 |
17 | Model = namedtuple('Model', ['source', 'compiler', 'namespace'])
18 |
19 | @magics_class
20 | class SPPL_Magics(Magics):
21 |
22 | def __init__(self, shell):
23 | super().__init__(shell)
24 | self.programs = {}
25 |
26 | @line_magic
27 | def sppl_get_spe(self, line):
28 | assert line in self.programs, 'unknown program %s' % (line,)
29 | return getattr(self.programs[line].namespace, line)
30 |
31 | @cell_magic
32 | def sppl(self, line, cell):
33 | if not line:
34 | sys.stderr.write('specify model name after %%sppl')
35 | return
36 | if line in self.programs:
37 | del self.programs[line]
38 | compiler = SPPL_Compiler(cell, line)
39 | namespace = compiler.execute_module()
40 | self.programs[line] = Model(cell, compiler, namespace)
41 |
42 | @line_magic
43 | def sppl_to_python(self, line):
44 | assert line in self.programs, 'unknown program %s' % (line,)
45 | print(self.programs[line].compiler.render_module())
46 |
47 | @needs_local_scope
48 | @line_magic
49 | def sppl_to_graph(self, line, local_ns):
50 | tokens = line.strip().split(' ')
51 | line = tokens[0]
52 | filename = tokens[1] if len(tokens) == 2 else None
53 | if line in self.programs:
54 | spe = self.sppl_get_spe(line)
55 | elif line in local_ns:
56 | spe = local_ns[line]
57 | else:
58 | assert False, 'unknown program %s' % (line,)
59 | return render_graphviz(spe, filename=filename)
60 |
61 | @line_magic
62 | def sppl_get_namespace(self, line):
63 | assert line in self.programs, 'unknown program %s' % (line,)
64 | return self.programs[line].namespace
65 |
--------------------------------------------------------------------------------
/magics/render.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import os
5 | import tempfile
6 | import time
7 |
8 | from math import exp
9 |
10 | import graphviz
11 | import networkx as nx
12 |
13 | from sppl.spe import AtomicLeaf
14 | from sppl.spe import NominalLeaf
15 | from sppl.spe import ProductSPE
16 | from sppl.spe import RealLeaf
17 | from sppl.spe import SumSPE
18 |
19 | gensym = lambda: 'r%s' % (str(time.time()).replace('.', ''),)
20 |
21 | def render_networkx_graph(spe):
22 | if isinstance(spe, NominalLeaf):
23 | G = nx.DiGraph()
24 | root = gensym()
25 | G.add_node(root, label='%s\n%s' % (spe.symbol.token, 'Nominal'))
26 | return G
27 | if isinstance(spe, AtomicLeaf):
28 | G = nx.DiGraph()
29 | root = gensym()
30 | G.add_node(root, label='%s\n%s(%s)'
31 | % (spe.symbol.token, 'Atomic', str(spe.value)))
32 | return G
33 | if isinstance(spe, RealLeaf):
34 | G = nx.DiGraph()
35 | root = gensym()
36 | kwds = '\n%s' % (tuple(spe.dist.kwds.values()),) if spe.dist.kwds else ''
37 | G.add_node(root, label='%s\n%s%s' % (spe.symbol.token, spe.dist.dist.name, kwds))
38 | if len(spe.env) > 1:
39 | for k, v in spe.env.items():
40 | if k != spe.symbol:
41 | roott = gensym()
42 | G.add_node(roott, label=str(v), style='filled')
43 | G.add_edge(root, roott, label=' %s' % (str(k),), style='dashed')
44 | return G
45 | if isinstance(spe, SumSPE):
46 | G = nx.DiGraph()
47 | root = gensym()
48 | G.add_node(root, label='\N{PLUS SIGN}')
49 | # Add nodes and edges from children.
50 | G_children = [render_networkx_graph(c) for c in spe.children]
51 | for i, x in enumerate(G_children):
52 | G.add_nodes_from(x.nodes.data())
53 | G.add_edges_from(x.edges.data())
54 | subroot = list(nx.topological_sort(x))[0]
55 | G.add_edge(root, subroot, label='%1.3f' % (exp(spe.weights[i]),))
56 | return G
57 | if isinstance(spe, ProductSPE):
58 | G = nx.DiGraph()
59 | root = gensym()
60 | G.add_node(root, label='\N{MULTIPLICATION SIGN}')
61 | # Add nodes and edges from children.
62 | G_children = [render_networkx_graph(c) for c in spe.children]
63 | for x in G_children:
64 | G.add_nodes_from(x.nodes.data())
65 | G.add_edges_from(x.edges.data())
66 | subroot = list(nx.topological_sort(x))[0]
67 | G.add_edge(root, subroot)
68 | return G
69 | assert False, 'Unknown SPE type: %s' % (spe,)
70 |
71 | def render_graphviz(spe, filename=None, ext=None, show=None):
72 | fname = filename
73 | if filename is None:
74 | f = tempfile.NamedTemporaryFile(delete=False)
75 | fname = f.name
76 | G = render_networkx_graph(spe)
77 | ext = ext or 'png'
78 | assert ext in ['png', 'pdf'], 'Extension must be .pdf or .png'
79 | fname_dot = '%s.dot' % (fname,)
80 | # nx.set_edge_attributes(G, 'serif', 'fontname')
81 | # nx.set_node_attributes(G, 'serif', 'fontname')
82 | nx.nx_agraph.write_dot(G, fname_dot)
83 | source = graphviz.Source.from_file(fname_dot, format=ext)
84 | source.render(filename=fname, view=show)
85 | os.unlink(fname)
86 | return source
87 |
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probsys/sppl/ab6435648e56df1603c4d8d27029605c247cb9f5/overview.png
--------------------------------------------------------------------------------
/pythenv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -Ceu
4 |
5 | : ${PYTHON:=python}
6 | root=`cd -- "$(dirname -- "$0")" && pwd`
7 | platform=$("${PYTHON}" -c 'import distutils.util as u; print(u.get_platform())')
8 | version=$("${PYTHON}" -c 'import sys; print(sys.version[0:3])')
9 |
10 | # The lib directory varies depending on
11 | #
12 | # (a) whether there are extension modules (here, no); and
13 | # (b) whether some Debian maintainer decided to patch the local Python
14 | # to behave as though there were.
15 | #
16 | # But there's no obvious way to just ask distutils what the name will
17 | # be. There's no harm in naming a pathname that doesn't exist, other
18 | # than a handful of microseconds of runtime, so we'll add both.
19 | libdir="${root}/build/lib"
20 | plat_libdir="${libdir}.${platform}-${version}"
21 | export PYTHONPATH="${libdir}:${plat_libdir}${PYTHONPATH:+:${PYTHONPATH}}"
22 |
23 | bindir="${root}/build/scripts-${version}"
24 | export PATH="${bindir}${PATH:+:${PATH}}"
25 |
26 | exec "$@"
27 |
--------------------------------------------------------------------------------
/requirements.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Tested on Ubuntu 18.04+, other platforms may vary.
4 | #
5 | # apt package pypi package
6 | # =========== ============
7 | # graphviz pygraphviz
8 | # libgraphviz-dev pygraphviz
9 | # gfortran scipy
10 |
11 | apt-get -y install python3-dev graphviz libgraphviz-dev gfortran
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import os
5 | import re
6 |
7 | from setuptools import setup
8 |
9 | # Specify the requirements.
10 | requirements = {
11 | 'src' : [
12 | 'astunparse==1.6.3',
13 | 'numpy==1.23.1',
14 | 'scipy==1.8.1',
15 | 'sympy==1.10.1',
16 | ],
17 | 'magics' : [
18 | 'graphviz==0.13.2',
19 | 'ipython==7.23.1',
20 | 'jupyter-core==4.6.3',
21 | 'networkx==2.4',
22 | 'notebook==6.0.3',
23 | 'matplotlib==3.3.2',
24 | 'pygraphviz==1.5',
25 | ],
26 | 'tests' : [
27 | 'pytest-timeout==2.1.0',
28 | 'pytest==7.1.2',
29 | 'coverage==6.4.2',
30 | ]
31 | }
32 | requirements['all'] = [r for v in requirements.values() for r in v]
33 |
34 | # Determine the version (hardcoded).
35 | dirname = os.path.dirname(os.path.realpath(__file__))
36 | vre = re.compile('__version__ = \'(.*?)\'')
37 | m = open(os.path.join(dirname, 'src', '__init__.py')).read()
38 | __version__ = vre.findall(m)[0]
39 |
40 | setup(
41 | name='sppl',
42 | version=__version__,
43 | description='The Sum-Product Probabilistic Language',
44 | long_description=open('README.md').read(),
45 | long_description_content_type='text/markdown',
46 | url='https://github.com/probcomp/sppl',
47 | license='Apache-2.0',
48 | maintainer='Feras A. Saad',
49 | maintainer_email='fsaad@mit.edu',
50 | classifiers=[
51 | 'Intended Audience :: Science/Research',
52 | 'License :: OSI Approved :: Apache Software License',
53 | 'Topic :: Scientific/Engineering :: Mathematics',
54 | ],
55 | packages=[
56 | 'sppl',
57 | 'sppl.compilers',
58 | 'sppl.magics',
59 | 'sppl.tests',
60 | ],
61 | package_dir={
62 | 'sppl' : 'src',
63 | 'sppl.compilers' : 'src/compilers',
64 | 'sppl.magics' : 'magics',
65 | 'sppl.tests' : 'tests',
66 | },
67 | install_requires=requirements['src'],
68 | extras_require=requirements,
69 | python_requires='>=3.8',
70 | )
71 |
--------------------------------------------------------------------------------
/sppl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probsys/sppl/ab6435648e56df1603c4d8d27029605c247cb9f5/sppl.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | __version__ = '2.0.4'
5 |
--------------------------------------------------------------------------------
/src/compilers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
--------------------------------------------------------------------------------
/src/compilers/ast_to_spe.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | """Convert AST to SPE."""
5 |
6 | import os
7 |
8 | from functools import reduce
9 |
10 | from ..dnf import dnf_normalize
11 | from ..math_util import allclose
12 | from ..math_util import isinf_neg
13 | from ..math_util import logsumexp
14 | from ..sets import FiniteNominal
15 | from ..sets import FiniteReal
16 | from ..sets import Set
17 | from ..spe import Memo
18 | from ..spe import SumSPE
19 | from ..spe import spe_simplify_sum
20 |
21 | from .. import transforms
22 |
23 | inf = float('inf')
24 |
25 | Id = transforms.Id
26 | def IdArray(token, n):
27 | return [Id('%s[%d]' % (token, i,)) for i in range(n)]
28 |
29 | class Command():
30 | def interpret(self, spe):
31 | raise NotImplementedError()
32 |
33 | class Skip(Command):
34 | def __init__(self):
35 | pass
36 | def interpret(self, spe=None):
37 | return spe
38 |
39 | class Sample(Command):
40 | def __init__(self, symbol, distribution):
41 | self.symbol = symbol
42 | self.distribution = distribution
43 | def interpret(self, spe=None):
44 | leaf = self.symbol >> self.distribution
45 | return leaf if (spe is None) else spe & leaf
46 |
47 | class Transform(Command):
48 | def __init__(self, symbol, expr):
49 | self.symbol = symbol
50 | self.expr = expr
51 | def interpret(self, spe=None):
52 | assert spe is not None
53 | return spe.transform(self.symbol, self.expr)
54 |
55 | class Condition(Command):
56 | def __init__(self, event):
57 | self.event = event
58 | def interpret(self, spe=None):
59 | assert spe is not None
60 | return spe.condition(self.event)
61 |
62 | class Constrain(Command):
63 | def __init__(self, assignment):
64 | self.assignment = assignment
65 | def interpret(self, spe=None):
66 | assert spe is not None
67 | return spe.constrain(self.assignment)
68 |
69 | class IfElse(Command):
70 | def __init__(self, *branches):
71 | assert len(branches) % 2 == 0
72 | self.branches = branches
73 | def interpret(self, spe=None):
74 | assert spe is not None
75 | conditions = self.branches[::2]
76 | subcommands = self.branches[1::2]
77 | # Make events for each condition.
78 | if conditions[-1] is True:
79 | events_if = [
80 | reduce(lambda x, e: x & ~e, conditions[:i], conditions[i])
81 | for i in range(len(conditions)-1)
82 | ]
83 | event_else = ~reduce(lambda x, e: x|e, conditions[:-1])
84 | events_unorm = events_if + [event_else]
85 | else:
86 | events_unorm = [
87 | reduce(lambda x, e: x & ~e, conditions[:i], conditions[i])
88 | for i in range(len(conditions))
89 | ]
90 | # Rewrite events in normalized form.
91 | events = [dnf_normalize(event) for event in events_unorm]
92 | # Rewrite events in normalized form.
93 | return interpret_if_block(spe, events, subcommands)
94 |
95 | class For(Command):
96 | def __init__(self, n0, n1, f):
97 | self.n0 = n0
98 | self.n1 = n1
99 | self.f = f
100 | def interpret(self, spe=None):
101 | commands = [self.f(i) for i in range(self.n0, self.n1)]
102 | sequence = Sequence(*commands)
103 | return sequence.interpret(spe)
104 |
105 | class Switch(Command):
106 | def __init__(self, symbol, values, f):
107 | self.symbol = symbol
108 | self.f = f
109 | self.values = values
110 | def interpret(self, spe=None):
111 | if isinstance(self.values, enumerate):
112 | values = list(self.values)
113 | sets = [self.value_to_set(v[1]) for v in values]
114 | subcommands = [self.f(*v) for v in values]
115 | else:
116 | sets = [self.value_to_set(v) for v in self.values]
117 | subcommands = [self.f(v) for v in self.values]
118 | sets_disjoint = [
119 | reduce(lambda x, s: x & ~s, sets[:i], sets[i])
120 | for i in range(len(sets))]
121 | events = [self.symbol << s for s in sets_disjoint]
122 | return interpret_if_block(spe, events, subcommands)
123 | def value_to_set(self, v):
124 | if isinstance(v, Set):
125 | return v
126 | if isinstance(v, str):
127 | return FiniteNominal(v)
128 | return FiniteReal(v)
129 |
130 | class Sequence(Command):
131 | def __init__(self, *commands):
132 | self.commands = commands
133 | def interpret(self, spe=None):
134 | return reduce(lambda S, c: c.interpret(S), self.commands, spe)
135 |
136 | Otherwise = True
137 |
138 | def interpret_if_block(spe, events, subcommands):
139 | assert len(events) == len(subcommands)
140 | # Prepare memo table.
141 | memo = Memo()
142 | # Obtain mixture probabilities.
143 | weights = [spe.logprob(event, memo)
144 | if event is not None else -inf for event in events]
145 | # Filter the irrelevant ones.
146 | indexes = [i for i, w in enumerate(weights) if not isinf_neg(w)]
147 | assert indexes, 'All conditions probability zero.'
148 | # Obtain conditioned SPEs.
149 | weights_conditioned = [weights[i] for i in indexes]
150 | spes_conditioned = [spe.condition(events[i], memo) for i in indexes]
151 | subcommands_conditioned = [subcommands[i] for i in indexes]
152 | assert allclose(logsumexp(weights_conditioned), 0)
153 | # Make the children.
154 | children = [
155 | subcommand.interpret(S)
156 | for S, subcommand in zip(spes_conditioned, subcommands_conditioned)
157 | ]
158 | # Maybe Simplify.
159 | if len(children) == 1:
160 | spe = children[0]
161 | else:
162 | spe = SumSPE(children, weights_conditioned)
163 | if not os.environ.get('SPPL_NO_SIMPLIFY'):
164 | spe = spe_simplify_sum(spe)
165 | # Return the SPE.
166 | return spe
167 |
--------------------------------------------------------------------------------
/src/compilers/spe_to_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | """Convert SPE to JSON friendly dictionary."""
5 |
6 | from fractions import Fraction
7 |
8 | import scipy.stats
9 | from sympy import E
10 | from sympy import sqrt
11 |
12 | from ..spe import AtomicLeaf
13 | from ..spe import ContinuousLeaf
14 | from ..spe import DiscreteLeaf
15 | from ..spe import NominalLeaf
16 | from ..spe import ProductSPE
17 | from ..spe import SumSPE
18 |
19 | # Needed for "eval"
20 | from ..sets import *
21 | from ..transforms import Id
22 | from ..transforms import Identity
23 | from ..transforms import Radical
24 | from ..transforms import Exponential
25 | from ..transforms import Logarithm
26 | from ..transforms import Abs
27 | from ..transforms import Reciprocal
28 | from ..transforms import Poly
29 | from ..transforms import Piecewise
30 | from ..transforms import EventInterval
31 | from ..transforms import EventFiniteReal
32 | from ..transforms import EventFiniteNominal
33 | from ..transforms import EventOr
34 | from ..transforms import EventAnd
35 |
36 | def env_from_dict(env):
37 | if env is None:
38 | return None
39 | # Used in eval.
40 | return {eval(k): eval(v) for k, v in env.items()}
41 |
42 | def env_to_dict(env):
43 | if len(env) == 1:
44 | return None
45 | return {repr(k): repr(v) for k, v in env.items()}
46 |
47 | def scipy_dist_from_dict(dist):
48 | constructor = getattr(scipy.stats, dist['name'])
49 | return constructor(*dist['args'], **dist['kwds'])
50 |
51 | def scipy_dist_to_dict(dist):
52 | return {
53 | 'name': dist.dist.name,
54 | 'args': dist.args,
55 | 'kwds': dist.kwds
56 | }
57 |
58 | def spe_from_dict(metadata):
59 | if metadata['class'] == 'NominalLeaf':
60 | symbol = Id(metadata['symbol'])
61 | dist = {x: Fraction(w[0], w[1]) for x, w in metadata['dist']}
62 | return NominalLeaf(symbol, dist)
63 | if metadata['class'] == 'AtomicLeaf':
64 | symbol = Id(metadata['symbol'])
65 | value = float(metadata['value'])
66 | env = env_from_dict(metadata['env'])
67 | return AtomicLeaf(symbol, value, env=env)
68 | if metadata['class'] == 'ContinuousLeaf':
69 | symbol = Id(metadata['symbol'])
70 | dist = scipy_dist_from_dict(metadata['dist'])
71 | support = eval(metadata['support'])
72 | conditioned = metadata['conditioned']
73 | env = env_from_dict(metadata['env'])
74 | return ContinuousLeaf(symbol, dist, support, conditioned, env=env)
75 | if metadata['class'] == 'DiscreteLeaf':
76 | symbol = Id(metadata['symbol'])
77 | dist = scipy_dist_from_dict(metadata['dist'])
78 | support = eval(metadata['support'])
79 | conditioned = metadata['conditioned']
80 | env = env_from_dict(metadata['env'])
81 | return DiscreteLeaf(symbol, dist, support, conditioned, env=env)
82 | if metadata['class'] == 'SumSPE':
83 | children = [spe_from_dict(c) for c in metadata['children']]
84 | weights = metadata['weights']
85 | return SumSPE(children, weights)
86 | if metadata['class'] == 'ProductSPE':
87 | children = [spe_from_dict(c) for c in metadata['children']]
88 | return ProductSPE(children)
89 |
90 | assert False, 'Cannot convert %s to SPE' % (metadata,)
91 |
92 | def spe_to_dict(spe):
93 | if isinstance(spe, NominalLeaf):
94 | return {
95 | 'class' : 'NominalLeaf',
96 | 'symbol' : spe.symbol.token,
97 | 'dist' : [
98 | (str(x), (w.numerator, w.denominator))
99 | for x, w in spe.dist.items()
100 | ],
101 | 'env' : env_to_dict(spe.env),
102 | }
103 | if isinstance(spe, AtomicLeaf):
104 | return {
105 | 'class' : 'AtomicLeaf',
106 | 'symbol' : spe.symbol.token,
107 | 'value' : spe.value,
108 | 'env' : env_to_dict(spe.env),
109 | }
110 | if isinstance(spe, ContinuousLeaf):
111 | return {
112 | 'class' : 'ContinuousLeaf',
113 | 'symbol' : spe.symbol.token,
114 | 'dist' : scipy_dist_to_dict(spe.dist),
115 | 'support' : repr(spe.support),
116 | 'conditioned' : spe.conditioned,
117 | 'env' : env_to_dict(spe.env),
118 | }
119 | if isinstance(spe, DiscreteLeaf):
120 | return {
121 | 'class' : 'DiscreteLeaf',
122 | 'symbol' : spe.symbol.token,
123 | 'dist' : scipy_dist_to_dict(spe.dist),
124 | 'support' : repr(spe.support),
125 | 'conditioned' : spe.conditioned,
126 | 'env' : env_to_dict(spe.env),
127 | }
128 | if isinstance(spe, SumSPE):
129 | return {
130 | 'class' : 'SumSPE',
131 | 'children' : [spe_to_dict(c) for c in spe.children],
132 | 'weights' : spe.weights,
133 | }
134 | if isinstance(spe, ProductSPE):
135 | return {
136 | 'class' : 'ProductSPE',
137 | 'children' : [spe_to_dict(c) for c in spe.children],
138 | }
139 | assert False, 'Cannot convert %s to JSON' % (spe,)
140 |
--------------------------------------------------------------------------------
/src/compilers/spe_to_sppl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | """Convert SPE to SPPL."""
5 |
6 | from io import StringIO
7 | from math import exp
8 |
9 | from ..spe import RealLeaf
10 | from ..spe import NominalLeaf
11 | from ..spe import ProductSPE
12 | from ..spe import SumSPE
13 |
14 | get_indentation = lambda i: ' ' * i
15 | float_to_str = lambda x,fw: '%1.*f' % (fw, float(x)) if fw else str(x)
16 | class _SPPL_Render_State:
17 | def __init__(self, stream=None, branches=None, indentation=None,
18 | fwidth=None):
19 | self.stream = stream or StringIO()
20 | self.branches = branches or []
21 | self.indentation = indentation or 0
22 | self.fwidth = fwidth
23 | def render_sppl_choice(symbol, dist, stream, indentation, fwidth):
24 | idt = get_indentation(indentation)
25 | dist_pairs = [
26 | '\'%s\' : %s' % (k, float_to_str(v, fwidth))
27 | for (k, v) in dist.items()
28 | ]
29 | # Write each outcome on new line.
30 | if sum(len(x) for x in dist_pairs) > 30:
31 | idt4 = get_indentation(indentation + 4)
32 | prefix = ',\n%s' % (idt4,)
33 | dist_op = '\n%s' % (idt4,)
34 | dist_cl = '\n%s' % (idt,)
35 | # Write all outcomes on same line.
36 | else:
37 | prefix = ', '
38 | dist_op = ''
39 | dist_cl = ''
40 | dist_str = prefix.join(dist_pairs)
41 | stream.write('%s%s ~= choice({%s%s%s})' %
42 | (idt, symbol, dist_op, dist_str, dist_cl))
43 | stream.write('\n')
44 | def render_sppl_helper(spe, state):
45 | if isinstance(spe, NominalLeaf):
46 | assert len(spe.env) == 1
47 | render_sppl_choice(
48 | spe.symbol,
49 | spe.dist,
50 | state.stream,
51 | state.indentation,
52 | state.fwidth)
53 | return state
54 | if isinstance(spe, RealLeaf):
55 | kwds = ', '.join([
56 | '%s=%s' % (k, float_to_str(v, state.fwidth))
57 | for k, v in spe.dist.kwds.items()
58 | ])
59 | dist = '%s(%s)' % (spe.dist.dist.name, kwds)
60 | idt = get_indentation(state.indentation)
61 | state.stream.write('%s%s ~= %s' % (idt, spe.symbol, dist))
62 | state.stream.write('\n')
63 | if spe.conditioned:
64 | event = spe.symbol << spe.support
65 | # TODO: Consider using repr(event)
66 | state.stream.write('%scondition(%s)' % (idt, event))
67 | state.stream.write('\n')
68 | for i, (var, expr) in enumerate(spe.env.items()):
69 | if 1 <= i:
70 | state.stream.write('%s%s ~= %s' % (idt, var, expr))
71 | state.stream.write('\n')
72 | return state
73 | if isinstance(spe, ProductSPE):
74 | for child in spe.children:
75 | state = render_sppl_helper(child, state)
76 | return state
77 | if isinstance(spe, SumSPE):
78 | if len(spe.children) == 0:
79 | return state
80 | if len(spe.children) == 1:
81 | return render_sppl_helper(spe.children[0], state)
82 | branch_var = 'branch_var_%s' % (len(state.branches))
83 | branch_idxs = [str(i) for i in range(len(spe.children))]
84 | branch_dist = {k: exp(w) for k, w in zip(branch_idxs, spe.weights)}
85 | state.branches.append((branch_var, branch_dist))
86 | # Write the branches.
87 | for i, child in zip(branch_idxs, spe.children):
88 | ifstmt = 'if' if i == '0' else 'elif'
89 | idt = get_indentation(state.indentation)
90 | state.stream.write('%s%s (%s == \'%s\'):'
91 | % (idt, ifstmt, branch_var, i))
92 | state.stream.write('\n')
93 | state.indentation += 4
94 | state = render_sppl_helper(child, state)
95 | state.stream.write('\n')
96 | state.indentation -= 4
97 | return state
98 | assert False, 'Unknown spe %s' % (spe,)
99 |
100 | def render_sppl(spe, stream=None, fwidth=None):
101 | if stream is None:
102 | stream = StringIO()
103 | state = _SPPL_Render_State(fwidth=fwidth)
104 | state = render_sppl_helper(spe, state)
105 | assert state.indentation == 0
106 | # Write the import.
107 | stream.write('from sppl.distributions import *')
108 | stream.write('\n')
109 | stream.write('\n')
110 | # Write the branch variables (if any).
111 | for branch_var, branch_dist in state.branches:
112 | render_sppl_choice(branch_var, branch_dist, stream, 0, fwidth)
113 | stream.write('\n')
114 | # Write the SPPL.
115 | stream.write(state.stream.getvalue())
116 | return stream
117 |
--------------------------------------------------------------------------------
/src/dnf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from functools import reduce
5 | from itertools import chain
6 | from itertools import combinations
7 |
8 | from .sets import EmptySet
9 | from .transforms import EventAnd
10 | from .transforms import EventBasic
11 | from .transforms import EventOr
12 | from .transforms import Id
13 |
14 | def dnf_factor(event, lookup=None):
15 | # Given an event (in DNF) and a dictionary lookup mapping symbols
16 | # to integers, this function returns a list R of dictionaries
17 | # R[i][j] is a conjunction events in the i-th DNF clause whose symbols
18 | # are assigned to integer j in the lookup dictionary.
19 | #
20 | # For example, if e is any predicate
21 | # event = (e(X0) & e(X1) & ~e(X2)) | (~e(X1) & e(X2) & e(X3) & e(X4)))
22 | # lookup = {X0: 0, X1: 1, X2: 0, X3: 1, X4: 2}
23 | # The output is
24 | # R = [
25 | # { // First clause
26 | # 0: e(X0) & ~e(X2),
27 | # 1: e(X1)},
28 | # { // Second clause
29 | # 0: e(X2),
30 | # 1: ~e(X1) & e(X3)},
31 | # 2: e(X4)},
32 | # ]
33 | if lookup is None:
34 | lookup = {s:s for s in event.get_symbols()}
35 |
36 | if isinstance(event, EventBasic):
37 | # Literal.
38 | symbols = event.get_symbols()
39 | assert len(symbols) == 1
40 | symbol = list(symbols)[0]
41 | key = lookup[symbol]
42 | return ({key: event},)
43 |
44 | if isinstance(event, EventAnd):
45 | # Conjunction.
46 | assert all(isinstance(e, EventBasic) for e in event.subexprs)
47 | mappings = (dnf_factor(e, lookup) for e in event.subexprs)
48 | events = {}
49 | for mapping in mappings:
50 | for m in mapping:
51 | for key, ev in m.items():
52 | if key not in events:
53 | events[key] = ev
54 | else:
55 | events[key] &= ev
56 | return (events,)
57 |
58 | if isinstance(event, EventOr):
59 | # Disjunction.
60 | assert all(isinstance(e, (EventAnd, EventBasic)) for e in event.subexprs)
61 | mappings = (dnf_factor(e, lookup) for e in event.subexprs)
62 | return tuple(chain.from_iterable(mappings))
63 |
64 | assert False, 'Invalid DNF event: %s' % (event,)
65 |
66 | def dnf_normalize(event):
67 | if isinstance(event, EventBasic):
68 | if isinstance(event.subexpr, Id):
69 | return event
70 | # Given an arbitrary event, rewrite in terms of only Id by
71 | # solving the subexpressions and return the resulting DNF formula,
72 | # or None if all solutions evaluate to EmptySet.
73 | event_dnf = event.to_dnf()
74 | event_factor = dnf_factor(event_dnf)
75 | solutions = list(filter(lambda x: all(y[1] is not EmptySet for y in x), [
76 | [(symbol, ev.solve()) for symbol, ev in clause.items()]
77 | for clause in event_factor
78 | ]))
79 | if not solutions:
80 | return None
81 | conjunctions = [
82 | reduce(lambda x, e: x & e, [(symbol << S) for symbol, S in clause])
83 | for i, clause in enumerate(solutions) if clause not in solutions[:i]
84 | ]
85 | disjunctions = reduce(lambda x, e: x|e, conjunctions)
86 | return disjunctions.to_dnf()
87 |
88 | def dnf_non_disjoint_clauses(event):
89 | # Given an event in DNF, returns a dictionary R
90 | # such that R[j] = [i | i < j and event[i] intersects event[j]]
91 | event_factor = dnf_factor(event)
92 | solutions = [
93 | {symbol: ev.solve() for symbol, ev in clause.items()}
94 | for clause in event_factor
95 | ]
96 |
97 | n_clauses = len(event_factor)
98 | overlap_dict = {}
99 | for i, j in combinations(range(n_clauses), 2):
100 | # Exit if any symbol in i does not intersect a symbol in j.
101 | intersections = (
102 | solutions[i][symbol] & solutions[j][symbol]
103 | if (symbol in solutions[j]) else
104 | solutions[i][symbol]
105 | for symbol in solutions[i]
106 | )
107 | if any(x is EmptySet for x in intersections):
108 | continue
109 | # Exit if any symbol in j is EmptySet.
110 | if any(solutions[j] is EmptySet for symbol in solutions[j]):
111 | continue
112 | # All symbols intersect, so clauses overlap.
113 | if j not in overlap_dict:
114 | overlap_dict[j] = []
115 | overlap_dict[j].append(i)
116 |
117 | return overlap_dict
118 |
119 | def dnf_to_disjoint_union(event):
120 | # Given an event in DNF, returns an event in DNF where all the
121 | # clauses are disjoint from one another, by recursively solving the
122 | # identity E = (A or B or C) = (A) or (B and ~A) or (C and ~A and ~B).
123 | # Base case.
124 | if isinstance(event, (EventBasic, EventAnd)):
125 | return event
126 | # Find indexes of pairs of clauses that overlap.
127 | overlap_dict = dnf_non_disjoint_clauses(event)
128 | if not overlap_dict:
129 | return event
130 | # Create the cascading negated clauses.
131 | n_clauses = len(event.subexprs)
132 | clauses_disjoint = [
133 | reduce(
134 | lambda state, event: state & ~event,
135 | (event.subexprs[j] for j in overlap_dict.get(i, [])),
136 | event.subexprs[i])
137 | for i in range(n_clauses)
138 | ]
139 | # Recursively find the solutions for each clause.
140 | clauses_normalized = [dnf_normalize(clause) for clause in clauses_disjoint]
141 | solutions = [dnf_to_disjoint_union(c) for c in clauses_normalized if c]
142 | # Return the merged solution.
143 | return reduce(lambda a, b: a|b, solutions)
144 |
--------------------------------------------------------------------------------
/src/math_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import isinf
5 |
6 | import numpy
7 |
8 | from scipy.special import logsumexp
9 |
10 | # Implementation of log1mexp and logdiffexp from PyMC3 math module.
11 | # https://github.com/pymc-devs/pymc3/blob/master/pymc3/math.py
12 | def log1mexp(x):
13 | if x < 0.683:
14 | return numpy.log(-numpy.expm1(-x))
15 | else:
16 | return numpy.log1p(-numpy.exp(-x))
17 |
18 | def logdiffexp(a, b):
19 | if b < a:
20 | return a + log1mexp(a - b)
21 | if allclose(b, a):
22 | return -float('inf')
23 | raise ValueError('Negative term in logdiffexp.')
24 |
25 | def lognorm(array):
26 | M = logsumexp(array)
27 | return [a - M for a in array]
28 |
29 | def logflip(logp, array, size, rng):
30 | p = numpy.exp(lognorm(logp))
31 | return flip(p, array, size, rng)
32 |
33 | def flip(p, array, size, rng):
34 | p = normalize(p)
35 | return random(rng).choice(array, size=size, p=p)
36 |
37 | def normalize(p):
38 | s = float(sum(p))
39 | return numpy.asarray(p, dtype=float) / s
40 |
41 | def allclose(values, x):
42 | return numpy.allclose(values, x)
43 |
44 | def isinf_pos(x):
45 | return isinf(x) and x > 0
46 |
47 | def isinf_neg(x):
48 | return isinf(x) and x < 0
49 |
50 | def random(x):
51 | return x or numpy.random
52 |
53 | int_or_isinf_neg = lambda a: isinf_neg(a) or float(a) == int(a)
54 | int_or_isinf_pos = lambda a: isinf_pos(a) or float(a) == int(a)
55 | float_to_int = lambda a: a if isinf(a) else int(a)
56 |
--------------------------------------------------------------------------------
/src/poly.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import os
5 |
6 | from itertools import chain
7 | from math import isinf
8 |
9 | import sympy
10 |
11 | from sympy import limit
12 |
13 | from .sets import EmptySet
14 | from .sets import ExtReals
15 | from .sets import FiniteReal
16 | from .sets import Interval
17 | from .sets import Reals
18 | from .sets import convert_sympy
19 | from .sets import oo
20 | from .sets import make_union
21 | from .sym_util import get_symbols
22 | from .timeout import timeout
23 |
24 | TIMEOUT_SYMBOLIC = 5
25 |
26 | def get_poly_symbol(expr):
27 | symbols = tuple(get_symbols(expr))
28 | assert len(symbols) == 1
29 | return symbols[0]
30 |
31 | # ==============================================================================
32 | # Solving inequalities.
33 |
34 | def solve_poly_inequality(expr, b, strict, extended=None):
35 | # Handle infinite case.
36 | if isinf(b):
37 | return solve_poly_inequality_inf(expr, b, strict, extended=extended)
38 | # Bypass symbolic inference.
39 | if os.environ.get('SPPL_NO_SYMBOLIC'):
40 | return solve_poly_inequality_numerically(expr, b, strict)
41 | # Solve symbolically, if possible.
42 | try:
43 | with timeout(seconds=TIMEOUT_SYMBOLIC):
44 | result_symbolic = solve_poly_inequality_symbolically(expr, b, strict)
45 | except TimeoutError:
46 | result_symbolic = None
47 | if result_symbolic is not None:
48 | if not isinstance(result_symbolic, (sympy.ConditionSet, sympy.Intersection)):
49 | return convert_sympy(result_symbolic)
50 | # Solve numerically.
51 | return solve_poly_inequality_numerically(expr, b, strict)
52 |
53 | def solve_poly_inequality_symbolically(expr, b, strict):
54 | expr = (expr < b) if strict else (expr <= b)
55 | return sympy.solveset(expr, domain=sympy.Reals)
56 |
57 | def solve_poly_inequality_numerically(expr, b, strict):
58 | poly = expr - b
59 | symX = get_poly_symbol(expr)
60 | # Obtain numerical roots.
61 | roots = sympy.nroots(poly)
62 | zeros = sorted([r for r in roots if r.is_real])
63 | if not zeros:
64 | return sympy.EmptySet
65 | # Construct intervals around roots.
66 | mk_intvl = lambda a, b: \
67 | Interval(a, b, left_open=strict, right_open=strict)
68 | intervals = list(chain(
69 | [mk_intvl(-oo, zeros[0])],
70 | [mk_intvl(x, y) for x, y in zip(zeros, zeros[1:])],
71 | [mk_intvl(zeros[-1], oo)]))
72 | # Define probe points.
73 | xs_probe = list(chain(
74 | [zeros[0] - 1/2],
75 | [(i.left + i.right)/2 for i in intervals[1:-1]
76 | if isinstance(i, Interval)],
77 | [zeros[-1] + 1/2]))
78 | # Evaluate poly at the probe points.
79 | f_xs_probe = [poly.subs(symX, x) for x in xs_probe]
80 | # Return intervals where poly is less than zero.
81 | idxs = [i for i, fx in enumerate(f_xs_probe) if fx < 0]
82 | return make_union(*[intervals[i] for i in idxs])
83 |
84 | def solve_poly_inequality_inf(expr, b, strict, extended=None):
85 | # Minimum value of polynomial is negative infinity.
86 | assert isinf(b)
87 | ext = True if extended is None else extended
88 | if b < 0:
89 | if strict or not ext:
90 | return EmptySet
91 | else:
92 | return solve_poly_equality_inf(expr, b)
93 | # Maximum value of polynomial is positive infinity.
94 | else:
95 | if strict:
96 | xinf = solve_poly_equality_inf(expr, -oo) if ext else EmptySet
97 | return Reals | xinf
98 | else:
99 | return ExtReals if ext else Reals
100 |
101 | # ==============================================================================
102 | # Solving equalities.
103 |
104 | def solve_poly_equality(expr, b):
105 | # Handle infinite case.
106 | if isinf(b):
107 | return solve_poly_equality_inf(expr, b)
108 | # Bypass symbolic inference.
109 | if os.environ.get('SPPL_NO_SYMBOLIC'):
110 | return solve_poly_equality_numerically(expr, b)
111 | # Solve symbolically, if possible.
112 | try:
113 | with timeout(seconds=TIMEOUT_SYMBOLIC):
114 | result_symbolic = solve_poly_equality_symbolically(expr, b)
115 | except TimeoutError:
116 | result_symbolic = None
117 | if result_symbolic is not None:
118 | if not isinstance(result_symbolic, (sympy.ConditionSet, sympy.Intersection)):
119 | return convert_sympy(result_symbolic)
120 | # Solve numerically.
121 | return solve_poly_equality_numerically(expr, b)
122 |
123 | def solve_poly_equality_symbolically(expr, b):
124 | soln_lt = solve_poly_inequality_symbolically(expr, b, False)
125 | soln_gt = solve_poly_inequality_symbolically(-expr, -b, False)
126 | return sympy.Intersection(soln_lt, soln_gt)
127 |
128 | def solve_poly_equality_numerically(expr, b):
129 | roots = sympy.nroots(expr-b)
130 | zeros = [r for r in roots if r.is_real]
131 | return FiniteReal(*zeros)
132 |
133 | def solve_poly_equality_inf(expr, b):
134 | assert isinf(b)
135 | symX = get_poly_symbol(expr)
136 | val_pos_inf = limit(expr, symX, oo)
137 | val_neg_inf = limit(expr, symX, -oo)
138 | check_equal = lambda x: isinf(x) and ((x > 0) if (b > 0) else (x < 0))
139 | if check_equal(val_pos_inf) and check_equal(val_neg_inf):
140 | return FiniteReal(oo, -oo)
141 | if check_equal(val_pos_inf):
142 | return FiniteReal(oo)
143 | if check_equal(val_neg_inf):
144 | return FiniteReal(-oo)
145 | return EmptySet
146 |
--------------------------------------------------------------------------------
/src/render.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import exp
5 |
6 | from .spe import AtomicLeaf
7 | from .spe import DiscreteLeaf
8 | from .spe import LeafSPE
9 | from .spe import NominalLeaf
10 | from .spe import ProductSPE
11 | from .spe import RealLeaf
12 | from .spe import SumSPE
13 |
14 | def render_nested_lists_concise(spe):
15 | if isinstance(spe, LeafSPE):
16 | return [(str(k), str(v)) for k, v in spe.env.items()]
17 | if isinstance(spe, SumSPE):
18 | return ['+(%d)' % (len(spe.children),),
19 | # [exp(w) for w in spe.weights],
20 | [render_nested_lists_concise(c) for c in spe.children]
21 | ]
22 | if isinstance(spe, ProductSPE):
23 | return ['*(%d)' % (len(spe.children),),
24 | [render_nested_lists_concise(c) for c in spe.children]
25 | ]
26 |
27 | def render_nested_lists(spe):
28 | if isinstance(spe, NominalLeaf):
29 | return ['NominalLeaf', [
30 | ['symbol', spe.symbol],
31 | ['env', dict(spe.env)],
32 | ['dist', {str(x): float(w) for x, w in spe.dist.items()}]]
33 | ]
34 | if isinstance(spe, AtomicLeaf):
35 | return ['AtomicLeaf', [
36 | ['symbol', spe.symbol],
37 | ['value', spe.value],
38 | ['env', dict(spe.env)]]
39 | ]
40 | if isinstance(spe, RealLeaf):
41 | return ['RealLeaf', [
42 | ['symbol', spe.symbol],
43 | ['env', dict(spe.env)],
44 | ['dist', (spe.dist.dist.name, spe.dist.args, spe.dist.kwds)],
45 | ['support', spe.support],
46 | ['conditioned', spe.conditioned]]
47 | ]
48 | if isinstance(spe, DiscreteLeaf):
49 | return ['DiscreteLeaf', [
50 | ['symbol', spe.symbol],
51 | ['dist', (spe.dist.dist.name, spe.dist.args, spe.dist.kwds)],
52 | ['support', spe.support],
53 | ['conditioned', spe.conditioned]]
54 | ]
55 | if isinstance(spe, SumSPE):
56 | return ['SumSPE', [
57 | ['symbols', list(spe.symbols)],
58 | ['weights', [exp(w) for w in spe.weights]],
59 | ['n_children', len(spe.children)],
60 | ['children', [render_nested_lists(c) for c in spe.children]]]
61 | ]
62 | if isinstance(spe, ProductSPE):
63 | return ['ProductSPE', [
64 | ['symbols', list(spe.symbols)],
65 | ['n_children', len(spe.children)],
66 | ['children', [render_nested_lists(c) for c in spe.children]]]
67 | ]
68 | assert False, 'Unknown SPE type: %s' % (spe,)
69 |
--------------------------------------------------------------------------------
/src/sym_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from collections import OrderedDict
5 | from itertools import chain
6 | from itertools import combinations
7 | from math import isinf
8 |
9 | import numpy
10 | import sympy
11 |
12 | from sympy.core.relational import Relational
13 |
14 | from .sets import FiniteReal
15 | from .sets import Interval
16 |
17 | def get_symbols(expr):
18 | atoms = expr.atoms()
19 | return [a for a in atoms if isinstance(a, sympy.Symbol)]
20 |
21 | def get_union(sets):
22 | return sets[0].union(*sets[1:])
23 |
24 | def get_intersection(sets):
25 | return sets[0].intersection(*sets[1:])
26 |
27 | def are_disjoint(sets):
28 | union = get_union(sets)
29 | return len(union) == sum(len(s) for s in sets)
30 |
31 | def are_identical(sets):
32 | intersection = get_intersection(sets)
33 | return all(len(s) == len(intersection) for s in sets)
34 |
35 | def binspace(start, stop, num=10):
36 | values = numpy.linspace(start, stop, num)
37 | bins = list(zip(values[:-1], values[1:]))
38 | return [Interval(*b) for b in bins]
39 |
40 | def powerset(values, start=0):
41 | s = list(values)
42 | subsets = (combinations(s, k) for k in range(start, len(s) + 1))
43 | return chain.from_iterable(subsets)
44 |
45 | def partition_list_blocks(values):
46 | partition = OrderedDict([])
47 | for i, v in enumerate(values):
48 | x = hash(v)
49 | if x not in partition:
50 | partition[x] = []
51 | partition[x].append(i)
52 | return list(partition.values())
53 |
54 | def partition_finite_real_contiguous(x):
55 | # Convert FiniteReal to list of FiniteReal, each with contiguous values.
56 | assert isinstance(x, FiniteReal)
57 | values = sorted(x.values)
58 | blocks = [[values[0]]]
59 | for y in values[1:]:
60 | expected = blocks[-1][-1] + 1
61 | if y == expected:
62 | blocks[-1].append(y)
63 | else:
64 | blocks.append([y])
65 | return [FiniteReal(*v) for v in blocks]
66 |
67 | def sympify_number(x):
68 | if isinstance(x, (int, float)):
69 | return x
70 | msg = 'Expected a numeric term, not %s' % (x,)
71 | try:
72 | # String fallback in sympify has been deprecated since SymPy 1.6. Use
73 | # sympify(str(obj)) or sympy.core.sympify.converter or obj._sympy_
74 | # instead. See https://github.com/sympy/sympy/issues/18066 for more
75 | # info.
76 | sym = sympy.sympify(str(x))
77 | if not sym.is_number:
78 | raise TypeError(msg)
79 | return sym
80 | except (sympy.SympifyError, AttributeError, TypeError):
81 | raise TypeError(msg)
82 |
83 | def sym_log(x):
84 | assert 0 <= x
85 | if x == 0:
86 | return -float('inf')
87 | if isinf(x):
88 | return float('inf')
89 | return sympy.log(x)
90 |
91 | def sympy_solver(expr):
92 | # Sympy is buggy and slow. Use Transforms.
93 | symbols = get_symbols(expr)
94 | if len(symbols) != 1:
95 | raise ValueError('Expression "%s" needs exactly one symbol.' % (expr,))
96 |
97 | if isinstance(expr, Relational):
98 | result = sympy.solveset(expr, domain=sympy.Reals)
99 | elif isinstance(expr, sympy.Or):
100 | subexprs = expr.args
101 | intervals = [sympy_solver(e) for e in subexprs]
102 | result = sympy.Union(*intervals)
103 | elif isinstance(expr, sympy.And):
104 | subexprs = expr.args
105 | intervals = [sympy_solver(e) for e in subexprs]
106 | result = sympy.Intersection(*intervals)
107 | elif isinstance(expr, sympy.Not):
108 | (notexpr,) = expr.args
109 | interval = sympy_solver(notexpr)
110 | result = interval.complement(sympy.Reals)
111 | else:
112 | raise ValueError('Expression "%s" has unknown type.' % (expr,))
113 |
114 | if isinstance(result, sympy.ConditionSet):
115 | raise ValueError('Expression "%s" is not invertible.' % (expr,))
116 |
117 | return result
118 |
--------------------------------------------------------------------------------
/src/timeout.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import signal
5 |
6 | class timeout:
7 | # Adapted from https://stackoverflow.com/a/22348885
8 | def __init__(self, seconds=1, error_message='Timeout'):
9 | self.seconds = seconds
10 | self.error_message = error_message
11 | def handle_timeout(self, signum, frame):
12 | raise TimeoutError(self.error_message)
13 | def __enter__(self):
14 | signal.signal(signal.SIGALRM, self.handle_timeout)
15 | signal.alarm(self.seconds)
16 | def __exit__(self, type, value, traceback):
17 | signal.alarm(0)
18 |
--------------------------------------------------------------------------------
/tests/.gitignore:
--------------------------------------------------------------------------------
1 | disabled_*
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
--------------------------------------------------------------------------------
/tests/test_ast_condition.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.compilers.ast_to_spe import Condition
7 | from sppl.compilers.ast_to_spe import Id
8 | from sppl.compilers.ast_to_spe import Sample
9 | from sppl.compilers.ast_to_spe import Sequence
10 | from sppl.distributions import beta
11 | from sppl.distributions import choice
12 | from sppl.distributions import randint
13 | from sppl.math_util import allclose
14 |
15 | Y = Id('Y')
16 | X = Id('X')
17 |
18 | def test_condition_nominal():
19 | command = Sequence(
20 | Sample(Y, choice({'a':.1, 'b':.1, 'c':.8})),
21 | Condition(Y << {'a', 'b'}))
22 | model = command.interpret()
23 | assert allclose(model.prob(Y << {'a'}), .5)
24 | assert allclose(model.prob(Y << {'b'}), .5)
25 | assert allclose(model.prob(Y << {'c'}), 0)
26 |
27 | def test_condition_real_discrete_range():
28 | command = Sequence(
29 | Sample(Y, randint(low=0, high=4)),
30 | Condition(Y << {0, 1}))
31 | model = command.interpret()
32 | assert allclose(model.prob(Y << {0}), .5)
33 | assert allclose(model.prob(Y << {1}), .5)
34 |
35 | @pytest.mark.xfail(strict=True, reason='https://github.com/probcomp/sum-product-dsl/issues/77')
36 | def test_condition_real_discrete_no_range():
37 | command = Sequence(
38 | Sample(Y, randint(low=0, high=4)),
39 | Condition(Y << {0, 2}))
40 | model = command.interpret()
41 | assert allclose(model.prob(Y << {0}), .5)
42 | assert allclose(model.prob(Y << {1}), .5)
43 |
44 | def test_condition_real_continuous():
45 | command = Sequence(
46 | Sample(Y, beta(a=1, b=1)),
47 | Condition(Y < .5))
48 | model = command.interpret()
49 | assert allclose(model.prob(Y < .5), 1)
50 | assert allclose(model.prob(Y > .5), 0)
51 |
52 | def test_condition_prob_zero():
53 | with pytest.raises(Exception):
54 | Sequence(
55 | Sample(Y, {'a':.1, 'b':.1, 'c':.8}),
56 | Condition(Y << {'d'})).interpret()
57 | with pytest.raises(Exception):
58 | Sequence(
59 | Sample(Y, beta(a=1, b=1)),
60 | Condition(Y > 1)).interpret()
61 |
--------------------------------------------------------------------------------
/tests/test_ast_for.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | from sppl.compilers.ast_to_spe import For
7 | from sppl.compilers.ast_to_spe import Id
8 | from sppl.compilers.ast_to_spe import IdArray
9 | from sppl.compilers.ast_to_spe import IfElse
10 | from sppl.compilers.ast_to_spe import Otherwise
11 | from sppl.compilers.ast_to_spe import Sample
12 | from sppl.compilers.ast_to_spe import Sequence
13 | from sppl.distributions import bernoulli
14 | from sppl.distributions import choice
15 | from sppl.math_util import allclose
16 |
17 | Y = Id('Y')
18 | X = IdArray('X', 5)
19 | Z = IdArray('Z', 5)
20 |
21 | def test_simple_model():
22 | command = Sequence(
23 | Sample(Y, bernoulli(p=0.5)),
24 | For(0, 5, lambda i:
25 | Sample(X[i], bernoulli(p=1/(i+1)))))
26 | model = command.interpret()
27 |
28 | symbols = model.get_symbols()
29 | assert len(symbols) == 6
30 | assert Y in symbols
31 | assert X[0] in symbols
32 | assert X[1] in symbols
33 | assert X[2] in symbols
34 | assert X[3] in symbols
35 | assert X[4] in symbols
36 | assert model.logprob(X[0] << {1}) == log(1/1)
37 | assert model.logprob(X[1] << {1}) == log(1/2)
38 | assert model.logprob(X[2] << {1}) == log(1/3)
39 | assert model.logprob(X[3] << {1}) == log(1/4)
40 | assert model.logprob(X[4] << {1}) == log(1/5)
41 |
42 | def test_complex_model():
43 | # Slow for larger number of repetitions
44 | # https://github.com/probcomp/sum-product-dsl/issues/43
45 | command = Sequence(
46 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})),
47 | For(0, 3, lambda i: Sequence(
48 | Sample(Z[i], bernoulli(p=0.1)),
49 | IfElse(
50 | Y << {str(i)} | Z[i] << {0}, Sample(X[i], bernoulli(p=1/(i+1))),
51 | Otherwise, Sample(X[i], bernoulli(p=0.1))))))
52 | model = command.interpret()
53 | assert allclose(model.prob(Y << {'0'}), 0.2)
54 |
55 | def test_complex_model_reorder():
56 | command = Sequence(
57 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})),
58 | For(0, 3, lambda i:
59 | Sample(Z[i], bernoulli(p=0.1))),
60 | For(0, 3, lambda i:
61 | IfElse(
62 | Y << {str(i)}, Sample(X[i], bernoulli(p=1/(i+1))),
63 | Z[i] << {0}, Sample(X[i], bernoulli(p=1/(i+1))),
64 | Otherwise, Sample(X[i], bernoulli(p=0.1)))))
65 | model = command.interpret()
66 | assert(allclose(model.prob(Y << {'0'}), 0.2))
67 |
68 | def test_repeat_handcode_equivalence():
69 | model_repeat = make_model_for()
70 | model_hand = make_model_handcode()
71 |
72 | assert allclose(model_repeat.prob(Y << {'0', '1'}), 0.4)
73 | assert allclose(model_repeat.prob(Z[0] << {0}), 0.5)
74 | assert allclose(model_repeat.prob(Z[0] << {1}), 0.5)
75 |
76 | event_condition = (X[0] << {1}) | (Y << {'1'})
77 | model_repeat_condition = model_repeat.condition(event_condition)
78 | model_hand_condition = model_hand.condition(event_condition)
79 |
80 | for event in [
81 | Y << {'0','1'},
82 | Z[0] << {0},
83 | Z[1] << {0},
84 | X[0] << {0},
85 | X[1] << {0},
86 | ]:
87 | lp_repeat = model_repeat.logprob(event)
88 | lp_hand = model_hand.logprob(event)
89 | assert allclose(lp_hand, lp_repeat)
90 |
91 | lp_repeat_condition = model_repeat_condition.logprob(event)
92 | lp_hand_condition = model_hand_condition.logprob(event)
93 | assert allclose(lp_hand_condition, lp_repeat_condition)
94 |
95 | # This test case ensures that the duplicate subtrees in the mixture
96 | # components are pointers to the same object, and is obtained by
97 | # manually inspecting the rendering of the network generated by the
98 | # following code:
99 | #
100 | # from sppl.magics.render import render_graphviz
101 | # render_graphviz(model_repeat, show=True)
102 | #
103 | # See also test_cache_duplicate_subtrees.test_cache_complex_sum_of_product
104 | a = model_repeat.children[0].children[0].children[1].children[0]
105 | b = model_repeat.children[1].children[0]
106 | assert a is b
107 |
108 | # ==============================================================================
109 | # Helper functions.
110 |
111 | def make_model_for(n=2):
112 | command = Sequence(
113 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})),
114 | For(0, n, lambda i: Sequence(
115 | Sample(Z[i], bernoulli(p=.5)),
116 | IfElse(
117 | (Y << {str(i)}) | (Z[i] << {0}), Sample(X[i], bernoulli(p=.1)),
118 | Otherwise, Sample(X[i], bernoulli(p=.5))))))
119 | return command.interpret()
120 |
121 | def make_model_handcode():
122 | command = Sequence(
123 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})),
124 | Sample(Z[0], bernoulli(p=.5)),
125 | Sample(Z[1], bernoulli(p=.5)),
126 | IfElse(
127 | Y << {str(0)}, Sequence(
128 | Sample(X[0], bernoulli(p=.1)),
129 | IfElse(
130 | Z[1] << {0}, Sample(X[1], bernoulli(p=.1)),
131 | Otherwise, Sample(X[1], bernoulli(p=.5)))),
132 | Y << {str(1)}, Sequence(
133 | Sample(X[1], bernoulli(p=.1)),
134 | IfElse(
135 | Z[0] << {0}, Sample(X[0], bernoulli(p=.1)),
136 | Otherwise, Sample(X[0], bernoulli(p=.5)))),
137 | Otherwise, Sequence(
138 | IfElse(
139 | Z[0] << {0}, Sample(X[0], bernoulli(p=.1)),
140 | Otherwise, Sample(X[0], bernoulli(p=.5))),
141 | IfElse(
142 | Z[1] << {0}, Sample(X[1], bernoulli(p=.1)),
143 | Otherwise, Sample(X[1], bernoulli(p=.5))))))
144 | return command.interpret()
145 |
--------------------------------------------------------------------------------
/tests/test_ast_ifelse.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from sppl.compilers.ast_to_spe import Id
5 | from sppl.compilers.ast_to_spe import IfElse
6 | from sppl.compilers.ast_to_spe import Sample
7 | from sppl.compilers.ast_to_spe import Sequence
8 | from sppl.compilers.ast_to_spe import Transform
9 | from sppl.distributions import bernoulli
10 | from sppl.distributions import randint
11 | from sppl.math_util import allclose
12 |
13 | Y = Id('Y')
14 | X = Id('X')
15 |
16 | def test_ifelse_zero_conditions():
17 | command = Sequence(
18 | Sample(Y, randint(low=0, high=3)),
19 | IfElse(
20 | Y << {-1}, Transform(X, Y**(-1)),
21 | Y << {0}, Sample(X, bernoulli(p=1)),
22 | Y << {1}, Transform(X, Y),
23 | Y << {2}, Transform(X, Y**2),
24 | Y << {3}, Transform(X, Y**3),
25 | ))
26 | model = command.interpret()
27 | assert len(model.children) == 3
28 | assert len(model.weights) == 3
29 | assert allclose(model.weights[0], model.logprob(Y << {0}))
30 | assert allclose(model.weights[1], model.logprob(Y << {1}))
31 | assert allclose(model.weights[2], model.logprob(Y << {2}))
32 |
--------------------------------------------------------------------------------
/tests/test_ast_switch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | import pytest
7 |
8 | from numpy import linspace
9 |
10 | from sppl.distributions import bernoulli
11 | from sppl.distributions import beta
12 | from sppl.distributions import randint
13 | from sppl.compilers.ast_to_spe import IfElse
14 | from sppl.compilers.ast_to_spe import Sample
15 | from sppl.compilers.ast_to_spe import Sequence
16 | from sppl.compilers.ast_to_spe import Switch
17 | from sppl.compilers.ast_to_spe import Id
18 | from sppl.math_util import allclose
19 | from sppl.math_util import logsumexp
20 | from sppl.sym_util import binspace
21 |
22 | Y = Id('Y')
23 | X = Id('X')
24 |
25 | def test_simple_model_eq():
26 | command_switch = Sequence(
27 | Sample(Y, randint(low=0, high=4)),
28 | Switch(Y, range(0, 4), lambda i:
29 | Sample(X, bernoulli(p=1/(i+1)))))
30 | model_switch = command_switch.interpret()
31 |
32 | command_ifelse = Sequence(
33 | Sample(Y, randint(low=0, high=4)),
34 | IfElse(
35 | Y << {0}, Sample(X, bernoulli(p=1/(0+1))),
36 | Y << {1}, Sample(X, bernoulli(p=1/(1+1))),
37 | Y << {2}, Sample(X, bernoulli(p=1/(2+1))),
38 | Y << {3}, Sample(X, bernoulli(p=1/(3+1))),
39 | ))
40 | model_ifelse = command_ifelse.interpret()
41 |
42 | for model in [model_switch, model_ifelse]:
43 | symbols = model.get_symbols()
44 | assert symbols == {X, Y}
45 | assert allclose(
46 | model.logprob(X << {1}),
47 | logsumexp([-log(4) - log(i+1) for i in range(4)]))
48 |
49 | def test_simple_model_lte():
50 | command_switch = Sequence(
51 | Sample(Y, beta(a=2, b=3)),
52 | Switch(Y, binspace(0, 1, 5), lambda i:
53 | Sample(X, bernoulli(p=i.right))))
54 | model_switch = command_switch.interpret()
55 |
56 | command_ifelse = Sequence(
57 | Sample(Y, beta(a=2, b=3)),
58 | IfElse(
59 | Y <= 0, Sample(X, bernoulli(p=0)),
60 | Y <= 0.25, Sample(X, bernoulli(p=.25)),
61 | Y <= 0.50, Sample(X, bernoulli(p=.50)),
62 | Y <= 0.75, Sample(X, bernoulli(p=.75)),
63 | Y <= 1, Sample(X, bernoulli(p=1)),
64 | ))
65 | model_ifelse = command_ifelse.interpret()
66 |
67 | grid = [float(x) for x in linspace(0, 1, 5)]
68 | for model in [model_switch, model_ifelse]:
69 | symbols = model.get_symbols()
70 | assert symbols == {X, Y}
71 | assert allclose(
72 | model.logprob(X << {1}),
73 | logsumexp([
74 | model.logprob((il < Y) <= ih) + log(ih)
75 | for il, ih in zip(grid[:-1], grid[1:])
76 | ]))
77 |
78 | def test_simple_model_enumerate():
79 | command_switch = Sequence(
80 | Sample(Y, randint(low=0, high=4)),
81 | Switch(Y, enumerate(range(0, 4)), lambda i,j:
82 | Sample(X, bernoulli(p=1/(i+j+1)))))
83 | model = command_switch.interpret()
84 | assert allclose(model.prob(Y<<{0} & (X << {1})), .25 * 1/(0+0+1))
85 | assert allclose(model.prob(Y<<{1} & (X << {1})), .25 * 1/(1+1+1))
86 | assert allclose(model.prob(Y<<{2} & (X << {1})), .25 * 1/(2+2+1))
87 | assert allclose(model.prob(Y<<{3} & (X << {1})), .25 * 1/(3+3+1))
88 |
89 | def test_error_range():
90 | with pytest.raises(AssertionError):
91 | # Switch cases do not sum to one.
92 | command = Sequence(
93 | Sample(Y, randint(low=0, high=4)),
94 | Switch(Y, range(0, 3), lambda i:
95 | Sample(X, bernoulli(p=1/(i+1)))))
96 | command.interpret()
97 |
98 | def test_error_linspace():
99 | with pytest.raises(AssertionError):
100 | # Switch cases do not sum to one.
101 | command = Sequence(
102 | Sample(Y, beta(a=2, b=3)),
103 | Switch(Y, linspace(0, .5, 5), lambda i:
104 | Sample(X, bernoulli(p=i))))
105 | command.interpret()
106 |
107 | def test_error_binspace():
108 | with pytest.raises(AssertionError):
109 | # Switch cases do not sum to one.
110 | command = Sequence(
111 | Sample(Y, beta(a=2, b=3)),
112 | Switch(Y, binspace(0, .5, 5), lambda i:
113 | Sample(X, bernoulli(p=i.right))))
114 | command.interpret()
115 |
--------------------------------------------------------------------------------
/tests/test_ast_transform.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | from sppl.compilers.ast_to_spe import Id
7 | from sppl.compilers.ast_to_spe import IfElse
8 | from sppl.compilers.ast_to_spe import Otherwise
9 | from sppl.compilers.ast_to_spe import Sample
10 | from sppl.compilers.ast_to_spe import Sequence
11 | from sppl.compilers.ast_to_spe import Transform
12 | from sppl.distributions import bernoulli
13 | from sppl.distributions import norm
14 | from sppl.math_util import allclose
15 |
16 | X = Id('X')
17 | Y = Id('Y')
18 | Z = Id('Z')
19 |
20 | def test_simple_transform():
21 | command = Sequence(
22 | Sample(X, norm(loc=0, scale=1)),
23 | Transform(Z, X**2))
24 | model = command.interpret()
25 | assert model.get_symbols() == {Z, X}
26 | assert model.env == {Z:X**2, X:X}
27 | assert (model.logprob(Z > 0)) == 0
28 |
29 | def test_if_else_transform():
30 | model = Sequence(
31 | Sample(X, norm(loc=0, scale=1)),
32 | IfElse(
33 | X > 0, Transform(Z, X**2),
34 | Otherwise, Transform(Z, X))).interpret()
35 | assert model.children[0].env == {X:X, Z:X**2}
36 | assert model.children[1].env == {X:X, Z:X}
37 | assert allclose(model.children[0].logprob(Z > 0), 0)
38 | assert allclose(model.children[1].logprob(Z > 0), -float('inf'))
39 | assert allclose(model.logprob(Z > 0), -log(2))
40 |
41 | def test_if_else_transform_reverse():
42 | command = Sequence(
43 | Sample(X, norm(loc=0, scale=1)),
44 | Sample(Y, bernoulli(p=0.5)),
45 | IfElse(
46 | Y << {0}, Transform(Z, X**2),
47 | Otherwise, Transform(Z, X)))
48 | model = command.interpret()
49 | assert allclose(model.logprob(Z > 0), log(3) - log(4))
50 |
--------------------------------------------------------------------------------
/tests/test_burglary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | '''
5 | Burglary network example from:
6 |
7 | Artificial Intelligence: A Modern Approach (3rd Edition).
8 | Russel and Norvig, Fig 14.2 pp 512.
9 | '''
10 |
11 | from sppl.compilers.ast_to_spe import Id
12 | from sppl.compilers.ast_to_spe import IfElse
13 | from sppl.compilers.ast_to_spe import Otherwise
14 | from sppl.compilers.ast_to_spe import Sample
15 | from sppl.compilers.ast_to_spe import Sequence
16 | from sppl.distributions import bernoulli
17 |
18 | Burglary = Id('Burglary')
19 | Earthquake = Id('Earthquake')
20 | Alarm = Id('Alarm')
21 | JohnCalls = Id('JohnCalls')
22 | MaryCalls = Id('MaryCalls')
23 |
24 | program = Sequence(
25 | Sample(Burglary, bernoulli(p=0.001)),
26 | Sample(Earthquake, bernoulli(p=0.002)),
27 | IfElse(
28 | Burglary << {1},
29 | IfElse(
30 | Earthquake << {1}, Sample(Alarm, bernoulli(p=0.95)),
31 | Otherwise, Sample(Alarm, bernoulli(p=0.94))),
32 | Otherwise,
33 | IfElse(
34 | Earthquake << {1}, Sample(Alarm, bernoulli(p=0.29)),
35 | Otherwise, Sample(Alarm, bernoulli(p=0.001)))),
36 | IfElse(
37 | Alarm << {1}, Sequence(
38 | Sample(JohnCalls, bernoulli(p=0.90)),
39 | Sample(MaryCalls, bernoulli(p=0.70))),
40 | Otherwise, Sequence(
41 | Sample(JohnCalls, bernoulli(p=0.05)),
42 | Sample(MaryCalls, bernoulli(p=0.01))),
43 | ))
44 | model = program.interpret()
45 |
46 | def test_marginal_probability():
47 | # Query on pp. 514.
48 | event = ((JohnCalls << {1}) & (MaryCalls << {1}) & (Alarm << {1})
49 | & (Burglary << {0}) & (Earthquake << {0}))
50 | x = model.prob(event)
51 | assert str(x)[:8] == '0.000628'
52 |
53 | def test_conditional_probability():
54 | # Query on pp. 523
55 | event = (JohnCalls << {1}) & (MaryCalls << {1})
56 | model_condition = model.condition(event)
57 | x = model_condition.prob(Burglary << {1})
58 | assert str(x)[:5] == '0.284'
59 |
60 | def test_mutual_information():
61 | event_a = (JohnCalls << {1}) | (MaryCalls << {1})
62 | event_b = (Burglary << {1}) & (Earthquake << {0})
63 | print(model.mutual_information(event_a, event_b))
64 |
--------------------------------------------------------------------------------
/tests/test_cache_duplicate_subtrees.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | import numpy
7 |
8 | from sppl.distributions import bernoulli
9 | from sppl.distributions import choice
10 | from sppl.distributions import norm
11 | from sppl.spe import ProductSPE
12 | from sppl.spe import SumSPE
13 | from sppl.spe import spe_cache_duplicate_subtrees
14 | from sppl.transforms import Id
15 |
16 | rng = numpy.random.RandomState(1)
17 |
18 | W = Id('W')
19 | Y = Id('Y')
20 | X = [Id('X[0]'), Id('X[1]')]
21 | Z = [Id('Z[0]'), Id('Z[1]')]
22 |
23 | def test_cache_simple_leaf():
24 | spe = .5 * (W >> norm(loc=0, scale=1)) | .5 * (W >> norm(loc=0, scale=1))
25 | assert spe.children[0] is not spe.children[1]
26 | spe_cached = spe_cache_duplicate_subtrees(spe, {})
27 | assert spe_cached.children[0] is spe_cached.children[1]
28 |
29 | def test_cache_simple_sum_of_product():
30 | spe \
31 | = 0.3 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=1))) \
32 | | 0.7 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=2)))
33 | spe_cached = spe_cache_duplicate_subtrees(spe, {})
34 | assert spe_cached.children[0].children[0] is spe_cached.children[1].children[0]
35 |
36 | def test_cache_complex_sum_of_product():
37 | # Test case adapted from the SPE generated by
38 | # test_repeat.make_model_repeat(n=2)
39 | duplicate_subtrees = [None, None]
40 | for i in range(2):
41 | duplicate_subtrees[i] = SumSPE([
42 | ProductSPE([
43 | (X[0] >> bernoulli(p=.1)),
44 | SumSPE([
45 | (Z[0] >> bernoulli(p=.5))
46 | & (Y >> choice({'0':.1, '1': .9})),
47 | (Z[0] >> bernoulli(p=.1))
48 | & (Y >> choice({'0':.9, '1': .1}))
49 | ], weights=[log(.730), log(.270)])
50 | ]),
51 | ProductSPE([
52 | Z[0] >> bernoulli(p=.1),
53 | Y >> choice({'0':.9, '1':.1}),
54 | X[0] >> bernoulli(p=.5),
55 | ]),
56 | ], weights=[log(.925), log(.075)])
57 |
58 | assert duplicate_subtrees[0] == duplicate_subtrees[1]
59 | assert duplicate_subtrees[0] is not duplicate_subtrees[1]
60 |
61 | left_subtree = ProductSPE([
62 | X[1] >> bernoulli(p=.5),
63 | SumSPE([
64 | ProductSPE([
65 | duplicate_subtrees[0],
66 | Z[1] >> bernoulli(p=.5),
67 | ]),
68 | ProductSPE([
69 | Z[1] >> bernoulli(p=.7),
70 | SumSPE([
71 | Y >> choice({'0':.3, '1':.7})
72 | & X[0] >> bernoulli(p=.1)
73 | & Z[0] >> bernoulli(p=.1),
74 | Y >> choice({'0':.7, '1':.3})
75 | & X[0] >> bernoulli(p=.5)
76 | & Z[0] >> bernoulli(p=.5),
77 | ], weights=[log(.9), log(.1)])
78 | ])
79 | ], weights=[log(.783), log(.217)])
80 | ])
81 |
82 | right_subtree = ProductSPE([
83 | Z[1] >> bernoulli(p=.8),
84 | X[1] >> bernoulli(p=.1),
85 | duplicate_subtrees[1]
86 | ])
87 |
88 | spe = .92 * left_subtree | .08 * right_subtree
89 |
90 | spe_cached = spe_cache_duplicate_subtrees(spe, {})
91 | assert spe_cached.children[0].children[1].children[0].children[0] is duplicate_subtrees[0]
92 | assert spe_cached.children[1].children[2] is duplicate_subtrees[0]
93 |
--------------------------------------------------------------------------------
/tests/test_clickgraph.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.distributions import bernoulli
7 | from sppl.distributions import beta
8 | from sppl.distributions import randint
9 | from sppl.distributions import uniform
10 |
11 | from sppl.compilers.ast_to_spe import For
12 | from sppl.compilers.ast_to_spe import Id
13 | from sppl.compilers.ast_to_spe import IdArray
14 | from sppl.compilers.ast_to_spe import IfElse
15 | from sppl.compilers.ast_to_spe import Sample
16 | from sppl.compilers.ast_to_spe import Sequence
17 | from sppl.compilers.ast_to_spe import Switch
18 | from sppl.compilers.ast_to_spe import Transform
19 |
20 | from sppl.sym_util import binspace
21 |
22 | simAll = Id('simAll')
23 | sim = IdArray('sim', 5)
24 | p1 = IdArray('p1', 5)
25 | p2 = IdArray('p2', 5)
26 | clickA = IdArray('clickA', 5)
27 | clickB = IdArray('clickB', 5)
28 |
29 | ns = 5
30 | nd = ns - 1
31 |
32 | def get_command_randint():
33 | return Sequence(
34 | Sample(simAll, randint(low=0, high=ns)),
35 | For(0, 5, lambda k:
36 | Switch (simAll, range(0, ns), lambda i:
37 | Sequence (
38 | Sample (sim[k], bernoulli(p=i/nd)),
39 | Sample (p1[k], randint(low=0, high=ns)),
40 | IfElse(
41 | sim[k] << {1},
42 | Sequence(
43 | Transform (p2[k], p1[k]),
44 | Switch (p1[k], range(ns), lambda j:
45 | Sequence(
46 | Sample(clickA[k], bernoulli(p=i/nd)),
47 | Sample(clickB[k], bernoulli(p=i/nd))))),
48 | True,
49 | Sequence(
50 | Sample (p2[k], randint(low=0, high=ns)),
51 | Switch (p1[k], range(ns), lambda j:
52 | Sample(clickA[k], bernoulli(p=j/nd))),
53 | Switch (p2[k], range(ns), lambda j:
54 | Sample (clickB[k], bernoulli(p=j/nd)))))))))
55 |
56 | def get_command_beta():
57 | return Sequence(
58 | Sample(simAll, beta(a=2, b=3)),
59 | For(0, 5, lambda k:
60 | Switch (simAll, binspace(0, 1, ns), lambda i:
61 | Sequence (
62 | Sample (sim[k], bernoulli(p=i.right)),
63 | Sample (p1[k], uniform()),
64 | IfElse(
65 | sim[k] << {1},
66 | Sequence(
67 | Transform (p2[k], p1[k]),
68 | Switch (p1[k], binspace(0, 1, ns), lambda j:
69 | Sequence(
70 | Sample(clickA[k], bernoulli(p=i.right)),
71 | Sample(clickB[k], bernoulli(p=i.right))))),
72 | True,
73 | Sequence(
74 | Sample (p2[k], uniform()),
75 | Switch (p1[k], binspace(0, 1, ns), lambda j:
76 | Sample(clickA[k], bernoulli(p=j.right))),
77 | Switch (p2[k], binspace(0, 1, ns), lambda j:
78 | Sample (clickB[k], bernoulli(p=j.right)))))))))
79 |
80 | @pytest.mark.parametrize('get_command', [get_command_randint, get_command_beta])
81 | def test_clickgraph_crash__ci_(get_command):
82 | command = get_command()
83 | model = command.interpret()
84 | model_condition = model.condition(
85 | (clickA[0] << {1}) & (clickB[0] << {1})
86 | & (clickA[1] << {1}) & (clickB[1] << {1})
87 | & (clickA[2] << {1}) & (clickB[2] << {1})
88 | & (clickA[3] << {0}) & (clickB[3] << {0})
89 | & (clickA[4] << {0}) & (clickB[4] << {0}))
90 | probabilities = [model_condition.prob(simAll << {i}) for i in range(ns)]
91 | assert all(p <= probabilities[nd-1] for p in probabilities)
92 |
--------------------------------------------------------------------------------
/tests/test_dnf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.dnf import dnf_factor
7 | from sppl.dnf import dnf_non_disjoint_clauses
8 | from sppl.dnf import dnf_to_disjoint_union
9 |
10 | from sppl.transforms import EventOr
11 | from sppl.transforms import Exp
12 | from sppl.transforms import Id
13 | from sppl.transforms import Log
14 | from sppl.transforms import Sqrt
15 |
16 | (X0, X1, X2, X3, X4, X5) = [Id("X%d" % (i,)) for i in range(6)]
17 |
18 | events = [
19 | X0 < 0,
20 | (X1<<(0,1)) & (X2<0),
21 | (X1<0) | (X2<0),
22 | (X0<0) | ((X1<0) & ~(X2<0)),
23 | ((X0<0) & ~(0 ~(A | B | C) | ~D
76 | assert event.to_dnf() == (~A & ~B & ~C) | ~D
77 |
78 | event = ~((A | ~B | C) & ~D)
79 | # => ~(A | B | C) | ~D
80 | assert event.to_dnf() == (~A & B & ~C) | D
81 |
82 | def test_dnf_factor():
83 | E00 = Exp(X0) > 0
84 | E01 = X0 < 10
85 | E10 = X1 < 10
86 | E20 = (X2**2 - X2*3) < 0
87 | E30 = X3 > 10
88 | E31 = (Sqrt(2*X3)) < 0
89 | E40 = X4 > 0
90 | E41 = X4 << [1, 5]
91 | E50 = 10*Log(X5) + 9 > 5
92 |
93 | event = (E00)
94 | event_dnf = event.to_dnf()
95 | dnf = dnf_factor(event_dnf)
96 | assert len(dnf) == 1
97 | assert dnf[0][X0] == E00
98 |
99 | event = E00 & E01
100 | event_dnf = event.to_dnf()
101 | dnf = dnf_factor(event_dnf)
102 | assert len(dnf) == 1
103 | assert dnf[0][X0] == E00 & E01
104 |
105 | event = E00 | E01
106 | event_dnf = event.to_dnf()
107 | dnf = dnf_factor(event_dnf)
108 | assert len(dnf) == 2
109 | assert dnf[0][X0] == E00
110 | assert dnf[1][X0] == E01
111 |
112 | event = E00 | (E01 & E10)
113 | event_dnf = event.to_dnf()
114 | dnf = dnf_factor(event_dnf, {X0: 0, X1: 0})
115 | assert len(dnf) == 2
116 | assert dnf[0][0] == E00
117 | assert dnf[1][0] == E01 & E10
118 |
119 | event = (E00 & E01 & E10 & E30 & E40) | (E20 & E50 & E31 & ~E41)
120 | # For the second clause we have:
121 | # ~E41 = (-oo, 1) U (1, 5) U (5, oo)
122 | # so the second clause becomes
123 | # = (E20 & E50 & E31 & ((-oo, 1) U (1, 5) U (5, oo)))
124 | # = (E20 & E50 & E31 & (-oo, 1))
125 | # or (E20 & E50 & E31 & (1, 5))
126 | # or (E20 & E50 & E31 & (5, oo))
127 | event_dnf = event.to_dnf()
128 | event_factor = dnf_factor(event_dnf)
129 | assert len(event_factor) == 4
130 | # clause 0
131 | assert len(event_factor[0]) == 4
132 | assert event_factor[0][X0] == E00 & E01
133 | assert event_factor[0][X1] == E10
134 | assert event_factor[0][X3] == E30
135 | assert event_factor[0][X4] == E40
136 | # clause 1
137 | assert len(event_factor[1]) == 4
138 | assert event_factor[1][X3] == E31
139 | assert event_factor[1][X2] == E20
140 | assert event_factor[1][X4] == (X4 < 1)
141 | assert event_factor[1][X5] == E50
142 | # clause 2
143 | assert len(event_factor[2]) == 4
144 | assert event_factor[2][X3] == E31
145 | assert event_factor[2][X2] == E20
146 | assert event_factor[2][X4] == (1 < (X4 < 5))
147 | assert event_factor[2][X5] == E50
148 | # clause 3
149 | assert len(event_factor[3]) == 4
150 | assert event_factor[3][X3] == E31
151 | assert event_factor[3][X2] == E20
152 | assert event_factor[3][X4] == (5 < X4)
153 | assert event_factor[3][X5] == E50
154 |
155 | def test_dnf_factor_1():
156 | A = Exp(X0) > 0
157 | B = X0 < 10
158 | C = X1 < 10
159 | D = X2 < 0
160 |
161 | event = A & B & C & ~D
162 | event_dnf = event.to_dnf()
163 | event_factor = dnf_factor(event_dnf, {X0:0, X1:0, X2:0, X3:1, X4:1, X5:2})
164 | assert len(event_factor) == 1
165 | assert event_factor[0][0] == event
166 |
167 | def test_dnf_factor_2():
168 | A = X0 < 1
169 | B = X4 < 1
170 | C = X5 < 1
171 | event = A & B & C
172 | event_dnf = event.to_dnf()
173 | event_factor = dnf_factor(event_dnf, {X0:0, X1:0, X2:0, X3:1, X4:1, X5:2})
174 | assert len(event_factor) == 1
175 | assert event_factor[0][0] == A
176 | assert event_factor[0][1] == B
177 | assert event_factor[0][2] == C
178 |
179 | def test_dnf_factor_3():
180 | A = (Exp(X0) > 0)
181 | B = X0 < 10
182 | C = X1 < 10
183 | D = X4 > 0
184 | E = (X2**2 - 3*X2) << (0, 10, 100)
185 | F = (10*Log(X5) + 9) > 5
186 | G = X4 < 4
187 |
188 | event = (A & B & C & ~D) | (E & F & G)
189 | event_dnf = event.to_dnf()
190 | event_factor = dnf_factor(event_dnf, {X0:0, X1:0, X2:0, X3:1, X4:1, X5:2})
191 | assert len(event_factor) == 2
192 | assert event_factor[0][0] == A & B & C
193 | assert event_factor[0][1] == ~D
194 | assert event_factor[1][0] == E
195 | assert event_factor[1][1] == G
196 | assert event_factor[1][2] == F
197 |
198 | def test_dnf_non_disjoint_clauses():
199 | X = Id('X')
200 | Y = Id('Y')
201 | Z = Id('Z')
202 |
203 | event = (X > 0) | (Y < 0)
204 | overlaps = dnf_non_disjoint_clauses(event)
205 | assert overlaps == {1: [0]}
206 |
207 | event = (X > 0) | ((X < 0) & (Y < 0))
208 | overlaps = dnf_non_disjoint_clauses(event)
209 | assert not overlaps
210 |
211 | event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1))
212 | overlaps = dnf_non_disjoint_clauses(event)
213 | assert overlaps == {2: [0]}
214 |
215 | event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1) & (Z > 1))
216 | overlaps = dnf_non_disjoint_clauses(event)
217 | assert not overlaps
218 |
219 | event = ((X**2 < 9)) | (1 < X)
220 | overlaps = dnf_non_disjoint_clauses(event)
221 | assert overlaps == {1: [0]}
222 |
223 | event = ((X**2 < 9) & (0 < X < 1)) | (1 < X)
224 | overlaps = dnf_non_disjoint_clauses(event)
225 | assert not overlaps
226 |
227 | def test_event_to_disjiont_union_numerical():
228 | X = Id('X')
229 | Y = Id('Y')
230 | Z = Id('Z')
231 |
232 | for event in [
233 | (X > 0) | (X < 3),
234 | (X > 0) | (Y < 3),
235 | ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0),
236 | ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X <<{1, 3}),
237 | ]:
238 | event_dnf = dnf_to_disjoint_union(event)
239 | assert not dnf_non_disjoint_clauses(event_dnf)
240 |
241 | def test_event_to_disjoint_union_nominal():
242 | X = Id('X')
243 | Y = Id('Y')
244 | event = (X << {'1'}) | (X << {'1', '2'})
245 | assert dnf_to_disjoint_union(event) == X << {'1', '2'}
246 |
247 | event = (X << {'1'}) | ~(Y << {'1'})
248 | assert dnf_to_disjoint_union(event) == EventOr([
249 | (X << {'1'}), ~(Y << {'1'}) & ~(X << {'1'})
250 | ])
251 |
252 | def test_event_to_disjoint_union_five():
253 | # This test case can be visualized as follows:
254 | # fig, ax = plt.subplots()
255 | # ax.add_patch(Rectangle((2, 5), 6, 3, fill=False))
256 | # ax.add_patch(Rectangle((3, 7), 1, 4, fill=False))
257 | # ax.add_patch(Rectangle((3.5, 2), 2, 7, fill=False))
258 | # ax.add_patch(Rectangle((4.5, 1), 2, 5, fill=False))
259 | # ax.add_patch(Rectangle((5, 7), 2, 3, fill=False))
260 | #
261 | # // ((2 < X < 8) & (5 < Y < 8))
262 | # // ((3 < X < 4) & (8 <= Y < 11))
263 | # // ((3.5 < X < 5.5) & (2 < Y <= 5))
264 | # // ((5.5 <= X < 6.5) & (1 < Y <= 5))
265 | # // ((4.5 < X < 5.5) & (1 < Y <= 2))
266 | # // ((5 < X < 7) & (8 <= Y < 10))
267 | #
268 | # fig, ax = plt.subplots()
269 | # ax.add_patch(Rectangle((2, 5), 6, 3, fill=False))
270 | # ax.add_patch(Rectangle((3, 8), 1, 3, fill=False))
271 | # ax.add_patch(Rectangle((3.5, 2), 2, 3, fill=False))
272 | # ax.add_patch(Rectangle((5.5, 1), 1, 4, fill=False))
273 | # ax.add_patch(Rectangle((4.5, 1), 1, 1, fill=False))
274 | # ax.add_patch(Rectangle((5, 8), 2, 2, fill=False))
275 | X = Id('X')
276 | Y = Id('Y')
277 | E1 = ((2 < X) < 8) & ((5 < Y) < 8)
278 | E2 = ((3 < X) < 4) & ((7 < Y) < 11)
279 | E3 = ((3.5 < X) < 5.5) & ((2 < Y) < 7)
280 | E4 = ((4.5 < X) < 6.5) & ((1 < Y) < 6)
281 | E5 = ((5 < X) < 7) & ((7 < Y) < 10)
282 | event = E1 | E2 | E3 | E4 | E5
283 | dnf_to_disjoint_union(event)
284 |
--------------------------------------------------------------------------------
/tests/test_event_evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.transforms import Exp
7 | from sppl.transforms import Id
8 | from sppl.transforms import Log
9 |
10 | X = Id('X')
11 | Y = Id('Y')
12 | W = Id('W')
13 | Z = Id('Z')
14 |
15 | def test_event_basic_invertible():
16 | expr = X**2 + 10*X
17 |
18 | event = (0 < expr)
19 | assert event.evaluate({X: 10})
20 | assert not (~event).evaluate({X: 10})
21 |
22 | event = (0 > expr)
23 | assert not event.evaluate({X: 10})
24 | assert (~event).evaluate({X: 10})
25 |
26 | event = (0 < expr) < 100
27 | for val in [0, 10]:
28 | assert not event.evaluate({X: val})
29 | assert (~event).evaluate({X: val})
30 |
31 | event = (0 <= expr) <= 100
32 | x_eq_100 = (expr << {100}).solve()
33 | for val in [0, list(x_eq_100)[0]]:
34 | assert event.evaluate({X: val})
35 | assert not (~event).evaluate({X: val})
36 |
37 | event = expr << {11}
38 | assert event.evaluate({X: 1})
39 | assert not event.evaluate({X: 3})
40 | assert not (~event).evaluate({X: 1})
41 | assert (~event).evaluate({X :3})
42 |
43 | with pytest.raises(ValueError):
44 | event.evaluate({Y: 10})
45 |
46 | def test_event_compound():
47 | expr0 = abs(X)**2 + 10*abs(X)
48 | expr1 = Y-1
49 |
50 | event = (0 < expr0) | (0 < expr1)
51 | assert not event.evaluate({X: 0, Y: 0})
52 | for i in range(1, 100):
53 | assert event.evaluate({X: i, Y: i})
54 | with pytest.raises(ValueError):
55 | event.evaluate({X: 0})
56 | with pytest.raises(ValueError):
57 | event.evaluate({Y: 0})
58 |
59 | event = (100 <= expr0) & ((expr0 < 200) | (3 < expr1))
60 | assert not event.evaluate({X: 0, Y: 0})
61 |
62 | x_eq_150 = (expr0 << {150}).solve()
63 | assert event.evaluate({X: list(x_eq_150)[0], Y: 0})
64 | assert event.evaluate({X: list(x_eq_150)[1], Y: 0})
65 |
66 | x_eq_500 = (expr0 << {500}).solve()
67 | assert not event.evaluate({X: list(x_eq_500)[1], Y: 4})
68 | assert event.evaluate({X: list(x_eq_500)[1], Y: 5})
69 |
70 | def test_event_solve_multi():
71 | event = (Exp(abs(3*X**2)) > 1) | (Log(Y) < 0.5)
72 | with pytest.raises(ValueError):
73 | event.solve()
74 | event = (Exp(abs(3*X**2)) > 1) & (Log(Y) < 0.5)
75 | with pytest.raises(ValueError):
76 | event.solve()
77 |
--------------------------------------------------------------------------------
/tests/test_indian_gpa.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | '''
5 | Indian GPA example from:
6 |
7 | Discrete-Continuous Mixtures in Probabilistic Programming: Generalized
8 | Semantics and Inference Algorithms, Wu et. al., ICML 2018.
9 | https://arxiv.org/pdf/1806.02027.pdf
10 | '''
11 |
12 | import pytest
13 |
14 | from sppl.compilers.ast_to_spe import Id
15 | from sppl.compilers.ast_to_spe import IfElse
16 | from sppl.compilers.ast_to_spe import Sample
17 | from sppl.compilers.ast_to_spe import Sequence
18 | from sppl.compilers.sppl_to_python import SPPL_Compiler
19 | from sppl.distributions import atomic
20 | from sppl.distributions import choice
21 | from sppl.distributions import uniform
22 | from sppl.math_util import allclose
23 | from sppl.sets import Interval
24 | from sppl.spe import ExposedSumSPE
25 |
26 | Nationality = Id('Nationality')
27 | Perfect = Id('Perfect')
28 | GPA = Id('GPA')
29 |
30 | def model_no_latents():
31 | return \
32 | 0.5 * ( # American student
33 | 0.99 * (GPA >> uniform(loc=0, scale=4)) | \
34 | 0.01 * (GPA >> atomic(loc=4))) | \
35 | 0.5 * ( # Indian student
36 | 0.99 * (GPA >> uniform(loc=0, scale=10)) | \
37 | 0.01 * (GPA >> atomic(loc=10)))
38 |
39 | def model_exposed():
40 | return ExposedSumSPE(
41 | spe_weights=(Nationality >> choice({'India': 0.5, 'USA': 0.5})),
42 | children={
43 | # American student.
44 | 'USA': ExposedSumSPE(
45 | spe_weights=(Perfect >> choice({'True': 0.01, 'False': 0.99})),
46 | children={
47 | 'False' : GPA >> uniform(loc=0, scale=4),
48 | 'True' : GPA >> atomic(loc=4),
49 | }),
50 | # Indian student.
51 | 'India': ExposedSumSPE(
52 | spe_weights=(Perfect >> choice({'True': 0.01, 'False': 0.99})),
53 | children={
54 | 'False' : GPA >> uniform(loc=0, scale=10),
55 | 'True' : GPA >> atomic(loc=10),
56 | })},
57 | )
58 |
59 | def model_ifelse_exhuastive():
60 | command = Sequence(
61 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})),
62 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})),
63 | IfElse(
64 | (Nationality << {'India'}) & (Perfect << {'False'}),
65 | Sample(GPA, uniform(loc=0, scale=10))
66 | ,
67 | (Nationality << {'India'}) & (Perfect << {'True'}),
68 | Sample(GPA, atomic(loc=10))
69 | ,
70 | (Nationality << {'USA'}) & (Perfect << {'False'}),
71 | Sample(GPA, uniform(loc=0, scale=4))
72 | ,
73 | (Nationality << {'USA'}) & (Perfect << {'True'}),
74 | Sample(GPA, atomic(loc=4))))
75 | return command.interpret()
76 |
77 | def model_ifelse_non_exhuastive():
78 | Nationality = Id('Nationality')
79 | Perfect = Id('Perfect')
80 | GPA = Id('GPA')
81 | command = Sequence(
82 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})),
83 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})),
84 | IfElse(
85 | (Nationality << {'India'}) & (Perfect << {'False'}),
86 | Sample(GPA, uniform(loc=0, scale=10))
87 | ,
88 | (Nationality << {'India'}) & (Perfect << {'True'}),
89 | Sample(GPA, atomic(loc=10))
90 | ,
91 | (Nationality << {'USA'}) & (Perfect << {'False'}),
92 | Sample(GPA, uniform(loc=0, scale=4))
93 | ,
94 | True,
95 | Sample(GPA, atomic(loc=4))))
96 | return command.interpret()
97 |
98 | def model_ifelse_nested():
99 | Nationality = Id('Nationality')
100 | Perfect = Id('Perfect')
101 | GPA = Id('GPA')
102 | command = Sequence(
103 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})),
104 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})),
105 | IfElse(
106 | Nationality << {'India'},
107 | IfElse(
108 | Perfect << {'True'}, Sample(GPA, atomic(loc=10)),
109 | Perfect << {'False'}, Sample(GPA, uniform(scale=10)),
110 | ),
111 | Nationality << {'USA'},
112 | IfElse(
113 | Perfect << {'True'}, Sample(GPA, atomic(loc=4)),
114 | Perfect << {'False'}, Sample(GPA, uniform(scale=4)),
115 | )))
116 | return command.interpret()
117 |
118 | def model_perfect_nested():
119 | Nationality = Id('Nationality')
120 | Perfect = Id('Perfect')
121 | GPA = Id('GPA')
122 | command = Sequence(
123 | Sample(Nationality, choice({'India': 0.5, 'USA': 0.5})),
124 | IfElse(
125 | Nationality << {'India'}, Sequence(
126 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})),
127 | IfElse(
128 | Perfect << {'True'}, Sample(GPA, atomic(loc=10)),
129 | True, Sample(GPA, uniform(scale=10))
130 | )),
131 | Nationality << {'USA'}, Sequence(
132 | Sample(Perfect, choice({'True': 0.01, 'False': 0.99})),
133 | IfElse(
134 | Perfect << {'True'}, Sample(GPA, atomic(loc=4)),
135 | True, Sample(GPA, uniform(scale=4)),
136 | ))))
137 | return command.interpret()
138 |
139 | def model_ifelse_exhuastive_compiled():
140 | compiler = SPPL_Compiler('''
141 | Nationality ~= choice({'India': 0.5, 'USA': 0.5})
142 | Perfect ~= choice({'True': 0.01, 'False': 0.99})
143 | if (Nationality == 'India') & (Perfect == 'False'):
144 | GPA ~= uniform(loc=0, scale=10)
145 | elif (Nationality == 'India') & (Perfect == 'True'):
146 | GPA ~= atomic(loc=10)
147 | elif (Nationality == 'USA') & (Perfect == 'False'):
148 | GPA ~= uniform(loc=0, scale=4)
149 | elif (Nationality == 'USA') & (Perfect == 'True'):
150 | GPA ~= atomic(loc=4)
151 | ''')
152 | namespace = compiler.execute_module()
153 | return namespace.model
154 |
155 | def model_ifelse_non_exhuastive_compiled():
156 | compiler = SPPL_Compiler('''
157 | Nationality ~= choice({'India': 0.5, 'USA': 0.5})
158 | Perfect ~= choice({'True': 0.01, 'False': 0.99})
159 | if (Nationality == 'India') & (Perfect == 'False'):
160 | GPA ~= uniform(loc=0, scale=10)
161 | elif (Nationality == 'India') & (Perfect == 'True'):
162 | GPA ~= atomic(loc=10)
163 | elif (Nationality == 'USA') & (Perfect == 'False'):
164 | GPA ~= uniform(loc=0, scale=4)
165 | else:
166 | GPA ~= atomic(loc=4)
167 | ''')
168 | namespace = compiler.execute_module()
169 | return namespace.model
170 |
171 | def model_ifelse_nested_compiled():
172 | compiler = SPPL_Compiler('''
173 | Nationality ~= choice({'India': 0.5, 'USA': 0.5})
174 | Perfect ~= choice({'True': 0.01, 'False': 0.99})
175 | if (Nationality == 'India'):
176 | if (Perfect == 'False'):
177 | GPA ~= uniform(loc=0, scale=10)
178 | else:
179 | GPA ~= atomic(loc=10)
180 | elif (Nationality == 'USA'):
181 | if (Perfect == 'False'):
182 | GPA ~= uniform(loc=0, scale=4)
183 | elif (Perfect == 'True'):
184 | GPA ~= atomic(loc=4)
185 | ''')
186 | namespace = compiler.execute_module()
187 | return namespace.model
188 |
189 | def model_perfect_nested_compiled():
190 | compiler = SPPL_Compiler('''
191 | Nationality ~= choice({'India': 0.5, 'USA': 0.5})
192 | if (Nationality == 'India'):
193 | Perfect ~= choice({'True': 0.01, 'False': 0.99})
194 | if (Perfect == 'False'):
195 | GPA ~= uniform(loc=0, scale=10)
196 | else:
197 | GPA ~= atomic(loc=10)
198 | elif (Nationality == 'USA'):
199 | Perfect ~= choice({'True': 0.01, 'False': 0.99})
200 | if (Perfect == 'False'):
201 | GPA ~= uniform(loc=0, scale=4)
202 | else:
203 | GPA ~= atomic(loc=4)
204 | ''')
205 | namespace = compiler.execute_module()
206 | return namespace.model
207 |
208 | @pytest.mark.parametrize('get_model', [
209 | # Manual
210 | model_no_latents,
211 | model_exposed,
212 | # Interpreter
213 | model_ifelse_exhuastive,
214 | model_ifelse_non_exhuastive,
215 | model_ifelse_nested,
216 | model_perfect_nested,
217 | # Compiler
218 | model_ifelse_exhuastive_compiled,
219 | model_ifelse_non_exhuastive_compiled,
220 | model_ifelse_nested_compiled,
221 | model_perfect_nested_compiled,
222 | ])
223 | def test_prior(get_model):
224 | model = get_model()
225 | GPA = Id('GPA')
226 | assert allclose(model.prob(GPA << {10}), 0.5*0.01)
227 | assert allclose(model.prob(GPA << {4}), 0.5*0.01)
228 | assert allclose(model.prob(GPA << {5}), 0)
229 | assert allclose(model.prob(GPA << {1}), 0)
230 |
231 | assert allclose(model.prob((2 < GPA) < 4),
232 | 0.5*0.99*0.5 + 0.5*0.99*0.2)
233 | assert allclose(model.prob((2 <= GPA) < 4),
234 | 0.5*0.99*0.5 + 0.5*0.99*0.2)
235 | assert allclose(model.prob((2 < GPA) <= 4),
236 | 0.5*(0.99*0.5 + 0.01) + 0.5*0.99*0.2)
237 | assert allclose(model.prob((2 < GPA) <= 8),
238 | 0.5*(0.99*0.5 + 0.01) + 0.5*0.99*0.6)
239 | assert allclose(model.prob((2 < GPA) < 10),
240 | 0.5*(0.99*0.5 + 0.01) + 0.5*0.99*0.8)
241 | assert allclose(model.prob((2 < GPA) <= 10),
242 | 0.5*(0.99*0.5 + 0.01) + 0.5*(0.99*0.8 + 0.01))
243 |
244 | assert allclose(model.prob(((2 <= GPA) < 4) | (7 < GPA)),
245 | (0.5*0.99*0.5 + 0.5*0.99*0.2) + (0.5*(0.99*0.3 + 0.01)))
246 |
247 | assert allclose(model.prob(((2 <= GPA) < 4) & (7 < GPA)), 0)
248 |
249 | def test_condition():
250 | model = model_no_latents()
251 | GPA = Id('GPA')
252 | model_condition = model.condition(GPA << {4} | GPA << {10})
253 | assert len(model_condition.children) == 2
254 | assert model_condition.children[0].support == Interval.Ropen(4, 5)
255 | assert model_condition.children[1].support == Interval.Ropen(10, 11)
256 |
257 | model_condition = model.condition((0 < GPA < 4))
258 | assert len(model_condition.children) == 2
259 | assert model_condition.children[0].support \
260 | == model_condition.children[1].support
261 | assert allclose(
262 | model_condition.children[0].logprob(GPA < 1),
263 | model_condition.children[1].logprob(GPA < 1))
264 |
265 | def test_logpdf():
266 | model = model_no_latents()
267 | assert allclose(0.005, model.pdf({GPA: 4}))
268 | assert allclose(0.005, model.pdf({GPA: 10}))
269 |
--------------------------------------------------------------------------------
/tests/test_logpdf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | import pytest
7 |
8 | from sppl.distributions import atomic
9 | from sppl.distributions import choice
10 | from sppl.distributions import discrete
11 | from sppl.distributions import gamma
12 | from sppl.distributions import norm
13 | from sppl.distributions import poisson
14 | from sppl.math_util import allclose
15 | from sppl.math_util import isinf_neg
16 | from sppl.math_util import logsumexp
17 | from sppl.spe import SumSPE
18 | from sppl.spe import ProductSPE
19 | from sppl.transforms import Id
20 |
21 | X = Id('X')
22 | Y = Id('Y')
23 | Z = Id('Z')
24 |
25 | def test_logpdf_real_continuous():
26 | spe = (X >> norm())
27 | assert allclose(spe.logpdf({X: 0}), norm().dist.logpdf(0))
28 |
29 | def test_logpdf_real_discrete():
30 | spe = (X >> poisson(mu=2))
31 | assert isinf_neg(spe.logpdf({X: 1.5}))
32 | assert isinf_neg(spe.logpdf({X: '1'}))
33 | assert not isinf_neg(spe.logpdf({X: 0}))
34 |
35 | def test_logpdf_nominal():
36 | spe = (X >> choice({'a' : .6, 'b': .4}))
37 | assert isinf_neg(spe.logpdf({X: 1.5}))
38 | allclose(spe.logpdf({X: 'a'}), log(.6))
39 |
40 | def test_logpdf_mixture_real_continuous_continuous():
41 | spe = X >> (.3*norm() | .7*gamma(a=1))
42 | assert allclose(
43 | spe.logpdf({X: .5}),
44 | logsumexp([
45 | log(.3) + spe.children[0].logpdf({X: 0.5}),
46 | log(.7) + spe.children[1].logpdf({X: 0.5}),
47 | ]))
48 |
49 | @pytest.mark.xfail
50 | def test_logpdf_mixture_real_continuous_discrete():
51 | spe = X >> (.3*norm() | .7*poisson(mu=1))
52 | assert allclose(
53 | spe.logpdf(X << {.5}),
54 | logsumexp([
55 | log(.3) + spe.children[0].logpdf({X: 0.5}),
56 | log(.7) + spe.children[1].logpdf({X: 0.5}),
57 | ]))
58 | assert False, 'Invalid base measure addition'
59 |
60 | def test_logpdf_mixture_nominal():
61 | spe = SumSPE([X >> norm(), X >> choice({'a':.1, 'b':.9})], [log(.4), log(.6)])
62 | assert allclose(
63 | spe.logpdf({X: .5}),
64 | log(.4) + spe.children[0].logpdf({X: .5}))
65 | assert allclose(
66 | spe.logpdf({X: 'a'}),
67 | log(.6) + spe.children[1].logpdf({X: 'a'}))
68 |
69 | def test_logpdf_error_event():
70 | spe = (X >> norm())
71 | with pytest.raises(Exception):
72 | spe.logpdf(X < 1)
73 |
74 | def test_logpdf_error_transform_base():
75 | spe = (X >> norm())
76 | with pytest.raises(Exception):
77 | spe.logpdf({X**2: 0})
78 |
79 | def test_logpdf_error_transform_env():
80 | spe = (X >> norm()).transform(Z, X**2)
81 | with pytest.raises(Exception):
82 | spe.logpdf({Z: 0})
83 |
84 | def test_logpdf_bivariate():
85 | spe = (X >> norm()) & (Y >> choice({'a': .5, 'b': .5}))
86 | assert allclose(
87 | spe.logpdf({X: 0, Y: 'a'}),
88 | norm().dist.logpdf(0) + log(.5))
89 |
90 | def test_logpdf_lexicographic_either():
91 | spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.1, 2:.9})) \
92 | | .25*(X >> atomic(loc=0) & Y >> norm() & Z >> norm())
93 | # Lexicographic, Branch 1
94 | assignment = {X:0, Y:0, Z:2}
95 | assert allclose(
96 | spe.logpdf(assignment),
97 | log(.75) + norm().dist.logpdf(0) + log(1) + log(.9))
98 | assert isinstance(spe.constrain(assignment), ProductSPE)
99 | # Lexicographic, Branch 2
100 | assignment = {X:0, Y:0, Z:0}
101 | assert allclose(
102 | spe.logpdf(assignment),
103 | log(.25) + log(1) + norm().dist.logpdf(0) + norm().dist.logpdf(0))
104 | assert isinstance(spe.constrain(assignment), ProductSPE)
105 |
106 | def test_logpdf_lexicographic_both():
107 | spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.2, 2:.8})) \
108 | | .25*(X >> discrete({1:.5, 2:.5}) & Y >> norm() & Z >> atomic(loc=2))
109 | # Lexicographic, Mix
110 | assignment = {X:1, Y:0, Z:2}
111 | assert allclose(
112 | spe.logpdf(assignment),
113 | logsumexp([
114 | log(.75) + norm().dist.logpdf(1) + log(1) + log(.8),
115 | log(.25) + log(.5) + norm().dist.logpdf(0) + log(1)]))
116 | assert isinstance(spe.constrain(assignment), SumSPE)
117 |
--------------------------------------------------------------------------------
/tests/test_mutual_information.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import exp
5 | from math import log
6 |
7 | import numpy
8 | import pytest
9 |
10 | from sppl.distributions import norm
11 | from sppl.math_util import allclose
12 | from sppl.math_util import isinf_neg
13 | from sppl.math_util import logdiffexp
14 | from sppl.math_util import logsumexp
15 | from sppl.spe import Memo
16 | from sppl.transforms import Id
17 |
18 | prng = numpy.random.RandomState(1)
19 |
20 | def entropy(spe, A, memo):
21 | lpA1 = spe.logprob(A, memo=memo)
22 | lpA0 = logdiffexp(0, lpA1)
23 | e1 = -exp(lpA1) * lpA1 if not isinf_neg(lpA1) else 0
24 | e0 = -exp(lpA0) * lpA0 if not isinf_neg(lpA0) else 0
25 | return e1 + e0
26 | def entropyc(spe, A, B, memo):
27 | lpB1 = spe.logprob(B)
28 | lpB0 = logdiffexp(0, lpB1)
29 | lp11 = spe.logprob(B & A, memo=memo)
30 | lp10 = spe.logprob(B & ~A, memo=memo)
31 | lp01 = spe.logprob(~B & A, memo=memo)
32 | # lp00 = self.logprob(~B & ~A, memo)
33 | lp00 = logdiffexp(0, logsumexp([lp11, lp10, lp01]))
34 | m11 = exp(lp11) * (lpB1 - lp11) if not isinf_neg(lp11) else 0
35 | m10 = exp(lp10) * (lpB1 - lp10) if not isinf_neg(lp10) else 0
36 | m01 = exp(lp01) * (lpB0 - lp01) if not isinf_neg(lp01) else 0
37 | m00 = exp(lp00) * (lpB0 - lp00) if not isinf_neg(lp00) else 0
38 | return m11 + m10 + m01 + m00
39 |
40 | def check_mi_properties(spe, A, B, memo):
41 | miAB = spe.mutual_information(A, B, memo=memo)
42 | miAA = spe.mutual_information(A, A, memo=memo)
43 | miBB = spe.mutual_information(B, B, memo=memo)
44 | eA = entropy(spe, A, memo=memo)
45 | eB = entropy(spe, B, memo=memo)
46 | eAB = entropyc(spe, A, B, memo=memo)
47 | eBA = entropyc(spe, B, A, memo=memo)
48 | assert allclose(miAA, eA)
49 | assert allclose(miBB, eB)
50 | assert allclose(miAB, eA - eAB)
51 | assert allclose(miAB, eB - eBA)
52 |
53 | @pytest.mark.parametrize('memo', [Memo(), None])
54 | def test_mutual_information_four_clusters(memo):
55 | X = Id('X')
56 | Y = Id('Y')
57 | spe \
58 | = 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \
59 | | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \
60 | | 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \
61 | | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \
62 |
63 | A = X > 2
64 | B = Y > 2
65 | samples = spe.sample(100, prng)
66 | mi = spe.mutual_information(A, B, memo=memo)
67 | assert allclose(mi, 0)
68 | check_mi_properties(spe, A, B, memo)
69 |
70 | event = ((X>2) & (Y<2) | ((X<2) & (Y>2)))
71 | spe_condition = spe.condition(event)
72 | samples = spe_condition.sample(100, prng)
73 | assert all(event.evaluate(sample) for sample in samples)
74 | mi = spe_condition.mutual_information(X > 2, Y > 2)
75 | assert allclose(mi, log(2))
76 |
77 | check_mi_properties(spe, (X>1) | (Y<1), (Y>2), memo)
78 | check_mi_properties(spe, (X>1) | (Y<1), (X>1.5) & (Y>2), memo)
79 |
--------------------------------------------------------------------------------
/tests/test_nominal_distribution.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from fractions import Fraction
5 | from math import log
6 |
7 | import numpy
8 | import pytest
9 |
10 | from sppl.distributions import choice
11 | from sppl.math_util import allclose
12 | from sppl.math_util import isinf_neg
13 | from sppl.sets import FiniteNominal
14 | from sppl.transforms import Id
15 |
16 | def test_nominal_distribution():
17 | X = Id('X')
18 | spe = X >> choice({
19 | 'a': Fraction(1, 5),
20 | 'b': Fraction(1, 5),
21 | 'c': Fraction(3, 5),
22 | })
23 | assert allclose(spe.logprob(X << {'a'}), log(Fraction(1, 5)))
24 | assert allclose(spe.logprob(X << {'b'}), log(Fraction(1, 5)))
25 | assert allclose(spe.logprob(X << {'a', 'c'}), log(Fraction(4, 5)))
26 | assert allclose(
27 | spe.logprob((X << {'a'}) & ~(X << {'b'})),
28 | log(Fraction(1, 5)))
29 | assert allclose(
30 | spe.logprob((X << {'a', 'b'}) & ~(X << {'b'})),
31 | log(Fraction(1, 5)))
32 | assert spe.logprob((X << {'d'})) == -float('inf')
33 | assert spe.logprob((X << ())) == -float('inf')
34 |
35 | samples = spe.sample(100)
36 | assert all(s[X] in spe.support for s in samples)
37 |
38 | samples = spe.sample_subset([X], 100)
39 | assert all(len(s)==1 and s[X] in spe.support for s in samples)
40 |
41 | with pytest.raises(Exception):
42 | spe.sample_subset(['f'], 100)
43 |
44 | predicate = lambda X: (X in {'a', 'b'}) or X in {'c'}
45 | samples = spe.sample_func(predicate, 100)
46 | assert all(samples)
47 |
48 | predicate = lambda X: (not (X in {'a', 'b'})) and (not (X in {'c'}))
49 | samples = spe.sample_func(predicate, 100)
50 | assert not any(samples)
51 |
52 | func = lambda X: 1 if X in {'a'} else None
53 | samples = spe.sample_func(func, 100, prng=numpy.random.RandomState(1))
54 | assert sum(1 for s in samples if s == 1) > 12
55 | assert sum(1 for s in samples if s is None) > 70
56 |
57 | with pytest.raises(ValueError):
58 | spe.sample_func(lambda Y: Y, 100)
59 |
60 | spe_condition = spe.condition(X<<{'a', 'b'})
61 | assert spe_condition.support == FiniteNominal('a', 'b', 'c')
62 | assert allclose(spe_condition.logprob(X << {'a'}), -log(2))
63 | assert allclose(spe_condition.logprob(X << {'b'}), -log(2))
64 | assert spe_condition.logprob(X << {'c'}) == -float('inf')
65 |
66 | assert isinf_neg(spe_condition.logprob(X**2 << {1}))
67 |
68 | with pytest.raises(ValueError):
69 | spe.condition(X << {'python'})
70 | assert spe.condition(~(X << {'python'})) == spe
71 |
--------------------------------------------------------------------------------
/tests/test_parse_distributions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | import pytest
7 |
8 | from sppl.distributions import DistributionMix
9 | from sppl.distributions import bernoulli
10 | from sppl.distributions import choice
11 | from sppl.distributions import discrete
12 | from sppl.distributions import norm
13 | from sppl.distributions import poisson
14 | from sppl.distributions import rv_discrete
15 | from sppl.distributions import uniformd
16 | from sppl.math_util import allclose
17 | from sppl.sets import FiniteNominal
18 | from sppl.sets import Interval
19 | from sppl.sets import inf as oo
20 | from sppl.spe import ContinuousLeaf
21 | from sppl.spe import DiscreteLeaf
22 | from sppl.spe import NominalLeaf
23 | from sppl.spe import SumSPE
24 | from sppl.transforms import Id
25 |
26 | X = Id('X')
27 |
28 | def test_simple_parse_real():
29 | assert isinstance(.3*bernoulli(p=.1), DistributionMix)
30 | a = .3*bernoulli(p=.1) | .5 * norm() | .2*poisson(mu=7)
31 | spe = a(X)
32 | assert isinstance(spe, SumSPE)
33 | assert allclose(spe.weights, [log(.3), log(.5), log(.2)])
34 | assert isinstance(spe.children[0], DiscreteLeaf)
35 | assert isinstance(spe.children[1], ContinuousLeaf)
36 | assert isinstance(spe.children[2], DiscreteLeaf)
37 | assert spe.children[0].support == Interval(0, 1)
38 | assert spe.children[1].support == Interval(-oo, oo)
39 | assert spe.children[2].support == Interval(0, oo)
40 |
41 | def test_simple_parse_nominal():
42 | assert isinstance(.7 * choice({'a': .1, 'b': .9}), DistributionMix)
43 | a = .3*bernoulli(p=.1) | .7*choice({'a': .1, 'b': .9})
44 | spe = a(X)
45 | assert isinstance(spe, SumSPE)
46 | assert allclose(spe.weights, [log(.3), log(.7)])
47 | assert isinstance(spe.children[0], DiscreteLeaf)
48 | assert isinstance(spe.children[1], NominalLeaf)
49 | assert spe.children[0].support == Interval(0, 1)
50 | assert spe.children[1].support == FiniteNominal('a', 'b')
51 |
52 | def test_error():
53 | with pytest.raises(TypeError):
54 | 'a'*bernoulli(p=.1)
55 | a = .1 *bernoulli(p=.1) | .7*poisson(mu=8)
56 | with pytest.raises(Exception):
57 | a(X)
58 |
59 | def test_parse_rv_discrete():
60 | for dist in [
61 | rv_discrete(values=((1, 2, 10), (.3, .5, .2))),
62 | discrete({1: .3, 2: .5, 10: .2})
63 | ]:
64 | spe = dist(X)
65 | assert spe.support == Interval(1, 10)
66 | assert allclose(spe.prob(X<<{1}), .3)
67 | assert allclose(spe.prob(X<<{2}), .5)
68 | assert allclose(spe.prob(X<<{10}), .2)
69 | assert allclose(spe.prob(X<=10), 1)
70 |
71 | dist = uniformd(values=((1, 2, 10, 0)))
72 | spe = dist(X)
73 | assert spe.support == Interval(0, 10)
74 | assert allclose(spe.prob(X<<{1}), .25)
75 | assert allclose(spe.prob(X<<{2}), .25)
76 | assert allclose(spe.prob(X<<{10}), .25)
77 | assert allclose(spe.prob(X<<{0}), .25)
78 |
--------------------------------------------------------------------------------
/tests/test_parse_spe.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | import pytest
7 |
8 | from sppl.spe import ContinuousLeaf
9 | from sppl.spe import PartialSumSPE
10 | from sppl.spe import ProductSPE
11 | from sppl.spe import SumSPE
12 |
13 | from sppl.distributions import gamma
14 | from sppl.distributions import norm
15 | from sppl.transforms import Id
16 |
17 | from sppl.math_util import allclose
18 |
19 | X = Id('X')
20 | Y = Id('Y')
21 | Z = Id('Z')
22 |
23 | def test_mul_leaf():
24 | for y in [0.3 * (X >> norm()), (X >> norm()) * 0.3]:
25 | assert isinstance(y, PartialSumSPE)
26 | assert len(y.weights) == 1
27 | assert allclose(float(sum(y.weights)), 0.3)
28 |
29 | def test_sum_leaf():
30 | # Cannot sum leaves without weights.
31 | with pytest.raises(TypeError):
32 | (X >> norm()) | (X >> gamma(a=1))
33 | # Cannot sum a leaf with a partial sum.
34 | with pytest.raises(TypeError):
35 | 0.3*(X >> norm()) | (X >> gamma(a=1))
36 | # Cannot sum a leaf with a partial sum.
37 | with pytest.raises(TypeError):
38 | (X >> norm()) | 0.3*(X >> gamma(a=1))
39 | # Wrong symbol.
40 | with pytest.raises(ValueError):
41 | 0.4*(X >> norm()) | 0.6*(Y >> gamma(a=1))
42 | # Sum exceeds one.
43 | with pytest.raises(ValueError):
44 | 0.4*(X >> norm()) | 0.7*(Y >> gamma(a=1))
45 |
46 | y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1))
47 | assert isinstance(y, PartialSumSPE)
48 | assert len(y.weights) == 2
49 | assert allclose(float(y.weights[0]), 0.4)
50 | assert allclose(float(y.weights[1]), 0.3)
51 |
52 | y = 0.4*(X >> norm()) | 0.6*(X >> gamma(a=1))
53 | assert isinstance(y, SumSPE)
54 | assert len(y.weights) == 2
55 | assert allclose(float(y.weights[0]), log(0.4))
56 | assert allclose(float(y.weights[1]), log(0.6))
57 | # Sum exceeds one.
58 | with pytest.raises(TypeError):
59 | y | 0.7 * (X >> norm())
60 |
61 | y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) | 0.1*(X >> norm())
62 | assert isinstance(y, PartialSumSPE)
63 | assert len(y.weights) == 3
64 | assert allclose(float(y.weights[0]), 0.4)
65 | assert allclose(float(y.weights[1]), 0.3)
66 | assert allclose(float(y.weights[2]), 0.1)
67 |
68 | y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) | 0.3*(X >> norm())
69 | assert isinstance(y, SumSPE)
70 | assert len(y.weights) == 3
71 | assert allclose(float(y.weights[0]), log(0.4))
72 | assert allclose(float(y.weights[1]), log(0.3))
73 | assert allclose(float(y.weights[2]), log(0.3))
74 |
75 | with pytest.raises(TypeError):
76 | (0.3)*(0.3*(X >> norm()))
77 | with pytest.raises(TypeError):
78 | (0.3*(X >> norm())) * (0.3)
79 | with pytest.raises(TypeError):
80 | 0.3*(0.3*(X >> norm()) | 0.5*(X >> norm()))
81 |
82 | w = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm()))
83 | assert isinstance(w, PartialSumSPE)
84 |
85 | def test_product_leaf():
86 | with pytest.raises(TypeError):
87 | 0.3*(X >> gamma(a=1)) & (X >> norm())
88 | with pytest.raises(TypeError):
89 | (X >> norm()) & 0.3*(X >> gamma(a=1))
90 | with pytest.raises(ValueError):
91 | (X >> norm()) & (X >> gamma(a=1))
92 |
93 | y = (X >> norm()) & (Y >> gamma(a=1)) & (Z >> norm())
94 | assert isinstance(y, ProductSPE)
95 | assert len(y.children) == 3
96 | assert y.get_symbols() == frozenset([X, Y, Z])
97 |
98 | def test_sum_of_sums():
99 | w \
100 | = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) \
101 | | 0.7*(0.1*(X >> norm()) | 0.9*(X >> norm()))
102 | assert isinstance(w, SumSPE)
103 | assert len(w.children) == 4
104 | assert allclose(float(w.weights[0]), log(0.3) + log(0.4))
105 | assert allclose(float(w.weights[1]), log(0.3) + log(0.6))
106 | assert allclose(float(w.weights[2]), log(0.7) + log(0.1))
107 | assert allclose(float(w.weights[3]), log(0.7) + log(0.9))
108 |
109 | w \
110 | = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) \
111 | | 0.2*(0.1*(X >> norm()) | 0.9*(X >> norm()))
112 | assert isinstance(w, PartialSumSPE)
113 | assert allclose(float(w.weights[0]), 0.3)
114 | assert allclose(float(w.weights[1]), 0.2)
115 |
116 | a = w | 0.5*(X >> gamma(a=1))
117 | assert isinstance(a, SumSPE)
118 | assert len(a.children) == 5
119 | assert allclose(float(a.weights[0]), log(0.3) + log(0.4))
120 | assert allclose(float(a.weights[1]), log(0.3) + log(0.6))
121 | assert allclose(float(a.weights[2]), log(0.2) + log(0.1))
122 | assert allclose(float(a.weights[3]), log(0.2) + log(0.9))
123 | assert allclose(float(a.weights[4]), log(0.5))
124 |
125 | # Wrong symbol.
126 | with pytest.raises(ValueError):
127 | z = w | 0.4*(Y >> gamma(a=1))
128 |
129 | def test_or_and():
130 | with pytest.raises(ValueError):
131 | (0.3*(X >> norm()) | 0.7*(Y >> gamma(a=1))) & (Z >> norm())
132 | a = (0.3*(X >> norm()) | 0.7*(X >> gamma(a=1))) & (Z >> norm())
133 | assert isinstance(a, ProductSPE)
134 | assert isinstance(a.children[0], SumSPE)
135 | assert isinstance(a.children[1], ContinuousLeaf)
136 |
--------------------------------------------------------------------------------
/tests/test_poly.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from sympy import Poly as SymPoly
5 | from sympy import Rational
6 | from sympy import sqrt as SymSqrt
7 | from sympy.abc import x
8 |
9 | from sppl.poly import solve_poly_equality
10 | from sppl.poly import solve_poly_inequality
11 |
12 | from sppl.math_util import allclose
13 |
14 | from sppl.sets import EmptySet
15 | from sppl.sets import ExtReals
16 | from sppl.sets import FiniteReal
17 | from sppl.sets import Interval
18 | from sppl.sets import Reals
19 | from sppl.sets import Union
20 | from sppl.sets import inf as oo
21 |
22 | def test_solve_poly_inequaltiy_pos_inf():
23 | assert solve_poly_inequality(x**2-10*x+100, oo, True) == Reals
24 | assert solve_poly_inequality(x**2-10*x+100, oo, False) == ExtReals
25 |
26 | assert solve_poly_inequality(-x**3+10*x, oo, False) == ExtReals
27 | assert solve_poly_inequality(-x**3+10*x, oo, True) == Reals | FiniteReal(oo)
28 |
29 | assert solve_poly_inequality(x**3-10*x, oo, False) == ExtReals
30 | assert solve_poly_inequality(x**3-10*x, oo, True) == Reals | FiniteReal(-oo)
31 |
32 |
33 | def test_solve_poly_inequaltiy_neg_inf():
34 | assert solve_poly_inequality(x**2-10*x+100, -oo, True) is EmptySet
35 | assert solve_poly_inequality(x**2-10*x+100, -oo, False) is EmptySet
36 |
37 | assert solve_poly_inequality(x**3-10*x, -oo, True) is EmptySet
38 | assert solve_poly_inequality(x**3-10*x, -oo, False) == FiniteReal(-oo)
39 |
40 | assert solve_poly_inequality(-x**2+10*x+100, -oo, True) is EmptySet
41 | assert solve_poly_inequality(-x**2+10*x+100, -oo, False) == FiniteReal(-oo, oo)
42 |
43 |
44 | p_quadratic = SymPoly((x-SymSqrt(2)/10)*(x+Rational(10, 7)), x)
45 | expr_quadratic = p_quadratic.args[0]
46 | def test_solve_poly_equality_quadratic_zero():
47 | roots = solve_poly_equality(expr_quadratic, 0)
48 | assert roots == FiniteReal(SymSqrt(2)/10, -Rational(10,7))
49 | def test_solve_poly_inequality_quadratic_zero():
50 | interval = solve_poly_inequality(expr_quadratic, 0, False)
51 | assert interval == Interval(-Rational(10,7), SymSqrt(2)/10)
52 | interval = solve_poly_inequality(expr_quadratic, 0, True)
53 | assert interval == Interval.open(-Rational(10,7), SymSqrt(2)/10)
54 |
55 | xe1_quad0 = -5/7 + SymSqrt(2)/20 + SymSqrt(2)*SymSqrt(700*SymSqrt(2) + 14849)/140
56 | xe1_quad1 = -SymSqrt(2)*SymSqrt(700*SymSqrt(2) + 14849)/140 - 5/7 + SymSqrt(2)/20
57 | def test_solve_poly_equality_quadratic_one():
58 | roots = solve_poly_equality(expr_quadratic, 1)
59 | # SymPy is not smart enough to simplify irrational roots symbolically
60 | # so check numerical equality of the symbolic roots.
61 | assert len(roots) == 2
62 | assert any(allclose(float(x), float(xe1_quad0)) for x in roots)
63 | assert any(allclose(float(x), float(xe1_quad1)) for x in roots)
64 |
65 |
66 | p_cubic_int = SymPoly((x-1)*(x+2)*(x-11), x)
67 | expr_cubic_int = p_cubic_int.args[0]
68 | def test_solve_poly_equality_cubic_int_zero():
69 | roots = solve_poly_equality(expr_cubic_int, 0)
70 | assert roots == FiniteReal(-2, 1, 11)
71 | def test_solve_poly_inequality_cubic_int_zero():
72 | interval = solve_poly_inequality(expr_cubic_int, 0, False)
73 | assert interval == Interval(-oo, -2) | Interval(1, 11)
74 | interval = solve_poly_inequality(expr_cubic_int, 0, True)
75 | assert interval == Interval.open(-oo, -2) | Interval.open(1, 11)
76 |
77 | xe1_cubic_int0 = -1.97408387376586
78 | xe1_cubic_int1 = 0.966402009973818
79 | xe1_cubic_int2 = 11.007681863792
80 | def test_solve_poly_equality_cubic_int_one():
81 | roots = solve_poly_equality(expr_cubic_int, 1)
82 | assert len(roots) == 3
83 | assert any(allclose(float(x), xe1_cubic_int0) for x in roots)
84 | assert any(allclose(float(x), xe1_cubic_int1) for x in roots)
85 | assert any(allclose(float(x), xe1_cubic_int2) for x in roots)
86 | def test_solve_poly_inequality_cubic_int_one():
87 | interval = solve_poly_inequality(expr_cubic_int, 1, True)
88 | assert isinstance(interval, Union)
89 | assert len(interval.args)==2
90 | # First interval.
91 | assert interval.args[0].left == -oo
92 | assert allclose(float(interval.args[0].right), xe1_cubic_int0)
93 | assert interval.args[0].right_open
94 | # Second interval.
95 | assert allclose(float(interval.args[1].left), xe1_cubic_int1)
96 | assert allclose(float(interval.args[1].right), xe1_cubic_int2)
97 | assert interval.args[1].left_open
98 | assert interval.args[1].right_open
99 |
100 | interval = solve_poly_inequality(-1*expr_cubic_int, -1, True)
101 | assert isinstance(interval, Union)
102 | assert len(interval.args) == 2
103 | # First interval.
104 | assert allclose(float(interval.args[0].left), xe1_cubic_int0)
105 | assert allclose(float(interval.args[0].right), xe1_cubic_int1)
106 | # Second interval.
107 | assert allclose(float(interval.args[1].left), xe1_cubic_int2)
108 | assert interval.args[1].right == oo
109 |
110 |
111 | p_cubic_irrat = SymPoly((x-SymSqrt(2)/10)*(x+Rational(10, 7))*(x-SymSqrt(5)), x)
112 | expr_cubic_irrat = p_cubic_irrat.args[0]
113 | def test_solve_poly_equality_cubic_irrat_zero():
114 | roots = solve_poly_equality(expr_cubic_irrat, 0)
115 | # Confirm that roots contains symbolic elements (no timeout).
116 | assert -Rational(10,7) in roots
117 | # SymPy is not smart enough to simplify irrational roots symbolically
118 | # so check numerical equality of the symbolic roots.
119 | assert any(allclose(float(x), float(SymSqrt(2)/10)) for x in roots)
120 | assert any(allclose(float(x), float(SymSqrt(5))) for x in roots)
121 |
122 | xe1_cubic_irrat0 = -1.21493150246058
123 | xe1_cubic_irrat1 = -0.19158140952462
124 | xe1_cubic_irrat2 = 2.35543081715088
125 | def test_solve_poly_equality_cubic_irrat_one():
126 | # This expression is too slow to solve symbolically.
127 | # The 5s timeout will trigger a numerical approximation.
128 | roots = solve_poly_equality(expr_cubic_irrat, 1)
129 | assert len(roots) == 3
130 | assert any(allclose(float(x), xe1_cubic_irrat0) for x in roots)
131 | assert any(allclose(float(x), xe1_cubic_irrat1) for x in roots)
132 | assert any(allclose(float(x), xe1_cubic_irrat2) for x in roots)
133 |
--------------------------------------------------------------------------------
/tests/test_real_continuous.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from math import log
5 |
6 | import numpy
7 | import pytest
8 | import scipy.stats
9 |
10 | from sppl.distributions import gamma
11 | from sppl.distributions import norm
12 | from sppl.math_util import allclose
13 | from sppl.math_util import isinf_neg
14 | from sppl.math_util import logdiffexp
15 | from sppl.sets import Interval
16 | from sppl.sets import Reals
17 | from sppl.sets import inf as oo
18 | from sppl.spe import ContinuousLeaf
19 | from sppl.spe import SumSPE
20 | from sppl.transforms import Id
21 |
22 | def test_numeric_distribution_normal():
23 | X = Id('X')
24 | spe = (X >> norm(loc=0, scale=1))
25 |
26 | assert spe.size() == 1
27 | assert allclose(spe.logprob(X > 0), -log(2))
28 | assert allclose(spe.logprob(abs(X) < 2), log(spe.dist.cdf(2) - spe.dist.cdf(-2)))
29 |
30 | assert allclose(spe.logprob(X**2 > 0), 0)
31 | assert allclose(spe.logprob(abs(X) > 0), 0)
32 | assert allclose(spe.logprob(~(X << {1})), 0)
33 |
34 | assert isinf_neg(spe.logprob(X**2 - X + 10 < 0))
35 | assert isinf_neg(spe.logprob(abs(X) < 0))
36 | assert isinf_neg(spe.logprob(X << {1}))
37 |
38 | spe.sample(100)
39 | spe.sample_subset([X], 100)
40 | assert spe.sample_subset([], 100) == [{}]*100
41 | spe.sample_func(lambda X: X**2, 1)
42 | spe.sample_func(lambda X: abs(X)+X**2, 1)
43 | spe.sample_func(lambda X: X**2 if X > 0 else X**3, 100)
44 |
45 | spe_condition_a = spe.condition((X < 2) | (X > 10))
46 | samples = spe_condition_a.sample(100)
47 | assert all(s[X] < 2 for s in samples)
48 |
49 | spe_condition_b = spe.condition((X < -10) | (X > 10))
50 | assert isinstance(spe_condition_b, SumSPE)
51 | assert allclose(spe_condition_b.weights[0], -log(2))
52 | assert allclose(spe_condition_b.weights[0], spe_condition_b.weights[1])
53 |
54 | for event in [(X<-10), (X>3)]:
55 | spe_condition_c = spe.condition(event)
56 | assert isinstance(spe_condition_c, ContinuousLeaf)
57 | assert isinf_neg(spe_condition_c.logprob((-1 < X) < 1))
58 | samples = spe_condition_c.sample(100, prng=numpy.random.RandomState(1))
59 | assert all(s[X] in event.values for s in samples)
60 |
61 | with pytest.raises(ValueError):
62 | spe.condition((X > 1) & (X < 1))
63 |
64 | with pytest.raises(ValueError):
65 | spe.condition(X << {1})
66 |
67 | with pytest.raises(ValueError):
68 | spe.sample_func(lambda Z: Z**2, 1)
69 |
70 | x = spe.logprob((X << {1, 2}) | (X < -1))
71 | assert allclose(x, spe.logprob(X < -1))
72 |
73 | with pytest.raises(AssertionError):
74 | spe.logprob(Id('Y') << {1, 2})
75 |
76 | def test_numeric_distribution_gamma():
77 | X = Id('X')
78 |
79 | spe = (X >> gamma(a=1, scale=1))
80 | with pytest.raises(ValueError):
81 | spe.condition((X << {1, 2}) | (X < 0))
82 |
83 | # Intentionally set Reals as the domain to exercise an important
84 | # code path in dist.condition (Union case with zero weights).
85 | spe = ContinuousLeaf(X, scipy.stats.gamma(a=1, scale=1), Reals)
86 | assert isinf_neg(spe.logprob((X << {1, 2}) | (X < 0)))
87 | with pytest.raises(ValueError):
88 | spe.condition((X << {1, 2}) | (X < 0))
89 |
90 | spe_condition = spe.condition((X << {1,2} | (X <= 3)))
91 | assert isinstance(spe_condition, ContinuousLeaf)
92 | assert spe_condition.conditioned
93 | assert spe_condition.support == Interval(-oo, 3)
94 | assert allclose(
95 | spe_condition.logprob(X <= 2),
96 | logdiffexp(spe.logprob(X<=2), spe.logprob(X<=0))
97 | - spe_condition.logZ)
98 |
99 | # Support on (-3, oo)
100 | spe = (X >> gamma(loc=-3, a=1))
101 | assert spe.prob((-3 < X) < 0) > 0.95
102 |
103 | # Constrain.
104 | with pytest.raises(Exception):
105 | spe.constrain({X: -4})
106 | spe_constrain = spe.constrain({X: .5})
107 | samples = spe_constrain.sample(100, prng=None)
108 | assert all(s == {X: .5} for s in samples)
109 |
--------------------------------------------------------------------------------
/tests/test_real_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | import numpy
7 |
8 | from sppl.distributions import poisson
9 | from sppl.distributions import randint
10 | from sppl.math_util import allclose
11 | from sppl.math_util import logdiffexp
12 | from sppl.math_util import logsumexp
13 | from sppl.sets import Interval
14 | from sppl.sets import Range
15 | from sppl.sets import inf as oo
16 | from sppl.spe import DiscreteLeaf
17 | from sppl.spe import SumSPE
18 | from sppl.transforms import Id
19 |
20 | def test_poisson():
21 | X = Id('X')
22 | spe = X >> poisson(mu=5)
23 |
24 | a = spe.logprob((1 <= X) <= 7)
25 | b = spe.logprob(X << {1,2,3,4,5,6,7})
26 | c = logsumexp([spe.logprob(X << {i}) for i in range(1, 8)])
27 | assert allclose(a, b)
28 | assert allclose(a, c)
29 | assert allclose(b, c)
30 |
31 | spe_condition = spe.condition(10 <= X)
32 | assert spe_condition.conditioned
33 | assert spe_condition.support == Range(10, oo)
34 | assert spe_condition.logZ == logdiffexp(0, spe.logprob(X<=9))
35 |
36 | assert allclose(
37 | spe_condition.logprob(X <= 10),
38 | spe_condition.logprob(X << {10}))
39 | assert allclose(
40 | spe_condition.logprob(X <= 10),
41 | spe_condition.logpdf({X: 10}))
42 |
43 | samples = spe_condition.sample(100)
44 | assert all(10 <= s[X] for s in samples)
45 |
46 | # Unify X = 5 with left interval to make one distribution.
47 | event = ((1 <= X) < 5) | ((3*X + 1) << {16})
48 | spe_condition = spe.condition(event)
49 | assert isinstance(spe_condition, DiscreteLeaf)
50 | assert spe_condition.conditioned
51 | assert spe_condition.xl == 1
52 | assert spe_condition.xu == 5
53 | assert spe_condition.support == Range(1, 5)
54 | samples = spe_condition.sample(100, prng=numpy.random.RandomState(1))
55 | assert all(event.evaluate(s) for s in samples)
56 |
57 | # Ignore X = 14/3 as a probability zero condition.
58 | spe_condition = spe.condition(((1 <= X) < 5) | (3*X + 1) << {15})
59 | assert isinstance(spe_condition, DiscreteLeaf)
60 | assert spe_condition.conditioned
61 | assert spe_condition.xl == 1
62 | assert spe_condition.xu == 4
63 | assert spe_condition.support == Interval.Ropen(1,5)
64 |
65 | # Make a mixture of two components.
66 | spe_condition = spe.condition(((1 <= X) < 5) | (3*X + 1) << {22})
67 | assert isinstance(spe_condition, SumSPE)
68 | xl = spe_condition.children[0].xl
69 | idx0 = 0 if xl == 7 else 1
70 | idx1 = 1 if xl == 7 else 0
71 | assert spe_condition.children[idx1].conditioned
72 | assert spe_condition.children[idx1].xl == 1
73 | assert spe_condition.children[idx1].xu == 4
74 | assert spe_condition.children[idx0].conditioned
75 | assert spe_condition.children[idx0].xl == 7
76 | assert spe_condition.children[idx0].xu == 7
77 | assert spe_condition.children[idx0].support == Range(7, 7)
78 |
79 | # Condition on probability zero event.
80 | with pytest.raises(ValueError):
81 | spe.condition(((-3 <= X) < 0) | (3*X + 1) << {20})
82 |
83 | # Condition on FiniteReal contiguous.
84 | spe_condition = spe.condition(X << {1,2,3})
85 | assert spe_condition.xl == 1
86 | assert spe_condition.xu == 3
87 | assert allclose(spe_condition.logprob((1 <= X) <=3), 0)
88 |
89 | # Condition on single point.
90 | assert allclose(0, spe.condition(X << {2}).logprob(X<<{2}))
91 |
92 | # Constrain.
93 | with pytest.raises(Exception):
94 | spe.constrain({X: -1})
95 | with pytest.raises(Exception):
96 | spe.constrain({X: .5})
97 | spe_constrain = spe.constrain({X: 10})
98 | assert allclose(spe_constrain.prob(X << {0, 10}), 1)
99 |
100 | def test_condition_non_contiguous():
101 | X = Id('X')
102 | spe = X >> poisson(mu=5)
103 | # FiniteSet.
104 | for c in [{0,2,3}, {-1,0,2,3}, {-1,0,2,3,'z'}]:
105 | spe_condition = spe.condition((X << c))
106 | assert isinstance(spe_condition, SumSPE)
107 | assert allclose(0, spe_condition.children[0].logprob(X<<{0}))
108 | assert allclose(0, spe_condition.children[1].logprob(X<<{2,3}))
109 | # FiniteSet or Interval.
110 | spe_condition = spe.condition((X << {-1,'x',0,2,3}) | (X > 7))
111 | assert isinstance(spe_condition, SumSPE)
112 | assert len(spe_condition.children) == 3
113 | assert allclose(0, spe_condition.children[0].logprob(X<<{0}))
114 | assert allclose(0, spe_condition.children[1].logprob(X<<{2,3}))
115 | assert allclose(0, spe_condition.children[2].logprob(X>7))
116 |
117 | def test_randint():
118 | X = Id('X')
119 | spe = X >> randint(low=0, high=5)
120 | assert spe.xl == 0
121 | assert spe.xu == 4
122 | assert spe.logprob(X < 5) == spe.logprob(X <= 4) == 0
123 | # i.e., X is not in [0, 3]
124 | spe_condition = spe.condition(~((X+1) << {1, 4}))
125 | assert isinstance(spe_condition, SumSPE)
126 | xl = spe_condition.children[0].xl
127 | idx0 = 0 if xl == 1 else 1
128 | idx1 = 1 if xl == 1 else 0
129 | assert spe_condition.children[idx0].xl == 1
130 | assert spe_condition.children[idx0].xu == 2
131 | assert spe_condition.children[idx1].xl == 4
132 | assert spe_condition.children[idx1].xu == 4
133 | assert allclose(spe_condition.children[idx0].logprob(X<<{1,2}), 0)
134 | assert allclose(spe_condition.children[idx1].logprob(X<<{4}), 0)
135 |
--------------------------------------------------------------------------------
/tests/test_render.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import os
5 |
6 | import pytest
7 |
8 | from sppl.compilers.ast_to_spe import Id
9 | from sppl.compilers.ast_to_spe import IfElse
10 | from sppl.compilers.ast_to_spe import Otherwise
11 | from sppl.compilers.ast_to_spe import Sample
12 | from sppl.compilers.ast_to_spe import Sequence
13 | from sppl.compilers.ast_to_spe import Transform
14 | from sppl.distributions import bernoulli
15 | from sppl.distributions import choice
16 | from sppl.render import render_nested_lists
17 | from sppl.render import render_nested_lists_concise
18 |
19 | def get_model():
20 | Y = Id('Y')
21 | X = Id('X')
22 | Z = Id('Z')
23 | command = Sequence(
24 | Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})),
25 | Sample(Z, bernoulli(p=0.1)),
26 | IfElse(
27 | Y << {str(0)} | Z << {0}, Sample(X, bernoulli(p=1/(0+1))),
28 | Otherwise, Transform(X, Z**2 + Z)))
29 | return command.interpret()
30 |
31 | def test_render_lists_crash():
32 | model = get_model()
33 | render_nested_lists_concise(model)
34 | render_nested_lists(model)
35 |
36 | def test_render_graphviz_crash__magics_():
37 | pytest.importorskip('graphviz')
38 | pytest.importorskip('pygraphviz')
39 | pytest.importorskip('networkx')
40 |
41 | from sppl.magics.render import render_networkx_graph
42 | from sppl.magics.render import render_graphviz
43 |
44 | model = get_model()
45 | render_networkx_graph(model)
46 | for fname in [None, '/tmp/spe.test.render']:
47 | for e in ['pdf', 'png', None]:
48 | render_graphviz(model, fname, ext=e)
49 | if fname is not None:
50 | assert not os.path.exists(fname)
51 | for ext in ['dot', e]:
52 | f = '%s.%s' % (fname, ext,)
53 | if e is not None:
54 | os.path.exists(f)
55 | os.unlink(f)
56 |
--------------------------------------------------------------------------------
/tests/test_sets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.sets import EmptySet
7 | from sppl.sets import FiniteNominal as FN
8 | from sppl.sets import FiniteReal as FR
9 | from sppl.sets import Interval
10 | from sppl.sets import Union
11 | from sppl.sets import inf
12 | from sppl.sets import union_intervals
13 | from sppl.sets import union_intervals_finite
14 |
15 | def test_FiniteNominal_in():
16 | with pytest.raises(Exception):
17 | FN()
18 | assert 'a' in FN('a')
19 | assert 'b' not in FN('a')
20 | assert 'b' in FN('a', b=True)
21 | assert 'a' not in FN('a', b=True)
22 | assert 'a' in FN(b=True)
23 |
24 | def test_FiniteNominal_invert():
25 | assert ~(FN('a')) == FN('a', b=True)
26 | assert ~(FN('a', b=True)) == FN('a')
27 | assert ~FN(b=True) == EmptySet
28 |
29 | def test_FiniteNominal_and():
30 | assert FN('a','b') & EmptySet is EmptySet
31 | assert FN('a','b') & FN('c') is EmptySet
32 | assert FN('a','b','c') & FN('a') == FN('a')
33 | assert FN('a','b','c') & FN(b=True) == FN('a','b','c')
34 | assert FN('a','b','c') & FN('a') == FN('a')
35 | assert FN('a','b','c', b=True) & FN('a') is EmptySet
36 | assert FN('a','b','c', b=True) & FN('d','a','b') == FN('d')
37 | assert FN('a','b','c', b=True) & FN('d') == FN('d')
38 | assert FN('a','b','c') & FN('a', b=True) == FN('b','c')
39 | assert FN('a','b','c') & FN('d','a','b', b=True) == FN('c')
40 | assert FN('a','b','c') & FN('d', b=True) == FN('a','b','c')
41 | assert FN('a','b','c', b=True) & FN('d', b=True) == FN('a','b','c','d', b=True)
42 | assert FN('a','b','c', b=True) & FN('a', b=True) == FN('a','b','c', b=True)
43 | assert FN(b=True) & FN(b=True) == FN(b=True)
44 | # FiniteReal
45 | assert FN('a') & FR(1) is EmptySet
46 | assert FN('a', b=True) & FR(1) is EmptySet
47 | # Interval
48 | assert FN('a') & Interval(0,1) is EmptySet
49 |
50 | def test_FiniteNominal_or():
51 | # EmptySet
52 | assert FN('a','b') | EmptySet == FN('a','b')
53 | # Nominal
54 | assert FN('a','b') | FN('c') == FN('a','b','c')
55 | assert FN('a','b','c') | FN('a') == FN('a','b','c')
56 | assert FN('a','b','c') | FN(b=True) == FN(b=True)
57 | assert FN('a','b','c', b=True) | FN('a') == FN('b','c', b=True)
58 | assert FN('a','b','c', b=True) | FN('d','a','b') == FN('c', b=True)
59 | assert FN('a','b','c', b=True) | FN('d') == FN('a','b','c', b=True)
60 | assert FN('a','b','c') | FN('a', b=True) == FN(b=True)
61 | assert FN('a','b','c') | FN('d','a','b', b=True) == FN('d', b=True)
62 | assert FN('a','b','c') | FN('d', b=True) == FN('d', b=True)
63 | assert FN('a','b','c', b=True) | FN('d', b=True) == FN(b=True)
64 | assert FN('a','b','c', b=True) | FN('a', b=True) == FN('a', b=True)
65 | assert FN(b=True) | FN(b=True) == FN(b=True)
66 | # FiniteReal
67 | assert FN('a') | FR(1) == Union(FN('a'), FR(1))
68 | assert FN('a', b=True) | FR(1) == Union(FR(1), FN('a', b=True))
69 | # Interval
70 | assert FN('a') | Interval(0,1) == Union(FN('a'), Interval(0,1))
71 |
72 | def test_FiniteReal_in():
73 | with pytest.raises(Exception):
74 | FR()
75 | assert 1 in FN(1,2)
76 | assert 2 not in FN(1)
77 |
78 | def test_FiniteReal_invert():
79 | assert ~FR(1) == Union(
80 | Interval.Ropen(-inf,1),
81 | Interval.Lopen(1, inf))
82 | assert ~(FR(0,-1)) == Union(
83 | Interval.Ropen(-inf,-1),
84 | Interval.open(-1, 0),
85 | Interval.Lopen(0, inf))
86 |
87 | def test_FiniteReal_and():
88 | assert FR(1) & FR(2) is EmptySet
89 | assert FR(1,2) & FR(2) == FR(2)
90 | assert FR(1,2) & FN('2') is EmptySet
91 | assert FR(1,2) & FN('2', b=True) is EmptySet
92 | assert FR(1,2) & Interval(0,1) == FR(1)
93 | assert FR(1,2) & Interval.Ropen(0,1) is EmptySet
94 | assert FR(0,2) & Interval(0,1) == FR(0)
95 | assert FR(0,2) & Interval.Lopen(0,1) is EmptySet
96 | assert FR(-1,1) & Interval(-10,10) == FR(-1,1)
97 | assert FR(-1,11) & Interval(-10,10) == FR(-1)
98 |
99 | def test_FiniteReal_or():
100 | assert FR(1) | FR(2) == FR(1,2)
101 | assert FR(1,2) | FR(2) == FR(1,2)
102 | assert FR(1,2) | FN('2') == Union(FR(1,2), FN('2'))
103 | assert FR(1,2) | Interval(0,1) == Union(Interval(0,1), FR(2))
104 | assert FR(1,2) | Interval.Ropen(0,1) == Union(Interval(0,1), FR(2))
105 | assert FR(0,1,2) | Interval.open(0,1) == Union(Interval(0,1), FR(2))
106 | assert FR(0,2) | Interval.Lopen(0,1) == Union(Interval(0,1), FR(2))
107 | assert FR(0,2) | Interval.Lopen(2.5,10) == Union(Interval.Lopen(2.5,10), FR(0,2))
108 | assert FR(-1,1) | Interval(-10,10) == Interval(-10,10)
109 | assert FR(-1,11) | Interval(-10,10) == Union(Interval(-10, 10), FR(11))
110 |
111 | def test_Interval_in():
112 | with pytest.raises(Exception):
113 | Interval(3, 1)
114 | assert 1 in Interval(0,1)
115 | assert 1 not in Interval.Ropen(0,1)
116 | assert 1 in Interval.Lopen(0,1)
117 | assert 0 in Interval(0,1)
118 | assert 0 not in Interval.Lopen(0,1)
119 | assert 0 in Interval.Ropen(0,1)
120 | assert inf not in Interval(-inf, inf)
121 | assert -inf not in Interval(-inf, 0)
122 | assert 10 in Interval(-inf, inf)
123 |
124 | def test_Interval_invert():
125 | assert ~(Interval(0,1)) == Union(Interval.Ropen(-inf, 0), Interval.Lopen(1, inf))
126 | assert ~(Interval.open(0,1)) == Union(Interval(-inf, 0), Interval(1, inf))
127 | assert ~(Interval.Lopen(0,1)) == Union(Interval(-inf, 0), Interval.Lopen(1, inf))
128 | assert ~(Interval.Ropen(0,1)) == Union(Interval.Ropen(-inf, 0), Interval(1, inf))
129 | assert ~(Interval(-inf, inf)) is EmptySet
130 | assert ~(Interval(3, inf)) == Interval.Ropen(-inf, 3)
131 | assert ~(Interval.open(3, inf)) == Interval(-inf, 3)
132 | assert ~(Interval.Lopen(3, inf)) == Interval(-inf, 3)
133 | assert ~(Interval.Ropen(3, inf)) == Interval.Ropen(-inf, 3)
134 | assert ~(Interval(-inf, 3)) == Interval.Lopen(3, inf)
135 | assert ~(Interval.open(-inf, 3)) == Interval(3, inf)
136 | assert ~(Interval.Lopen(-inf, 3)) == Interval.Lopen(3, inf)
137 | assert ~(Interval.Ropen(-inf, 3)) == Interval(3, inf)
138 | assert ~(Interval.open(-inf, inf)) is EmptySet
139 |
140 | def test_Interval_and():
141 | assert Interval(0,1) & Interval(-1,1) == Interval(0,1)
142 | assert Interval(0,2) & Interval.open(0,1) == Interval.open(0,1)
143 | assert Interval.Lopen(0,1) & Interval(-1,1) == Interval.Lopen(0,1)
144 | assert Interval.Ropen(0,1) & Interval(-1,1) == Interval.Ropen(0,1)
145 | assert Interval(0,1) & Interval(1,2) == FR(1)
146 | assert Interval.Lopen(0,1) & Interval(1,2) == FR(1)
147 | assert Interval.Ropen(0,1) & Interval(1,2) is EmptySet
148 | assert Interval.Lopen(0,1) & Interval.Lopen(1,2) is EmptySet
149 | assert Interval(1,2) & Interval.Lopen(0,1) == FR(1)
150 | assert Interval(1,2) & Interval.open(0,1) is EmptySet
151 | assert Interval(0,2) & Interval.Lopen(0.5,2.5) == Interval.Lopen(0.5,2)
152 | assert Interval.Ropen(0,2) & Interval.Lopen(0.5,2.5) == Interval.open(0.5,2)
153 | assert Interval.open(0,2) & Interval(0.5,2.5) == Interval.Ropen(0.5,2)
154 | assert Interval.Lopen(0,2) & Interval.Ropen(0,2) == Interval.open(0,2)
155 | assert Interval(0,1) & Interval(2,3) is EmptySet
156 | assert Interval(2,3) & Interval(0,1) is EmptySet
157 | assert Interval.open(0,1) & Interval.open(0,1) == Interval.open(0,1)
158 | assert Interval.Ropen(-inf, -3) & Interval(-inf, inf) == Interval.Ropen(-inf, -3)
159 | assert Interval(-inf, inf) & Interval.Ropen(-inf, -3) == Interval.Ropen(-inf, -3)
160 | assert Interval(0, inf) & (Interval.Lopen(-5, inf)) == Interval(0, inf)
161 | assert Interval.Lopen(0, 1) & Interval.Ropen(0, 1) == Interval.open(0, 1)
162 | assert Interval.Ropen(0, 1) & Interval.Lopen(0, 1) == Interval.open(0, 1)
163 | assert Interval.Ropen(0, 5) & Interval.Ropen(-inf, 5) == Interval.Ropen(0, 5)
164 |
165 | def test_Interval_or():
166 | assert Interval(0,1) | Interval(-1,1) == Interval(-1,1)
167 | assert Interval(0, 2) | Interval.open(0, 1) == Interval(0, 2)
168 | assert Interval.Lopen(0,1) | Interval(-1,1) == Interval(-1,1)
169 | assert Interval.Ropen(0,1) | Interval(-1,1) == Interval(-1,1)
170 | assert Interval(0,1) | Interval(1, 2) == Interval(0, 2)
171 | assert Interval.open(0, 1) | Interval(0,1) == Interval(0, 1)
172 | assert Interval(0, 1) | Interval(0,.5) == Interval(0, 1)
173 | assert Interval.Lopen(0, 1) | Interval(1,2) == Interval.Lopen(0, 2)
174 | assert Interval.Ropen(-1, 0) | Interval.Ropen(0, 1) == Interval.Ropen(-1,1)
175 | assert Interval.Ropen(0, 1) | Interval.Ropen(-1, 0) == Interval.Ropen(-1,1)
176 | assert Interval.Lopen(0, 1) | Interval.Lopen(1, 2) == Interval.Lopen(0,2)
177 | assert Interval.Lopen(0, 1) | Interval.Ropen(1, 2) == Interval.open(0,2)
178 | assert Interval.open(0, 2) | Interval(0, 1) == Interval.Ropen(0, 2)
179 | assert Interval.open(0, 1) | Interval.Ropen(-1, 0) == Union(Interval.open(0, 1), Interval.Ropen(-1,0))
180 | assert Interval(1, 2) | Interval.Ropen(0, 1) == Interval(0,2)
181 | assert Interval.Ropen(0, 1) | Interval(1, 2) == Interval(0,2)
182 | assert Interval.open(0, 1) | Interval(1, 2) == Interval.Lopen(0,2)
183 | assert Interval.Ropen(0, 1) | Interval.Ropen(1, 2) == Interval.Ropen(0,2)
184 | assert Interval(1, 2) | Interval.open(0, 1) == Interval.Lopen(0,2)
185 | assert Interval(1, 2) | Interval.Lopen(0, 1) == Interval.Lopen(0,2)
186 | assert Interval(1, 2) | Interval.Ropen(0, 1) == Interval(0,2)
187 | assert Interval.open(0,1) | Interval(1,2) == Interval.Lopen(0,2)
188 | assert Interval.Lopen(0,1) | Interval(1,2) == Interval.Lopen(0,2)
189 | assert Interval.Ropen(0,1) | Interval(1,2) == Interval(0,2)
190 | assert Interval(0,2) | Interval.open(0.5, 2.5) == Interval.Ropen(0, 2.5)
191 | assert Interval.open(0,2) | Interval.open(0, 2.5) == Interval.open(0, 2.5)
192 | assert Interval.open(0,2.5) | Interval.open(0, 2) == Interval.open(0, 2.5)
193 | assert Interval.open(0,1) | Interval.open(1, 2) == Union(Interval.open(0, 1), Interval.open(1,2))
194 | assert Interval.Ropen(0,2) | Interval.Lopen(0.5, 2.5) == Interval(0, 2.5)
195 | assert Interval.open(0,2) | Interval(0.5,2.5) == Interval.Lopen(0, 2.5)
196 | assert Interval.Lopen(0,2) | Interval.Ropen(0,2) == Interval(0,2)
197 | assert Interval(0,1) | Interval(2,3) == Union(Interval(0,1), Interval(2,3))
198 | assert Interval(2,3) | Interval.Ropen(0,1) == Union(Interval(2,3), Interval.Ropen(0,1))
199 | assert Interval.Lopen(0,1) | Interval(1,2) == Interval.Lopen(0,2)
200 | assert Interval(-10,10) | FR(-1,1) == Interval(-10,10)
201 | assert Interval(-10,10) | FR(-1,11) == Union(Interval(-10, 10), FR(11))
202 | assert Interval(-inf, -3, right_open=True) | Interval(-inf, inf) == Interval(-inf, inf)
203 |
204 | def test_union_intervals():
205 | assert union_intervals([
206 | Interval(0,1),
207 | Interval(2,3),
208 | Interval(1,2)
209 | ]) == [Interval(0,3)]
210 | assert union_intervals([
211 | Interval.open(0,1),
212 | Interval(2,3),
213 | Interval(1,2)
214 | ]) == [Interval.Lopen(0,3)]
215 | assert union_intervals([
216 | Interval.open(0,1),
217 | Interval(2,3),
218 | Interval.Lopen(1,2)
219 | ]) == [Interval.open(0,1), Interval.Lopen(1,3)]
220 | assert union_intervals([
221 | Interval.open(0,1),
222 | Interval.Ropen(0,3),
223 | Interval.Lopen(1,2)
224 | ]) == [Interval.Ropen(0,3)]
225 | assert union_intervals([
226 | Interval.open(-2,-1),
227 | Interval.Ropen(0,3),
228 | Interval.Lopen(1,2)
229 | ]) == [Interval.open(-2,-1), Interval.Ropen(0,3)]
230 |
231 | def test_union_intervals_finite():
232 | assert union_intervals_finite([
233 | Interval.open(0,1),
234 | Interval(2,3),
235 | Interval.Lopen(1,2)
236 | ], FR(1)) \
237 | == [Interval.Lopen(0, 3)]
238 | assert union_intervals_finite([
239 | Interval.open(0,1),
240 | Interval.open(2, 3),
241 | Interval.open(1,2)
242 | ], FR(1, 3)) \
243 | == [Interval.open(0, 2), Interval.Lopen(2, 3)]
244 | assert union_intervals_finite([
245 | Interval.open(0,1),
246 | Interval.open(1, 3),
247 | Interval.open(11,15)
248 | ], FR(1, -11, -19, 3)) \
249 | == [Interval.Lopen(0, 3), Interval.open(11,15), FR(-11, -19)]
250 |
251 | def test_Union_or():
252 | x = Interval(0,1) | Interval(5,6) | Interval(10,11)
253 | assert x == Union(Interval(0,1), Interval(5,6), Interval(10,11))
254 | x = Interval.Ropen(0,1) | Interval.Lopen(1,2) | Interval(10,11)
255 | assert x == Union(Interval.Ropen(0,1), Interval.Lopen(1,2), Interval(10,11))
256 | x = Interval.Ropen(0,1) | Interval.Lopen(1,2) | Interval(10,11) | FR(1)
257 | assert x == Union(Interval(0,2), Interval(10,11))
258 | x = (Interval.Ropen(0,1) | Interval.Lopen(1,2)) | (Interval(10,11) | FR(1))
259 | assert x == Union(Interval(0,2), Interval(10,11))
260 | x = FR(1) | ((Interval.Ropen(0,1) | Interval.Lopen(1,2) | FR(10,13)) \
261 | | (Interval.Lopen(10,11) | FR(7)))
262 | assert x == Union(Interval(0,2), Interval(10,11), FR(13, 7))
263 | assert 2 in x
264 | assert 13 in x
265 | x = FN('f') | (FR(1) | FN('g', b=True))
266 | assert x == Union(FR(1), FN('g', b=True))
267 | assert 'w' in x
268 | assert 'g' not in x
269 |
270 | def test_Union_and():
271 | x = (Interval(0,1) | FR(1)) & (FN('a'))
272 | assert x is EmptySet
273 | x = (FN('x', b=True)| Interval(0,1) | FR(1)) & (FN('a'))
274 | assert x == FN('a')
275 | x = (FN('x')| Interval(0,1) | FR(1)) & (FN('a'))
276 | assert x is EmptySet
277 | x = (FN('x')| Interval.open(0,1) | FR(7)) & ((FN('x')) | FR(.5) | Interval(.75, 1.2))
278 | assert x == Union(FR(.5), FN('x'), Interval.Ropen(.75, 1))
279 | x = (FN('x')| Interval.open(0,1) | FR(7)) & (FR(3))
280 | assert x is EmptySet
281 | x = (Interval.Lopen(-5, inf)) & (Interval(0, inf) | FR(inf))
282 | assert x == Interval(0, inf)
283 | x = (FR(1,2) | Interval.Ropen(-inf, 0)) & Interval(0, inf)
284 | assert x == FR(1,2)
285 | x = (FR(1,12) | Interval(0, 5) | Interval(7,10)) & Interval(4, 12)
286 | assert x == Union(Interval(4,5), Interval(7,10), FR(12))
287 |
--------------------------------------------------------------------------------
/tests/test_spe_to_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import json
5 | import pytest
6 |
7 | from sympy import sqrt
8 |
9 | from sppl.compilers.spe_to_dict import spe_from_dict
10 | from sppl.compilers.spe_to_dict import spe_to_dict
11 | from sppl.distributions import choice
12 | from sppl.distributions import gamma
13 | from sppl.distributions import norm
14 | from sppl.distributions import poisson
15 | from sppl.sets import EmptySet
16 | from sppl.transforms import EventFiniteNominal
17 | from sppl.transforms import Exp
18 | from sppl.transforms import Exponential
19 | from sppl.transforms import Id
20 | from sppl.transforms import Log
21 | from sppl.transforms import Logarithm
22 |
23 | X = Id('X')
24 | Y = Id('Y')
25 |
26 | spes = [
27 | X >> norm(loc=0, scale=1),
28 | X >> poisson(mu=7),
29 | Y >> choice({'a': 0.5, 'b': 0.5}),
30 | (X >> norm(loc=0, scale=1)) & (Y >> gamma(a=1)),
31 | 0.2*(X >> norm(loc=0, scale=1)) | 0.8*(X >> gamma(a=1)),
32 | ((X >> norm(loc=0, scale=1)) & (Y >> gamma(a=1))).constrain({Y:1}),
33 | ]
34 | @pytest.mark.parametrize('spe', spes)
35 | def test_serialize_equal(spe):
36 | metadata = spe_to_dict(spe)
37 | spe_json_encoded = json.dumps(metadata)
38 | spe_json_decoded = json.loads(spe_json_encoded)
39 | spe2 = spe_from_dict(spe_json_decoded)
40 | assert spe2 == spe
41 |
42 | transforms = [
43 | X,
44 | X**(1,3),
45 | Exponential(X, base=3),
46 | Logarithm(X, base=2),
47 | 2**Log(X),
48 | 1/Exp(X),
49 | abs(X),
50 | 1/X,
51 | 2*X + X**3,
52 | (X/2)*(X<0) + (X**(1,2))*(0<=X),
53 | X < sqrt(3),
54 | X << [],
55 | ~(X << []),
56 | EventFiniteNominal(1/X**(1,10), EmptySet),
57 | X << {1, 2},
58 | X << {'a', 'x'},
59 | ~(X << {'a', '1'}),
60 | (X < 3) | (X << {1,2}),
61 | (X < 3) & (X << {1,2}),
62 | ]
63 | @pytest.mark.parametrize('transform', transforms)
64 | def test_serialize_env(transform):
65 | spe = (X >> norm()).transform(Y, transform)
66 | metadata = spe_to_dict(spe)
67 | spe_json_encoded = json.dumps(metadata)
68 | spe_json_decoded = json.loads(spe_json_encoded)
69 | spe2 = spe_from_dict(spe_json_decoded)
70 | assert spe2 == spe
71 |
--------------------------------------------------------------------------------
/tests/test_spe_to_spml.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from sppl.compilers.spe_to_sppl import render_sppl
5 | from sppl.compilers.sppl_to_python import SPPL_Compiler
6 | from sppl.math_util import allclose
7 | from sppl.tests.test_render import get_model
8 |
9 | def test_render_sppl():
10 | model = get_model()
11 | sppl_code = render_sppl(model)
12 | compiler = SPPL_Compiler(sppl_code.getvalue())
13 | namespace = compiler.execute_module()
14 | (X, Y) = (namespace.X, namespace.Y)
15 | for i in range(5):
16 | assert allclose(model.logprob(Y << {'0'}), [
17 | model.logprob(Y << {str(i)}),
18 | namespace.model.logprob(Y << {str(i)})
19 | ])
20 | for i in range(4):
21 | assert allclose(
22 | model.logprob(X << {i}),
23 | namespace.model.logprob(X << {i}))
24 |
--------------------------------------------------------------------------------
/tests/test_spe_transform.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.distributions import choice
7 | from sppl.distributions import norm
8 | from sppl.distributions import poisson
9 | from sppl.math_util import allclose
10 | from sppl.transforms import Id
11 |
12 | def test_transform_real_leaf_logprob():
13 | X = Id('X')
14 | Y = Id('Y')
15 | Z = Id('Z')
16 | spe = (X >> norm(loc=0, scale=1))
17 |
18 | with pytest.raises(AssertionError):
19 | spe.transform(Z, Y**2)
20 | with pytest.raises(AssertionError):
21 | spe.transform(X, X**2)
22 |
23 | spe = spe.transform(Z, X**2)
24 | assert spe.env == {X:X, Z:X**2}
25 | assert spe.get_symbols() == {X, Z}
26 | assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1)
27 | assert spe.logprob((Z < 1) | ((X + 1) < 3)) \
28 | == spe.logprob((X**2 < 1) | ((X+1) < 3))
29 |
30 | spe = spe.transform(Y, 2*Z)
31 | assert spe.env == {X:X, Z:X**2, Y:2*Z}
32 | assert spe.logprob(Y**(1,3) < 10) \
33 | == spe.logprob((2*Z)**(1,3) < 10) \
34 | == spe.logprob((2*(X**2))**(1,3) < 10) \
35 |
36 | W = Id('W')
37 | spe = spe.transform(W, X > 1)
38 | assert allclose(spe.logprob(W), spe.logprob(X > 1))
39 |
40 | def test_transform_real_leaf_sample():
41 | X = Id('X')
42 | Z = Id('Z')
43 | Y = Id('Y')
44 | spe = (X >> poisson(loc=-1, mu=1))
45 | spe = spe.transform(Z, X+1)
46 | spe = spe.transform(Y, Z-1)
47 | samples = spe.sample(100)
48 | assert any(s[X] == -1 for s in samples)
49 | assert all(0 <= s[Z] for s in samples)
50 | assert all(s[Y] == s[X] for s in samples)
51 | assert all(spe.sample_func(lambda X,Y,Z: X-Y+Z==Z, 100))
52 | assert all(set(s) == {X,Y} for s in spe.sample_subset([X, Y], 100))
53 |
54 | def test_transform_sum():
55 | X = Id('X')
56 | Z = Id('Z')
57 | Y = Id('Y')
58 | spe \
59 | = 0.3*(X >> norm(loc=0, scale=1)) \
60 | | 0.7*(X >> choice({'0': 0.4, '1': 0.6}))
61 | with pytest.raises(Exception):
62 | # Cannot transform Nominal variate.
63 | spe.transform(Z, X**2)
64 | spe \
65 | = 0.3*(X >> norm(loc=0, scale=1)) \
66 | | 0.7*(X >> poisson(mu=2))
67 | spe = spe.transform(Z, X**2)
68 | assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1)
69 | assert spe.children[0].env == spe.children[1].env
70 | spe = spe.transform(Y, Z/2)
71 | assert spe.children[0].env \
72 | == spe.children[1].env \
73 | == {X:X, Z:X**2, Y:Z/2}
74 |
75 | def test_transform_product():
76 | X = Id('X')
77 | Y = Id('Y')
78 | W = Id('W')
79 | Z = Id('Z')
80 | V = Id('V')
81 | spe \
82 | = (X >> norm(loc=0, scale=1)) \
83 | & (Y >> poisson(mu=10))
84 | with pytest.raises(Exception):
85 | # Cannot use symbols from different transforms.
86 | spe.transform(W, (X > 0) | (Y << {'0'}))
87 | spe = spe.transform(W, (X**2 - 3*X)**(1,10))
88 | spe = spe.transform(Z, (W > 0) | (X**3 < 1))
89 | spe = spe.transform(V, Y/10)
90 | assert allclose(
91 | spe.logprob(W>1),
92 | spe.logprob((X**2 - 3*X)**(1,10) > 1))
93 | with pytest.raises(Exception):
94 | spe.tarnsform(Id('R'), (V>1) | (W < 0))
95 |
--------------------------------------------------------------------------------
/tests/test_sppl_to_python.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.compilers.sppl_to_python import SPPL_Compiler
7 |
8 | isclose = lambda a, b : abs(a-b) < 1e-10
9 |
10 | # Test errors in visit_Assign:
11 | # - [x] overwrite variable / array
12 | # - [x] unknown sample target array
13 | # - [x] assigning fresh variables in else
14 | # - [x] non-array variable in for
15 | # - [x] array target must not be subscript
16 | # - [x] assign invalid Python after sampling
17 | # - [x] assign invalid Python non-global
18 | # - [x] invalid left-hand side (tuple, unpacking, etc)
19 |
20 | overwrite_cases = [
21 | 'X ~= norm(); X ~= bernoulli(p=.5)',
22 | 'X = array(5); X ~= bernoulli(p=.5)',
23 | 'X ~= bernoulli(p=.5); X = array(5)',
24 | '''
25 | X ~= norm()
26 | if (X > 0):
27 | X ~= norm()
28 | else:
29 | X ~= norm()
30 | ''',
31 | 'X = array(5); W ~= norm(); X = array(10)',
32 | ]
33 | @pytest.mark.parametrize('case', overwrite_cases)
34 | def test_error_assign_overwrite_variable(case):
35 | with pytest.raises(AssertionError):
36 | SPPL_Compiler(case)
37 |
38 | def test_error_assign_unknown_array():
39 | source = '''
40 | X ~= uniform(loc=0, scale=1);
41 | Y[0] ~= bernoulli(p=.1)
42 | '''
43 | with pytest.raises(AssertionError):
44 | SPPL_Compiler(source)
45 | SPPL_Compiler('Y = array(10);\n%s' % (source,))
46 |
47 | def test_error_assign_fresh_variable_else():
48 | source = '''
49 | X ~= norm()
50 | if (X > 0):
51 | Y ~= 2*X
52 | else:
53 | W ~= 2*X
54 | '''
55 | with pytest.raises(AssertionError):
56 | SPPL_Compiler(source)
57 | SPPL_Compiler(source.replace('W', 'Y'))
58 |
59 | def test_error_assign_non_array_in_for():
60 | source = '''
61 | Y ~= norm();
62 | for i in range(10):
63 | X ~= norm()
64 | '''
65 | with pytest.raises(AssertionError):
66 | SPPL_Compiler(source)
67 | source_prime = 'X = array(5);\n%s' % (source.replace('X', 'X[i]'),)
68 | SPPL_Compiler(source_prime)
69 |
70 | def test_error_assign_array_subscript():
71 | source = '''
72 | Y = array(5)
73 | Y[0] = array(2)
74 | '''
75 | with pytest.raises(AssertionError):
76 | SPPL_Compiler(source)
77 | SPPL_Compiler(source.replace('array(2)', 'norm()'))
78 |
79 | def test_error_assign_py_constant_after_sampling():
80 | source = '''
81 | X ~= norm()
82 | Y = "foo"
83 | '''
84 | with pytest.raises(AssertionError):
85 | SPPL_Compiler(source)
86 |
87 | def test_error_assign_py_constant_non_global():
88 | source = '''
89 | X ~= norm()
90 | if (X > 0):
91 | Y = "ali"
92 | else:
93 | Y = norm()
94 | '''
95 | with pytest.raises(AssertionError):
96 | SPPL_Compiler(source)
97 |
98 | def test_error_assign_assign_invalid_lhs():
99 | with pytest.raises(AssertionError):
100 | SPPL_Compiler('[Y] ~= norm()')
101 | with pytest.raises(AssertionError):
102 | SPPL_Compiler('(X, Y) ~= norm()')
103 |
104 | # Test errors in visit_For:
105 | # - [x] func is not range.
106 | # - [x] func range has > 2 arguments
107 | # - [x] more than one iteration variable
108 |
109 | def test_error_for_func_not_range():
110 | source = '''
111 | X = array(10)
112 | for i in xrange(10):
113 | X[i] ~= norm()
114 | '''
115 | with pytest.raises(AssertionError):
116 | SPPL_Compiler(source)
117 | SPPL_Compiler(source.replace('xrange', 'range'))
118 |
119 | def test_error_for_range_step():
120 | source = '''
121 | X = array(10)
122 | X[0] ~= norm()
123 | for i in range(1, 10, 2):
124 | X[i] ~= 2*X[i-1]
125 | '''
126 | with pytest.raises(AssertionError):
127 | SPPL_Compiler(source)
128 | SPPL_Compiler(source.replace(', 2)' , ')'))
129 |
130 | def test_error_for_range_multiple_vars():
131 | source = '''
132 | X = array(10)
133 | X[0] ~= norm()
134 | for (i, j) in range(1, 9):
135 | X[i] ~= 2*X[i-1]
136 | '''
137 | with pytest.raises(AssertionError):
138 | SPPL_Compiler(source)
139 | SPPL_Compiler(source.replace('(i, j)' , 'i'))
140 |
141 | # Test errors in visit_If:
142 | # - [x] if without matching else/elif
143 | def test_error_if_no_else():
144 | source = '''
145 | X ~= norm()
146 | if (X > 0):
147 | Y ~= 2*X + 1
148 | '''
149 | with pytest.raises(AssertionError):
150 | SPPL_Compiler(source)
151 | SPPL_Compiler('%s\nelse:\n Y~= norm()' % (source,))
152 | SPPL_Compiler('%s\nelif (X<0):\n Y~= norm()' % (source,))
153 |
154 | # Test SPPL_Transformer
155 | # Q = Z in {0, 1} E = Z << {0, 1}
156 | # Q = Z not in {0, 1} Q = ~(Z << {0, 1})
157 | # Q = Z == 'foo' Q = Z << {'foo'}
158 | # B = Z != 'foo' Q = ~(Z << {'foo'})
159 |
160 | def test_transform_in():
161 | source = '''
162 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4})
163 | Y = X in {'foo', 'baz'}
164 | '''
165 | compiler = SPPL_Compiler(source)
166 | py_source = compiler.render_module()
167 | assert 'X << {\'foo\', \'baz\'}' in py_source
168 | assert 'X in' not in py_source
169 |
170 | def test_transform_in_not():
171 | source = '''
172 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4})
173 | Y = X not in {'foo', 'baz'}
174 | '''
175 | compiler = SPPL_Compiler(source)
176 | py_source = compiler.render_module()
177 | assert '~ (X << {\'foo\', \'baz\'})' in py_source
178 | assert 'not in' not in py_source
179 |
180 | def test_transform_eq():
181 | source = '''
182 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4})
183 | Y = X == 'foo'
184 | '''
185 | compiler = SPPL_Compiler(source)
186 | py_source = compiler.render_module()
187 | assert 'X << {\'foo\'}' in py_source
188 | assert '==' not in py_source
189 |
190 | def test_transform_eq_not():
191 | source = '''
192 | X ~= choice({'foo': .5, 'bar': .1, 'baz': .4})
193 | Y = X != 'foo'
194 | '''
195 | compiler = SPPL_Compiler(source)
196 | py_source = compiler.render_module()
197 | assert '~ (X << {\'foo\'})' in py_source
198 | assert '!=' not in py_source
199 |
200 | def test_compile_all_constructs():
201 | source = '''
202 | X = array(10)
203 | W = array(10)
204 | Y = randint(low=1, high=2)
205 | Z = bernoulli(p=0.1)
206 |
207 | E = choice({'1': 0.3, '2': 0.7})
208 |
209 |
210 | for i in range(1,5):
211 | W[i] = uniform(loc=0, scale=2)
212 | X[i] = bernoulli(p=0.5)
213 |
214 | X[0] ~= gamma(a=1)
215 | H = (X[0]**2 + 2*X[0] + 3)**(1, 10)
216 |
217 | # Here is a comment, with indentation on next line
218 |
219 |
220 |
221 | X[5] ~= 0.3*atomic(loc=0) | 0.4*atomic(loc=-1) | 0.3*atomic(loc=3)
222 | if X[5] == 0:
223 | X[7] = bernoulli(p=0.1)
224 | X[8] = 1 + X[3]
225 | elif (X[5]**2 == 1):
226 | X[7] = bernoulli(p=0.2)
227 | X[8] = 1 + X[3]
228 | else:
229 | if (X[3] in {0, 1}):
230 | X[7] = bernoulli(p=0.2)
231 | X[8] = 1 + X[3]
232 | else:
233 | X[7] = bernoulli(p=0.2)
234 | X[8] = 1 + X[3]
235 | '''
236 | compiler = SPPL_Compiler(source)
237 | namespace = compiler.execute_module()
238 | model = namespace.model
239 | assert isclose(model.prob(namespace.X[5] << {0}), .3)
240 | assert isclose(model.prob(namespace.X[5] << {-1}), .4)
241 | assert isclose(model.prob(namespace.X[5] << {3}), .3)
242 | assert isclose(model.prob(namespace.E << {'1'}), .3)
243 | assert isclose(model.prob(namespace.E << {'2'}), .7)
244 |
245 | def test_imports():
246 | source = '''
247 | Y ~= bernoulli(p=.5)
248 | Z ~= choice({str(i): Fraction(1, 5) for i in range(5)})
249 | X = array(5)
250 | for i in range(5):
251 | X[i] ~= Fraction(1,2) * Y
252 | '''
253 | compiler = SPPL_Compiler(source)
254 | with pytest.raises(NameError):
255 | compiler.execute_module()
256 | compiler = SPPL_Compiler('from fractions import Fraction\n%s' % (source,))
257 | namespace = compiler.execute_module()
258 | for i in range(5):
259 | assert isclose(namespace.model.prob(namespace.Z << {str(i)}), .2)
260 |
261 | def test_ifexp():
262 | source = '''
263 | from fractions import Fraction
264 | Y ~= choice({str(i): Fraction(1, 4) for i in range(4)})
265 | Z ~= (
266 | atomic(loc=0) if (Y in {'0', '1'}) else
267 | atomic(loc=4) if (Y == '2') else
268 | atomic(loc=6))
269 | '''
270 | compiler = SPPL_Compiler(source)
271 | assert 'IfElse' in compiler.render_module()
272 | namespace = compiler.execute_module()
273 | assert isclose(namespace.model.prob(namespace.Z << {0}), .5)
274 | assert isclose(namespace.model.prob(namespace.Z << {4}), .25)
275 | assert isclose(namespace.model.prob(namespace.Z << {6}), .25)
276 |
277 | def test_switch_shallow():
278 | source = '''
279 | Y ~= choice({'0': .25, '1': .5, '2': .25})
280 |
281 | switch (Y) cases (i in ['0', '1', '2']):
282 | Z ~= atomic(loc=int(i))
283 | '''
284 | compiler = SPPL_Compiler(source)
285 | namespace = compiler.execute_module()
286 | assert isclose(namespace.model.prob(namespace.Z << {0}), .25)
287 | assert isclose(namespace.model.prob(namespace.Z << {1}), .5)
288 | assert isclose(namespace.model.prob(namespace.Z << {2}), .25)
289 |
290 | def test_switch_enumerate():
291 | source = '''
292 | Y ~= choice({'0': .25, '1': .5, '2': .25})
293 |
294 | switch (Y) cases (i,j in enumerate(['0', '1', '2'])):
295 | Z ~= atomic(loc=i+int(j))
296 | '''
297 | compiler = SPPL_Compiler(source)
298 | namespace = compiler.execute_module()
299 | assert isclose(namespace.model.prob(namespace.Z << {0}), .25)
300 | assert isclose(namespace.model.prob(namespace.Z << {2}), .5)
301 | assert isclose(namespace.model.prob(namespace.Z << {4}), .25)
302 |
303 | def test_switch_nested():
304 | source = '''
305 | Y ~= randint(low=0, high=4)
306 | W ~= randint(low=0, high=2)
307 |
308 | switch (Y) cases (i in range(0, 5)):
309 | Z ~= choice({str(i): 1})
310 | switch (W) cases (i in range(0, 2)): V ~= atomic(loc=i)
311 | '''
312 | compiler = SPPL_Compiler(source)
313 | namespace = compiler.execute_module()
314 | assert isclose(namespace.model.prob(namespace.Z << {'0'}), .25)
315 | assert isclose(namespace.model.prob(namespace.Z << {'1'}), .25)
316 | assert isclose(namespace.model.prob(namespace.Z << {'2'}), .25)
317 | assert isclose(namespace.model.prob(namespace.Z << {'3'}), .25)
318 | assert isclose(namespace.model.prob(namespace.V << {0}), .5)
319 | assert isclose(namespace.model.prob(namespace.V << {1}), .5)
320 |
321 | def test_condition_simple():
322 | source = '''
323 | Y ~= norm(loc=0, scale=2)
324 | condition((0 < Y) < 2)
325 | Z ~= binom(n=10, p=.2)
326 | condition(Z == 0)
327 | '''
328 | compiler = SPPL_Compiler(source)
329 | namespace = compiler.execute_module()
330 | assert isclose(namespace.model.prob((0 < namespace.Y) < 2), 1)
331 |
332 | def test_condition_if():
333 | source = '''
334 | Y ~= norm(loc=0, scale=2)
335 | if (Y > 1):
336 | condition(Y < 1)
337 | else:
338 | condition(Y > -1)
339 | '''
340 | with pytest.raises(Exception):
341 | compiler = SPPL_Compiler(source)
342 | compiler.execute_module()
343 | compiler = SPPL_Compiler(source.replace('Y > 1', 'Y > 0'))
344 | namespace = compiler.execute_module()
345 | assert isclose(namespace.model.prob((-1 < namespace.Y) < 1), 1)
346 |
347 | def test_constrain_simple():
348 | source = '''
349 | Y ~= norm(loc=0, scale=2)
350 | Z ~= binom(n=10, p=.2)
351 | constrain({Z: 10, Y: 0})
352 | '''
353 | compiler = SPPL_Compiler(source)
354 | namespace = compiler.execute_module()
355 | assert isclose(namespace.model.prob((0 < namespace.Y) < 2), 0)
356 | assert isclose(namespace.model.prob((0 <= namespace.Y) < 2), 1)
357 | assert isclose(namespace.model.prob(namespace.Z << {10}), 1)
358 |
359 | def test_constrain_if():
360 | source = '''
361 | Y ~= norm(loc=0, scale=2)
362 | if (Y > 1):
363 | constrain({Y: 2})
364 | else:
365 | constrain({Y: 0})
366 | '''
367 | compiler = SPPL_Compiler(source)
368 | namespace = compiler.execute_module()
369 | assert isclose(namespace.model.prob(namespace.Y > 0), 0.3085375387259868)
370 | assert isclose(namespace.model.prob(namespace.Y < 0), 0)
371 | assert isclose(namespace.model.prob(namespace.Y << {0}), 1-0.3085375387259868)
372 |
373 | def test_constant_parameter():
374 | source = '''
375 | parameters = [1, 2, 3]
376 | n_array = 2
377 | Y = array(n_array + 1)
378 | for i in range(n_array + 1):
379 | Y[i] = randint(low=parameters[i], high=parameters[i]+1)
380 | '''
381 | compiler = SPPL_Compiler(source)
382 | namespace = compiler.execute_module()
383 | assert isclose(namespace.model.prob(namespace.Y[0] << {1}), 1)
384 | assert isclose(namespace.model.prob(namespace.Y[1] << {2}), 1)
385 | assert isclose(namespace.model.prob(namespace.Y[2] << {3}), 1)
386 | with pytest.raises(AssertionError):
387 | SPPL_Compiler('%sZ = "foo"\n' % (source,))
388 |
389 | def test_error_array_length():
390 | with pytest.raises(TypeError):
391 | SPPL_Compiler('Y = array(1.3)').execute_module()
392 | with pytest.raises(TypeError):
393 | SPPL_Compiler('Y = array(\'foo\')').execute_module()
394 | # Length zero array
395 | namespace = SPPL_Compiler('Y = array(-1)').execute_module()
396 | assert len(namespace.Y) == 0
397 |
--------------------------------------------------------------------------------
/tests/test_substitute.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sppl.transforms import Id
7 | from sppl.transforms import Logarithm
8 |
9 | X = Id('X')
10 | Z = Id('Z')
11 | Y = Id('Y')
12 |
13 | b = (1, 10)
14 | cases = [
15 | # Basic cases.
16 | [X , {} , X],
17 | [1/X , {} , 1/X],
18 | [X**2+X , {} , X**2+X],
19 | [X , {} , X],
20 | [X , {Z:X} , X],
21 | [Z , {Z:2**X} , 2**X],
22 | [((Z+1)**b) , {Z:1/X} , (1/X+1)**b],
23 | [((2**Z+1)**b) , {Z:X+1} , (2**(X+1)+1)**b],
24 | [((2**Logarithm(Z, 2)+1)**b) , {Z:X+1} , (2**Logarithm((X+1), 2)+1)**b] ,
25 | [((2**abs(Logarithm(Z, 2))+1)**b) , {Z:X+1} , (2**abs(Logarithm((X+1), 2))+1)**b],
26 | [((2**abs(Logarithm(1/Z, 2))+1)**b) , {Z:X+1} , (2**abs(Logarithm(1/(X+1), 2))+1)**b],
27 | # Compound cases.
28 | [(Z > 1)*(1/Z) + (Z < 1)*Z**b
29 | , {Z:X+1}
30 | , ((X+1) > 1)*(1/(X+1)) + ((X+1) < 1)*(X+1)**b],
31 | [(Z << {'a'}) & (Y < 3)
32 | , {Y:1/X}
33 | , (Z << {'a'}) & (1/X < 3)],
34 | [((Y > 1) | Z << {'a'}) & (Y < 3)
35 | , {Z : Y**2, Y:1/X}
36 | , (((1/X) > 1) | ((1/X)**2) << {'a'}) & ((1/X) < 3)],
37 | ]
38 | @pytest.mark.parametrize('case', cases)
39 | def test_substitute_basic(case):
40 | (expr, env, expr_prime) = case
41 | assert expr.substitute(env) == expr_prime
42 |
43 | @pytest.mark.parametrize('case', cases)
44 | def test_substitute_transitive(case):
45 | (expr, env, expr_prime) = case
46 | if len(env) == 1:
47 | [(s0, s1)] = env.items()
48 | s2 = Id('s2')
49 | env_prime = {s0: s2, s2: s1}
50 | assert expr.substitute(env_prime) == expr_prime
51 |
--------------------------------------------------------------------------------
/tests/test_sum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | from fractions import Fraction
5 | from math import log
6 |
7 | import pytest
8 |
9 | import numpy
10 |
11 | from sppl.distributions import choice
12 | from sppl.distributions import gamma
13 | from sppl.distributions import norm
14 | from sppl.math_util import allclose
15 | from sppl.math_util import isinf_neg
16 | from sppl.math_util import logsumexp
17 | from sppl.sets import FiniteNominal
18 | from sppl.sets import Interval
19 | from sppl.sets import inf as oo
20 | from sppl.spe import ContinuousLeaf
21 | from sppl.spe import ExposedSumSPE
22 | from sppl.spe import NominalLeaf
23 | from sppl.spe import ProductSPE
24 | from sppl.spe import SumSPE
25 | from sppl.transforms import Id
26 |
27 | def test_sum_normal_gamma():
28 | X = Id('X')
29 | weights = [
30 | log(Fraction(2, 3)),
31 | log(Fraction(1, 3))
32 | ]
33 | spe = SumSPE(
34 | [X >> norm(loc=0, scale=1), X >> gamma(loc=0, a=1),], weights)
35 |
36 | assert spe.logprob(X > 0) == logsumexp([
37 | spe.weights[0] + spe.children[0].logprob(X > 0),
38 | spe.weights[1] + spe.children[1].logprob(X > 0),
39 | ])
40 | assert spe.logprob(X < 0) == log(Fraction(2, 3)) + log(Fraction(1, 2))
41 | samples = spe.sample(100, prng=numpy.random.RandomState(1))
42 | assert all(s[X] for s in samples)
43 | spe.sample_func(lambda X: abs(X**3), 100)
44 | with pytest.raises(ValueError):
45 | spe.sample_func(lambda Y: abs(X**3), 100)
46 |
47 | spe_condition = spe.condition(X < 0)
48 | assert isinstance(spe_condition, ContinuousLeaf)
49 | assert spe_condition.conditioned
50 | assert spe_condition.logprob(X < 0) == 0
51 | samples = spe_condition.sample(100)
52 | assert all(s[X] < 0 for s in samples)
53 |
54 | assert spe.logprob(X < 0) == logsumexp([
55 | spe.weights[0] + spe.children[0].logprob(X < 0),
56 | spe.weights[1] + spe.children[1].logprob(X < 0),
57 | ])
58 |
59 | def test_sum_normal_gamma_exposed():
60 | X = Id('X')
61 | W = Id('W')
62 | weights = W >> choice({
63 | '0': Fraction(2,3),
64 | '1': Fraction(1,3),
65 | })
66 | children = {
67 | '0': X >> norm(loc=0, scale=1),
68 | '1': X >> gamma(loc=0, a=1),
69 | }
70 | spe = ExposedSumSPE(children, weights)
71 |
72 | assert spe.logprob(W << {'0'}) == log(Fraction(2, 3))
73 | assert spe.logprob(W << {'1'}) == log(Fraction(1, 3))
74 | assert allclose(spe.logprob((W << {'0'}) | (W << {'1'})), 0)
75 | assert spe.logprob((W << {'0'}) & (W << {'1'})) == -float('inf')
76 |
77 | assert allclose(
78 | spe.logprob((W << {'0', '1'}) & (X < 1)),
79 | spe.logprob(X < 1))
80 |
81 | assert allclose(
82 | spe.logprob((W << {'0'}) & (X < 1)),
83 | spe.weights[0] + spe.children[0].logprob(X < 1))
84 |
85 | spe_condition = spe.condition((W << {'1'}) | (W << {'0'}))
86 | assert isinstance(spe_condition, SumSPE)
87 | assert len(spe_condition.weights) == 2
88 | assert \
89 | allclose(spe_condition.weights[0], log(Fraction(2,3))) \
90 | and allclose(spe_condition.weights[0], log(Fraction(2,3))) \
91 | or \
92 | allclose(spe_condition.weights[1], log(Fraction(2,3))) \
93 | and allclose(spe_condition.weights[0], log(Fraction(2,3))
94 | )
95 |
96 | spe_condition = spe.condition((W << {'1'}))
97 | assert isinstance(spe_condition, ProductSPE)
98 | assert isinstance(spe_condition.children[0], NominalLeaf)
99 | assert isinstance(spe_condition.children[1], ContinuousLeaf)
100 | assert spe_condition.logprob(X < 5) == spe.children[1].logprob(X < 5)
101 |
102 | def test_sum_normal_nominal():
103 | X = Id('X')
104 | children = [
105 | X >> norm(loc=0, scale=1),
106 | X >> choice({'low': Fraction(3, 10), 'high': Fraction(7, 10)}),
107 | ]
108 | weights = [log(Fraction(4,7)), log(Fraction(3, 7))]
109 |
110 | spe = SumSPE(children, weights)
111 |
112 | assert allclose(
113 | spe.logprob(X < 0),
114 | log(Fraction(4,7)) + log(Fraction(1,2)))
115 |
116 | assert allclose(
117 | spe.logprob(X << {'low'}),
118 | log(Fraction(3,7)) + log(Fraction(3, 10)))
119 |
120 | # The semantics of ~(X<<{'low'}) are (X << String and X != 'low')
121 | assert allclose(
122 | spe.logprob(~(X << {'low'})),
123 | spe.logprob((X << {'high'})))
124 | assert allclose(
125 | spe.logprob((X<> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)],
19 | [log(0.4), log(0.6)]),
20 | X >> gamma(loc=0, a=1),
21 | ]
22 | spe = SumSPE(children, [log(0.7), log(0.3)])
23 | assert spe.size() == 4
24 | assert spe.children == (
25 | children[0].children[0],
26 | children[0].children[1],
27 | children[1]
28 | )
29 | assert allclose(spe.weights[0], log(0.7) + log(0.4))
30 | assert allclose(spe.weights[1], log(0.7) + log(0.6))
31 | assert allclose(spe.weights[2], log(0.3))
32 |
33 | def test_sum_simplify_nested_sum_2():
34 | X = Id('X')
35 | W = Id('W')
36 | children = [
37 | SumSPE([
38 | (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)),
39 | (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1))],
40 | [log(0.9), log(0.1)]),
41 | (X >> norm(loc=0, scale=4)) & (W >> norm(loc=0, scale=10)),
42 | SumSPE([
43 | (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)),
44 | (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1)),
45 | (X >> norm(loc=0, scale=8)) & (W >> norm(loc=0, scale=3)),],
46 | [log(0.4), log(0.3), log(0.3)]),
47 | ]
48 | spe = SumSPE(children, [log(0.4), log(0.4), log(0.2)])
49 | assert spe.size() == 19
50 | assert spe.children == (
51 | children[0].children[0], # 2 leaves
52 | children[0].children[1], # 2 leaves
53 | children[1], # 2 leaf
54 | children[2].children[0], # 2 leaves
55 | children[2].children[1], # 2 leaves
56 | children[2].children[2], # 2 leaves
57 | )
58 | assert allclose(spe.weights[0], log(0.4) + log(0.9))
59 | assert allclose(spe.weights[1], log(0.4) + log(0.1))
60 | assert allclose(spe.weights[2], log(0.4))
61 | assert allclose(spe.weights[3], log(0.2) + log(0.4))
62 | assert allclose(spe.weights[4], log(0.2) + log(0.3))
63 | assert allclose(spe.weights[5], log(0.2) + log(0.3))
64 |
65 | def test_sum_simplify_leaf():
66 | Xd0 = Id('X') >> norm(loc=0, scale=1)
67 | Xd1 = Id('X') >> norm(loc=0, scale=2)
68 | Xd2 = Id('X') >> norm(loc=0, scale=3)
69 | spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)])
70 | assert spe.size() == 4
71 | assert spe_simplify_sum(spe) == spe
72 |
73 | Xd0 = Id('X') >> norm(loc=0, scale=1)
74 | Xd1 = Id('X') >> norm(loc=0, scale=1)
75 | Xd2 = Id('X') >> norm(loc=0, scale=1)
76 | spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)])
77 | assert spe_simplify_sum(spe) == Xd0
78 |
79 | Xd3 = Id('X') >> norm(loc=0, scale=2)
80 | spe = SumSPE([Xd0, Xd3, Xd1, Xd3], [log(0.5), log(0.1), log(.3), log(.1)])
81 | spe_simplified = spe_simplify_sum(spe)
82 | assert len(spe_simplified.children) == 2
83 | assert spe_simplified.children[0] == Xd0
84 | assert spe_simplified.children[1] == Xd3
85 | assert allclose(spe_simplified.weights[0], log(0.8))
86 | assert allclose(spe_simplified.weights[1], log(0.2))
87 |
88 | def test_sum_simplify_product_collapse():
89 | A1 = Id('A') >> norm(loc=0, scale=1)
90 | A0 = Id('A') >> norm(loc=0, scale=1)
91 | B = Id('B') >> norm(loc=0, scale=1)
92 | B1 = Id('B') >> norm(loc=0, scale=1)
93 | B0 = Id('B') >> norm(loc=0, scale=1)
94 | C = Id('C') >> norm(loc=0, scale=1)
95 | C1 = Id('C') >> norm(loc=0, scale=1)
96 | D = Id('D') >> norm(loc=0, scale=1)
97 | spe = SumSPE([
98 | ProductSPE([A1, B, C, D]),
99 | ProductSPE([A0, B1, C, D]),
100 | ProductSPE([A0, B0, C1, D]),
101 | ], [log(0.4), log(0.4), log(0.2)])
102 | assert spe_simplify_sum(spe) == ProductSPE([A1, B, C, D])
103 |
104 | def test_sum_simplify_product_complex():
105 | A1 = Id('A') >> norm(loc=0, scale=1)
106 | A0 = Id('A') >> norm(loc=0, scale=2)
107 | B = Id('B') >> norm(loc=0, scale=1)
108 | B1 = Id('B') >> norm(loc=0, scale=2)
109 | B0 = Id('B') >> norm(loc=0, scale=3)
110 | C = Id('C') >> norm(loc=0, scale=1)
111 | C1 = Id('C') >> norm(loc=0, scale=2)
112 | D = Id('D') >> norm(loc=0, scale=1)
113 | spe = SumSPE([
114 | ProductSPE([A1, B, C, D]),
115 | ProductSPE([A0, B1, C, D]),
116 | ProductSPE([A0, B0, C1, D]),
117 | ], [log(0.4), log(0.4), log(0.2)])
118 |
119 | spe_simplified = spe_simplify_sum(spe)
120 | assert isinstance(spe_simplified, ProductSPE)
121 | assert isinstance(spe_simplified.children[0], SumSPE)
122 | assert spe_simplified.children[1] == D
123 |
124 | ssc0 = spe_simplified.children[0]
125 | assert isinstance(ssc0.children[1], ProductSPE)
126 | assert ssc0.children[1].children == (A0, B0, C1)
127 |
128 | assert isinstance(ssc0.children[0], ProductSPE)
129 | assert ssc0.children[0].children[1] == C
130 |
131 | ssc0c0 = ssc0.children[0].children[0]
132 | assert isinstance(ssc0c0, SumSPE)
133 | assert isinstance(ssc0c0.children[0], ProductSPE)
134 | assert isinstance(ssc0c0.children[1], ProductSPE)
135 | assert ssc0c0.children[0].children == (A1, B)
136 | assert ssc0c0.children[1].children == (A0, B1)
137 |
--------------------------------------------------------------------------------
/tests/test_sym_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MIT Probabilistic Computing Project.
2 | # See LICENSE.txt
3 |
4 | import pytest
5 |
6 | from sympy import exp as SymExp
7 | from sympy import log as SymLog
8 | from sympy import symbols
9 |
10 | from sppl.sets import FiniteReal
11 | from sppl.sym_util import get_symbols
12 | from sppl.sym_util import partition_finite_real_contiguous
13 | from sppl.sym_util import partition_list_blocks
14 |
15 | (X0, X1, X2, X3, X4, X5, X6, X7, X8, X9) = symbols('X:10')
16 |
17 | def test_get_symbols():
18 | syms = get_symbols((X0 > 3) & (X1 < 4))
19 | assert len(syms) == 2
20 | assert X0 in syms
21 | assert X1 in syms
22 |
23 | syms = get_symbols((SymExp(X0) > SymLog(X1)+10) & (X2 < 4))
24 | assert len(syms) == 3
25 | assert X0 in syms
26 | assert X1 in syms
27 | assert X2 in syms
28 |
29 | @pytest.mark.parametrize('a, b', [
30 | ([0, 1, 2, 3], [[0], [1], [2], [3]]),
31 | ([0, 1, 2, 1], [[0], [1, 3], [2]]),
32 | (['0', '0', 2, '0'], [[0, 1, 3], [2]]),
33 | ])
34 | def test_partition_list_blocks(a, b):
35 | solution = partition_list_blocks(a)
36 | assert solution == b
37 |
38 | @pytest.mark.parametrize('a, b', [
39 | (FiniteReal(0,1,2), [FiniteReal(0,1,2)]),
40 | (FiniteReal(0,3,1,2), [FiniteReal(0,1,2,3)]),
41 | (FiniteReal(-1,3,1,2), [FiniteReal(-1), FiniteReal(1,2,3)]),
42 | (FiniteReal(-1,3,1,2,-2,-7), [FiniteReal(-7), FiniteReal(-1,-2), FiniteReal(1,2,3)]),
43 | (FiniteReal(-1,3,1,2,-2,-7,0), [FiniteReal(-7), FiniteReal(-2,-1,0,1,2,3)]),
44 | ])
45 | def test_parition_finite_real_contiguous(a, b):
46 | solution = partition_finite_real_contiguous(a)
47 | assert solution == b
48 |
--------------------------------------------------------------------------------