├── .github
└── workflows
│ ├── build_docs.yml
│ └── main.yml
├── .gitignore
├── LICENSE
├── README.md
├── docs
├── contributing.md
├── gen_ref_pages.py
├── gotchas.md
├── index.md
├── install.md
└── tutorials
│ ├── conditional_demo.ipynb
│ ├── customizing_example.ipynb
│ ├── dequantization.ipynb
│ ├── ensemble_demo.ipynb
│ ├── gaussian_errors.ipynb
│ ├── index.md
│ ├── intro.ipynb
│ ├── marginalization.ipynb
│ ├── nongaussian_errors.ipynb
│ ├── spherical_flow_example.ipynb
│ └── weighted.ipynb
├── mkdocs-requirements.txt
├── mkdocs.yml
├── poetry.lock
├── pyproject.toml
├── pzflow
├── __init__.py
├── bijectors.py
├── distributions.py
├── example_files
│ ├── checkerboard-data.pkl
│ ├── city-data.pkl
│ ├── example-flow.pzflow.pkl
│ ├── galaxy-data.pkl
│ └── two-moons-data.pkl
├── examples.py
├── flow.py
├── flowEnsemble.py
└── utils.py
├── setup.cfg
└── tests
├── test_bijectors.py
├── test_distributions.py
├── test_ensemble.py
├── test_examples.py
├── test_flow.py
├── test_utils.py
└── test_version.py
/.github/workflows/build_docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | on:
4 | push:
5 | branches: [main]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: actions/setup-python@v5
13 | with:
14 | python-version: 3.x
15 | - run: pip install -r mkdocs-requirements.txt
16 | - run: mkdocs gh-deploy --force
17 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # inspired by examples here: https://github.com/snok/install-poetry
3 |
4 | name: build
5 |
6 | on:
7 | push:
8 | branches: [main, dev]
9 | pull_request:
10 | branches: [main, dev]
11 |
12 | jobs:
13 | build:
14 | # ----------------------------------------------
15 | # test linux and mac; several python versions
16 | # ----------------------------------------------
17 | strategy:
18 | fail-fast: true
19 | matrix:
20 | os: [ubuntu-latest]
21 | python-version: ["3.10", "3.11", "3.12"]
22 | runs-on: ${{ matrix.os }}
23 |
24 | steps:
25 | # ----------------------------------------------
26 | # checkout repository
27 | # ----------------------------------------------
28 | - name: Check out repo
29 | uses: actions/checkout@v4
30 |
31 | # ----------------------------------------------
32 | # install and configure poetry
33 | # ----------------------------------------------
34 | - name: Install poetry
35 | run: pipx install poetry
36 | - name: Set up Python ${{ matrix.python-version }}
37 | uses: actions/setup-python@v5
38 | with:
39 | python-version: ${{ matrix.python-version }}
40 | cache: "poetry"
41 |
42 | # ----------------------------------------------
43 | # install root project, if required
44 | # ----------------------------------------------
45 | - name: Install library
46 | run: poetry install --no-interaction
47 |
48 | # ----------------------------------------------
49 | # run linters
50 | # ----------------------------------------------
51 | #- name: Run linters
52 | # run: |
53 | # poetry run flake8 .
54 | # poetry run black . --check
55 | # poetry run isort .
56 |
57 | # ----------------------------------------------
58 | # run tests
59 | # ----------------------------------------------
60 | - name: Test with pytest
61 | run: |
62 | poetry run pytest --cov=./pzflow --cov-report=xml -n 10
63 |
64 | # ----------------------------------------------
65 | # upload coverage to Codecov
66 | # ----------------------------------------------
67 | - name: Upload coverage to Codecov
68 | uses: codecov/codecov-action@v5
69 | with:
70 | file: ./coverage.xml
71 | flags: unittests
72 | env_vars: OS,PYTHON
73 | name: codecov-umbrella
74 | fail_ci_if_error: true
75 | token: ${{ secrets.CODECOV_TOKEN }}
76 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # vscode settings
132 | .vscode/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 John Franklin Crenshaw
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://codecov.io/gh/jfcrenshaw/pzflow)
3 | [](https://badge.fury.io/py/pzflow)
4 | [](https://zenodo.org/badge/latestdoi/327498448)
5 | [](https://jfcrenshaw.github.io/pzflow/)
6 |
7 | # PZFlow
8 |
9 | PZFlow is a python package for probabilistic modeling of tabular data with normalizing flows.
10 |
11 | If your data consists of continuous variables that can be put into a Pandas DataFrame, pzflow can model the joint probability distribution of your data set.
12 |
13 | The `Flow` class makes building and training a normalizing flow simple.
14 | It also allows you to easily sample from the normalizing flow (e.g., for forward modeling or data augmentation), and calculate posteriors over any of your variables.
15 |
16 | There are several tutorial notebooks in the [docs](https://jfcrenshaw.github.io/pzflow/tutorials/).
17 |
18 | ## Installation
19 |
20 | See the instructions in the [docs](https://jfcrenshaw.github.io/pzflow/install/).
21 |
22 | ## Citation
23 |
24 | If you use this package in your research, please cite the following two sources:
25 |
26 | 1. The paper
27 | ```bibtex
28 | @ARTICLE{2024AJ....168...80C,
29 | author = {{Crenshaw}, John Franklin and {Kalmbach}, J. Bryce and {Gagliano}, Alexander and {Yan}, Ziang and {Connolly}, Andrew J. and {Malz}, Alex I. and {Schmidt}, Samuel J. and {The LSST Dark Energy Science Collaboration}},
30 | title = "{Probabilistic Forward Modeling of Galaxy Catalogs with Normalizing Flows}",
31 | journal = {\aj},
32 | keywords = {Neural networks, Galaxy photometry, Surveys, Computational methods, 1933, 611, 1671, 1965, Astrophysics - Instrumentation and Methods for Astrophysics, Astrophysics - Cosmology and Nongalactic Astrophysics},
33 | year = 2024,
34 | month = aug,
35 | volume = {168},
36 | number = {2},
37 | eid = {80},
38 | pages = {80},
39 | doi = {10.3847/1538-3881/ad54bf},
40 | archivePrefix = {arXiv},
41 | eprint = {2405.04740},
42 | primaryClass = {astro-ph.IM},
43 | adsurl = {https://ui.adsabs.harvard.edu/abs/2024AJ....168...80C},
44 | adsnote = {Provided by the SAO/NASA Astrophysics Data System}
45 | }
46 | ```
47 |
48 | 2. The [Zenodo deposit](https://zenodo.org/records/10710271) associated with the version number you used.
49 | The Zenodo deposit
50 |
51 |
52 | ### Sources
53 |
54 | PZFlow was originally designed for forward modeling of photometric redshifts as a part of the Creation Module of the [DESC](https://lsstdesc.org/) [RAIL](https://github.com/LSSTDESC/RAIL) project.
55 | The idea to use normalizing flows for photometric redshifts originated with [Bryce Kalmbach](https://github.com/jbkalmbach).
56 | The earliest version of the normalizing flow in RAIL was based on a notebook by [Francois Lanusse](https://github.com/eiffl) and included contributions from [Alex Malz](https://github.com/aimalz).
57 |
58 | The functional jax structure of the bijectors was originally based on [`jax-flows`](https://github.com/ChrisWaites/jax-flows) by [Chris Waites](https://github.com/ChrisWaites). The implementation of the Neural Spline Coupling is largely based on the [Tensorflow implementation](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/rational_quadratic_spline.py), with some inspiration from [`nflows`](https://github.com/bayesiains/nflows/).
59 |
60 | Neural Spline Flows are based on the following papers:
61 |
62 | > [NICE: Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516)\
63 | > Laurent Dinh, David Krueger, Yoshua Bengio\
64 | > _arXiv:1410.8516_
65 |
66 | > [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)\
67 | > Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio\
68 | > _arXiv:1605.08803_
69 |
70 | > [Neural Spline Flows](https://arxiv.org/abs/1906.04032)\
71 | > Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios\
72 | > _arXiv:1906.04032_
73 |
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing to PZFlow
2 |
3 | If you notice any bugs, have any questions, or would like to request a feature, please [submit an issue](https://github.com/jfcrenshaw/pzflow/issues).
4 |
5 | To work on pzflow, after forking and cloning the repo:
6 |
7 | 1. Create a virtual environment with Python
8 | E.g., with conda `conda create -n pzflow`
9 | 2. Activate the environment.
10 | E.g., `conda activate pzflow`
11 | 3. Install pzflow in edit mode with the `dev` flag
12 | I.e., in the root directory, `pip install -e ".[dev]"`
13 |
--------------------------------------------------------------------------------
/docs/gen_ref_pages.py:
--------------------------------------------------------------------------------
1 | """Generate the code reference pages."""
2 |
3 | from pathlib import Path
4 |
5 | import mkdocs_gen_files
6 |
7 | # recurse through python files in pzflow
8 | for path in sorted(Path("pzflow").rglob("*.py")):
9 | # get the module path, not including the pzflow prefix
10 | module_path = path.with_suffix("")
11 | # path where we'll save the markdown file for this module
12 | doc_path = "API" / path.relative_to("pzflow").with_suffix(".md")
13 |
14 | # split up the path into its parts
15 | parts = list(module_path.parts)
16 |
17 | # we don't want to explicitly list __init__
18 | if parts[-1] in ["__init__"]:
19 | #continue
20 | parts = parts[:-1]
21 | doc_path = doc_path.with_name("index.md")
22 | # skip main files
23 | elif parts[-1] == "__main__":
24 | continue
25 |
26 | # create the markdown file in the docs directory
27 | with mkdocs_gen_files.open(doc_path, "w") as fd:
28 | identifier = ".".join(parts)
29 | print(doc_path, identifier)
30 | print("::: " + identifier, file=fd)
31 |
32 | mkdocs_gen_files.set_edit_path(doc_path, path)
33 |
--------------------------------------------------------------------------------
/docs/gotchas.md:
--------------------------------------------------------------------------------
1 | # Common gotchas
2 |
3 | * It is important to note that there are two different conventions in the literature for the direction of the bijection in normalizing flows. PZFlow defines the bijection as the mapping from the data space to the latent space, and the inverse bijection as the mapping from the latent space to the data space. This distinction can be important when designing more complicated bijections (e.g., in Example 2 above).
4 |
5 | * If you get NaNs during training, try decreasing the learning rate (the default is 1e-3):
6 |
7 | ```python
8 | import optax
9 |
10 | opt = optax.adam(learning_rate=...)
11 | flow.train(..., optimizer=opt)
12 | ```
13 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://codecov.io/gh/jfcrenshaw/pzflow)
3 | [](https://badge.fury.io/py/pzflow)
4 | [](https://zenodo.org/badge/latestdoi/327498448)
5 | [](https://jfcrenshaw.github.io/pzflow/)
6 |
7 | # PZFlow
8 |
9 | PZFlow is a python package for probabilistic modeling of tabular data with normalizing flows.
10 |
11 | If your data consists of continuous variables that can be put into a Pandas DataFrame, pzflow can model the joint probability distribution of your data set.
12 |
13 | The `Flow` class makes building and training a normalizing flow simple.
14 | It also allows you to easily sample from the normalizing flow (e.g., for forward modeling or data augmentation), and calculate posteriors over any of your variables.
15 |
16 | There are several tutorial notebooks in the [docs](https://jfcrenshaw.github.io/pzflow/tutorials/).
17 |
18 | ## Installation
19 |
20 | See the instructions in the [docs](https://jfcrenshaw.github.io/pzflow/install/).
21 |
22 | ## Citation
23 |
24 | We are preparing a paper on pzflow.
25 | If you use this package in your research, please check back here for a citation before publication.
26 | In the meantime, please cite the [Zenodo release](https://zenodo.org/badge/latestdoi/327498448).
27 |
28 | ### Sources
29 |
30 | PZFlow was originally designed for forward modeling of photometric redshifts as a part of the Creation Module of the [DESC](https://lsstdesc.org/) [RAIL](https://github.com/LSSTDESC/RAIL) project.
31 | The idea to use normalizing flows for photometric redshifts originated with [Bryce Kalmbach](https://github.com/jbkalmbach).
32 | The earliest version of the normalizing flow in RAIL was based on a notebook by [Francois Lanusse](https://github.com/eiffl) and included contributions from [Alex Malz](https://github.com/aimalz).
33 |
34 | The functional jax structure of the bijectors was originally based on [`jax-flows`](https://github.com/ChrisWaites/jax-flows) by [Chris Waites](https://github.com/ChrisWaites). The implementation of the Neural Spline Coupling is largely based on the [Tensorflow implementation](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/rational_quadratic_spline.py), with some inspiration from [`nflows`](https://github.com/bayesiains/nflows/).
35 |
36 | Neural Spline Flows are based on the following papers:
37 |
38 | > [NICE: Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516)\
39 | > Laurent Dinh, David Krueger, Yoshua Bengio\
40 | > _arXiv:1410.8516_
41 |
42 | > [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)\
43 | > Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio\
44 | > _arXiv:1605.08803_
45 |
46 | > [Neural Spline Flows](https://arxiv.org/abs/1906.04032)\
47 | > Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios\
48 | > _arXiv:1906.04032_
49 |
--------------------------------------------------------------------------------
/docs/install.md:
--------------------------------------------------------------------------------
1 | # Install
2 |
3 | You can install PZFlow from PyPI with pip:
4 |
5 | ```shell
6 | pip install pzflow
7 | ```
8 |
9 | If you want to run PZFlow on a GPU with CUDA, you need to follow the GPU-enabled installation instructions for Jax [here](https://github.com/google/jax).
10 | You may also need to add cuda to your path.
11 | For example, I needed to add the following to my `.bashrc`:
12 |
13 | ```shell
14 | # cuda setup
15 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
16 | export PATH=$PATH:/usr/local/cuda/bin
17 | ```
18 |
19 | If you have the GPU enabled version of jax installed, but would like to run on a CPU, add the following to the top of your scripts/notebooks:
20 |
21 | ```python
22 | import jax
23 | # Global flag to set a specific platform, must be used at startup.
24 | jax.config.update('jax_platform_name', 'cpu')
25 | ```
26 |
27 | Note that if you run jax on GPU in multiple Jupyter notebooks simultaneously, you may get `RuntimeError: cuSolver internal error`. Read more [here](https://github.com/google/jax/issues/4497) and [here](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html).
28 |
--------------------------------------------------------------------------------
/docs/tutorials/index.md:
--------------------------------------------------------------------------------
1 | # Tutorials
2 |
3 | Below are example notebooks demonstrating how to use PZFlow.
4 | Each contains a link to open the notebook on Google Colab, as well as a link to the source code on Github.
5 |
6 | ### Basic
7 |
8 | - [Introduction to PZFlow](intro.ipynb) - using the default flow to train, sample, and calculate posteriors
9 | - [Conditional Flows](conditional_demo.ipynb) - building a conditional flow to model conditional distributions
10 | - [Convolving Gaussian Errors](gaussian_errors.ipynb) - convolving Gaussian errors during training and posterior calculation
11 | - [Flow Ensembles](ensemble_demo.ipynb) - using `FlowEnsemble` to create an ensemble of normalizing flows
12 | - [Training Weights](weighted.ipynb) - giving different weights to your training samples
13 |
14 | ### Intermediate
15 |
16 | - [Customizing the flow](customizing_example.ipynb) - Customizing the bijector and latent space
17 | - [Modeling Variables with Periodic Topology](spherical_flow_example.ipynb) - using circular splines to model data with periodic topology, e.g. positions on a sphere
18 |
19 | ### Advanced
20 |
21 | - [Marginalizing Variables](marginalization.ipynb) - marginalizing over missing variables during posterior calculation
22 | - [Convolving Non-Gaussian Errors](nongaussian_errors.ipynb) - convolving non-Gaussian errors during training and posterior calculation
23 |
--------------------------------------------------------------------------------
/docs/tutorials/nongaussian_errors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "[](https://colab.research.google.com/github/jfcrenshaw/pzflow/blob/main/docs/tutorials/nongaussian_errors.ipynb)\n",
8 | "[](https://github.com/jfcrenshaw/pzflow/blob/main/docs/tutorials/nongaussian_errors.ipynb)\n",
9 | "\n",
10 | "If running in Colab, to switch to GPU, go to the menu and select Runtime -> Change runtime type -> Hardware accelerator -> GPU.\n",
11 | "\n",
12 | "In addition, uncomment and run the following code:"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# !pip install pzflow matplotlib"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "------------------\n",
29 | "## Convolving non-Gaussian errors\n",
30 | "\n",
31 | "This notebook demonstrates how to train a flow on data that has non-Gaussian errors/uncertainties, as well as convolving those errors in log_prob and posterior calculations.\n",
32 | "We will use the example galaxy data again.\n",
33 | "\n",
34 | "For an example of how to handle Gaussian errors, which is much easier, see the notebook on Gaussian errors."
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 5,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "import numpy as np\n",
44 | "import jax.numpy as jnp\n",
45 | "from jax import random\n",
46 | "import matplotlib.pyplot as plt\n",
47 | "\n",
48 | "from pzflow import Flow\n",
49 | "from pzflow.examples import get_galaxy_data"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 6,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "plt.rcParams[\"figure.facecolor\"] = \"white\""
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "metadata": {},
64 | "source": [
65 | "First let's load the example galaxy data set included with PZFlow, and include photometric errors. For simplicity, we will assume all bands have error = 0.1"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 7,
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "data": {
75 | "text/html": [
76 | "
\n",
77 | "\n",
90 | "
\n",
91 | " \n",
92 | "
\n",
93 | "
\n",
94 | "
redshift
\n",
95 | "
u
\n",
96 | "
g
\n",
97 | "
r
\n",
98 | "
i
\n",
99 | "
z
\n",
100 | "
y
\n",
101 | "
u_err
\n",
102 | "
g_err
\n",
103 | "
r_err
\n",
104 | "
i_err
\n",
105 | "
z_err
\n",
106 | "
y_err
\n",
107 | "
\n",
108 | " \n",
109 | " \n",
110 | "
\n",
111 | "
0
\n",
112 | "
0.287087
\n",
113 | "
26.759261
\n",
114 | "
25.901778
\n",
115 | "
25.187710
\n",
116 | "
24.932318
\n",
117 | "
24.736903
\n",
118 | "
24.671623
\n",
119 | "
0.1
\n",
120 | "
0.1
\n",
121 | "
0.1
\n",
122 | "
0.1
\n",
123 | "
0.1
\n",
124 | "
0.1
\n",
125 | "
\n",
126 | "
\n",
127 | "
1
\n",
128 | "
0.293313
\n",
129 | "
27.428358
\n",
130 | "
26.679299
\n",
131 | "
25.977161
\n",
132 | "
25.700094
\n",
133 | "
25.522763
\n",
134 | "
25.417632
\n",
135 | "
0.1
\n",
136 | "
0.1
\n",
137 | "
0.1
\n",
138 | "
0.1
\n",
139 | "
0.1
\n",
140 | "
0.1
\n",
141 | "
\n",
142 | "
\n",
143 | "
2
\n",
144 | "
1.497276
\n",
145 | "
27.294001
\n",
146 | "
26.068798
\n",
147 | "
25.450055
\n",
148 | "
24.460507
\n",
149 | "
23.887221
\n",
150 | "
23.206112
\n",
151 | "
0.1
\n",
152 | "
0.1
\n",
153 | "
0.1
\n",
154 | "
0.1
\n",
155 | "
0.1
\n",
156 | "
0.1
\n",
157 | "
\n",
158 | "
\n",
159 | "
3
\n",
160 | "
0.283310
\n",
161 | "
28.154075
\n",
162 | "
26.283166
\n",
163 | "
24.599570
\n",
164 | "
23.723491
\n",
165 | "
23.214108
\n",
166 | "
22.860012
\n",
167 | "
0.1
\n",
168 | "
0.1
\n",
169 | "
0.1
\n",
170 | "
0.1
\n",
171 | "
0.1
\n",
172 | "
0.1
\n",
173 | "
\n",
174 | "
\n",
175 | "
4
\n",
176 | "
1.545183
\n",
177 | "
29.276065
\n",
178 | "
27.878301
\n",
179 | "
27.333528
\n",
180 | "
26.543374
\n",
181 | "
26.061941
\n",
182 | "
25.383056
\n",
183 | "
0.1
\n",
184 | "
0.1
\n",
185 | "
0.1
\n",
186 | "
0.1
\n",
187 | "
0.1
\n",
188 | "
0.1
\n",
189 | "
\n",
190 | " \n",
191 | "
\n",
192 | "
"
193 | ],
194 | "text/plain": [
195 | " redshift u g r i z y \\\n",
196 | "0 0.287087 26.759261 25.901778 25.187710 24.932318 24.736903 24.671623 \n",
197 | "1 0.293313 27.428358 26.679299 25.977161 25.700094 25.522763 25.417632 \n",
198 | "2 1.497276 27.294001 26.068798 25.450055 24.460507 23.887221 23.206112 \n",
199 | "3 0.283310 28.154075 26.283166 24.599570 23.723491 23.214108 22.860012 \n",
200 | "4 1.545183 29.276065 27.878301 27.333528 26.543374 26.061941 25.383056 \n",
201 | "\n",
202 | " u_err g_err r_err i_err z_err y_err \n",
203 | "0 0.1 0.1 0.1 0.1 0.1 0.1 \n",
204 | "1 0.1 0.1 0.1 0.1 0.1 0.1 \n",
205 | "2 0.1 0.1 0.1 0.1 0.1 0.1 \n",
206 | "3 0.1 0.1 0.1 0.1 0.1 0.1 \n",
207 | "4 0.1 0.1 0.1 0.1 0.1 0.1 "
208 | ]
209 | },
210 | "execution_count": 7,
211 | "metadata": {},
212 | "output_type": "execute_result"
213 | }
214 | ],
215 | "source": [
216 | "data = get_galaxy_data()\n",
217 | "for col in data.columns[1:]:\n",
218 | " data[f\"{col}_err\"] = 0.1 * np.ones(data.shape[0])\n",
219 | "data.head()"
220 | ]
221 | },
222 | {
223 | "cell_type": "markdown",
224 | "metadata": {},
225 | "source": [
226 | "Now, we need to build the machinery to handle non-Gaussian errors. PZFlow convolves errors by sampling from an error model, which by default is Gaussian. If we want to convolve non-Gaussian errors, we need to pass the `Flow` constructor an error model that tells it how to sample from our non-Gaussian distribution.\n",
227 | "\n",
228 | "The error model must be a callable that takes the following arguments:\n",
229 | "- key is a jax rng key, e.g. jax.random.PRNGKey(0)\n",
230 | "- X is a 2 dimensional array of data variables, where the order of variables matches the order of the columns in data_columns\n",
231 | "- Xerr is the corresponding 2 dimensional array of errors\n",
232 | "- nsamples is the number of samples to draw from the error distribution\n",
233 | "\n",
234 | "and it must return an array of samples with the shape `(X.shape[0], nsamples, X.shape[1])`\n",
235 | "\n",
236 | "Below we build a photometric error model, which takes the exponential of the magnitudes to convert to flux values, adds Gaussian flux errors, then takes the log to convert back to magnitudes. "
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 9,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "def photometric_error_model(\n",
246 | " key,\n",
247 | " X: np.ndarray,\n",
248 | " Xerr: np.ndarray,\n",
249 | " nsamples: int,\n",
250 | ") -> np.ndarray:\n",
251 | " \n",
252 | " # calculate fluxes\n",
253 | " F = 10 ** (X / -2.5)\n",
254 | " # calculate flux errors\n",
255 | " dF = np.log(10) / 2.5 * F * Xerr\n",
256 | " \n",
257 | " # add Gaussian errors\n",
258 | " eps = random.normal(key, shape=(F.shape[0], nsamples, F.shape[1]))\n",
259 | " F = F[:, None, :] + eps * dF[:, None, :]\n",
260 | " \n",
261 | " # add a flux floor to avoid infinite magnitudes\n",
262 | " # this flux corresponds to a max magnitude of 30\n",
263 | " F = np.clip(F, 1e-12, None)\n",
264 | " \n",
265 | " # calculate magnitudes\n",
266 | " M = -2.5 * np.log10(F)\n",
267 | " \n",
268 | " return M"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "metadata": {},
274 | "source": [
275 | "Now we can construct the Flow, this time passing the error model"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 10,
281 | "metadata": {},
282 | "outputs": [],
283 | "source": [
284 | "flow = Flow(\n",
285 | " [\"redshift\"] + list(\"ugrizy\"), \n",
286 | " data_error_model=photometric_error_model,\n",
287 | ")"
288 | ]
289 | },
290 | {
291 | "cell_type": "markdown",
292 | "metadata": {},
293 | "source": [
294 | "Now that we have set up the Flow with the new error model, we can train the flow, calculate posteriors, etc. just like we did in the Gaussian error example.\n",
295 | "\n",
296 | "For example, to train with error convolution:"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 11,
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "name": "stderr",
306 | "output_type": "stream",
307 | "text": [
308 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
309 | ]
310 | },
311 | {
312 | "name": "stdout",
313 | "output_type": "stream",
314 | "text": [
315 | "Training 200 epochs \n",
316 | "Loss:\n",
317 | "(0) 20.3297\n",
318 | "(1) 2.9366\n",
319 | "(11) 0.1215\n",
320 | "(21) -0.1496\n",
321 | "(31) -0.0601\n",
322 | "(41) -0.1246\n",
323 | "(51) -0.2496\n",
324 | "(61) -0.0759\n",
325 | "(71) 0.1112\n",
326 | "(81) -0.2375\n",
327 | "(91) -0.2352\n",
328 | "(101) -0.2811\n",
329 | "(111) -0.2134\n",
330 | "(121) -0.2033\n",
331 | "(131) -0.2952\n",
332 | "(141) -0.2706\n",
333 | "(151) -0.2759\n",
334 | "(161) -0.2222\n",
335 | "(171) -0.1448\n",
336 | "(181) -0.2966\n",
337 | "(191) -0.1698\n",
338 | "(200) -0.2314\n",
339 | "CPU times: user 1h 8min 26s, sys: 45min 42s, total: 1h 54min 8s\n",
340 | "Wall time: 22min 21s\n"
341 | ]
342 | }
343 | ],
344 | "source": [
345 | "%%time\n",
346 | "losses = flow.train(data, epochs=200, convolve_errs=True, verbose=True)"
347 | ]
348 | },
349 | {
350 | "cell_type": "markdown",
351 | "metadata": {},
352 | "source": [
353 | "And to calculate posteriors with error convolution:"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 12,
359 | "metadata": {},
360 | "outputs": [],
361 | "source": [
362 | "grid = np.linspace(0, 3, 100)\n",
363 | "pdfs = flow.posterior(data[:10], column=\"redshift\", grid=grid, err_samples=int(1e3))"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": 13,
369 | "metadata": {},
370 | "outputs": [
371 | {
372 | "data": {
373 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAHtCAYAAAA0tCb7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAABJ0AAASdAHeZh94AABglklEQVR4nO3deXxTZb4/8M9J0qTpmpa2lKUt0IJIUXDDBQVUxFEHRERc74ArOuoMDnoVZxTwil6vo44/HRfUERe84wqjeNVRZ9RRkcFdkJ0WCnShS7pnP78/knNyQkvXJCfP6ef9evVlm6bJQ+zTT77PdiRZlmUQERFRwjPp3QAiIiLqGYY2ERGRIBjaREREgmBoExERCYKhTUREJAiGNhERkSAY2kRERIJgaBMREQmCoU1ERCQISzQexOl04tNPP0VBQQFsNls0HpKIesjtdqOiogJTp06Fw+Ho9c+z/xLpp9f9V46CtWvXygD4wQ9+6Pixdu1a9l9+8EPQj57236hU2gUFBQCAtWvXoqSkJBoP2SMH7liC9k2bYB8/HkP/+/64PS9RItm5cydmz56t9sPeYv8l0k9v+29UQlsZUispKUFpaWk0HrJH0hwOtNlsSHE4UBTH5yVKRH0d2mb/JdJfT/svF6IREREJgqFNREQkCIY2ERGRIAwV2rIso6y2FS6vX++mEFEU7KlrRWObV+9mECWMqCxESxQfbK7C9S9/i+LcVLz168nItCf1+bF8Ph8aGhrQ0tICWZaj2Eqi7kmSBJvNhoyMDKSmpkKSJL2bFHff7KnHhU+uxzCHHR8vnorkJLPeTSLSnaEq7fc3VQEAdh1sxX++8UOfw1aWZezbtw+1tbXwevkun+LP7/ejsbERFRUVqKmpGZBvHP9d1gAA2O9sx7/L6nVuDVFiMFSl3eYJD4t/sLkaf/miHFefOrLXj9Pc3Iz29nZkZmZiyJAhA7LKIf15PB5UVlaivr4eqampSEtL07tJcVXT7FI//3xnLaaMydWxNUSJwVCVdl2rJ+LrP36wrU/z201NTQCAvLw8Bjbpxmq1YsiQIQDCv5MDSU2zW/38XztqdWwJUeIwVGhr35kDQLvXj21Vzb1+HK/XC4vFAovFUAMRJCCr1YqkpCS43e7u72wwBzWhvaWyKeJrooHKMKEty7Laqc8Ym6fe/nNl7ysUWZZhMhnmpSHBSZI0IOe0Dw3pL3ex2iYyTDK1uH1weQMAgONHZCHFGlxpuvlAY58ej8PilCgG6u9iTVPkyBmHyIkMFNrad+X5GckYm58OANh8YODNBRKJrtXtQ6sncj3K5ztqB+SIA5GWYUJbu2glLz0ZpUMzAQBbK5vhD7CjE4lE+ya8dGgGAKCqyYXvK5w6tYgoMRgmtLWdPDfdpnb0dq8fZbWtejWLYmzVqlWQJAnl5eV6N4WiSPsmfP4pI2AxBacI/vJFuU4tIkoMhg3tcaHQBvo+r21UStApH8nJyRgzZgxuuukmVFdXR/W52trasGzZMnzyySdRfVwyNm1/PmpYJs47Orj17f9+qsQBZ7tezSLSnXFCuyXYyS0mCQ57EsYMToc59O68LyvIB4J77rkHL730Eh5//HGccsopePLJJ3HyySejra0tas/R1taG5cuXxyy0/+M//gPt7e0oKiqKyeOTPrTbN3PTbeohSf6AjBfWl+vUKiL9GSa0a5qCoZ2bboPJJCE5yYzRecETpH7mYrROnXPOObjiiitwzTXXYNWqVVi0aBHKysrwt7/9Te+mdau1NTjlYTabkZycHLUV1tF8w0J9pwyPW0wSslOsOHq4AyeMyAIA/O+GvWjz+PRsHpFuDBPaSqWdm25Tbxs3JDhEvvlAE1ed9sAZZ5wBACgrK4PP58N//dd/obi4GDabDSNGjMCdd97Z4ZCPr7/+GmeffTZycnJgt9sxcuRIXHXVVQCA8vJy5OYGj55cvny5Ohy/bNky9ee3bt2KuXPnIjs7G8nJyTj++OPx9ttvRzyHMpz/6aef4te//jXy8vIwfPjwiO8dOqf9xBNPoLS0FDabDUOHDsWNN94Ip9MZcZ9p06Zh/Pjx+OabbzBlyhSkpKTgzjvv7O/LSFGgDI/npAXfhAPAVZOD1XaTy4ePt9To1jYiPRnmyC+lk+emaUJ7aAbe+m4/6ls9qGv1IEfzPepo165dAIBBgwbhmmuuwQsvvIC5c+di8eLF2LBhA+6//35s2bIFa9asAQDU1NRgxowZyM3NxR133AGHw4Hy8nK89dZbAIDc3Fw8+eSTuOGGG3DBBRdgzpw5AICjjz4aALB582ZMnjwZw4YNwx133IHU1FS89tprmD17Nt58801ccMEFEe379a9/jdzcXNx9991qpd2ZZcuWYfny5Zg+fTpuuOEGbNu2DU8++SQ2btyIL774AklJ4au/1dXV4ZxzzsEll1yCK664AoMHD47eC0p9plTa2jfhZxyZh/RkC5pdPry3qRIzJwzVq3lEujFeaGs6+ZBMe8T3+xvay9/ZnFBD7eOGZmDpzNI+/3xjYyNqa2vhcrnwxRdf4J577oHdbsfYsWNx/fXX45prrsEzzzwDAGqF+8c//hH//Oc/cfrpp+PLL79EQ0MD/v73v+P4449XH/fee+8FAKSmpmLu3Lm44YYbcPTRR+OKK66IeP7f/va3KCwsxMaNG2Gz2dTnOfXUU3H77bd3CO3s7Gx8/PHHMJsPf4nGgwcP4v7778eMGTPw3nvvqSfbjR07FjfddBNefvllXHnller9q6qq8NRTT2HhwoV9fh0p+pSDVfI0/dlmMeOsIwfjre/2459bD6LN40OK1TB/woh6xBC/8TKAutZgaGs7eU6aVf28rsVz6I/12s8HmrDBQJcInD59esTXRUVFWL16Nb788ksAwO9+97uI7y9evBh//OMf8e677+L000+Hw+EAAKxbtw4TJkyIqGC7U19fj3/84x+455570NzcjObm8BnxZ599NpYuXYr9+/dj2LBh6u3XXnttl4ENAB999BE8Hg8WLVoUcRTttddeizvvvBPvvvtuRGjbbLaIrykx1Iamu/IyIt9on3PUELz13X60e/34dNtBnHPUED2aR6QbQ4S21x+AMmWtrbQHaSpr5Y9Af2i3kSWC/rbnz3/+M8aMGQOLxYLBgwfjiCOOgMlkwpo1a2AymVBSUhJx//z8fDgcDuzZswcAMHXqVFx44YVYvnw5HnnkEUybNg2zZ8/GZZddplbOh7Nz507Isoy77roLd911V6f3qampiQjtkSO7v8yq0rYjjjgi4nar1YpRo0ap31cMGzYMVqsVlDh8/oB6xb7cQ0bHThudg1SrGa0eP979qZKhTQOOYUJbkXuYSjsaod2foehENGnSpIhh7UN1tyJbkiS88cYb+Oqrr/DOO+/ggw8+wFVXXYWHHnoIX331VZfXfw4Egv/Pbr31Vpx99tmd3ufQNw12u73T+/VHLB6T+qe2xRN+E56RHPG95CQzzjxyMN7+4QD+sbUGLq8fyUldj74QGYkhVo97/eGV4drQzrQnqScp1UZheHygKCoqQiAQwI4dOyJur66uhtPp7LAn+qSTTsKKFSvw9ddfY/Xq1di8eTP++te/Ajh88I8aNQoAkJSUhOnTp3f6kZ6e3qe2A8C2bdsibvd4PCgrK+N+bgEcjDiSuOOIzblH5QMA2jx+rN9dF7d2ESUCQ4S2xxeutPPSw+/MJUnCoFC1XReFSnugOPfccwEAf/rTnyJuf/jhhwEA5513HgCgoaGhw1a6iRMnAoC6NSwlJQUAOmy3ysvLw7Rp0/D000+jsrKyQxsOHjzYp7ZPnz4dVqsV/+///b+Itj333HNobGxU206J69CDVQ518qgc9fPv9jTEpU1EicJww+OHrhDPSbOhusmtzpFR9yZMmID58+dj5cqVcDqdmDp1Kv7973/jhRdewOzZs3H66acDAF544QU88cQTuOCCC1BcXIzm5mY888wzyMjIUIPfbrdj3LhxePXVVzFmzBhkZ2dj/PjxGD9+PP785z/j1FNPxVFHHYVrr70Wo0aNQnV1NdavX499+/bhhx9+6HXbc3NzsWTJEixfvhy/+MUvMGvWLGzbtg1PPPEETjjhhA4r2Cnx1HRTaWemJKEkLw07a1rw7V5nHFtGpD9DhXaazQK7NXJ+S1mMFo057YHk2WefxahRo7Bq1SqsWbMG+fn5WLJkCZYuXareRwnzv/71r6iurkZmZiYmTZqE1atXRywae/bZZ3HzzTfjlltugcfjwdKlSzF+/HiMGzcOX3/9NZYvX45Vq1ahrq4OeXl5OOaYY3D33Xf3ue3Lli1Dbm4uHn/8cdxyyy3Izs7Gddddh/vuu69XK9xJH1WNXVfaAHBsoQM7a1rwfYUT/oCsHllMZHQGCe3gMGhn78pz1OFxVtqKBQsWYMGCBV3ex2Kx4O677+4yPI855hi88sor3T7fySefjK+//rrT740aNQovvPBCn9t7uO/deOONuPHGG7t8XF7EJDFV1AePks3PSIbN0vkis2MKs/Da1/vQ4vZhZ00Ljsjv/foHIhEZY047VGl3dniKctvBFjePMiUSwJ5QaBdmpxz2PscWZqmff7uX89o0cBgitJXh8Zz0jvttlUrb4wugxc2LDBAlur1KaA86fGiPzktDui04UPgtF6PRAGKM0PYdvtIelKo9YIVD5ESJrM3jU7d8FXVRaZtMEiYWOgAA31U449AyosRgiND2B4LD3oeengRA3fIFcNsXUaJTqmyg60obCM5rA8DOmhY0tnlj2i6iRGGI0FbkdLoQLbpHmRJR7Oyt04R2F5U2ABwTqrQB4Pt9zhi1iCixGCu0u1iIBnB4nCjRaSvtokGpXd63VHP2/q6alpi1iSiRGCq0O9vTmZ0a3St9EVHsKKGdZrMgK6XrPfW5aTakhRajldUe/vrqREZiqNDWXiBEYbWYkGkPdn4OjxMltj114e1ePblgzcicYDXO0KaBwmCh3fnpSeoBK60MbaJEVtGDPdpaIxjaNMAYJrTTbZbDXqJPPcq0mcPjRInKH5BR0RAM7aJuVo4rlEr7QGM7XF5/zNpGlCgME9qHO6MYCG8Fq2WlTZSwKhvb1SOJu9vupRgVCm1ZDg+tExmZYUL7cEPjQHivdm0zQ3sgKy8vhyRJWLVqVa9+bsSIEfjlL3/Z7f0++eQTSJLU4Uzzl156CWPHjkVSUhIcDkevnnsgidij3cPhcaXSBoCyWq4gJ+MzTmh3coSpQjkVrcnli7j29kAjSVKPPnghjejZunUrFixYgOLiYjzzzDNYuXIl2trasGzZMr7Oh9Du0S7K7nq7l2JERGiz0ibjM8RVvoDOT0NTaAO9rtWNIZn2eDQp4bz00ksRX7/44ov48MMPO9x+5JFHxrNZhjFlyhS0t7fDag3/vn3yyScIBAJ49NFHUVJSAgCora3F8uXLAQDTpk3To6kJSZnPNpskDHUk9+hnMu1JGJRqRV2rh5U2DQhCh3ZAc9WurobH8zPCfwAOOF0DNrSvuOKKiK+/+uorfPjhhx1uP1RbWxtSUno2XBkLsizD5XLBbk/s/28mkwnJyZFhU1NTAwAcFu+B6qbg9FVeug0Wc88HAUfmpIZCmyvIyfiEHh5XFq0AnR9hqtDOj1XUcwitK9OmTcP48ePxzTffYMqUKUhJScGdd94JIDi8vmzZsg4/M2LEiA7XtHY6nVi0aBEKCgpgs9lQUlKCBx54AIFA99MTyhzyBx98gOOPPx52ux1PP/10rx7X6XRiwYIFyMzMhMPhwPz58+F0Ojs8V1VVFa688koMHz4cNpsNQ4YMwfnnn4/y8vIO9/38888xadIkJCcnY9SoUXjxxRcjvn/onPaIESOwdOlSAEBubi4kScKCBQuQm5sLAFi+fLk6JdHZ6zrQVDe5AARDuzfC277Yt8n4hK60vf6A+q6jq+Hx4Vnh0OYK0+7V1dXhnHPOwSWXXIIrrrgCgwcP7tXPt7W1YerUqdi/fz8WLlyIwsJCfPnll1iyZAkqKyvxpz/9qdvH2LZtGy699FIsXLgQ1157LY444ogeP64syzj//PPx+eef4/rrr8eRRx6JNWvWYP78+R2e58ILL8TmzZtx8803Y8SIEaipqcGHH36IvXv3YsSIEer9du7ciblz5+Lqq6/G/Pnz8Ze//AULFizAcccdh9LS0k7/DX/605/w4osvYs2aNXjyySeRlpaGo446CieddBJuuOEGXHDBBZgzZw4A4Oijj+7Va2xENUqlndGzoXGFshittsWNJpcXGcldn6RGJDLhQ1uJ6q4qbbvVjLx0G2qa3RErVKlzVVVVeOqpp7Bw4cI+/fzDDz+MXbt24bvvvsPo0aMBAAsXLsTQoUPx4IMPYvHixSgoKOjyMXbu3In3338fZ599tnrbvffe26PHffvtt/HZZ5/hf/7nf3DbbbcBAG644QacfvrpEc/hdDrx5Zdf4sEHH8Stt96q3r5kyZIO7dm2bRs+++wznHbaaQCAefPmoaCgAM8//zz++Mc/dvpvmD17Nr7//nusWbMGc+fORU5ODgBg2LBhuOGGG3D00Ud3OzUxkNQ0963SHqVZjFZe24qjhzui2SyihCJ4aMvh0O7kCFOtwuwU1DS7+zU8XnXffXBv2drnn48225FjkR8auo7q49psuPLKK/v886+//jpOO+00ZGVloba2Vr19+vTp+O///m989tlnuPzyy7t8jJEjR0YEdm8e9//+7/9gsVhwww03qPcxm824+eab8a9//Uu9zW63w2q14pNPPsHVV1+NrKysw7Zn3LhxamADweHuI444Art37+7+BaFuuX1+NIQurzm4t5V2rnYFOUObjE3w0A7PY3a1EA0IhvbXexr6VWm7t2xF28aNff55UQwbNixiBXRv7dixAz/++KM6d3soZXFWV0aOHNnnx92zZw+GDBmCtLS0iO8fccQREV/bbDY88MADWLx4MQYPHoyTTjoJv/zlL/GrX/0K+fn5EfctLCzs8HxZWVloaGjo9t9C3VOGxgFgcEbvKu0hGeEFigd5FgMZnCFC22ySDnuEqaIgtBitqskFl9ff7f07YztybO8bGUOxak9vV2n7/ZHHRwYCAZx11ln4z//8z07vP2bMmD61IRqPe6hFixZh5syZWLt2LT744APcdddduP/++/GPf/wDxxxzjHo/s7nz3xdZs4OB+q5GE7Z56b2rtNOTLTCbJPgDMhraeFQxGZvgoR38g5nUg+0h2hXk+xraUZKX1sW9OxeLoWiRZGVldViB7fF4UFlZGXFbcXExWlpaMH369Kg+f08ft6ioCB9//DFaWloiqu1t27Yd9nEXL16MxYsXY8eOHZg4cSIeeughvPzyy1Ftv6K7q1cNRDWhleMAkNfLSttkkpCVkoTaFg/qW73RbhpRQhF8y1ew0u5RaA/itq/+Ki4uxmeffRZx28qVKztU2vPmzcP69evxwQcfdHgMp9MJn8/Xp+fv6eOee+658Pl8ePLJJ9Xv+/1+PPbYYxE/09bWBpfLFXFbcXEx0tPT4XbHbphV2fPe2Ra0gapaE9q9ndMGgKyU4HROQysrbTI2wSvtYGhbzN1XLkXZ2m1fPIShL6655hpcf/31uPDCC3HWWWfhhx9+wAcffKCuilbcdtttePvtt/HLX/5S3RbV2tqKn376CW+88QbKy8s7/ExP9PRxZ86cicmTJ+OOO+5AeXk5xo0bh7feeguNjY0Rj7d9+3aceeaZmDdvHsaNGweLxYI1a9aguroal1xySb9eq67Y7XaMGzcOr776KsaMGYPs7GyMHz8e48ePj9lzJjpleNxikpCd0vv1FFmpwZ+p5/A4GZzQoe3rxfB4broNNosJbl8Ae+vbY900Q7r22mtRVlaG5557Du+//z5OO+00fPjhhzjzzDMj7peSkoJPP/0U9913H15//XW8+OKLyMjIwJgxY7B8+XJkZmb26fl7+rgmkwlvv/02Fi1ahJdffhmSJGHWrFl46KGHIuapCwoKcOmll+Ljjz/GSy+9BIvFgrFjx+K1117DhRde2PcXqgeeffZZ3Hzzzbjlllvg8XiwdOnSAR3aymlouek2mEy9nz4YlMpKmwYGSY7CSprNmzdj/Pjx2LRp02EPmoi2QEDGG1Nm4qjaXagrKcWp697o9mfOevhT7KhpwfQjB+PZ+ccf9n7KNp5Ro0ZFrb1EfdXd72N/+58e/RcA9vzHr9C2cSNSTjgBf5h2I/61oxYTChz4242Te/1Yd675Ca9s2ItBqVZ8c9dZMWgtUWz0tv8JO6fd2O5VV+5aTD37ZyiL0TinTZRYajTnjveFMqTe0OZBIMAV/WRcwoZ2nWYYLKkHc9pAeNvX3vo2btUhSiDVodPQertHW6HMaQdkoMnFFeRkXOKGdkt4dW9P5rSBcKXd7vWjtoVzX0SJICDLcIZOQ+vtHm1Fdmr4vPF6zmuTgQkb2tqO2ZPV4wBQpNn2xcv4ESUGj+ZqfX2utDUrznnAChmZsKEdOTzes3/GmMHp6ufbqpqi3iYi6j3tccS9vcKXYlBqOOzrOIpGBiZsaEdU2j3cIjLMYUeqNXgc5daq5i7vyzlvShRG/130+DSh3ceFaFma4XFW2mRkwoa2dk7b1MNjIU0mCUfkB6vtrkLbZDLB7/cb/o8lJT5ZluH3+w199Km20u7LaWgAkJ0aHh7nUaZkZOKGdh8Xm4wdkgEA2FbVfNhQttls8Pv9qKmpYXCTbnw+HyorK+H3+ztcscxIlEq7r6ehAYA9yQybJfjnjJU2GZmwJ6L1dYXo2FCl3eL2YV9Du7oNTGvw4MFwu92or69HY2MjzGazoSsdSiyyLCMQCKhnqaekpHR5rW/RKZV2TlrfTkMDghdhyU61orLRxdXjZGjCVtp9D+0M9fPDDZGbTCYUFhbC4XDAarUysCmuJEmCxWJBeno6hg0bhsLCQlgswr6/7pZytb5BaX2/hjvAi4bQwCDsX4K+Do8fccgK8rPGDe70fiaTCUOGDOnTcxBRz/lClfagtL4tQlMood/Xvw1EIhCy0g4E5D5X2pkpSRiaGVzssqWbFeREFHtqpZ0apUqbc9pkYEKGdpPLC38/zhfWLkYjIn15A6FKu5+hrawg55w2GZmQod3f4S9l29fugy1wef3RaBIR9ZFygY/+Do8rlXazyxexjYzISIQM7f6+k1ZWkAdkYEd1SzSaRET91N+FaNk8YIUGACFDu7/HFI4bEl5BvvlAY3+bQ0RR0O857YgDVhjaZExihnaru/s7dWFUbhpSQseZ/rSfoU2UCPo7PJ7N0KYBQMjQru9npW02SSgdGqy2GdpEiSFaC9EAoIFHmZJBCRnaykI0cx9PTwKAo4Y5AABbK5sjLlhARPro95x2irbS7t9oHFGiEjK0laEvSw8vydmZo4YHK22PP4Dt1dz6RaQne5IZKdb+nfXkSOFFQ8j4hAxtZWVoUhQqbYBD5ER662+VDQBWiwnpycHgZ6VNRiVkaDe5ghdSMJv7HtqjclLVa2v/uI+hTaSn/i5CUx8nlUeZkrEJGdrNruDQl8XU9+abTBJKh2UCADax0ibSVX8XoSmUbV/cp01GJWRoN7WHKu1+DI8DwFGh0N5a1QS3jyejEeklWqGtVtr93GFClKiEDG2l0u5vaB89PBjaXr+M7VU8GY0onrRXD4jW8DjPHyejEy603T4/3KEtWpZ+hrayVxsAtnEFOVFcaS/6E61KOzs1GP4NbR7Ict8vKkSUqIQL7ebQIjSg/5X28KwU9fOK+rZ+PRYR9Y72oh7RWD0OhM8f9/plNLt93dybSDzChXZTe3j/ZX8r7eQkM/LSg+/MKxoY2kTx5PNrKu2oDY+HH6e/JycSJSLxQjuKlTYAFGQHq+199e39fiwi6rmISjvKC9EAbvsiYxIutJVFaABg7seWL0VBlh0AK22iePMGYjE8zouGkLEJF9rKdi+g/8PjQLjSrmpycdsXURx5NcPj2VFbiKa9aAhDm4xHuNCOrLSjENqhxWiyDBxwuvr9eETUM77Q8LjZJMFmMUflMbM5PE4GJ1xoN0U5tIdn29XPuYKcKH6USjupHxf+OVSK1QybJfh4PH+cjEi40Fa2fElSdEK7MFuz7Yvz2kRxo0xHWS3R+zMkSRLPHydDEy60lS1faTYL+h/ZwJBMuzo3XsEV5ERxoxySZItiaAOa88cZ2mRAwoW2UmlnJCdF5fHMJglDHVxBThRPHl9A3fJljdJ8toJHmZKRCRfaypy2ct3caCgIzWvv45w2UVxUN7nUw8ejXWlzeJyMTMDQDlXa9uhU2kB4BXlFA4fHieJhvzPc16Id2sqpaKy0yYjEC+3QnHZGVCvtYGjXt3rQyvOKiWLugCa0o7kQDQgf1NLm8cPl5dkLZCzChXa057QBYHiWZtsX57WJYi6WoZ2VwlPRyLiEC+3YzGlrr/bFIXKiWNuvOcjILEVjH0gYjzIlIxMqtAMBGS3u6M9pD3eEK+3KRoY2UaxpK+1o055jzsVoZDRChXaLxwfluvbRHB4flGaDck7LwWaeokQUa7EM7chKm/2ZjEWo0G7WXJYzmsPjZpOkrjhlaBPFlizLsQ3tiDltbxf3JBKPUKGtrBwHojs8DgB56cHQrmFoE8VUU7sPrZ7YrerOtCepRxzXtbA/k7EIFdqxqrQBIDedlTZRPByI8boRk0lSV5BzIRoZjVChHVFpR3FOG2BoE8VLLIfGFTwVjYxKrNDWXJYz2pW2Mjxe2+JGICBH9bGJKCwuoR1aQc7hcTIaoUJbOzwe7TltpdL2BWQ0tPHdOVGsKHu0pSjvz9YalBbsz6y0yWiECm3t8His5rQB4CDfnRPFjFJpR/skNC1leLy+haFNxiJUaDeHDlaxWUywRflyfnnpyernNU0MbaJYqWkOVtpWc+xDu9ntg9vH88fJOIQKbaXSTo/yIjTgkEqbi9GIYqY2VP0mmWM3PJ6dxqNMyZiECm31YiH26A6NAxweJ4oXZXFYUkwr7XB/ruMQORmIUKEdvlhI9CvtNJsFKdbgkDuHx4liw+cPoKEt2I8tsQxtTaVdyzfhZCBChbayqjszyivHFepebXZyopio1+zMiOXw+CBe6YsMSqjQrm0Odr48zVB2NOWGtonUNLm6uScR9YV2qDqmw+NpHB4nYxImtAMBWR3myo1RaOdlsNImiqXI0I5dpZ2RbFEfn3u1yUiECW1nuxe+0EllOWmxrbS5epwoNuo0l8qMZaUtSZJ6iU6eikZGIkxoa4M0VpW28rjNLh9cXu7tJIq22jgNjwNQL7fLSpuMRMzQjlGlrT1ghdU2UfQpVa/FJKmXz4yVnDReNISMR5zQbgkvDot1pQ3wutpEsaCsS8lOtSK2ka250heHx8lAhAltZeU4EJ/QPtjMFeRE0aYsRBsUo9EyLWV4nFu+yEiECW1lRbfVYkJGlC8WoshjpU0UU7WhAM3RHH4SK8oBK20eP9o8vm7uTSQGcUI7FKK5abaYXdIvO9UKZZqNp6IRRZ8yVK09/CRWtM/BvdpkFMKFdk6MhsaB4LGKyhB5FQ9YIYq6eA6Pa5+DQ+RkFMKFdqxWjisGZwRXkFcztImiqs3jQ3toK+WgOAyPZ2sr7VaOnJExiBPaMT4NTcHQJooN7RB1TmrsK23tvDmHx8kohAhtrz+gXiwk1qGdHwrtqkaGNlE0aa+2lZMej4VomvPHOTxOBiFEaNe3eiAHTzCNfWhnBkO7yeVDu4enohFFi7baHRSHSjvVaobVEvwTV8vdIGQQQoR2PE5DU2i3fXGInCh6tJV2POa0JUlS+zO3cJJRiBfacaq0Aa4gJ4om7RB1PCptgGtUyHiEC+1YXUtbocxpA+zoRNGkVNqpVjPsVnNcnlPpz6y0ySjECG3tApZYb/nSVtpcjEYUNfHco63Iywidu9DogqwsjCESmBihHXqXnGazxPwderrNAntS8DmqeSoaUdQoe6XjMZ+tUCrtdq8fzW4eZUriEyO047RHGwguXlHmtTk8ThQ9lc5gfxqsuQRurA3WTndx5IwMQIzQjtNpaIrBGTzKlCiaAgEZ+5ztAIDhWfa4PW9EaHPkjAxAiNCubAx29tyM+IR2PlecEkVVbasbHl8AQLxDO/w3g2/CyQgSPrS9/gAOhIbVirJT4vKcyrvzmiY3F68QRcG+hnb182FZ8enHwKGVNkObxJfwob2/oR3+QDA4iwbFN7Q9/gCvDkQUBfu1oe2IX6WdarMg3WYBANQwtMkAEj6099S3qZ8XZqfG5Tm1B6xwHoyo//Y7tZV2/EIbCG/j5PA4GUHih3Zdq/r5iJz4VtoAh9SIomFfQ/DNd3qyBZn2pLg+tzKvzTfgZAQChHaws1stprhtFeHiFaLoUobHh8dxPlvBo0zJSIQJ7cLsFJhMUlyeMy89GVLoqQ5ohvWIqG+U4fF4zmcrBmuOMg0EuLCUxJbwob23Pjg8Hq+V40Cwqlf+uOyube3m3kTUFVmW1dXj8dzupVC2cPoDMmpbOUROYkvo0A4EZOwNLUQrjNPKcUVxbhoAYPdBhjZRfzjbvGgLXZtej9DWTnfVcF6bBJfQoV3T7IbLGzyQYcSg+KwcV4RDu4VDakT9sE+n7V4KLiwlI0no0NauHI97pZ0XfJPg9gUitqsQUe/sd4a3beq5EA3gwlISX2KHtmaPdjzntIFwpQ0Auw62xPW5iYwk8jS0+Ffauek2dWEpt32R6BI6tPeGVo6bpPi/Q48Mbc5rE/WVEtopVjOyUuK7RxsAkswm5IQuNlTJUTMSXEKHdnloeHxIph1WS3ybmpNmRUZy8PhDVtpEfafd7iVJ8dm2eaiRoTUx7MskuoQObWXleLxOQtOSJAnFecFqe1cNOzpRXylrU/QYGleUDA725R3VLbwIEAktYUPbH5CxMxSWI3Piu3JcMSonFNocHifqk2aXFztC/XjckAzd2jEm9Aa82e3jvDYJLWFDe3t1s7q38+jhDl3aoKwgr21xo7HNq0sbiET23V4nlML2uKIs3doxenC6+vmOmmbd2kHUXwkb2j9UONXPJxY4dGlDxGK0Wg6RE/XWN3sa1M+PLdQxtPPCfXl7NfsyiSthQ/v7UGin2SwR4RlPEaHNeW2iXvt2bzC0i3NTkZVq1a0duek2dWHpTlbaJLCED+2jhmXCHKcLhRyqaFAKLKHn3sHQJuoVf0DGd3udAPQdGgeCC0vHhIbId7DSJoElZGi3un3YXh18Nzyx0KFbO5LMJhyRH+zoG8vrdWsHkYi2VTWjxe0DABxflK1za4DRoRXk26ubuYKchJWQof3T/kYox31P0GkRmuKkUYMAAD/ta0Rr6A8QEXXvm72a+WydK20AKMkLvgFvcvlwsJkryElMCRna2kVox+hYaQPAiSODFYIvIEcsqiGirn0TGp1ypCRhlE7bNrXGDA6vUeF0F4kqIUNbmc/Oz0iOOOxfD5NGZqvnFm8oq9O1LUSiCARkfLU7GNrHFmbBpNO6FK3ReeFtX8r0G5FoEi60ff4Avg5VtHpt9dJypFgxNj94KITyR4iIuvZVWZ16Ra3Tx+bp3JqgwRk2pNuCK8gZ2iSqhAvtdT9WqvNNU8bk6tyaoJNGBYfIf9znRJuH89pE3Vnz7X4AQJJZwi+PGqJza4IkScLRBZkAgA9/roHXH9C5RUS9lxChHQjIcHn9kGUZT36yCwCQk2bDnGOH6dyyoBNHBhejef0yvt3j1LcxRAmu3ePHe5uqAABnjM3TdX/2oWZPDP5NqW1x49NtB3VuDVHv6R7a/oCMi55ej6OX/R3XvvgNtoWGra46dQSSk8w6ty5IWYwGAF/sqtWxJUSJ78Mt1epWrwuOGa5zayKde9QQpFiDf1de/6ZC59YQ9Z7uob2xvB7f7GmAxx/AR1uqAQDpNguuOKlI55aFZaVaMWF4cFjt1Y0VaA+diU5EkWRZxmsbg2GYaU/C6WMTY4pLkWqz4LzQcP3HW2pQ18KtXyQW3UN73Y8HOtz2q1OKkJGcpENrDu+a00YBAOpbPfjrxr06t4Yocbi8ftQ0uSDLMv749234fGdwNGrmhCGwWRJjtEzrouMLAAS3cf51I6ttEotFzyf3+QN476fg3NeUMbm4fsooVDS04cJjE2tIDQgOqz30920or2vDM5/txuUnFsFq0f09D5Gu2j1+zP7zF9hW3YycNBtqQ5VrYXYKfnvmGJ1b17kTRmRhxKAUlNe14cEPtsHrD+A3Z4xOiG1pRN3RNXXW765DXasHADDz6CE4pSQHF59QCIs58cLQbJJw/dRiAMCBRhdWfraLRyHSgPfUp7vUdShKYOem2/Dy1SciN92mZ9MOS5IkrLjgKHVu+08f7cCsP3+O936qhI8ryinB6Vppr/uhEgBgNZswozRfz6b0yAXHDsOfPtqBqiYX/vj37fhyVx0umVSI0qEZKMhKYeVNA0pFfRue+jS422NUbiqKc9PQ7PJi2axSFA5K0bl1XZtckoO/3TgZC1/6BrtrW7FpfxNuWP0t0m0WnDAyG0cOSUdRdiqGZ9kxxGFHXroNKVYzJInVuGhkWcaWymZIEjA2P134/4dxC+29dW34yxdlkGUZARnYU9+Gr3YFTxibMiYHmfbEmsPujM1ixuOXHYObXvkOVU0ufLmrDl+G/g2SBOSl25CTZkNWihV2qxlWiwlJJgkmkwSzJMEkSTCZgu/0pdDPSJDUE9eUXyU9fqk6GzWQ1e8BMuTQf4NfI/R1QA7fHgh9Ih/ymIcbjwi+BsHXAuprIIVel/Dro95f0v63s9ct/BhQf77/r+e+hjZNp8/A8Cx7vx6vt+44Z6zuOykq6tvw3Odl6tcefwDf73XC7QtWpv9z4dE4foT+FwXpjdGD0/G3mybjuc/LsOrLcjjbvGh2+/CPrTX4x9aaDve3WUzISrEiw25Bqs2CVKsFyUkm2CxmJJklWMwmWEL9vav+rejN7+XhRvW0fTR8m7avhvtpQPlvQIZfDm619Qdk+AIBeP3Bz73+APwBOdivgeDPhL72B4KP5Zdltd+r95Wh/m1XflbLJAFmKfQamSVYTJLarmAbgo/hCwTQ5vFDloOXcx0zOB12q7nP/dgfkPH5zlqU1bYCAEblpGJySU5crxwZ7f4bt9CuaXZh1ZflnX5vTgLOYR/O8SOy8fffTcH/vL8Vr26sgNcfCiYZqG5yo7qJq1GNbF9De9yf83czxuge2l3139kThwoX2Ir05CQsmj4G1542Cu9vqsL63XX4urwee+vb1IsWKdy+AKqaXKhq0qetA80P+xrxw77GqD7m7tpW7A4FeLxEu//GLbQtZhMyki3qu9CsFCuOK8rClDG5OGd84g+Na2UkJ+He2UfhD+eNw7aqZmyrasa+hjbsd7rQ0OZBfasHLq8fHl8AvtC7WfWdKsLvfpXPgc7fMcdbZ29ktdW/hMgqVwJgkoKfm0I/bDIdpjo+5HGVil2pCgBo/itrKnrl/uE3R5HfkyN/VvscoXf+yud9lZNmw5FDMiAjOMw2ELcJmSQpYjTMYpJgt5oxMicVvz9vnI4ti45UmwUXHjccFx4XLCA8vgD2O9txIPRR2+JBfasbje1eNLZ70ebxo9Xtg9sXgMvrh9cvw+cP9veArP0d7ny0qS+/jocrNDsboVMr/VC/NYf+7ppMwRE/s0mCSQIsJhPMJglJ5uBtFrMJ5tD3lX4thapkSZJgNiH0s6Gv1fto73doVSwjEAhW6D5/AN6ADL9fVh/fYg6NRJqCFXhykhn+gIxt1c3YfbAV/kCgwxuo8Osod1qBa/v7iJxUzJowFLIM/O2H/dhb19ablz3hxC20JxY48OOys+P1dHGRnGTGhAIHJiTAGelEsXRMYRZ+WDpD72bEjdViwsicVIxMgKuTUfRcO2WU3k3oN66cIiIiEgRDm4iISBAMbSIiIkEwtImIiATB0CYiIhJEVFaPu93BLTA7d+6MxsP12AGnE+1uN+xOJ1o2b47rcxMlCqXfKf2wt9h/ifTT2/4bldCuqAheKWf27NnReLjeKy8D1r2jz3MTJYiKigoce+yxffo5gP2XSE897b+SHIWrXjidTnz66acoKCiAzXb4iwTs3LkTs2fPxtq1a1FSUtLfpxUeX49IfD0i9fT1cLvdqKiowNSpU+FwOHr9POy/fcPXIxJfj0ix6r9RqbQdDgfOP//8Ht+/pKQEpaWl0XhqQ+DrEYmvR6SevB59qbAV7L/9w9cjEl+PSNHuv1yIRkREJAiGNhERkSAY2kRERIKIa2jn5uZi6dKlyM3NjefTJiy+HpH4ekRKtNcj0dqjN74ekfh6RIrV6xGV1eNEREQUexweJyIiEgRDm4iISBAMbSIiIkEwtImIiATB0CYiIhJEXELb7Xbj9ttvx9ChQ2G323HiiSfiww8/jMdTJ6SWlhYsXboUv/jFL5CdnQ1JkrBq1Sq9m6WLjRs34qabbkJpaSlSU1NRWFiIefPmYfv27Xo3TRebN2/GRRddhFGjRiElJQU5OTmYMmUK3nlHvwtqsP9GYv8NY/+NFI/+G5fQXrBgAR5++GFcfvnlePTRR2E2m3Huuefi888/j8fTJ5za2lrcc8892LJlCyZMmKB3c3T1wAMP4M0338SZZ56JRx99FNdddx0+++wzHHvssdi0aZPezYu7PXv2oLm5GfPnz8ejjz6Ku+66CwAwa9YsrFy5Upc2sf9GYv8NY/+NFJf+K8fYhg0bZADygw8+qN7W3t4uFxcXyyeffHKsnz4huVwuubKyUpZlWd64caMMQH7++ef1bZROvvjiC9ntdkfctn37dtlms8mXX365Tq1KLD6fT54wYYJ8xBFHxP252X87Yv8NY//tXrT7b8wr7TfeeANmsxnXXXedeltycjKuvvpqrF+/Xr2W70Bis9mQn5+vdzMSwimnnAKr1Rpx2+jRo1FaWootW7bo1KrEYjabUVBQAKfTGffnZv/tiP03jP23e9HuvzEP7e+++w5jxoxBRkZGxO2TJk0CAHz//fexbgIJRpZlVFdXIycnR++m6Ka1tRW1tbXYtWsXHnnkEbz33ns488wz494O9l/qLfbf2PbfqFxPuyuVlZUYMmRIh9uV2w4cOBDrJpBgVq9ejf379+Oee+7Ruym6Wbx4MZ5++mkAgMlkwpw5c/D444/HvR3sv9Rb7L+x7b8xD+329nbYbLYOtycnJ6vfJ1Js3boVN954I04++WTMnz9f7+boZtGiRZg7dy4OHDiA1157DX6/Hx6PJ+7tYP+l3mD/DYpl/4358Ljdbofb7e5wu8vlUr9PBABVVVU477zzkJmZqc6lDlRjx47F9OnT8atf/Qrr1q1DS0sLZs6cCTnO1/dh/6WeYv8Ni2X/jXloDxkyBJWVlR1uV24bOnRorJtAAmhsbMQ555wDp9OJ999/n78Xh5g7dy42btwY9/2v7L/UE+y/XYtm/415aE+cOBHbt29HU1NTxO0bNmxQv08Dm8vlwsyZM7F9+3asW7cO48aN07tJCUcZhm5sbIzr87L/UnfYf7sXzf4b89CeO3cu/H5/xMZyt9uN559/HieeeCIKCgpi3QRKYH6/HxdffDHWr1+P119/HSeffLLeTdJVTU1Nh9u8Xi9efPFF2O32uP9BZP+lrrD/RopH/435QrQTTzwRF110EZYsWYKamhqUlJTghRdeQHl5OZ577rlYP33Cevzxx+F0OtXVt++88w727dsHALj55puRmZmpZ/PiZvHixXj77bcxc+ZM1NfX4+WXX474/hVXXKFTy/SxcOFCNDU1YcqUKRg2bBiqqqqwevVqbN26FQ899BDS0tLi2h72386x/wax/0aKS/+NyhEt3Whvb5dvvfVWOT8/X7bZbPIJJ5wgv//++/F46oRVVFQkA+j0o6ysTO/mxc3UqVMP+zrE6dczofzv//6vPH36dHnw4MGyxWKRs7Ky5OnTp8t/+9vfdGsT+29H7L9B7L+R4tF/JVmO83JUIiIi6hNempOIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0iYiIBMHQJiIiEgRDm4iISBAMbSIiIkEwtAeIadOmYdq0ab36mWXLlkGSJNTW1nZ73xEjRmDBggURt+3YsQMzZsxAZmYmJEnC2rVre/X8RBTE/kuKmF/liwau+fPno6ysDCtWrIDD4cDxxx+PV155BTU1NVi0aJHezSOiLrD/JiaGNkXFtm3bYDKFB27a29uxfv16/P73v8dNN92k3v7KK69g06ZN7PRECYT9VxwcHhdAa2ur3k3ols1mQ1JSkvr1wYMHAQAOh0OnFhElBvZfiiaGdoJR5qF+/vlnXHbZZcjKysKpp54KAHj55Zdx3HHHwW63Izs7G5dccgkqKio6PMbKlStRXFwMu92OSZMm4V//+lenz/XYY4+htLQUKSkpyMrKUoe/DuV0OrFgwQI4HA5kZmbiyiuvRFtbW8R9tHNiy5YtQ1FREQDgtttugyRJGDFiBKZNm4Z3330Xe/bsgSRJ6u1ERsH+S7HG4fEEddFFF2H06NG47777IMsyVqxYgbvuugvz5s3DNddcg4MHD+Kxxx7DlClT8N1336nviJ977jksXLgQp5xyChYtWoTdu3dj1qxZyM7ORkFBgfr4zzzzDH7zm99g7ty5+O1vfwuXy4Uff/wRGzZswGWXXRbRlnnz5mHkyJG4//778e233+LZZ59FXl4eHnjggU7bPmfOHDgcDtxyyy249NJLce655yItLQ2pqalobGzEvn378MgjjwAA0tLSYvMCEumI/ZdiRqaEsnTpUhmAfOmll6q3lZeXy2azWV6xYkXEfX/66SfZYrGot3s8HjkvL0+eOHGi7Ha71futXLlSBiBPnTpVve3888+XS0tLe9SWq666KuL2Cy64QB40aFDEbUVFRfL8+fPVr8vKymQA8oMPPhhxv/POO08uKirq8nmJRMX+S7HG4fEEdf3116ufv/XWWwgEApg3bx5qa2vVj/z8fIwePRr//Oc/AQBff/01ampqcP3118Nqtao/v2DBAmRmZkY8vsPhwL59+7Bx48ZetQUATjvtNNTV1aGpqak//0Qiw2L/pVjh8HiCGjlypPr5jh07IMsyRo8e3el9lQUke/bsAYAO90tKSsKoUaMibrv99tvx0UcfYdKkSSgpKcGMGTNw2WWXYfLkyR0ev7CwMOLrrKwsAEBDQwMyMjJ6+S8jMj72X4oVhnaCstvt6ueBQACSJOG9996D2WzucN++zCsdeeSR2LZtG9atW4f3338fb775Jp544gncfffdWL58ecR9O3tOAJBludfPSzQQsP9SrDC0BVBcXAxZljFy5EiMGTPmsPdTVnzu2LEDZ5xxhnq71+tFWVkZJkyYEHH/1NRUXHzxxbj44ovh8XgwZ84crFixAkuWLEFycnJM/i2SJMXkcYkSFfsvRRPntAUwZ84cmM1mLF++vMO7Y1mWUVdXBwA4/vjjkZubi6eeegoej0e9z6pVq+B0OiN+TvkZhdVqxbhx4yDLMrxeb2z+IYC6ApVooGD/pWhipS2A4uJi3HvvvViyZAnKy8sxe/ZspKeno6ysDGvWrMF1112HW2+9FUlJSbj33nuxcOFCnHHGGbj44otRVlaG559/vsOc2IwZM5Cfn4/Jkydj8ODB2LJlCx5//HGcd955SE9Pj9m/5bjjjsOrr76K3/3udzjhhBOQlpaGmTNnxuz5iPTG/kvRxNAWxB133IExY8bgkUceUeesCgoKMGPGDMyaNUu933XXXQe/348HH3wQt912G4466ii8/fbbuOuuuyIeb+HChVi9ejUefvhhtLS0YPjw4fjNb36DP/zhDzH9d/z617/G999/j+effx6PPPIIioqK2OnJ8Nh/KVokmasRiIiIhMA5bSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0iYiIBMHQJiIiEkRULhjidDrx6aefoqCgADabLRoPSUQ95Ha7UVFRgalTp8LhcPT659l/ifTT6/4rR8HatWtlAPzgBz90/Fi7di37Lz/4IehHT/tvVCrtgoICAMDatWtRUlISjYekODpwxxK0b9oE+/jxGPrf9+vdHOqlnTt3Yvbs2Wo/7C09+i9/54iCett/oxLaypBaSUkJSktLo/GQFEdpDgfabDakOBwo4v8/YfV1aFuP/svfOaJIPe2/XIhGREQkCIY2ERGRIBjaREREgmBoG5gsy9he3Qy3z693U4jiqt3jR1ltq97NIIo6hraBvbxhL2Y88hmuffEbvZtCFDeyLOOip7/E6X/8BB/+XK13c4iiiqFtYJ9uqwEAfLGzFi4vq20aGNy+ADbtbwIAbNhdp3NriKKLoW1guw4Ghwf9ARk7a1p0bg1RfLS4ferntS1uHVtCFH0MbYNy+/zYW9+mfr2tqlnH1hDFT4tLG9oeHVtCFH0MbYPaW9cGf0BWv95a1aRja4jih5U2GRlD26B2HYwcDt/KSpsGiGYXQ5uMi6FtUIfOYTO0aaDQVtr1rZ6IESci0TG0DUpZhKY42OxGHasOGgBa3F7184AcDG4io2BoG5QyPG61hP8XczEaDQTahWgAUNfKN6tkHAxtA5JlGbtCw+NTx+Sqt29haNMA0OyODO3aZlbaZBwMbQOqbnKj1RM8TGVy8SA4UpIAAFsruYKcjO/QSpuL0chIGNoGpF05XpKXjrH56QCAbdWstMn4Wg6ttBnaZCAMbQPShnZxXirG5mcAALZXN0OWuZKWjO3QSvsgQ5sMhKFtQMp2rxSrGfkZyRieZQcAuLwBONu8Xf0okfCaDh0e55w2GQhD24AOOF0AgMLsFEiShCGZdvV7lY0uvZpFFBfaLV8Ah8fJWBjaBtTQFqwsslKsAID8zGT1e1VN7bq0iSheOKdNRsbQNiAltLNTg6E9RBParLTJ6Drs0+ZFQ8hAGNoG1BA6ASorNbjVKzfdBpMU/F4VQ5sM7tBKu67VzQWYZBgMbYPxB2Q0tgfn9JTh8SSzCbnpNgCstMn4lAuGWM3BP29ef7hPEImOoW0wTe1eKNdHUEIbAPJDi9EqGzmnTcbl8QXg9gUAAAXZ4QWYnNcmo2BoG4wynw2Eh8cBYGhoXpuVNhlZq2ZofGROqvr5QW77IoNgaBtMRGhHVNrB0K5qdHF+jwxLO589YlA4tFlpk1EwtA2mvjU8d6cNbWUFeZvH3+HwCSKjaNb8bhflMLTJeBjaBqOttJUtX0B4ThvgCnIyLm2lXZidou6aYGiTUTC0DUbZ7gVAvboXcOhebS5GI2PSnoaWaU9Cms0SvJ2jS2QQDG2DaQidLZ5kltQ/WACQn6E5FY2VNhmUdng8zWaBLckMAPD4A3o1iSiqGNoGo1TajhQrJElSbx+ckQzlywMMbTIo7fB4erJF3avt9jK0yRgY2gajHmGqWYQGAFaLCTlpwQNWqjg8TgbV0qHSDoW2j6FNxsDQNhj1YiGaPdqKIdyrTQanVNqSFLw0rc0SHB5naJNRMLQNpr418gpfWsq8Nue0yaiUOe00mwWSJMFqUSptv57NIooahrbBOEML0bJSO4b2kEyGNhmbEtrpoUWYtlBoe1hpk0EwtA0kEJA119LuODw+1BHcq93s9qGxjRdQIONRtnylJUeGNofHySgY2gbS7PJ1erEQRWF2ivr5nvrWeDWLKG6UOe00VtpkUAxtA6k/zLnjisJBmtCua4tLm4jiSVk9npYcHGninDYZDUPbQA53hKmiSHMBhb31DG0ynmb3oXPaPFyFjIWhbSCHO8JUkWazYFAozPfUcXicjKfFFTk8zsNVyGgY2gZS39p1pQ2Eh8g5PE5GpM5pKwvRQoersNImo2BoG4hTsyLc0cmcNhC+xjCHx8lo/AEZbZ7g3PWhC9FYaZNRMLQNRFmIZjZJyEi2dHofZQV5ZaMLLi8X55BxaM8dV4fHuRCNDIahbSDOtvBpaNqLhWgVaVaQV7DaJgNp94SDOcUWXICmLEQLyICPQ+RkAAxtAwkfYdpxEZqiiNu+yKDaPOFKO8UaDGul0gZ4wAoZA0PbQJQ57c5WjisKs8Pbvvaw0iYDadNU2vakyDltgAeskDEwtA2kKbTdJdN++NDOSbOqVchebvsiA9Gu0bCz0iaDYmgbSFN7sNLOSD58aEuSpB6ywkqbjERbaStvTJU5bYCVNhkDQ9tA1NDuotIGgKJs7tUm44kcHu+s0uYKchIfQ9sg/AFZPcKx29AOLUbb19AGv3KFESLBtXs7LkSzcXicDIahbRDK8Y0ADrtHWzE8VGl7/TJqmnltbTKGdk84lO0MbTIohrZBNLaHT0PrrtLOS7epn9c2e7q4J5E4IrZ8JUUergJweJyMgaFtEE2ucGh3tXocAHI1oX2whZU2GYP2cBU7F6KRQTG0DaJJW2l3sXocAHLTNKHd7I5Zm4jiqS205ctiktQKm8PjZDQMbYPQDo/3qtJmaJNBKJW2snIc4OEqZDwMbYPQDo9n2LteiJacZEZ6aLEaQ5uMQg1tazi0ebgKGQ1D2yCa2jWrx7uptIFwtX2whaFNxqAMj6dYtZU257TJWBjaBqEMj5skIM3adaUNhOe1WWmTUbSHVo/bNb//Nq4eJ4NhaBuEMjyenpwEk6nzy3JqqZU2Q5sMot2rzGmH/6xZOadNBsPQNojwEabdV9lAOLRrGNpkEMoxpimHrbQZ2iQ+hrZBKMPj3a0cVyih3ebxwy/zKFMSX2cL0SxmE5SBJw6PkxEwtA1CuSxnd3u0Fdq92l5WIGQA4UrbHHG7shiNw+NkBAxtg+jJZTm18jKS1c+9flbaJL62TvZpA+F5bQ6PkxEwtA2i18Pjmkrb4+cfMxKfy9txeBwIz2uz0iYjYGgbhLJ6vLcL0QDAy9AmwcmyrF4w5NDhcVbaZCQMbQNw+/xweYN/kHo6PJ6dalUX6DC0SXRuXwDKpeFTDjmngJU2GQlD2wC0p6FlpvQstM0mCYNCQ+ReH+e0SWwRV/hK6nwhGlePkxEwtA0g4tzxHlbaQHhem3PaJDrlYBWg45w2h8fJSBjaBhBxWc4ezmkD4XltDo+T6No0lXbHLV8MbTIOhrYB9OaynFoMbTKKrobHWWmTkTC0DUA5WAXo5fC4Gtqc0yaxKSvHgc4WovFwFTIOhrYBRA6P935OW+YxpiS4yDntyD9r4eFxLkQj8TG0DaCvw+M5mr3aRCKLHB7nli8yLoa2ASirx61mU8RVjbqTk2aNVZOI4qqrhWic0yYjYWgbgLJPO8NugSR1fy1tRU4aK20yhjZv96vHWWmTETC0DSB8Le2eD40DwKBUVtpkDC5NpZ18aGgn8XAVMg6GtgGo5473YuU4ADhSwkeZEoksYnj80C1f5nClzUWXJDqGtgE09rHSNpskZKdyiJzE1+YNThFZzSZYzJ2vHg/IgC/A0CaxMbQNwNkWDO2sHp47rsXFaGQEyurxQ48wBcIL0QAuRiPxMbQNwNnmAQA4ellpA8AghjYZgBraSR1DW7ujgovRSHQMbcH5AzKa3cGhwd7s0VZwBTkZgbJ6/NCV4wBgtYRv42I0Eh1DW3DNLi+UtTWZKb2vmgdxTpsMoKvhcVbaZCQMbcEp89lA/4fH/VxZS4JSzh7vvNLmnDYZB0NbcE7NEaaOfi5E44VDSFTt3mAYJ3NOmwyOoS04ZREa0NfQDg+P+3iJThJUexeVti2Jc9pkHAxtwUVeLKQPc9qa0OZ1tUlUyuEqh16WEwgfrgJweJzEx9AWXMScdh8qbe1RphweJ1F1uRAtiaFNxsHQFpw2tPu75YuVNolKuZ52Z/u0tZU257RJdAxtwTnbg3PaaTYLksy9/99pt5phDh1AztAmEcmyrIZ2Z3Payay0yUAY2oJrDFXafamyFcpZzRweJxG5vAH1rIJOjzE1axaiebkQjcTG0BacsuWrL/PZiiRzsNLm6nESkbJHG+h4hS8gck7bw99xEhxDW3DqueP9Cm1W2iQu7WU5uzsRze1laJPYGNqCUyvtPmz3UoRDm3/QSDytmko7zdbxzav2RDRW2iQ6hrbgmkKhnRmF4XFvIAA/rzdMgml1h0M71db16nFW2iQ6hrbAZFlWt3z15dxxhbrqXI48YY1IBC3u8PB4mq3j4SoWs0ndIeHxcyEaiY2hLbBWjx++UGUcjTltAKhtYWiTWFpcmuHx5I6hDYTntVlpk+gY2gKLOHe8X3Pakvp5dZOrX20iireI4fFOjjEFwvPanNMm0TG0BaY9DS2jH8PjVkt4HrCqkaFNYmlxaxeidV1pu7hPmwTH0BZYYz8vy6mwairtKlbaJJiWiIVonYe2crxpO4fHSXAMbYH192IhCpMUDu1KVtokGGV43GoxRWzv0rKHhs3bNdvDiETE0BaYcu440L85ba2qxvaoPA5RvCiV9uGGxoHwmeTag1iIRMTQFli0Km2tqiZ3VB6HKF6USruzPdoKhjYZBUNbYMqcts1iQnInZy73BSttEk240j78G1d1TpuhTYJjaAssGueOH6qhzcsVtiSUcGj3oNL2ck6bxMbQFlj4NLTozGcruFebRNIaOhHtcCvHAe1CNL4hJbExtAXmjMK5453hCnISCRei0UDC0BZYY6jSzuzHwSqdYaVNIulNaLd7/ZBlXhSHxMXQFlhda3BOOyctusPjrLRJJOHV412FdvB7sgy4eMAKCYyhLahAQEZ9a3B71qBUW1QeU7kSEo8yJVH4A7I65N2TShuIvP42kWgY2oJytnuhXPp6UJQq7aTQaVIMbRKFNoC7Cm27JrS5GI1ExtAWVF1L+BCUQWnRqbRtoUt0VnJOmwTR2oNzx4HISpuL0UhkDG1Baa97nZManUpbObe5mpU2CUIb2oe7ljZwaGhzeJzExdAWVF1r9CttJbRrml3w8brDJIAWd7hq7upwFXtSONA5PE4iY2gLqk5TaWdHq9IODY8HZOBgC88gp8TX4tIMj1s5PE7Gx9AWlDKnLUlAVpQOV9Fe1pDbvkgEPbmWNnBIaPOYXhIYQ1tQtaE92lkpVljM0fnfaLWE/7Dta+CFQyjxaee007uY045cPc45bRIXQ1tQSqU9KEpD4wCQnBT+ddhT2xq1xyWKlZ5W2tqhcw6Pk8gY2oJS5rSjtUcbAMySpJ6utqe+LWqPSxQr2tDu6T5thjaJjKEtKOUI02itHFcUZqcAAPbWMbQp8SnD4xaTBJvl8H/ObBYTQgf+ccsXCY2hLaja0PB4tPZoK0YMSgUA7Knn8DglPu2545IkHfZ+kiSp54+z0iaRMbQF5Pb50Rza6hL1SntQsNKubnLDxVW2lOCae3CFL4UyRM592iQyhraAGlq96ufRnNMGgKJQaAPAXs5rU4Jr7UVo85raZAQMbQHVas8dj9IVvhSF2anq53s4r00JrjV0IlpqF6ehKexJDG0SH0NbQMoiNCD619LWVtp76jivTYmtuQfX0lYolXa7lwvRSFwMbQHF4gpf6uOlWpEa+uPGSpsSnTI83tXBKgouRCMjYGgLSHvueLTntCVJQpG6gpyhTYlNXT3exbnjihQuRCMDYGgLqDZ0hS+r2YT0HgwL9pYyRL6Xw+OU4Fr6MDzOSptExtAWkPY0tK72pvaVsu1rX0M7L9FJCUuW5V6tHrerw+Oc0yZxMbQFpJ47HuWhcUVRaAW5LyDzal+UsNq9fgTk4OdpPZrTZqVN4mNoC0g9wjTK270U2hXk5RwipwTV04uFKMKrx/2QZTlm7SKKJYa2gA42R/8KX1oleWnq51srm2PyHET9pezRBoC0nuzTDoW2LAMuL6d9SEwMbcG4vH5UNQWHrIdn2WPyHHnpNuSEtpJtOtAYk+cg6i9nW3gXRUZyUrf3T0nSXumL89okJoa2YCrq26CM7I3ISe36zn0kSRLGD8sAAGw+0BST5yDqL2XECQDy0pO7vX8Kr6lNBsDQFkxZbXiOOVahDQDjh2YCAHYdbGFVQgmpRhPauendr+/QXlO7nRfDIUExtAWjXRg2clAMQztUacsysKWS1TYlHm2l3ZOdFClW7fA4Q5vExNAWTFlt8JSyTHsSsmK0EA0ASkOVNgBs2s/QpsRzMLT1MTvViiRz93/KIofHOXpEYmJoC6Y8NDwey6FxILjILdMeXNyzaT8Xo1HiUSrt3B6evx9RabtZaZOYGNqCUYbHR2j2UseCJEkoHRocIt/ExWiUgJTQzsvoQ2hzTpsExdAWSLvHr55QNiKG89mK8cOCQ+Q7qpvh9vGPHCWW3lbaEQvRODxOgmJoC2Sv5qpbI2M8PA5ArbR9ARnbq1pi/nxEPSXLsjqn3ZOV4wC3fJExMLQFEq/tXoqjhoUXo20sr4/58xH1VJPLB48veKpZz0Obq8dJfAxtgcRru5f6HDmpGJoZPLTi463VMX8+op462By+kE1PQ9tmMUG5KB6vqU2iYmgLRFk5npWShMyU7o9t7C9JknDmkYMBABt216PJ5Y35cxL1RMTBKj2c05YkST3KlJU2iYqhLZCyOG330jrzyDwAwXntz7YfjNvzEnXlYC9PQ1Mo19Ru93IhGomJoS0QZXg8HkPjipOLByE1NBf48ZaauD0vUVf6GtqpNlbaJDaGtiAqG9tR3RT8QzV6cHrcntdmMeO00bkAgH9uq4HPz0sakv6UleNWs0k9BKgnHKH71ra4u7knUWJiaAvi32Xh1duTRmbF9bmVIXJnmxff7GmI63MTdUbdo51ug6SsLusBZWpp98HWbu5JlJgY2oJQtlzZLCYcNcwR1+c+Y2yeuur2H1s5RE76U0I7pxdD4wAwKicNAFDZ6OL54yQkhrYglEp7YoEDVkt8/7cNSrNhYoEDAEObEkNvT0NTjMoNrwdhtU0iYmgLoKHVg+3VwRPJThyZrUsbzhwbHCLfUdOCCs3JbER60A6P90ZEaNcytEk8DG0BfK2ZRz5Bp9A+PRTaAKtt0pfXH0B9mwdA70Nbe/zv7oM8mpfEw9AWgDKfbTZJOLYwvovQFOOGZCA/I3g6GkOb9FTf6oEsBz/vbWinWC3qKX8cHicRMbQFsCE0n106NAOpNks3944NSZLUanv97jou4iHd7He2q5/3dk4bAEblBhej7a5lpU3iYWgnuMrGdmza3wgAOGGEPkPjCmVe2+ML4PMdtbq2hQauN7/Zp35+5JDen1lQHJrXLjvYClkp2YkEwdBOcC98uQf+QPAPy6wJQ3Vty+SSHCQnBX9l1n6/X9e20MBU3+rBm98GQ/uMsXko6sPpgEql3erxqwcWEYmCoZ3AWt0+vLJhDwDghBFZmBDadqUXu9WM844KvnH4++Zq1GiutEQUD69s2AOXN3gq3zWnjuzTY0Ru++IQOYmFoZ3A3vx2H5pcwbnjq/v4ByraLjuxEEDwAiKvf72vm3sTRY/b58cL64NvYo8ckoGTiwf16XGUShsAdjG0STAM7QTl8vrxl8/LAAAF2XacNS5f5xYFHVvowNj84DziXzfuRSDAOUGKj7/+u0Ldn33NqSN7dXyp1pCMZHWaZxdXkJNgGNoJasW7W1BeFzzE5KrJI2E29e0PVLRJkqRW2xX17fiUl+ukOGhyefHoxzsAAIXZKZjZj/UdJpOEkaHjTDeW13MxGgmFoZ2A1v14AC99FRwGPKbQgStOKtK5RZFmHzMM9qTgJQ7/sHYTnKGDLohi5clPdqG+Nfh7dsc5Y/t9lO9Z4wYDADYfaMJHvOQsCYShnUBkWcbqDXtw6+s/AAAy7Ul47NJjkGROrP9NGclJuPXsIwAE98z+7rUfOExOMbO9uhnPhaaKji104Jzx/Z8quvrUkchIDp558NDft/H3l4SRWGkwQHl8Afx9cxV+9Zd/4/drNsHlDUCSgIcumoDhWSl6N69TV00eof7x/MfWGtz81+/USogoWjbtb8QlK7+CxxdcMf77847s81y2VqY9CQunFgMAtlY14+0fDvT7MYniQZ/jtQY4l9ePnyub8O2eBny1ux4bdteh2R0+YSw/IxkPXzwBpxTn6NjKrkmShP+ZezS2VjWjrLYV7/5YiQ2763DppEKcNW4wxgxOR3JoCJ3oUL5QZVvR0I5/flGGETmpyE23ISM5Cf6ADGe7F+/+eAB//XeF2jdu/8VYHFcUvQOGFpwyAs9/UYbaFg9ue+MH1Ld6cOXkEVF5U0AUK3EL7Yr6NnWIK9EpC1Nk9WtAhhz6b/D7/oAMfwDwBwLwy0Ag9E1/QEZAlkP3Cz6KLANtHj+a3V5UN7nVFbCHSrWaMWviMPzn2UcgK9Uah39p/6QnJ+HVhSfhzrc24aMt1aht8eCxf+zEY//YCQDIS7chK8WKDLsFKVYLkpNMsJhNMEsSzCYJJkmC2QSYJAkmkwSTBEiQIEmA8meTf0CDc7h6vwHqa/+VQ33BH5Dh9gXQ2O7Ftqpm3Fhej6MB7G9ow7J3fu72cZbOHIcrJ0d322OqzYJls0qx6K/fw+uXcc+6n/HEJ7swflgGBqXakJxkgtkkQQJ/D6nvot1/4xbaNc0urPqyPF5PJ4zBGTZMLs7B1CNycda4wUixijX4kZeejGd+dRze/uEAnvnXbmza36R+r6bZjZrDvEGhnvvdjDG6h7Ze/ffEkdm4bsoonHnk4Jg8/i+PHoqhDjtufuU77He2o7bFjU+2cUcERU+0+2/cEsIkSerCj1iJ5rth5aG0FZ/yjluSoFaLyodSHSoVpNIWKfRYKVYzUm0W5KbZMNRhR3FeGo4pcGB4ll34d/GSJOH8icNw/sRh2O9sx/pdddjX0Ib9De1wtnvR1O6Fy+tHu9cPX0BGICDDL8sIBIIVmF8OjkYEZFmtzAAgGjtxorGd59D/PwNxi1Bf+q/yupmk4M/bLCbYrWaU5KVh+KYUoC54qdmvlpyJstpWNLZ70OTywWKSkJxkRunQjD4dU9pbxxZm4f9+exre/GYfftznxNaqZjS7fHB5/ervJjAw/79T4olbaB9TmIUfl50dr6cjnQxz2DH3uOF6N4OiLNr9d8/f7WhD8M1vfmYy8kOXy9RLpj0JVyXIqYNEXeHqcSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBRGX1uNsd3Iu7c+fOaDwcxdkBpxPtbjfsTidaNm/WuznUS0q/U/phb+nRf/k7RxTU2/4bldCuqKgAAMyePTsaD0d6KS8D1r2jdyuojyoqKnDsscf26ecAnfovf+eIAPS8/0pyFE4McDqd+PTTT1FQUACbzXbY++3cuROzZ8/G2rVrUVJS0t+nFR5fj0h8PSL19PVwu92oqKjA1KlT4XA4ev087L99w9cjEl+PSLHqv1GptB0OB84///we37+kpASlpaXReGpD4OsRia9HpJ68Hn2psBXsv/3D1yMSX49I0e6/XIhGREQkCIY2ERGRIBjaREREgohraOfm5mLp0qXIzc2N59MmLL4ekfh6REq01yPR2qM3vh6R+HpEitXrEZXV40RERBR7HB4nIiISBEObiIhIEAxtIiIiQTC0iYiIBMHQJiIiEkRcQtvtduP222/H0KFDYbfbceKJJ+LDDz+Mx1MnpJaWFixduhS/+MUvkJ2dDUmSsGrVKr2bpYuNGzfipptuQmlpKVJTU1FYWIh58+Zh+/btejdNF5s3b8ZFF12EUaNGISUlBTk5OZgyZQreeUe/i2qw/0Zi/w1j/40Uj/4bl9BesGABHn74YVx++eV49NFHYTabce655+Lzzz+Px9MnnNraWtxzzz3YsmULJkyYoHdzdPXAAw/gzTffxJlnnolHH30U1113HT777DMce+yx2LRpk97Ni7s9e/agubkZ8+fPx6OPPoq77roLADBr1iysXLlSlzax/0Zi/w1j/40Ul/4rx9iGDRtkAPKDDz6o3tbe3i4XFxfLJ598cqyfPiG5XC65srJSlmVZ3rhxowxAfv755/VtlE6++OIL2e12R9y2fft22WazyZdffrlOrUosPp9PnjBhgnzEEUfE/bnZfzti/w1j/+1etPtvzCvtN954A2azGdddd516W3JyMq6++mqsX79evZbvQGKz2ZCfn693MxLCKaecAqvVGnHb6NGjUVpaii1btujUqsRiNptRUFAAp9MZ9+dm/+2I/TeM/bd70e6/MQ/t7777DmPGjEFGRkbE7ZMmTQIAfP/997FuAglGlmVUV1cjJydH76boprW1FbW1tdi1axceeeQRvPfeezjzzDPj3g72X+ot9t/Y9t+oXE+7K5WVlRgyZEiH25XbDhw4EOsmkGBWr16N/fv345577tG7KbpZvHgxnn76aQCAyWTCnDlz8Pjjj8e9Hey/1Fvsv7HtvzEP7fb2dthstg63Jycnq98nUmzduhU33ngjTj75ZMyfP1/v5uhm0aJFmDt3Lg4cOIDXXnsNfr8fHo8n7u1g/6XeYP8NimX/jfnwuN1uh9vt7nC7y+VSv08EAFVVVTjvvPOQmZmpzqUOVGPHjsX06dPxq1/9CuvWrUNLSwtmzpwJOc7X92H/pZ5i/w2LZf+NeWgPGTIElZWVHW5Xbhs6dGism0ACaGxsxDnnnAOn04n333+fvxeHmDt3LjZu3Bj3/a/sv9QT7L9di2b/jXloT5w4Edu3b0dTU1PE7Rs2bFC/TwOby+XCzJkzsX37dqxbtw7jxo3Tu0kJRxmGbmxsjOvzsv9Sd9h/uxfN/hvz0J47dy78fn/ExnK3243nn38eJ554IgoKCmLdBEpgfr8fF198MdavX4/XX38dJ598st5N0lVNTU2H27xeL1588UXY7fa4/0Fk/6WusP9Gikf/jflCtBNPPBEXXXQRlixZgpqaGpSUlOCFF15AeXk5nnvuuVg/fcJ6/PHH4XQ61dW377zzDvbt2wcAuPnmm5GZmaln8+Jm8eLFePvttzFz5kzU19fj5Zdfjvj+FVdcoVPL9LFw4UI0NTVhypQpGDZsGKqqqrB69Wps3boVDz30ENLS0uLaHvbfzrH/BrH/RopL/43KES3daG9vl2+99VY5Pz9fttls8gknnCC///778XjqhFVUVCQD6PSjrKxM7+bFzdSpUw/7OsTp1zOh/O///q88ffp0efDgwbLFYpGzsrLk6dOny3/72990axP7b0fsv0Hsv5Hi0X8lWY7zclQiIiLqE16ak4iISBAMbSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0B4hp06Zh2rRpvfqZZcuWQZIk1NbWdnvfESNGYMGCBRG37dixAzNmzEBmZiYkScLatWt79fxEFMT+S4qYX+WLBq758+ejrKwMK1asgMPhwPHHH49XXnkFNTU1WLRokd7NI6IusP8mJoY2RcW2bdtgMoUHbtrb27F+/Xr8/ve/x0033aTe/sorr2DTpk3s9EQJhP1XHBweF0Bra6veTeiWzWZDUlKS+vXBgwcBAA6HQ6cWESUG9l+KJoZ2glHmoX7++WdcdtllyMrKwqmnngoAePnll3HcccfBbrcjOzsbl1xyCSoqKjo8xsqVK1FcXAy73Y5JkybhX//6V6fP9dhjj6G0tBQpKSnIyspSh78O5XQ6sWDBAjgcDmRmZuLKK69EW1tbxH20c2LLli1DUVERAOC2226DJEkYMWIEpk2bhnfffRd79uyBJEnq7URGwf5Lscbh8QR10UUXYfTo0bjvvvsgyzJWrFiBu+66C/PmzcM111yDgwcP4rHHHsOUKVPw3Xffqe+In3vuOSxcuBCnnHIKFi1ahN27d2PWrFnIzs5GQUGB+vjPPPMMfvOb32Du3Ln47W9/C5fLhR9//BEbNmzAZZddFtGWefPmYeTIkbj//vvx7bff4tlnn0VeXh4eeOCBTts+Z84cOBwO3HLLLbj00ktx7rnnIi0tDampqWhsbMS+ffvwyCOPAADS0tJi8wIS6Yj9l2JGpoSydOlSGYB86aWXqreVl5fLZrNZXrFiRcR9f/rpJ9lisai3ezweOS8vT544caLsdrvV+61cuVIGIE+dOlW97fzzz5dLS0t71Jarrroq4vYLLrhAHjRoUMRtRUVF8vz589Wvy8rKZADygw8+GHG/8847Ty4qKuryeYlExf5Lscbh8QR1/fXXq5+/9dZbCAQCmDdvHmpra9WP/Px8jB49Gv/85z8BAF9//TVqampw/fXXw2q1qj+/YMECZGZmRjy+w+HAvn37sHHjxl61BQBOO+001NXVoampqT//RCLDYv+lWOHweIIaOXKk+vmOHTsgyzJGjx7d6X2VBSR79uwBgA73S0pKwqhRoyJuu/322/HRRx9h0qRJKCkpwYwZM3DZZZdh8uTJHR6/sLAw4uusrCwAQENDAzIyMnr5LyMyPvZfihWGdoKy2+3q54FAAJIk4b333oPZbO5w377MKx155JHYtm0b1q1bh/fffx9vvvkmnnjiCdx9991Yvnx5xH07e04AkGW5189LNBCw/1KsMLQFUFxcDFmWMXLkSIwZM+aw91NWfO7YsQNnnHGGervX60VZWRkmTJgQcf/U1FRcfPHFuPjii+HxeDBnzhysWLECS5YsQXJyckz+LZIkxeRxiRIV+y9FE+e0BTBnzhyYzWYsX768w7tjWZZRV1cHADj++OORm5uLp556Ch6PR73PqlWr4HQ6I35O+RmF1WrFuHHjIMsyvF5vbP4hgLoClWigYP+laGKlLYDi4mLce++9WLJkCcrLyzF79mykp6ejrKwMa9aswXXXXYdbb70VSUlJuPfee7Fw4UKcccYZuPjii1FWVobnn3++w5zYjBkzkJ+fj8mTJ2Pw4MHYsmULHn/8cZx33nlIT0+P2b/luOOOw6uvvorf/e53OOGEE5CWloaZM2fG7PmI9Mb+S9HE0BbEHXfcgTFjxuCRRx5R56wKCgowY8YMzJo1S73fddddB7/fjwcffBC33XYbjjrqKLz99tu46667Ih5v4cKFWL16NR5++GG0tLRg+PDh+M1vfoM//OEPMf13/PrXv8b333+P559/Ho888giKiorY6cnw2H8pWiSZqxGIiIiEwDltIiIiQTC0iYiIBMHQJiIiEgRDm4iISBAMbSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0iYiIBPH/ARImQkMFFr0ZAAAAAElFTkSuQmCC",
374 | "text/plain": [
375 | ""
376 | ]
377 | },
378 | "metadata": {},
379 | "output_type": "display_data"
380 | }
381 | ],
382 | "source": [
383 | "fig, axes = plt.subplots(2, 2, figsize=(4, 4), dpi=120, constrained_layout=True)\n",
384 | "\n",
385 | "for i, ax in enumerate(axes.flatten()):\n",
386 | " \n",
387 | " ax.plot(grid, pdfs[i], label=\"Posterior\")\n",
388 | " ax.axvline(data[\"redshift\"][i], c=\"C3\", label=\"True redshift\")\n",
389 | " ax.set(xlabel=\"redshift\", yticks=[])\n",
390 | "\n",
391 | "axes[0,0].legend()\n",
392 | "plt.show()"
393 | ]
394 | },
395 | {
396 | "cell_type": "code",
397 | "execution_count": null,
398 | "metadata": {},
399 | "outputs": [],
400 | "source": []
401 | }
402 | ],
403 | "metadata": {
404 | "interpreter": {
405 | "hash": "12e089b4d8e7c489ece8bd483c1f38c5ce283ca32d0fbd08723b9602a5027f48"
406 | },
407 | "kernelspec": {
408 | "display_name": "Python 3.9.1 64-bit ('pzflow': conda)",
409 | "language": "python",
410 | "name": "python3"
411 | },
412 | "language_info": {
413 | "codemirror_mode": {
414 | "name": "ipython",
415 | "version": 3
416 | },
417 | "file_extension": ".py",
418 | "mimetype": "text/x-python",
419 | "name": "python",
420 | "nbconvert_exporter": "python",
421 | "pygments_lexer": "ipython3",
422 | "version": "3.9.12"
423 | }
424 | },
425 | "nbformat": 4,
426 | "nbformat_minor": 4
427 | }
428 |
--------------------------------------------------------------------------------
/mkdocs-requirements.txt:
--------------------------------------------------------------------------------
1 | mkdocs==1.6.1
2 | mkdocs-material==9.6.10
3 | mkdocstrings==0.29.1
4 | mkdocstrings-python==1.16.8
5 | mkdocs-gen-files==0.5.0
6 | mkdocs-section-index==0.3.9
7 | mkdocs-literate-nav==0.6.2
8 | mkdocs-jupyter==0.25.1
9 | mike==2.1.3
10 | jax
11 | jaxlib
12 | pandas
13 | dill
14 | tqdm
15 | flake8
16 | mypy
17 | black
18 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: PZFlow
2 | repo_url: https://github.com/jfcrenshaw/pzflow
3 | nav:
4 | - Home: index.md
5 | - Install: install.md
6 | - Tutorials:
7 | - tutorials/index.md
8 | - Introduction: tutorials/intro.ipynb
9 | - Conditional Flows: tutorials/conditional_demo.ipynb
10 | - Convolving Gaussian Errors: tutorials/gaussian_errors.ipynb
11 | - Flow Ensembles: tutorials/ensemble_demo.ipynb
12 | - Training Weights: tutorials/weighted.ipynb
13 | - Customizing the flow: tutorials/customizing_example.ipynb
14 | - Modeling Variables with Periodic Topology: tutorials/spherical_flow_example.ipynb
15 | - Marginalizing Variables: tutorials/marginalization.ipynb
16 | - Convolving Non-Gaussian Errors: tutorials/nongaussian_errors.ipynb
17 | - Common gotchas: gotchas.md
18 | - API: API/
19 |
20 | theme:
21 | name: material
22 | palette:
23 | scheme: slate
24 | primary: pink
25 | accent: deep orange
26 | icon:
27 | logo: material/star-face
28 | features:
29 | - navigation.indexes
30 | plugins:
31 | - search
32 | - mkdocs-jupyter:
33 | theme: dark
34 | - gen-files:
35 | scripts:
36 | - docs/gen_ref_pages.py
37 | - section-index
38 | - literate-nav
39 | - mkdocstrings:
40 | handlers:
41 | python:
42 | options:
43 | docstring_style: numpy
44 | extra:
45 | version:
46 | provider: mike
47 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "pzflow"
3 | version = "3.2.0"
4 | description = "Probabilistic modeling of tabular data with normalizing flows."
5 | authors = ["John Franklin Crenshaw "]
6 | license = "MIT"
7 | repository = "https://github.com/jfcrenshaw/pzflow"
8 | readme = "README.md"
9 |
10 |
11 | [tool.poetry.dependencies]
12 | python = "^3.10"
13 | jax = ">=0.5.3"
14 | jaxlib = ">=0.5.3"
15 | optax = ">=0.2.4"
16 | pandas = ">=2.2.3"
17 | dill = ">=0.3.9"
18 | tqdm = ">=4.67.1"
19 |
20 | [tool.poetry.group.dev.dependencies]
21 | pre-commit = ">=4.2.0"
22 | black = ">=25.1.0"
23 | mypy = ">=1.15.0"
24 | isort = ">=6.0.1"
25 | flake8 = ">=7.2.0"
26 | flake8-bugbear = ">=24.12.12"
27 | flake8-builtins = ">=2.5.0"
28 | flake8-comprehensions = ">=3.16.0"
29 | flake8-docstrings = ">=1.7.0"
30 | flake8-isort = ">=6.1.2"
31 | flake8-markdown = ">=0.6.0"
32 | flake8-print = ">=5.0.0"
33 | flake8-pytest-style = ">=2.1.0"
34 | flake8-simplify = ">=0.21.0"
35 | flake8-tidy-imports = ">=4.11.0"
36 | pep8-naming = ">=0.14.1"
37 | pandas-vet = ">=2023.8.2"
38 | pytest = ">=8.3.5"
39 | pytest-cov = ">=6.0.0"
40 | pytest-xdist = ">=3.6.1"
41 | jupyter = ">=1.1.1"
42 | matplotlib = ">=3.10.1"
43 | toml = ">=0.10.2"
44 |
45 | [build-system]
46 | requires = ["poetry-core>=1.0.0"]
47 | build-backend = "poetry.core.masonry.api"
48 |
49 | [tool.black]
50 | line-length=79
51 |
52 | [tool.isort]
53 | multi_line_output = 3
54 | include_trailing_comma = true
55 | force_grid_wrap = 0
56 | use_parentheses = true
57 | line_length = 79
58 |
59 | [tool.pyright]
60 | reportGeneralTypeIssues = false
61 |
--------------------------------------------------------------------------------
/pzflow/__init__.py:
--------------------------------------------------------------------------------
1 | """Import modules and set version."""
2 |
3 | from pzflow.flow import Flow
4 | from pzflow.flowEnsemble import FlowEnsemble
5 |
6 | __version__ = "3.2.0"
7 |
--------------------------------------------------------------------------------
/pzflow/bijectors.py:
--------------------------------------------------------------------------------
1 | """Define the bijectors used in the normalizing flows."""
2 | from functools import update_wrapper
3 | from typing import Callable, Sequence, Tuple, Union
4 |
5 | import jax.numpy as jnp
6 | from jax import random
7 | from jax.nn import softmax, softplus
8 |
9 | from pzflow.utils import DenseReluNetwork, RationalQuadraticSpline
10 |
11 | # define a type alias for Jax Pytrees
12 | Pytree = Union[tuple, list]
13 | Bijector_Info = Tuple[str, tuple]
14 |
15 |
16 | class ForwardFunction:
17 | """Return the output and log_det of the forward bijection on the inputs.
18 |
19 | ForwardFunction of a Bijector, originally returned by the
20 | InitFunction of the Bijector.
21 |
22 | Parameters
23 | ----------
24 | params : a Jax pytree
25 | A pytree of bijector parameters.
26 | This usually looks like a nested tuple or list of parameters.
27 | inputs : jnp.ndarray
28 | The data to be transformed by the bijection.
29 |
30 | Returns
31 | -------
32 | outputs : jnp.ndarray
33 | Result of the forward bijection applied to the inputs.
34 | log_det : jnp.ndarray
35 | The log determinant of the Jacobian evaluated at the inputs.
36 | """
37 |
38 | def __init__(self, func: Callable) -> None:
39 | self._func = func
40 |
41 | def __call__(
42 | self, params: Pytree, inputs: jnp.ndarray, **kwargs
43 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
44 | return self._func(params, inputs, **kwargs)
45 |
46 |
47 | class InverseFunction:
48 | """Return the output and log_det of the inverse bijection on the inputs.
49 |
50 | InverseFunction of a Bijector, originally returned by the
51 | InitFunction of the Bijector.
52 |
53 | Parameters
54 | ----------
55 | params : a Jax pytree
56 | A pytree of bijector parameters.
57 | This usually looks like a nested tuple or list of parameters.
58 | inputs : jnp.ndarray
59 | The data to be transformed by the bijection.
60 |
61 | Returns
62 | -------
63 | outputs : jnp.ndarray
64 | Result of the inverse bijection applied to the inputs.
65 | log_det : jnp.ndarray
66 | The log determinant of the Jacobian evaluated at the inputs.
67 | """
68 |
69 | def __init__(self, func: Callable) -> None:
70 | self._func = func
71 |
72 | def __call__(
73 | self, params: Pytree, inputs: jnp.ndarray, **kwargs
74 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
75 | return self._func(params, inputs, **kwargs)
76 |
77 |
78 | class InitFunction:
79 | """Initialize the corresponding Bijector.
80 |
81 | InitFunction returned by the initialization of a Bijector.
82 |
83 | Parameters
84 | ----------
85 | rng : jnp.ndarray
86 | A Random Number Key from jax.random.PRNGKey.
87 | input_dim : int
88 | The input dimension of the bijection.
89 |
90 | Returns
91 | -------
92 | params : a Jax pytree
93 | A pytree of bijector parameters.
94 | This usually looks like a nested tuple or list of parameters.
95 | forward_fun : ForwardFunction
96 | The forward function of the Bijector.
97 | inverse_fun : InverseFunction
98 | The inverse function of the Bijector.
99 | """
100 |
101 | def __init__(self, func: Callable) -> None:
102 | self._func = func
103 |
104 | def __call__(
105 | self, rng: jnp.ndarray, input_dim: int, **kwargs
106 | ) -> Tuple[Pytree, ForwardFunction, InverseFunction]:
107 | return self._func(rng, input_dim, **kwargs)
108 |
109 |
110 | class Bijector:
111 | """Wrapper class for bijector functions"""
112 |
113 | def __init__(self, func: Callable) -> None:
114 | self._func = func
115 | update_wrapper(self, func)
116 |
117 | def __call__(self, *args, **kwargs) -> Tuple[InitFunction, Bijector_Info]:
118 | return self._func(*args, **kwargs)
119 |
120 |
121 | @Bijector
122 | def Chain(
123 | *inputs: Sequence[Tuple[InitFunction, Bijector_Info]]
124 | ) -> Tuple[InitFunction, Bijector_Info]:
125 | """Bijector that chains multiple InitFunctions into a single InitFunction.
126 |
127 | Parameters
128 | ----------
129 | inputs : (Bijector1(), Bijector2(), ...)
130 | A container of Bijector calls to be chained together.
131 |
132 | Returns
133 | -------
134 | InitFunction
135 | The InitFunction of the total chained Bijector.
136 | Bijector_Info
137 | Tuple('Chain', Tuple(Bijector_Info for each bijection in the chain))
138 | This allows the chain to be recreated later.
139 | """
140 |
141 | init_funs = tuple(i[0] for i in inputs)
142 | bijector_info = ("Chain", tuple(i[1] for i in inputs))
143 |
144 | @InitFunction
145 | def init_fun(rng, input_dim, **kwargs):
146 | all_params, forward_funs, inverse_funs = [], [], []
147 | for init_f in init_funs:
148 | rng, layer_rng = random.split(rng)
149 | param, forward_f, inverse_f = init_f(layer_rng, input_dim)
150 |
151 | all_params.append(param)
152 | forward_funs.append(forward_f)
153 | inverse_funs.append(inverse_f)
154 |
155 | def bijector_chain(params, bijectors, inputs, **kwargs):
156 | log_dets = jnp.zeros(inputs.shape[0])
157 | for bijector, param in zip(bijectors, params):
158 | inputs, log_det = bijector(param, inputs, **kwargs)
159 | log_dets += log_det
160 | return inputs, log_dets
161 |
162 | @ForwardFunction
163 | def forward_fun(params, inputs, **kwargs):
164 | return bijector_chain(params, forward_funs, inputs, **kwargs)
165 |
166 | @InverseFunction
167 | def inverse_fun(params, inputs, **kwargs):
168 | return bijector_chain(
169 | params[::-1], inverse_funs[::-1], inputs, **kwargs
170 | )
171 |
172 | return all_params, forward_fun, inverse_fun
173 |
174 | return init_fun, bijector_info
175 |
176 |
177 | @Bijector
178 | def ColorTransform(
179 | ref_idx: int, mag_idx: int
180 | ) -> Tuple[InitFunction, Bijector_Info]:
181 | """Bijector that calculates photometric colors from magnitudes.
182 |
183 | Using ColorTransform restricts and impacts the order of columns in the
184 | corresponding normalizing flow. See the notes below for an example.
185 |
186 | Parameters
187 | ----------
188 | ref_idx : int
189 | The index corresponding to the column of the reference band, which
190 | serves as a proxy for overall luminosity.
191 | mag_idx : arraylike of int
192 | The indices of the magnitude columns from which colors will be calculated.
193 |
194 | Returns
195 | -------
196 | InitFunction
197 | The InitFunction of the ColorTransform Bijector.
198 | Bijector_Info
199 | Tuple of the Bijector name and the input parameters.
200 | This allows it to be recreated later.
201 |
202 | Notes
203 | -----
204 | ColorTransform requires careful management of column order in the bijector.
205 | This is best explained with an example:
206 |
207 | Assume we have data
208 | [redshift, u, g, ellipticity, r, i, z, y, mass]
209 | Then
210 | ColorTransform(ref_idx=4, mag_idx=[1, 2, 4, 5, 6, 7])
211 | will output
212 | [redshift, ellipticity, mass, r, u-g, g-r, r-i, i-z, z-y]
213 |
214 | Notice how the non-magnitude columns are aggregated at the front of the
215 | array, maintaining their relative order from the original array.
216 | These values are then followed by the reference magnitude, and the new colors.
217 |
218 | Also notice that the magnitudes indices in mag_idx are assumed to be
219 | adjacent colors. E.g. mag_idx=[1, 2, 5, 4, 6, 7] would have produced
220 | the colors [u-g, g-i, i-r, r-z, z-y]. You can chain multiple ColorTransforms
221 | back-to-back to create colors in a non-adjacent manner.
222 | """
223 |
224 | # validate parameters
225 | if ref_idx <= 0:
226 | raise ValueError("ref_idx must be a positive integer.")
227 | if not isinstance(ref_idx, int):
228 | raise ValueError("ref_idx must be an integer.")
229 | if ref_idx not in mag_idx:
230 | raise ValueError("ref_idx must be in mag_idx.")
231 |
232 | bijector_info = ("ColorTransform", (ref_idx, mag_idx))
233 |
234 | # convert mag_idx to an array
235 | mag_idx = jnp.array(mag_idx)
236 |
237 | @InitFunction
238 | def init_fun(rng, input_dim, **kwargs):
239 | # array of all the indices
240 | all_idx = jnp.arange(input_dim)
241 | # indices for columns to stick at the front
242 | front_idx = jnp.setdiff1d(all_idx, mag_idx)
243 | # the index corresponding to the first magnitude
244 | mag0_idx = len(front_idx)
245 |
246 | # the new column order
247 | new_idx = jnp.concatenate((front_idx, mag_idx))
248 | # the new column for the reference magnitude
249 | new_ref = jnp.where(new_idx == ref_idx)[0][0]
250 |
251 | # define a convenience function for the forward_fun below
252 | # if the first magnitude is the reference mag, do nothing
253 | if ref_idx == mag_idx[0]:
254 |
255 | def mag0(outputs):
256 | return outputs
257 |
258 | # if the first magnitude is not the reference mag,
259 | # then we need to calculate the first magnitude (mag[0])
260 | else:
261 |
262 | def mag0(outputs):
263 | return outputs.at[:, mag0_idx].set(
264 | outputs[:, mag0_idx] + outputs[:, new_ref],
265 | indices_are_sorted=True,
266 | unique_indices=True,
267 | )
268 |
269 | @ForwardFunction
270 | def forward_fun(params, inputs, **kwargs):
271 | # re-order columns and calculate colors
272 | outputs = jnp.hstack(
273 | (
274 | inputs[:, front_idx], # other values
275 | inputs[:, ref_idx, None], # ref mag
276 | -jnp.diff(inputs[:, mag_idx]), # colors
277 | )
278 | )
279 | # determinant of Jacobian is zero
280 | log_det = jnp.zeros(inputs.shape[0])
281 | return outputs, log_det
282 |
283 | @InverseFunction
284 | def inverse_fun(params, inputs, **kwargs):
285 | # convert all colors to be in terms of the first magnitude, mag[0]
286 | outputs = jnp.hstack(
287 | (
288 | inputs[:, 0:mag0_idx], # other values unchanged
289 | inputs[:, mag0_idx, None], # reference mag unchanged
290 | jnp.cumsum(
291 | inputs[:, mag0_idx + 1 :], axis=-1
292 | ), # all colors mag[i-1] - mag[i] --> mag[0] - mag[i]
293 | )
294 | )
295 | # calculate mag[0]
296 | outputs = mag0(outputs)
297 | # mag[i] = mag[0] - (mag[0] - mag[i])
298 | outputs = outputs.at[:, mag0_idx + 1 :].set(
299 | outputs[:, mag0_idx, None] - outputs[:, mag0_idx + 1 :],
300 | indices_are_sorted=True,
301 | unique_indices=True,
302 | )
303 | # return to original ordering
304 | outputs = outputs[:, jnp.argsort(new_idx)]
305 | # determinant of Jacobian is zero
306 | log_det = jnp.zeros(inputs.shape[0])
307 | return outputs, log_det
308 |
309 | return (), forward_fun, inverse_fun
310 |
311 | return init_fun, bijector_info
312 |
313 |
314 | @Bijector
315 | def InvSoftplus(
316 | column_idx: int, sharpness: float = 1
317 | ) -> Tuple[InitFunction, Bijector_Info]:
318 | """Bijector that applies inverse softplus to the specified column(s).
319 |
320 | Applying the inverse softplus ensures that samples from that column will
321 | always be non-negative. This is because samples are the output of the
322 | inverse bijection -- so samples will have a softplus applied to them.
323 |
324 | Parameters
325 | ----------
326 | column_idx : int
327 | An index or iterable of indices corresponding to the column(s)
328 | you wish to be transformed.
329 | sharpness : float; default=1
330 | The sharpness(es) of the softplus transformation. If more than one
331 | is provided, the list of sharpnesses must be of the same length as
332 | column_idx.
333 |
334 | Returns
335 | -------
336 | InitFunction
337 | The InitFunction of the Softplus Bijector.
338 | Bijector_Info
339 | Tuple of the Bijector name and the input parameters.
340 | This allows it to be recreated later.
341 | """
342 |
343 | idx = jnp.atleast_1d(column_idx)
344 | k = jnp.atleast_1d(sharpness)
345 | if len(idx) != len(k) and len(k) != 1:
346 | raise ValueError(
347 | "Please provide either a single sharpness or one for each column index."
348 | )
349 |
350 | bijector_info = ("InvSoftplus", (column_idx, sharpness))
351 |
352 | @InitFunction
353 | def init_fun(rng, input_dim, **kwargs):
354 | @ForwardFunction
355 | def forward_fun(params, inputs, **kwargs):
356 | outputs = inputs.at[:, idx].set(
357 | jnp.log(-1 + jnp.exp(k * inputs[:, idx])) / k,
358 | )
359 | log_det = jnp.log(1 + jnp.exp(-k * outputs[:, idx])).sum(axis=1)
360 | return outputs, log_det
361 |
362 | @InverseFunction
363 | def inverse_fun(params, inputs, **kwargs):
364 | outputs = inputs.at[:, idx].set(
365 | jnp.log(1 + jnp.exp(k * inputs[:, idx])) / k,
366 | )
367 | log_det = -jnp.log(1 + jnp.exp(-k * inputs[:, idx])).sum(axis=1)
368 | return outputs, log_det
369 |
370 | return (), forward_fun, inverse_fun
371 |
372 | return init_fun, bijector_info
373 |
374 |
375 | @Bijector
376 | def NeuralSplineCoupling(
377 | K: int = 16,
378 | B: float = 5,
379 | hidden_layers: int = 2,
380 | hidden_dim: int = 128,
381 | transformed_dim: int = None,
382 | n_conditions: int = 0,
383 | periodic: bool = False,
384 | ) -> Tuple[InitFunction, Bijector_Info]:
385 | """A coupling layer bijection with rational quadratic splines.
386 |
387 | This Bijector is a Coupling Layer [1,2], and as such only transforms
388 | the second half of input dimensions (or the last N dimensions, where
389 | N = transformed_dim). In order to transform all of the dimensions,
390 | you need multiple Couplings interspersed with Bijectors that change
391 | the order of inputs dimensions, e.g., Reverse, Shuffle, Roll, etc.
392 |
393 | NeuralSplineCoupling uses piecewise rational quadratic splines,
394 | as developed in [3].
395 |
396 | If periodic=True, then this is a Circular Spline as described in [4].
397 |
398 | Parameters
399 | ----------
400 | K : int; default=16
401 | Number of bins in the spline (the number of knots is K+1).
402 | B : float; default=5
403 | Range of the splines.
404 | If periodic=False, outside of (-B,B), the transformation is just
405 | the identity. If periodic=True, the input is mapped into the
406 | appropriate location in the range (-B,B).
407 | hidden_layers : int; default=2
408 | The number of hidden layers in the neural network used to calculate
409 | the positions and derivatives of the spline knots.
410 | hidden_dim : int; default=128
411 | The width of the hidden layers in the neural network used to
412 | calculate the positions and derivatives of the spline knots.
413 | transformed_dim : int; optional
414 | The number of dimensions transformed by the splines.
415 | Default is ceiling(input_dim /2).
416 | n_conditions : int; default=0
417 | The number of variables to condition the bijection on.
418 | periodic : bool; default=False
419 | Whether to make this a periodic, Circular Spline [4].
420 |
421 | Returns
422 | -------
423 | InitFunction
424 | The InitFunction of the NeuralSplineCoupling Bijector.
425 | Bijector_Info
426 | Tuple of the Bijector name and the input parameters.
427 | This allows it to be recreated later.
428 |
429 | References
430 | ----------
431 | [1] Laurent Dinh, David Krueger, Yoshua Bengio. NICE: Non-linear
432 | Independent Components Estimation. arXiv: 1605.08803, 2015.
433 | http://arxiv.org/abs/1605.08803
434 | [2] Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio.
435 | Density Estimation Using Real NVP. arXiv: 1605.08803, 2017.
436 | http://arxiv.org/abs/1605.08803
437 | [3] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios.
438 | Neural Spline Flows. arXiv:1906.04032, 2019.
439 | https://arxiv.org/abs/1906.04032
440 | [4] Rezende, Danilo Jimenez et al.
441 | Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020
442 | http://arxiv.org/abs/2002.02428
443 | """
444 |
445 | if not isinstance(periodic, bool):
446 | raise ValueError("`periodic` must be True or False.")
447 |
448 | bijector_info = (
449 | "NeuralSplineCoupling",
450 | (
451 | K,
452 | B,
453 | hidden_layers,
454 | hidden_dim,
455 | transformed_dim,
456 | n_conditions,
457 | periodic,
458 | ),
459 | )
460 |
461 | @InitFunction
462 | def init_fun(rng, input_dim, **kwargs):
463 | if transformed_dim is None:
464 | upper_dim = input_dim // 2 # variables that determine NN params
465 | lower_dim = (
466 | input_dim - upper_dim
467 | ) # variables transformed by the NN
468 | else:
469 | upper_dim = input_dim - transformed_dim
470 | lower_dim = transformed_dim
471 |
472 | # create the neural network that will take in the upper dimensions and
473 | # will return the spline parameters to transform the lower dimensions
474 | network_init_fun, network_apply_fun = DenseReluNetwork(
475 | (3 * K - 1 + int(periodic)) * lower_dim, hidden_layers, hidden_dim
476 | )
477 | _, network_params = network_init_fun(rng, (upper_dim + n_conditions,))
478 |
479 | # calculate spline parameters as a function of the upper variables
480 | def spline_params(params, upper, conditions):
481 | key = jnp.hstack((upper, conditions))[
482 | :, : upper_dim + n_conditions
483 | ]
484 | outputs = network_apply_fun(params, key)
485 | outputs = jnp.reshape(
486 | outputs, [-1, lower_dim, 3 * K - 1 + int(periodic)]
487 | )
488 | W, H, D = jnp.split(outputs, [K, 2 * K], axis=2)
489 | W = 2 * B * softmax(W)
490 | H = 2 * B * softmax(H)
491 | D = softplus(D)
492 | return W, H, D
493 |
494 | @ForwardFunction
495 | def forward_fun(params, inputs, conditions, **kwargs):
496 | # lower dimensions are transformed as function of upper dimensions
497 | upper, lower = inputs[:, :upper_dim], inputs[:, upper_dim:]
498 | # widths, heights, derivatives = function(upper dimensions)
499 | W, H, D = spline_params(params, upper, conditions)
500 | # transform the lower dimensions with the Rational Quadratic Spline
501 | lower, log_det = RationalQuadraticSpline(
502 | lower, W, H, D, B, periodic, inverse=False
503 | )
504 | outputs = jnp.hstack((upper, lower))
505 | return outputs, log_det
506 |
507 | @InverseFunction
508 | def inverse_fun(params, inputs, conditions, **kwargs):
509 | # lower dimensions are transformed as function of upper dimensions
510 | upper, lower = inputs[:, :upper_dim], inputs[:, upper_dim:]
511 | # widths, heights, derivatives = function(upper dimensions)
512 | W, H, D = spline_params(params, upper, conditions)
513 | # transform the lower dimensions with the Rational Quadratic Spline
514 | lower, log_det = RationalQuadraticSpline(
515 | lower, W, H, D, B, periodic, inverse=True
516 | )
517 | outputs = jnp.hstack((upper, lower))
518 | return outputs, log_det
519 |
520 | return network_params, forward_fun, inverse_fun
521 |
522 | return init_fun, bijector_info
523 |
524 |
525 | @Bijector
526 | def Reverse() -> Tuple[InitFunction, Bijector_Info]:
527 | """Bijector that reverses the order of inputs.
528 |
529 | Returns
530 | -------
531 | InitFunction
532 | The InitFunction of the the Reverse Bijector.
533 | Bijector_Info
534 | Tuple of the Bijector name and the input parameters.
535 | This allows it to be recreated later.
536 | """
537 |
538 | bijector_info = ("Reverse", ())
539 |
540 | @InitFunction
541 | def init_fun(rng, input_dim, **kwargs):
542 | @ForwardFunction
543 | def forward_fun(params, inputs, **kwargs):
544 | outputs = inputs[:, ::-1]
545 | log_det = jnp.zeros(inputs.shape[0])
546 | return outputs, log_det
547 |
548 | @InverseFunction
549 | def inverse_fun(params, inputs, **kwargs):
550 | outputs = inputs[:, ::-1]
551 | log_det = jnp.zeros(inputs.shape[0])
552 | return outputs, log_det
553 |
554 | return (), forward_fun, inverse_fun
555 |
556 | return init_fun, bijector_info
557 |
558 |
559 | @Bijector
560 | def Roll(shift: int = 1) -> Tuple[InitFunction, Bijector_Info]:
561 | """Bijector that rolls inputs along their last column using jnp.roll.
562 |
563 | Parameters
564 | ----------
565 | shift : int; default=1
566 | The number of places to roll.
567 |
568 | Returns
569 | -------
570 | InitFunction
571 | The InitFunction of the the Roll Bijector.
572 | Bijector_Info
573 | Tuple of the Bijector name and the input parameters.
574 | This allows it to be recreated later.
575 | """
576 |
577 | if not isinstance(shift, int):
578 | raise ValueError("shift must be an integer.")
579 |
580 | bijector_info = ("Roll", (shift,))
581 |
582 | @InitFunction
583 | def init_fun(rng, input_dim, **kwargs):
584 | @ForwardFunction
585 | def forward_fun(params, inputs, **kwargs):
586 | outputs = jnp.roll(inputs, shift=shift, axis=-1)
587 | log_det = jnp.zeros(inputs.shape[0])
588 | return outputs, log_det
589 |
590 | @InverseFunction
591 | def inverse_fun(params, inputs, **kwargs):
592 | outputs = jnp.roll(inputs, shift=-shift, axis=-1)
593 | log_det = jnp.zeros(inputs.shape[0])
594 | return outputs, log_det
595 |
596 | return (), forward_fun, inverse_fun
597 |
598 | return init_fun, bijector_info
599 |
600 |
601 | @Bijector
602 | def RollingSplineCoupling(
603 | nlayers: int,
604 | shift: int = 1,
605 | K: int = 16,
606 | B: float = 5,
607 | hidden_layers: int = 2,
608 | hidden_dim: int = 128,
609 | transformed_dim: int = None,
610 | n_conditions: int = 0,
611 | periodic: bool = False,
612 | ) -> Tuple[InitFunction, Bijector_Info]:
613 | """Bijector that alternates NeuralSplineCouplings and Roll bijections.
614 |
615 | Parameters
616 | ----------
617 | nlayers : int
618 | The number of (NeuralSplineCoupling(), Roll()) couplets in the chain.
619 | shift : int
620 | How far the inputs are shifted on each Roll().
621 | K : int; default=16
622 | Number of bins in the RollingSplineCoupling.
623 | B : float; default=5
624 | Range of the splines in the RollingSplineCoupling.
625 | If periodic=False, outside of (-B,B), the transformation is just
626 | the identity. If periodic=True, the input is mapped into the
627 | appropriate location in the range (-B,B).
628 | hidden_layers : int; default=2
629 | The number of hidden layers in the neural network used to calculate
630 | the bins and derivatives in the RollingSplineCoupling.
631 | hidden_dim : int; default=128
632 | The width of the hidden layers in the neural network used to
633 | calculate the bins and derivatives in the RollingSplineCoupling.
634 | transformed_dim : int; optional
635 | The number of dimensions transformed by the splines.
636 | Default is ceiling(input_dim /2).
637 | n_conditions : int; default=0
638 | The number of variables to condition the bijection on.
639 | periodic : bool; default=False
640 | Whether to make this a periodic, Circular Spline
641 |
642 | Returns
643 | -------
644 | InitFunction
645 | The InitFunction of the RollingSplineCoupling Bijector.
646 | Bijector_Info
647 | Nested tuple of the Bijector name and input parameters. This allows
648 | it to be recreated later.
649 |
650 | """
651 | return Chain(
652 | *(
653 | NeuralSplineCoupling(
654 | K=K,
655 | B=B,
656 | hidden_layers=hidden_layers,
657 | hidden_dim=hidden_dim,
658 | transformed_dim=transformed_dim,
659 | n_conditions=n_conditions,
660 | periodic=periodic,
661 | ),
662 | Roll(shift),
663 | )
664 | * nlayers
665 | )
666 |
667 |
668 | @Bijector
669 | def Scale(scale: float) -> Tuple[InitFunction, Bijector_Info]:
670 | """Bijector that multiplies inputs by a scalar.
671 |
672 | Parameters
673 | ----------
674 | scale : float
675 | Factor by which to scale inputs.
676 |
677 | Returns
678 | -------
679 | InitFunction
680 | The InitFunction of the the Scale Bijector.
681 | Bijector_Info
682 | Tuple of the Bijector name and the input parameters.
683 | This allows it to be recreated later.
684 | """
685 |
686 | if isinstance(scale, jnp.ndarray):
687 | if scale.dtype != jnp.float32:
688 | raise ValueError("scale must be a float or array of floats.")
689 | elif not isinstance(scale, float):
690 | raise ValueError("scale must be a float or array of floats.")
691 |
692 | bijector_info = ("Scale", (scale,))
693 |
694 | @InitFunction
695 | def init_fun(rng, input_dim, **kwargs):
696 | @ForwardFunction
697 | def forward_fun(params, inputs, **kwargs):
698 | outputs = scale * inputs
699 | log_det = jnp.log(scale ** inputs.shape[-1]) * jnp.ones(
700 | inputs.shape[0]
701 | )
702 | return outputs, log_det
703 |
704 | @InverseFunction
705 | def inverse_fun(params, inputs, **kwargs):
706 | outputs = 1 / scale * inputs
707 | log_det = -jnp.log(scale ** inputs.shape[-1]) * jnp.ones(
708 | inputs.shape[0]
709 | )
710 | return outputs, log_det
711 |
712 | return (), forward_fun, inverse_fun
713 |
714 | return init_fun, bijector_info
715 |
716 |
717 | @Bijector
718 | def ShiftBounds(
719 | min: float, max: float, B: float = 5
720 | ) -> Tuple[InitFunction, Bijector_Info]:
721 | """Bijector shifts the bounds of inputs so the lie in the range (-B, B).
722 |
723 | Parameters
724 | ----------
725 | min : float
726 | The minimum of the input range.
727 | min : float
728 | The maximum of the input range.
729 | B : float; default=5
730 | The extent of the output bounds, which will be (-B, B).
731 |
732 | Returns
733 | -------
734 | InitFunction
735 | The InitFunction of the ShiftBounds Bijector.
736 | Bijector_Info
737 | Tuple of the Bijector name and the input parameters.
738 | This allows it to be recreated later.
739 | """
740 |
741 | min = jnp.atleast_1d(min)
742 | max = jnp.atleast_1d(max)
743 | if len(min) != len(max):
744 | raise ValueError(
745 | "Lengths of min and max do not match. "
746 | + "Please provide either a single min and max, "
747 | + "or a min and max for each dimension."
748 | )
749 | if (min > max).any():
750 | raise ValueError("All mins must be less than maxes.")
751 |
752 | bijector_info = ("ShiftBounds", (min, max, B))
753 |
754 | mean = (max + min) / 2
755 | half_range = (max - min) / 2
756 |
757 | @InitFunction
758 | def init_fun(rng, input_dim, **kwargs):
759 | @ForwardFunction
760 | def forward_fun(params, inputs, **kwargs):
761 | outputs = B * (inputs - mean) / half_range
762 | log_det = jnp.log(jnp.prod(B / half_range)) * jnp.ones(
763 | inputs.shape[0]
764 | )
765 | return outputs, log_det
766 |
767 | @InverseFunction
768 | def inverse_fun(params, inputs, **kwargs):
769 | outputs = inputs * half_range / B + mean
770 | log_det = jnp.log(jnp.prod(half_range / B)) * jnp.ones(
771 | inputs.shape[0]
772 | )
773 | return outputs, log_det
774 |
775 | return (), forward_fun, inverse_fun
776 |
777 | return init_fun, bijector_info
778 |
779 |
780 | @Bijector
781 | def Shuffle() -> Tuple[InitFunction, Bijector_Info]:
782 | """Bijector that randomly permutes inputs.
783 |
784 | Returns
785 | -------
786 | InitFunction
787 | The InitFunction of the Shuffle Bijector.
788 | Bijector_Info
789 | Tuple of the Bijector name and the input parameters.
790 | This allows it to be recreated later.
791 | """
792 |
793 | bijector_info = ("Shuffle", ())
794 |
795 | @InitFunction
796 | def init_fun(rng, input_dim, **kwargs):
797 | perm = random.permutation(rng, jnp.arange(input_dim))
798 | inv_perm = jnp.argsort(perm)
799 |
800 | @ForwardFunction
801 | def forward_fun(params, inputs, **kwargs):
802 | outputs = inputs[:, perm]
803 | log_det = jnp.zeros(inputs.shape[0])
804 | return outputs, log_det
805 |
806 | @InverseFunction
807 | def inverse_fun(params, inputs, **kwargs):
808 | outputs = inputs[:, inv_perm]
809 | log_det = jnp.zeros(inputs.shape[0])
810 | return outputs, log_det
811 |
812 | return (), forward_fun, inverse_fun
813 |
814 | return init_fun, bijector_info
815 |
816 |
817 | @Bijector
818 | def StandardScaler(
819 | means: jnp.array, stds: jnp.array
820 | ) -> Tuple[InitFunction, Bijector_Info]:
821 | """Bijector that applies standard scaling to each input.
822 |
823 | Each input dimension i has an associated mean u_i and standard dev s_i.
824 | Each input is rescaled as (input[i] - u_i)/s_i, so that each input dimension
825 | has mean zero and unit variance.
826 |
827 | Parameters
828 | ----------
829 | means : jnp.ndarray
830 | The mean of each column.
831 | stds : jnp.ndarray
832 | The standard deviation of each column.
833 |
834 | Returns
835 | -------
836 | InitFunction
837 | The InitFunction of the StandardScaler Bijector.
838 | Bijector_Info
839 | Tuple of the Bijector name and the input parameters.
840 | This allows it to be recreated later.
841 | """
842 |
843 | bijector_info = ("StandardScaler", (means, stds))
844 |
845 | @InitFunction
846 | def init_fun(rng, input_dim, **kwargs):
847 | @ForwardFunction
848 | def forward_fun(params, inputs, **kwargs):
849 | outputs = (inputs - means) / stds
850 | log_det = jnp.log(1 / jnp.prod(stds)) * jnp.ones(inputs.shape[0])
851 | return outputs, log_det
852 |
853 | @InverseFunction
854 | def inverse_fun(params, inputs, **kwargs):
855 | outputs = inputs * stds + means
856 | log_det = jnp.log(jnp.prod(stds)) * jnp.ones(inputs.shape[0])
857 | return outputs, log_det
858 |
859 | return (), forward_fun, inverse_fun
860 |
861 | return init_fun, bijector_info
862 |
863 |
864 | @Bijector
865 | def UniformDequantizer(column_idx: int) -> Tuple[InitFunction, Bijector_Info]:
866 | """Bijector that dequantizes discrete variables with uniform noise.
867 |
868 | Dequantizers are necessary for modeling discrete values with a flow.
869 | Note that this isn't technically a bijector.
870 |
871 | Parameters
872 | ----------
873 | column_idx : int
874 | An index or iterable of indices corresponding to the column(s) with
875 | discrete values.
876 |
877 | Returns
878 | -------
879 | InitFunction
880 | The InitFunction of the UniformDequantizer Bijector.
881 | Bijector_Info
882 | Tuple of the Bijector name and the input parameters.
883 | This allows it to be recreated later.
884 | """
885 |
886 | bijector_info = ("UniformDequantizer", (column_idx,))
887 | column_idx = jnp.array(column_idx)
888 |
889 | @InitFunction
890 | def init_fun(rng, input_dim, **kwargs):
891 | @ForwardFunction
892 | def forward_fun(params, inputs, **kwargs):
893 | u = random.uniform(
894 | random.PRNGKey(0), shape=inputs[:, column_idx].shape
895 | )
896 | outputs = inputs.astype(float)
897 | outputs.at[:, column_idx].set(outputs[:, column_idx] + u)
898 | log_det = jnp.zeros(inputs.shape[0])
899 | return outputs, log_det
900 |
901 | @InverseFunction
902 | def inverse_fun(params, inputs, **kwargs):
903 | outputs = inputs.at[:, column_idx].set(
904 | jnp.floor(inputs[:, column_idx])
905 | )
906 | log_det = jnp.zeros(inputs.shape[0])
907 | return outputs, log_det
908 |
909 | return (), forward_fun, inverse_fun
910 |
911 | return init_fun, bijector_info
912 |
--------------------------------------------------------------------------------
/pzflow/distributions.py:
--------------------------------------------------------------------------------
1 | """Define the latent distributions used in the normalizing flows."""
2 | import sys
3 | from abc import ABC, abstractmethod
4 | from typing import Union
5 |
6 | import jax.numpy as jnp
7 | import numpy as np
8 | from jax import random
9 | from jax.scipy.special import gammaln
10 | from jax.scipy.stats import beta, multivariate_normal
11 |
12 | from pzflow.bijectors import Pytree
13 |
14 | epsilon = sys.float_info.epsilon
15 |
16 |
17 | class LatentDist(ABC):
18 | """Base class for latent distributions."""
19 |
20 | info = ("LatentDist", ())
21 |
22 | @abstractmethod
23 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
24 | """Calculate log-probability of the inputs."""
25 |
26 | @abstractmethod
27 | def sample(
28 | self, params: Pytree, nsamples: int, seed: int = None
29 | ) -> jnp.ndarray:
30 | """Sample from the distribution."""
31 |
32 |
33 | def _mahalanobis_and_logdet(x: jnp.array, cov: jnp.array) -> tuple:
34 | # Calculate mahalanobis distance and log_det of cov.
35 | # Uses scipy method, explained here:
36 | # http://gregorygundersen.com/blog/2019/10/30/scipy-multivariate/
37 |
38 | vals, vecs = jnp.linalg.eigh(cov)
39 | U = vecs * jnp.sqrt(1 / vals[..., None])
40 | maha = jnp.square(U @ x[..., None]).reshape(x.shape[0], -1).sum(axis=1)
41 | log_det = jnp.log(vals).sum(axis=-1)
42 | return maha, log_det
43 |
44 |
45 | class CentBeta(LatentDist):
46 | """A centered Beta distribution.
47 |
48 | This distribution is just a regular Beta distribution, scaled and shifted
49 | to have support on the domain [-B, B] in each dimension.
50 |
51 | Alpha and beta parameters for each dimension are learned during training.
52 | """
53 |
54 | def __init__(self, input_dim: int, B: float = 5) -> None:
55 | """
56 | Parameters
57 | ----------
58 | input_dim : int
59 | The dimension of the distribution.
60 | B : float; default=5
61 | The distribution has support (-B, B) along each dimension.
62 | """
63 | self.input_dim = input_dim
64 | self.B = B
65 |
66 | # save dist info
67 | self._params = tuple([(0.0, 0.0) for i in range(input_dim)])
68 | self.info = ("CentBeta", (input_dim, B))
69 |
70 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
71 | """Calculates log probability density of inputs.
72 |
73 | Parameters
74 | ----------
75 | params : a Jax pytree
76 | Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta)
77 | for the Nth dimension.
78 | inputs : jnp.ndarray
79 | Input data for which log probability density is calculated.
80 |
81 | Returns
82 | -------
83 | jnp.ndarray
84 | Device array of shape (inputs.shape[0],).
85 | """
86 | log_prob = jnp.hstack(
87 | [
88 | beta.logpdf(
89 | inputs[:, i],
90 | a=jnp.exp(params[i][0]),
91 | b=jnp.exp(params[i][1]),
92 | loc=-self.B,
93 | scale=2 * self.B,
94 | ).reshape(-1, 1)
95 | for i in range(self.input_dim)
96 | ]
97 | ).sum(axis=1)
98 |
99 | return log_prob
100 |
101 | def sample(
102 | self, params: Pytree, nsamples: int, seed: int = None
103 | ) -> jnp.ndarray:
104 | """Returns samples from the distribution.
105 |
106 | Parameters
107 | ----------
108 | params : a Jax pytree
109 | Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta)
110 | for the Nth dimension.
111 | nsamples : int
112 | The number of samples to be returned.
113 | seed : int; optional
114 | Sets the random seed for the samples.
115 |
116 | Returns
117 | -------
118 | jnp.ndarray
119 | Device array of shape (nsamples, self.input_dim).
120 | """
121 | seed = np.random.randint(1e18) if seed is None else seed
122 | seeds = random.split(random.PRNGKey(seed), self.input_dim)
123 | samples = jnp.hstack(
124 | [
125 | random.beta(
126 | seeds[i],
127 | jnp.exp(params[i][0]),
128 | jnp.exp(params[i][1]),
129 | shape=(nsamples, 1),
130 | )
131 | for i in range(self.input_dim)
132 | ]
133 | )
134 | return 2 * self.B * (samples - 0.5)
135 |
136 |
137 | class CentBeta13(LatentDist):
138 | """A centered Beta distribution with alpha, beta = 13.
139 |
140 | This distribution is just a regular Beta distribution, scaled and shifted
141 | to have support on the domain [-B, B] in each dimension.
142 |
143 | Alpha, beta = 13 means that the distribution looks like a Gaussian
144 | distribution, but with hard cutoffs at +/- B.
145 | """
146 |
147 | def __init__(self, input_dim: int, B: float = 5) -> None:
148 | """
149 | Parameters
150 | ----------
151 | input_dim : int
152 | The dimension of the distribution.
153 | B : float; default=5
154 | The distribution has support (-B, B) along each dimension.
155 | """
156 | self.input_dim = input_dim
157 | self.B = B
158 |
159 | # save dist info
160 | self._params = tuple([(0.0, 0.0) for i in range(input_dim)])
161 | self.info = ("CentBeta13", (input_dim, B))
162 | self.a = 13
163 | self.b = 13
164 |
165 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
166 | """Calculates log probability density of inputs.
167 |
168 | Parameters
169 | ----------
170 | params : a Jax pytree
171 | Empty pytree -- this distribution doesn't have learnable parameters.
172 | This parameter is present to ensure a consistent interface.
173 | inputs : jnp.ndarray
174 | Input data for which log probability density is calculated.
175 |
176 | Returns
177 | -------
178 | jnp.ndarray
179 | Device array of shape (inputs.shape[0],).
180 | """
181 | log_prob = jnp.hstack(
182 | [
183 | beta.logpdf(
184 | inputs[:, i],
185 | a=self.a,
186 | b=self.b,
187 | loc=-self.B,
188 | scale=2 * self.B,
189 | ).reshape(-1, 1)
190 | for i in range(self.input_dim)
191 | ]
192 | ).sum(axis=1)
193 |
194 | return log_prob
195 |
196 | def sample(
197 | self, params: Pytree, nsamples: int, seed: int = None
198 | ) -> jnp.ndarray:
199 | """Returns samples from the distribution.
200 |
201 | Parameters
202 | ----------
203 | params : a Jax pytree
204 | Empty pytree -- this distribution doesn't have learnable parameters.
205 | This parameter is present to ensure a consistent interface.
206 | nsamples : int
207 | The number of samples to be returned.
208 | seed : int; optional
209 | Sets the random seed for the samples.
210 |
211 | Returns
212 | -------
213 | jnp.ndarray
214 | Device array of shape (nsamples, self.input_dim).
215 | """
216 | seed = np.random.randint(1e18) if seed is None else seed
217 | seeds = random.split(random.PRNGKey(seed), self.input_dim)
218 | samples = jnp.hstack(
219 | [
220 | random.beta(
221 | seeds[i],
222 | self.a,
223 | self.b,
224 | shape=(nsamples, 1),
225 | )
226 | for i in range(self.input_dim)
227 | ]
228 | )
229 | return 2 * self.B * (samples - 0.5)
230 |
231 |
232 | class Normal(LatentDist):
233 | """A multivariate Gaussian distribution with mean zero and unit variance.
234 |
235 | Note this distribution has infinite support, so it is not recommended that
236 | you use it with the spline coupling layers, which have compact support.
237 | If you do use the two together, you should set the support of the spline
238 | layers (using the spline parameter B) to be large enough that you rarely
239 | draw Gaussian samples outside the support of the splines.
240 | """
241 |
242 | def __init__(self, input_dim: int) -> None:
243 | """
244 | Parameters
245 | ----------
246 | input_dim : int
247 | The dimension of the distribution.
248 | """
249 | self.input_dim = input_dim
250 |
251 | # save dist info
252 | self._params = ()
253 | self.info = ("Normal", (input_dim,))
254 |
255 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
256 | """Calculates log probability density of inputs.
257 |
258 | Parameters
259 | ----------
260 | params : a Jax pytree
261 | Empty pytree -- this distribution doesn't have learnable parameters.
262 | This parameter is present to ensure a consistent interface.
263 | inputs : jnp.ndarray
264 | Input data for which log probability density is calculated.
265 |
266 | Returns
267 | -------
268 | jnp.ndarray
269 | Device array of shape (inputs.shape[0],).
270 | """
271 | return multivariate_normal.logpdf(
272 | inputs,
273 | mean=jnp.zeros(self.input_dim),
274 | cov=jnp.identity(self.input_dim),
275 | )
276 |
277 | def sample(
278 | self, params: Pytree, nsamples: int, seed: int = None
279 | ) -> jnp.ndarray:
280 | """Returns samples from the distribution.
281 |
282 | Parameters
283 | ----------
284 | params : a Jax pytree
285 | Empty pytree -- this distribution doesn't have learnable parameters.
286 | This parameter is present to ensure a consistent interface.
287 | nsamples : int
288 | The number of samples to be returned.
289 | seed : int; optional
290 | Sets the random seed for the samples.
291 |
292 | Returns
293 | -------
294 | jnp.ndarray
295 | Device array of shape (nsamples, self.input_dim).
296 | """
297 | seed = np.random.randint(1e18) if seed is None else seed
298 | return random.multivariate_normal(
299 | key=random.PRNGKey(seed),
300 | mean=jnp.zeros(self.input_dim),
301 | cov=jnp.identity(self.input_dim),
302 | shape=(nsamples,),
303 | )
304 |
305 |
306 | class Tdist(LatentDist):
307 | """A multivariate T distribution with mean zero and unit scale matrix.
308 |
309 | The number of degrees of freedom (i.e. the weight of the tails) is learned
310 | during training.
311 |
312 | Note this distribution has infinite support and potentially large tails,
313 | so it is not recommended to use this distribution with the spline coupling
314 | layers, which have compact support.
315 | """
316 |
317 | def __init__(self, input_dim: int) -> None:
318 | """
319 | Parameters
320 | ----------
321 | input_dim : int
322 | The dimension of the distribution.
323 | """
324 | self.input_dim = input_dim
325 |
326 | # save dist info
327 | self._params = jnp.log(30.0)
328 | self.info = ("Tdist", (input_dim,))
329 |
330 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
331 | """Calculates log probability density of inputs.
332 |
333 | Uses method explained here:
334 | http://gregorygundersen.com/blog/2020/01/20/multivariate-t/
335 |
336 | Parameters
337 | ----------
338 | params : float
339 | The degrees of freedom (nu) of the t-distribution.
340 | inputs : jnp.ndarray
341 | Input data for which log probability density is calculated.
342 |
343 | Returns
344 | -------
345 | jnp.ndarray
346 | Device array of shape (inputs.shape[0],).
347 | """
348 | cov = jnp.identity(self.input_dim)
349 | nu = jnp.exp(params)
350 | maha, log_det = _mahalanobis_and_logdet(inputs, cov)
351 | t = 0.5 * (nu + self.input_dim)
352 | A = gammaln(t)
353 | B = gammaln(0.5 * nu)
354 | C = self.input_dim / 2.0 * jnp.log(nu * jnp.pi)
355 | D = 0.5 * log_det
356 | E = -t * jnp.log(1 + (1.0 / nu) * maha)
357 |
358 | return A - B - C - D + E
359 |
360 | def sample(
361 | self, params: Pytree, nsamples: int, seed: int = None
362 | ) -> jnp.ndarray:
363 | """Returns samples from the distribution.
364 |
365 | Parameters
366 | ----------
367 | params : float
368 | The degrees of freedom (nu) of the t-distribution.
369 | nsamples : int
370 | The number of samples to be returned.
371 | seed : int; optional
372 | Sets the random seed for the samples.
373 |
374 | Returns
375 | -------
376 | jnp.ndarray
377 | Device array of shape (nsamples, self.input_dim).
378 | """
379 | mean = jnp.zeros(self.input_dim)
380 | nu = jnp.exp(params)
381 |
382 | seed = np.random.randint(1e18) if seed is None else seed
383 | rng = np.random.default_rng(int(seed))
384 | x = jnp.array(rng.chisquare(nu, nsamples) / nu)
385 | z = random.multivariate_normal(
386 | key=random.PRNGKey(seed),
387 | mean=jnp.zeros(self.input_dim),
388 | cov=jnp.identity(self.input_dim),
389 | shape=(nsamples,),
390 | )
391 | samples = mean + z / jnp.sqrt(x)[:, None]
392 | return samples
393 |
394 |
395 | class Uniform(LatentDist):
396 | """A multivariate uniform distribution with support [-B, B]."""
397 |
398 | def __init__(self, input_dim: int, B: float = 5) -> None:
399 | """
400 | Parameters
401 | ----------
402 | input_dim : int
403 | The dimension of the distribution.
404 | B : float; default=5
405 | The distribution has support (-B, B) along each dimension.
406 | """
407 | self.input_dim = input_dim
408 | self.B = B
409 |
410 | # save dist info
411 | self._params = ()
412 | self.info = ("Uniform", (input_dim, B))
413 |
414 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
415 | """Calculates log probability density of inputs.
416 |
417 | Parameters
418 | ----------
419 | params : Jax Pytree
420 | Empty pytree -- this distribution doesn't have learnable parameters.
421 | This parameter is present to ensure a consistent interface.
422 | inputs : jnp.ndarray
423 | Input data for which log probability density is calculated.
424 |
425 | Returns
426 | -------
427 | jnp.ndarray
428 | Device array of shape (inputs.shape[0],).
429 | """
430 |
431 | # which inputs are inside the support of the distribution
432 | mask = jnp.prod((inputs >= -self.B) & (inputs <= self.B), axis=-1)
433 |
434 | # calculate log_prob
435 | log_prob = jnp.where(
436 | mask,
437 | -self.input_dim * jnp.log(2 * self.B),
438 | -jnp.inf,
439 | )
440 |
441 | return log_prob
442 |
443 | def sample(
444 | self, params: Pytree, nsamples: int, seed: int = None
445 | ) -> jnp.ndarray:
446 | """Returns samples from the distribution.
447 |
448 | Parameters
449 | ----------
450 | params : a Jax pytree
451 | Empty pytree -- this distribution doesn't have learnable parameters.
452 | This parameter is present to ensure a consistent interface.
453 | nsamples : int
454 | The number of samples to be returned.
455 | seed : int; optional
456 | Sets the random seed for the samples.
457 |
458 | Returns
459 | -------
460 | jnp.ndarray
461 | Device array of shape (nsamples, self.input_dim).
462 | """
463 | seed = np.random.randint(1e18) if seed is None else seed
464 | samples = random.uniform(
465 | random.PRNGKey(seed),
466 | shape=(nsamples, self.input_dim),
467 | minval=-self.B,
468 | maxval=self.B,
469 | )
470 | return jnp.array(samples)
471 |
472 |
473 | class Joint(LatentDist):
474 | """A joint distribution built from other distributions.
475 |
476 | Note that each of the other distributions already have support for
477 | multiple dimensions. This is only useful if you want to combine
478 | different distributions for different dimensions, e.g. if your first
479 | dimension has a Uniform latent space and the second dimension has a
480 | CentBeta latent space.
481 | """
482 |
483 | def __init__(self, *inputs: Union[LatentDist, tuple]) -> None:
484 | """
485 | Parameters
486 | ----------
487 | inputs: LatentDist or tuple
488 | The latent distributions to join together.
489 | """
490 |
491 | # if Joint info provided, use that for setup
492 | if inputs[0] == "Joint info":
493 | self.dists = [globals()[dist[0]](*dist[1]) for dist in inputs[1]]
494 | # otherwise, assume it's a list of distributions
495 | else:
496 | self.dists = inputs
497 |
498 | # save info
499 | self._params = [dist._params for dist in self.dists]
500 | self.input_dim = sum([dist.input_dim for dist in self.dists])
501 | self.info = (
502 | "Joint",
503 | ("Joint info", [dist.info for dist in self.dists]),
504 | )
505 |
506 | # save the indices at which inputs will be split for log_prob
507 | # they must be concretely saved ahead-of-time so that jax trace
508 | # works properly when jitting
509 | self._splits = jnp.cumsum(
510 | jnp.array([dist.input_dim for dist in self.dists])
511 | )[:-1]
512 |
513 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
514 | """Calculates log probability density of inputs.
515 |
516 | Parameters
517 | ----------
518 | params : Jax Pytree
519 | Parameters for the distributions.
520 | inputs : jnp.ndarray
521 | Input data for which log probability density is calculated.
522 |
523 | Returns
524 | -------
525 | jnp.ndarray
526 | Device array of shape (inputs.shape[0],).
527 | """
528 |
529 | # split inputs for corresponding distribution
530 | inputs = jnp.split(inputs, self._splits, axis=1)
531 |
532 | # calculate log_prob with respect to each sub-distribution,
533 | # then sum all the log_probs for each input
534 | log_prob = jnp.hstack(
535 | [
536 | self.dists[i].log_prob(params[i], inputs[i]).reshape(-1, 1)
537 | for i in range(len(self.dists))
538 | ]
539 | ).sum(axis=1)
540 |
541 | return log_prob
542 |
543 | def sample(
544 | self, params: Pytree, nsamples: int, seed: int = None
545 | ) -> jnp.ndarray:
546 | """Returns samples from the distribution.
547 |
548 | Parameters
549 | ----------
550 | params : a Jax pytree
551 | Parameters for the distributions.
552 | nsamples : int
553 | The number of samples to be returned.
554 | seed : int; optional
555 | Sets the random seed for the samples.
556 |
557 | Returns
558 | -------
559 | jnp.ndarray
560 | Device array of shape (nsamples, self.input_dim).
561 | """
562 |
563 | seed = np.random.randint(1e18) if seed is None else seed
564 | seeds = random.randint(
565 | random.PRNGKey(seed), (len(self.dists),), 0, int(1e9)
566 | )
567 | samples = jnp.hstack(
568 | [
569 | self.dists[i]
570 | .sample(params[i], nsamples, seeds[i])
571 | .reshape(nsamples, -1)
572 | for i in range(len(self.dists))
573 | ]
574 | )
575 |
576 | return samples
577 |
--------------------------------------------------------------------------------
/pzflow/example_files/checkerboard-data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/checkerboard-data.pkl
--------------------------------------------------------------------------------
/pzflow/example_files/city-data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/city-data.pkl
--------------------------------------------------------------------------------
/pzflow/example_files/example-flow.pzflow.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/example-flow.pzflow.pkl
--------------------------------------------------------------------------------
/pzflow/example_files/galaxy-data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/galaxy-data.pkl
--------------------------------------------------------------------------------
/pzflow/example_files/two-moons-data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/two-moons-data.pkl
--------------------------------------------------------------------------------
/pzflow/examples.py:
--------------------------------------------------------------------------------
1 | """Functions that return example data and a example flow trained on
2 | galaxy data. To see these examples in action, see the tutorial notebooks.
3 | """
4 |
5 | import os
6 |
7 | import pandas as pd
8 |
9 | from pzflow import Flow
10 |
11 | EXAMPLE_FILE_DIR = "example_files"
12 |
13 |
14 | def _load_example_data(name: str) -> pd.DataFrame:
15 | this_dir, _ = os.path.split(__file__)
16 | data_path = os.path.join(this_dir, f"{EXAMPLE_FILE_DIR}/{name}.pkl")
17 | data = pd.read_pickle(data_path)
18 | return data
19 |
20 |
21 | def get_twomoons_data() -> pd.DataFrame:
22 | """Return DataFrame with two moons example data.
23 |
24 | Two moons data originally from scikit-learn,
25 | i.e., `sklearn.datasets.make_moons`.
26 | """
27 | return _load_example_data("two-moons-data")
28 |
29 |
30 | def get_galaxy_data() -> pd.DataFrame:
31 | """Return DataFrame with example galaxy data.
32 |
33 | 100,000 galaxies from the Buzzard simulation [1], with redshifts
34 | in the range (0,2.3) and photometry in the LSST ugrizy bands.
35 |
36 | References
37 | ----------
38 | [1] Joseph DeRose et al. The Buzzard Flock: Dark Energy Survey
39 | Synthetic Sky Catalogs. arXiv:1901.02401, 2019.
40 | https://arxiv.org/abs/1901.02401
41 | """
42 | return _load_example_data("galaxy-data")
43 |
44 |
45 | def get_checkerboard_data() -> pd.DataFrame:
46 | """Return DataFrame with discrete checkerboard data."""
47 | return _load_example_data("checkerboard-data")
48 |
49 |
50 | def get_city_data() -> pd.DataFrame:
51 | """Return DataFrame with example city data.
52 |
53 | The countries, names, population, and coordinates of 47,966 cities.
54 |
55 | Subset of the Kaggle world cities database.
56 | https://www.kaggle.com/max-mind/world-cities-database
57 | This database was downloaded from MaxMind. The license follows:
58 |
59 | OPEN DATA LICENSE for MaxMind WorldCities and Postal Code Databases
60 |
61 | Copyright (c) 2008 MaxMind Inc. All Rights Reserved.
62 |
63 | The database uses toponymic information, based on the Geographic Names
64 | Data Base, containing official standard names approved by the United States
65 | Board on Geographic Names and maintained by the National
66 | Geospatial-Intelligence Agency. More information is available at the Maps
67 | and Geodata link at www.nga.mil. The National Geospatial-Intelligence Agency
68 | name, initials, and seal are protected by 10 United States Code Section 445.
69 |
70 | It also uses free population data from Stefan Helders www.world-gazetteer.com.
71 | Visit his website to download the free population data. Our database
72 | combines Stefan's population data with the list of all cities in the world.
73 |
74 | All advertising materials and documentation mentioning features or use of
75 | this database must display the following acknowledgment:
76 | "This product includes data created by MaxMind, available from
77 | http://www.maxmind.com/"
78 |
79 | Redistribution and use with or without modification, are permitted provided
80 | that the following conditions are met:
81 | 1. Redistributions must retain the above copyright notice, this list of
82 | conditions and the following disclaimer in the documentation and/or other
83 | materials provided with the distribution.
84 | 2. All advertising materials and documentation mentioning features or use of
85 | this database must display the following acknowledgement:
86 | "This product includes data created by MaxMind, available from
87 | http://www.maxmind.com/"
88 | 3. "MaxMind" may not be used to endorse or promote products derived from this
89 | database without specific prior written permission.
90 |
91 | THIS DATABASE IS PROVIDED BY MAXMIND.COM ``AS IS'' AND ANY
92 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
93 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
94 | DISCLAIMED. IN NO EVENT SHALL MAXMIND.COM BE LIABLE FOR ANY
95 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
96 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
97 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
98 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
99 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
100 | DATABASE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
101 | """
102 | return _load_example_data("city-data")
103 |
104 |
105 | def get_example_flow() -> Flow:
106 | """Return a normalizing flow that was trained on galaxy data.
107 |
108 | This flow was trained in the `redshift_example.ipynb` Jupyter notebook,
109 | on the example data available in `pzflow.examples.galaxy_data`.
110 | For more info: `print(example_flow().info)`.
111 | """
112 | this_dir, _ = os.path.split(__file__)
113 | flow_path = os.path.join(
114 | this_dir, f"{EXAMPLE_FILE_DIR}/example-flow.pzflow.pkl"
115 | )
116 | flow = Flow(file=flow_path)
117 | return flow
118 |
--------------------------------------------------------------------------------
/pzflow/flowEnsemble.py:
--------------------------------------------------------------------------------
1 | """Define FlowEnsemble object that holds an ensemble of normalizing flows."""
2 |
3 | from typing import Any, Callable, Sequence, Tuple
4 |
5 | import dill as pickle
6 | import jax.numpy as jnp
7 | import numpy as np
8 | import pandas as pd
9 | from jax import random
10 | from jax.scipy.integrate import trapezoid
11 |
12 | from pzflow import Flow, distributions
13 | from pzflow.bijectors import Bijector_Info, InitFunction
14 |
15 |
16 | class FlowEnsemble:
17 | """An ensemble of normalizing flows.
18 |
19 | Attributes
20 | ----------
21 | data_columns : tuple
22 | List of DataFrame columns that the flows expect/produce.
23 | conditional_columns : tuple
24 | List of DataFrame columns on which the flows are conditioned.
25 | latent: distributions.LatentDist
26 | The latent distribution of the normalizing flows.
27 | Has it's own sample and log_prob methods.
28 | data_error_model : Callable
29 | The error model for the data variables. See the docstring of
30 | __init__ for more details.
31 | condition_error_model : Callable
32 | The error model for the conditional variables. See the docstring
33 | of __init__ for more details.
34 | info : Any
35 | Object containing any kind of info included with the ensemble.
36 | Often Reverse the data the flows are trained on.
37 | """
38 |
39 | def __init__(
40 | self,
41 | data_columns: Sequence[str] = None,
42 | bijector: Tuple[InitFunction, Bijector_Info] = None,
43 | latent: distributions.LatentDist = None,
44 | conditional_columns: Sequence[str] = None,
45 | data_error_model: Callable = None,
46 | condition_error_model: Callable = None,
47 | autoscale_conditions: bool = True,
48 | N: int = 1,
49 | info: Any = None,
50 | file: str = None,
51 | ) -> None:
52 | """Instantiate an ensemble of normalizing flows.
53 |
54 | Note that while all of the init parameters are technically optional,
55 | you must provide either data_columns and bijector OR file.
56 | In addition, if a file is provided, all other parameters must be None.
57 |
58 | Parameters
59 | ----------
60 | data_columns : Sequence[str]; optional
61 | Tuple, list, or other container of column names.
62 | These are the columns the flows expect/produce in DataFrames.
63 | bijector : Bijector Call; optional
64 | A Bijector call that consists of the bijector InitFunction that
65 | initializes the bijector and the tuple of Bijector Info.
66 | Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
67 | If not provided, the bijector can be set later using
68 | flow.set_bijector, or by calling flow.train, in which case the
69 | default bijector will be used. The default bijector is
70 | ShiftBounds -> RollingSplineCoupling, where the range of shift
71 | bounds is learned from the training data, and the dimensions of
72 | RollingSplineCoupling is inferred. The default bijector assumes
73 | that the latent has support [-5, 5] for every dimension.
74 | latent : distributions.LatentDist; optional
75 | The latent distribution for the normalizing flow. Can be any of
76 | the distributions from pzflow.distributions. If not provided,
77 | a uniform distribution is used with input_dim = len(data_columns),
78 | and B=5.
79 | conditional_columns : Sequence[str]; optional
80 | Names of columns on which to condition the normalizing flows.
81 | data_error_model : Callable; optional
82 | A callable that defines the error model for data variables.
83 | data_error_model must take key, X, Xerr, nsamples as arguments:
84 | - key is a jax rng key, e.g. jax.random.PRNGKey(0)
85 | - X is 2D array of data variables, where the order of variables
86 | matches the order of the columns in data_columns
87 | - Xerr is the corresponding 2D array of errors
88 | - nsamples is number of samples to draw from error distribution
89 | data_error_model must return an array of samples with the shape
90 | (X.shape[0], nsamples, X.shape[1]).
91 | If data_error_model is not provided, Gaussian error model assumed.
92 | condition_error_model : Callable; optional
93 | A callable that defines the error model for conditional variables.
94 | condition_error_model must take key, X, Xerr, nsamples, where:
95 | - key is a jax rng key, e.g. jax.random.PRNGKey(0)
96 | - X is 2D array of conditional variables, where the order of
97 | variables matches order of columns in conditional_columns
98 | - Xerr is the corresponding 2D array of errors
99 | - nsamples is number of samples to draw from error distribution
100 | condition_error_model must return array of samples with shape
101 | (X.shape[0], nsamples, X.shape[1]).
102 | If condition_error_model is not provided, Gaussian error model
103 | assumed.
104 | autoscale_conditions : bool; default=True
105 | Sets whether or not conditions are automatically standard scaled
106 | when passed to a conditional flow. I recommend you leave as True.
107 | N : int; default=1
108 | The number of flows in the ensemble.
109 | info : Any; optional
110 | An object to attach to the info attribute.
111 | file : str; optional
112 | Path to file from which to load a pretrained flow ensemble.
113 | If a file is provided, all other parameters must be None.
114 | """
115 |
116 | # validate parameters
117 | if data_columns is None and file is None:
118 | raise ValueError("You must provide data_columns OR file.")
119 | if file is not None and any(
120 | (
121 | data_columns is not None,
122 | bijector is not None,
123 | conditional_columns is not None,
124 | latent is not None,
125 | data_error_model is not None,
126 | condition_error_model is not None,
127 | info is not None,
128 | )
129 | ):
130 | raise ValueError(
131 | "If providing a file, please do not provide any other parameters."
132 | )
133 |
134 | # if file is provided, load everything from the file
135 | if file is not None:
136 | # load the file
137 | with open(file, "rb") as handle:
138 | save_dict = pickle.load(handle)
139 |
140 | # make sure the saved file is for this class
141 | c = save_dict.pop("class")
142 | if c != self.__class__.__name__:
143 | raise TypeError(
144 | f"This save file isn't a {self.__class__.__name__}. It is a {c}."
145 | )
146 |
147 | # load the ensemble from the dictionary
148 | self._ensemble = {
149 | name: Flow(_dictionary=flow_dict)
150 | for name, flow_dict in save_dict["ensemble"].items()
151 | }
152 | # load the metadata
153 | self.data_columns = save_dict["data_columns"]
154 | self.conditional_columns = save_dict["conditional_columns"]
155 | self.data_error_model = save_dict["data_error_model"]
156 | self.condition_error_model = save_dict["condition_error_model"]
157 | self.info = save_dict["info"]
158 |
159 | self._latent_info = save_dict["latent_info"]
160 | self.latent = getattr(distributions, self._latent_info[0])(
161 | *self._latent_info[1]
162 | )
163 |
164 | # otherwise create a new ensemble from the provided parameters
165 | else:
166 | # save the ensemble of flows
167 | self._ensemble = {
168 | f"Flow {i}": Flow(
169 | data_columns=data_columns,
170 | bijector=bijector,
171 | conditional_columns=conditional_columns,
172 | latent=latent,
173 | data_error_model=data_error_model,
174 | condition_error_model=condition_error_model,
175 | autoscale_conditions=autoscale_conditions,
176 | seed=i,
177 | info=f"Flow {i}",
178 | )
179 | for i in range(N)
180 | }
181 | # save the metadata
182 | self.data_columns = data_columns
183 | self.conditional_columns = conditional_columns
184 | self.latent = self._ensemble["Flow 0"].latent
185 | self.data_error_model = data_error_model
186 | self.condition_error_model = condition_error_model
187 | self.info = info
188 |
189 | def log_prob(
190 | self,
191 | inputs: pd.DataFrame,
192 | err_samples: int = None,
193 | seed: int = None,
194 | returnEnsemble: bool = False,
195 | ) -> jnp.ndarray:
196 | """Calculates log probability density of inputs.
197 |
198 | Parameters
199 | ----------
200 | inputs : pd.DataFrame
201 | Input data for which log probability density is calculated.
202 | Every column in self.data_columns must be present.
203 | If self.conditional_columns is not None, those must be present
204 | as well. If other columns are present, they are ignored.
205 | err_samples : int; default=None
206 | Number of samples from the error distribution to average over for
207 | the log_prob calculation. If provided, Gaussian errors are assumed,
208 | and method will look for error columns in `inputs`. Error columns
209 | must end in `_err`. E.g. the error column for the variable `u` must
210 | be `u_err`. Zero error assumed for any missing error columns.
211 | seed : int; default=None
212 | Random seed for drawing the samples with Gaussian errors.
213 | returnEnsemble : bool; default=False
214 | If True, returns log_prob for each flow in the ensemble as an
215 | array of shape (inputs.shape[0], N flows in ensemble).
216 | If False, the prob is averaged over the flows in the ensemble,
217 | and the log of this average is returned as an array of shape
218 | (inputs.shape[0],)
219 |
220 | Returns
221 | -------
222 | jnp.ndarray
223 | For shape, see returnEnsemble description above.
224 | """
225 |
226 | # calculate log_prob for each flow in the ensemble
227 | ensemble = jnp.array(
228 | [
229 | flow.log_prob(inputs, err_samples, seed)
230 | for flow in self._ensemble.values()
231 | ]
232 | )
233 |
234 | # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
235 | ensemble = jnp.rollaxis(ensemble, axis=1)
236 |
237 | if returnEnsemble:
238 | # return the ensemble of log_probs
239 | return ensemble
240 | else:
241 | # return mean over ensemble
242 | # note we return log(mean prob) instead of just mean log_prob
243 | return jnp.log(jnp.exp(ensemble).mean(axis=1))
244 |
245 | def posterior(
246 | self,
247 | inputs: pd.DataFrame,
248 | column: str,
249 | grid: jnp.ndarray,
250 | marg_rules: dict = None,
251 | normalize: bool = True,
252 | err_samples: int = None,
253 | seed: int = None,
254 | batch_size: int = None,
255 | returnEnsemble: bool = False,
256 | nan_to_zero: bool = True,
257 | ) -> jnp.ndarray:
258 | """Calculates posterior distributions for the provided column.
259 |
260 | Calculates the conditional posterior distribution, assuming the
261 | data values in the other columns of the DataFrame.
262 |
263 | Parameters
264 | ----------
265 | inputs : pd.DataFrame
266 | Data on which the posterior distributions are conditioned.
267 | Must have columns matching self.data_columns, *except*
268 | for the column specified for the posterior (see below).
269 | column : str
270 | Name of the column for which the posterior distribution
271 | is calculated. Must be one of the columns in self.data_columns.
272 | However, whether or not this column is one of the columns in
273 | `inputs` is irrelevant.
274 | grid : jnp.ndarray
275 | Grid on which to calculate the posterior.
276 | marg_rules : dict; optional
277 | Dictionary with rules for marginalizing over missing variables.
278 | The dictionary must contain the key "flag", which gives the flag
279 | that indicates a missing value. E.g. if missing values are given
280 | the value 99, the dictionary should contain {"flag": 99}.
281 | The dictionary must also contain {"name": callable} for any
282 | variables that will need to be marginalized over, where name is
283 | the name of the variable, and callable is a callable that takes
284 | the row of variables nad returns a grid over which to marginalize
285 | the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
286 | Note: the callable for a given name must *always* return an array
287 | of the same length, regardless of the input row.
288 | normalize : boolean; default=True
289 | Whether to normalize the posterior so that it integrates to 1.
290 | err_samples : int; default=None
291 | Number of samples from the error distribution to average over for
292 | the posterior calculation. If provided, Gaussian errors are assumed,
293 | and method will look for error columns in `inputs`. Error columns
294 | must end in `_err`. E.g. the error column for the variable `u` must
295 | be `u_err`. Zero error assumed for any missing error columns.
296 | seed : int; default=None
297 | Random seed for drawing the samples with Gaussian errors.
298 | batch_size : int; default=None
299 | Size of batches in which to calculate posteriors. If None, all
300 | posteriors are calculated simultaneously. Simultaneous calculation
301 | is faster, but memory intensive for large data sets.
302 | returnEnsemble : bool; default=False
303 | If True, returns posterior for each flow in the ensemble as an
304 | array of shape (inputs.shape[0], N flows in ensemble, grid.size).
305 | If False, the posterior is averaged over the flows in the ensemble,
306 | and returned as an array of shape (inputs.shape[0], grid.size)
307 | nan_to_zero : bool; default=True
308 | Whether to convert NaN's to zero probability in the final pdfs.
309 |
310 | Returns
311 | -------
312 | jnp.ndarray
313 | For shape, see returnEnsemble description above.
314 | """
315 |
316 | # calculate posterior for each flow in the ensemble
317 | ensemble = jnp.array(
318 | [
319 | flow.posterior(
320 | inputs=inputs,
321 | column=column,
322 | grid=grid,
323 | marg_rules=marg_rules,
324 | err_samples=err_samples,
325 | seed=seed,
326 | batch_size=batch_size,
327 | normalize=False,
328 | nan_to_zero=nan_to_zero,
329 | )
330 | for flow in self._ensemble.values()
331 | ]
332 | )
333 |
334 | # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
335 | ensemble = jnp.rollaxis(ensemble, axis=1)
336 |
337 | if returnEnsemble:
338 | # return the ensemble of posteriors
339 | if normalize:
340 | ensemble = ensemble.reshape(-1, grid.size)
341 | ensemble = ensemble / trapezoid(y=ensemble, x=grid).reshape(
342 | -1, 1
343 | )
344 | ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size)
345 | return ensemble
346 | else:
347 | # return mean over ensemble
348 | pdfs = ensemble.mean(axis=1)
349 | if normalize:
350 | pdfs = pdfs / trapezoid(y=pdfs, x=grid).reshape(-1, 1)
351 | return pdfs
352 |
353 | def sample(
354 | self,
355 | nsamples: int = 1,
356 | conditions: pd.DataFrame = None,
357 | save_conditions: bool = True,
358 | seed: int = None,
359 | returnEnsemble: bool = False,
360 | ) -> pd.DataFrame:
361 | """Returns samples from the ensemble.
362 |
363 | Parameters
364 | ----------
365 | nsamples : int; default=1
366 | The number of samples to be returned, either overall or per flow
367 | in the ensemble (see returnEnsemble below).
368 | conditions : pd.DataFrame; optional
369 | If this is a conditional flow, you must pass conditions for
370 | each sample. nsamples will be drawn for each row in conditions.
371 | save_conditions : bool; default=True
372 | If true, conditions will be saved in the DataFrame of samples
373 | that is returned.
374 | seed : int; optional
375 | Sets the random seed for the samples.
376 | returnEnsemble : bool; default=False
377 | If True, nsamples is drawn from each flow in the ensemble.
378 | If False, nsamples are drawn uniformly from the flows in the ensemble.
379 |
380 | Returns
381 | -------
382 | pd.DataFrame
383 | Pandas DataFrame of samples.
384 | """
385 |
386 | if returnEnsemble:
387 | # return nsamples for each flow in the ensemble
388 | return pd.concat(
389 | [
390 | flow.sample(nsamples, conditions, save_conditions, seed)
391 | for flow in self._ensemble.values()
392 | ],
393 | keys=self._ensemble.keys(),
394 | )
395 | else:
396 | # if this isn't a conditional flow, sampling is straightforward
397 | if conditions is None:
398 | # return nsamples drawn uniformly from the flows in the ensemble
399 | N = int(jnp.ceil(nsamples / len(self._ensemble)))
400 | samples = pd.concat(
401 | [
402 | flow.sample(N, conditions, save_conditions, seed)
403 | for flow in self._ensemble.values()
404 | ]
405 | )
406 | return samples.sample(nsamples, random_state=seed).reset_index(
407 | drop=True
408 | )
409 | # if this is a conditional flow, it's a little more complicated...
410 | else:
411 | # if nsamples > 1, we duplicate the rows of the conditions
412 | if nsamples > 1:
413 | conditions = pd.concat([conditions] * nsamples)
414 |
415 | # now the main sampling algorithm
416 | seed = np.random.randint(1e18) if seed is None else seed
417 | # if we are drawing more samples than the number of flows in
418 | # the ensemble, then we will shuffle the conditions and randomly
419 | # assign them to one of the constituent flows
420 | if conditions.shape[0] > len(self._ensemble):
421 | # shuffle the conditions
422 | conditions_shuffled = conditions.sample(
423 | frac=1.0, random_state=int(seed / 1e9)
424 | )
425 | # split conditions into ~equal sized chunks
426 | chunks = np.array_split(
427 | conditions_shuffled, len(self._ensemble)
428 | )
429 | # shuffle the chunks
430 | chunks = [
431 | chunks[i]
432 | for i in random.permutation(
433 | random.PRNGKey(seed), jnp.arange(len(chunks))
434 | )
435 | ]
436 | # sample from each flow, and return all the samples
437 | return pd.concat(
438 | [
439 | flow.sample(
440 | 1, chunk, save_conditions, seed
441 | ).set_index(chunk.index)
442 | for flow, chunk in zip(
443 | self._ensemble.values(), chunks
444 | )
445 | ]
446 | ).sort_index()
447 | # however, if there are more flows in the ensemble than samples
448 | # being drawn, then we will randomly select flows for each condition
449 | else:
450 | rng = np.random.default_rng(seed)
451 | # randomly select a flow to sample from for each condition
452 | flows = rng.choice(
453 | list(self._ensemble.values()),
454 | size=conditions.shape[0],
455 | replace=True,
456 | )
457 | # sample from each flow and return all the samples together
458 | seeds = rng.integers(1e18, size=conditions.shape[0])
459 | return pd.concat(
460 | [
461 | flow.sample(
462 | 1,
463 | conditions[i : i + 1],
464 | save_conditions,
465 | new_seed,
466 | )
467 | for i, (flow, new_seed) in enumerate(
468 | zip(flows, seeds)
469 | )
470 | ],
471 | ).set_index(conditions.index)
472 |
473 | def save(self, file: str) -> None:
474 | """Saves the ensemble to a file.
475 |
476 | Pickles the ensemble and saves it to a file that can be passed as
477 | the `file` argument during flow instantiation.
478 |
479 | WARNING: Currently, this method only works for bijectors that are
480 | implemented in the `bijectors` module. If you want to save a flow
481 | with a custom bijector, you either need to add the bijector to that
482 | module, or handle the saving and loading on your end.
483 |
484 | Parameters
485 | ----------
486 | file : str
487 | Path to where the ensemble will be saved.
488 | Extension `.pkl` will be appended if not already present.
489 | """
490 | save_dict = {
491 | "data_columns": self.data_columns,
492 | "conditional_columns": self.conditional_columns,
493 | "latent_info": self.latent.info,
494 | "data_error_model": self.data_error_model,
495 | "condition_error_model": self.condition_error_model,
496 | "info": self.info,
497 | "class": self.__class__.__name__,
498 | "ensemble": {
499 | name: flow._save_dict()
500 | for name, flow in self._ensemble.items()
501 | },
502 | }
503 |
504 | with open(file, "wb") as handle:
505 | pickle.dump(save_dict, handle, recurse=True)
506 |
507 | def train(
508 | self,
509 | inputs: pd.DataFrame,
510 | val_set: pd.DataFrame = None,
511 | train_weight: np.ndarray = None,
512 | val_weight: np.ndarray = None,
513 | epochs: int = 50,
514 | batch_size: int = 1024,
515 | optimizer: Callable = None,
516 | loss_fn: Callable = None,
517 | convolve_errs: bool = False,
518 | patience: int = None,
519 | best_params: bool = True,
520 | seed: int = 0,
521 | verbose: bool = False,
522 | progress_bar: bool = False,
523 | initial_loss: bool = True,
524 | ) -> dict:
525 | """Trains the normalizing flows on the provided inputs.
526 |
527 | Parameters
528 | ----------
529 | inputs : pd.DataFrame
530 | Data on which to train the normalizing flows.
531 | Must have columns matching self.data_columns.
532 | val_set : pd.DataFrame; default=None
533 | Validation set, of same format as inputs. If provided,
534 | validation loss will be calculated at the end of each epoch.
535 | train_weight: np.ndarray; default=None
536 | Array of weights for each sample in the training set.
537 | val_weight: np.ndarray; default=None
538 | Array of weights for each sample in the validation set.
539 | epochs : int; default=50
540 | Number of epochs to train.
541 | batch_size : int; default=1024
542 | Batch size for training.
543 | optimizer : optax optimizer
544 | An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
545 | see https://optax.readthedocs.io/en/latest/index.html for more.
546 | loss_fn : Callable; optional
547 | A function to calculate the loss: loss = loss_fn(params, x).
548 | If not provided, will be -mean(log_prob).
549 | convolve_errs : bool; default=False
550 | Whether to draw new data from the error distributions during
551 | each epoch of training. Method will look for error columns in
552 | `inputs`. Error columns must end in `_err`. E.g. the error column
553 | for the variable `u` must be `u_err`. Zero error assumed for
554 | any missing error columns. The error distribution is set during
555 | ensemble instantiation.
556 | patience : int; optional
557 | Factor that controls early stopping. Training will stop if the
558 | loss doesn't decrease for this number of epochs.
559 | best_params : bool; default=True
560 | Whether to use the params from the epoch with the lowest loss.
561 | Note if a validation set is provided, the epoch with the lowest
562 | validation loss is chosen. If False, the params from the final
563 | epoch are saved.
564 | seed : int; default=0
565 | A random seed to control the batching and the (optional)
566 | error sampling.
567 | verbose : bool; default=False
568 | If true, print the training loss every 5% of epochs.
569 | progress_bar : bool; default=False
570 | If true, display a tqdm progress bar during training.
571 | initial_loss : bool; default=True
572 | If true, start by calculating the initial loss.
573 |
574 | Returns
575 | -------
576 | dict
577 | Dictionary of training losses from every epoch for each flow
578 | in the ensemble.
579 | """
580 |
581 | # generate random seeds for each flow
582 | rng = np.random.default_rng(seed)
583 | seeds = rng.integers(1e9, size=len(self._ensemble))
584 |
585 | loss_dict = dict()
586 |
587 | for i, (name, flow) in enumerate(self._ensemble.items()):
588 | if verbose:
589 | print(name)
590 |
591 | loss_dict[name] = flow.train(
592 | inputs=inputs,
593 | val_set=val_set,
594 | train_weight=train_weight,
595 | val_weight=val_weight,
596 | epochs=epochs,
597 | batch_size=batch_size,
598 | optimizer=optimizer,
599 | loss_fn=loss_fn,
600 | convolve_errs=convolve_errs,
601 | patience=patience,
602 | best_params=best_params,
603 | seed=seeds[i],
604 | verbose=verbose,
605 | progress_bar=progress_bar,
606 | initial_loss=initial_loss,
607 | )
608 |
609 | return loss_dict
610 |
--------------------------------------------------------------------------------
/pzflow/utils.py:
--------------------------------------------------------------------------------
1 | """Define utility functions for use in other modules."""
2 | from typing import Callable, Tuple
3 |
4 | import jax.numpy as jnp
5 | from jax import random
6 | from jax.example_libraries.stax import Dense, LeakyRelu, serial
7 |
8 | from pzflow import bijectors
9 |
10 |
11 | def build_bijector_from_info(info: tuple) -> tuple:
12 | """Build a Bijector from a Bijector_Info object"""
13 |
14 | # recurse through chains
15 | if info[0] == "Chain":
16 | return bijectors.Chain(*(build_bijector_from_info(i) for i in info[1]))
17 | # build individual bijector from name and parameters
18 | else:
19 | return getattr(bijectors, info[0])(*info[1])
20 |
21 |
22 | def DenseReluNetwork(
23 | out_dim: int, hidden_layers: int, hidden_dim: int
24 | ) -> Tuple[Callable, Callable]:
25 | """Create a dense neural network with Relu after hidden layers.
26 |
27 | Parameters
28 | ----------
29 | out_dim : int
30 | The output dimension.
31 | hidden_layers : int
32 | The number of hidden layers
33 | hidden_dim : int
34 | The dimension of the hidden layers
35 |
36 | Returns
37 | -------
38 | init_fun : function
39 | The function that initializes the network. Note that this is the
40 | init_function defined in the Jax stax module, which is different
41 | from the functions of my InitFunction class.
42 | forward_fun : function
43 | The function that passes the inputs through the neural network.
44 | """
45 | init_fun, forward_fun = serial(
46 | *(Dense(hidden_dim), LeakyRelu) * hidden_layers,
47 | Dense(out_dim),
48 | )
49 | return init_fun, forward_fun
50 |
51 |
52 | def gaussian_error_model(
53 | key, X: jnp.ndarray, Xerr: jnp.ndarray, nsamples: int
54 | ) -> jnp.ndarray:
55 | """
56 | Default Gaussian error model were X are the means and Xerr are the stds.
57 | """
58 |
59 | eps = random.normal(key, shape=(X.shape[0], nsamples, X.shape[1]))
60 |
61 | return X[:, None, :] + eps * Xerr[:, None, :]
62 |
63 |
64 | def RationalQuadraticSpline(
65 | inputs: jnp.ndarray,
66 | W: jnp.ndarray,
67 | H: jnp.ndarray,
68 | D: jnp.ndarray,
69 | B: float,
70 | periodic: bool = False,
71 | inverse: bool = False,
72 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
73 | """Apply rational quadratic spline to inputs and return outputs with log_det.
74 |
75 | Applies the piecewise rational quadratic spline developed in [1].
76 |
77 | Parameters
78 | ----------
79 | inputs : jnp.ndarray
80 | The inputs to be transformed.
81 | W : jnp.ndarray
82 | The widths of the spline bins.
83 | H : jnp.ndarray
84 | The heights of the spline bins.
85 | D : jnp.ndarray
86 | The derivatives of the inner spline knots.
87 | B : float
88 | Range of the splines.
89 | Outside of (-B,B), the transformation is just the identity.
90 | inverse : bool; default=False
91 | If True, perform the inverse transformation.
92 | Otherwise perform the forward transformation.
93 | periodic : bool; default=False
94 | Whether to make this a periodic, Circular Spline [2].
95 |
96 | Returns
97 | -------
98 | outputs : jnp.ndarray
99 | The result of applying the splines to the inputs.
100 | log_det : jnp.ndarray
101 | The log determinant of the Jacobian at the inputs.
102 |
103 | References
104 | ----------
105 | [1] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios.
106 | Neural Spline Flows. arXiv:1906.04032, 2019.
107 | https://arxiv.org/abs/1906.04032
108 | [2] Rezende, Danilo Jimenez et al.
109 | Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020
110 | http://arxiv.org/abs/2002.02428
111 | """
112 | # knot x-positions
113 | xk = jnp.pad(
114 | -B + jnp.cumsum(W, axis=-1),
115 | [(0, 0)] * (len(W.shape) - 1) + [(1, 0)],
116 | mode="constant",
117 | constant_values=-B,
118 | )
119 | # knot y-positions
120 | yk = jnp.pad(
121 | -B + jnp.cumsum(H, axis=-1),
122 | [(0, 0)] * (len(H.shape) - 1) + [(1, 0)],
123 | mode="constant",
124 | constant_values=-B,
125 | )
126 | # knot derivatives
127 | if periodic:
128 | dk = jnp.pad(D, [(0, 0)] * (len(D.shape) - 1) + [(1, 0)], mode="wrap")
129 | else:
130 | dk = jnp.pad(
131 | D,
132 | [(0, 0)] * (len(D.shape) - 1) + [(1, 1)],
133 | mode="constant",
134 | constant_values=1,
135 | )
136 | # knot slopes
137 | sk = H / W
138 |
139 | # if not periodic, out-of-bounds inputs will have identity applied
140 | # if periodic, we map the input into the appropriate region inside
141 | # the period. For now, we will pretend all inputs are periodic.
142 | # This makes sure that out-of-bounds inputs don't cause problems
143 | # with the spline, but for the non-periodic case, we will replace
144 | # these with their original values at the end
145 | out_of_bounds = (inputs <= -B) | (inputs >= B)
146 | masked_inputs = jnp.where(out_of_bounds, jnp.abs(inputs) - B, inputs)
147 |
148 | # find bin for each input
149 | if inverse:
150 | idx = jnp.sum(yk <= masked_inputs[..., None], axis=-1)[..., None] - 1
151 | else:
152 | idx = jnp.sum(xk <= masked_inputs[..., None], axis=-1)[..., None] - 1
153 |
154 | # get kx, ky, kyp1, kd, kdp1, kw, ks for the bin corresponding to each input
155 | input_xk = jnp.take_along_axis(xk, idx, -1)[..., 0]
156 | input_yk = jnp.take_along_axis(yk, idx, -1)[..., 0]
157 | input_dk = jnp.take_along_axis(dk, idx, -1)[..., 0]
158 | input_dkp1 = jnp.take_along_axis(dk, idx + 1, -1)[..., 0]
159 | input_wk = jnp.take_along_axis(W, idx, -1)[..., 0]
160 | input_hk = jnp.take_along_axis(H, idx, -1)[..., 0]
161 | input_sk = jnp.take_along_axis(sk, idx, -1)[..., 0]
162 |
163 | if inverse:
164 | # [1] Appendix A.3
165 | # quadratic formula coefficients
166 | a = (input_hk) * (input_sk - input_dk) + (masked_inputs - input_yk) * (
167 | input_dkp1 + input_dk - 2 * input_sk
168 | )
169 | b = (input_hk) * input_dk - (masked_inputs - input_yk) * (
170 | input_dkp1 + input_dk - 2 * input_sk
171 | )
172 | c = -input_sk * (masked_inputs - input_yk)
173 |
174 | relx = 2 * c / (-b - jnp.sqrt(b**2 - 4 * a * c))
175 | outputs = relx * input_wk + input_xk
176 | # if not periodic, replace out-of-bounds values with original values
177 | if not periodic:
178 | outputs = jnp.where(out_of_bounds, inputs, outputs)
179 |
180 | # [1] Appendix A.2
181 | # calculate the log determinant
182 | dnum = (
183 | input_dkp1 * relx**2
184 | + 2 * input_sk * relx * (1 - relx)
185 | + input_dk * (1 - relx) ** 2
186 | )
187 | dden = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * (
188 | 1 - relx
189 | )
190 | log_det = 2 * jnp.log(input_sk) + jnp.log(dnum) - 2 * jnp.log(dden)
191 | # if not periodic, replace log_det for out-of-bounds values = 0
192 | if not periodic:
193 | log_det = jnp.where(out_of_bounds, 0, log_det)
194 | log_det = log_det.sum(axis=1)
195 |
196 | return outputs, -log_det
197 |
198 | else:
199 | # [1] Appendix A.1
200 | # calculate spline
201 | relx = (masked_inputs - input_xk) / input_wk
202 | num = input_hk * (input_sk * relx**2 + input_dk * relx * (1 - relx))
203 | den = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * (
204 | 1 - relx
205 | )
206 | outputs = input_yk + num / den
207 | # if not periodic, replace out-of-bounds values with original values
208 | if not periodic:
209 | outputs = jnp.where(out_of_bounds, inputs, outputs)
210 |
211 | # [1] Appendix A.2
212 | # calculate the log determinant
213 | dnum = (
214 | input_dkp1 * relx**2
215 | + 2 * input_sk * relx * (1 - relx)
216 | + input_dk * (1 - relx) ** 2
217 | )
218 | dden = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * (
219 | 1 - relx
220 | )
221 | log_det = 2 * jnp.log(input_sk) + jnp.log(dnum) - 2 * jnp.log(dden)
222 | # if not periodic, replace log_det for out-of-bounds values = 0
223 | if not periodic:
224 | log_det = jnp.where(out_of_bounds, 0, log_det)
225 | log_det = log_det.sum(axis=1)
226 |
227 | return outputs, log_det
228 |
229 |
230 | def sub_diag_indices(
231 | inputs: jnp.ndarray,
232 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
233 | """Return indices for diagonal of 2D blocks in 3D array"""
234 | if inputs.ndim != 3:
235 | raise ValueError("Input must be a 3D array.")
236 | nblocks = inputs.shape[0]
237 | ndiag = min(inputs.shape[1], inputs.shape[2])
238 | idx = (
239 | jnp.repeat(jnp.arange(nblocks), ndiag),
240 | jnp.tile(jnp.arange(ndiag), nblocks),
241 | jnp.tile(jnp.arange(ndiag), nblocks),
242 | )
243 | return idx
244 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | extend-ignore = E203
3 | docstring-convention = numpy
4 | ban-relative-imports = true
5 |
6 | [mypy]
7 | follow_imports = silent
8 | warn_redundant_casts = True
9 | warn_unused_ignores = True
10 | disallow_any_generics = False
11 | check_untyped_defs = True
12 | implicit_reexport = False
13 | disallow_untyped_defs = True
14 | ignore_missing_imports = True
15 |
--------------------------------------------------------------------------------
/tests/test_bijectors.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import pytest
3 | from jax import jit, random
4 |
5 | from pzflow.bijectors import *
6 |
7 | x = jnp.array(
8 | [
9 | [0.2, 0.1, -0.3, 0.5, 0.1, -0.4, -0.3],
10 | [0.6, 0.5, 0.2, 0.2, -0.4, -0.1, 0.7],
11 | [0.9, 0.2, -0.3, 0.3, 0.4, -0.4, -0.1],
12 | ]
13 | )
14 |
15 |
16 | @pytest.mark.parametrize(
17 | "bijector,args,conditions",
18 | [
19 | (ColorTransform, (3, [1, 3, 5]), jnp.zeros((3, 1))),
20 | (Reverse, (), jnp.zeros((3, 1))),
21 | (Roll, (2,), jnp.zeros((3, 1))),
22 | (Scale, (2.0,), jnp.zeros((3, 1))),
23 | (Shuffle, (), jnp.zeros((3, 1))),
24 | (InvSoftplus, (0,), jnp.zeros((3, 1))),
25 | (InvSoftplus, ([1, 3], [2.0, 12.0]), jnp.zeros((3, 1))),
26 | (
27 | StandardScaler,
28 | (jnp.linspace(-1, 1, 7), jnp.linspace(1, 8, 7)),
29 | jnp.zeros((3, 1)),
30 | ),
31 | (Chain, (Reverse(), Scale(1 / 6), Roll(-1)), jnp.zeros((3, 1))),
32 | (NeuralSplineCoupling, (), jnp.zeros((3, 1))),
33 | (
34 | NeuralSplineCoupling,
35 | (16, 3, 2, 128, 3),
36 | jnp.arange(9).reshape(3, 3),
37 | ),
38 | (RollingSplineCoupling, (2,), jnp.zeros((3, 1))),
39 | (
40 | RollingSplineCoupling,
41 | (2, 1, 16, 3, 2, 128, None, 0, True),
42 | jnp.zeros((3, 1)),
43 | ),
44 | (ShiftBounds, (-0.5, 0.9, 5), jnp.zeros((3, 1))),
45 | (
46 | ShiftBounds,
47 | (-1 * jnp.ones(7), 1.1 * jnp.ones(7), 3),
48 | jnp.zeros((3, 1)),
49 | ),
50 | ],
51 | )
52 | class TestBijectors:
53 | def test_returns_correct_shape(self, bijector, args, conditions):
54 | init_fun, bijector_info = bijector(*args)
55 | params, forward_fun, inverse_fun = init_fun(
56 | random.PRNGKey(0), x.shape[-1]
57 | )
58 |
59 | fwd_outputs, fwd_log_det = forward_fun(
60 | params, x, conditions=conditions
61 | )
62 | assert fwd_outputs.shape == x.shape
63 | assert fwd_log_det.shape == x.shape[:1]
64 |
65 | inv_outputs, inv_log_det = inverse_fun(
66 | params, x, conditions=conditions
67 | )
68 | assert inv_outputs.shape == x.shape
69 | assert inv_log_det.shape == x.shape[:1]
70 |
71 | def test_is_bijective(self, bijector, args, conditions):
72 | init_fun, bijector_info = bijector(*args)
73 | params, forward_fun, inverse_fun = init_fun(
74 | random.PRNGKey(0), x.shape[-1]
75 | )
76 |
77 | fwd_outputs, fwd_log_det = forward_fun(
78 | params, x, conditions=conditions
79 | )
80 | inv_outputs, inv_log_det = inverse_fun(
81 | params, fwd_outputs, conditions=conditions
82 | )
83 |
84 | print(inv_outputs)
85 | assert jnp.allclose(inv_outputs, x, atol=1e-6)
86 | assert jnp.allclose(fwd_log_det, -inv_log_det, atol=1e-6)
87 |
88 | def test_is_jittable(self, bijector, args, conditions):
89 | init_fun, bijector_info = bijector(*args)
90 | params, forward_fun, inverse_fun = init_fun(
91 | random.PRNGKey(0), x.shape[-1]
92 | )
93 |
94 | fwd_outputs_1, fwd_log_det_1 = forward_fun(
95 | params, x, conditions=conditions
96 | )
97 | forward_fun = jit(forward_fun)
98 | fwd_outputs_2, fwd_log_det_2 = forward_fun(
99 | params, x, conditions=conditions
100 | )
101 |
102 | inv_outputs_1, inv_log_det_1 = inverse_fun(
103 | params, x, conditions=conditions
104 | )
105 | inverse_fun = jit(inverse_fun)
106 | inv_outputs_2, inv_log_det_2 = inverse_fun(
107 | params, x, conditions=conditions
108 | )
109 |
110 |
111 | @pytest.mark.parametrize(
112 | "bijector,args",
113 | [
114 | (ColorTransform, (0, [1, 2, 3, 4])),
115 | (ColorTransform, (1.3, [1, 2, 3, 4])),
116 | (ColorTransform, (1, [2, 3, 4])),
117 | (Roll, (2.4,)),
118 | (Scale, (2,)),
119 | (Scale, (jnp.arange(7),)),
120 | (InvSoftplus, ([0, 1, 2], [1.0, 2.0])),
121 | (
122 | RollingSplineCoupling,
123 | (2, 1, 16, 3, 2, 128, None, 0, "fake"),
124 | ),
125 | (ShiftBounds, (4, 2, 1)),
126 | (ShiftBounds, (jnp.array([0, 1]), 2, 1)),
127 | ],
128 | )
129 | def test_bad_inputs(bijector, args):
130 | with pytest.raises(ValueError):
131 | bijector(*args)
132 |
133 |
134 | def test_uniform_dequantizer_returns_correct_shape():
135 | init_fun, bijector_info = UniformDequantizer([1, 3, 4])
136 | params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), x.shape[-1])
137 |
138 | conditions = jnp.zeros((3, 1))
139 | fwd_outputs, fwd_log_det = forward_fun(params, x, conditions=conditions)
140 | assert fwd_outputs.shape == x.shape
141 | assert fwd_log_det.shape == x.shape[:1]
142 |
143 | inv_outputs, inv_log_det = inverse_fun(params, x, conditions=conditions)
144 | assert inv_outputs.shape == x.shape
145 | assert inv_log_det.shape == x.shape[:1]
146 |
--------------------------------------------------------------------------------
/tests/test_distributions.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import pytest
3 |
4 | from pzflow.distributions import *
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "distribution,inputs,params",
9 | [
10 | (CentBeta, (2, 3), ((0, 1), (2, 3))),
11 | (Normal, (2,), ()),
12 | (Tdist, (2,), jnp.log(30.0)),
13 | (Uniform, (2,), ()),
14 | (Joint, (Normal(1), Uniform(1, 4)), ((), ())),
15 | (Joint, (Normal(1), Tdist(1)), ((), jnp.log(30.0))),
16 | (Joint, (Joint(Normal(1), Uniform(1)).info[1]), ((), ())),
17 | (CentBeta13, (2, 4), ()),
18 | ],
19 | )
20 | class TestDistributions:
21 | def test_returns_correct_shapes(self, distribution, inputs, params):
22 | dist = distribution(*inputs)
23 |
24 | nsamples = 8
25 | samples = dist.sample(params, nsamples)
26 | assert samples.shape == (nsamples, 2)
27 |
28 | log_prob = dist.log_prob(params, samples)
29 | assert log_prob.shape == (nsamples,)
30 |
31 | def test_control_sample_randomness(self, distribution, inputs, params):
32 | dist = distribution(*inputs)
33 |
34 | nsamples = 8
35 | s1 = dist.sample(params, nsamples)
36 | s2 = dist.sample(params, nsamples)
37 | assert ~jnp.all(jnp.isclose(s1, s2))
38 |
39 | s1 = dist.sample(params, nsamples, seed=0)
40 | s2 = dist.sample(params, nsamples, seed=0)
41 | assert jnp.allclose(s1, s2)
42 |
--------------------------------------------------------------------------------
/tests/test_ensemble.py:
--------------------------------------------------------------------------------
1 | import dill as pickle
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from jax import random
7 | from jax.scipy.integrate import trapezoid
8 |
9 | from pzflow import Flow, FlowEnsemble
10 | from pzflow.bijectors import Reverse, RollingSplineCoupling
11 |
12 | flowEns = FlowEnsemble(("x", "y"), RollingSplineCoupling(nlayers=2), N=2)
13 | flow0 = Flow(("x", "y"), RollingSplineCoupling(nlayers=2), seed=0)
14 | flow1 = Flow(("x", "y"), RollingSplineCoupling(nlayers=2), seed=1)
15 |
16 | xarray = np.arange(6).reshape(3, 2) / 10
17 | x = pd.DataFrame(xarray, columns=("x", "y"))
18 |
19 |
20 | def test_log_prob():
21 | lpEns = flowEns.log_prob(x, returnEnsemble=True)
22 | assert lpEns.shape == (3, 2)
23 |
24 | lp0 = flow0.log_prob(x)
25 | lp1 = flow1.log_prob(x)
26 | assert jnp.allclose(lpEns[:, 0], lp0)
27 | assert jnp.allclose(lpEns[:, 1], lp1)
28 |
29 | lpEnsMean = flowEns.log_prob(x)
30 | assert lpEnsMean.shape == lp0.shape
31 |
32 | manualMean = jnp.log(
33 | jnp.mean(jnp.array([jnp.exp(lp0), jnp.exp(lp1)]), axis=0)
34 | )
35 | assert jnp.allclose(lpEnsMean, manualMean)
36 |
37 |
38 | def test_posterior():
39 | grid = jnp.linspace(-1, 1, 5)
40 |
41 | pEns = flowEns.posterior(x, "x", grid, returnEnsemble=True)
42 | assert pEns.shape == (3, 2, grid.size)
43 |
44 | p0 = flow0.posterior(x, "x", grid)
45 | p1 = flow1.posterior(x, "x", grid)
46 | assert jnp.allclose(pEns[:, 0, :], p0)
47 | assert jnp.allclose(pEns[:, 1, :], p1)
48 |
49 | pEnsMean = flowEns.posterior(x, "x", grid)
50 | assert pEnsMean.shape == p0.shape
51 |
52 | p0 = flow0.posterior(x, "x", grid, normalize=False)
53 | p1 = flow1.posterior(x, "x", grid, normalize=False)
54 | manualMean = (p0 + p1) / 2
55 | manualMean = manualMean / trapezoid(y=manualMean, x=grid).reshape(-1, 1)
56 | assert jnp.allclose(pEnsMean, manualMean)
57 |
58 |
59 | def test_sample():
60 | # first test everything with returnEnsemble=False
61 | sEns = flowEns.sample(10, seed=0).values
62 | assert sEns.shape == (10, 2)
63 |
64 | s0 = flow0.sample(5, seed=0)
65 | s1 = flow1.sample(5, seed=0)
66 | sManual = jnp.vstack([s0.values, s1.values])
67 | assert jnp.allclose(
68 | sEns[sEns[:, 0].argsort()], sManual[sManual[:, 0].argsort()]
69 | )
70 |
71 | # now test everything with returnEnsemble=True
72 | sEns = flowEns.sample(10, seed=0, returnEnsemble=True).values
73 | assert sEns.shape == (20, 2)
74 |
75 | s0 = flow0.sample(10, seed=0)
76 | s1 = flow1.sample(10, seed=0)
77 | sManual = jnp.vstack([s0.values, s1.values])
78 | assert jnp.allclose(sEns, sManual)
79 |
80 |
81 | def test_conditional_sample():
82 | cEns = FlowEnsemble(
83 | ("x", "y"),
84 | RollingSplineCoupling(nlayers=2, n_conditions=2),
85 | conditional_columns=("a", "b"),
86 | N=2,
87 | )
88 |
89 | # test with nsamples = 1, fewer samples than flows
90 | conditions = pd.DataFrame(np.arange(2).reshape(-1, 2), columns=("a", "b"))
91 | samples = cEns.sample(
92 | nsamples=1, conditions=conditions, save_conditions=False
93 | )
94 | assert samples.shape == (1, 2)
95 |
96 | # test with nsamples = 1, more samples than flows
97 | conditions = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=("a", "b"))
98 | samples = cEns.sample(
99 | nsamples=1, conditions=conditions, save_conditions=False
100 | )
101 | assert samples.shape == (5, 2)
102 |
103 | # test with nsamples = 2, more samples than flows
104 | conditions = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=("a", "b"))
105 | samples = cEns.sample(
106 | nsamples=2, conditions=conditions, save_conditions=False
107 | )
108 | assert samples.shape == (10, 2)
109 |
110 | # test with returnEnsemble=True
111 | conditions = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=("a", "b"))
112 | samples = cEns.sample(
113 | nsamples=1,
114 | conditions=conditions,
115 | save_conditions=False,
116 | returnEnsemble=True,
117 | )
118 | assert samples.shape == (10, 2)
119 |
120 |
121 | def test_train():
122 | data = random.normal(random.PRNGKey(0), shape=(100, 2))
123 | data = pd.DataFrame(np.array(data), columns=("x", "y"))
124 |
125 | loss_dict = flowEns.train(data, epochs=4, batch_size=50, verbose=True)
126 |
127 | rng = np.random.default_rng(0)
128 | seeds = rng.integers(1e9, size=2)
129 | losses0 = flow0.train(data, epochs=4, batch_size=50, seed=seeds[0])
130 | losses1 = flow1.train(data, epochs=4, batch_size=50, seed=seeds[1])
131 |
132 | assert jnp.allclose(jnp.array(loss_dict["Flow 0"]), jnp.array(losses0))
133 | assert jnp.allclose(jnp.array(loss_dict["Flow 1"]), jnp.array(losses1))
134 |
135 |
136 | def test_load_ensemble(tmp_path):
137 | flowEns = FlowEnsemble(("x", "y"), RollingSplineCoupling(nlayers=2), N=2)
138 |
139 | preSave = flowEns.sample(10, seed=0)
140 |
141 | file = tmp_path / "test-ensemble.pzflow.pkl"
142 | flowEns.save(str(file))
143 |
144 | file = tmp_path / "test-ensemble.pzflow.pkl"
145 | flowEns = FlowEnsemble(file=str(file))
146 |
147 | postSave = flowEns.sample(10, seed=0)
148 |
149 | assert jnp.allclose(preSave.values, postSave.values)
150 |
151 | with open(str(file), "rb") as handle:
152 | save_dict = pickle.load(handle)
153 | save_dict["class"] = "Flow"
154 | with open(str(file), "wb") as handle:
155 | pickle.dump(save_dict, handle, recurse=True)
156 | with pytest.raises(TypeError):
157 | FlowEnsemble(file=str(file))
158 |
159 |
160 | @pytest.mark.parametrize(
161 | "data_columns,bijector,info,file",
162 | [
163 | (None, None, None, None),
164 | (None, Reverse(), None, None),
165 | (("x", "y"), None, None, "file"),
166 | (None, Reverse(), None, "file"),
167 | (None, None, "fake", "file"),
168 | ],
169 | )
170 | def test_bad_inputs(data_columns, bijector, info, file):
171 | with pytest.raises(ValueError):
172 | FlowEnsemble(data_columns, bijector=bijector, info=info, file=file)
173 |
--------------------------------------------------------------------------------
/tests/test_examples.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import pandas as pd
3 |
4 | from pzflow import Flow, examples
5 |
6 |
7 | def test_get_twomoons_data():
8 | data = examples.get_twomoons_data()
9 | assert isinstance(data, pd.DataFrame)
10 | assert data.shape == (100_000, 2)
11 |
12 |
13 | def test_get_galaxy_data():
14 | data = examples.get_galaxy_data()
15 | assert isinstance(data, pd.DataFrame)
16 | assert data.shape == (100_000, 7)
17 |
18 |
19 | def test_get_city_data():
20 | data = examples.get_city_data()
21 | assert isinstance(data, pd.DataFrame)
22 | assert data.shape == (47_966, 5)
23 |
24 |
25 | def test_get_checkerboard_data():
26 | data = examples.get_checkerboard_data()
27 | assert isinstance(data, pd.DataFrame)
28 | assert data.shape == (100_000, 2)
29 |
30 |
31 | def test_get_example_flow():
32 | flow = examples.get_example_flow()
33 | assert isinstance(flow, Flow)
34 | assert isinstance(flow.info, str)
35 |
36 | samples = flow.sample(2)
37 | flow.log_prob(samples)
38 |
39 | grid = jnp.arange(0, 2.5, 0.5)
40 | flow.posterior(samples, column="redshift", grid=grid)
41 |
--------------------------------------------------------------------------------
/tests/test_flow.py:
--------------------------------------------------------------------------------
1 | import dill as pickle
2 | import jax.numpy as jnp
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from jax import random
7 |
8 | from pzflow import Flow
9 | from pzflow.bijectors import Reverse, RollingSplineCoupling
10 | from pzflow.distributions import *
11 | from pzflow.examples import get_twomoons_data
12 |
13 |
14 | @pytest.mark.parametrize(
15 | "data_columns,bijector,info,file,_dictionary",
16 | [
17 | (None, None, None, None, None),
18 | (None, Reverse(), None, None, None),
19 | (("x", "y"), None, None, "file", None),
20 | (None, Reverse(), None, "file", None),
21 | (None, None, "fake", "file", None),
22 | (("x", "y"), Reverse(), None, None, "dict"),
23 | (None, None, None, "file", "dict"),
24 | ],
25 | )
26 | def test_bad_inputs(data_columns, bijector, info, file, _dictionary):
27 | with pytest.raises(ValueError):
28 | Flow(
29 | data_columns,
30 | bijector=bijector,
31 | info=info,
32 | file=file,
33 | _dictionary=_dictionary,
34 | )
35 |
36 |
37 | @pytest.mark.parametrize(
38 | "flow",
39 | [
40 | Flow(("redshift", "y"), Reverse(), latent=Normal(2)),
41 | Flow(("redshift", "y"), Reverse(), latent=Tdist(2)),
42 | Flow(("redshift", "y"), Reverse(), latent=Uniform(2, 10)),
43 | Flow(("redshift", "y"), Reverse(), latent=CentBeta(2, 10)),
44 | ],
45 | )
46 | def test_returns_correct_shape(flow):
47 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
48 | x = pd.DataFrame(xarray, columns=("redshift", "y"))
49 |
50 | conditions = flow._get_conditions(x)
51 |
52 | xfwd, xfwd_log_det = flow._forward(
53 | flow._params, xarray, conditions=conditions
54 | )
55 | assert xfwd.shape == x.shape
56 | assert xfwd_log_det.shape == (x.shape[0],)
57 |
58 | xinv, xinv_log_det = flow._inverse(
59 | flow._params, xarray, conditions=conditions
60 | )
61 | assert xinv.shape == x.shape
62 | assert xinv_log_det.shape == (x.shape[0],)
63 |
64 | nsamples = 4
65 | assert flow.sample(nsamples).shape == (nsamples, x.shape[1])
66 | assert flow.log_prob(x).shape == (x.shape[0],)
67 |
68 | grid = jnp.arange(0, 2.1, 0.12)
69 | pdfs = flow.posterior(x, column="y", grid=grid)
70 | assert pdfs.shape == (x.shape[0], grid.size)
71 | pdfs = flow.posterior(x.iloc[:, 1:], column="redshift", grid=grid)
72 | assert pdfs.shape == (x.shape[0], grid.size)
73 | pdfs = flow.posterior(
74 | x.iloc[:, 1:], column="redshift", grid=grid, batch_size=2
75 | )
76 | assert pdfs.shape == (x.shape[0], grid.size)
77 |
78 | assert len(flow.train(x, epochs=11, verbose=True)) == 12
79 | assert (
80 | len(flow.train(x, epochs=11, verbose=True, convolve_errs=True)) == 12
81 | )
82 |
83 |
84 | @pytest.mark.parametrize(
85 | "flag",
86 | [
87 | 99,
88 | np.nan,
89 | ],
90 | )
91 | def test_posterior_with_marginalization(flag):
92 | flow = Flow(("a", "b", "c", "d"), Reverse())
93 |
94 | # test posteriors with marginalization
95 | x = pd.DataFrame(
96 | np.arange(16).reshape(-1, 4), columns=("a", "b", "c", "d")
97 | )
98 | grid = np.arange(0, 2.1, 0.12)
99 |
100 | marg_rules = {
101 | "flag": flag,
102 | "b": lambda row: np.linspace(0, 1, 2),
103 | "c": lambda row: np.linspace(1, 2, 3),
104 | }
105 |
106 | x["b"] = flag * np.ones(x.shape[0])
107 | pdfs = flow.posterior(x, column="a", grid=grid, marg_rules=marg_rules)
108 | assert pdfs.shape == (x.shape[0], grid.size)
109 |
110 | x["c"] = flag * np.ones(x.shape[0])
111 | pdfs = flow.posterior(x, column="a", grid=grid, marg_rules=marg_rules)
112 | assert pdfs.shape == (x.shape[0], grid.size)
113 |
114 |
115 | @pytest.mark.parametrize(
116 | "flow,x,x_with_err",
117 | [
118 | (
119 | Flow(
120 | ("redshift", "y"), RollingSplineCoupling(2), latent=Normal(2)
121 | ),
122 | pd.DataFrame(
123 | np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
124 | columns=("redshift", "y"),
125 | ),
126 | pd.DataFrame(
127 | np.array(
128 | [
129 | [1.0, 2.0, 0.1, 0.2],
130 | [3.0, 4.0, 0.2, 0.3],
131 | [5.0, 6.0, 0.1, 0.2],
132 | ]
133 | ),
134 | columns=("redshift", "y", "redshift_err", "y_err"),
135 | ),
136 | ),
137 | (
138 | Flow(
139 | ("redshift", "y"),
140 | RollingSplineCoupling(2, n_conditions=2),
141 | latent=Normal(2),
142 | conditional_columns=("a", "b"),
143 | ),
144 | pd.DataFrame(
145 | np.array(
146 | [
147 | [1.0, 2.0, 10, 20],
148 | [3.0, 4.0, 30, 40],
149 | [5.0, 6.0, 50, 60],
150 | ]
151 | ),
152 | columns=("redshift", "y", "a", "b"),
153 | ),
154 | pd.DataFrame(
155 | np.array(
156 | [
157 | [1.0, 2.0, 10, 20, 0.1, 0.2, 1, 2],
158 | [3.0, 4.0, 30, 40, 0.2, 0.3, 3, 4],
159 | [5.0, 6.0, 50, 60, 0.1, 0.2, 5, 6],
160 | ]
161 | ),
162 | columns=(
163 | "redshift",
164 | "y",
165 | "a",
166 | "b",
167 | "redshift_err",
168 | "y_err",
169 | "a_err",
170 | "b_err",
171 | ),
172 | ),
173 | ),
174 | (
175 | Flow(
176 | ("redshift", "y"),
177 | RollingSplineCoupling(2, n_conditions=1),
178 | latent=Normal(2),
179 | conditional_columns=("a",),
180 | ),
181 | pd.DataFrame(
182 | np.array([[1.0, 2.0, 10], [3.0, 4.0, 30], [5.0, 6.0, 50]]),
183 | columns=("redshift", "y", "a"),
184 | ),
185 | pd.DataFrame(
186 | np.array(
187 | [
188 | [1.0, 2.0, 10, 0.1, 0.2, 1],
189 | [3.0, 4.0, 30, 0.2, 0.3, 3],
190 | [5.0, 6.0, 50, 0.1, 0.2, 5],
191 | ]
192 | ),
193 | columns=(
194 | "redshift",
195 | "y",
196 | "a",
197 | "redshift_err",
198 | "y_err",
199 | "a_err",
200 | ),
201 | ),
202 | ),
203 | (
204 | Flow(
205 | ("y",),
206 | RollingSplineCoupling(1, n_conditions=2),
207 | latent=Normal(1),
208 | conditional_columns=("a", "b"),
209 | ),
210 | pd.DataFrame(
211 | np.array([[1.0, 10, 20], [3.0, 30, 40], [5.0, 50, 60]]),
212 | columns=("y", "a", "b"),
213 | ),
214 | pd.DataFrame(
215 | np.array(
216 | [
217 | [1.0, 10, 20, 0.1, 1, 2],
218 | [3.0, 30, 40, 0.2, 3, 4],
219 | [5.0, 50, 60, 0.1, 5, 6],
220 | ]
221 | ),
222 | columns=(
223 | "y",
224 | "a",
225 | "b",
226 | "y_err",
227 | "a_err",
228 | "b_err",
229 | ),
230 | ),
231 | ),
232 | ],
233 | )
234 | def test_error_convolution(flow, x, x_with_err):
235 | assert flow.log_prob(x, err_samples=10).shape == (x.shape[0],)
236 | assert jnp.allclose(
237 | flow.log_prob(x, err_samples=10, seed=0),
238 | flow.log_prob(x),
239 | )
240 | assert ~jnp.allclose(
241 | flow.log_prob(x_with_err, err_samples=10, seed=0),
242 | flow.log_prob(x_with_err),
243 | )
244 | assert jnp.allclose(
245 | flow.log_prob(x_with_err, err_samples=10, seed=0),
246 | flow.log_prob(x_with_err, err_samples=10, seed=0),
247 | )
248 | assert ~jnp.allclose(
249 | flow.log_prob(x_with_err, err_samples=10, seed=0),
250 | flow.log_prob(x_with_err, err_samples=10, seed=1),
251 | )
252 | assert ~jnp.allclose(
253 | flow.log_prob(x_with_err, err_samples=10),
254 | flow.log_prob(x_with_err, err_samples=10),
255 | )
256 |
257 | grid = jnp.arange(0, 2.1, 0.12)
258 | pdfs = flow.posterior(x, column="y", grid=grid, err_samples=10)
259 | assert pdfs.shape == (x.shape[0], grid.size)
260 | assert jnp.allclose(
261 | flow.posterior(x, column="y", grid=grid, err_samples=10, seed=0),
262 | flow.posterior(x, column="y", grid=grid),
263 | rtol=1e-4,
264 | )
265 | assert jnp.allclose(
266 | flow.posterior(
267 | x_with_err, column="y", grid=grid, err_samples=10, seed=0
268 | ),
269 | flow.posterior(
270 | x_with_err, column="y", grid=grid, err_samples=10, seed=0
271 | ),
272 | )
273 |
274 |
275 | def test_posterior_batch():
276 | columns = ("redshift", "y")
277 | flow = Flow(columns, Reverse())
278 |
279 | xarray = np.array([[1, 2], [3, 4], [5, 6]])
280 | x = pd.DataFrame(xarray, columns=columns)
281 |
282 | grid = jnp.arange(0, 2.1, 0.12)
283 | pdfs = flow.posterior(x.iloc[:, 1:], column="redshift", grid=grid)
284 | pdfs_batched = flow.posterior(
285 | x.iloc[:, 1:], column="redshift", grid=grid, batch_size=2
286 | )
287 | assert jnp.allclose(pdfs, pdfs_batched)
288 |
289 |
290 | def test_flow_bijection():
291 | columns = ("x", "y")
292 | flow = Flow(columns, Reverse())
293 |
294 | x = jnp.array([[1, 2], [3, 4]])
295 | xrev = jnp.array([[2, 1], [4, 3]])
296 |
297 | assert jnp.allclose(flow._forward(flow._params, x)[0], xrev)
298 | assert jnp.allclose(
299 | flow._inverse(flow._params, flow._forward(flow._params, x)[0])[0], x
300 | )
301 | assert jnp.allclose(
302 | flow._forward(flow._params, x)[1], flow._inverse(flow._params, x)[1]
303 | )
304 |
305 |
306 | def test_load_flow(tmp_path):
307 | columns = ("x", "y")
308 | flow = Flow(columns, Reverse(), info=["random", 42])
309 |
310 | file = tmp_path / "test-flow.pzflow.pkl"
311 | flow.save(str(file))
312 |
313 | file = tmp_path / "test-flow.pzflow.pkl"
314 | flow = Flow(file=str(file))
315 |
316 | x = jnp.array([[1, 2], [3, 4]])
317 | xrev = jnp.array([[2, 1], [4, 3]])
318 |
319 | assert jnp.allclose(flow._forward(flow._params, x)[0], xrev)
320 | assert jnp.allclose(
321 | flow._inverse(flow._params, flow._forward(flow._params, x)[0])[0], x
322 | )
323 | assert jnp.allclose(
324 | flow._forward(flow._params, x)[1], flow._inverse(flow._params, x)[1]
325 | )
326 | assert flow.info == ["random", 42]
327 |
328 | with open(str(file), "rb") as handle:
329 | save_dict = pickle.load(handle)
330 | save_dict["class"] = "FlowEnsemble"
331 | with open(str(file), "wb") as handle:
332 | pickle.dump(save_dict, handle, recurse=True)
333 | with pytest.raises(TypeError):
334 | Flow(file=str(file))
335 |
336 |
337 | def test_control_sample_randomness():
338 | columns = ("x", "y")
339 | flow = Flow(columns, Reverse())
340 |
341 | assert np.all(~np.isclose(flow.sample(2), flow.sample(2)))
342 | assert np.allclose(flow.sample(2, seed=0), flow.sample(2, seed=0))
343 |
344 |
345 | @pytest.mark.parametrize(
346 | "epochs,loss_fn,",
347 | [
348 | (-1, None),
349 | (2.4, None),
350 | ("a", None),
351 | ],
352 | )
353 | def test_train_bad_inputs(epochs, loss_fn):
354 | columns = ("redshift", "y")
355 | flow = Flow(columns, Reverse())
356 |
357 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
358 | x = pd.DataFrame(xarray, columns=columns)
359 |
360 | with pytest.raises(ValueError):
361 | flow.train(
362 | x,
363 | epochs=epochs,
364 | loss_fn=loss_fn,
365 | )
366 |
367 |
368 | def test_conditional_sample():
369 | flow = Flow(("x", "y"), Reverse(), conditional_columns=("a", "b"))
370 | x = np.arange(12).reshape(-1, 4)
371 | x = pd.DataFrame(x, columns=("x", "y", "a", "b"))
372 |
373 | conditions = flow._get_conditions(x)
374 | assert conditions.shape == (x.shape[0], 2)
375 |
376 | with pytest.raises(ValueError):
377 | flow.sample(4)
378 |
379 | samples = flow.sample(4, conditions=x)
380 | assert samples.shape == (4 * x.shape[0], 4)
381 |
382 | samples = flow.sample(4, conditions=x, save_conditions=False)
383 | assert samples.shape == (4 * x.shape[0], 2)
384 |
385 |
386 | def test_train_no_errs_same():
387 | columns = ("redshift", "y")
388 | flow = Flow(columns, Reverse())
389 |
390 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
391 | x = pd.DataFrame(xarray, columns=columns)
392 |
393 | losses1 = flow.train(x, convolve_errs=True)
394 | losses2 = flow.train(x, convolve_errs=False)
395 | assert jnp.allclose(jnp.array(losses1), jnp.array(losses2))
396 |
397 |
398 | def test_get_err_samples():
399 | rng = random.PRNGKey(0)
400 |
401 | # check Gaussian data samples
402 | columns = ("x", "y")
403 | flow = Flow(columns, Reverse())
404 | xarray = np.array([[1.0, 2.0, 0.1, 0.2], [3.0, 4.0, 0.3, 0.4]])
405 | x = pd.DataFrame(xarray, columns=("x", "y", "x_err", "y_err"))
406 | samples = flow._get_err_samples(rng, x, 10)
407 | assert samples.shape == (20, 2)
408 |
409 | # test skip
410 | xarray = np.array([[1.0, 2.0, 0, 0]])
411 | x = pd.DataFrame(xarray, columns=("x", "y", "x_err", "y_err"))
412 | samples = flow._get_err_samples(rng, x, 10, skip="y")
413 | assert jnp.allclose(samples, jnp.ones((10, 1)))
414 | samples = flow._get_err_samples(rng, x, 10, skip="x")
415 | assert jnp.allclose(samples, 2 * jnp.ones((10, 1)))
416 |
417 | # check Gaussian conditional samples
418 | flow = Flow(("x"), Reverse(), conditional_columns=("y"))
419 | samples = flow._get_err_samples(rng, x, 10, type="conditions")
420 | assert jnp.allclose(samples, 2 * jnp.ones((10, 1)))
421 |
422 | # check incorrect type
423 | with pytest.raises(ValueError):
424 | flow._get_err_samples(rng, x, 10, type="wrong")
425 |
426 | # check constant shift data samples
427 | columns = ("x", "y")
428 | shift_err_model = lambda key, X, Xerr, nsamples: jnp.repeat(
429 | X + Xerr, nsamples, axis=0
430 | ).reshape(X.shape[0], nsamples, X.shape[1])
431 | flow = Flow(columns, Reverse(), data_error_model=shift_err_model)
432 | xarray = np.array([[1.0, 2.0, 0.1, 0.2], [3.0, 4.0, 0.3, 0.4]])
433 | x = pd.DataFrame(xarray, columns=("x", "y", "x_err", "y_err"))
434 | samples = flow._get_err_samples(rng, x, 10)
435 | assert samples.shape == (20, 2)
436 | assert jnp.allclose(
437 | samples,
438 | shift_err_model(None, xarray[:, :2], xarray[:, 2:], 10).reshape(20, 2),
439 | )
440 |
441 | # check constant shift conditional samples
442 | flow = Flow(
443 | ("x"),
444 | Reverse(),
445 | conditional_columns=("y"),
446 | condition_error_model=shift_err_model,
447 | )
448 | samples = flow._get_err_samples(rng, x, 10, type="conditions")
449 | assert jnp.allclose(
450 | samples, jnp.repeat(jnp.array([[2.2], [4.4]]), 10, axis=0)
451 | )
452 |
453 |
454 | def test_train_w_conditions():
455 | xarray = np.array(
456 | [[1.0, 2.0, 0.1, 0.2], [3.0, 4.0, 0.3, 0.4], [5.0, 6.0, 0.5, 0.6]]
457 | )
458 | x = pd.DataFrame(xarray, columns=("redshift", "y", "a", "b"))
459 |
460 | flow = Flow(
461 | ("redshift", "y"),
462 | Reverse(),
463 | latent=Normal(2),
464 | conditional_columns=("a", "b"),
465 | )
466 | assert len(flow.train(x, epochs=11)) == 12
467 |
468 | print("------->>>>>")
469 | print(flow._condition_stds, "\n\n")
470 | print(xarray[:, 2:].std(axis=0))
471 | assert jnp.allclose(flow._condition_means, xarray[:, 2:].mean(axis=0))
472 | assert jnp.allclose(flow._condition_stds, xarray[:, 2:].std(axis=0))
473 |
474 |
475 | def test_patience():
476 | columns = ("redshift", "y")
477 | flow = Flow(columns, Reverse())
478 |
479 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
480 | x = pd.DataFrame(xarray, columns=columns)
481 |
482 | losses = flow.train(x, patience=2)
483 | print(losses)
484 | assert len(losses) == 4
485 |
486 |
487 | def test_latent_with_wrong_dimension():
488 | cols = ["x", "y"]
489 | latent = Uniform(3)
490 |
491 | with pytest.raises(ValueError):
492 | Flow(data_columns=cols, latent=latent, bijector=Reverse())
493 |
494 |
495 | def test_bijector_not_set():
496 | flow = Flow(["x", "y"])
497 |
498 | with pytest.raises(ValueError):
499 | flow.sample(1)
500 |
501 | with pytest.raises(ValueError):
502 | x = np.linspace(0, 1, 12)
503 | df = pd.DataFrame(x.reshape(-1, 2), columns=("x", "y"))
504 | flow.posterior(x, column="x", grid=x)
505 |
506 |
507 | def test_default_bijector():
508 | flow = Flow(["x", "y"])
509 |
510 | losses = flow.train(get_twomoons_data())
511 | assert all(~np.isnan(losses))
512 |
513 |
514 | def test_validation_train():
515 | # load some training data
516 | data = get_twomoons_data()[:10]
517 | train_set = data[:8]
518 | val_set = data[8:]
519 |
520 | # train the default flow
521 | flow = Flow(train_set.columns, Reverse())
522 | losses = flow.train(
523 | train_set,
524 | val_set,
525 | verbose=True,
526 | epochs=3,
527 | best_params=False,
528 | )
529 | assert len(losses[0]) == 4
530 | assert len(losses[1]) == 4
531 |
532 |
533 | def test_nan_train_stop():
534 | # create data with NaNs
535 | data = jnp.nan * jnp.ones((4, 2))
536 | data = pd.DataFrame(data, columns=["x", "y"])
537 |
538 | # train the flow
539 | flow = Flow(data.columns, Reverse())
540 | losses = flow.train(data)
541 | assert len(losses) == 2
542 |
543 | def test_train_weights():
544 | # load some training data
545 | data = get_twomoons_data()[:10]
546 | train_set = data[:8]
547 | val_set = data[8:]
548 | train_weight = np.linspace(1, 2, len(train_set))
549 | val_weight = np.linspace(1, 2, len(val_set))
550 |
551 | # train the default flow
552 | flow = Flow(train_set.columns, Reverse())
553 | losses = flow.train(
554 | train_set,
555 | val_set,
556 | verbose=True,
557 | epochs=3,
558 | best_params=False,
559 | train_weight=train_weight,
560 | val_weight=val_weight,
561 | )
562 | assert len(losses[0]) == 4
563 | assert len(losses[1]) == 4
564 |
565 | def test_no_initial_loss():
566 | # load some training data
567 | data = get_twomoons_data()[:10]
568 | train_set = data[:8]
569 | val_set = data[8:]
570 |
571 | # train the default flow
572 | flow = Flow(train_set.columns, Reverse())
573 | losses = flow.train(
574 | train_set,
575 | val_set,
576 | verbose=True,
577 | epochs=3,
578 | best_params=False,
579 | initial_loss=False,
580 | )
581 | assert len(losses[0]) == 3
582 | assert len(losses[1]) == 3
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import pytest
3 | from jax import random
4 |
5 | from pzflow.bijectors import *
6 | from pzflow.utils import *
7 |
8 |
9 | def test_build_bijector_from_info():
10 | x = jnp.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])
11 |
12 | init_fun, info1 = Chain(
13 | Reverse(),
14 | Chain(ColorTransform(1, [1, 2, 3]), Roll(-1)),
15 | InvSoftplus(0, 1),
16 | Scale(-0.5),
17 | Chain(Roll(), Scale(-4.0)),
18 | )
19 |
20 | params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), 4)
21 | xfwd1, log_det1 = forward_fun(params, x)
22 |
23 | init_fun, info2 = build_bijector_from_info(info1)
24 | assert info1 == info2
25 |
26 | params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), 4)
27 | xfwd2, log_det2 = forward_fun(params, x)
28 | assert jnp.allclose(xfwd1, xfwd2)
29 | assert jnp.allclose(log_det1, log_det2)
30 |
31 | invx, inv_log_det = inverse_fun(params, xfwd2)
32 | assert jnp.allclose(x, invx)
33 | assert jnp.allclose(log_det2, -inv_log_det)
34 |
35 |
36 | def test_sub_diag_indices_correct():
37 | x = jnp.array([[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[2, 2], [2, 2]]])
38 | y = jnp.array([[[1, 0], [0, 1]], [[2, 1], [1, 2]], [[3, 2], [2, 3]]])
39 | idx = sub_diag_indices(x)
40 | x = x.at[idx].set(x[idx] + 1)
41 |
42 | assert jnp.allclose(x, y)
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "x",
47 | [jnp.ones(2), jnp.ones((2, 2)), jnp.ones((2, 2, 2, 2))],
48 | )
49 | def test_sub_diag_indices_bad_input(x):
50 | with pytest.raises(ValueError):
51 | idx = sub_diag_indices(x)
52 |
--------------------------------------------------------------------------------
/tests/test_version.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import toml
4 |
5 | import pzflow
6 |
7 |
8 | def test_versions_are_in_sync():
9 | """Checks if the pyproject.toml and pzflow.__init__.py __version__ are in sync."""
10 |
11 | path = Path(__file__).resolve().parents[1] / "pyproject.toml"
12 | pyproject = toml.loads(open(str(path)).read())
13 | pyproject_version = pyproject["tool"]["poetry"]["version"]
14 |
15 | package_init_version = pzflow.__version__
16 |
17 | assert package_init_version == pyproject_version
18 |
--------------------------------------------------------------------------------