├── .gitignore
├── .travis.yml
├── HACKING
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── check.sh
├── conda
├── meta.yaml
└── upload.sh
├── docker
├── README
├── ubuntu1604
└── ubuntu1604-int
├── images
├── linear_regression.png
├── one_view.png
├── smc.gif
└── spherical_gaussian.png
├── pythenv.sh
├── setup.py
├── src
├── __init__.py
├── cgpm.py
├── crosscat
│ ├── README.md
│ ├── __init__.py
│ ├── engine.py
│ ├── loomcat.py
│ ├── lovecat.py
│ ├── sampling.py
│ ├── state.py
│ └── statedoc.py
├── dummy
│ ├── README
│ ├── __init__.py
│ ├── barebones.py
│ ├── fourway.py
│ ├── piecewise.py
│ ├── trollnormal.py
│ └── twoway.py
├── factor
│ ├── README
│ ├── __init__.py
│ └── factor.py
├── kde
│ ├── README
│ ├── __init__.py
│ └── mvkde.py
├── knn
│ ├── README
│ ├── __init__.py
│ └── mvknn.py
├── mixtures
│ ├── README
│ ├── __init__.py
│ ├── dim.py
│ ├── relevance.py
│ └── view.py
├── network
│ ├── README
│ ├── __init__.py
│ ├── helpers.py
│ └── importance.py
├── primitives
│ ├── README
│ ├── __init__.py
│ ├── bernoulli.py
│ ├── beta.py
│ ├── categorical.py
│ ├── crp.py
│ ├── distribution.py
│ ├── exponential.py
│ ├── geometric.py
│ ├── lognormal.py
│ ├── normal.py
│ ├── normal_trunc.py
│ ├── poisson.py
│ └── vonmises.py
├── regressions
│ ├── README
│ ├── __init__.py
│ ├── forest.py
│ ├── linreg.py
│ └── ols.py
├── uncorrelated
│ ├── README
│ ├── __init__.py
│ ├── diamond.py
│ ├── directed.py
│ ├── dots.py
│ ├── linear.py
│ ├── parabola.py
│ ├── ring.py
│ ├── sin.py
│ ├── undirected.py
│ ├── uniformx.py
│ └── xcross.py
├── utils
│ ├── __init__.py
│ ├── config.py
│ ├── data.py
│ ├── entropy_estimators.py
│ ├── general.py
│ ├── mvnormal.py
│ ├── parallel_map.py
│ ├── plots.py
│ ├── render.py
│ ├── sampling.py
│ ├── test.py
│ ├── timer.py
│ └── validation.py
└── venturescript
│ ├── __init__.py
│ ├── vscgpm.py
│ └── vsinline.py
└── tests
├── __init__.py
├── conftest.py
├── disabled_test_loomcat.py
├── disabled_test_render_utils.py
├── disabled_test_simulate_univariate.py
├── disabled_test_uncorrelated_simulate.py
├── disabled_test_vsinline_determinism.py
├── graphical
├── animals.py
├── depprob_id.py
├── dpmm_nignormal.py
├── one_view.py
├── recover.py
├── resources
│ └── satellites.csv
├── slice.py
├── two_views.py
└── zero_corr.py
├── hacks.py
├── markers.py
├── stochastic.py
├── test_add_state.py
├── test_bernoulli.py
├── test_check_env_debug.py
├── test_cmi.py
├── test_cmi_partition.py
├── test_constr_crp.py
├── test_crp.py
├── test_dependence_constraints.py
├── test_dependence_probability.py
├── test_diagnostics.py
├── test_direct_network.py
├── test_engine_alter.py
├── test_engine_dimensions.py
├── test_engine_seed.py
├── test_exponentials_transition_hypers.py
├── test_factor_analysis.py
├── test_force_cell.py
├── test_forest.py
├── test_get_vqe.py
├── test_gpmcc_simple_composite.py
├── test_importance_helpers.py
├── test_impossible_evidence.py
├── test_incorporate_dim.py
├── test_incorporate_row.py
├── test_iter_counter.py
├── test_linreg.py
├── test_linreg_mixture.py
├── test_logpdf_score.py
├── test_logsumexp.py
├── test_lovecat.py
├── test_lw_rf.py
├── test_mvkde.py
├── test_mvknn.py
├── test_normal_categorical.py
├── test_ols.py
├── test_piecewise.py
├── test_populate_evidence.py
├── test_relevance.py
├── test_row_similarity.py
├── test_serialize.py
├── test_simulate_many_decorator.py
├── test_state_initialize.py
├── test_stochastic.py
├── test_teh_murphy.py
├── test_type_check.py
├── test_update_cctype.py
├── test_view_logpdf_cluster.py
├── test_vscgpm.py
└── test_vsinline.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | env/
11 | bin/
12 | build/
13 | develop-eggs/
14 | dist/
15 | eggs/
16 | lib/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 |
25 | # Installer logs
26 | pip-log.txt
27 | pip-delete-this-directory.txt
28 |
29 | # Unit test / coverage reports
30 | htmlcov/
31 | .tox/
32 | .coverage
33 | .cache
34 | nosetests.xml
35 | coverage.xml
36 |
37 | # Translations
38 | *.mo
39 |
40 | # Mr Developer
41 | .mr.developer.cfg
42 | .project
43 | .pydevproject
44 |
45 | # Rope
46 | .ropeproject
47 |
48 | # Django stuff:
49 | *.log
50 | *.pot
51 |
52 | # Sphinx documentation
53 | docs/_build/
54 |
55 | # gpmcc artifacts
56 | *.engine*
57 | *.bak*
58 |
59 | # numpy binaries
60 | *.npy
61 |
62 | # latex artifacts
63 | *.aux
64 | *.bbl
65 | *.blg
66 | *.fdb_latexmk
67 | *.fls
68 | *.synctex.gz
69 | *_minted-*
70 | *.pdf
71 |
72 | # specific directories
73 | examples/malawi/resources/
74 | examples/satellites/splits/
75 | examples/swallows/resources/
76 | examples/swallows/splits/
77 | examples/swallows/univariate.py
78 | experiments/resources/
79 | experiments/*/resources/
80 | cgpm
81 | src/debug
82 | src/venturescript/resources
83 | notes/gpm/vs/vsgpm.pdf
84 | tests/resources
85 | /src/version.py
86 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | dist: trusty
3 | env:
4 | global:
5 | # get all the branches referencing this commit
6 | - REAL_BRANCH=$(git ls-remote origin | sed -n "\|$TRAVIS_COMMIT\s\+refs/heads/|{s///p}")
7 |
8 | python:
9 | - 2.7
10 | install:
11 | - wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh
12 | - bash miniconda.sh -b -p $HOME/miniconda
13 | - export PATH="$HOME/miniconda/bin:$PATH"
14 | - hash -r
15 | - conda config --set always_yes yes --set changeps1 no
16 | - conda update -q conda
17 | - conda install -q conda=4.6.14 conda-build=3.13.0
18 | before_script:
19 | - gem install travis
20 | script:
21 | - export CONDA_PACKAGE_VERSION="${TRAVIS_TAG:-$(date +%Y.%m.%d)}"
22 | # remove leading v from tags if they exist
23 | - CONDA_PACKAGE_VERSION="$(sed s/^v// <<<$CONDA_PACKAGE_VERSION)"
24 | - conda build . -c probcomp -c anaconda
25 | after_success:
26 | - bash conda/upload.sh
27 |
--------------------------------------------------------------------------------
/HACKING:
--------------------------------------------------------------------------------
1 | Notes For cgpm Hackers
2 |
3 | * Branches
4 |
5 | Should be [year][month][day]-[username]-[branch-description...], for example a
6 | well named branch is 20160621-fsaad-update-hacking.
7 |
8 | Please do not delete branches, history should be maintained.
9 |
10 | * Python coding style
11 |
12 | Generally follow PEP 8, 80-char max, with these exceptions:
13 |
14 | - New line, instead of alignment, for continuation lines. For function
15 | definitions, no new line and use eight spaces.
16 |
17 | Example: Yes
18 | model = cgpm.crosscat.state.State(
19 | X, outputs=[1,2,3,4], inputs=None,
20 | cctypes=['normal', 'bernoulli', 'poisson', 'lognormal'],
21 | rng=np.random.RandomState(0))
22 |
23 | Example: No
24 | model = cgpm.crosscat.state.State(X, outputs=[1,2,3], inputs=None,
25 | cctypes=['normal', 'bernoulli', 'poisson',
26 | 'lognormal'],
27 | rng=np.random.RandomState(0))
28 |
29 | Example: Yes, Preferred
30 | def generate_mh_sample(
31 | x, logpdf_target, jump_std, D, num_samples=1,
32 | num_burn=1, num_chains=7 num_lag=1, rng=None):
33 | ...body...
34 |
35 | Example: Yes
36 | def generate_mh_sample(x, logpdf_target, jump_std, D, num_samples=1,
37 | num_burn=1, num_chains=7 num_lag=1, rng=None):
38 | ...body...
39 |
40 | Example: No
41 | def generate_mh_sample(x, logpdf_target, jump_std, D, num_samples=1,
42 | num_burn=1, num_chains=7 num_lag=1, rng=None):
43 | ...body...
44 |
45 | - Generally use single-quoted strings, except use """ for doc strings.
46 |
47 | * Python imports
48 |
49 | Should be organized as follows:
50 |
51 | [standard library imports]
52 | - blank line -
53 | [third-party library imports]
54 | - blank line -
55 | [sister projects library imports]
56 | - blank line -
57 | [cgpm module imports]
58 |
59 | Each import block should be organized alphabetically: first all unqualified
60 | imports (import baz), then all named imports (from foo import nix).
61 | For example
62 |
63 | import math
64 | import multiprocessing as mp
65 |
66 | from array import ArrayType
67 | from struct import pack
68 |
69 | import numpy as np
70 |
71 | from scipy.misc import logsumexp
72 | from scipy.stats import geom
73 | from scipy.stats import norm
74 |
75 | import bayeslite.core
76 | import bayeslite.math_util
77 |
78 | from gpmcc.crosscat.state import State
79 | from gpmcc.utils import general as gu
80 |
81 | * Testing
82 |
83 | The tip of every branch merged into master __must__ pass ./check.sh, and be
84 | consistent with the code conventions here. New functionality must always be
85 | associated with test -- fixing bugs should preferably include a test as well
86 | (less strict).
87 |
88 | * Entropy
89 |
90 | Please, never, ever used global random state. Every source of random bits must
91 | be managed explicitly.
92 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE.txt
2 | include README.md
3 | include check.sh
4 | include pythenv.sh
5 | include tests/graphical/animals.py
6 | include tests/graphical/depprob_id.py
7 | include tests/graphical/dpmm_nignormal.py
8 | include tests/graphical/one_view.py
9 | include tests/graphical/recover.py
10 | include tests/graphical/resources/satellites.csv
11 | include tests/graphical/slice.py
12 | include tests/graphical/two_views.py
13 | include tests/graphical/zero_corr.py
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cgpm
2 |
3 | [](https://travis-ci.org/probcomp/cgpm)
4 |
5 | The aim of this project is to provide a unified probabilistic programming
6 | framework to express different models and techniques from statistics, machine
7 | learning and non-parametric Bayes. It serves as the primary modeling and
8 | inference runtime system for [bayeslite](https://github.com/probcomp/bayeslite),
9 | an open-source implementation of BayesDB.
10 |
11 | Composable generative population models (CGPM) are a computational abstraction
12 | for probabilistic objects. They provide an interface that explicitly
13 | differentiates between the _sampler_ of a random variable from its conditional
14 | distribution and the _assessor_ of its conditional density. By encapsulating
15 | models as probabilistic programs that implement CGPMs, complex models can be
16 | built as compositions of sub-CGPMs, and queried in a model-independent way
17 | using the Bayesian Query Language.
18 |
19 | ## Installing
20 |
21 | ### Conda
22 |
23 | The easiest way to install cgpm is to use the
24 | [package](https://anaconda.org/probcomp/cgpm) on Anaconda Cloud.
25 | Please follow [these instructions](https://github.com/probcomp/iventure/blob/master/docs/conda.md).
26 |
27 | ### Manual Build
28 |
29 | `cgpm` targets Ubuntu 14.04 and 16.04. The package can be installed by cloning
30 | this repository and following these instructions. It is _highly recommended_ to
31 | install `cgpm` inside of a virtualenv which was created using the
32 | `--system-site-packages` flag.
33 |
34 | 1. Install dependencies from `apt`, [listed here](https://github.com/probcomp/cgpm/blob/71fe62790f466e9dd2149d0f527c584cce19e70f/docker/ubuntu1604#L4-L14).
35 |
36 | 2. Retrieve and build the source.
37 |
38 | ```
39 | % git clone git@github.com:probcomp/cgpm
40 | % cd cgpm
41 | % pip install --no-deps .
42 | ```
43 |
44 | 3. Verify the installation.
45 |
46 | ```
47 | % python -c 'import cgpm'
48 | % cd cgpm && ./check.sh
49 | ```
50 |
51 | ## Publications
52 |
53 | CGPMs, and their integration as a runtime system for
54 | [BayesDB](probcomp.csail.mit.edu/bayesdb/), are described in the following
55 | technical report:
56 |
57 | - __Probabilistic Data Analysis with Probabilistic Programming__.
58 | Saad, F., and Mansinghka, V. [_arXiv preprint, arXiv:1608.05347_](https://arxiv.org/abs/1608.05347), 2017.
59 |
60 | Applications of using cgpm and bayeslite for data analysis tasks can be further
61 | found in:
62 |
63 | - __Probabilistic Search for Structured Data via Probabilistic Programming and Nonparametric Bayes__.
64 | Saad, F. Casarsa, L., and Mansinghka, V. [_arXiv preprint, arXiv:1704.01087_](https://arxiv.org/abs/1704.01087), 2017.
65 |
66 | - __Detecting Dependencies in Sparse, Multivariate Databases Using Probabilistic Programming and Non-parametric Bayes__.
67 | Saad, F., and Mansinghka, V. [_Artificial Intelligence and Statistics (AISTATS)_](http://proceedings.mlr.press/v54/saad17a.html), 2017.
68 |
69 | - __A Probabilistic Programming Approach to Probabilistic Data Analysis__.
70 | Saad, F., and Mansinghka, V. [_Advances in Neural Information Processing Systems (NIPS)_](https://papers.nips.cc/paper/6060-a-probabilistic-programming-approach-to-probabilistic-data-analysis.html), 2016.
71 |
72 |
73 | ## Tests
74 |
75 | Running `./check.sh` will run a subset of the tests that are considered complete
76 | and stable. To launch the full test suite, including continuous integration
77 | tests, run `py.test` in the root directory. There are more tests in the `tests/`
78 | directory, but those that do not start with `test_` or do start with `disabled_`
79 | are not considered ready. The tip of every branch merged into master __must__
80 | pass `./check.sh`, and be consistent with the code conventions outlined in
81 | [HACKING](HACKING).
82 |
83 | To run the full test suite, use `./check.sh --integration tests/`. Note that the
84 | full integration test suite requires installing the C++
85 | [crosscat](https://github.com/probcomp/crosscat) backend.
86 |
87 | ## License
88 |
89 | Copyright (c) 2015-2016 MIT Probabilistic Computing Project
90 |
91 | Licensed under the Apache License, Version 2.0 (the "License");
92 | you may not use this file except in compliance with the License.
93 | You may obtain a copy of the License at
94 |
95 | http://www.apache.org/licenses/LICENSE-2.0
96 |
97 | Unless required by applicable law or agreed to in writing, software
98 | distributed under the License is distributed on an "AS IS" BASIS,
99 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100 | See the License for the specific language governing permissions and
101 | limitations under the License.
102 |
--------------------------------------------------------------------------------
/check.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -Ceu
4 |
5 | : ${PYTHON:=python}
6 |
7 | root=`cd -- "$(dirname -- "$0")" && pwd`
8 |
9 | (
10 | set -Ceu
11 | cd -- "${root}"
12 | rm -rf build
13 | "$PYTHON" setup.py build
14 | if [ $# -eq 0 ]; then
15 | # By default, when running all tests, skip tests that have
16 | # been marked for continuous integration by using __ci_ in
17 | # their names. (git grep __ci_ to find these.)
18 | ./pythenv.sh "$PYTHON" -m pytest --pyargs cgpm -k "not __ci_"
19 | else
20 | # If args are specified, run all tests, including continuous
21 | # integration tests, for the selected components.
22 | ./pythenv.sh "$PYTHON" -m pytest "$@"
23 | fi
24 | )
25 |
--------------------------------------------------------------------------------
/conda/meta.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: cgpm
3 | version: {{ CONDA_PACKAGE_VERSION }}
4 |
5 | source:
6 | path: ../
7 |
8 | build:
9 | script: python setup.py install
10 |
11 | requirements:
12 | build:
13 | - git
14 | - matplotlib 1.5.*
15 | - numpy 1.11.*
16 | - python 2.7.*
17 | run:
18 | - nomkl
19 | - matplotlib 1.5.*
20 | - numpy 1.11.*
21 | - pandas 0.18.*
22 | - python 2.7.*
23 | - scikit-learn 0.17.*
24 | - scipy 0.17.*
25 | - statsmodels 0.6.*
26 |
27 | test:
28 | requires:
29 | - matplotlib 1.5.*
30 | - pytest 2.8.*
31 | - python 2.7.*
32 | imports:
33 | - cgpm
34 | commands:
35 | - python -m pytest --pyargs cgpm
36 |
37 | about:
38 | home: https://github.com/probcomp/cgpm
39 | license: Apache
40 | license_file: LICENSE.txt
41 |
--------------------------------------------------------------------------------
/conda/upload.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -ev
3 |
4 | # fyi, the logic below is necessary due to the fact that on a tagged build, TRAVIS_BRANCH and TRAVIS_TAG are the same
5 | # in the case of a tagged build, use the REAL_BRANCH env var defined in travis.yml
6 | if [ -n "${TRAVIS_TAG}" ]; then
7 | conda install anaconda-client
8 | # if tag didn't come from master, add the "dev" label
9 | if [ ${REAL_BRANCH} = "master" ]; then
10 | anaconda -t ${CONDA_UPLOAD_TOKEN} upload -u ${CONDA_USER} ~/miniconda/conda-bld/linux-64/cgpm-*.tar.bz2 --force
11 | else
12 | anaconda -t ${CONDA_UPLOAD_TOKEN} upload -u ${CONDA_USER} -l dev ~/miniconda/conda-bld/linux-64/cgpm-*.tar.bz2 --force
13 | fi
14 | elif [ ${TRAVIS_BRANCH} = "master" ]; then
15 | if [ ${TRAVIS_EVENT_TYPE} = "cron" ]; then
16 | # don't build package for nightly cron.. this is just for test stability info
17 | exit 0
18 | else
19 | conda install anaconda-client
20 | anaconda -t ${CONDA_UPLOAD_TOKEN} upload -u ${CONDA_USER} -l edge ~/miniconda/conda-bld/linux-64/cgpm-*.tar.bz2 --force
21 | # trigger a downstream bayeslite build using the edge package
22 | curl -LO https://raw.github.com/stephanmg/travis-dependent-builds/master/trigger.sh
23 | bash trigger.sh probcomp bayeslite master $TRAVIS_ACCESS_TOKEN | grep -v TOKEN
24 | fi
25 | else
26 | exit 0
27 | fi
28 |
--------------------------------------------------------------------------------
/docker/README:
--------------------------------------------------------------------------------
1 | This directory contains Dockerfiles that build and install cgpm.
2 |
3 | Note that the file ubuntu1404-int is used as part of an integration test on the
4 | probcomp Jenkins test server at https://probcomp-4.csail.mit.edu. This file is
5 | not appropriate for being run in isolation.
6 |
--------------------------------------------------------------------------------
/docker/ubuntu1604:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 | MAINTAINER MIT Probabilistic Computing Project
3 |
4 | RUN apt-get update -qq \
5 | && apt-get install -qq -y \
6 | git \
7 | python-matplotlib \
8 | python-numpy \
9 | python-pandas \
10 | python-pytest \
11 | python-scipy \
12 | python-sklearn \
13 | python-statsmodels
14 |
15 | ADD . /cgpm
16 | WORKDIR /cgpm
17 | RUN ./check.sh tests
18 | RUN python setup.py sdist
19 | RUN python setup.py bdist
20 |
--------------------------------------------------------------------------------
/docker/ubuntu1604-int:
--------------------------------------------------------------------------------
1 | # Integration tests.
2 |
3 | FROM ubuntu:16.04
4 | MAINTAINER MIT Probabilistic Computing Project
5 |
6 | RUN apt-get update -qq \
7 | && apt-get upgrade -qq \
8 | && apt-get install -y \
9 | build-essential \
10 | ccache \
11 | git \
12 | libboost-all-dev \
13 | libgsl0-dev \
14 | python-flask \
15 | python-jsonschema \
16 | python-matplotlib \
17 | python-nose \
18 | python-nose-testconfig \
19 | python-numpy \
20 | python-pandas \
21 | python-pexpect \
22 | python-pytest \
23 | python-requests \
24 | python-scipy \
25 | python-six \
26 | python-sklearn \
27 | python-statsmodels
28 |
29 | ADD . /cgpm
30 | WORKDIR /cgpm
31 | RUN \
32 | ./docker/deps/bayeslite-apsw/pythenv.sh \
33 | ./docker/deps/bayeslite/pythenv.sh \
34 | ./docker/deps/crosscat/pythenv.sh \
35 | ./docker/deps/venture/pythenv.sh \
36 | ./check.sh --integration tests
37 | RUN python setup.py sdist
38 | RUN python setup.py bdist
39 |
--------------------------------------------------------------------------------
/images/linear_regression.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probcomp/cgpm/56a481829448bddc9cdfebd42f65023287d5b7c7/images/linear_regression.png
--------------------------------------------------------------------------------
/images/one_view.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probcomp/cgpm/56a481829448bddc9cdfebd42f65023287d5b7c7/images/one_view.png
--------------------------------------------------------------------------------
/images/smc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probcomp/cgpm/56a481829448bddc9cdfebd42f65023287d5b7c7/images/smc.gif
--------------------------------------------------------------------------------
/images/spherical_gaussian.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/probcomp/cgpm/56a481829448bddc9cdfebd42f65023287d5b7c7/images/spherical_gaussian.png
--------------------------------------------------------------------------------
/pythenv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -Ceu
4 |
5 | : ${PYTHON:=python}
6 | root=`cd -- "$(dirname -- "$0")" && pwd`
7 | platform=`"${PYTHON}" -c 'import distutils.util as u; print u.get_platform()'`
8 | version=`"${PYTHON}" -c 'import sys; print sys.version[0:3]'`
9 |
10 | # The lib directory varies depending on
11 | #
12 | # (a) whether there are extension modules (here, no); and
13 | # (b) whether some Debian maintainer decided to patch the local Python
14 | # to behave as though there were.
15 | #
16 | # But there's no obvious way to just ask distutils what the name will
17 | # be. There's no harm in naming a pathname that doesn't exist, other
18 | # than a handful of microseconds of runtime, so we'll add both.
19 | libdir="${root}/build/lib"
20 | plat_libdir="${libdir}.${platform}-${version}"
21 | export PYTHONPATH="${libdir}:${plat_libdir}${PYTHONPATH:+:${PYTHONPATH}}"
22 |
23 | bindir="${root}/build/scripts-${version}"
24 | export PATH="${bindir}${PATH:+:${PATH}}"
25 |
26 | exec "$@"
27 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .version import __version__
18 |
--------------------------------------------------------------------------------
/src/crosscat/README.md:
--------------------------------------------------------------------------------
1 | # gpmcc
2 |
3 | Implementation of [crosscat](http://probcomp.csail.mit.edu/crosscat/) from
4 | the lens of generative population models (GPMs). The goal is to express the
5 | hierarchial generative process that defines crosscat as a composition of
6 | modules that follow the GPM interface.
7 |
8 | ## Research Goals
9 |
10 | gpmcc aims to implement all the features that exist in current crosscat
11 | implementations, as well as new features not available in vanilla crosscat.
12 | Key ideas on the development roadmap are:
13 |
14 | - Interface that permits key constructs of the Metamodeling Language (MML),
15 | such as:
16 | - Suggesting column dependencies.
17 | - Suggesting row dependencies, with respect to a subset of columns.
18 | - Sequential incorporate/unincorporate of (partial) observations
19 | interleaved with analysis.
20 | - Targeted analysis over the crosscat inference kernels.
21 | - Column-specific datatype constraints (`REAL`, `POSITIVE`,
22 | `IN-RANGE(min,max)`, `CATEGORICAL`, `ORDINAL`, etc).
23 |
24 | - Sequential Monte Carlo (SMC) implementation of the posterior inference
25 | algorithm described in [Mansinghka, et
26 | al.](http://arxiv.org/pdf/1512.01272.pdf) Section 2.4, as opposed to
27 | observe-all then Gibbs forever.
28 |
29 | - Interface for the Bayesian Query Language (BQL) and
30 | [bayeslite](https://github.com/probcomp/bayeslite) integration, with new
31 | BQL additions such as:
32 | - Conditional mutual information.
33 | - KL-divergence of predictive distribution against synthetic GPMs.
34 | - Marginal likelihood estimates of real datasets.
35 |
36 | - Interface for foreign GPMs that are jointly analyzed with crosscat.
37 | Current implementations only allow foreign GPMs to be composed at query,
38 | not analysis, time.
39 |
40 | - Subsampling, where each model is responsible for a subset of data from an
41 | overlapping partition of the overall dataset.
42 |
43 | - Multiprocessing for analysis. Distributed?
44 |
45 | - Several DistributionGpms for the different MML data-types, not just
46 | Normal and Multinomial.
47 |
48 | ## Static Example
49 |
50 | The simplest example is creating a synthetic dataset where each variable is a
51 | mixture of one of the available DistributionGpms. Inference is run using, an
52 | extended implementation of CrossCat from the lens of compositions of composable
53 | generative population models.
54 |
55 | ```
56 | $ python -i tests/graphical/one_view.py
57 | ```
58 |
59 | A plot similar to  should appear.
60 |
61 | ## Interactive Example (Experimental)
62 |
63 | Single-particle SMC in a DP Mixture with Normal-InverseGamma base measure and
64 | normal obervations can be run interactively:
65 |
66 | ```
67 | $ python -i tests/graphical/dpmm_nignormal.py
68 | ```
69 |
70 | Click on the graph to produce observations and watch, the Gibbs kernel cycle
71 | through the hypothesis space
72 |
73 | 
74 |
75 | The values printed in the console after each click are estimates of the
76 | marginal-log-likelihood of observations, based on the single particle
77 | weight. The following output
78 |
79 | ```
80 | Observation 8.000000: 0.209677
81 | [-8.0740236375201153]
82 | ```
83 |
84 | means the eighth observation is 0.209677, and the estimated marginal
85 | log-liklelihood is -8.0740236375201153.
86 |
--------------------------------------------------------------------------------
/src/crosscat/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/crosscat/sampling.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | '''
18 | This module contains implementations of simulate and logpdf specialized to
19 | cgpm.crosscat, avoiding overhead of recursive implementations using the
20 | importance network on the sub-cgpms that comprise cgpm.crosscat.State.
21 | '''
22 |
23 | from itertools import chain
24 |
25 | import numpy as np
26 |
27 | from cgpm.primitives.crp import Crp
28 |
29 | from cgpm.utils.general import log_normalize
30 | from cgpm.utils.general import log_pflip
31 | from cgpm.utils.general import logsumexp
32 | from cgpm.utils.general import merged
33 |
34 | from cgpm.utils.validation import partition_query_evidence
35 |
36 |
37 | def state_logpdf(state, rowid, targets, constraints=None):
38 | targets_lookup, constraints_lookup = partition_query_evidence(
39 | state.Zv(), targets, constraints)
40 | logps = (
41 | view_logpdf(
42 | view=state.views[v],
43 | rowid=rowid,
44 | targets=targets_lookup[v],
45 | constraints=constraints_lookup.get(v, {})
46 | )
47 | for v in targets_lookup
48 | )
49 | return sum(logps)
50 |
51 |
52 | def state_simulate(state, rowid, targets, constraints=None, N=None):
53 | targets_lookup, constraints_lookup = partition_query_evidence(
54 | state.Zv(), targets, constraints)
55 | N_sim = N if N is not None else 1
56 | draws = (
57 | view_simulate(
58 | view=state.views[v],
59 | rowid=rowid,
60 | targets=targets_lookup[v],
61 | constraints=constraints_lookup.get(v, {}),
62 | N=N_sim
63 | )
64 | for v in targets_lookup
65 | )
66 | samples = [merged(*l) for l in zip(*draws)]
67 | return samples if N is not None else samples[0]
68 |
69 |
70 | def view_logpdf(view, rowid, targets, constraints):
71 | if not view.hypothetical(rowid):
72 | return _logpdf_row(view, targets, view.Zr(rowid))
73 | Nk = view.Nk()
74 | N_rows = len(view.Zr())
75 | K = view.crp.clusters[0].gibbs_tables(-1)
76 | lp_crp = [Crp.calc_predictive_logp(k, N_rows, Nk, view.alpha()) for k in K]
77 | lp_constraints = [_logpdf_row(view, constraints, k) for k in K]
78 | if all(np.isinf(lp_constraints)):
79 | raise ValueError('Zero density constraints: %s' % (constraints,))
80 | lp_cluster = log_normalize(np.add(lp_crp, lp_constraints))
81 | lp_targets = [_logpdf_row(view, targets, k) for k in K]
82 | return logsumexp(np.add(lp_cluster, lp_targets))
83 |
84 |
85 | def view_simulate(view, rowid, targets, constraints, N):
86 | if not view.hypothetical(rowid):
87 | return _simulate_row(view, targets, view.Zr(rowid), N)
88 | Nk = view.Nk()
89 | N_rows = len(view.Zr())
90 | K = view.crp.clusters[0].gibbs_tables(-1)
91 | lp_crp = [Crp.calc_predictive_logp(k, N_rows, Nk, view.alpha()) for k in K]
92 | lp_constraints = [_logpdf_row(view, constraints, k) for k in K]
93 | if all(np.isinf(lp_constraints)):
94 | raise ValueError('Zero density constraints: %s' % (constraints,))
95 | lp_cluster = np.add(lp_crp, lp_constraints)
96 | ks = log_pflip(lp_cluster, array=K, size=N, rng=view.rng)
97 | counts = {k:n for k,n in enumerate(np.bincount(ks)) if n > 0}
98 | samples = (_simulate_row(view, targets, k, counts[k]) for k in counts)
99 | return chain.from_iterable(samples)
100 |
101 |
102 | def _logpdf_row(view, targets, cluster):
103 | """Return joint density of the targets in a fixed cluster."""
104 | return sum(
105 | view.dims[c].logpdf(None, {c:x}, None, {view.outputs[0]: cluster})
106 | for c, x in targets.iteritems()
107 | )
108 |
109 |
110 | def _simulate_row(view, targets, cluster, N):
111 | """Return sample of the targets in a fixed cluster."""
112 | samples = (
113 | view.dims[c].simulate(None, [c], None, {view.outputs[0]: cluster}, N)
114 | for c in targets
115 | )
116 | return (merged(*l) for l in zip(*samples))
117 |
--------------------------------------------------------------------------------
/src/dummy/README:
--------------------------------------------------------------------------------
1 | These CGPMs are dummies, used mainly for test and exploratory purposes.
2 |
3 |
--------------------------------------------------------------------------------
/src/dummy/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/dummy/barebones.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """A barebones cgpm only outputs, inputs, serialization, and not much else."""
18 |
19 | from cgpm.cgpm import CGpm
20 |
21 | class BareBonesCGpm(CGpm):
22 |
23 | def __init__(self, outputs, inputs, distargs=None, rng=None):
24 | self.outputs = outputs
25 | self.inputs = inputs
26 |
27 | def incorporate(self, rowid, observation, inputs=None):
28 | return
29 |
30 | def unincorporate(self, rowid):
31 | return
32 |
33 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
34 | return 0
35 |
36 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
37 | result = {i:1 for i in self.outputs}
38 | return result if N is None else [result] * N
39 |
40 | def transition(self, **kwargs):
41 | return
42 |
43 | def to_metadata(self):
44 | metadata = dict()
45 | metadata['outputs'] = self.outputs
46 | metadata['inputs'] = self.inputs
47 | metadata['factory'] = ('cgpm.dummy.barebones', 'BareBonesCGpm')
48 | return metadata
49 |
50 | @classmethod
51 | def from_metadata(cls, metadata, rng=None):
52 | return cls(metadata['outputs'], metadata['inputs'])
53 |
--------------------------------------------------------------------------------
/src/dummy/fourway.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import time
18 |
19 | import numpy as np
20 |
21 | from cgpm.cgpm import CGpm
22 | from cgpm.utils import general as gu
23 |
24 |
25 | class FourWay(CGpm):
26 | """Outputs categorical(4) (quadrant indicator) on R2 valued input."""
27 |
28 | def __init__(self, outputs, inputs, distargs=None, rng=None):
29 | if rng is None:
30 | rng = gu.gen_rng(1)
31 | self.rng = rng
32 | self.probabilities =[
33 | [.70, .10, .05, .05],
34 | [.10, .80, .10, .10],
35 | [.10, .15, .65, .10],
36 | [.10, .05, .10, .75],
37 | ]
38 | assert len(outputs) == 1
39 | assert len(inputs) == 2
40 | self.outputs = list(outputs)
41 | self.inputs = list(inputs)
42 |
43 | def incorporate(self, rowid, observation, inputs=None):
44 | return
45 |
46 | def unincorporate(self, rowid):
47 | return
48 |
49 | @gu.simulate_many
50 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
51 | regime = self.lookup_quadrant(
52 | inputs[self.inputs[0]],
53 | inputs[self.inputs[1]],
54 | )
55 | x = gu.pflip(self.probabilities[regime], rng=self.rng)
56 | return {self.outputs[0]: x}
57 |
58 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
59 | x = targets[self.outputs[0]]
60 | if not (0 <= x <= 3):
61 | return -float('inf')
62 | regime = self.lookup_quadrant(
63 | inputs[self.inputs[0]],
64 | inputs[self.inputs[1]],
65 | )
66 | return np.log(self.probabilities[regime][x])
67 |
68 | def transition(self, N=None, S=None):
69 | time.sleep(.1)
70 |
71 | @staticmethod
72 | def lookup_quadrant(y0, y1):
73 | if y0 >= 0 and y1 >= 0: return 0
74 | if y0 <= 0 and y1 >= 0: return 1
75 | if y0 >= 0 and y1 <= 0: return 2
76 | if y0 <= 0 and y1 <= 0: return 3
77 | raise ValueError('Invalid value: %s' % str((y0, y1)))
78 |
79 | @staticmethod
80 | def retrieve_y_for_x(x):
81 | if x == 0: return [2, 2]
82 | if x == 1: return [-2, 2]
83 | if x == 2: return [2, -2]
84 | if x == 3: return [-2, -2]
85 | raise ValueError('Invalid value: %s' % str(x))
86 |
87 | def to_metadata(self):
88 | metadata = dict()
89 | metadata['outputs'] = self.outputs
90 | metadata['inputs'] = self.inputs
91 | metadata['factory'] = ('cgpm.dummy.fourway', 'FourWay')
92 | return metadata
93 |
94 | @classmethod
95 | def from_metadata(cls, metadata, rng=None):
96 | if rng is None: rng = gu.gen_rng(0)
97 | return cls(
98 | outputs=metadata['outputs'],
99 | inputs=metadata['inputs'],
100 | rng=rng)
101 |
--------------------------------------------------------------------------------
/src/dummy/trollnormal.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import time
18 |
19 | import numpy as np
20 |
21 | from cgpm.cgpm import CGpm
22 | from cgpm.utils import general as gu
23 |
24 |
25 | class TrollNormal(CGpm):
26 | def __init__(self, outputs, inputs, rng=None, distargs=None):
27 | if rng is None:
28 | rng = gu.gen_rng(1)
29 | self.rng = rng
30 | assert len(outputs) == 1
31 | assert len(inputs) == 2
32 | self.outputs = list(outputs)
33 | self.inputs = list(inputs)
34 | self.rowids = set([])
35 |
36 | def incorporate(self, rowid, observation, inputs=None):
37 | assert rowid not in self.rowids
38 | self.rowids.add(rowid)
39 |
40 | def unincorporate(self, rowid):
41 | assert rowid in self.rowids
42 | self.rowids.remove(rowid)
43 |
44 | @gu.simulate_many
45 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
46 | x = self.rng.normal(
47 | self._retrieve_location(inputs),
48 | self._retrieve_scale(inputs),
49 | )
50 | return {self.outputs[0]: x}
51 |
52 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
53 | return self._gaussian_log_pdf(
54 | targets[self.outputs[0]],
55 | self._retrieve_location(inputs),
56 | self._retrieve_scale(inputs),
57 | )
58 |
59 | def _gaussian_log_pdf(self, x, mu, s):
60 | normalizing_constant = -(np.log(2 * np.pi) / 2) - np.log(s)
61 | return normalizing_constant - ((x - mu)**2 / (2 * s**2))
62 |
63 | def transition(self, N=None, S=None):
64 | time.sleep(.1)
65 |
66 | def _retrieve_location(self, inputs):
67 | return inputs[self.inputs[0]]
68 |
69 | def _retrieve_scale(self, inputs):
70 | return abs(inputs[self.inputs[1]]) + 1
71 |
72 | def to_metadata(self):
73 | metadata = dict()
74 | metadata['outputs'] = self.outputs
75 | metadata['inputs'] = self.inputs
76 | metadata['factory'] = ('cgpm.dummy.trollnormal', 'TrollNormal')
77 | return metadata
78 |
79 | @classmethod
80 | def from_metadata(cls, metadata, rng=None):
81 | if rng is None: rng = gu.gen_rng(0)
82 | return cls(
83 | outputs=metadata['outputs'],
84 | inputs=metadata['inputs'],
85 | rng=rng)
86 |
--------------------------------------------------------------------------------
/src/dummy/twoway.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import time
18 |
19 | import numpy as np
20 |
21 | from cgpm.cgpm import CGpm
22 | from cgpm.utils import general as gu
23 |
24 |
25 | class TwoWay(CGpm):
26 | """Generates {0,1} output on {0,1} valued input with given CPT."""
27 |
28 | def __init__(self, outputs, inputs, distargs=None, rng=None):
29 | if rng is None:
30 | rng = gu.gen_rng(1)
31 | self.rng = rng
32 | self.probabilities =[
33 | [.9, .1],
34 | [.3, .7],
35 | ]
36 | assert len(outputs) == 1
37 | assert len(inputs) == 1
38 | self.outputs = list(outputs)
39 | self.inputs = list(inputs)
40 |
41 | def incorporate(self, rowid, observation, inputs=None):
42 | return
43 |
44 | def unincorporate(self, rowid):
45 | return
46 |
47 | @gu.simulate_many
48 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
49 | y = inputs[self.inputs[0]]
50 | assert int(y) == float(y)
51 | assert y in [0, 1]
52 | x = gu.pflip(self.probabilities[y], rng=self.rng)
53 | return {self.outputs[0]: x}
54 |
55 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
56 | y = inputs[self.inputs[0]]
57 | assert int(y) == float(y)
58 | assert y in [0, 1]
59 | x = targets[self.outputs[0]]
60 | return np.log(self.probabilities[y][x]) if x in [0,1] else -float('inf')
61 |
62 | def transition(self, N=None):
63 | time.sleep(.1)
64 |
65 | @staticmethod
66 | def retrieve_y_for_x(x):
67 | if x == 0:
68 | return 0
69 | if x == 1:
70 | return 1
71 | raise ValueError('Invalid value: %s' % str(x))
72 |
73 | def to_metadata(self):
74 | metadata = dict()
75 | metadata['outputs'] = self.outputs
76 | metadata['inputs'] = self.inputs
77 | metadata['factory'] = ('cgpm.dummy.twoway', 'TwoWay')
78 | return metadata
79 |
80 | @classmethod
81 | def from_metadata(cls, metadata, rng=None):
82 | if rng is None:
83 | rng = gu.gen_rng(0)
84 | return cls(
85 | outputs=metadata['outputs'],
86 | inputs=metadata['inputs'],
87 | rng=rng,
88 | )
89 |
--------------------------------------------------------------------------------
/src/factor/README:
--------------------------------------------------------------------------------
1 | These CGPMs learn low dimensional, continuous latent variable representations.
2 |
--------------------------------------------------------------------------------
/src/factor/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/kde/README:
--------------------------------------------------------------------------------
1 | These CGPMs learn the joint denisty on a high-dimensional variable using KDE.
2 |
--------------------------------------------------------------------------------
/src/kde/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/knn/README:
--------------------------------------------------------------------------------
1 | This CGPM builds ad-hoc machine learning models on a per-query basis using KNN.
2 |
--------------------------------------------------------------------------------
/src/knn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/mixtures/README:
--------------------------------------------------------------------------------
1 | These CGPMs represent mixtures of CGPMs, using crosscat-like terminology.
2 |
3 |
--------------------------------------------------------------------------------
/src/mixtures/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/network/README:
--------------------------------------------------------------------------------
1 | Importance sampling for inference in a composite network of CGPMs.
2 |
--------------------------------------------------------------------------------
/src/network/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/network/helpers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import itertools as it
18 |
19 | import numpy as np
20 | from scipy.sparse.csgraph import connected_components
21 |
22 |
23 | def validate_cgpms(cgpms):
24 | ot = [set(c.outputs) for c in cgpms]
25 | if not all(s for s in ot):
26 | raise ValueError('No output for a cgpm: %s' % ot)
27 | if any(set.intersection(a,b) for a,b in it.combinations(ot, 2)):
28 | raise ValueError('Duplicate outputs for cgpms: %s' % ot)
29 | return cgpms
30 |
31 |
32 | def retrieve_variable_to_cgpm(cgpms):
33 | """Return map of variable v to its index i in the list of cgpms."""
34 | return {v:i for i, c in enumerate(cgpms) for v in c.outputs}
35 |
36 |
37 | def retrieve_adjacency_list(cgpms, v_to_c):
38 | """Return map of cgpm index to list of indexes of its parent cgpms."""
39 | return {
40 | i: list(set(v_to_c[p] for p in c.inputs if p in v_to_c))
41 | for i, c in enumerate(cgpms)
42 | }
43 |
44 | def retrieve_adjacency_matrix(cgpms, v_to_c):
45 | """Return a directed adjacency matrix of cgpms."""
46 | adjacency_list = retrieve_adjacency_list(cgpms, v_to_c)
47 | adjacency_matrix = np.zeros((len(adjacency_list), len(adjacency_list)))
48 | for i in adjacency_list:
49 | adjacency_matrix[i, adjacency_list[i]] = 1
50 | return adjacency_matrix.T
51 |
52 | def retrieve_extraneous_inputs(cgpms, v_to_c):
53 | """Return list of inputs that are not the output of any cgpm."""
54 | extraneous = [[i for i in c.inputs if i not in v_to_c] for c in cgpms]
55 | return list(it.chain.from_iterable(extraneous))
56 |
57 |
58 | def retrieve_ancestors(cgpms, q):
59 | """Return list of all variables that are ancestors of q (duplicates)."""
60 | v_to_c = retrieve_variable_to_cgpm(cgpms)
61 | if q not in v_to_c:
62 | raise ValueError('Invalid node: %s, %s' % (q, v_to_c))
63 | def ancestors(v):
64 | parents = cgpms[v_to_c[v]].inputs if v in v_to_c else []
65 | parent_ancestors = [ancestors(v) for v in parents]
66 | return list(it.chain.from_iterable(parent_ancestors)) + parents
67 | return ancestors(q)
68 |
69 | def retrieve_descendents(cgpms, q):
70 | """Return list of all variables that are descends of q (duplicates)."""
71 | v_to_c = retrieve_variable_to_cgpm(cgpms)
72 | if q not in v_to_c:
73 | raise ValueError('Invalid node: %s, %s' % (q, v_to_c))
74 | def descendents(v):
75 | children = list(it.chain.from_iterable(
76 | [c.outputs for c in cgpms if v in c.inputs]))
77 | children_descendents = [descendents(c) for c in children]
78 | return list(it.chain.from_iterable(children_descendents)) + children
79 | return descendents(q)
80 |
81 |
82 | def retrieve_weakly_connected_components(cgpms):
83 | v_to_c = retrieve_variable_to_cgpm(cgpms)
84 | adjacency = retrieve_adjacency_matrix(cgpms, v_to_c)
85 | n_components, labels = connected_components(
86 | adjacency, directed=True, connection='weak', return_labels=True)
87 | return labels
88 |
89 |
90 | def topological_sort(graph):
91 | """Topologically sort a directed graph represented as an adjacency list.
92 |
93 | Assumes that edges are incoming, ie (10: [8,7]) means 8->10 and 7->10.
94 |
95 | Parameters
96 | ----------
97 | graph : list or dict
98 | Adjacency list or dict representing the graph, for example:
99 | graph_l = [(10, [8, 7]), (5, [8, 7, 9, 10, 11, 13, 15])]
100 | graph_d = {10: [8, 7], 5: [8, 7, 9, 10, 11, 13, 15]}
101 |
102 | Returns
103 | -------
104 | graph_sorted : list
105 | An adjacency list, where order of nodes is in topological order.
106 | """
107 | graph_sorted = []
108 | graph = dict(graph)
109 | while graph:
110 | cyclic = True
111 | for node, edges in graph.items():
112 | if all(e not in graph for e in edges):
113 | cyclic = False
114 | del graph[node]
115 | graph_sorted.append(node)
116 | if cyclic:
117 | raise ValueError('Cyclic dependency occurred in topological_sort.')
118 | return graph_sorted
119 |
120 |
121 | def retrieve_required_inputs(cgpms, topo, targets, constraints, extraneous):
122 | """Return list of input addresses required to answer query."""
123 | required = set(targets)
124 | for l in reversed(topo):
125 | outputs_l = cgpms[l].outputs
126 | inputs_l = cgpms[l].inputs
127 | if any(i in required or i in constraints for i in outputs_l):
128 | required.update(inputs_l)
129 | return [
130 | target for target in required if
131 | all(target not in x for x in [targets, constraints, extraneous])
132 | ]
133 |
--------------------------------------------------------------------------------
/src/primitives/README:
--------------------------------------------------------------------------------
1 | These CGPMs are univariate probability distributions from exponential family.
2 |
3 |
--------------------------------------------------------------------------------
/src/primitives/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/primitives/bernoulli.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from math import log
18 |
19 | from scipy.special import betaln
20 |
21 | from cgpm.primitives.distribution import DistributionGpm
22 | from cgpm.utils import general as gu
23 |
24 |
25 | class Bernoulli(DistributionGpm):
26 | """Bernoulli distribution with beta prior on bias theta.
27 |
28 | theta ~ Beta(alpha, beta)
29 | x ~ Bernoulli(theta)
30 | """
31 |
32 | def __init__(self, outputs, inputs, hypers=None, params=None,
33 | distargs=None, rng=None):
34 | DistributionGpm.__init__(
35 | self, outputs, inputs, hypers, params, distargs, rng)
36 | # Sufficent statistics.
37 | self.N = 0
38 | self.x_sum = 0
39 | # Hyperparameters.
40 | if hypers is None: hypers = {}
41 | self.alpha = hypers.get('alpha', 1.)
42 | self.beta = hypers.get('beta', 1.)
43 | assert self.alpha > 0
44 | assert self.beta > 0
45 |
46 | def incorporate(self, rowid, observation, inputs=None):
47 | DistributionGpm.incorporate(self, rowid, observation, inputs)
48 | x = observation[self.outputs[0]]
49 | if x not in [0, 1]:
50 | raise ValueError('Invalid Bernoulli: %s' % str(x))
51 | self.N += 1
52 | self.x_sum += x
53 | self.data[rowid] = x
54 |
55 | def unincorporate(self, rowid):
56 | x = self.data.pop(rowid)
57 | self.N -= 1
58 | self.x_sum -= x
59 |
60 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
61 | DistributionGpm.logpdf(self, rowid, targets, constraints, inputs)
62 | x = targets[self.outputs[0]]
63 | if x not in [0, 1]:
64 | return -float('inf')
65 | return Bernoulli.calc_predictive_logp(
66 | x, self.N, self.x_sum, self.alpha, self.beta)
67 |
68 | @gu.simulate_many
69 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
70 | DistributionGpm.simulate(self, rowid, targets, constraints, inputs, N)
71 | if rowid in self.data:
72 | return {self.outputs[0]: self.data[rowid]}
73 | p0 = Bernoulli.calc_predictive_logp(
74 | 0, self.N, self.x_sum, self.alpha, self.beta)
75 | p1 = Bernoulli.calc_predictive_logp(
76 | 1, self.N, self.x_sum, self.alpha, self.beta)
77 | x = gu.log_pflip([p0, p1], rng=self.rng)
78 | return {self.outputs[0]: x}
79 |
80 | def logpdf_score(self):
81 | return Bernoulli.calc_logpdf_marginal(
82 | self.N, self.x_sum, self.alpha, self.beta)
83 |
84 | ##################
85 | # NON-GPM METHOD #
86 | ##################
87 |
88 | def transition_params(self):
89 | return
90 |
91 | def set_hypers(self, hypers):
92 | assert hypers['alpha'] > 0
93 | assert hypers['beta'] > 0
94 | self.alpha = hypers['alpha']
95 | self.beta = hypers['beta']
96 |
97 | def get_hypers(self):
98 | return {'alpha': self.alpha, 'beta': self.beta}
99 |
100 | def get_params(self):
101 | return {}
102 |
103 | def get_suffstats(self):
104 | return {'N':self.N, 'x_sum':self.x_sum}
105 |
106 | def get_distargs(self):
107 | return {'k': 2}
108 |
109 | @staticmethod
110 | def construct_hyper_grids(X, n_grid=30):
111 | grids = dict()
112 | grids['alpha'] = gu.log_linspace(1., float(len(X)), n_grid)
113 | grids['beta'] = gu.log_linspace(1., float(len(X)),n_grid)
114 | return grids
115 |
116 | @staticmethod
117 | def name():
118 | return 'bernoulli'
119 |
120 | @staticmethod
121 | def is_collapsed():
122 | return True
123 |
124 | @staticmethod
125 | def is_continuous():
126 | return False
127 |
128 | @staticmethod
129 | def is_conditional():
130 | return False
131 |
132 | @staticmethod
133 | def is_numeric():
134 | return False
135 |
136 | ##################
137 | # HELPER METHODS #
138 | ##################
139 |
140 | @staticmethod
141 | def calc_predictive_logp(x, N, x_sum, alpha, beta):
142 | log_denom = log(N + alpha + beta)
143 | if x == 1:
144 | return log(x_sum + alpha) - log_denom
145 | else:
146 | return log(N - x_sum + beta) - log_denom
147 |
148 | @staticmethod
149 | def calc_logpdf_marginal(N, x_sum, alpha, beta):
150 | return betaln(x_sum + alpha, N - x_sum + beta) - betaln(alpha, beta)
151 |
--------------------------------------------------------------------------------
/src/primitives/categorical.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from math import log
18 |
19 | import numpy as np
20 |
21 | from scipy.special import gammaln
22 |
23 | from cgpm.primitives.distribution import DistributionGpm
24 | from cgpm.utils import general as gu
25 |
26 |
27 | class Categorical(DistributionGpm):
28 | """Categorical distribution with symmetric dirichlet prior on
29 | category weight vector v.
30 |
31 | k := distarg
32 | v ~ Symmetric-Dirichlet(alpha/k)
33 | x ~ Categorical(v)
34 | http://www.cs.berkeley.edu/~stephentu/writeups/dirichlet-conjugate-prior.pdf
35 | """
36 |
37 | def __init__(self, outputs, inputs, hypers=None, params=None,
38 | distargs=None, rng=None):
39 | DistributionGpm.__init__(
40 | self, outputs, inputs, hypers, params, distargs, rng)
41 | # Distargs.
42 | k = distargs.get('k', None)
43 | if k is None:
44 | raise ValueError('Categorical requires distarg `k`.')
45 | self.k = int(k)
46 | # Sufficient statistics.
47 | self.N = 0
48 | self.counts = np.zeros(self.k)
49 | # Hyperparameters.
50 | if hypers is None: hypers = {}
51 | self.alpha = hypers.get('alpha', 1.)
52 |
53 | def incorporate(self, rowid, observation, inputs=None):
54 | DistributionGpm.incorporate(self, rowid, observation, inputs)
55 | x = observation[self.outputs[0]]
56 | if not (x % 1 == 0 and 0 <= x < self.k):
57 | raise ValueError('Invalid Categorical(%d): %s' % (self.k, x))
58 | x = int(x)
59 | self.N += 1
60 | self.counts[x] += 1
61 | self.data[rowid] = x
62 |
63 | def unincorporate(self, rowid):
64 | x = self.data.pop(rowid)
65 | self.N -= 1
66 | self.counts[x] -= 1
67 |
68 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
69 | DistributionGpm.logpdf(self, rowid, targets, constraints, inputs)
70 | x = targets[self.outputs[0]]
71 | if not (x % 1 == 0 and 0 <= x < self.k):
72 | return -float('inf')
73 | return Categorical.calc_predictive_logp(
74 | int(x), self.N, self.counts, self.alpha)
75 |
76 | @gu.simulate_many
77 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
78 | DistributionGpm.simulate(self, rowid, targets, constraints, inputs, N)
79 | if rowid in self.data:
80 | return {self.outputs[0]: self.data[rowid]}
81 | x = gu.pflip(self.counts + self.alpha, rng=self.rng)
82 | return {self.outputs[0]: x}
83 |
84 | def logpdf_score(self):
85 | return Categorical.calc_logpdf_marginal(self.N, self.counts, self.alpha)
86 |
87 | ##################
88 | # NON-GPM METHOD #
89 | ##################
90 |
91 | def transition_params(self):
92 | return
93 |
94 | def set_hypers(self, hypers):
95 | assert hypers['alpha'] > 0
96 | self.alpha = hypers['alpha']
97 |
98 | def get_hypers(self):
99 | return {'alpha': self.alpha}
100 |
101 | def get_params(self):
102 | return {}
103 |
104 | def get_suffstats(self):
105 | return {'N' : self.N, 'counts' : list(self.counts)}
106 |
107 | def get_distargs(self):
108 | return {'k': self.k}
109 |
110 | @staticmethod
111 | def construct_hyper_grids(X, n_grid=30):
112 | grids = dict()
113 | grids['alpha'] = gu.log_linspace(1., float(len(X)), n_grid)
114 | return grids
115 |
116 | @staticmethod
117 | def name():
118 | return 'categorical'
119 |
120 | @staticmethod
121 | def is_collapsed():
122 | return True
123 |
124 | @staticmethod
125 | def is_continuous():
126 | return False
127 |
128 | @staticmethod
129 | def is_conditional():
130 | return False
131 |
132 | @staticmethod
133 | def is_numeric():
134 | return False
135 |
136 | ##################
137 | # HELPER METHODS #
138 | ##################
139 |
140 | @staticmethod
141 | def validate(x, K):
142 | return int(x) == float(x) and 0 <= x < K
143 |
144 | @staticmethod
145 | def calc_predictive_logp(x, N, counts, alpha):
146 | numer = log(alpha + counts[x])
147 | denom = log(np.sum(counts) + alpha * len(counts))
148 | return numer - denom
149 |
150 | @staticmethod
151 | def calc_logpdf_marginal(N, counts, alpha):
152 | K = len(counts)
153 | A = K * alpha
154 | lg = sum(gammaln(counts[k] + alpha) for k in xrange(K))
155 | return gammaln(A) - gammaln(A+N) + lg - K * gammaln(alpha)
156 |
--------------------------------------------------------------------------------
/src/regressions/README:
--------------------------------------------------------------------------------
1 | These CGPMs are discriminative, i.e. numeric- or symbolic-valued regression.
2 |
3 |
--------------------------------------------------------------------------------
/src/regressions/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/uncorrelated/README:
--------------------------------------------------------------------------------
1 | These CGPMs are zero-correlation bivariate functions on the R2 plane.
2 |
3 |
--------------------------------------------------------------------------------
/src/uncorrelated/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/uncorrelated/diamond.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.cgpm import CGpm
18 | from cgpm.network.importance import ImportanceNetwork
19 | from cgpm.uncorrelated.directed import DirectedXyGpm
20 | from cgpm.uncorrelated.uniformx import UniformX
21 | from cgpm.utils import general as gu
22 |
23 |
24 | class DiamondY(CGpm):
25 | def __init__(self, outputs=None, inputs=None, noise=None, rng=None):
26 | if rng is None:
27 | rng = gu.gen_rng(1)
28 | if outputs is None:
29 | outputs = [0]
30 | if inputs is None:
31 | inputs = [1]
32 | if noise is None:
33 | noise = .1
34 | self.rng = rng
35 | self.outputs = outputs
36 | self.inputs = inputs
37 | self.noise = noise
38 |
39 | @gu.simulate_many
40 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
41 | assert targets == self.outputs
42 | assert inputs.keys() == self.inputs
43 | x = inputs[self.inputs[0]]
44 | slope = self.rng.rand()
45 | noise = self.rng.uniform(high=self.noise)
46 | if x < 0 and slope < .5:
47 | y = max(-x-1, x+1 - noise)
48 | elif x < 0 and slope > .5:
49 | y = min(x+1, -x-1 + noise)
50 | elif x > 0 and slope < .5:
51 | y = min(-x+1, x-1 + noise)
52 | elif x > 0 and slope > .5:
53 | y = max(x-1, -x+1 - noise)
54 | else:
55 | raise ValueError()
56 | return {self.outputs[0]: y}
57 |
58 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
59 | raise NotImplementedError
60 |
61 |
62 | class Diamond(DirectedXyGpm):
63 | """Y = (+/- w.p .5) X^2 + U(0,noise)."""
64 |
65 | def __init__(self, outputs=None, inputs=None, noise=None, rng=None):
66 | DirectedXyGpm.__init__(
67 | self, outputs=outputs, inputs=inputs, noise=noise, rng=rng)
68 | self.x = UniformX(
69 | outputs=[self.outputs[0]], low=-1, high=1)
70 | self.y = DiamondY(
71 | outputs=[self.outputs[1]],
72 | inputs=[self.outputs[0]],
73 | noise=noise)
74 | self.network = ImportanceNetwork([self.x, self.y], rng=self.rng)
75 |
--------------------------------------------------------------------------------
/src/uncorrelated/directed.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.cgpm import CGpm
18 | from cgpm.utils.general import gen_rng
19 |
20 |
21 | class DirectedXyGpm(CGpm):
22 | """Interface directed two-dimensional GPMs over the R2 plane."""
23 |
24 | def __init__(self, outputs=None, inputs=None, noise=None, acc=1, rng=None):
25 | """Initialize the Gpm with given noise parameter.
26 |
27 | Parameters
28 | ----------
29 | noise : float
30 | Value in (0,1) indicating the noise level of the distribution.
31 | rng : np.random.RandomState, optional.
32 | Source of entropy.
33 | """
34 | if type(self) is DirectedXyGpm:
35 | raise Exception('Cannot directly instantiate DirectedXyGpm.')
36 | if rng is None:
37 | rng = gen_rng(0)
38 | if outputs is None:
39 | outputs = [0, 1]
40 | if noise is None:
41 | noise = .1
42 | self.rng = rng
43 | self.outputs = outputs
44 | self.noise = noise
45 | self.acc = acc
46 | # Override the network in subclass.
47 | self.network = None
48 |
49 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
50 | if self.network is None:
51 | raise ValueError('self.network not defined by %s' % (type(self),))
52 | return self.network.logpdf(rowid, targets, inputs)
53 |
54 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
55 | if self.network is None:
56 | raise ValueError('self.network not defined by %s' % (type(self),))
57 | return self.network.simulate(rowid, targets, inputs, N)
58 |
--------------------------------------------------------------------------------
/src/uncorrelated/dots.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | from scipy.stats import norm
20 |
21 | from cgpm.uncorrelated.undirected import UnDirectedXyGpm
22 | from cgpm.utils import general as gu
23 |
24 |
25 | class Dots(UnDirectedXyGpm):
26 | """(X,Y) ~ Four Dots."""
27 |
28 | mx = [ -1, 1, -1, 1]
29 | my = [ -1, -1, 1, 1]
30 |
31 | def simulate_joint(self):
32 | n = self.rng.randint(4)
33 | x = self.rng.normal(loc=self.mx[n], scale=self.noise)
34 | y = self.rng.normal(loc=self.my[n], scale=self.noise)
35 | return [x, y]
36 |
37 | def simulate_conditional(self, z):
38 | return self.simulate_joint()[0]
39 |
40 | def logpdf_joint(self, x, y):
41 | return gu.logsumexp([np.log(.25)
42 | + norm.logpdf(x, loc=mx, scale=self.noise)
43 | + norm.logpdf(y, loc=my, scale=self.noise)
44 | for (mx,my) in zip(self.mx, self.my)])
45 |
46 | def logpdf_marginal(self, z):
47 | return gu.logsumexp(
48 | [np.log(.5) + norm.logpdf(z, loc=mx, scale=self.noise)
49 | for mx in set(self.mx)])
50 |
51 | def logpdf_conditional(self, w, z):
52 | return self.logpdf_marginal(z)
53 |
--------------------------------------------------------------------------------
/src/uncorrelated/linear.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | from scipy.stats import norm
20 |
21 | from cgpm.uncorrelated.undirected import UnDirectedXyGpm
22 | from cgpm.utils import mvnormal as multivariate_normal
23 |
24 |
25 | class Linear(UnDirectedXyGpm):
26 |
27 | def simulate_joint(self):
28 | return self.rng.multivariate_normal(
29 | [0,0], [[1,1-self.noise],[1-self.noise,1]])
30 |
31 | def simulate_conditional(self, z):
32 | mean = self.conditional_mean(z)
33 | var = self.conditional_variance(z)
34 | return self.rng.normal(loc=mean, scale=np.sqrt(var))
35 |
36 | def logpdf_joint(self, x, y):
37 | return multivariate_normal.logpdf(
38 | np.array([x,y]), np.array([0,0]),
39 | np.array([[1,1-self.noise],[1-self.noise,1]]))
40 |
41 | def logpdf_marginal(self, z):
42 | return norm.logpdf(z, scale=1)
43 |
44 | def logpdf_conditional(self, w, z):
45 | mean = self.conditional_mean(z)
46 | var = self.conditional_variance(z)
47 | return norm.logpdf(w, loc=mean, scale=np.sqrt(var))
48 |
49 | def conditional_mean(self, z):
50 | return (1-self.noise)*z
51 |
52 | def conditional_variance(self, z):
53 | return (1-(1-self.noise)**2)
54 |
55 | def mutual_information(self):
56 | cov = 1-self.noise
57 | return -.5 * np.log(1-cov**2)
58 |
--------------------------------------------------------------------------------
/src/uncorrelated/parabola.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | from scipy.misc import logsumexp
20 | from scipy.stats import uniform
21 |
22 | from cgpm.cgpm import CGpm
23 | from cgpm.network.importance import ImportanceNetwork
24 | from cgpm.uncorrelated.directed import DirectedXyGpm
25 | from cgpm.uncorrelated.uniformx import UniformX
26 | from cgpm.utils import general as gu
27 |
28 |
29 | class ParabolaY(CGpm):
30 | def __init__(self, outputs=None, inputs=None, noise=None, rng=None):
31 | if rng is None:
32 | rng = gu.gen_rng(1)
33 | if outputs is None:
34 | outputs = [0]
35 | if inputs is None:
36 | inputs = [1]
37 | if noise is None:
38 | noise = .1
39 | self.rng = rng
40 | self.outputs = outputs
41 | self.inputs = inputs
42 | self.noise = noise
43 | self.uniform = uniform(loc=-self.noise, scale=2*self.noise)
44 |
45 | @gu.simulate_many
46 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
47 | assert targets == self.outputs
48 | assert inputs.keys() == self.inputs
49 | assert not constraints
50 | x = inputs[self.inputs[0]]
51 | u = self.rng.rand()
52 | noise = self.rng.uniform(low=-self.noise, high=self.noise)
53 | if u < .5:
54 | y = x**2 + noise
55 | else:
56 | y = -(x**2 + noise)
57 | return {self.outputs[0]: y}
58 |
59 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
60 | assert targets.keys() == self.outputs
61 | assert inputs.keys() == self.inputs
62 | assert not constraints
63 | x = inputs[self.inputs[0]]
64 | y = targets[self.outputs[0]]
65 | return logsumexp([
66 | np.log(.5)+self.uniform.logpdf(y-x**2),
67 | np.log(.5)+self.uniform.logpdf(-y-x**2)
68 | ])
69 |
70 |
71 | class Parabola(DirectedXyGpm):
72 | """Y = (+/- w.p .5) X^2 + U(0,noise)."""
73 |
74 | def __init__(self, outputs=None, inputs=None, noise=None, rng=None):
75 | DirectedXyGpm.__init__(
76 | self, outputs=outputs, inputs=inputs, noise=noise, rng=rng)
77 | self.x = UniformX(
78 | outputs=[self.outputs[0]], low=-1, high=1)
79 | self.y = ParabolaY(
80 | outputs=[self.outputs[1]],
81 | inputs=[self.outputs[0]],
82 | noise=noise)
83 | self.network = ImportanceNetwork([self.x, self.y], rng=self.rng)
84 |
--------------------------------------------------------------------------------
/src/uncorrelated/ring.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | from cgpm.uncorrelated.undirected import UnDirectedXyGpm
20 |
21 |
22 | class Ring(UnDirectedXyGpm):
23 | """sqrt(X**2 + Y**2) + noise = 1 """
24 |
25 | def simulate_joint(self):
26 | angle = self.rng.uniform(0., 2.*np.pi)
27 | distance = self.rng.uniform(1.-self.noise, 1.)
28 | x = np.cos(angle)*distance
29 | y = np.sin(angle)*distance
30 | return [x, y]
31 |
--------------------------------------------------------------------------------
/src/uncorrelated/undirected.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.cgpm import CGpm
18 | from cgpm.utils import general as gu
19 |
20 |
21 | class UnDirectedXyGpm(CGpm):
22 | """Interface undirected two-dimensional GPMs over the R2 plane."""
23 |
24 | def __init__(self, outputs=None, inputs=None, noise=None, rng=None):
25 | if rng is None:
26 | rng = gu.gen_rng(0)
27 | if outputs is None:
28 | outputs = [0, 1]
29 | if noise is None:
30 | noise = .1
31 | self.rng = rng
32 | self.outputs = outputs
33 | self.inputs = []
34 | self.noise = noise
35 |
36 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
37 | assert not inputs
38 | if not constraints:
39 | if len(targets) == 2:
40 | x, y = targets.values()
41 | return self.logpdf_joint(x, y)
42 | else:
43 | z = targets.values()[0]
44 | return self.logpdf_maringal(z)
45 | else:
46 | assert len(constraints) == len(targets) == 1
47 | z = constraints.values()[0]
48 | w = targets.values()[0]
49 | return self.logpdf_conditional(w, z)
50 |
51 | @gu.simulate_many
52 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
53 | assert not inputs
54 | if not constraints:
55 | sample = self.simulate_joint()
56 | return {q: sample[self.outputs.index(q)] for q in targets}
57 | assert len(constraints) == len(targets) == 1
58 | z = constraints.values()[0]
59 | return {targets[0]: self.simulate_conditional(z)}
60 |
61 | # Internal simulators and assesors.
62 |
63 | def simulate_joint(self):
64 | raise NotImplementedError
65 |
66 | def simulate_conditional(self, z):
67 | raise NotImplementedError
68 |
69 | def logpdf_marginal(self, z):
70 | raise NotImplementedError
71 |
72 | def logpdf_joint(self, x, y):
73 | raise NotImplementedError
74 |
75 | def logpdf_conditional(self, w, z):
76 | raise NotImplementedError
77 |
78 | def mutual_information(self):
79 | raise NotImplementedError
80 |
--------------------------------------------------------------------------------
/src/uncorrelated/uniformx.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from scipy.stats import uniform
18 |
19 | from cgpm.cgpm import CGpm
20 | from cgpm.utils import general as gu
21 |
22 |
23 | class UniformX(CGpm):
24 |
25 | def __init__(self, outputs=None, inputs=None, low=0, high=1, rng=None):
26 | assert not inputs
27 | if rng is None:
28 | rng = gu.gen_rng(0)
29 | if outputs is None:
30 | outputs = [0]
31 | self.rng = rng
32 | self.low = low
33 | self.high = high
34 | self.outputs = outputs
35 | self.inputs = []
36 | self.uniform = uniform(loc=self.low, scale=self.high-self.low)
37 |
38 | @gu.simulate_many
39 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
40 | assert not constraints
41 | assert targets == self.outputs
42 | x = self.rng.uniform(low=self.low, high=self.high)
43 | return {self.outputs[0]: x}
44 |
45 | def logpdf(self, rowid, targets, constraints=None, inputs=None):
46 | assert not constraints
47 | assert not inputs
48 | assert targets.keys() == self.outputs
49 | x = targets[self.outputs[0]]
50 | return self.uniform.logpdf(x)
51 |
--------------------------------------------------------------------------------
/src/uncorrelated/xcross.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | from cgpm.uncorrelated.undirected import UnDirectedXyGpm
20 | from cgpm.utils import general as gu
21 | from cgpm.utils import mvnormal as multivariate_normal
22 |
23 |
24 | class XCross(UnDirectedXyGpm):
25 | """Y = (+/- w.p .5) X + N(0,noise)."""
26 |
27 | def simulate_joint(self):
28 | if self.rng.rand() < .5:
29 | cov = np.array([[1,1-self.noise],[1-self.noise,1]])
30 | else:
31 | cov = np.array([[1,-1+self.noise],[-1+self.noise,1]])
32 | return self.rng.multivariate_normal([0,0], cov=cov)
33 |
34 | def logpdf_joint(self, x, y):
35 | X = np.array([x, y])
36 | Mu = np.array([0, 0])
37 | Sigma0 = np.array([[1, 1 - self.noise], [1 - self.noise, 1]])
38 | Sigma1 = np.array([[1, -1 + self.noise], [-1 + self.noise, 1]])
39 | return gu.logsumexp([
40 | np.log(.5)+multivariate_normal.logpdf(X, Mu, Sigma0),
41 | np.log(.5)+multivariate_normal.logpdf(X, Mu, Sigma1),
42 | ])
43 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import re
19 |
20 | from datetime import datetime
21 |
22 | import importlib
23 |
24 |
25 | cctype_class_lookup = {
26 | 'bernoulli' : ('cgpm.primitives.bernoulli', 'Bernoulli'),
27 | 'beta' : ('cgpm.primitives.beta', 'Beta'),
28 | 'categorical' : ('cgpm.primitives.categorical', 'Categorical'),
29 | 'crp' : ('cgpm.primitives.crp', 'Crp'),
30 | 'exponential' : ('cgpm.primitives.exponential', 'Exponential'),
31 | 'geometric' : ('cgpm.primitives.geometric', 'Geometric'),
32 | 'linear_regression' : ('cgpm.regressions.linreg', 'LinearRegression'),
33 | 'lognormal' : ('cgpm.primitives.lognormal', 'Lognormal'),
34 | 'normal' : ('cgpm.primitives.normal', 'Normal'),
35 | 'normal_trunc' : ('cgpm.primitives.normal_trunc', 'NormalTrunc'),
36 | 'poisson' : ('cgpm.primitives.poisson', 'Poisson'),
37 | 'random_forest' : ('cgpm.regressions.forest', 'RandomForest'),
38 | 'vonmises' : ('cgpm.primitives.vonmises', 'Vonmises'),
39 | }
40 |
41 | # https://github.com/posterior/loom/blob/master/doc/using.md#input-format
42 | cctype_loom_lookup = {
43 | 'bernoulli' : 'boolean',
44 | 'beta' : 'real',
45 | 'categorical' : 'categorical',
46 | 'crp' : 'unbounded_categorical',
47 | 'exponential' : 'real',
48 | 'geometric' : 'real',
49 | 'lognormal' : 'real',
50 | 'normal' : 'real',
51 | 'normal_trunc' : 'real',
52 | 'poisson' : 'count',
53 | 'vonmises' : 'real',
54 | }
55 |
56 | def timestamp():
57 | return datetime.now().strftime('%Y%m%d-%H%M%S')
58 |
59 | def colors():
60 | """Returns a list of colors."""
61 | return ['red', 'blue', 'green', 'yellow', 'orange', 'purple', 'brown',
62 | 'black', 'pink']
63 |
64 | def cctype_class(cctype):
65 | """Return class object for initializing a named GPM (default normal)."""
66 | if not cctype:
67 | raise ValueError('Specify a cctype!')
68 | modulename, classname = cctype_class_lookup[cctype]
69 | mod = importlib.import_module(modulename)
70 | return getattr(mod, classname)
71 |
72 | def loom_stattype(cctype, distargs):
73 | # XXX Loom categorical is only up to 256 values; otherwise we need
74 | # unbounded_categorical (aka crp).
75 | if cctype == 'categorical' and distargs['k'] > 256:
76 | cctype = 'crp'
77 | try:
78 | return cctype_loom_lookup[cctype]
79 | except KeyError:
80 | raise ValueError(
81 | 'Cannot convert cgpm type to loom type: %s' % (cctype,))
82 |
83 | def valid_cctype(dist):
84 | """Returns True if dist is a valid DistributionGpm."""
85 | return dist in cctype_class_lookup
86 |
87 | def all_cctypes():
88 | """Returns a list of all known DistributionGpm."""
89 | return cctype_class_lookup.keys()
90 |
91 | def parse_distargs(dists):
92 | """Parses a list of cctypes, where distargs are in parenthesis.
93 | >>> Input ['normal','categorical(k=8)','beta'].
94 | >>> Output ['normal','categorical','beta'], [None, {'k':8}, None].
95 | """
96 | cctypes, distargs = [], []
97 | for cctype in dists:
98 | keywords = re.search('\(.*\)', cctype)
99 | if keywords is not None:
100 | keywords = keywords.group(0).replace('(','').\
101 | replace(')','')
102 | temp = {}
103 | for subpair in keywords.split(','):
104 | key, val = subpair.split('=')
105 | temp[key] = float(val)
106 | keywords = temp
107 | cctype = cctype[:cctype.index('(')]
108 | cctypes.append(cctype)
109 | distargs.append(keywords)
110 | return cctypes, distargs
111 |
112 | def check_env_debug():
113 | debug = os.environ.get('GPMCCDEBUG', None)
114 | return False if debug is None else int(debug)
115 |
--------------------------------------------------------------------------------
/src/utils/parallel_map.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2016 MIT Probabilistic Computing Project.
2 | #
3 | # This file is part of Venture.
4 | #
5 | # Venture is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # Venture is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with Venture. If not, see .
17 |
18 | import cPickle as pickle
19 | import os
20 | import struct
21 | import traceback
22 |
23 | from multiprocessing import Process
24 | from multiprocessing import Pipe
25 | from multiprocessing import cpu_count
26 |
27 |
28 | def le32enc(n):
29 | return struct.pack(' (cos x) 0) (uniform_continuous (- (cos x) .5) (cos x)) (uniform_continuous (cos x) (+ (cos x) .5))))''',
31 | rng=gu.gen_rng(4))
32 |
33 | # The CGPM for Y with seed 5.
34 | vsy_s5 = InlineVsCGpm([1], [0],
35 | expression='''(lambda (x) (if (> (cos x) 0) (uniform_continuous (- (cos x) .5) (cos x)) (uniform_continuous (cos x) (+ (cos x) .5))))''',
36 | rng=gu.gen_rng(5))
37 |
38 | # Simulate the uniform x from vsx.
39 | samples_x4 = vsx_s4.simulate(0, [0], N=200)
40 |
41 | # Simulate Y from each of the seed 4 and 5.
42 | samples_y4 = [vsy_s4.simulate(0, [1], sx) for sx in samples_x4]
43 | samples_y5 = [vsy_s5.simulate(0, [1], sx) for sx in samples_x4]
44 |
45 | # Convert all samples from dictionaries to lists.
46 | xs4 = [s[0] for s in samples_x4]
47 | ys4 = [s[1] for s in samples_y4]
48 | ys5 = [s[1] for s in samples_y5]
49 |
50 | # Compute the noise at each data point for both sample sets.
51 | errors1 = np.cos(xs4)-ys4
52 | errors2 = np.cos(xs4)-ys5
53 |
54 | # Plot the joint query.
55 | fig, ax = plt.subplots()
56 | ax.scatter(xs4, ys4, color='blue', alpha=.4) # There is no noise in the cosx.
57 | ax.scatter(xs4, ys5, color='red', alpha=.4) # This is noise in the cosx.
58 | ax.set_xlim([-1.5*np.pi, 1.5*np.pi])
59 | ax.set_ylim([-1.75, 1.75])
60 | for x in xs4:
61 | ax.vlines(x, -1.75, -1.65, linewidth=.5)
62 | ax.grid()
63 |
64 | # Plot the errors.
65 | fig, ax = plt.subplots()
66 | ax.scatter(xs4, errors1, color='blue', alpha=.4)
67 | ax.scatter(xs4, errors2, color='red', alpha=.4)
68 | ax.set_xlabel('value of x')
69 | ax.set_ylabel('error of cos(x) - y')
70 | ax.grid()
71 |
--------------------------------------------------------------------------------
/tests/graphical/animals.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 |
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 | import pandas as pd
22 | import pytest
23 |
24 | from cgpm.crosscat.engine import Engine
25 | from cgpm.utils import general as gu
26 | from cgpm.utils import plots as pu
27 | from cgpm.utils import test as tu
28 | from cgpm.utils import render as ru
29 |
30 |
31 | animals = pd.read_csv('resources/animals/animals.csv', index_col=0)
32 | animal_values = animals.values
33 | animal_names = animals.index.values
34 | animal_features = animals.columns.values
35 |
36 |
37 | # XXX This function should be parametrized better!
38 | def launch_analysis():
39 | engine = Engine(
40 | animals.values.astype(float),
41 | num_states=64,
42 | cctypes=['categorical']*len(animals.values[0]),
43 | distargs=[{'k':2}]*len(animals.values[0]),
44 | rng=gu.gen_rng(7))
45 |
46 | engine.transition(N=900)
47 | with open('resources/animals/animals.engine', 'w') as f:
48 | engine.to_pickle(f)
49 |
50 | engine = Engine.from_pickle(open('resources/animals/animals.engine','r'))
51 | D = engine.dependence_probability_pairwise()
52 | pu.plot_clustermap(D)
53 |
54 |
55 | def render_states_to_disk(filepath, prefix):
56 | engine = Engine.from_pickle(filepath)
57 | for i in range(engine.num_states()):
58 | print '\r%d' % (i,)
59 | savefile = '%s-%d' % (prefix, i)
60 | state = engine.get_state(i)
61 | ru.viz_state(
62 | state, row_names=animal_names,
63 | col_names=animal_features, savefile=savefile)
64 |
65 |
66 | def compare_dependence_heatmap():
67 | e1 = Engine.from_pickle('resources/animals/animals.engine')
68 | e2 = Engine.from_pickle('resources/animals/animals-lovecat.engine')
69 |
70 | D1 = e1.dependence_probability_pairwise()
71 | D2 = e2.dependence_probability_pairwise()
72 | C1 = pu.plot_clustermap(D1)
73 |
74 | ordering = C1.dendrogram_row.reordered_ind
75 |
76 | fig, ax = plt.subplots(nrows=1, ncols=2)
77 | pu.plot_heatmap(D1, xordering=ordering, yordering=ordering, ax=ax[0])
78 | pu.plot_heatmap(D2, xordering=ordering, yordering=ordering, ax=ax[1])
79 |
--------------------------------------------------------------------------------
/tests/graphical/depprob_id.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import matplotlib.pyplot as plt
18 | import numpy as np
19 | import seaborn as sns
20 |
21 | from cgpm.crosscat.engine import Engine
22 | from cgpm.utils import config as cu
23 |
24 | np.random.seed(0)
25 |
26 | N_ROWS = 300
27 | N_STATES = 12
28 | N_ITERS = 100
29 |
30 | cctypes = ['categorical(k={})'.format(N_ROWS)] + ['normal']*8
31 | cctypes, distargs = cu.parse_distargs(cctypes)
32 | column_names = ['id'] + ['one cluster']*4 + ['four cluster']*4
33 |
34 | # id column.
35 | X = np.zeros((N_ROWS, 9))
36 | X[:,0] = np.arange(N_ROWS)
37 |
38 | # Four columns of one cluster from the standard normal.
39 | X[:,1:5] = np.random.randn(N_ROWS, 4)
40 |
41 | # Four columns of four clusters with unit variance and means \in {0,1,2,3}.
42 | Z = np.random.randint(4, size=(N_ROWS))
43 | X[:,5:] = 4*np.reshape(np.repeat(Z,4), (len(Z),4)) + np.random.randn(N_ROWS, 4)
44 |
45 | # Inference.
46 | engine = Engine(
47 | X, cctypes=cctypes, distargs=distargs, num_states=N_STATES)
48 | engine.transition(N=N_ITERS)
49 |
50 | # Dependence probability.
51 | D = engine.dependence_probability_pairwise()
52 | zmat = sns.clustermap(D, yticklabels=column_names, xticklabels=column_names)
53 | plt.setp(zmat.ax_heatmap.get_yticklabels(), rotation=0)
54 | plt.setp(zmat.ax_heatmap.get_xticklabels(), rotation=90)
55 | plt.show()
56 |
--------------------------------------------------------------------------------
/tests/graphical/dpmm_nignormal.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import matplotlib.pyplot as plt
17 |
18 | from cgpm.crosscat.state import State
19 |
20 | def observe_datum(x):
21 | global state
22 | state.incorporate(rowid=state.n_rows(), observation={0:x})
23 | state.transition_dim_grids()
24 | print 'Observation %d: %f' % (state.n_rows(), x)
25 | while True:
26 | state.transition_view_rows()
27 | state.transition_dim_hypers()
28 | state.transition_view_alphas()
29 | ax.clear()
30 | state.dim_for(0).plot_dist(
31 | state.X[0], Y=np.linspace(0.01,0.99,200), ax=ax)
32 | ax.grid()
33 | plt.pause(.8)
34 |
35 | def on_click(event):
36 | if event.button == 1:
37 | if event.inaxes is not None:
38 | observe_datum(event.xdata)
39 |
40 | # Create state.
41 | initial_point = .8
42 | state = State([[initial_point]], cctypes=['normal'])
43 |
44 | # Activate plotter.
45 | fig, ax = plt.subplots()
46 | ax.grid()
47 | plt.connect('button_press_event', on_click)
48 | plt.show()
49 |
--------------------------------------------------------------------------------
/tests/graphical/one_view.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 | import pytest
19 |
20 | from cgpm.crosscat.state import State
21 | from cgpm.utils import config as cu
22 | from cgpm.utils import general as gu
23 | from cgpm.utils import general as gu
24 | from cgpm.utils import test as tu
25 |
26 | # Set up the data generation
27 | cctypes, distargs = cu.parse_distargs(
28 | ['normal',
29 | 'poisson',
30 | 'bernoulli',
31 | 'categorical(k=4)',
32 | 'lognormal',
33 | 'exponential',
34 | 'beta',
35 | 'geometric',
36 | 'vonmises'])
37 |
38 | T, Zv, Zc = tu.gen_data_table(
39 | 200, [1], [[.25, .25, .5]], cctypes, distargs,
40 | [.95]*len(cctypes), rng=gu.gen_rng(10))
41 |
42 | state = State(T.T, cctypes=cctypes, distargs=distargs, rng=gu.gen_rng(312))
43 | state.transition(N=10, progress=1)
44 |
45 | def test_crash_simulate_joint(state):
46 | state.simulate(-1, [0, 1, 2, 3, 4, 5, 6, 7, 8], N=10)
47 |
48 | def test_crash_logpdf_joint(state):
49 | state.logpdf(-1, {0:1, 1:2, 2:1, 3:3, 4:1, 5:10, 6:.4, 7:2, 8:1.8})
50 |
51 | def test_crash_simulate_conditional(state):
52 | state.simulate(-1, [1, 4, 5, 6, 7, 8], {0:1, 2:1, 3:3}, None, 10)
53 |
54 | def test_crash_logpdf_conditional(state):
55 | state.logpdf(
56 | -1, {1:2, 4:1, 5:10, 6:.4, 7:2, 8:1.8}, {0:1, 2:1, 3:3})
57 |
58 | def test_crash_simulate_joint_observed(state):
59 | state.simulate(1, [0, 1, 2, 3, 4, 5, 6, 7, 8], None, None, 10)
60 |
61 | def test_crash_logpdf_joint_observed(state):
62 | with pytest.raises(ValueError):
63 | state.logpdf(1, {0:1, 1:2, 2:1, 3:3, 4:1, 5:10, 6:.4, 7:2, 8:1.8})
64 |
65 | def test_crash_simulate_conditional_observed(state):
66 | with pytest.raises(ValueError):
67 | state.simulate(1, [1, 4, 5, 6, 7, 8], {0:1, 2:1, 3:3}, None, 10)
68 |
69 | def test_crash_logpdf_conditional_observed(state):
70 | with pytest.raises(ValueError):
71 | state.logpdf(
72 | 1, {1:2, 4:1, 5:10, 6:.4, 7:2, 8:1.8}, {0:1, 2:1, 3:3})
73 |
74 | # Plot!
75 | state.plot()
76 |
77 | # Run some solid checks on a complex state.
78 | test_crash_simulate_joint(state)
79 | test_crash_logpdf_joint(state)
80 | test_crash_simulate_conditional(state)
81 | test_crash_logpdf_conditional(state)
82 | test_crash_simulate_joint_observed(state)
83 | test_crash_logpdf_joint_observed(state)
84 | test_crash_simulate_conditional_observed(state)
85 | test_crash_logpdf_conditional_observed(state)
86 |
87 | # Joint equals chain rule for state 1.
88 | joint = state.logpdf(-1, {0:1, 1:2})
89 | chain = state.logpdf(-1, {0:1}, {1:2}) + state.logpdf(-1, {1:2})
90 | assert np.allclose(joint, chain)
91 |
92 | if False:
93 | state2 = State(T.T, cctypes=cctypes, distargs=distargs, rng=gu.gen_rng(12))
94 | state2.transition(N=10, progress=1)
95 |
96 | # Joint equals chain rule for state 2.
97 | state2.logpdf(-1, {0:1, 1:2})
98 | state2.logpdf(-1, {0:1}, {1:2}) + state2.logpdf(-1, {1:2})
99 |
100 | # Take the Monte Carlo average of the conditional.
101 | mc_conditional = np.log(.5) + gu.logsumexp([
102 | state.logpdf(-1, {0:1}, {1:2}),
103 | state2.logpdf(-1, {0:1}, {1:2})
104 | ])
105 |
106 | # Take the Monte Carlo average of the joint.
107 | mc_joint = np.log(.5) + gu.logsumexp([
108 | state.logpdf(-1, {0:1, 1:2}),
109 | state2.logpdf(-1, {0:1, 1:2})
110 | ])
111 |
112 | # Take the Monte Carlo average of the marginal.
113 | mc_marginal = np.log(.5) + gu.logsumexp([
114 | state.logpdf(-1, {1:2}),
115 | state2.logpdf(-1, {1:2})
116 | ])
117 |
--------------------------------------------------------------------------------
/tests/graphical/recover.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 | import matplotlib.pyplot as plt
19 |
20 | from cgpm.utils import test as tu
21 | from cgpm.utils import sampling as su
22 | from cgpm.crosscat.engine import Engine
23 |
24 | shapes = ['x', 'sin', 'ring', 'dots']
25 | gen_function = {
26 | 'sin' : tu.gen_sine_wave,
27 | 'x' : tu.gen_x,
28 | 'ring' : tu.gen_ring,
29 | 'dots' : tu.gen_four_dots
30 | }
31 |
32 | cctypes = ['normal', 'normal']
33 | distargs = [None, None]
34 |
35 |
36 | def run_test(args):
37 | n_rows = args["num_rows"]
38 | n_iters = args["num_iters"]
39 | n_chains = args["num_chains"]
40 |
41 | n_per_chain = int(float(n_rows)/n_chains)
42 |
43 | fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(16,9))
44 | axes = axes.ravel()
45 | k = 0
46 | for shape in shapes:
47 | print "Shape: %s" % shape
48 | T_o = np.asarray(gen_function[shape](n_rows))
49 | T_i = []
50 |
51 | engine = Engine(
52 | T_o.T, cctypes=cctypes, distargs=distargs, num_states=n_chains)
53 | engine.transition(N=n_iters)
54 |
55 | for chain in xrange(n_chains):
56 | state = engine.get_state(chain)
57 | print "chain %i of %i" % (chain+1, n_chains)
58 | T_i.extend(state.simulate(-1, [0,1], N=n_per_chain))
59 |
60 | T_i = np.array(T_i)
61 |
62 | ax = axes[k]
63 | ax.scatter( T_o[0], T_o[1], color='blue', edgecolor='none' )
64 | ax.set_xlabel("X")
65 | ax.set_ylabel("Y")
66 | ax.set_title("%s original" % shape)
67 |
68 | ax = axes[k+4]
69 | ax.scatter( T_i[:,0], T_i[:,1], color='red', edgecolor='none' )
70 | ax.set_xlabel("X")
71 | ax.set_ylabel("Y")
72 | ax.set_xlim(ax.get_xlim())
73 | ax.set_ylim(ax.get_ylim())
74 | ax.set_title("%s simulated" % shape)
75 |
76 | k += 1
77 |
78 | print "Done."
79 | return fig
80 |
81 | if __name__ == "__main__":
82 | args = dict(num_rows=1000, num_iters=5000, num_chains=6)
83 | fig = run_test(args)
84 | fig.savefig('recover.png')
85 |
--------------------------------------------------------------------------------
/tests/graphical/slice.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import math
18 |
19 | import numpy as np
20 |
21 | from matplotlib import pyplot as plt
22 | from scipy.integrate import trapz
23 |
24 | from cgpm.utils import general as gu
25 | from cgpm.utils import sampling as su
26 |
27 |
28 | def main(num_samples, burn, lag, w):
29 | alpha = 1.0
30 | N = 25
31 | Z = gu.simulate_crp(N, alpha)
32 | K = max(Z) + 1
33 |
34 | # CRP with gamma prior.
35 | log_pdf_fun = lambda alpha : gu.logp_crp_unorm(N, K, alpha) - alpha
36 | proposal_fun = lambda : np.random.gamma(1.0, 1.0)
37 | D = (0, float('Inf'))
38 |
39 | samples = su.slice_sample(proposal_fun, log_pdf_fun, D,
40 | num_samples=num_samples, burn=burn, lag=lag, w=w)
41 |
42 | minval = min(samples)
43 | maxval = max(samples)
44 | xvals = np.linspace(minval, maxval, 100)
45 | yvals = np.array([math.exp(log_pdf_fun(x)) for x in xvals])
46 | yvals /= trapz(xvals, yvals)
47 |
48 | fig, ax = plt.subplots(2,1)
49 |
50 | ax[0].hist(samples, 50, normed=True)
51 |
52 | ax[1].hist(samples, 100, normed=True)
53 | ax[1].plot(xvals,-yvals, c='red', lw=3, alpha=.8)
54 | ax[1].set_xlim(ax[0].get_xlim())
55 | ax[1].set_ylim(ax[0].get_ylim())
56 | plt.show()
57 |
58 | if __name__ == '__main__':
59 | import argparse
60 |
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument('--num_samples', default=100, type=int)
63 | parser.add_argument('--burn', default=10, type=int)
64 | parser.add_argument('--lag', default=5, type=int)
65 | parser.add_argument('--w', default=1.0, type=float)
66 |
67 | args = parser.parse_args()
68 |
69 | num_samples = args.num_samples
70 | burn = args.burn
71 | lag = args.lag
72 | w = args.w
73 |
74 | main(num_samples, burn, lag, w)
75 |
--------------------------------------------------------------------------------
/tests/hacks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """"pytest hacks"""
18 |
19 | import pytest
20 |
21 |
22 | # XXX https://github.com/pytest-dev/pytest/issues/2338
23 | def skip(reason):
24 | if pytest.__version__ >= '3':
25 | raise pytest.skip.Exception(reason, allow_module_level=True)
26 | else:
27 | pytest.skip(reason)
28 |
--------------------------------------------------------------------------------
/tests/markers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | integration = pytest.mark.skipif(not pytest.config.getoption('--integration'),
20 | reason='specify --integration to run integration tests')
21 |
--------------------------------------------------------------------------------
/tests/stochastic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2010-2016, MIT Probabilistic Computing Project
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import sys
19 |
20 | class StochasticError(Exception):
21 | def __init__(self, seed, exctype, excvalue):
22 | self.seed = seed
23 | self.exctype = exctype
24 | self.excvalue = excvalue
25 | def __str__(self):
26 | hexseed = self.seed.encode('hex')
27 | if hasattr(self.exctype, '__name__'):
28 | typename = self.exctype.__name__
29 | else:
30 | typename = repr(self.exctype)
31 | return '[seed %s]\n%s: %s' % (hexseed, typename, self.excvalue)
32 |
33 | def stochastic(max_runs, min_passes):
34 | assert 0 < max_runs
35 | assert min_passes <= max_runs
36 | def wrap(f):
37 | def f_(seed=None):
38 | if seed is not None:
39 | return f(seed)
40 | npasses = 0
41 | last_seed = None
42 | last_exc_info = None
43 | for i in xrange(max_runs):
44 | seed = os.urandom(32)
45 | try:
46 | value = f(seed)
47 | except:
48 | last_seed = seed
49 | last_exc_info = sys.exc_info()
50 | else:
51 | npasses += 1
52 | if min_passes <= npasses:
53 | return value
54 | t, v, tb = last_exc_info
55 | raise StochasticError, StochasticError(last_seed, t, v), tb
56 | return f_
57 | return wrap
58 |
--------------------------------------------------------------------------------
/tests/test_add_state.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Test suite targeting cgpm.crosscat.engine.add_state."""
18 |
19 | import pytest
20 |
21 | from cgpm.crosscat.engine import Engine
22 | from cgpm.dummy.twoway import TwoWay
23 | from cgpm.utils import general as gu
24 |
25 |
26 | def get_engine():
27 | X = [[0.123, 1, 0], [1.12, 0, 1], [1.1, 1, 2]]
28 | rng = gu.gen_rng(1)
29 | return Engine(
30 | X,
31 | outputs=[8,7,9],
32 | num_states=4,
33 | cctypes=['normal', 'bernoulli', 'categorical'],
34 | distargs=[None, None, {'k': 3}],
35 | rng=rng
36 | )
37 |
38 |
39 | def test_engine_add_state_basic():
40 | engine = get_engine()
41 | initial_num_states = engine.num_states()
42 | engine.add_state()
43 | assert engine.num_states() == initial_num_states + 1
44 | engine.add_state(count=2)
45 | assert engine.num_states() == initial_num_states + 3
46 | engine.transition(N=3)
47 | engine.drop_state(6)
48 | assert engine.num_states() == initial_num_states + 2
49 |
50 |
51 | def test_engine_add_state_custom():
52 | # Add a state with a specified view and row partition.
53 | engine = get_engine()
54 | engine.add_state(count=2, Zv={7:0, 8:1, 9:0}, Zrv={0: [0,1,1], 1: [1,1,1]})
55 | new_state = engine.get_state(engine.num_states()-1)
56 | assert new_state.Zv() == {7:0, 8:1, 9:0}
57 | assert new_state.views[0].Zr(0) == 0
58 | assert new_state.views[0].Zr(1) == 1
59 | assert new_state.views[0].Zr(2) == 1
60 | assert new_state.views[1].Zr(0) == 1
61 | assert new_state.views[1].Zr(1) == 1
62 | assert new_state.views[1].Zr(2) == 1
63 |
64 |
65 | def test_engine_add_state_kwarg_errors():
66 | engine = get_engine()
67 | with pytest.raises(ValueError):
68 | # Cannot specify new dataset.
69 | engine.add_state(X=[[0,1]])
70 | with pytest.raises(ValueError):
71 | # Cannot specify new outputs.
72 | engine.add_state(outputs=[1,2])
73 | with pytest.raises(ValueError):
74 | # Cannot specify new cctypes.
75 | engine.add_state(cctypes=['normal', 'normal'])
76 | with pytest.raises(ValueError):
77 | # Cannot specify new distargs.
78 | engine.add_state(distargs=[None, None, {'k' : 3}])
79 | with pytest.raises(ValueError):
80 | # Cannot specify all together.
81 | engine.add_state(X=[[0,1]], outputs=[1,2], cctypes=['normal', 'normal'])
82 |
83 |
84 | def test_engine_add_state_composite_errors():
85 | # XXX Add a Github ticket to support this feature. User should provide all
86 | # the composite cgpms to match the count of initialized models.
87 | engine = get_engine()
88 | engine.compose_cgpm([
89 | TwoWay(outputs=[4], inputs=[7]) for _i in xrange(engine.num_states())
90 | ])
91 | with pytest.raises(ValueError):
92 | engine.add_state()
93 |
--------------------------------------------------------------------------------
/tests/test_bernoulli.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 | import numpy as np
19 |
20 | from cgpm.crosscat.engine import Engine
21 | from cgpm.utils import general as gu
22 |
23 |
24 | DATA_NUM_0 = 100
25 | DATA_NUM_1 = 200
26 | NUM_SIM = 10000
27 | NUM_ITER = 5
28 |
29 |
30 | def test_bernoulli():
31 | # Switch for multiprocess (0 is faster).
32 | multiprocess = 0
33 |
34 | # Create categorical data of DATA_NUM_0 zeros and DATA_NUM_1 ones.
35 | data = np.transpose(np.array([[0] * DATA_NUM_0 + [1] * DATA_NUM_1]))
36 |
37 | # Run a single chain for a few iterations.
38 | engine = Engine(
39 | data, cctypes=['categorical'], distargs=[{'k': 2}],
40 | rng=gu.gen_rng(0), multiprocess=0)
41 | engine.transition(NUM_ITER, multiprocess=multiprocess)
42 |
43 | # Simulate from hypothetical row and compute the proportion of ones.
44 | sample = engine.simulate(-1, [0], N=NUM_SIM, multiprocess=multiprocess)[0]
45 | sum_b = sum(s[0] for s in sample)
46 | observed_prob_of_1 = (float(sum_b) / float(NUM_SIM))
47 | true_prob_of_1 = float(DATA_NUM_1) / float(DATA_NUM_0 + DATA_NUM_1)
48 | # Check 1% relative match.
49 | assert np.allclose(true_prob_of_1, observed_prob_of_1, rtol=.1)
50 |
51 | # Simulate from observed row as a crash test.
52 | sample = engine.simulate(1, [0], N=1, multiprocess=multiprocess)
53 |
54 | # Ensure normalized unobserved probabilities.
55 | p0_uob = engine.logpdf(-1, {0:0}, multiprocess=multiprocess)[0]
56 | p1_uob = engine.logpdf(-1, {0:1}, multiprocess=multiprocess)[0]
57 | assert np.allclose(gu.logsumexp([p0_uob, p1_uob]), 0)
58 |
59 | # A logpdf query constraining an observed returns an error.
60 | with pytest.raises(ValueError):
61 | engine.logpdf(1, {0:0}, multiprocess=multiprocess)
62 | with pytest.raises(ValueError):
63 | engine.logpdf(1, {0:1}, multiprocess=multiprocess)
64 |
--------------------------------------------------------------------------------
/tests/test_check_env_debug.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import os
19 | import pytest
20 |
21 | from cgpm.utils.config import check_env_debug
22 |
23 | token = 'GPMCCDEBUG'
24 |
25 | def test_debug_none():
26 | if token in os.environ:
27 | del os.environ[token]
28 | assert not check_env_debug()
29 |
30 | def test_debug_false():
31 | os.environ[token] = '0'
32 | assert not check_env_debug()
33 |
34 | def test_debug_true():
35 | os.environ[token] = '1'
36 | assert check_env_debug()
37 |
--------------------------------------------------------------------------------
/tests/test_cmi_partition.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import numpy as np
19 | import pytest
20 |
21 | from cgpm.crosscat.engine import DummyCgpm
22 | from cgpm.crosscat.state import State
23 | from cgpm.utils.general import gen_rng
24 |
25 |
26 | def retrieve_state():
27 | X = np.eye(7)
28 | cctypes = ['normal'] * 7
29 | return State(
30 | X,
31 | outputs=[10,11,12,13,14,15,16],
32 | Zv={10:0, 11:0, 12:1, 13:2, 14:2, 15:2, 16:0},
33 | cctypes=cctypes,
34 | rng=gen_rng(2),
35 | )
36 |
37 |
38 | def test_partition_mutual_information_query():
39 | state = retrieve_state()
40 |
41 | def check_expected_partitions(query, expected):
42 | blocks = state._partition_mutual_information_query(*query)
43 | assert len(blocks) == len(expected)
44 | for b in blocks:
45 | assert b in expected
46 |
47 | check_expected_partitions(
48 | query=([10], [11], {}),
49 | expected=[
50 | ([10], [11], {}),
51 | ])
52 | check_expected_partitions(
53 | query=([10,16], [11], {}),
54 | expected=[
55 | ([10,16], [11], {}),
56 | ])
57 | check_expected_partitions(
58 | query=([10,16], [11], {12:None}),
59 | expected=[
60 | ([10,16], [11], {}),
61 | ([], [], {12:None}),
62 | ])
63 | check_expected_partitions(
64 | query=([10,16], [11, 14], {12:None}),
65 | expected=[
66 | ([10,16], [11], {}),
67 | ([], [14], {}),
68 | ([], [], {12:None}),
69 | ])
70 | check_expected_partitions(
71 | query=([15, 16], [11, 14, 13], {12:None, 13:2, 10:-12}),
72 | expected=[
73 | ([15], [14,13], {13:2}),
74 | ([16], [11], {10:-12}),
75 | ([], [], {12:None}),
76 | ])
77 | check_expected_partitions(
78 | query=([15, 16], [11, 14, 13], {12:None, 13:2}),
79 | expected=[
80 | ([15], [14,13], {13:2}),
81 | ([16], [11], {}),
82 | ([], [], {12:None}),
83 | ])
84 | check_expected_partitions(
85 | query=([15, 16], [15, 16], {12:None, 13:2}),
86 | expected=[
87 | ([15], [15], {13:2}),
88 | ([16], [16], {}),
89 | ([], [], {12:None}),
90 | ])
91 | check_expected_partitions(
92 | query=([13, 14], [14, 13], {}),
93 | expected=[
94 | ([13, 14], [14,13], {}),
95 | ])
96 |
97 | # Connect variable 12 with variables in view 0.
98 | state.compose_cgpm(DummyCgpm(outputs=[100, 102], inputs=[10, 12]))
99 |
100 | check_expected_partitions(
101 | query=([15, 16], [11, 14, 13], {12:None, 13:2}),
102 | expected=[
103 | ([15], [14,13], {13:2}),
104 | ([16], [11], {12:None}),
105 | ])
106 | check_expected_partitions(
107 | query=([15, 16], [11, 14, 13], {12:None, 13:2}),
108 | expected=[
109 | ([15], [14,13], {13:2}),
110 | ([16], [11], {12:None}),
111 | ])
112 | check_expected_partitions(
113 | query=([15, 100, 16], [11, 14, 13], {102: -12, 12:None, 13:2}),
114 | expected=[
115 | ([15], [14,13], {13:2}),
116 | ([100, 16], [11], {12:None, 102: -12}),
117 | ])
118 |
119 | # Connect variables in view 0 with variables in view 2.
120 | state.compose_cgpm(DummyCgpm(outputs=[200, 202], inputs=[100, 12]))
121 | state.compose_cgpm(DummyCgpm(outputs=[300], inputs=[200, 13]))
122 |
123 | check_expected_partitions(
124 | query=([15, 16], [11, 14, 13], {12:None, 13:2, 10:-12}),
125 | expected=[
126 | ([15, 16], [11, 14, 13], {12:None, 13:2, 10:-12})
127 | ])
128 | check_expected_partitions(
129 | query=([15, 16], [11, 14, 13], {12:None, 13:2}),
130 | expected=[
131 | ([15, 16], [11, 14, 13], {12:None, 13:2})
132 | ])
133 | check_expected_partitions(
134 | query=([300], [202], {13:2}),
135 | expected=[
136 | ([300], [202], {13:2})
137 | ])
138 |
--------------------------------------------------------------------------------
/tests/test_dependence_constraints.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import itertools
18 | import pytest
19 |
20 | import numpy as np
21 |
22 | from cgpm.crosscat.state import State
23 | from cgpm.utils import general as gu
24 | from cgpm.utils import validation as vu
25 |
26 | from markers import integration
27 |
28 |
29 | def test_naive_bayes_independence():
30 | rng = gu.gen_rng(1)
31 | D = rng.normal(size=(10,1))
32 | T = np.repeat(D, 10, axis=1)
33 | Ci = list(itertools.combinations(range(10), 2))
34 | state = State(T, cctypes=['normal']*10, Ci=Ci, rng=rng)
35 | state.transition(N=10, progress=0)
36 | vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
37 |
38 |
39 | def test_complex_independent_relationships():
40 | rng = gu.gen_rng(1)
41 | D = rng.normal(size=(10,1))
42 | T = np.repeat(D, 10, axis=1)
43 | Ci = [(2,8), (0,3)]
44 | state = State(T, cctypes=['normal']*10, Ci=Ci, rng=rng)
45 | state.transition(N=10, progress=0)
46 | vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
47 |
48 |
49 | CIs = [[], [(2,8), (0,3)]]
50 | @pytest.mark.parametrize('Ci', CIs)
51 | def test_simple_dependence_constraint(Ci):
52 | rng = gu.gen_rng(1)
53 | D = rng.normal(size=(10,1))
54 | T = np.repeat(D, 10, axis=1)
55 | Cd = [(2,0), (8,3)]
56 | state = State(T, cctypes=['normal']*10, Ci=Ci, Cd=Cd, rng=rng)
57 | with pytest.raises(ValueError):
58 | # Cannot transition columns with dependencies.
59 | state.transition(N=10, kernels=['columns'], progress=0)
60 | state.transition(
61 | N=10,
62 | kernels=['rows', 'alpha', 'column_hypers', 'alpha', 'view_alphas'],
63 | progress=False)
64 | vu.validate_crp_constrained_partition(state.Zv(), Cd, Ci, {}, {})
65 |
66 |
67 | def test_zero_based_outputs():
68 | """Constraints must have zero-based output variables for now."""
69 | rng = gu.gen_rng(1)
70 | D = rng.normal(size=(10,1))
71 | T = np.repeat(D, 10, axis=1)
72 | outputs = range(10, 20)
73 | with pytest.raises(ValueError):
74 | State(T, outputs=range(10,20), cctypes=['normal']*10,
75 | Cd=[(2,0)], rng=rng)
76 | with pytest.raises(ValueError):
77 | State(T, outputs=range(10,20), cctypes=['normal']*10,
78 | Ci=[(2,0)], rng=gu.gen_rng(0))
79 |
80 | @integration
81 | def test_naive_bayes_independence_lovecat():
82 | rng = gu.gen_rng(1)
83 | D = rng.normal(size=(10,1))
84 | T = np.repeat(D, 10, axis=1)
85 | Ci = list(itertools.combinations(range(10), 2))
86 | state = State(T, cctypes=['normal']*10, Ci=Ci, rng=gu.gen_rng(0))
87 | state.transition(N=10, progress=0)
88 | vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
89 | state.transition_lovecat(N=100, progress=0)
90 | vu.validate_crp_constrained_partition(state.Zv(), [], Ci, {}, {})
91 |
92 |
93 | @integration
94 | def test_complex_independent_relationships_lovecat():
95 | rng = gu.gen_rng(1)
96 | D = rng.normal(size=(10,1))
97 | T = np.repeat(D, 10, axis=1)
98 | Ci = [(2,8), (0,3)]
99 | Cd = [(2,3), (0,8)]
100 | state = State(T, cctypes=['normal']*10, Ci=Ci, Cd=Cd, rng=gu.gen_rng(0))
101 | state.transition_lovecat(N=1000, progress=1)
102 | vu.validate_crp_constrained_partition(state.Zv(), Cd, Ci, {}, {})
103 |
104 | @integration
105 | def test_independence_inference_quality_lovecat():
106 | rng = gu.gen_rng(584)
107 | column_view_1 = rng.normal(loc=0, size=(50,1))
108 |
109 | column_view_2 = np.concatenate((
110 | rng.normal(loc=10, size=(25,1)),
111 | rng.normal(loc=20, size=(25,1)),
112 | ))
113 |
114 | data_view_1 = np.repeat(column_view_1, 4, axis=1)
115 | data_view_2 = np.repeat(column_view_2, 4, axis=1)
116 | data = np.column_stack((data_view_1, data_view_2))
117 |
118 | Zv0 = {i: 0 for i in xrange(8)}
119 | state = State(data, Zv=Zv0, cctypes=['normal']*8, rng=gu.gen_rng(10))
120 | state.transition_lovecat(N=100, progress=1)
121 | for col in [0, 1, 2, 3,]:
122 | assert state.Zv(col) == state.Zv(0)
123 | for col in [4, 5, 6, 7]:
124 | assert state.Zv(col) == state.Zv(4)
125 | assert state.Zv(0) != state.Zv(4)
126 |
127 | # Get lovecat to merge the dependent columns into one view.
128 | Cd = [(0,1), (2,3), (4,5), (6,7)]
129 | Zv0 = {0:0, 1:0, 2:1, 3:1, 4:2, 5:2, 6:3, 7:3}
130 | state = State(data, Zv=Zv0, cctypes=['normal']*8, Cd=Cd, rng=gu.gen_rng(1))
131 | state.transition_lovecat(N=100, progress=1)
132 | for col in [0, 1, 2, 3,]:
133 | assert state.Zv(col) == state.Zv(0)
134 | for col in [4, 5, 6, 7]:
135 | assert state.Zv(col) == state.Zv(4)
136 | assert state.Zv(0) != state.Zv(4)
137 |
--------------------------------------------------------------------------------
/tests/test_diagnostics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import time
18 |
19 | from cgpm.crosscat.engine import Engine
20 | from cgpm.utils import general as gu
21 | from cgpm.utils import test as tu
22 |
23 | from markers import integration
24 |
25 |
26 | def retrieve_normal_dataset():
27 | D, Zv, Zc = tu.gen_data_table(
28 | n_rows=20,
29 | view_weights=None,
30 | cluster_weights=[[.2,.2,.2,.4],[.2,.8],],
31 | cctypes=['normal', 'normal'],
32 | distargs=[None]*2,
33 | separation=[0.95]*2,
34 | view_partition=[0,1],
35 | rng=gu.gen_rng(12))
36 | return D
37 |
38 |
39 | @integration
40 | def test_simple_diagnostics():
41 | def diagnostics_without_iters(diagnostics):
42 | return (v for k, v in diagnostics.iteritems() if k != 'iterations')
43 | D = retrieve_normal_dataset()
44 | engine = Engine(
45 | D.T, cctypes=['normal']*len(D), num_states=4, rng=gu.gen_rng(12),)
46 | engine.transition(N=20, checkpoint=2)
47 | assert all(
48 | all(len(v) == 10 for v in diagnostics_without_iters(state.diagnostics))
49 | for state in engine.states
50 | )
51 | engine.transition(N=7, checkpoint=2)
52 | assert all(
53 | all(len(v) == 13 for v in diagnostics_without_iters(state.diagnostics))
54 | for state in engine.states
55 | )
56 | engine.transition_lovecat(N=7, checkpoint=3)
57 | assert all(
58 | all(len(v) == 15 for v in diagnostics_without_iters(state.diagnostics))
59 | for state in engine.states
60 | )
61 | engine.transition(S=0.5)
62 | assert all(
63 | all(len(v) == 15 for v in diagnostics_without_iters(state.diagnostics))
64 | for state in engine.states
65 | )
66 | engine.transition(S=0.5, checkpoint=1)
67 | assert all(
68 | all(len(v) > 15 for v in diagnostics_without_iters(state.diagnostics))
69 | for state in engine.states
70 | )
71 | # Add a timed analysis with diagnostic overrides large iterations, due
72 | # to oddness of diagnostic tracing in lovecat.
73 | start = time.time()
74 | engine.transition_lovecat(N=20000, S=1, checkpoint=1)
75 | assert 1 < time.time() - start < 3
76 |
--------------------------------------------------------------------------------
/tests/test_engine_alter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 | import pytest
19 |
20 | from cgpm.crosscat.engine import Engine
21 | from cgpm.utils import config as cu
22 | from cgpm.utils import general as gu
23 | from cgpm.utils import test as tu
24 |
25 |
26 | # Set up the data generation
27 | def get_engine():
28 | cctypes, distargs = cu.parse_distargs([
29 | 'normal',
30 | 'poisson',
31 | 'bernoulli',
32 | 'lognormal',
33 | 'beta',
34 | 'vonmises'
35 | ])
36 | T, Zv, Zc = tu.gen_data_table(
37 | 20, [1], [[.25, .25, .5]], cctypes, distargs,
38 | [.95]*len(cctypes), rng=gu.gen_rng(0))
39 | T = T.T
40 | # Make some nan cells for evidence.
41 | T[5,0] = T[5,1] = T[5,2] = T[5,3] = np.nan
42 | T[8,4] = np.nan
43 | engine = Engine(
44 | T,
45 | cctypes=cctypes,
46 | distargs=distargs,
47 | num_states=6,
48 | rng=gu.gen_rng(0)
49 | )
50 | engine.transition(N=2)
51 | return engine
52 |
53 |
54 | def test_simple_alterations():
55 | engine = get_engine()
56 |
57 | # Initial state outputs.
58 | out_initial = engine.states[0].outputs
59 |
60 | # Indexes of outputs to alter.
61 | out_f = 0
62 | out_g = 3
63 |
64 | def alteration_f(state):
65 | state.outputs[out_f] *= 13
66 | return state
67 |
68 | def alteration_g(state):
69 | state.outputs[out_g] *= 12
70 | return state
71 |
72 | statenos = [0,3]
73 | engine.alter((alteration_f, alteration_g), [0,3])
74 |
75 | out_expected = list(out_initial)
76 | out_expected[out_f] *= 13
77 | out_expected[out_g] *= 12
78 |
79 | for s in xrange(engine.num_states()):
80 | if s in statenos:
81 | assert engine.states[s].outputs == out_expected
82 | else:
83 | assert engine.states[s].outputs == out_initial
84 |
--------------------------------------------------------------------------------
/tests/test_engine_seed.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.crosscat.engine import Engine
18 | from cgpm.utils import general as gu
19 |
20 | def test_engine_simulate_no_repeat():
21 | """Generate 3 samples from 2 states 10 times, and ensure uniqueness."""
22 | rng = gu.gen_rng(1)
23 | engine = Engine(X=[[1]], cctypes=['normal'], num_states=2, rng=rng)
24 | samples_list = [
25 | [sample[0] for sample in engine.simulate(None, [0], N=3)[0]]
26 | for _i in xrange(10)
27 | ]
28 | samples_set = set([frozenset(s) for s in samples_list])
29 | assert len(samples_set) == len(samples_list)
30 |
--------------------------------------------------------------------------------
/tests/test_exponentials_transition_hypers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | import numpy as np
20 |
21 | from cgpm.utils import config as cu
22 | from cgpm.utils import general as gu
23 | from cgpm.utils import test as tu
24 |
25 |
26 | cctypes = [
27 | ('normal', None),
28 | ('categorical', {'k':4}),
29 | ('lognormal', None),
30 | ('poisson', None),
31 | ('bernoulli', None),
32 | ('exponential', None),
33 | ('geometric', None),
34 | ('vonmises', None)
35 | ]
36 |
37 |
38 | @pytest.mark.parametrize('cctype', cctypes)
39 | def test_transition_hypers(cctype):
40 | name, arg = cctype
41 | model = cu.cctype_class(name)(
42 | outputs=[0], inputs=None, distargs=arg, rng=gu.gen_rng(10))
43 | D, Zv, Zc = tu.gen_data_table(
44 | 50, [1], [[.33, .33, .34]], [name], [arg], [.8], rng=gu.gen_rng(1))
45 |
46 | hypers_previous = model.get_hypers()
47 | for rowid, x in enumerate(np.ravel(D)[:25]):
48 | model.incorporate(rowid, {0:x}, None)
49 | model.transition_hypers(N=3)
50 | hypers_new = model.get_hypers()
51 | assert not all(
52 | np.allclose(hypers_new[hyper], hypers_previous[hyper])
53 | for hyper in hypers_new)
54 |
55 | for rowid, x in enumerate(np.ravel(D)[:25]):
56 | model.incorporate(rowid+25, {0:x}, None)
57 | model.transition_hypers(N=3)
58 | hypers_newer = model.get_hypers()
59 | assert not all(
60 | np.allclose(hypers_new[hyper], hypers_newer[hyper])
61 | for hyper in hypers_newer)
62 |
63 | # In general inference should improve the log score.
64 | # logpdf_score = model.logpdf_score()
65 | # model.transition_hypers(N=200)
66 | # assert model.logpdf_score() > logpdf_score
67 |
68 |
--------------------------------------------------------------------------------
/tests/test_force_cell.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | import numpy as np
20 |
21 | from cgpm.crosscat.state import State
22 | from cgpm.utils import general as gu
23 |
24 |
25 | X = [[1, np.nan, 2, -1, np.nan ],
26 | [1, 3, 2, -1, -5 ],
27 | [18, -7, -2, 11, -12 ],
28 | [1, np.nan, np.nan, np.nan, np.nan ],
29 | [18, -7, -2, 11, -12 ]]
30 |
31 |
32 | def get_state():
33 | return State(
34 | X,
35 | outputs=range(5),
36 | cctypes=['normal']*5,
37 | Zv={0:0, 1:0, 2:0, 3:1, 4:1},
38 | rng=gu.gen_rng(0),
39 | )
40 |
41 |
42 | def test_invalid_nonnan_cell():
43 | state = get_state()
44 | # rowid 0 and output 1 is not nan.
45 | with pytest.raises(ValueError):
46 | state.force_cell(0, {0: 1})
47 |
48 | def test_invalid_variable():
49 | state = get_state()
50 | # Output variable 10 does not exist.
51 | with pytest.raises(KeyError):
52 | state.force_cell(0, {10: 1})
53 |
54 | def test_invalid_rowid():
55 | state = get_state()
56 | # Cannot force non-incorporated rowid.
57 | with pytest.raises(ValueError):
58 | state.force_cell(10, {0: 1})
59 | with pytest.raises(ValueError):
60 | state.force_cell(None, {0: 1})
61 |
62 | def test_force_cell_valid():
63 | state = get_state()
64 | # Retrieve normal component model to force cell (0,1)
65 | rowid, dim = 0, 1
66 | view = state.view_for(dim)
67 | k = view.Zr(rowid)
68 | normal_component = state.dim_for(dim).clusters[k]
69 | # Initial sufficient statistics.
70 | N_initial = normal_component.N
71 | sum_x_initial = normal_component.sum_x
72 | sum_x_sq_initial = normal_component.sum_x_sq
73 | # Force!
74 | state.force_cell(0, {1: 1.5})
75 | # Confirm incremented statistics.
76 | assert normal_component.N == N_initial + 1
77 | assert np.allclose(normal_component.sum_x, sum_x_initial + 1.5)
78 | assert np.allclose(normal_component.sum_x_sq, sum_x_sq_initial + 1.5**2)
79 | # Cannot force again.
80 | with pytest.raises(ValueError):
81 | state.force_cell(0, {1: 1})
82 | # Run a transition.
83 | state.transition(N=1)
84 | # Force cell (3,[1,2])
85 | state.force_cell(3, {1: -7, 2:-2})
86 | state.transition(N=1)
87 |
--------------------------------------------------------------------------------
/tests/test_get_vqe.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.utils import validation as vu
18 |
19 |
20 | def test_partition_query_evidence_dict():
21 | Zv = [0,0,0,1,1,1,2,2,2,3]
22 | query = {9:101, 1:1, 4:2, 5:7, 7:0}
23 | evidence = {2:4, 3:1, 9:1, 6:-1, 0:0}
24 | queries, evidences = vu.partition_query_evidence(Zv, query, evidence)
25 |
26 | # All 4 views have query.
27 | assert len(queries) == 4
28 |
29 | # View 0 has 2 queries.
30 | assert len(queries[0]) == 1
31 | assert queries[0][1] == 1
32 | # View 1 has 2 queries.
33 | assert len(queries[1]) == 2
34 | assert queries[1][4] == 2
35 | assert queries[1][5] == 7
36 | # View 2 has 1 queries.
37 | assert len(queries[2]) == 1
38 | assert queries[2][7] == 0
39 | # View 3 has 1 queries.
40 | assert len(queries[3]) == 1
41 | assert queries[3][9] == 101
42 |
43 | # Views 0,1,2,3 have evidence.
44 | assert len(evidences) == 4
45 | # View 0 has 2 evidence.
46 | assert len(evidences[0]) == 2
47 | assert evidences[0][0] == 0
48 | assert evidences[0][2] == 4
49 | # View 1 has 1 evidence.
50 | assert len(evidences[1]) == 1
51 | assert evidences[1][3] == 1
52 | # View 2 has 1 evidence.
53 | assert len(evidences[2]) == 1
54 | assert evidences[2][6] == -1
55 | # View 3 has 1 evidence.
56 | assert len(evidences[3]) == 1
57 | assert evidences[3][9] == 1
58 |
59 |
60 | def test_partition_query_evidence_list():
61 | Zv = [0,0,0,1,1,1,2,2,2,3]
62 | query = [9, 1, 4, 5, 7]
63 | evidence = {2:-4, 3:-1, 9:-1, 6:1, 0:100}
64 | queries, evidences = vu.partition_query_evidence(Zv, query, evidence)
65 |
66 | # All 4 views have query.
67 | assert len(queries) == 4
68 |
69 | # View 0 has 2 queries.
70 | assert len(queries[0]) == 1
71 | assert 1 in queries[0]
72 | # View 1 has 2 queries.
73 | assert len(queries[1]) == 2
74 | assert 4 in queries[1]
75 | assert 5 in queries[1]
76 | # View 2 has 1 queries.
77 | assert len(queries[2]) == 1
78 | assert 7 in queries[2]
79 | # View 3 has 1 queries.
80 | assert len(queries[3]) == 1
81 | assert 9 in queries[3]
82 |
83 | # Views 0,1,2,3 have evidence.
84 | assert len(evidences) == 4
85 | # View 0 has 2 evidence.
86 | assert len(evidences[0]) == 2
87 | assert evidences[0][0] == 100
88 | assert evidences[0][2] == -4
89 | # View 1 has 1 evidence.
90 | assert len(evidences[1]) == 1
91 | assert evidences[1][3] == -1
92 | # View 2 has 1 evidence.
93 | assert len(evidences[2]) == 1
94 | assert evidences[2][6] == 1
95 | # View 3 has 1 evidence.
96 | assert len(evidences[3]) == 1
97 | assert evidences[3][9] == -1
98 |
--------------------------------------------------------------------------------
/tests/test_impossible_evidence.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """This test suite ensures that simulate and logpdf with zero-density evidence
18 | raises a ValueError."""
19 |
20 | import pytest
21 |
22 | from cgpm.crosscat.engine import State
23 | from cgpm.utils import config as cu
24 | from cgpm.utils import general as gu
25 | from cgpm.utils import test as tu
26 |
27 |
28 | @pytest.fixture(scope='module')
29 | def state():
30 | # Set up the data generation
31 | cctypes, distargs = cu.parse_distargs([
32 | 'normal',
33 | 'poisson',
34 | 'bernoulli',
35 | 'lognormal',
36 | 'beta',
37 | 'vonmises'])
38 | T, Zv, Zc = tu.gen_data_table(
39 | 30, [1], [[.25, .25, .5]], cctypes, distargs,
40 | [.95]*len(cctypes), rng=gu.gen_rng(0))
41 | T = T.T
42 | s = State(
43 | T,
44 | cctypes=cctypes,
45 | distargs=distargs,
46 | Zv={i: 0 for i in xrange(len(cctypes))},
47 | rng=gu.gen_rng(0)
48 | )
49 | return s
50 |
51 |
52 | def test_impossible_simulate_evidence(state):
53 | with pytest.raises(ValueError):
54 | # Variable 2 is binary-valued, so .8 is impossible.
55 | state.simulate(-1, [0,1], {2:.8})
56 | with pytest.raises(ValueError):
57 | # Variable 3 is lognormal so -1 impossible.
58 | state.simulate(-1, [4], {3:-1})
59 | with pytest.raises(ValueError):
60 | # Variable 4 is beta so 1.1 impossible.
61 | state.simulate(-1, [5], {3:-1})
62 |
63 |
64 | def test_impossible_logpdf_evidence(state):
65 | with pytest.raises(ValueError):
66 | # Variable 2 is binary-valued, so .8 is impossible.
67 | state.logpdf(-1, {0:-1}, {2:.8})
68 | with pytest.raises(ValueError):
69 | # Variable 3 is lognormal so -1 impossible.
70 | state.logpdf(-1, {1:1}, {3:-1})
71 | with pytest.raises(ValueError):
72 | # Variable 4 is beta so 1.1 impossible.
73 | state.logpdf(-1, {4:1.1}, {3:-1})
74 |
75 |
76 | def test_valid_logpdf_query(state):
77 | # Zero density logpdf is fine.
78 | state.logpdf(-1, {2:.8}, {1:1.})
79 | state.logpdf(-1, {5:18})
80 |
81 |
--------------------------------------------------------------------------------
/tests/test_incorporate_row.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | import numpy as np
20 |
21 | from cgpm.crosscat.state import State
22 | from cgpm.utils import general as gu
23 |
24 |
25 | X = [[1, np.nan, 2, -1, np.nan ],
26 | [1, 3, 2, -1, -5 ],
27 | [18, -7, -2, 11, -12 ],
28 | [1, np.nan, np.nan, np.nan, np.nan ],
29 | [18, -7, -2, 11, -12 ]]
30 |
31 |
32 | def get_state():
33 | return State(
34 | X,
35 | outputs=range(5),
36 | cctypes=['normal']*5,
37 | Zv={0:0, 1:0, 2:0, 3:1, 4:1},
38 | rng=gu.gen_rng(0),
39 | )
40 |
41 |
42 | def test_invalid_evidence_keys():
43 | state = get_state()
44 | # Non-existent view -3.
45 | with pytest.raises(ValueError):
46 | state.incorporate(
47 | state.n_rows(),
48 | {0:0, 1:1, 2:2, 3:3, 4:4, state.crp_id_view+2:0}
49 | )
50 |
51 |
52 | def test_invalid_evidence():
53 | state = get_state()
54 | # Evidence is disabled since State has no inputs.
55 | with pytest.raises(Exception):
56 | state.incorporate(state.n_rows(), {0:0, 1:1, 2:2, 3:3, 4:4}, {12:1})
57 |
58 |
59 | def test_invalid_cluster():
60 | state = get_state()
61 | # Should crash with None.
62 | with pytest.raises(Exception):
63 | state.incorporate(
64 | state.n_rows(),
65 | {0:0, 1:1, 2:2, 3:3, 4:4, state.views[0].outputs[0]:None})
66 |
67 |
68 | def test_invalid_query_nan():
69 | state = get_state()
70 | # Not allowed to incorporate nan.
71 | with pytest.raises(ValueError):
72 | state.incorporate(state.n_rows(), {0:np.nan, 1:1, 2:2, 3:3, 4:4})
73 |
74 |
75 | def test_invalid_rowid():
76 | state = get_state()
77 | # Non-contiguous rowids disabled.
78 | for rowid in range(state.n_rows()):
79 | with pytest.raises(ValueError):
80 | state.incorporate(rowid, {0:2})
81 |
82 | def test_incorporate_valid():
83 | state = get_state()
84 | # Incorporate row into cluster 0 for all views.
85 | previous = np.asarray([state.views[v].Nk(0) for v in [0,1]])
86 | state.incorporate(
87 | state.n_rows(),
88 | {0:0, 1:1, 2:2, 3:3, 4:4, state.views[0].outputs[0]:0,
89 | state.views[1].outputs[0]:0}
90 | )
91 | assert [state.views[v].Nk(0) for v in [0,1]] == list(previous+1)
92 | # Incorporate row into cluster 0 for view 1 with some missing values.
93 | previous = state.views[1].Nk(0)
94 | state.incorporate(state.n_rows(), {0:0, 2:2, state.views[1].outputs[0]:0})
95 | assert state.views[1].Nk(0) == previous+1
96 | state.transition(N=2)
97 | # Hypothetical cluster 100.
98 | view = state.views[state.views.keys()[0]]
99 | state.incorporate(
100 | state.n_rows(),
101 | {0:0, 1:1, 2:2, 3:3, 4:4, view.outputs[0]:100})
102 |
103 |
104 | def test_unincorporate():
105 | state = get_state()
106 | # Unincorporate all the rows except for the last one.
107 | # XXX Must remove the last rowid only at each invocation.
108 | rowids = range(0, state.n_rows())
109 | for rowid in rowids[:-1]:
110 | with pytest.raises(ValueError):
111 | state.unincorporate(rowid)
112 | # Remove rowids starting from state.n_rows()-1 down to 1.
113 | for rowid in reversed(rowids[1:]):
114 | state.unincorporate(rowid)
115 | assert state.n_rows() == 1
116 | # Cannot unincorporate the final rowid.
117 | with pytest.raises(ValueError):
118 | state.unincorporate(0)
119 | state.transition(N=2)
120 |
121 |
122 | def test_incorporate_session():
123 | rng = gu.gen_rng(4)
124 | state = State(
125 | X, cctypes=['normal']*5, Zv={0:0, 1:0, 2:1, 3:1, 4:2}, rng=rng)
126 | # Incorporate row into a singleton cluster for all views.
127 | previous = [len(state.views[v].Nk()) for v in [0,1,2]]
128 | data = {i: rng.normal() for i in xrange(5)}
129 | clusters = {
130 | state.views[0].outputs[0]: previous[0],
131 | state.views[1].outputs[0]: previous[1],
132 | state.views[2].outputs[0]: previous[2],
133 | }
134 | state.incorporate(state.n_rows(), gu.merged(data, clusters))
135 | assert [len(state.views[v].Nk()) for v in [0,1,2]] == \
136 | [p+1 for p in previous]
137 | # Incorporate row without specifying clusters, and some missing values
138 | data = {i: rng.normal() for i in xrange(2)}
139 | state.incorporate(state.n_rows(), data)
140 | state.transition(N=3)
141 | # Remove the incorporated rowid.
142 | state.unincorporate(state.n_rows()-1)
143 | state.transition(N=3)
144 |
--------------------------------------------------------------------------------
/tests/test_iter_counter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import time
18 |
19 | import pytest
20 |
21 | from cgpm.crosscat.engine import Engine
22 | from cgpm.crosscat.state import State
23 | from cgpm.dummy.fourway import FourWay
24 | from cgpm.dummy.twoway import TwoWay
25 | from cgpm.utils import general as gu
26 |
27 |
28 | def test_all_kernels():
29 | rng = gu.gen_rng(0)
30 | X = rng.normal(size=(5,5))
31 | state = State(X, cctypes=['normal']*5)
32 | state.transition(N=5)
33 | for k, n in state.to_metadata()['diagnostics']['iterations'].iteritems():
34 | assert n == 5
35 |
36 | def test_individual_kernels():
37 | rng = gu.gen_rng(0)
38 | X = rng.normal(size=(5,5))
39 | state = State(X, cctypes=['normal']*5)
40 | state.transition(N=3, kernels=['alpha', 'rows'])
41 | check_expected_counts(
42 | state.diagnostics['iterations'],
43 | {'alpha':3, 'rows':3})
44 | state.transition(N=5, kernels=['view_alphas', 'column_params'])
45 | check_expected_counts(
46 | state.to_metadata()['diagnostics']['iterations'],
47 | {'alpha':3, 'rows':3, 'view_alphas':5, 'column_params':5})
48 | state.transition(
49 | N=1, kernels=['view_alphas', 'column_params', 'column_hypers'])
50 | check_expected_counts(
51 | state.to_metadata()['diagnostics']['iterations'],
52 | {'alpha':3, 'rows':3, 'view_alphas':6, 'column_params':6,
53 | 'column_hypers':1})
54 |
55 |
56 | def test_transition_foreign():
57 | rng = gu.gen_rng(0)
58 | X = rng.normal(size=(5,5))
59 | state = State(X, cctypes=['normal']*5)
60 |
61 | token_a = state.compose_cgpm(FourWay(outputs=[12], inputs=[0,1], rng=rng))
62 | state.transition_foreign(cols=[12], N=5)
63 | check_expected_counts(
64 | state.diagnostics['iterations'],
65 | {'foreign-%s'%token_a: 5})
66 |
67 | token_b = state.compose_cgpm(TwoWay(outputs=[22], inputs=[2], rng=rng))
68 | state.transition_foreign(cols=[22], N=1)
69 | check_expected_counts(
70 | state.diagnostics['iterations'],
71 | {'foreign-%s'%token_a: 5, 'foreign-%s'%token_b: 1})
72 |
73 |
74 | state.transition_foreign(N=3)
75 | check_expected_counts(
76 | state.diagnostics['iterations'],
77 | {'foreign-%s'%token_a: 8, 'foreign-%s'%token_b: 4})
78 |
79 | start = time.time()
80 | state.transition_foreign(S=2)
81 | assert time.time() - start >= 2
82 |
83 | # Crash test for engine to transition everyone.
84 | engine = Engine(X, cctypes=['normal']*5)
85 | engine.compose_cgpm([FourWay(outputs=[12], inputs=[0,1], rng=rng)])
86 | engine.transition(N=4)
87 | # Cannot use transition with a foreign variable.
88 | with pytest.raises(ValueError):
89 | engine.transition(N=4, cols=state.outputs+[12], multiprocess=False)
90 | # Cannot use transition_foreign with a Crosscat variable.
91 | with pytest.raises(ValueError):
92 | engine.transition_foreign(cols=[0, 12], N=4, multiprocess=False)
93 | engine.transition_foreign(cols=[12], N=4)
94 |
95 |
96 | def check_expected_counts(actual, expected):
97 | for k, n in expected.iteritems():
98 | assert n == actual[k]
99 |
--------------------------------------------------------------------------------
/tests/test_linreg_mixture.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import matplotlib.cm as cm
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 |
22 | from cgpm.crosscat.state import State
23 | from cgpm.utils import general as gu
24 |
25 |
26 | def _compute_y(x):
27 | noise = [.5, 1]
28 | slopes = [-2, 5]
29 | model = x > 5
30 | return slopes[model] * x + rng.normal(scale=noise[model])
31 |
32 |
33 | rng = gu.gen_rng(1)
34 | X = rng.uniform(low=0, high=10, size=50)
35 | Y = map(_compute_y, X)
36 | D = np.column_stack((X,Y))
37 |
38 |
39 | def replace_key(d, a, b):
40 | d[b] = d[a]
41 | del d[a]
42 | return d
43 |
44 |
45 | def generate_gaussian_samples():
46 | state = State(
47 | D, cctypes=['normal','normal'], Zv={0:0, 1:0}, rng=gu.gen_rng(0))
48 | view = state.view_for(1)
49 | state.transition(S=15, kernels=['rows','column_params','column_hypers'])
50 | samples = view.simulate(-1, [0,1, view.outputs[0]], N=100)
51 | return [replace_key(s, view.outputs[0], -1) for s in samples]
52 |
53 |
54 | def generate_regression_samples():
55 | state = State(
56 | D, cctypes=['normal','normal'], Zv={0:0, 1:0}, rng=gu.gen_rng(4))
57 | view = state.view_for(1)
58 | assert not state._composite
59 | state.update_cctype(1, 'linear_regression')
60 | assert state._composite
61 | state.transition(S=30, kernels=['rows','column_params','column_hypers'])
62 | samples = view.simulate(-1, [0, 1, view.outputs[0]], N=100)
63 | return [replace_key(s, view.outputs[0], -1) for s in samples]
64 |
65 | def plot_samples(samples, title):
66 | fig, ax = plt.subplots()
67 | clusters = set(s[-1] for s in samples)
68 | colors = iter(cm.gist_rainbow(np.linspace(0, 1, len(clusters)+2)))
69 | ax.scatter(D[:,0], D[:,1], color='k', label='Data')
70 | for i, c in enumerate(clusters):
71 | sc = [(j[0], j[1]) for j in samples if j[-1] == c]
72 | xs, ys = zip(*sc)
73 | ax.scatter(
74 | xs, ys, alpha=.5, color=next(colors),
75 | label='Simulated (cluster %d)' %i)
76 | ax.set_title(title)
77 | ax.legend(framealpha=0, loc='upper left')
78 | ax.grid()
79 |
80 |
81 | def test_regression_plot_crash__ci_():
82 | samples_a = generate_gaussian_samples()
83 | samples_b = generate_regression_samples()
84 | plot_samples(samples_a, 'Model: Mixture of 2D Gaussians')
85 | plot_samples(samples_b, 'Model: Mixture of Linear Regression')
86 | # plt.close('all')
87 |
--------------------------------------------------------------------------------
/tests/test_logpdf_score.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import numpy as np
19 |
20 | from cgpm.crosscat.engine import Engine
21 | from cgpm.utils.general import gen_rng
22 |
23 | def test_logpdf_score_crash():
24 | rng = gen_rng(8)
25 | # T = rng.choice([0,1], p=[.3,.7], size=250).reshape(-1,1)
26 | T = rng.normal(size=30).reshape(-1,1)
27 | engine = Engine(T, cctypes=['normal'], rng=rng, num_states=4)
28 | logpdf_likelihood_initial = np.array(engine.logpdf_likelihood())
29 | logpdf_score_initial = np.array(engine.logpdf_score())
30 | assert np.all(logpdf_score_initial < logpdf_likelihood_initial)
31 | # assert np.all(logpdf_likelihood_initial < logpdf_score_initial)
32 | engine.transition(N=100)
33 | engine.transition(kernels=['column_hypers','view_alphas'], N=10)
34 | logpdf_likelihood_final = np.asarray(engine.logpdf_likelihood())
35 | logpdf_score_final = np.asarray(engine.logpdf_score())
36 | assert np.all(logpdf_score_final < logpdf_likelihood_final)
37 | assert np.max(logpdf_score_initial) < np.max(logpdf_score_final)
38 |
--------------------------------------------------------------------------------
/tests/test_logsumexp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2010-2016, MIT Probabilistic Computing Project
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import math
18 | import pytest
19 |
20 | from cgpm.utils import general as gu
21 |
22 | def relerr(expected, actual):
23 | """Relative error between `expected` and `actual`: ``abs((a - e)/e)``."""
24 | return abs((actual - expected)/expected)
25 |
26 | def test_logsumexp():
27 | inf = float('inf')
28 | nan = float('nan')
29 | with pytest.raises(OverflowError):
30 | math.log(sum(map(math.exp, range(1000))))
31 | assert relerr(999.4586751453871, gu.logsumexp(range(1000))) < 1e-15
32 | assert gu.logsumexp([]) == -inf
33 | assert gu.logsumexp([-1000.]) == -1000.
34 | assert gu.logsumexp([-1000., -1000.]) == -1000. + math.log(2.)
35 | assert relerr(math.log(2.), gu.logsumexp([0., 0.])) < 1e-15
36 | assert gu.logsumexp([-inf, 1]) == 1
37 | assert gu.logsumexp([-inf, -inf]) == -inf
38 | assert gu.logsumexp([+inf, +inf]) == +inf
39 | assert math.isnan(gu.logsumexp([-inf, +inf]))
40 | assert math.isnan(gu.logsumexp([nan, inf]))
41 | assert math.isnan(gu.logsumexp([nan, -3]))
42 |
43 | def test_logmeanexp():
44 | inf = float('inf')
45 | nan = float('nan')
46 | assert gu.logmeanexp([]) == -inf
47 | assert relerr(992.550919866405, gu.logmeanexp(range(1000))) < 1e-15
48 | assert gu.logmeanexp([-1000., -1000.]) == -1000.
49 | assert relerr(math.log(0.5 * (1 + math.exp(-1.))),
50 | gu.logmeanexp([0., -1.])) \
51 | < 1e-15
52 | assert relerr(math.log(0.5), gu.logmeanexp([0., -1000.])) < 1e-15
53 | assert relerr(-3 - math.log(2.), gu.logmeanexp([-inf, -3])) < 1e-15
54 | assert relerr(-3 - math.log(2.), gu.logmeanexp([-3, -inf])) < 1e-15
55 | assert gu.logmeanexp([+inf, -3]) == +inf
56 | assert gu.logmeanexp([-3, +inf]) == +inf
57 | assert gu.logmeanexp([-inf, 0, +inf]) == +inf
58 | assert math.isnan(gu.logmeanexp([nan, inf]))
59 | assert math.isnan(gu.logmeanexp([nan, -3]))
60 | assert math.isnan(gu.logmeanexp([nan]))
61 |
--------------------------------------------------------------------------------
/tests/test_lw_rf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Crash and sanity tests for queries using likelihood weighting inference
18 | with a RandomForest component model. Not an inference quality test suite."""
19 |
20 | import pytest
21 |
22 | import numpy as np
23 |
24 | from cgpm.crosscat.state import State
25 | from cgpm.utils import config as cu
26 | from cgpm.utils import general as gu
27 | from cgpm.utils import test as tu
28 |
29 |
30 | @pytest.fixture(scope='module')
31 | def state():
32 | cctypes, distargs = cu.parse_distargs([
33 | 'categorical(k=5)',
34 | 'normal',
35 | 'poisson',
36 | 'bernoulli'
37 | ])
38 | T, Zv, Zc = tu.gen_data_table(50, [1], [[.33, .33, .34]], cctypes, distargs,
39 | [.95]*len(cctypes), rng=gu.gen_rng(0))
40 | s = State(T.T, cctypes=cctypes, distargs=distargs,
41 | Zv={i:0 for i in xrange(len(cctypes))}, rng=gu.gen_rng(0))
42 | s.update_cctype(0, 'random_forest', distargs={'k':5})
43 | # XXX Uncomment me for a bug!
44 | # state.update_cctype(1, 'linear_regression')
45 | kernels = ['rows','view_alphas','alpha','column_params','column_hypers']
46 | s.transition(N=1, kernels=kernels)
47 | return s
48 |
49 |
50 | def test_simulate_unconditional__ci_(state):
51 | for rowid in [-1, 1]:
52 | samples = state.simulate(rowid, [0], N=2)
53 | check_entries_in_list(samples, range(5))
54 |
55 |
56 | def test_simulate_conditional__ci_(state):
57 | samples = state.simulate(
58 | -1, [0], {1:-1, 2:1, 3:1}, None, 2)
59 | check_entries_in_list(samples, range(5))
60 | samples = state.simulate(-1, [0, 2, 3], None, None, N=2)
61 | check_entries_in_list(samples, range(5))
62 | samples = state.simulate(1, [0, 2, 3], None, None, 2)
63 | check_entries_in_list(samples, range(5))
64 |
65 |
66 | def test_logpdf_unconditional__ci_(state):
67 | for k in xrange(5):
68 | assert state.logpdf(None, {0: k}) < 0
69 |
70 |
71 | def test_logpdf_deterministic__ci_(state):
72 | # Ensure logpdf estimation deterministic when all parents in constraints.
73 | for k in xrange(5):
74 | lp1 = state.logpdf(-1, {0:k, 3:0}, {1:1, 2:1})
75 | lp2 = state.logpdf(-1, {0:k, 3:0}, {1:1, 2:1})
76 | assert np.allclose(lp1, lp2)
77 | # Observed cell already has parents in constraints
78 | # Currently, logpdf for a non-nan observed cell is not possible.
79 | for k in xrange(5):
80 | with pytest.raises(ValueError):
81 | lp1 = state.logpdf(1, {0:k, 3:0})
82 | with pytest.raises(ValueError):
83 | lp2 = state.logpdf(1, {0:k, 3:0})
84 | assert np.allclose(lp1, lp2)
85 |
86 |
87 | def test_logpdf_impute__ci_(state):
88 | # Ensure logpdf estimation nondeterministic when all parents in constraints.
89 | # In practice, since the Random Forest discretizes its input, is quite
90 | # likely that different importance sampling estimates return the same
91 | # probability even when the parent nodes have different values.
92 | for k in xrange(5):
93 | lp1 = state.logpdf(-1, {0:k}, {1:1})
94 | lp2 = state.logpdf(-1, {0:k}, {1:1})
95 | print lp1, lp2
96 | # Observed cell already has parents in constraints.
97 | for k in xrange(5):
98 | lp1 = state.logpdf(-1, {1:1, 2:2}, {0:k})
99 | lp2 = state.logpdf(-1, {1:1, 2:2}, {0:k})
100 | print lp1, lp2
101 |
102 |
103 | def check_entries_in_list(entries, allowed):
104 | for entry in entries:
105 | assert entry[0] in allowed
106 |
--------------------------------------------------------------------------------
/tests/test_normal_categorical.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """This graphical test trains a gpmcc state on a bivariate population [X, Z].
18 | X (called the data) is a cctype from DistributionGpm. Z is a categorical
19 | variable that is a function of the latent cluster of each row
20 | (called the indicator).
21 |
22 | The three simulations are:
23 | - Joint Z,X.
24 | - Data conditioned on the indicator Z|X.
25 | - Indicator conditioned on the data X|Z.
26 |
27 | Simulations are compared to synthetic data at indicator subpopulations.
28 | """
29 |
30 | import pytest
31 |
32 | import matplotlib.pyplot as plt
33 | import numpy as np
34 |
35 | from scipy.stats import ks_2samp
36 |
37 | from cgpm.crosscat.engine import Engine
38 | from cgpm.utils import general as gu
39 | from cgpm.utils import test as tu
40 |
41 |
42 | N_SAMPLES = 250
43 |
44 | T, Zv, Zc = tu.gen_data_table(
45 | N_SAMPLES, [1], [[.3, .5, .2]], ['normal'], [None], [.95],
46 | rng=gu.gen_rng(0))
47 |
48 | DATA = np.zeros((N_SAMPLES, 2))
49 | DATA[:,0] = T[0]
50 |
51 | INDICATORS = [0, 1, 2, 3, 4, 5]
52 |
53 | counts = {0:0, 1:0, 2:0}
54 | for i in xrange(N_SAMPLES):
55 | k = Zc[0][i]
56 | DATA[i,1] = 2*INDICATORS[k] + counts[k] % 2
57 | counts[k] += 1
58 |
59 |
60 | @pytest.fixture(scope='module')
61 | def state():
62 | # Create an engine.
63 | engine = Engine(
64 | DATA, cctypes=['normal', 'categorical'], distargs=[None, {'k':6}],
65 | num_states=4, rng=gu.gen_rng(212))
66 | engine.transition(N=15)
67 | marginals = engine.logpdf_score()
68 | ranking = np.argsort(marginals)[::-1]
69 | return engine.get_state(ranking[0])
70 |
71 |
72 | def test_joint(state):
73 | # Simulate from the joint distribution of (x,z).
74 | joint_samples = state.simulate(-1, [0,1], N=N_SAMPLES)
75 | _, ax = plt.subplots()
76 | ax.set_title('Joint Simulation')
77 | for t in INDICATORS:
78 | # Plot original data.
79 | data_subpop = DATA[DATA[:,1] == t]
80 | ax.scatter(data_subpop[:,1], data_subpop[:,0], color=gu.colors[t])
81 | # Plot simulated data for indicator t.
82 | samples_subpop = [j[0] for j in joint_samples if j[1] == t]
83 | ax.scatter(
84 | np.add([t]*len(samples_subpop), .25), samples_subpop,
85 | color=gu.colors[t])
86 | # KS test.
87 | pvalue = ks_2samp(data_subpop[:,0], samples_subpop)[1]
88 | assert .05 < pvalue
89 | ax.set_xlabel('Indicator')
90 | ax.set_ylabel('x')
91 | ax.grid()
92 |
93 |
94 | def test_conditional_indicator(state):
95 | # Simulate from the conditional X|Z
96 | _, ax = plt.subplots()
97 | ax.set_title('Conditional Simulation Of Data X Given Indicator Z')
98 | for t in INDICATORS:
99 | # Plot original data.
100 | data_subpop = DATA[DATA[:,1] == t]
101 | ax.scatter(data_subpop[:,1], data_subpop[:,0], color=gu.colors[t])
102 | # Plot simulated data.
103 | samples_subpop = [s[0] for s in
104 | state.simulate(-1, [0], {1:t}, None, len(data_subpop))]
105 | ax.scatter(
106 | np.repeat(t, len(data_subpop)) + .25,
107 | samples_subpop, color=gu.colors[t])
108 | # KS test.
109 | pvalue = ks_2samp(data_subpop[:,0], samples_subpop)[1]
110 | assert .01 < pvalue
111 | ax.set_xlabel('Indicator')
112 | ax.set_ylabel('x')
113 | ax.grid()
114 |
115 |
116 | def test_conditional_real(state):
117 | # Simulate from the conditional Z|X
118 | fig, axes = plt.subplots(2,3)
119 | fig.suptitle('Conditional Simulation Of Indicator Z Given Data X')
120 | # Compute representative data sample for each indicator.
121 | means = [np.mean(DATA[DATA[:,1]==t], axis=0)[0] for t in INDICATORS]
122 | for mean, indicator, ax in zip(means, INDICATORS, axes.ravel('F')):
123 | samples_subpop = [s[1] for s in
124 | state.simulate(-1, [1], {0:mean}, None, N_SAMPLES)]
125 | ax.hist(samples_subpop, color='g', alpha=.4)
126 | ax.set_title('True Indicator %d' % indicator)
127 | ax.set_xlabel('Simulated Indicator')
128 | ax.set_xticks(INDICATORS)
129 | ax.set_ylabel('Frequency')
130 | ax.set_ylim([0, ax.get_ylim()[1]+10])
131 | ax.grid()
132 | # Check that the simulated indicator agrees with true indicator.
133 | true_ind_a = indicator
134 | true_ind_b = indicator-1 if indicator % 2 else indicator+1
135 | counts = np.bincount(samples_subpop)
136 | frac = sum(counts[[true_ind_a, true_ind_b]])/float(sum(counts))
137 | assert .8 < frac
138 |
--------------------------------------------------------------------------------
/tests/test_ols.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import importlib
18 | import pytest
19 |
20 | from math import log
21 |
22 | import numpy as np
23 |
24 | from cgpm.regressions.ols import OrdinaryLeastSquares
25 | from cgpm.utils import config as cu
26 | from cgpm.utils import general as gu
27 | from cgpm.utils import test as tu
28 |
29 |
30 | cctypes, distargs = cu.parse_distargs([
31 | 'normal',
32 | 'categorical(k=3)',
33 | 'poisson',
34 | 'bernoulli',
35 | 'lognormal',
36 | 'exponential',
37 | 'geometric',
38 | 'vonmises',
39 | 'normal'])
40 |
41 | T, Zv, Zc = tu.gen_data_table(
42 | 100, [1], [[.33, .33, .34]], cctypes, distargs,
43 | [.2]*len(cctypes), rng=gu.gen_rng(0))
44 |
45 | D = T.T
46 | OLS_DISTARGS = {
47 | 'inputs': {
48 | 'stattypes': cctypes[1:],
49 | 'statargs':
50 | [{'k': 3}] + [None] + [{'k': 2}] + [None, None, None, None, None]
51 | }
52 | }
53 | OLS_OUTPUTS = [0]
54 | OLS_INPUTS = range(1, len(cctypes))
55 |
56 |
57 | def test_integration():
58 | ols = OrdinaryLeastSquares(
59 | outputs=OLS_OUTPUTS,
60 | inputs=OLS_INPUTS,
61 | distargs=OLS_DISTARGS,
62 | rng=gu.gen_rng(0)
63 | )
64 | # Incorporate first 20 rows.
65 | for rowid, row in enumerate(D[:20]):
66 | observation = {0: row[0]}
67 | inputs = {i: row[i] for i in ols.inputs}
68 | ols.incorporate(rowid, observation, inputs)
69 | # Unincorporating row 20 should raise.
70 | with pytest.raises(ValueError):
71 | ols.unincorporate(20)
72 | # Unincorporate all rows.
73 | for rowid in xrange(20):
74 | ols.unincorporate(rowid)
75 | # Unincorporating row 0 should raise.
76 | with pytest.raises(ValueError):
77 | ols.unincorporate(0)
78 | # Incorporating with wrong covariate dimensions should raise.
79 | with pytest.raises(ValueError):
80 | observation = {0: D[0,0]}
81 | inputs = {i: v for (i, v) in enumerate(D[0])}
82 | ols.incorporate(0, observation, inputs)
83 | # Incorporating with None output value should raise.
84 | with pytest.raises(ValueError):
85 | observation = {0: None}
86 | inputs = {i: D[0,i] for i in ols.inputs}
87 | ols.incorporate(0, observation, inputs)
88 | # Incorporating with nan inputs value should raise.
89 | with pytest.raises(ValueError):
90 | observation = {0: 100}
91 | inputs = {i: D[0,i] for i in ols.inputs}
92 | inputs[inputs.keys()[0]] = np.nan
93 | ols.incorporate(0, observation, inputs)
94 | # Incorporate some more rows.
95 | for rowid, row in enumerate(D[:10]):
96 | observation = {0: row[0]}
97 | inputs = {i: row[i] for i in ols.inputs}
98 | ols.incorporate(rowid, observation, inputs)
99 |
100 | # Run a transition.
101 | ols.transition()
102 | assert ols.noise > 0
103 |
104 | # Invalid categorical inputs 5 for categorical(k=3).
105 | targets = {OLS_OUTPUTS[0]: 2}
106 | inputs = dict(zip(OLS_INPUTS, [5, 5, 0, 1.4, 7, 4, 2, -2]))
107 | with pytest.raises(ValueError):
108 | ols.logpdf(-1, targets, None, inputs)
109 | with pytest.raises(ValueError):
110 | ols.simulate(-1, OLS_OUTPUTS, None, inputs)
111 |
112 | # Invalid categorical inputs 2 for bernoulli.
113 | targets = {OLS_OUTPUTS[0]: 2}
114 | inputs = dict(zip(OLS_INPUTS, [5, 5, 2, 1.4, 7, 4, 2, -2]))
115 | with pytest.raises(ValueError):
116 | ols.logpdf(-1, targets, None, inputs)
117 | with pytest.raises(ValueError):
118 | ols.simulate(-1, OLS_OUTPUTS, None, inputs)
119 |
120 | # Do a logpdf computation.
121 | targets = {OLS_OUTPUTS[0]: 2}
122 | inputs = dict(zip(OLS_INPUTS, [2, 5, 0, 1.4, 7, 4, 2, -2]))
123 | logp_old = ols.logpdf(-1, targets, None, inputs)
124 | assert logp_old < 0
125 | ols.simulate(-1, OLS_OUTPUTS, None, inputs)
126 |
127 | # Now serialize and deserialize, and check if logp_old is the same.
128 | metadata = ols.to_metadata()
129 | builder = getattr(
130 | importlib.import_module(metadata['factory'][0]),
131 | metadata['factory'][1])
132 | ols2 = builder.from_metadata(metadata, rng=gu.gen_rng(1))
133 |
134 | assert ols2.noise == ols.noise
135 | logp_new = ols2.logpdf(-1, targets, None, inputs)
136 | assert np.allclose(logp_new, logp_old)
137 | ols2.simulate(-1, OLS_OUTPUTS, None, inputs)
138 |
--------------------------------------------------------------------------------
/tests/test_piecewise.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | from cgpm.dummy.piecewise import PieceWise
20 | from cgpm.utils.general import logsumexp
21 |
22 |
23 | def test_piecewise_logpdf():
24 | pw = PieceWise([0,1], [2], sigma=1, flip=.8)
25 | # x,z
26 | pw.simulate(None, [0,1], None, {2:1})
27 | pw.logpdf(None, {0:1.5, 1:0}, None, {2:1})
28 |
29 | # x
30 | pw.simulate(None, [0], None, {2:1})
31 | pw.logpdf(None, {0:1.5}, None, {2:1})
32 |
33 | # z
34 | pw.simulate(None, [1], None, {2:1})
35 | assert np.allclose(
36 | logsumexp([
37 | pw.logpdf(None, {1:0}, None, {2:1}),
38 | pw.logpdf(None, {1:1}, None, {2:1})]),
39 | 0)
40 |
41 | # z|x
42 | pw.simulate(None, [1], {0:1.5}, {2:1})
43 | assert np.allclose(
44 | logsumexp([
45 | pw.logpdf(None, {1:0}, {0:1.5}, {2:1}),
46 | pw.logpdf(None, {1:1}, {0:1.5}, {2:1})]),
47 | 0)
48 |
49 | # x|z
50 | pw.simulate(None, [0], {1:0}, {2:1})
51 | pw.logpdf(None, {0:1.5}, {1:0}, {2:1})
52 |
--------------------------------------------------------------------------------
/tests/test_populate_evidence.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 | import pytest
19 |
20 | from cgpm.mixtures.view import View
21 | from cgpm.crosscat.state import State
22 |
23 | """Test suite for View._populate_constraints.
24 |
25 | Ensures that View._populate_constraints correctly retrieves values from the
26 | dataset.
27 | """
28 |
29 |
30 | # ------------------------------------------------------------------------------
31 | # Tests for cgpm.mixtures.view.View
32 |
33 | def retrieve_view():
34 | X = np.asarray([
35 | [1, np.nan, 2, -1, np.nan],
36 | [1, 3, 2, -1, -5],
37 | [1, np.nan, np.nan, np.nan, np.nan],
38 | ])
39 | outputs = [0,1,2,3,4]
40 | return View(
41 | {c: X[:,c].tolist() for c in outputs},
42 | outputs=[-1] + outputs,
43 | cctypes=['normal']*5,
44 | Zr=[0,1,2]
45 | )
46 |
47 |
48 | def test_view_hypothetical_unchanged():
49 | view = retrieve_view()
50 |
51 | rowid = -1
52 | targets1 = {3:-1}
53 | constraints1 = {1:1, 2:2}
54 | constraints2 = view._populate_constraints(rowid, targets1, constraints1)
55 | assert constraints1 == constraints2
56 |
57 |
58 | def test_view_only_rowid_to_populate():
59 | view = retrieve_view()
60 |
61 | # Can targets X[2,0] for simulate.
62 | rowid = 2
63 | targets1 = [0]
64 | constraints1 = {}
65 | constraints2 = view._populate_constraints(rowid, targets1, constraints1)
66 | assert constraints2 == {-1: view.Zr(rowid)}
67 |
68 |
69 | def test_view_constrain_cluster():
70 | view = retrieve_view()
71 |
72 | # Cannot constrain cluster assignment of observed rowid.
73 | rowid = 1
74 | targets1 = {-1: 2}
75 | constraints1 = {}
76 | with pytest.raises(ValueError):
77 | view._populate_constraints(rowid, targets1, constraints1)
78 |
79 |
80 | def test_view_values_to_populate():
81 | view = retrieve_view()
82 |
83 | rowid = 0
84 | targets1 = [1]
85 | constraints1 = {4:2}
86 | constraints2 = view._populate_constraints(rowid, targets1, constraints1)
87 | assert constraints2 == {0:1, 2:2, 3:-1, 4:2, -1: view.Zr(rowid)}
88 |
89 | rowid = 0
90 | targets1 = {1:1}
91 | constraints1 = {4:2}
92 | constraints2 = view._populate_constraints(rowid, targets1, constraints1)
93 | assert constraints2 == {2:2, 0:1, 3:-1, 4:2, -1: view.Zr(rowid)}
94 |
95 |
96 | # ------------------------------------------------------------------------------
97 | # Tests for cgpm.crosscat.state.State
98 |
99 | def retrieve_state():
100 | X = np.asarray([
101 | [1, np.nan, 2, -1, np.nan],
102 | [1, 3, 2, -1, -5],
103 | [1, np.nan, np.nan, np.nan, np.nan],
104 | ])
105 | outputs = [0,1,2,3,4]
106 | return State(
107 | X,
108 | outputs=outputs,
109 | cctypes=['normal']*5,
110 | Zv={0:0, 1:0, 2:0, 3:0, 4:0},
111 | Zrv={0:[0,1,2]}
112 | )
113 |
114 | def test_state_constrain_logpdf():
115 | state = retrieve_state()
116 | # Cannot targets X[2,0] for logpdf.
117 | rowid = 2
118 | targets1 = {0:2}
119 | constraints1 = {}
120 | with pytest.raises(ValueError):
121 | state._validate_cgpm_query(rowid, targets1, constraints1)
122 |
123 | def test_state_constrain_errors():
124 | state = retrieve_state()
125 |
126 | rowid = 1
127 | targets1 = {1:1, 4:1}
128 | constraints1 = {}
129 | with pytest.raises(ValueError):
130 | state._validate_cgpm_query(rowid, targets1, constraints1)
131 |
132 | rowid = 1
133 | targets1 = {1:3}
134 | constraints1 = {4:-5}
135 | with pytest.raises(ValueError):
136 | state._validate_cgpm_query(rowid, targets1, constraints1)
137 |
138 | rowid = 1
139 | targets1 = {0:1, 1:3}
140 | constraints1 = {}
141 | with pytest.raises(ValueError):
142 | state._validate_cgpm_query(rowid, targets1, constraints1)
143 |
--------------------------------------------------------------------------------
/tests/test_row_similarity.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | @pytest.mark.xfail(strict=True, reason='Stub: test not implemented yet.')
20 | def test_row_similarity_basic():
21 | raise ValueError('Implement me!')
22 |
--------------------------------------------------------------------------------
/tests/test_simulate_many_decorator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.utils import general as gu
18 |
19 | class DummyCgpm():
20 |
21 | @gu.simulate_many
22 | def simulate(self, rowid, targets, constraints=None, inputs=None, N=None):
23 | return (self, rowid, targets, constraints, inputs, N)
24 |
25 | def test_simulate_many_kwarg_none():
26 | """N is None as a kwarg."""
27 | gpm = DummyCgpm()
28 | assert gpm.simulate(None, [1,2]) \
29 | == (gpm, None, [1,2], None, None, None)
30 | assert gpm.simulate(None, [1,2], inputs={4:1}) \
31 | == (gpm, None, [1,2], None, {4:1}, None)
32 | assert gpm.simulate(4, [1,2], {3:2}) \
33 | == (gpm, 4, [1,2], {3:2}, None, None)
34 | assert gpm.simulate(None, [1,2], {4:2}, inputs={5:2}) \
35 | == (gpm, None, [1,2], {4:2}, {5:2}, None)
36 | assert gpm.simulate(None, [1,2], {5:2}) \
37 | == (gpm, None, [1,2], {5:2}, None, None)
38 | assert gpm.simulate(2, [1,2], constraints={4:2}, inputs={5:2}) \
39 | == (gpm, 2, [1,2], {4:2}, {5:2}, None)
40 |
41 | def test_simulate_many_kwarg_not_none():
42 | """N is not None, used as a named parameter."""
43 | gpm = DummyCgpm()
44 | assert gpm.simulate(None, [1,2], N=10) \
45 | == [(gpm, None, [1,2], None, None, 10)] * 10
46 | assert gpm.simulate(9, [1,2], inputs=None, N=7) \
47 | == [(gpm, 9, [1,2], None, None, 7)] * 7
48 | assert gpm.simulate(None, [1,2], {2:1}, N=10) \
49 | == [(gpm, None, [1,2], {2:1}, None, 10)] * 10
50 | assert gpm.simulate(None, [1,2], constraints={4:2}, inputs={5:2}, N=1) \
51 | == [(gpm, None, [1,2], {4:2}, {5:2}, 1)]
52 |
53 | def test_simulate_many_positional():
54 | # N is positional.
55 | gpm = DummyCgpm()
56 | assert gpm.simulate(77, [1,2], None, None, 10) \
57 | == [(gpm, 77, [1,2], None, None, 10)] * 10
58 | assert gpm.simulate(None, [1,2], None, {4:1}, 7) \
59 | == [(gpm, None, [1,2], None, {4:1}, 7)] * 7
60 | assert gpm.simulate(None, [1,2], {5:2}, None, 1) \
61 | == [(gpm, None, [1,2], {5:2}, None, 1)]
62 | assert gpm.simulate(None, [1,2], None, {3:1}, None) \
63 | == (gpm, None, [1,2], None, {3:1}, None)
64 |
--------------------------------------------------------------------------------
/tests/test_state_initialize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from cgpm.crosscat.state import State
18 | from cgpm.utils import general as gu
19 |
20 | def test_Zv_without_Zrv():
21 | rng = gu.gen_rng(2)
22 | D = rng.normal(size=(10,4))
23 |
24 | state = State(
25 | D,
26 | outputs=[3,2,1,0,],
27 | cctypes=['normal']*D.shape[1],
28 | Zv={3:0, 2:1, 1:2, 0:4},
29 | rng=rng,
30 | )
31 |
--------------------------------------------------------------------------------
/tests/test_stochastic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2010-2016, MIT Probabilistic Computing Project
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | from stochastic import StochasticError
20 | from stochastic import stochastic
21 |
22 | class Quagga(Exception):
23 | pass
24 |
25 | @stochastic(max_runs=1, min_passes=1)
26 | def _test_fail(seed):
27 | raise Quagga
28 |
29 | @stochastic(max_runs=1, min_passes=1)
30 | def _test_pass(_seed):
31 | pass
32 |
33 | passthenfail_counter = 0
34 | @stochastic(max_runs=2, min_passes=1)
35 | def _test_passthenfail(seed):
36 | global passthenfail_counter
37 | passthenfail_counter += 1
38 | passthenfail_counter %= 2
39 | if passthenfail_counter == 0:
40 | raise Quagga
41 |
42 | failthenpass_counter = 0
43 | @stochastic(max_runs=2, min_passes=1)
44 | def _test_failthenpass(seed):
45 | global failthenpass_counter
46 | failthenpass_counter += 1
47 | failthenpass_counter %= 2
48 | if failthenpass_counter == 1:
49 | raise Quagga
50 |
51 | @stochastic(max_runs=2, min_passes=1)
52 | def _test_failthenfail(seed):
53 | raise Quagga
54 |
55 | @stochastic(max_runs=1, min_passes=1)
56 | def test_stochastic(seed):
57 | with pytest.raises(StochasticError):
58 | _test_fail()
59 | try:
60 | _test_fail()
61 | except StochasticError as e:
62 | assert isinstance(e.excvalue, Quagga)
63 | with pytest.raises(Quagga):
64 | _test_fail(seed)
65 | _test_pass()
66 | _test_pass(seed)
67 | _test_passthenfail()
68 | with pytest.raises(Quagga):
69 | _test_passthenfail(seed)
70 | _test_failthenpass()
71 | with pytest.raises(Quagga):
72 | _test_failthenpass(seed)
73 | with pytest.raises(StochasticError):
74 | _test_failthenfail()
75 | with pytest.raises(Quagga):
76 | _test_failthenfail(seed)
77 |
--------------------------------------------------------------------------------
/tests/test_teh_murphy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """The class cgpm.primitives.normal.Normal uses derivations from both
18 |
19 | http://www.stats.ox.ac.uk/~teh/research/notes/GaussianInverseGamma.pdf
20 | https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf (Section 3)
21 |
22 | The two sources use different parameterizations. This suite ensures that the
23 | conversions of these parameterizations performed in Normal produce consistent
24 | numerical results between the two sources.
25 |
26 | In particular, truncating the Normal requires computing the logcdf to normalize
27 | the probability, which is a Student T derived in Murphy but not Teh.
28 | """
29 |
30 | import itertools as it
31 |
32 | import numpy as np
33 | from numpy import log, pi, sqrt
34 | from scipy.special import gammaln
35 | from scipy.stats import t
36 |
37 | from cgpm.primitives.normal import Normal
38 | from cgpm.utils.general import gen_rng
39 |
40 |
41 | # Prepare some functions for use in the test.
42 | def teh_posterior(m, r, s, nu, x):
43 | """Eq 10 Teh."""
44 | N = len(x)
45 | sum_x = np.sum(x)
46 | sum_x_sq = np.sum(x**2)
47 | return Normal.posterior_hypers(N, sum_x, sum_x_sq, m, r, s, nu)
48 |
49 |
50 | def murphy_posterior(a, b, k, mu, x):
51 | """Eqs 85 to 89 Murphy."""
52 | n = len(x)
53 | xbar = np.mean(x)
54 | kn = k + n
55 | an = a + n/2.
56 | mun = (k*mu+n*xbar)/(k+n)
57 | bn = b + .5*np.sum((x-xbar)**2) + k*n*(xbar-mu)**2 / (2*(k+n))
58 | return an, bn, kn, mun
59 |
60 |
61 | def murphy_posterior_predictive(an1, an, bn1, bn, kn1, kn):
62 | """Eq 99 Murphy."""
63 | return gammaln(an1) - gammaln(an) + an*log(bn) - \
64 | (an1)*log(bn1) + 1/2.*(log(kn) - log(kn1)) - 1/2.*log(2*pi)
65 |
66 |
67 | # Test suite.
68 |
69 | def test_agreement():
70 | # Hyperparmaeters in Teh notation.
71 | all_m = map(float, (1., 7., .43, 1.2))
72 | all_r = map(float, (2., 18., 3., 11.))
73 | all_s = map(float, (2., 6., 15., 55.))
74 | all_nu = map(float, (4., .6, 14., 8.))
75 |
76 | # Dataset
77 | rng = gen_rng(0)
78 | x1 = rng.normal(10, 3, size=100)
79 | x2 = rng.normal(-3, 7, size=100)
80 |
81 | for (m, r, s, nu), x in \
82 | it.product(zip(all_m, all_r, all_s, all_nu), [x1,x2]):
83 | # Murphy hypers in terms of Teh.
84 | a = nu/2.
85 | b = s/2.
86 | k = r
87 | mu = m
88 |
89 | # Test equality of posterior hypers.
90 | mn, rn, sn, nun = teh_posterior(m, r, s, nu, x)
91 | an, bn, kn, mun = murphy_posterior(a, b, k, mu, x)
92 | assert np.allclose(an, nun/2, atol=1e-5)
93 | assert np.allclose(bn, sn/2, atol=1e-5)
94 | assert np.allclose(kn, rn, atol=1e-5)
95 | assert np.allclose(mun, mn, atol=1e-5)
96 |
97 | # Test posterior predictive agree with each other, and Student T.
98 | for xtest in np.linspace(1.1, 80.8, 14.1):
99 | # Murphy exact, Eq 99.
100 | an1, bn1, kn1, mun1 = murphy_posterior(
101 | a, b, k, mu, np.append(x, xtest))
102 | logprob_murphy = murphy_posterior_predictive(
103 | an1, an, bn1, bn, kn1, kn)
104 |
105 | # Student T Murphy, Eq 100.
106 | scalesq = bn*(kn+1)/(an*kn)
107 | logprob_t_murphy = t.logpdf(
108 | xtest, 2*an, loc=mun, scale=sqrt(scalesq))
109 |
110 | # Teh exact using Murphy Eq 99.
111 | mn1, rn1, sn1, nun1 = teh_posterior(
112 | m, r, s, nu, np.append(x, xtest))
113 | logprob_teh = murphy_posterior_predictive(
114 | nun1/2., nun/2, sn1/2., sn/2, rn1, rn)
115 |
116 | # Posterior predictive from Normal DistributionGpm.
117 | logprob_nignormal = Normal.calc_predictive_logp(
118 | xtest, len(x), sum(x), np.sum(x**2), m, r, s, nu)
119 |
120 | # Student T Teh using Murphy Eq 100.
121 | scalesq = sn/2.*(rn+1)/(nun/2.*rn)
122 | logprob_t_teh = t.logpdf(
123 | xtest, 2*nun/2., loc=mn, scale=sqrt(scalesq))
124 |
125 | # Aggregate all values and test their equality.
126 | values = [logprob_murphy, logprob_teh, logprob_t_murphy,
127 | logprob_t_teh, logprob_nignormal]
128 | for v in values:
129 | assert np.allclose(v, values[0], atol=1e-2)
130 |
--------------------------------------------------------------------------------
/tests/test_type_check.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from collections import namedtuple
18 |
19 | import pytest
20 |
21 | from cgpm.utils import config as cu
22 |
23 |
24 | Case = namedtuple(
25 | 'Case', ['outputs', 'inputs', 'distargs', 'good', 'bad'])
26 |
27 | cases = {
28 | 'bernoulli' : Case(
29 | outputs=[0],
30 | inputs=None,
31 | distargs=None,
32 | good=[0, 1.],
33 | bad=[-1, .5, 3]),
34 |
35 | 'beta' : Case(
36 | outputs=[0],
37 | inputs=None,
38 | distargs=None,
39 | good=[.3, .1, .9, 0.001, .9999],
40 | bad=[-1, 1.02, 21]),
41 |
42 | 'categorical' : Case(
43 | outputs=[0],
44 | inputs=None,
45 | distargs={'k': 4},
46 | good=[0., 1, 2, 3.],
47 | bad=[-1, 2.5, 4]),
48 |
49 | 'exponential' : Case(
50 | outputs=[0],
51 | inputs=None,
52 | distargs=None,
53 | good=[0, 1, 2, 3],
54 | bad=[-1, -2.5]),
55 |
56 | 'geometric' : Case(
57 | outputs=[0],
58 | inputs=None,
59 | distargs=None,
60 | good=[0, 2, 12],
61 | bad=[-1, .5, -4]),
62 |
63 | 'lognormal' : Case(
64 | outputs=[0],
65 | inputs=None,
66 | distargs=None,
67 | good=[1, 2, 3],
68 | bad=[-12, -0.01, 0]),
69 |
70 | 'normal' : Case(
71 | outputs=[0],
72 | inputs=None,
73 | distargs=None,
74 | good=[-1, 0, 10],
75 | bad=[]),
76 |
77 | 'normal_trunc' : Case(
78 | outputs=[0],
79 | inputs=None,
80 | distargs={'l': -1, 'h': 10},
81 | good=[0, 4, 9],
82 | bad=[44, -1.02]),
83 |
84 | 'poisson' : Case(
85 | outputs=[0],
86 | inputs=None,
87 | distargs=None,
88 | good=[0, 5, 11],
89 | bad=[-1, .5, -4]),
90 |
91 | 'random_forest' : Case(
92 | outputs=[0],
93 | inputs=[1, 2],
94 | distargs={'k': 2, 'inputs':{'stattypes': [1,2]}},
95 | good=[(0, {1:1, 2:2}), (1, {1:0, 2:2})],
96 | bad=[(-1, {1:1, 2:2}), (0, {0:1, 2:2}), (0, {1: 3})]),
97 |
98 | 'vonmises' : Case(
99 | outputs=[0],
100 | inputs=None,
101 | distargs=None,
102 | good=[0.1, 3.14, 6.2],
103 | bad=[-1, 7, 12]),
104 |
105 | 'linear_regression' : Case(
106 | outputs=[0],
107 | inputs=[1, 2],
108 | distargs={
109 | 'inputs': {
110 | 'stattypes': ['normal', 'bernoulli'],
111 | 'statargs': [None, {'k':2}]}},
112 | good=[(0, {1:1, 2:0})],
113 | bad=[(0, {0:1, 1:1, 2:0})]),
114 | }
115 |
116 |
117 | def get_observation_inputs(t):
118 | # Assumes that the output is always column id 0.
119 | return ({0: t[0]}, t[1]) if isinstance(t, tuple) else ({0: t}, None)
120 |
121 |
122 | @pytest.mark.parametrize('cctype', cases.keys())
123 | def test_distributions(cctype):
124 | case = cases[cctype]
125 | assert_distribution(
126 | cctype, case.outputs, case.inputs, case.distargs, case.good, case.bad)
127 |
128 |
129 | def assert_distribution(cctype, outputs, inputs, distargs, good, bad):
130 | model = cu.cctype_class(cctype)(outputs, inputs, distargs=distargs)
131 | for rowid, g in enumerate(good):
132 | assert_good(model, rowid, g)
133 | for rowid, b in enumerate(bad):
134 | assert_bad(model, rowid, b)
135 |
136 |
137 | def assert_good(model, rowid, g):
138 | observation, inputs = get_observation_inputs(g)
139 | model.incorporate(rowid, observation, inputs)
140 | model.unincorporate(rowid)
141 | assert model.logpdf(-1, observation, None, inputs) != -float('inf')
142 |
143 |
144 | def assert_bad(model, rowid, b):
145 | observation, inputs = get_observation_inputs(b)
146 | with pytest.raises(Exception):
147 | model.incorporate(rowid, observation, inputs)
148 | with pytest.raises(Exception):
149 | model.unincorporate(rowid)
150 | try: # GPM return negative infinity for invalid input.
151 | assert model.logpdf(-1, observation, None, inputs) == -float('inf')
152 | except Exception: # Conditional GPM throws error on wrong input variables.
153 | assert True
154 |
--------------------------------------------------------------------------------
/tests/test_view_logpdf_cluster.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2015-2016 MIT Probabilistic Computing Project
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 |
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import pytest
18 |
19 | import numpy as np
20 |
21 | from cgpm.mixtures.view import View
22 | from cgpm.utils import general as gu
23 |
24 |
25 | def retrieve_view():
26 | data = np.asarray([
27 | [1.1, -2.1, 0], # rowid=0
28 | [2., .1, 0], # rowid=1
29 | [1.5, np.nan, .5], # rowid=2
30 | [4.7, 7.4, .5], # rowid=3
31 | [5.2, 9.6, np.nan], # rowid=4
32 | ])
33 |
34 | outputs = [0,1,2,]
35 |
36 | return View(
37 | {c: data[:,i].tolist() for i, c in enumerate(outputs)},
38 | outputs=[1000] + outputs,
39 | alpha=2.,
40 | cctypes=['normal'] * len(outputs),
41 | Zr=[0,0,0,1,1,]
42 | )
43 |
44 |
45 | def test_crp_prior_logpdf():
46 | view = retrieve_view()
47 | crp_normalizer = view.alpha() + 5.
48 | cluster_logps = np.log(np.asarray([
49 | 3 / crp_normalizer,
50 | 2 / crp_normalizer,
51 | view.alpha() / crp_normalizer
52 | ]))
53 | # Test the crp probabilities agree for a hypothetical row.
54 | for k in [0,1,2]:
55 | expected_logpdf = cluster_logps[k]
56 | crp_logpdf = view.crp.clusters[0].logpdf(None, {view.outputs[0]: k})
57 | assert np.allclose(expected_logpdf, crp_logpdf)
58 | view_logpdf = view.logpdf(None, {view.outputs[0]: k})
59 | assert np.allclose(view_logpdf, crp_logpdf)
60 |
61 |
62 | def test_crp_posterior_logpdf():
63 | view = retrieve_view()
64 | fresh_row = {0:2, 1:3, 2:.5}
65 | logps = [
66 | view.logpdf(None, {view.outputs[0]: k}, fresh_row)
67 | for k in [0,1,2]
68 | ]
69 | assert np.allclose(gu.logsumexp(logps), 0)
70 |
71 |
72 | def test_logpdf_observed_nan():
73 | view = retrieve_view()
74 | logp_view = view.logpdf(2, {1:1})
75 | logp_dim = view.dims[1].logpdf(2, {1:1}, None, {view.outputs[0]: view.Zr(2)})
76 | assert np.allclose(logp_view, logp_dim)
77 |
78 |
79 | def test_logpdf_chain():
80 | view = retrieve_view()
81 | logp_cluster = view.logpdf(None, {view.outputs[0]: 0})
82 | logp_data = view.logpdf(None, {1:1, 2:0}, {view.outputs[0]: 0})
83 | logp_joint = view.logpdf(None, {1:1, 2:0, view.outputs[0]: 0})
84 | assert np.allclose(logp_cluster+logp_data, logp_joint)
85 |
86 |
87 | def test_logpdf_bayes():
88 | view = retrieve_view()
89 | logp_posterior = view.logpdf(None, {view.outputs[0]: 0, 1:1}, {2:0})
90 | logp_evidence = view.logpdf(None, {2:0})
91 | logp_joint = view.logpdf(None, {1:1, 2:0, view.outputs[0]: 0})
92 | assert np.allclose(logp_joint - logp_evidence, logp_posterior)
93 |
--------------------------------------------------------------------------------
/tests/test_vsinline.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2010-2016, MIT Probabilistic Computing Project
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import hacks
18 | import pytest
19 | if not pytest.config.getoption('--integration'):
20 | hacks.skip('specify --integration to run integration tests')
21 |
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | import pytest
25 |
26 | from cgpm.venturescript.vsinline import InlineVsCGpm
27 | from cgpm.utils import general as gu
28 |
29 |
30 | def test_input_matches_args():
31 | InlineVsCGpm([0], [], expression='() ~> {normal(0, 1)}')
32 | InlineVsCGpm([0], [], expression='( ) ~> {normal(0, 1)}')
33 | InlineVsCGpm([0], [], expression='() ~> {normal(0, 1)}')
34 | InlineVsCGpm([0], [], expression=' ( ) ~> {normal(0, 1)}')
35 | InlineVsCGpm([0], [1], expression='(a) ~> {normal(0, 1)}')
36 | InlineVsCGpm([0], [1], expression='(a ) ~> {normal(0, 1)}')
37 | InlineVsCGpm([0], [1], expression=' ( a ) ~> {normal(0, 1)}')
38 | InlineVsCGpm([0], [1], expression='( ab) ~> {normal(0, 1)}')
39 | InlineVsCGpm([0], [1,2], expression='(a, b) ~> {normal(0, 1)}')
40 | InlineVsCGpm([0], [1,2], expression='( a, b ) ~> {normal(0, 1)}')
41 | InlineVsCGpm([0], [1,2], expression='(a, b ) ~> {normal(0, 1)}')
42 | InlineVsCGpm([0], [1,2,3], expression='(a, b, bc) ~> {2}')
43 |
44 | with pytest.raises(Exception):
45 | InlineVsCGpm([0], [], expression='(a) ~> {normal(0,1)}')
46 | with pytest.raises(Exception):
47 | InlineVsCGpm([0], [1], expression='(a, b) ~> {normal(0,1)}')
48 | with pytest.raises(Exception):
49 | InlineVsCGpm([0], [4], expression='(a, b , c) ~> {normal(0,1)}')
50 | with pytest.raises(Exception):
51 | InlineVsCGpm([0], [1,2], expression='(a) ~> {normal(0,1)}')
52 |
53 |
54 | def test_simulate_uniform():
55 | vs = InlineVsCGpm([0], [],
56 | expression='() ~> {uniform(low: -4.71, high: 4.71)}',
57 | rng=gu.gen_rng(10))
58 |
59 | lp = vs.logpdf(0, {0:0})
60 | for x in np.linspace(-4.70, 4.70, 100):
61 | assert np.allclose(vs.logpdf(0, {0:x}), lp)
62 | assert np.isinf(vs.logpdf(0, {0:12}))
63 |
64 | samples = vs.simulate(0, [0], None, None, N=200)
65 | extracted = [s[0] for s in samples]
66 | fig, ax = plt.subplots()
67 | ax.hist(extracted)
68 |
69 |
70 | def test_simulate_noisy_cos():
71 | vs_x = InlineVsCGpm([0], [],
72 | expression='() ~> {uniform(low: -4.71, high: 4.71)}',
73 | rng=gu.gen_rng(10))
74 |
75 | vs_y = InlineVsCGpm([1], [0],
76 | expression="""
77 | (x) ~>
78 | {if (cos(x) > 0)
79 | {uniform(low: cos(x) - .5, high: cos(x))}
80 | else
81 | {uniform(low: cos(x), high: cos(x) + .5)}}""",
82 | rng=gu.gen_rng(12))
83 |
84 | samples_x = vs_x.simulate(0, [0], None, None, N=200)
85 | samples_y = [vs_y.simulate(0, [1], None, sx) for sx in samples_x]
86 |
87 | # Plot the joint query.
88 | fig, ax = plt.subplots()
89 |
90 | xs = [s[0] for s in samples_x]
91 | ys = [s[1] for s in samples_y]
92 |
93 | # Scatter the dots.
94 | ax.scatter(xs, ys, color='blue', alpha=.4)
95 | ax.set_xlim([-1.5*np.pi, 1.5*np.pi])
96 | ax.set_ylim([-1.75, 1.75])
97 | for x in xs:
98 | ax.vlines(x, -1.75, -1.65, linewidth=.5)
99 | ax.grid()
100 |
101 | # Plot the density from y=0 to y=2 for x = 0
102 | fig, ax = plt.subplots()
103 | logpdfs = np.exp([
104 | vs_y.logpdf(0, {1:y}, None, {0:0})
105 | for y in np.linspace(0,2,50)
106 | ])
107 | ax.plot(np.linspace(0, 2, 50), logpdfs)
108 |
--------------------------------------------------------------------------------