├── .github └── workflows │ ├── build_docs.yml │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── contributing.md ├── gen_ref_pages.py ├── gotchas.md ├── index.md ├── install.md └── tutorials │ ├── conditional_demo.ipynb │ ├── customizing_example.ipynb │ ├── dequantization.ipynb │ ├── ensemble_demo.ipynb │ ├── gaussian_errors.ipynb │ ├── index.md │ ├── intro.ipynb │ ├── marginalization.ipynb │ ├── nongaussian_errors.ipynb │ ├── spherical_flow_example.ipynb │ └── weighted.ipynb ├── mkdocs-requirements.txt ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── pzflow ├── __init__.py ├── bijectors.py ├── distributions.py ├── example_files │ ├── checkerboard-data.pkl │ ├── city-data.pkl │ ├── example-flow.pzflow.pkl │ ├── galaxy-data.pkl │ └── two-moons-data.pkl ├── examples.py ├── flow.py ├── flowEnsemble.py └── utils.py ├── setup.cfg └── tests ├── test_bijectors.py ├── test_distributions.py ├── test_ensemble.py ├── test_examples.py ├── test_flow.py ├── test_utils.py └── test_version.py /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: 3.x 15 | - run: pip install -r mkdocs-requirements.txt 16 | - run: mkdocs gh-deploy --force 17 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # inspired by examples here: https://github.com/snok/install-poetry 3 | 4 | name: build 5 | 6 | on: 7 | push: 8 | branches: [main, dev] 9 | pull_request: 10 | branches: [main, dev] 11 | 12 | jobs: 13 | build: 14 | # ---------------------------------------------- 15 | # test linux and mac; several python versions 16 | # ---------------------------------------------- 17 | strategy: 18 | fail-fast: true 19 | matrix: 20 | os: [ubuntu-latest] 21 | python-version: ["3.10", "3.11", "3.12"] 22 | runs-on: ${{ matrix.os }} 23 | 24 | steps: 25 | # ---------------------------------------------- 26 | # checkout repository 27 | # ---------------------------------------------- 28 | - name: Check out repo 29 | uses: actions/checkout@v4 30 | 31 | # ---------------------------------------------- 32 | # install and configure poetry 33 | # ---------------------------------------------- 34 | - name: Install poetry 35 | run: pipx install poetry 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | cache: "poetry" 41 | 42 | # ---------------------------------------------- 43 | # install root project, if required 44 | # ---------------------------------------------- 45 | - name: Install library 46 | run: poetry install --no-interaction 47 | 48 | # ---------------------------------------------- 49 | # run linters 50 | # ---------------------------------------------- 51 | #- name: Run linters 52 | # run: | 53 | # poetry run flake8 . 54 | # poetry run black . --check 55 | # poetry run isort . 56 | 57 | # ---------------------------------------------- 58 | # run tests 59 | # ---------------------------------------------- 60 | - name: Test with pytest 61 | run: | 62 | poetry run pytest --cov=./pzflow --cov-report=xml -n 10 63 | 64 | # ---------------------------------------------- 65 | # upload coverage to Codecov 66 | # ---------------------------------------------- 67 | - name: Upload coverage to Codecov 68 | uses: codecov/codecov-action@v5 69 | with: 70 | file: ./coverage.xml 71 | flags: unittests 72 | env_vars: OS,PYTHON 73 | name: codecov-umbrella 74 | fail_ci_if_error: true 75 | token: ${{ secrets.CODECOV_TOKEN }} 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode settings 132 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 John Franklin Crenshaw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![build](https://github.com/jfcrenshaw/pzflow/workflows/build/badge.svg) 2 | [![codecov](https://codecov.io/gh/jfcrenshaw/pzflow/branch/main/graph/badge.svg?token=qR5cey0swQ)](https://codecov.io/gh/jfcrenshaw/pzflow) 3 | [![PyPI version](https://badge.fury.io/py/pzflow.svg)](https://badge.fury.io/py/pzflow) 4 | [![DOI](https://zenodo.org/badge/327498448.svg)](https://zenodo.org/badge/latestdoi/327498448) 5 | [![Docs](https://img.shields.io/badge/Docs-https%3A%2F%2Fjfcrenshaw.github.io%2Fpzflow%2F-red)](https://jfcrenshaw.github.io/pzflow/) 6 | 7 | # PZFlow 8 | 9 | PZFlow is a python package for probabilistic modeling of tabular data with normalizing flows. 10 | 11 | If your data consists of continuous variables that can be put into a Pandas DataFrame, pzflow can model the joint probability distribution of your data set. 12 | 13 | The `Flow` class makes building and training a normalizing flow simple. 14 | It also allows you to easily sample from the normalizing flow (e.g., for forward modeling or data augmentation), and calculate posteriors over any of your variables. 15 | 16 | There are several tutorial notebooks in the [docs](https://jfcrenshaw.github.io/pzflow/tutorials/). 17 | 18 | ## Installation 19 | 20 | See the instructions in the [docs](https://jfcrenshaw.github.io/pzflow/install/). 21 | 22 | ## Citation 23 | 24 | If you use this package in your research, please cite the following two sources: 25 | 26 | 1. The paper 27 | ```bibtex 28 | @ARTICLE{2024AJ....168...80C, 29 | author = {{Crenshaw}, John Franklin and {Kalmbach}, J. Bryce and {Gagliano}, Alexander and {Yan}, Ziang and {Connolly}, Andrew J. and {Malz}, Alex I. and {Schmidt}, Samuel J. and {The LSST Dark Energy Science Collaboration}}, 30 | title = "{Probabilistic Forward Modeling of Galaxy Catalogs with Normalizing Flows}", 31 | journal = {\aj}, 32 | keywords = {Neural networks, Galaxy photometry, Surveys, Computational methods, 1933, 611, 1671, 1965, Astrophysics - Instrumentation and Methods for Astrophysics, Astrophysics - Cosmology and Nongalactic Astrophysics}, 33 | year = 2024, 34 | month = aug, 35 | volume = {168}, 36 | number = {2}, 37 | eid = {80}, 38 | pages = {80}, 39 | doi = {10.3847/1538-3881/ad54bf}, 40 | archivePrefix = {arXiv}, 41 | eprint = {2405.04740}, 42 | primaryClass = {astro-ph.IM}, 43 | adsurl = {https://ui.adsabs.harvard.edu/abs/2024AJ....168...80C}, 44 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 45 | } 46 | ``` 47 | 48 | 2. The [Zenodo deposit](https://zenodo.org/records/10710271) associated with the version number you used. 49 | The Zenodo deposit 50 | 51 | 52 | ### Sources 53 | 54 | PZFlow was originally designed for forward modeling of photometric redshifts as a part of the Creation Module of the [DESC](https://lsstdesc.org/) [RAIL](https://github.com/LSSTDESC/RAIL) project. 55 | The idea to use normalizing flows for photometric redshifts originated with [Bryce Kalmbach](https://github.com/jbkalmbach). 56 | The earliest version of the normalizing flow in RAIL was based on a notebook by [Francois Lanusse](https://github.com/eiffl) and included contributions from [Alex Malz](https://github.com/aimalz). 57 | 58 | The functional jax structure of the bijectors was originally based on [`jax-flows`](https://github.com/ChrisWaites/jax-flows) by [Chris Waites](https://github.com/ChrisWaites). The implementation of the Neural Spline Coupling is largely based on the [Tensorflow implementation](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/rational_quadratic_spline.py), with some inspiration from [`nflows`](https://github.com/bayesiains/nflows/). 59 | 60 | Neural Spline Flows are based on the following papers: 61 | 62 | > [NICE: Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516)\ 63 | > Laurent Dinh, David Krueger, Yoshua Bengio\ 64 | > _arXiv:1410.8516_ 65 | 66 | > [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)\ 67 | > Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio\ 68 | > _arXiv:1605.08803_ 69 | 70 | > [Neural Spline Flows](https://arxiv.org/abs/1906.04032)\ 71 | > Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios\ 72 | > _arXiv:1906.04032_ 73 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to PZFlow 2 | 3 | If you notice any bugs, have any questions, or would like to request a feature, please [submit an issue](https://github.com/jfcrenshaw/pzflow/issues). 4 | 5 | To work on pzflow, after forking and cloning the repo: 6 | 7 | 1. Create a virtual environment with Python 8 | E.g., with conda `conda create -n pzflow` 9 | 2. Activate the environment. 10 | E.g., `conda activate pzflow` 11 | 3. Install pzflow in edit mode with the `dev` flag 12 | I.e., in the root directory, `pip install -e ".[dev]"` 13 | -------------------------------------------------------------------------------- /docs/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | # recurse through python files in pzflow 8 | for path in sorted(Path("pzflow").rglob("*.py")): 9 | # get the module path, not including the pzflow prefix 10 | module_path = path.with_suffix("") 11 | # path where we'll save the markdown file for this module 12 | doc_path = "API" / path.relative_to("pzflow").with_suffix(".md") 13 | 14 | # split up the path into its parts 15 | parts = list(module_path.parts) 16 | 17 | # we don't want to explicitly list __init__ 18 | if parts[-1] in ["__init__"]: 19 | #continue 20 | parts = parts[:-1] 21 | doc_path = doc_path.with_name("index.md") 22 | # skip main files 23 | elif parts[-1] == "__main__": 24 | continue 25 | 26 | # create the markdown file in the docs directory 27 | with mkdocs_gen_files.open(doc_path, "w") as fd: 28 | identifier = ".".join(parts) 29 | print(doc_path, identifier) 30 | print("::: " + identifier, file=fd) 31 | 32 | mkdocs_gen_files.set_edit_path(doc_path, path) 33 | -------------------------------------------------------------------------------- /docs/gotchas.md: -------------------------------------------------------------------------------- 1 | # Common gotchas 2 | 3 | * It is important to note that there are two different conventions in the literature for the direction of the bijection in normalizing flows. PZFlow defines the bijection as the mapping from the data space to the latent space, and the inverse bijection as the mapping from the latent space to the data space. This distinction can be important when designing more complicated bijections (e.g., in Example 2 above). 4 | 5 | * If you get NaNs during training, try decreasing the learning rate (the default is 1e-3): 6 | 7 | ```python 8 | import optax 9 | 10 | opt = optax.adam(learning_rate=...) 11 | flow.train(..., optimizer=opt) 12 | ``` 13 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ![build](https://github.com/jfcrenshaw/pzflow/workflows/build/badge.svg) 2 | [![codecov](https://codecov.io/gh/jfcrenshaw/pzflow/branch/main/graph/badge.svg?token=qR5cey0swQ)](https://codecov.io/gh/jfcrenshaw/pzflow) 3 | [![PyPI version](https://badge.fury.io/py/pzflow.svg)](https://badge.fury.io/py/pzflow) 4 | [![DOI](https://zenodo.org/badge/327498448.svg)](https://zenodo.org/badge/latestdoi/327498448) 5 | [![Docs](https://img.shields.io/badge/Docs-https%3A%2F%2Fjfcrenshaw.github.io%2Fpzflow%2F-red)](https://jfcrenshaw.github.io/pzflow/) 6 | 7 | # PZFlow 8 | 9 | PZFlow is a python package for probabilistic modeling of tabular data with normalizing flows. 10 | 11 | If your data consists of continuous variables that can be put into a Pandas DataFrame, pzflow can model the joint probability distribution of your data set. 12 | 13 | The `Flow` class makes building and training a normalizing flow simple. 14 | It also allows you to easily sample from the normalizing flow (e.g., for forward modeling or data augmentation), and calculate posteriors over any of your variables. 15 | 16 | There are several tutorial notebooks in the [docs](https://jfcrenshaw.github.io/pzflow/tutorials/). 17 | 18 | ## Installation 19 | 20 | See the instructions in the [docs](https://jfcrenshaw.github.io/pzflow/install/). 21 | 22 | ## Citation 23 | 24 | We are preparing a paper on pzflow. 25 | If you use this package in your research, please check back here for a citation before publication. 26 | In the meantime, please cite the [Zenodo release](https://zenodo.org/badge/latestdoi/327498448). 27 | 28 | ### Sources 29 | 30 | PZFlow was originally designed for forward modeling of photometric redshifts as a part of the Creation Module of the [DESC](https://lsstdesc.org/) [RAIL](https://github.com/LSSTDESC/RAIL) project. 31 | The idea to use normalizing flows for photometric redshifts originated with [Bryce Kalmbach](https://github.com/jbkalmbach). 32 | The earliest version of the normalizing flow in RAIL was based on a notebook by [Francois Lanusse](https://github.com/eiffl) and included contributions from [Alex Malz](https://github.com/aimalz). 33 | 34 | The functional jax structure of the bijectors was originally based on [`jax-flows`](https://github.com/ChrisWaites/jax-flows) by [Chris Waites](https://github.com/ChrisWaites). The implementation of the Neural Spline Coupling is largely based on the [Tensorflow implementation](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/rational_quadratic_spline.py), with some inspiration from [`nflows`](https://github.com/bayesiains/nflows/). 35 | 36 | Neural Spline Flows are based on the following papers: 37 | 38 | > [NICE: Non-linear Independent Components Estimation](https://arxiv.org/abs/1410.8516)\ 39 | > Laurent Dinh, David Krueger, Yoshua Bengio\ 40 | > _arXiv:1410.8516_ 41 | 42 | > [Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)\ 43 | > Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio\ 44 | > _arXiv:1605.08803_ 45 | 46 | > [Neural Spline Flows](https://arxiv.org/abs/1906.04032)\ 47 | > Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios\ 48 | > _arXiv:1906.04032_ 49 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Install 2 | 3 | You can install PZFlow from PyPI with pip: 4 | 5 | ```shell 6 | pip install pzflow 7 | ``` 8 | 9 | If you want to run PZFlow on a GPU with CUDA, you need to follow the GPU-enabled installation instructions for Jax [here](https://github.com/google/jax). 10 | You may also need to add cuda to your path. 11 | For example, I needed to add the following to my `.bashrc`: 12 | 13 | ```shell 14 | # cuda setup 15 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 16 | export PATH=$PATH:/usr/local/cuda/bin 17 | ``` 18 | 19 | If you have the GPU enabled version of jax installed, but would like to run on a CPU, add the following to the top of your scripts/notebooks: 20 | 21 | ```python 22 | import jax 23 | # Global flag to set a specific platform, must be used at startup. 24 | jax.config.update('jax_platform_name', 'cpu') 25 | ``` 26 | 27 | Note that if you run jax on GPU in multiple Jupyter notebooks simultaneously, you may get `RuntimeError: cuSolver internal error`. Read more [here](https://github.com/google/jax/issues/4497) and [here](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html). 28 | -------------------------------------------------------------------------------- /docs/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | Below are example notebooks demonstrating how to use PZFlow. 4 | Each contains a link to open the notebook on Google Colab, as well as a link to the source code on Github. 5 | 6 | ### Basic 7 | 8 | - [Introduction to PZFlow](intro.ipynb) - using the default flow to train, sample, and calculate posteriors 9 | - [Conditional Flows](conditional_demo.ipynb) - building a conditional flow to model conditional distributions 10 | - [Convolving Gaussian Errors](gaussian_errors.ipynb) - convolving Gaussian errors during training and posterior calculation 11 | - [Flow Ensembles](ensemble_demo.ipynb) - using `FlowEnsemble` to create an ensemble of normalizing flows 12 | - [Training Weights](weighted.ipynb) - giving different weights to your training samples 13 | 14 | ### Intermediate 15 | 16 | - [Customizing the flow](customizing_example.ipynb) - Customizing the bijector and latent space 17 | - [Modeling Variables with Periodic Topology](spherical_flow_example.ipynb) - using circular splines to model data with periodic topology, e.g. positions on a sphere 18 | 19 | ### Advanced 20 | 21 | - [Marginalizing Variables](marginalization.ipynb) - marginalizing over missing variables during posterior calculation 22 | - [Convolving Non-Gaussian Errors](nongaussian_errors.ipynb) - convolving non-Gaussian errors during training and posterior calculation 23 | -------------------------------------------------------------------------------- /docs/tutorials/nongaussian_errors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jfcrenshaw/pzflow/blob/main/docs/tutorials/nongaussian_errors.ipynb)\n", 8 | "[![Open on Github](https://img.shields.io/badge/github-Open%20on%20Github-black?logo=github)](https://github.com/jfcrenshaw/pzflow/blob/main/docs/tutorials/nongaussian_errors.ipynb)\n", 9 | "\n", 10 | "If running in Colab, to switch to GPU, go to the menu and select Runtime -> Change runtime type -> Hardware accelerator -> GPU.\n", 11 | "\n", 12 | "In addition, uncomment and run the following code:" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# !pip install pzflow matplotlib" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "------------------\n", 29 | "## Convolving non-Gaussian errors\n", 30 | "\n", 31 | "This notebook demonstrates how to train a flow on data that has non-Gaussian errors/uncertainties, as well as convolving those errors in log_prob and posterior calculations.\n", 32 | "We will use the example galaxy data again.\n", 33 | "\n", 34 | "For an example of how to handle Gaussian errors, which is much easier, see the notebook on Gaussian errors." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 5, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "import numpy as np\n", 44 | "import jax.numpy as jnp\n", 45 | "from jax import random\n", 46 | "import matplotlib.pyplot as plt\n", 47 | "\n", 48 | "from pzflow import Flow\n", 49 | "from pzflow.examples import get_galaxy_data" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 6, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "plt.rcParams[\"figure.facecolor\"] = \"white\"" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "First let's load the example galaxy data set included with PZFlow, and include photometric errors. For simplicity, we will assume all bands have error = 0.1" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 7, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/html": [ 76 | "
\n", 77 | "\n", 90 | "\n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | "
redshiftugrizyu_errg_errr_erri_errz_erry_err
00.28708726.75926125.90177825.18771024.93231824.73690324.6716230.10.10.10.10.10.1
10.29331327.42835826.67929925.97716125.70009425.52276325.4176320.10.10.10.10.10.1
21.49727627.29400126.06879825.45005524.46050723.88722123.2061120.10.10.10.10.10.1
30.28331028.15407526.28316624.59957023.72349123.21410822.8600120.10.10.10.10.10.1
41.54518329.27606527.87830127.33352826.54337426.06194125.3830560.10.10.10.10.10.1
\n", 192 | "
" 193 | ], 194 | "text/plain": [ 195 | " redshift u g r i z y \\\n", 196 | "0 0.287087 26.759261 25.901778 25.187710 24.932318 24.736903 24.671623 \n", 197 | "1 0.293313 27.428358 26.679299 25.977161 25.700094 25.522763 25.417632 \n", 198 | "2 1.497276 27.294001 26.068798 25.450055 24.460507 23.887221 23.206112 \n", 199 | "3 0.283310 28.154075 26.283166 24.599570 23.723491 23.214108 22.860012 \n", 200 | "4 1.545183 29.276065 27.878301 27.333528 26.543374 26.061941 25.383056 \n", 201 | "\n", 202 | " u_err g_err r_err i_err z_err y_err \n", 203 | "0 0.1 0.1 0.1 0.1 0.1 0.1 \n", 204 | "1 0.1 0.1 0.1 0.1 0.1 0.1 \n", 205 | "2 0.1 0.1 0.1 0.1 0.1 0.1 \n", 206 | "3 0.1 0.1 0.1 0.1 0.1 0.1 \n", 207 | "4 0.1 0.1 0.1 0.1 0.1 0.1 " 208 | ] 209 | }, 210 | "execution_count": 7, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | } 214 | ], 215 | "source": [ 216 | "data = get_galaxy_data()\n", 217 | "for col in data.columns[1:]:\n", 218 | " data[f\"{col}_err\"] = 0.1 * np.ones(data.shape[0])\n", 219 | "data.head()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "Now, we need to build the machinery to handle non-Gaussian errors. PZFlow convolves errors by sampling from an error model, which by default is Gaussian. If we want to convolve non-Gaussian errors, we need to pass the `Flow` constructor an error model that tells it how to sample from our non-Gaussian distribution.\n", 227 | "\n", 228 | "The error model must be a callable that takes the following arguments:\n", 229 | "- key is a jax rng key, e.g. jax.random.PRNGKey(0)\n", 230 | "- X is a 2 dimensional array of data variables, where the order of variables matches the order of the columns in data_columns\n", 231 | "- Xerr is the corresponding 2 dimensional array of errors\n", 232 | "- nsamples is the number of samples to draw from the error distribution\n", 233 | "\n", 234 | "and it must return an array of samples with the shape `(X.shape[0], nsamples, X.shape[1])`\n", 235 | "\n", 236 | "Below we build a photometric error model, which takes the exponential of the magnitudes to convert to flux values, adds Gaussian flux errors, then takes the log to convert back to magnitudes. " 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 9, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "def photometric_error_model(\n", 246 | " key,\n", 247 | " X: np.ndarray,\n", 248 | " Xerr: np.ndarray,\n", 249 | " nsamples: int,\n", 250 | ") -> np.ndarray:\n", 251 | " \n", 252 | " # calculate fluxes\n", 253 | " F = 10 ** (X / -2.5)\n", 254 | " # calculate flux errors\n", 255 | " dF = np.log(10) / 2.5 * F * Xerr\n", 256 | " \n", 257 | " # add Gaussian errors\n", 258 | " eps = random.normal(key, shape=(F.shape[0], nsamples, F.shape[1]))\n", 259 | " F = F[:, None, :] + eps * dF[:, None, :]\n", 260 | " \n", 261 | " # add a flux floor to avoid infinite magnitudes\n", 262 | " # this flux corresponds to a max magnitude of 30\n", 263 | " F = np.clip(F, 1e-12, None)\n", 264 | " \n", 265 | " # calculate magnitudes\n", 266 | " M = -2.5 * np.log10(F)\n", 267 | " \n", 268 | " return M" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "Now we can construct the Flow, this time passing the error model" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 10, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "flow = Flow(\n", 285 | " [\"redshift\"] + list(\"ugrizy\"), \n", 286 | " data_error_model=photometric_error_model,\n", 287 | ")" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "Now that we have set up the Flow with the new error model, we can train the flow, calculate posteriors, etc. just like we did in the Gaussian error example.\n", 295 | "\n", 296 | "For example, to train with error convolution:" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 11, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "name": "stderr", 306 | "output_type": "stream", 307 | "text": [ 308 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 309 | ] 310 | }, 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "Training 200 epochs \n", 316 | "Loss:\n", 317 | "(0) 20.3297\n", 318 | "(1) 2.9366\n", 319 | "(11) 0.1215\n", 320 | "(21) -0.1496\n", 321 | "(31) -0.0601\n", 322 | "(41) -0.1246\n", 323 | "(51) -0.2496\n", 324 | "(61) -0.0759\n", 325 | "(71) 0.1112\n", 326 | "(81) -0.2375\n", 327 | "(91) -0.2352\n", 328 | "(101) -0.2811\n", 329 | "(111) -0.2134\n", 330 | "(121) -0.2033\n", 331 | "(131) -0.2952\n", 332 | "(141) -0.2706\n", 333 | "(151) -0.2759\n", 334 | "(161) -0.2222\n", 335 | "(171) -0.1448\n", 336 | "(181) -0.2966\n", 337 | "(191) -0.1698\n", 338 | "(200) -0.2314\n", 339 | "CPU times: user 1h 8min 26s, sys: 45min 42s, total: 1h 54min 8s\n", 340 | "Wall time: 22min 21s\n" 341 | ] 342 | } 343 | ], 344 | "source": [ 345 | "%%time\n", 346 | "losses = flow.train(data, epochs=200, convolve_errs=True, verbose=True)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "And to calculate posteriors with error convolution:" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 12, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "grid = np.linspace(0, 3, 100)\n", 363 | "pdfs = flow.posterior(data[:10], column=\"redshift\", grid=grid, err_samples=int(1e3))" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 13, 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "data": { 373 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAHtCAYAAAA0tCb7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAABJ0AAASdAHeZh94AABglklEQVR4nO3deXxTZb4/8M9J0qTpmpa2lKUt0IJIUXDDBQVUxFEHRERc74ArOuoMDnoVZxTwil6vo44/HRfUERe84wqjeNVRZ9RRkcFdkJ0WCnShS7pnP78/knNyQkvXJCfP6ef9evVlm6bJQ+zTT77PdiRZlmUQERFRwjPp3QAiIiLqGYY2ERGRIBjaREREgmBoExERCYKhTUREJAiGNhERkSAY2kRERIJgaBMREQmCoU1ERCQISzQexOl04tNPP0VBQQFsNls0HpKIesjtdqOiogJTp06Fw+Ho9c+z/xLpp9f9V46CtWvXygD4wQ9+6Pixdu1a9l9+8EPQj57236hU2gUFBQCAtWvXoqSkJBoP2SMH7liC9k2bYB8/HkP/+/64PS9RItm5cydmz56t9sPeYv8l0k9v+29UQlsZUispKUFpaWk0HrJH0hwOtNlsSHE4UBTH5yVKRH0d2mb/JdJfT/svF6IREREJgqFNREQkCIY2ERGRIAwV2rIso6y2FS6vX++mEFEU7KlrRWObV+9mECWMqCxESxQfbK7C9S9/i+LcVLz168nItCf1+bF8Ph8aGhrQ0tICWZaj2Eqi7kmSBJvNhoyMDKSmpkKSJL2bFHff7KnHhU+uxzCHHR8vnorkJLPeTSLSnaEq7fc3VQEAdh1sxX++8UOfw1aWZezbtw+1tbXwevkun+LP7/ejsbERFRUVqKmpGZBvHP9d1gAA2O9sx7/L6nVuDVFiMFSl3eYJD4t/sLkaf/miHFefOrLXj9Pc3Iz29nZkZmZiyJAhA7LKIf15PB5UVlaivr4eqampSEtL07tJcVXT7FI//3xnLaaMydWxNUSJwVCVdl2rJ+LrP36wrU/z201NTQCAvLw8Bjbpxmq1YsiQIQDCv5MDSU2zW/38XztqdWwJUeIwVGhr35kDQLvXj21Vzb1+HK/XC4vFAovFUAMRJCCr1YqkpCS43e7u72wwBzWhvaWyKeJrooHKMKEty7Laqc8Ym6fe/nNl7ysUWZZhMhnmpSHBSZI0IOe0Dw3pL3ex2iYyTDK1uH1weQMAgONHZCHFGlxpuvlAY58ej8PilCgG6u9iTVPkyBmHyIkMFNrad+X5GckYm58OANh8YODNBRKJrtXtQ6sncj3K5ztqB+SIA5GWYUJbu2glLz0ZpUMzAQBbK5vhD7CjE4lE+ya8dGgGAKCqyYXvK5w6tYgoMRgmtLWdPDfdpnb0dq8fZbWtejWLYmzVqlWQJAnl5eV6N4WiSPsmfP4pI2AxBacI/vJFuU4tIkoMhg3tcaHQBvo+r21UStApH8nJyRgzZgxuuukmVFdXR/W52trasGzZMnzyySdRfVwyNm1/PmpYJs47Orj17f9+qsQBZ7tezSLSnXFCuyXYyS0mCQ57EsYMToc59O68LyvIB4J77rkHL730Eh5//HGccsopePLJJ3HyySejra0tas/R1taG5cuXxyy0/+M//gPt7e0oKiqKyeOTPrTbN3PTbeohSf6AjBfWl+vUKiL9GSa0a5qCoZ2bboPJJCE5yYzRecETpH7mYrROnXPOObjiiitwzTXXYNWqVVi0aBHKysrwt7/9Te+mdau1NTjlYTabkZycHLUV1tF8w0J9pwyPW0wSslOsOHq4AyeMyAIA/O+GvWjz+PRsHpFuDBPaSqWdm25Tbxs3JDhEvvlAE1ed9sAZZ5wBACgrK4PP58N//dd/obi4GDabDSNGjMCdd97Z4ZCPr7/+GmeffTZycnJgt9sxcuRIXHXVVQCA8vJy5OYGj55cvny5Ohy/bNky9ee3bt2KuXPnIjs7G8nJyTj++OPx9ttvRzyHMpz/6aef4te//jXy8vIwfPjwiO8dOqf9xBNPoLS0FDabDUOHDsWNN94Ip9MZcZ9p06Zh/Pjx+OabbzBlyhSkpKTgzjvv7O/LSFGgDI/npAXfhAPAVZOD1XaTy4ePt9To1jYiPRnmyC+lk+emaUJ7aAbe+m4/6ls9qGv1IEfzPepo165dAIBBgwbhmmuuwQsvvIC5c+di8eLF2LBhA+6//35s2bIFa9asAQDU1NRgxowZyM3NxR133AGHw4Hy8nK89dZbAIDc3Fw8+eSTuOGGG3DBBRdgzpw5AICjjz4aALB582ZMnjwZw4YNwx133IHU1FS89tprmD17Nt58801ccMEFEe379a9/jdzcXNx9991qpd2ZZcuWYfny5Zg+fTpuuOEGbNu2DU8++SQ2btyIL774AklJ4au/1dXV4ZxzzsEll1yCK664AoMHD47eC0p9plTa2jfhZxyZh/RkC5pdPry3qRIzJwzVq3lEujFeaGs6+ZBMe8T3+xvay9/ZnFBD7eOGZmDpzNI+/3xjYyNqa2vhcrnwxRdf4J577oHdbsfYsWNx/fXX45prrsEzzzwDAGqF+8c//hH//Oc/cfrpp+PLL79EQ0MD/v73v+P4449XH/fee+8FAKSmpmLu3Lm44YYbcPTRR+OKK66IeP7f/va3KCwsxMaNG2Gz2dTnOfXUU3H77bd3CO3s7Gx8/PHHMJsPf4nGgwcP4v7778eMGTPw3nvvqSfbjR07FjfddBNefvllXHnller9q6qq8NRTT2HhwoV9fh0p+pSDVfI0/dlmMeOsIwfjre/2459bD6LN40OK1TB/woh6xBC/8TKAutZgaGs7eU6aVf28rsVz6I/12s8HmrDBQJcInD59esTXRUVFWL16Nb788ksAwO9+97uI7y9evBh//OMf8e677+L000+Hw+EAAKxbtw4TJkyIqGC7U19fj3/84x+455570NzcjObm8BnxZ599NpYuXYr9+/dj2LBh6u3XXnttl4ENAB999BE8Hg8WLVoUcRTttddeizvvvBPvvvtuRGjbbLaIrykx1Iamu/IyIt9on3PUELz13X60e/34dNtBnHPUED2aR6QbQ4S21x+AMmWtrbQHaSpr5Y9Af2i3kSWC/rbnz3/+M8aMGQOLxYLBgwfjiCOOgMlkwpo1a2AymVBSUhJx//z8fDgcDuzZswcAMHXqVFx44YVYvnw5HnnkEUybNg2zZ8/GZZddplbOh7Nz507Isoy77roLd911V6f3qampiQjtkSO7v8yq0rYjjjgi4nar1YpRo0ap31cMGzYMVqsVlDh8/oB6xb7cQ0bHThudg1SrGa0eP979qZKhTQOOYUJbkXuYSjsaod2foehENGnSpIhh7UN1tyJbkiS88cYb+Oqrr/DOO+/ggw8+wFVXXYWHHnoIX331VZfXfw4Egv/Pbr31Vpx99tmd3ufQNw12u73T+/VHLB6T+qe2xRN+E56RHPG95CQzzjxyMN7+4QD+sbUGLq8fyUldj74QGYkhVo97/eGV4drQzrQnqScp1UZheHygKCoqQiAQwI4dOyJur66uhtPp7LAn+qSTTsKKFSvw9ddfY/Xq1di8eTP++te/Ajh88I8aNQoAkJSUhOnTp3f6kZ6e3qe2A8C2bdsibvd4PCgrK+N+bgEcjDiSuOOIzblH5QMA2jx+rN9dF7d2ESUCQ4S2xxeutPPSw+/MJUnCoFC1XReFSnugOPfccwEAf/rTnyJuf/jhhwEA5513HgCgoaGhw1a6iRMnAoC6NSwlJQUAOmy3ysvLw7Rp0/D000+jsrKyQxsOHjzYp7ZPnz4dVqsV/+///b+Itj333HNobGxU206J69CDVQ518qgc9fPv9jTEpU1EicJww+OHrhDPSbOhusmtzpFR9yZMmID58+dj5cqVcDqdmDp1Kv7973/jhRdewOzZs3H66acDAF544QU88cQTuOCCC1BcXIzm5mY888wzyMjIUIPfbrdj3LhxePXVVzFmzBhkZ2dj/PjxGD9+PP785z/j1FNPxVFHHYVrr70Wo0aNQnV1NdavX499+/bhhx9+6HXbc3NzsWTJEixfvhy/+MUvMGvWLGzbtg1PPPEETjjhhA4r2Cnx1HRTaWemJKEkLw07a1rw7V5nHFtGpD9DhXaazQK7NXJ+S1mMFo057YHk2WefxahRo7Bq1SqsWbMG+fn5WLJkCZYuXareRwnzv/71r6iurkZmZiYmTZqE1atXRywae/bZZ3HzzTfjlltugcfjwdKlSzF+/HiMGzcOX3/9NZYvX45Vq1ahrq4OeXl5OOaYY3D33Xf3ue3Lli1Dbm4uHn/8cdxyyy3Izs7Gddddh/vuu69XK9xJH1WNXVfaAHBsoQM7a1rwfYUT/oCsHllMZHQGCe3gMGhn78pz1OFxVtqKBQsWYMGCBV3ex2Kx4O677+4yPI855hi88sor3T7fySefjK+//rrT740aNQovvPBCn9t7uO/deOONuPHGG7t8XF7EJDFV1AePks3PSIbN0vkis2MKs/Da1/vQ4vZhZ00Ljsjv/foHIhEZY047VGl3dniKctvBFjePMiUSwJ5QaBdmpxz2PscWZqmff7uX89o0cBgitJXh8Zz0jvttlUrb4wugxc2LDBAlur1KaA86fGiPzktDui04UPgtF6PRAGKM0PYdvtIelKo9YIVD5ESJrM3jU7d8FXVRaZtMEiYWOgAA31U449AyosRgiND2B4LD3oeengRA3fIFcNsXUaJTqmyg60obCM5rA8DOmhY0tnlj2i6iRGGI0FbkdLoQLbpHmRJR7Oyt04R2F5U2ABwTqrQB4Pt9zhi1iCixGCu0u1iIBnB4nCjRaSvtokGpXd63VHP2/q6alpi1iSiRGCq0O9vTmZ0a3St9EVHsKKGdZrMgK6XrPfW5aTakhRajldUe/vrqREZiqNDWXiBEYbWYkGkPdn4OjxMltj114e1ePblgzcicYDXO0KaBwmCh3fnpSeoBK60MbaJEVtGDPdpaIxjaNMAYJrTTbZbDXqJPPcq0mcPjRInKH5BR0RAM7aJuVo4rlEr7QGM7XF5/zNpGlCgME9qHO6MYCG8Fq2WlTZSwKhvb1SOJu9vupRgVCm1ZDg+tExmZYUL7cEPjQHivdm0zQ3sgKy8vhyRJWLVqVa9+bsSIEfjlL3/Z7f0++eQTSJLU4Uzzl156CWPHjkVSUhIcDkevnnsgidij3cPhcaXSBoCyWq4gJ+MzTmh3coSpQjkVrcnli7j29kAjSVKPPnghjejZunUrFixYgOLiYjzzzDNYuXIl2trasGzZMr7Oh9Du0S7K7nq7l2JERGiz0ibjM8RVvoDOT0NTaAO9rtWNIZn2eDQp4bz00ksRX7/44ov48MMPO9x+5JFHxrNZhjFlyhS0t7fDag3/vn3yyScIBAJ49NFHUVJSAgCora3F8uXLAQDTpk3To6kJSZnPNpskDHUk9+hnMu1JGJRqRV2rh5U2DQhCh3ZAc9WurobH8zPCfwAOOF0DNrSvuOKKiK+/+uorfPjhhx1uP1RbWxtSUno2XBkLsizD5XLBbk/s/28mkwnJyZFhU1NTAwAcFu+B6qbg9FVeug0Wc88HAUfmpIZCmyvIyfiEHh5XFq0AnR9hqtDOj1XUcwitK9OmTcP48ePxzTffYMqUKUhJScGdd94JIDi8vmzZsg4/M2LEiA7XtHY6nVi0aBEKCgpgs9lQUlKCBx54AIFA99MTyhzyBx98gOOPPx52ux1PP/10rx7X6XRiwYIFyMzMhMPhwPz58+F0Ojs8V1VVFa688koMHz4cNpsNQ4YMwfnnn4/y8vIO9/38888xadIkJCcnY9SoUXjxxRcjvn/onPaIESOwdOlSAEBubi4kScKCBQuQm5sLAFi+fLk6JdHZ6zrQVDe5AARDuzfC277Yt8n4hK60vf6A+q6jq+Hx4Vnh0OYK0+7V1dXhnHPOwSWXXIIrrrgCgwcP7tXPt7W1YerUqdi/fz8WLlyIwsJCfPnll1iyZAkqKyvxpz/9qdvH2LZtGy699FIsXLgQ1157LY444ogeP64syzj//PPx+eef4/rrr8eRRx6JNWvWYP78+R2e58ILL8TmzZtx8803Y8SIEaipqcGHH36IvXv3YsSIEer9du7ciblz5+Lqq6/G/Pnz8Ze//AULFizAcccdh9LS0k7/DX/605/w4osvYs2aNXjyySeRlpaGo446CieddBJuuOEGXHDBBZgzZw4A4Oijj+7Va2xENUqlndGzoXGFshittsWNJpcXGcldn6RGJDLhQ1uJ6q4qbbvVjLx0G2qa3RErVKlzVVVVeOqpp7Bw4cI+/fzDDz+MXbt24bvvvsPo0aMBAAsXLsTQoUPx4IMPYvHixSgoKOjyMXbu3In3338fZ599tnrbvffe26PHffvtt/HZZ5/hf/7nf3DbbbcBAG644QacfvrpEc/hdDrx5Zdf4sEHH8Stt96q3r5kyZIO7dm2bRs+++wznHbaaQCAefPmoaCgAM8//zz++Mc/dvpvmD17Nr7//nusWbMGc+fORU5ODgBg2LBhuOGGG3D00Ud3OzUxkNQ0963SHqVZjFZe24qjhzui2SyihCJ4aMvh0O7kCFOtwuwU1DS7+zU8XnXffXBv2drnn48225FjkR8auo7q49psuPLKK/v886+//jpOO+00ZGVloba2Vr19+vTp+O///m989tlnuPzyy7t8jJEjR0YEdm8e9//+7/9gsVhwww03qPcxm824+eab8a9//Uu9zW63w2q14pNPPsHVV1+NrKysw7Zn3LhxamADweHuI444Art37+7+BaFuuX1+NIQurzm4t5V2rnYFOUObjE3w0A7PY3a1EA0IhvbXexr6VWm7t2xF28aNff55UQwbNixiBXRv7dixAz/++KM6d3soZXFWV0aOHNnnx92zZw+GDBmCtLS0iO8fccQREV/bbDY88MADWLx4MQYPHoyTTjoJv/zlL/GrX/0K+fn5EfctLCzs8HxZWVloaGjo9t9C3VOGxgFgcEbvKu0hGeEFigd5FgMZnCFC22ySDnuEqaIgtBitqskFl9ff7f07YztybO8bGUOxak9vV2n7/ZHHRwYCAZx11ln4z//8z07vP2bMmD61IRqPe6hFixZh5syZWLt2LT744APcdddduP/++/GPf/wDxxxzjHo/s7nz3xdZs4OB+q5GE7Z56b2rtNOTLTCbJPgDMhraeFQxGZvgoR38g5nUg+0h2hXk+xraUZKX1sW9OxeLoWiRZGVldViB7fF4UFlZGXFbcXExWlpaMH369Kg+f08ft6ioCB9//DFaWloiqu1t27Yd9nEXL16MxYsXY8eOHZg4cSIeeughvPzyy1Ftv6K7q1cNRDWhleMAkNfLSttkkpCVkoTaFg/qW73RbhpRQhF8y1ew0u5RaA/itq/+Ki4uxmeffRZx28qVKztU2vPmzcP69evxwQcfdHgMp9MJn8/Xp+fv6eOee+658Pl8ePLJJ9Xv+/1+PPbYYxE/09bWBpfLFXFbcXEx0tPT4XbHbphV2fPe2Ra0gapaE9q9ndMGgKyU4HROQysrbTI2wSvtYGhbzN1XLkXZ2m1fPIShL6655hpcf/31uPDCC3HWWWfhhx9+wAcffKCuilbcdtttePvtt/HLX/5S3RbV2tqKn376CW+88QbKy8s7/ExP9PRxZ86cicmTJ+OOO+5AeXk5xo0bh7feeguNjY0Rj7d9+3aceeaZmDdvHsaNGweLxYI1a9aguroal1xySb9eq67Y7XaMGzcOr776KsaMGYPs7GyMHz8e48ePj9lzJjpleNxikpCd0vv1FFmpwZ+p5/A4GZzQoe3rxfB4broNNosJbl8Ae+vbY900Q7r22mtRVlaG5557Du+//z5OO+00fPjhhzjzzDMj7peSkoJPP/0U9913H15//XW8+OKLyMjIwJgxY7B8+XJkZmb26fl7+rgmkwlvv/02Fi1ahJdffhmSJGHWrFl46KGHIuapCwoKcOmll+Ljjz/GSy+9BIvFgrFjx+K1117DhRde2PcXqgeeffZZ3Hzzzbjlllvg8XiwdOnSAR3aymlouek2mEy9nz4YlMpKmwYGSY7CSprNmzdj/Pjx2LRp02EPmoi2QEDGG1Nm4qjaXagrKcWp697o9mfOevhT7KhpwfQjB+PZ+ccf9n7KNp5Ro0ZFrb1EfdXd72N/+58e/RcA9vzHr9C2cSNSTjgBf5h2I/61oxYTChz4242Te/1Yd675Ca9s2ItBqVZ8c9dZMWgtUWz0tv8JO6fd2O5VV+5aTD37ZyiL0TinTZRYajTnjveFMqTe0OZBIMAV/WRcwoZ2nWYYLKkHc9pAeNvX3vo2btUhSiDVodPQertHW6HMaQdkoMnFFeRkXOKGdkt4dW9P5rSBcKXd7vWjtoVzX0SJICDLcIZOQ+vtHm1Fdmr4vPF6zmuTgQkb2tqO2ZPV4wBQpNn2xcv4ESUGj+ZqfX2utDUrznnAChmZsKEdOTzes3/GmMHp6ufbqpqi3iYi6j3tccS9vcKXYlBqOOzrOIpGBiZsaEdU2j3cIjLMYUeqNXgc5daq5i7vyzlvShRG/130+DSh3ceFaFma4XFW2mRkwoa2dk7b1MNjIU0mCUfkB6vtrkLbZDLB7/cb/o8lJT5ZluH3+w199Km20u7LaWgAkJ0aHh7nUaZkZOKGdh8Xm4wdkgEA2FbVfNhQttls8Pv9qKmpYXCTbnw+HyorK+H3+ztcscxIlEq7r6ehAYA9yQybJfjnjJU2GZmwJ6L1dYXo2FCl3eL2YV9Du7oNTGvw4MFwu92or69HY2MjzGazoSsdSiyyLCMQCKhnqaekpHR5rW/RKZV2TlrfTkMDghdhyU61orLRxdXjZGjCVtp9D+0M9fPDDZGbTCYUFhbC4XDAarUysCmuJEmCxWJBeno6hg0bhsLCQlgswr6/7pZytb5BaX2/hjvAi4bQwCDsX4K+Do8fccgK8rPGDe70fiaTCUOGDOnTcxBRz/lClfagtL4tQlMood/Xvw1EIhCy0g4E5D5X2pkpSRiaGVzssqWbFeREFHtqpZ0apUqbc9pkYEKGdpPLC38/zhfWLkYjIn15A6FKu5+hrawg55w2GZmQod3f4S9l29fugy1wef3RaBIR9ZFygY/+Do8rlXazyxexjYzISIQM7f6+k1ZWkAdkYEd1SzSaRET91N+FaNk8YIUGACFDu7/HFI4bEl5BvvlAY3+bQ0RR0O857YgDVhjaZExihnaru/s7dWFUbhpSQseZ/rSfoU2UCPo7PJ7N0KYBQMjQru9npW02SSgdGqy2GdpEiSFaC9EAoIFHmZJBCRnaykI0cx9PTwKAo4Y5AABbK5sjLlhARPro95x2irbS7t9oHFGiEjK0laEvSw8vydmZo4YHK22PP4Dt1dz6RaQne5IZKdb+nfXkSOFFQ8j4hAxtZWVoUhQqbYBD5ER662+VDQBWiwnpycHgZ6VNRiVkaDe5ghdSMJv7HtqjclLVa2v/uI+hTaSn/i5CUx8nlUeZkrEJGdrNruDQl8XU9+abTBJKh2UCADax0ibSVX8XoSmUbV/cp01GJWRoN7WHKu1+DI8DwFGh0N5a1QS3jyejEeklWqGtVtr93GFClKiEDG2l0u5vaB89PBjaXr+M7VU8GY0onrRXD4jW8DjPHyejEy603T4/3KEtWpZ+hrayVxsAtnEFOVFcaS/6E61KOzs1GP4NbR7Ict8vKkSUqIQL7ebQIjSg/5X28KwU9fOK+rZ+PRYR9Y72oh7RWD0OhM8f9/plNLt93dybSDzChXZTe3j/ZX8r7eQkM/LSg+/MKxoY2kTx5PNrKu2oDY+HH6e/JycSJSLxQjuKlTYAFGQHq+199e39fiwi6rmISjvKC9EAbvsiYxIutJVFaABg7seWL0VBlh0AK22iePMGYjE8zouGkLEJF9rKdi+g/8PjQLjSrmpycdsXURx5NcPj2VFbiKa9aAhDm4xHuNCOrLSjENqhxWiyDBxwuvr9eETUM77Q8LjZJMFmMUflMbM5PE4GJ1xoN0U5tIdn29XPuYKcKH6USjupHxf+OVSK1QybJfh4PH+cjEi40Fa2fElSdEK7MFuz7Yvz2kRxo0xHWS3R+zMkSRLPHydDEy60lS1faTYL+h/ZwJBMuzo3XsEV5ERxoxySZItiaAOa88cZ2mRAwoW2UmlnJCdF5fHMJglDHVxBThRPHl9A3fJljdJ8toJHmZKRCRfaypy2ct3caCgIzWvv45w2UVxUN7nUw8ejXWlzeJyMTMDQDlXa9uhU2kB4BXlFA4fHieJhvzPc16Id2sqpaKy0yYjEC+3QnHZGVCvtYGjXt3rQyvOKiWLugCa0o7kQDQgf1NLm8cPl5dkLZCzChXa057QBYHiWZtsX57WJYi6WoZ2VwlPRyLiEC+3YzGlrr/bFIXKiWNuvOcjILEVjH0gYjzIlIxMqtAMBGS3u6M9pD3eEK+3KRoY2UaxpK+1o055jzsVoZDRChXaLxwfluvbRHB4flGaDck7LwWaeokQUa7EM7chKm/2ZjEWo0G7WXJYzmsPjZpOkrjhlaBPFlizLsQ3tiDltbxf3JBKPUKGtrBwHojs8DgB56cHQrmFoE8VUU7sPrZ7YrerOtCepRxzXtbA/k7EIFdqxqrQBIDedlTZRPByI8boRk0lSV5BzIRoZjVChHVFpR3FOG2BoE8VLLIfGFTwVjYxKrNDWXJYz2pW2Mjxe2+JGICBH9bGJKCwuoR1aQc7hcTIaoUJbOzwe7TltpdL2BWQ0tPHdOVGsKHu0pSjvz9YalBbsz6y0yWiECm3t8His5rQB4CDfnRPFjFJpR/skNC1leLy+haFNxiJUaDeHDlaxWUywRflyfnnpyernNU0MbaJYqWkOVtpWc+xDu9ntg9vH88fJOIQKbaXSTo/yIjTgkEqbi9GIYqY2VP0mmWM3PJ6dxqNMyZiECm31YiH26A6NAxweJ4oXZXFYUkwr7XB/ruMQORmIUKEdvlhI9CvtNJsFKdbgkDuHx4liw+cPoKEt2I8tsQxtTaVdyzfhZCBChbayqjszyivHFepebXZyopio1+zMiOXw+CBe6YsMSqjQrm0Odr48zVB2NOWGtonUNLm6uScR9YV2qDqmw+NpHB4nYxImtAMBWR3myo1RaOdlsNImiqXI0I5dpZ2RbFEfn3u1yUiECW1nuxe+0EllOWmxrbS5epwoNuo0l8qMZaUtSZJ6iU6eikZGIkxoa4M0VpW28rjNLh9cXu7tJIq22jgNjwNQL7fLSpuMRMzQjlGlrT1ghdU2UfQpVa/FJKmXz4yVnDReNISMR5zQbgkvDot1pQ3wutpEsaCsS8lOtSK2ka250heHx8lAhAltZeU4EJ/QPtjMFeRE0aYsRBsUo9EyLWV4nFu+yEiECW1lRbfVYkJGlC8WoshjpU0UU7WhAM3RHH4SK8oBK20eP9o8vm7uTSQGcUI7FKK5abaYXdIvO9UKZZqNp6IRRZ8yVK09/CRWtM/BvdpkFMKFdk6MhsaB4LGKyhB5FQ9YIYq6eA6Pa5+DQ+RkFMKFdqxWjisGZwRXkFcztImiqs3jQ3toK+WgOAyPZ2sr7VaOnJExiBPaMT4NTcHQJooN7RB1TmrsK23tvDmHx8kohAhtrz+gXiwk1qGdHwrtqkaGNlE0aa+2lZMej4VomvPHOTxOBiFEaNe3eiAHTzCNfWhnBkO7yeVDu4enohFFi7baHRSHSjvVaobVEvwTV8vdIGQQQoR2PE5DU2i3fXGInCh6tJV2POa0JUlS+zO3cJJRiBfacaq0Aa4gJ4om7RB1PCptgGtUyHiEC+1YXUtbocxpA+zoRNGkVNqpVjPsVnNcnlPpz6y0ySjECG3tApZYb/nSVtpcjEYUNfHco63Iywidu9DogqwsjCESmBihHXqXnGazxPwderrNAntS8DmqeSoaUdQoe6XjMZ+tUCrtdq8fzW4eZUriEyO047RHGwguXlHmtTk8ThQ9lc5gfxqsuQRurA3WTndx5IwMQIzQjtNpaIrBGTzKlCiaAgEZ+5ztAIDhWfa4PW9EaHPkjAxAiNCubAx29tyM+IR2PlecEkVVbasbHl8AQLxDO/w3g2/CyQgSPrS9/gAOhIbVirJT4vKcyrvzmiY3F68QRcG+hnb182FZ8enHwKGVNkObxJfwob2/oR3+QDA4iwbFN7Q9/gCvDkQUBfu1oe2IX6WdarMg3WYBANQwtMkAEj6099S3qZ8XZqfG5Tm1B6xwHoyo//Y7tZV2/EIbCG/j5PA4GUHih3Zdq/r5iJz4VtoAh9SIomFfQ/DNd3qyBZn2pLg+tzKvzTfgZAQChHaws1stprhtFeHiFaLoUobHh8dxPlvBo0zJSIQJ7cLsFJhMUlyeMy89GVLoqQ5ohvWIqG+U4fF4zmcrBmuOMg0EuLCUxJbwob23Pjg8Hq+V40Cwqlf+uOyube3m3kTUFVmW1dXj8dzupVC2cPoDMmpbOUROYkvo0A4EZOwNLUQrjNPKcUVxbhoAYPdBhjZRfzjbvGgLXZtej9DWTnfVcF6bBJfQoV3T7IbLGzyQYcSg+KwcV4RDu4VDakT9sE+n7V4KLiwlI0no0NauHI97pZ0XfJPg9gUitqsQUe/sd4a3beq5EA3gwlISX2KHtmaPdjzntIFwpQ0Auw62xPW5iYwk8jS0+Ffauek2dWEpt32R6BI6tPeGVo6bpPi/Q48Mbc5rE/WVEtopVjOyUuK7RxsAkswm5IQuNlTJUTMSXEKHdnloeHxIph1WS3ybmpNmRUZy8PhDVtpEfafd7iVJ8dm2eaiRoTUx7MskuoQObWXleLxOQtOSJAnFecFqe1cNOzpRXylrU/QYGleUDA725R3VLbwIEAktYUPbH5CxMxSWI3Piu3JcMSonFNocHifqk2aXFztC/XjckAzd2jEm9Aa82e3jvDYJLWFDe3t1s7q38+jhDl3aoKwgr21xo7HNq0sbiET23V4nlML2uKIs3doxenC6+vmOmmbd2kHUXwkb2j9UONXPJxY4dGlDxGK0Wg6RE/XWN3sa1M+PLdQxtPPCfXl7NfsyiSthQ/v7UGin2SwR4RlPEaHNeW2iXvt2bzC0i3NTkZVq1a0duek2dWHpTlbaJLCED+2jhmXCHKcLhRyqaFAKLKHn3sHQJuoVf0DGd3udAPQdGgeCC0vHhIbId7DSJoElZGi3un3YXh18Nzyx0KFbO5LMJhyRH+zoG8vrdWsHkYi2VTWjxe0DABxflK1za4DRoRXk26ubuYKchJWQof3T/kYox31P0GkRmuKkUYMAAD/ta0Rr6A8QEXXvm72a+WydK20AKMkLvgFvcvlwsJkryElMCRna2kVox+hYaQPAiSODFYIvIEcsqiGirn0TGp1ypCRhlE7bNrXGDA6vUeF0F4kqIUNbmc/Oz0iOOOxfD5NGZqvnFm8oq9O1LUSiCARkfLU7GNrHFmbBpNO6FK3ReeFtX8r0G5FoEi60ff4Avg5VtHpt9dJypFgxNj94KITyR4iIuvZVWZ16Ra3Tx+bp3JqgwRk2pNuCK8gZ2iSqhAvtdT9WqvNNU8bk6tyaoJNGBYfIf9znRJuH89pE3Vnz7X4AQJJZwi+PGqJza4IkScLRBZkAgA9/roHXH9C5RUS9lxChHQjIcHn9kGUZT36yCwCQk2bDnGOH6dyyoBNHBhejef0yvt3j1LcxRAmu3ePHe5uqAABnjM3TdX/2oWZPDP5NqW1x49NtB3VuDVHv6R7a/oCMi55ej6OX/R3XvvgNtoWGra46dQSSk8w6ty5IWYwGAF/sqtWxJUSJ78Mt1epWrwuOGa5zayKde9QQpFiDf1de/6ZC59YQ9Z7uob2xvB7f7GmAxx/AR1uqAQDpNguuOKlI55aFZaVaMWF4cFjt1Y0VaA+diU5EkWRZxmsbg2GYaU/C6WMTY4pLkWqz4LzQcP3HW2pQ18KtXyQW3UN73Y8HOtz2q1OKkJGcpENrDu+a00YBAOpbPfjrxr06t4Yocbi8ftQ0uSDLMv749234fGdwNGrmhCGwWRJjtEzrouMLAAS3cf51I6ttEotFzyf3+QN476fg3NeUMbm4fsooVDS04cJjE2tIDQgOqz30920or2vDM5/txuUnFsFq0f09D5Gu2j1+zP7zF9hW3YycNBtqQ5VrYXYKfnvmGJ1b17kTRmRhxKAUlNe14cEPtsHrD+A3Z4xOiG1pRN3RNXXW765DXasHADDz6CE4pSQHF59QCIs58cLQbJJw/dRiAMCBRhdWfraLRyHSgPfUp7vUdShKYOem2/Dy1SciN92mZ9MOS5IkrLjgKHVu+08f7cCsP3+O936qhI8ryinB6Vppr/uhEgBgNZswozRfz6b0yAXHDsOfPtqBqiYX/vj37fhyVx0umVSI0qEZKMhKYeVNA0pFfRue+jS422NUbiqKc9PQ7PJi2axSFA5K0bl1XZtckoO/3TgZC1/6BrtrW7FpfxNuWP0t0m0WnDAyG0cOSUdRdiqGZ9kxxGFHXroNKVYzJInVuGhkWcaWymZIEjA2P134/4dxC+29dW34yxdlkGUZARnYU9+Gr3YFTxibMiYHmfbEmsPujM1ixuOXHYObXvkOVU0ufLmrDl+G/g2SBOSl25CTZkNWihV2qxlWiwlJJgkmkwSzJMEkSTCZgu/0pdDPSJDUE9eUXyU9fqk6GzWQ1e8BMuTQf4NfI/R1QA7fHgh9Ih/ymIcbjwi+BsHXAuprIIVel/Dro95f0v63s9ct/BhQf77/r+e+hjZNp8/A8Cx7vx6vt+44Z6zuOykq6tvw3Odl6tcefwDf73XC7QtWpv9z4dE4foT+FwXpjdGD0/G3mybjuc/LsOrLcjjbvGh2+/CPrTX4x9aaDve3WUzISrEiw25Bqs2CVKsFyUkm2CxmJJklWMwmWEL9vav+rejN7+XhRvW0fTR8m7avhvtpQPlvQIZfDm619Qdk+AIBeP3Bz73+APwBOdivgeDPhL72B4KP5Zdltd+r95Wh/m1XflbLJAFmKfQamSVYTJLarmAbgo/hCwTQ5vFDloOXcx0zOB12q7nP/dgfkPH5zlqU1bYCAEblpGJySU5crxwZ7f4bt9CuaXZh1ZflnX5vTgLOYR/O8SOy8fffTcH/vL8Vr26sgNcfCiYZqG5yo7qJq1GNbF9De9yf83czxuge2l3139kThwoX2Ir05CQsmj4G1542Cu9vqsL63XX4urwee+vb1IsWKdy+AKqaXKhq0qetA80P+xrxw77GqD7m7tpW7A4FeLxEu//GLbQtZhMyki3qu9CsFCuOK8rClDG5OGd84g+Na2UkJ+He2UfhD+eNw7aqZmyrasa+hjbsd7rQ0OZBfasHLq8fHl8AvtC7WfWdKsLvfpXPgc7fMcdbZ29ktdW/hMgqVwJgkoKfm0I/bDIdpjo+5HGVil2pCgBo/itrKnrl/uE3R5HfkyN/VvscoXf+yud9lZNmw5FDMiAjOMw2ELcJmSQpYjTMYpJgt5oxMicVvz9vnI4ti45UmwUXHjccFx4XLCA8vgD2O9txIPRR2+JBfasbje1eNLZ70ebxo9Xtg9sXgMvrh9cvw+cP9veArP0d7ny0qS+/jocrNDsboVMr/VC/NYf+7ppMwRE/s0mCSQIsJhPMJglJ5uBtFrMJ5tD3lX4thapkSZJgNiH0s6Gv1fto73doVSwjEAhW6D5/AN6ADL9fVh/fYg6NRJqCFXhykhn+gIxt1c3YfbAV/kCgwxuo8Osod1qBa/v7iJxUzJowFLIM/O2H/dhb19ablz3hxC20JxY48OOys+P1dHGRnGTGhAIHJiTAGelEsXRMYRZ+WDpD72bEjdViwsicVIxMgKuTUfRcO2WU3k3oN66cIiIiEgRDm4iISBAMbSIiIkEwtImIiATB0CYiIhJEVFaPu93BLTA7d+6MxsP12AGnE+1uN+xOJ1o2b47rcxMlCqXfKf2wt9h/ifTT2/4bldCuqAheKWf27NnReLjeKy8D1r2jz3MTJYiKigoce+yxffo5gP2XSE897b+SHIWrXjidTnz66acoKCiAzXb4iwTs3LkTs2fPxtq1a1FSUtLfpxUeX49IfD0i9fT1cLvdqKiowNSpU+FwOHr9POy/fcPXIxJfj0ix6r9RqbQdDgfOP//8Ht+/pKQEpaWl0XhqQ+DrEYmvR6SevB59qbAV7L/9w9cjEl+PSNHuv1yIRkREJAiGNhERkSAY2kRERIKIa2jn5uZi6dKlyM3NjefTJiy+HpH4ekRKtNcj0dqjN74ekfh6RIrV6xGV1eNEREQUexweJyIiEgRDm4iISBAMbSIiIkEwtImIiATB0CYiIhJEXELb7Xbj9ttvx9ChQ2G323HiiSfiww8/jMdTJ6SWlhYsXboUv/jFL5CdnQ1JkrBq1Sq9m6WLjRs34qabbkJpaSlSU1NRWFiIefPmYfv27Xo3TRebN2/GRRddhFGjRiElJQU5OTmYMmUK3nlHvwtqsP9GYv8NY/+NFI/+G5fQXrBgAR5++GFcfvnlePTRR2E2m3Huuefi888/j8fTJ5za2lrcc8892LJlCyZMmKB3c3T1wAMP4M0338SZZ56JRx99FNdddx0+++wzHHvssdi0aZPezYu7PXv2oLm5GfPnz8ejjz6Ku+66CwAwa9YsrFy5Upc2sf9GYv8NY/+NFJf+K8fYhg0bZADygw8+qN7W3t4uFxcXyyeffHKsnz4huVwuubKyUpZlWd64caMMQH7++ef1bZROvvjiC9ntdkfctn37dtlms8mXX365Tq1KLD6fT54wYYJ8xBFHxP252X87Yv8NY//tXrT7b8wr7TfeeANmsxnXXXedeltycjKuvvpqrF+/Xr2W70Bis9mQn5+vdzMSwimnnAKr1Rpx2+jRo1FaWootW7bo1KrEYjabUVBQAKfTGffnZv/tiP03jP23e9HuvzEP7e+++w5jxoxBRkZGxO2TJk0CAHz//fexbgIJRpZlVFdXIycnR++m6Ka1tRW1tbXYtWsXHnnkEbz33ns488wz494O9l/qLfbf2PbfqFxPuyuVlZUYMmRIh9uV2w4cOBDrJpBgVq9ejf379+Oee+7Ruym6Wbx4MZ5++mkAgMlkwpw5c/D444/HvR3sv9Rb7L+x7b8xD+329nbYbLYOtycnJ6vfJ1Js3boVN954I04++WTMnz9f7+boZtGiRZg7dy4OHDiA1157DX6/Hx6PJ+7tYP+l3mD/DYpl/4358Ljdbofb7e5wu8vlUr9PBABVVVU477zzkJmZqc6lDlRjx47F9OnT8atf/Qrr1q1DS0sLZs6cCTnO1/dh/6WeYv8Ni2X/jXloDxkyBJWVlR1uV24bOnRorJtAAmhsbMQ555wDp9OJ999/n78Xh5g7dy42btwY9/2v7L/UE+y/XYtm/415aE+cOBHbt29HU1NTxO0bNmxQv08Dm8vlwsyZM7F9+3asW7cO48aN07tJCUcZhm5sbIzr87L/UnfYf7sXzf4b89CeO3cu/H5/xMZyt9uN559/HieeeCIKCgpi3QRKYH6/HxdffDHWr1+P119/HSeffLLeTdJVTU1Nh9u8Xi9efPFF2O32uP9BZP+lrrD/RopH/435QrQTTzwRF110EZYsWYKamhqUlJTghRdeQHl5OZ577rlYP33Cevzxx+F0OtXVt++88w727dsHALj55puRmZmpZ/PiZvHixXj77bcxc+ZM1NfX4+WXX474/hVXXKFTy/SxcOFCNDU1YcqUKRg2bBiqqqqwevVqbN26FQ899BDS0tLi2h72386x/wax/0aKS/+NyhEt3Whvb5dvvfVWOT8/X7bZbPIJJ5wgv//++/F46oRVVFQkA+j0o6ysTO/mxc3UqVMP+zrE6dczofzv//6vPH36dHnw4MGyxWKRs7Ky5OnTp8t/+9vfdGsT+29H7L9B7L+R4tF/JVmO83JUIiIi6hNempOIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0iYiIBMHQJiIiEgRDm4iISBAMbSIiIkEwtAeIadOmYdq0ab36mWXLlkGSJNTW1nZ73xEjRmDBggURt+3YsQMzZsxAZmYmJEnC2rVre/X8RBTE/kuKmF/liwau+fPno6ysDCtWrIDD4cDxxx+PV155BTU1NVi0aJHezSOiLrD/JiaGNkXFtm3bYDKFB27a29uxfv16/P73v8dNN92k3v7KK69g06ZN7PRECYT9VxwcHhdAa2ur3k3ols1mQ1JSkvr1wYMHAQAOh0OnFhElBvZfiiaGdoJR5qF+/vlnXHbZZcjKysKpp54KAHj55Zdx3HHHwW63Izs7G5dccgkqKio6PMbKlStRXFwMu92OSZMm4V//+lenz/XYY4+htLQUKSkpyMrKUoe/DuV0OrFgwQI4HA5kZmbiyiuvRFtbW8R9tHNiy5YtQ1FREQDgtttugyRJGDFiBKZNm4Z3330Xe/bsgSRJ6u1ERsH+S7HG4fEEddFFF2H06NG47777IMsyVqxYgbvuugvz5s3DNddcg4MHD+Kxxx7DlClT8N1336nviJ977jksXLgQp5xyChYtWoTdu3dj1qxZyM7ORkFBgfr4zzzzDH7zm99g7ty5+O1vfwuXy4Uff/wRGzZswGWXXRbRlnnz5mHkyJG4//778e233+LZZ59FXl4eHnjggU7bPmfOHDgcDtxyyy249NJLce655yItLQ2pqalobGzEvn378MgjjwAA0tLSYvMCEumI/ZdiRqaEsnTpUhmAfOmll6q3lZeXy2azWV6xYkXEfX/66SfZYrGot3s8HjkvL0+eOHGi7Ha71futXLlSBiBPnTpVve3888+XS0tLe9SWq666KuL2Cy64QB40aFDEbUVFRfL8+fPVr8vKymQA8oMPPhhxv/POO08uKirq8nmJRMX+S7HG4fEEdf3116ufv/XWWwgEApg3bx5qa2vVj/z8fIwePRr//Oc/AQBff/01ampqcP3118Nqtao/v2DBAmRmZkY8vsPhwL59+7Bx48ZetQUATjvtNNTV1aGpqak//0Qiw2L/pVjh8HiCGjlypPr5jh07IMsyRo8e3el9lQUke/bsAYAO90tKSsKoUaMibrv99tvx0UcfYdKkSSgpKcGMGTNw2WWXYfLkyR0ev7CwMOLrrKwsAEBDQwMyMjJ6+S8jMj72X4oVhnaCstvt6ueBQACSJOG9996D2WzucN++zCsdeeSR2LZtG9atW4f3338fb775Jp544gncfffdWL58ecR9O3tOAJBludfPSzQQsP9SrDC0BVBcXAxZljFy5EiMGTPmsPdTVnzu2LEDZ5xxhnq71+tFWVkZJkyYEHH/1NRUXHzxxbj44ovh8XgwZ84crFixAkuWLEFycnJM/i2SJMXkcYkSFfsvRRPntAUwZ84cmM1mLF++vMO7Y1mWUVdXBwA4/vjjkZubi6eeegoej0e9z6pVq+B0OiN+TvkZhdVqxbhx4yDLMrxeb2z+IYC6ApVooGD/pWhipS2A4uJi3HvvvViyZAnKy8sxe/ZspKeno6ysDGvWrMF1112HW2+9FUlJSbj33nuxcOFCnHHGGbj44otRVlaG559/vsOc2IwZM5Cfn4/Jkydj8ODB2LJlCx5//HGcd955SE9Pj9m/5bjjjsOrr76K3/3udzjhhBOQlpaGmTNnxuz5iPTG/kvRxNAWxB133IExY8bgkUceUeesCgoKMGPGDMyaNUu933XXXQe/348HH3wQt912G4466ii8/fbbuOuuuyIeb+HChVi9ejUefvhhtLS0YPjw4fjNb36DP/zhDzH9d/z617/G999/j+effx6PPPIIioqK2OnJ8Nh/KVokmasRiIiIhMA5bSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0iYiIBMHQJiIiEkRULhjidDrx6aefoqCgADabLRoPSUQ95Ha7UVFRgalTp8LhcPT659l/ifTT6/4rR8HatWtlAPzgBz90/Fi7di37Lz/4IehHT/tvVCrtgoICAMDatWtRUlISjYekODpwxxK0b9oE+/jxGPrf9+vdHOqlnTt3Yvbs2Wo/7C09+i9/54iCett/oxLaypBaSUkJSktLo/GQFEdpDgfabDakOBwo4v8/YfV1aFuP/svfOaJIPe2/XIhGREQkCIY2ERGRIBjaREREgmBoG5gsy9he3Qy3z693U4jiqt3jR1ltq97NIIo6hraBvbxhL2Y88hmuffEbvZtCFDeyLOOip7/E6X/8BB/+XK13c4iiiqFtYJ9uqwEAfLGzFi4vq20aGNy+ADbtbwIAbNhdp3NriKKLoW1guw4Ghwf9ARk7a1p0bg1RfLS4ferntS1uHVtCFH0MbYNy+/zYW9+mfr2tqlnH1hDFT4tLG9oeHVtCFH0MbYPaW9cGf0BWv95a1aRja4jih5U2GRlD26B2HYwcDt/KSpsGiGYXQ5uMi6FtUIfOYTO0aaDQVtr1rZ6IESci0TG0DUpZhKY42OxGHasOGgBa3F7184AcDG4io2BoG5QyPG61hP8XczEaDQTahWgAUNfKN6tkHAxtA5JlGbtCw+NTx+Sqt29haNMA0OyODO3aZlbaZBwMbQOqbnKj1RM8TGVy8SA4UpIAAFsruYKcjO/QSpuL0chIGNoGpF05XpKXjrH56QCAbdWstMn4Wg6ttBnaZCAMbQPShnZxXirG5mcAALZXN0OWuZKWjO3QSvsgQ5sMhKFtQMp2rxSrGfkZyRieZQcAuLwBONu8Xf0okfCaDh0e55w2GQhD24AOOF0AgMLsFEiShCGZdvV7lY0uvZpFFBfaLV8Ah8fJWBjaBtTQFqwsslKsAID8zGT1e1VN7bq0iSheOKdNRsbQNiAltLNTg6E9RBParLTJ6Drs0+ZFQ8hAGNoG1BA6ASorNbjVKzfdBpMU/F4VQ5sM7tBKu67VzQWYZBgMbYPxB2Q0tgfn9JTh8SSzCbnpNgCstMn4lAuGWM3BP29ef7hPEImOoW0wTe1eKNdHUEIbAPJDi9EqGzmnTcbl8QXg9gUAAAXZ4QWYnNcmo2BoG4wynw2Eh8cBYGhoXpuVNhlZq2ZofGROqvr5QW77IoNgaBtMRGhHVNrB0K5qdHF+jwxLO589YlA4tFlpk1EwtA2mvjU8d6cNbWUFeZvH3+HwCSKjaNb8bhflMLTJeBjaBqOttJUtX0B4ThvgCnIyLm2lXZidou6aYGiTUTC0DUbZ7gVAvboXcOhebS5GI2PSnoaWaU9Cms0SvJ2jS2QQDG2DaQidLZ5kltQ/WACQn6E5FY2VNhmUdng8zWaBLckMAPD4A3o1iSiqGNoGo1TajhQrJElSbx+ckQzlywMMbTIo7fB4erJF3avt9jK0yRgY2gajHmGqWYQGAFaLCTlpwQNWqjg8TgbV0qHSDoW2j6FNxsDQNhj1YiGaPdqKIdyrTQanVNqSFLw0rc0SHB5naJNRMLQNpr418gpfWsq8Nue0yaiUOe00mwWSJMFqUSptv57NIooahrbBOEML0bJSO4b2kEyGNhmbEtrpoUWYtlBoe1hpk0EwtA0kEJA119LuODw+1BHcq93s9qGxjRdQIONRtnylJUeGNofHySgY2gbS7PJ1erEQRWF2ivr5nvrWeDWLKG6UOe00VtpkUAxtA6k/zLnjisJBmtCua4tLm4jiSVk9npYcHGninDYZDUPbQA53hKmiSHMBhb31DG0ynmb3oXPaPFyFjIWhbSCHO8JUkWazYFAozPfUcXicjKfFFTk8zsNVyGgY2gZS39p1pQ2Eh8g5PE5GpM5pKwvRQoersNImo2BoG4hTsyLc0cmcNhC+xjCHx8lo/AEZbZ7g3PWhC9FYaZNRMLQNRFmIZjZJyEi2dHofZQV5ZaMLLi8X55BxaM8dV4fHuRCNDIahbSDOtvBpaNqLhWgVaVaQV7DaJgNp94SDOcUWXICmLEQLyICPQ+RkAAxtAwkfYdpxEZqiiNu+yKDaPOFKO8UaDGul0gZ4wAoZA0PbQJQ57c5WjisKs8Pbvvaw0iYDadNU2vakyDltgAeskDEwtA2kKbTdJdN++NDOSbOqVchebvsiA9Gu0bCz0iaDYmgbSFN7sNLOSD58aEuSpB6ywkqbjERbaStvTJU5bYCVNhkDQ9tA1NDuotIGgKJs7tUm44kcHu+s0uYKchIfQ9sg/AFZPcKx29AOLUbb19AGv3KFESLBtXs7LkSzcXicDIahbRDK8Y0ADrtHWzE8VGl7/TJqmnltbTKGdk84lO0MbTIohrZBNLaHT0PrrtLOS7epn9c2e7q4J5E4IrZ8JUUergJweJyMgaFtEE2ucGh3tXocAHI1oX2whZU2GYP2cBU7F6KRQTG0DaJJW2l3sXocAHLTNKHd7I5Zm4jiqS205ctiktQKm8PjZDQMbYPQDo/3qtJmaJNBKJW2snIc4OEqZDwMbYPQDo9n2LteiJacZEZ6aLEaQ5uMQg1tazi0ebgKGQ1D2yCa2jWrx7uptIFwtX2whaFNxqAMj6dYtZU257TJWBjaBqEMj5skIM3adaUNhOe1WWmTUbSHVo/bNb//Nq4eJ4NhaBuEMjyenpwEk6nzy3JqqZU2Q5sMot2rzGmH/6xZOadNBsPQNojwEabdV9lAOLRrGNpkEMoxpimHrbQZ2iQ+hrZBKMPj3a0cVyih3ebxwy/zKFMSX2cL0SxmE5SBJw6PkxEwtA1CuSxnd3u0Fdq92l5WIGQA4UrbHHG7shiNw+NkBAxtg+jJZTm18jKS1c+9flbaJL62TvZpA+F5bQ6PkxEwtA2i18Pjmkrb4+cfMxKfy9txeBwIz2uz0iYjYGgbhLJ6vLcL0QDAy9AmwcmyrF4w5NDhcVbaZCQMbQNw+/xweYN/kHo6PJ6dalUX6DC0SXRuXwDKpeFTDjmngJU2GQlD2wC0p6FlpvQstM0mCYNCQ+ReH+e0SWwRV/hK6nwhGlePkxEwtA0g4tzxHlbaQHhem3PaJDrlYBWg45w2h8fJSBjaBhBxWc4ezmkD4XltDo+T6No0lXbHLV8MbTIOhrYB9OaynFoMbTKKrobHWWmTkTC0DUA5WAXo5fC4Gtqc0yaxKSvHgc4WovFwFTIOhrYBRA6P935OW+YxpiS4yDntyD9r4eFxLkQj8TG0DaCvw+M5mr3aRCKLHB7nli8yLoa2ASirx61mU8RVjbqTk2aNVZOI4qqrhWic0yYjYWgbgLJPO8NugSR1fy1tRU4aK20yhjZv96vHWWmTETC0DSB8Le2eD40DwKBUVtpkDC5NpZ18aGgn8XAVMg6GtgGo5473YuU4ADhSwkeZEoksYnj80C1f5nClzUWXJDqGtgE09rHSNpskZKdyiJzE1+YNThFZzSZYzJ2vHg/IgC/A0CaxMbQNwNkWDO2sHp47rsXFaGQEyurxQ48wBcIL0QAuRiPxMbQNwNnmAQA4ellpA8AghjYZgBraSR1DW7ujgovRSHQMbcH5AzKa3cGhwd7s0VZwBTkZgbJ6/NCV4wBgtYRv42I0Eh1DW3DNLi+UtTWZKb2vmgdxTpsMoKvhcVbaZCQMbcEp89lA/4fH/VxZS4JSzh7vvNLmnDYZB0NbcE7NEaaOfi5E44VDSFTt3mAYJ3NOmwyOoS04ZREa0NfQDg+P+3iJThJUexeVti2Jc9pkHAxtwUVeLKQPc9qa0OZ1tUlUyuEqh16WEwgfrgJweJzEx9AWXMScdh8qbe1RphweJ1F1uRAtiaFNxsHQFpw2tPu75YuVNolKuZ52Z/u0tZU257RJdAxtwTnbg3PaaTYLksy9/99pt5phDh1AztAmEcmyrIZ2Z3Payay0yUAY2oJrDFXafamyFcpZzRweJxG5vAH1rIJOjzE1axaiebkQjcTG0BacsuWrL/PZiiRzsNLm6nESkbJHG+h4hS8gck7bw99xEhxDW3DqueP9Cm1W2iQu7WU5uzsRze1laJPYGNqCUyvtPmz3UoRDm3/QSDytmko7zdbxzav2RDRW2iQ6hrbgmkKhnRmF4XFvIAA/rzdMgml1h0M71db16nFW2iQ6hrbAZFlWt3z15dxxhbrqXI48YY1IBC3u8PB4mq3j4SoWs0ndIeHxcyEaiY2hLbBWjx++UGUcjTltAKhtYWiTWFpcmuHx5I6hDYTntVlpk+gY2gKLOHe8X3Pakvp5dZOrX20iireI4fFOjjEFwvPanNMm0TG0BaY9DS2jH8PjVkt4HrCqkaFNYmlxaxeidV1pu7hPmwTH0BZYYz8vy6mwairtKlbaJJiWiIVonYe2crxpO4fHSXAMbYH192IhCpMUDu1KVtokGGV43GoxRWzv0rKHhs3bNdvDiETE0BaYcu440L85ba2qxvaoPA5RvCiV9uGGxoHwmeTag1iIRMTQFli0Km2tqiZ3VB6HKF6USruzPdoKhjYZBUNbYMqcts1iQnInZy73BSttEk240j78G1d1TpuhTYJjaAssGueOH6qhzcsVtiSUcGj3oNL2ck6bxMbQFlj4NLTozGcruFebRNIaOhHtcCvHAe1CNL4hJbExtAXmjMK5453hCnISCRei0UDC0BZYY6jSzuzHwSqdYaVNIulNaLd7/ZBlXhSHxMXQFlhda3BOOyctusPjrLRJJOHV412FdvB7sgy4eMAKCYyhLahAQEZ9a3B71qBUW1QeU7kSEo8yJVH4A7I65N2TShuIvP42kWgY2oJytnuhXPp6UJQq7aTQaVIMbRKFNoC7Cm27JrS5GI1ExtAWVF1L+BCUQWnRqbRtoUt0VnJOmwTR2oNzx4HISpuL0UhkDG1Baa97nZManUpbObe5mpU2CUIb2oe7ljZwaGhzeJzExdAWVF1r9CttJbRrml3w8brDJIAWd7hq7upwFXtSONA5PE4iY2gLqk5TaWdHq9IODY8HZOBgC88gp8TX4tIMj1s5PE7Gx9AWlDKnLUlAVpQOV9Fe1pDbvkgEPbmWNnBIaPOYXhIYQ1tQtaE92lkpVljM0fnfaLWE/7Dta+CFQyjxaee007uY045cPc45bRIXQ1tQSqU9KEpD4wCQnBT+ddhT2xq1xyWKlZ5W2tqhcw6Pk8gY2oJS5rSjtUcbAMySpJ6utqe+LWqPSxQr2tDu6T5thjaJjKEtKOUI02itHFcUZqcAAPbWMbQp8SnD4xaTBJvl8H/ObBYTQgf+ccsXCY2hLaja0PB4tPZoK0YMSgUA7Knn8DglPu2545IkHfZ+kiSp54+z0iaRMbQF5Pb50Rza6hL1SntQsNKubnLDxVW2lOCae3CFL4UyRM592iQyhraAGlq96ufRnNMGgKJQaAPAXs5rU4Jr7UVo85raZAQMbQHVas8dj9IVvhSF2anq53s4r00JrjV0IlpqF6ehKexJDG0SH0NbQMoiNCD619LWVtp76jivTYmtuQfX0lYolXa7lwvRSFwMbQHF4gpf6uOlWpEa+uPGSpsSnTI83tXBKgouRCMjYGgLSHvueLTntCVJQpG6gpyhTYlNXT3exbnjihQuRCMDYGgLqDZ0hS+r2YT0HgwL9pYyRL6Xw+OU4Fr6MDzOSptExtAWkPY0tK72pvaVsu1rX0M7L9FJCUuW5V6tHrerw+Oc0yZxMbQFpJ47HuWhcUVRaAW5LyDzal+UsNq9fgTk4OdpPZrTZqVN4mNoC0g9wjTK270U2hXk5RwipwTV04uFKMKrx/2QZTlm7SKKJYa2gA42R/8KX1oleWnq51srm2PyHET9pezRBoC0nuzTDoW2LAMuL6d9SEwMbcG4vH5UNQWHrIdn2WPyHHnpNuSEtpJtOtAYk+cg6i9nW3gXRUZyUrf3T0nSXumL89okJoa2YCrq26CM7I3ISe36zn0kSRLGD8sAAGw+0BST5yDqL2XECQDy0pO7vX8Kr6lNBsDQFkxZbXiOOVahDQDjh2YCAHYdbGFVQgmpRhPauendr+/QXlO7nRfDIUExtAWjXRg2clAMQztUacsysKWS1TYlHm2l3ZOdFClW7fA4Q5vExNAWTFlt8JSyTHsSsmK0EA0ASkOVNgBs2s/QpsRzMLT1MTvViiRz93/KIofHOXpEYmJoC6Y8NDwey6FxILjILdMeXNyzaT8Xo1HiUSrt3B6evx9RabtZaZOYGNqCUYbHR2j2UseCJEkoHRocIt/ExWiUgJTQzsvoQ2hzTpsExdAWSLvHr55QNiKG89mK8cOCQ+Q7qpvh9vGPHCWW3lbaEQvRODxOgmJoC2Sv5qpbI2M8PA5ArbR9ARnbq1pi/nxEPSXLsjqn3ZOV4wC3fJExMLQFEq/tXoqjhoUXo20sr4/58xH1VJPLB48veKpZz0Obq8dJfAxtgcRru5f6HDmpGJoZPLTi463VMX8+op462By+kE1PQ9tmMUG5KB6vqU2iYmgLRFk5npWShMyU7o9t7C9JknDmkYMBABt216PJ5Y35cxL1RMTBKj2c05YkST3KlJU2iYqhLZCyOG330jrzyDwAwXntz7YfjNvzEnXlYC9PQ1Mo19Ru93IhGomJoS0QZXg8HkPjipOLByE1NBf48ZaauD0vUVf6GtqpNlbaJDaGtiAqG9tR3RT8QzV6cHrcntdmMeO00bkAgH9uq4HPz0sakv6UleNWs0k9BKgnHKH71ra4u7knUWJiaAvi32Xh1duTRmbF9bmVIXJnmxff7GmI63MTdUbdo51ug6SsLusBZWpp98HWbu5JlJgY2oJQtlzZLCYcNcwR1+c+Y2yeuur2H1s5RE76U0I7pxdD4wAwKicNAFDZ6OL54yQkhrYglEp7YoEDVkt8/7cNSrNhYoEDAEObEkNvT0NTjMoNrwdhtU0iYmgLoKHVg+3VwRPJThyZrUsbzhwbHCLfUdOCCs3JbER60A6P90ZEaNcytEk8DG0BfK2ZRz5Bp9A+PRTaAKtt0pfXH0B9mwdA70Nbe/zv7oM8mpfEw9AWgDKfbTZJOLYwvovQFOOGZCA/I3g6GkOb9FTf6oEsBz/vbWinWC3qKX8cHicRMbQFsCE0n106NAOpNks3944NSZLUanv97jou4iHd7He2q5/3dk4bAEblBhej7a5lpU3iYWgnuMrGdmza3wgAOGGEPkPjCmVe2+ML4PMdtbq2hQauN7/Zp35+5JDen1lQHJrXLjvYClkp2YkEwdBOcC98uQf+QPAPy6wJQ3Vty+SSHCQnBX9l1n6/X9e20MBU3+rBm98GQ/uMsXko6sPpgEql3erxqwcWEYmCoZ3AWt0+vLJhDwDghBFZmBDadqUXu9WM844KvnH4++Zq1GiutEQUD69s2AOXN3gq3zWnjuzTY0Ru++IQOYmFoZ3A3vx2H5pcwbnjq/v4ByraLjuxEEDwAiKvf72vm3sTRY/b58cL64NvYo8ckoGTiwf16XGUShsAdjG0STAM7QTl8vrxl8/LAAAF2XacNS5f5xYFHVvowNj84DziXzfuRSDAOUGKj7/+u0Ldn33NqSN7dXyp1pCMZHWaZxdXkJNgGNoJasW7W1BeFzzE5KrJI2E29e0PVLRJkqRW2xX17fiUl+ukOGhyefHoxzsAAIXZKZjZj/UdJpOEkaHjTDeW13MxGgmFoZ2A1v14AC99FRwGPKbQgStOKtK5RZFmHzMM9qTgJQ7/sHYTnKGDLohi5clPdqG+Nfh7dsc5Y/t9lO9Z4wYDADYfaMJHvOQsCYShnUBkWcbqDXtw6+s/AAAy7Ul47NJjkGROrP9NGclJuPXsIwAE98z+7rUfOExOMbO9uhnPhaaKji104Jzx/Z8quvrUkchIDp558NDft/H3l4SRWGkwQHl8Afx9cxV+9Zd/4/drNsHlDUCSgIcumoDhWSl6N69TV00eof7x/MfWGtz81+/USogoWjbtb8QlK7+CxxdcMf77847s81y2VqY9CQunFgMAtlY14+0fDvT7MYniQZ/jtQY4l9ePnyub8O2eBny1ux4bdteh2R0+YSw/IxkPXzwBpxTn6NjKrkmShP+ZezS2VjWjrLYV7/5YiQ2763DppEKcNW4wxgxOR3JoCJ3oUL5QZVvR0I5/flGGETmpyE23ISM5Cf6ADGe7F+/+eAB//XeF2jdu/8VYHFcUvQOGFpwyAs9/UYbaFg9ue+MH1Ld6cOXkEVF5U0AUK3EL7Yr6NnWIK9EpC1Nk9WtAhhz6b/D7/oAMfwDwBwLwy0Ag9E1/QEZAlkP3Cz6KLANtHj+a3V5UN7nVFbCHSrWaMWviMPzn2UcgK9Uah39p/6QnJ+HVhSfhzrc24aMt1aht8eCxf+zEY//YCQDIS7chK8WKDLsFKVYLkpNMsJhNMEsSzCYJJkmC2QSYJAkmkwSTBEiQIEmA8meTf0CDc7h6vwHqa/+VQ33BH5Dh9gXQ2O7Ftqpm3Fhej6MB7G9ow7J3fu72cZbOHIcrJ0d322OqzYJls0qx6K/fw+uXcc+6n/HEJ7swflgGBqXakJxkgtkkQQJ/D6nvot1/4xbaNc0urPqyPF5PJ4zBGTZMLs7B1CNycda4wUixijX4kZeejGd+dRze/uEAnvnXbmza36R+r6bZjZrDvEGhnvvdjDG6h7Ze/ffEkdm4bsoonHnk4Jg8/i+PHoqhDjtufuU77He2o7bFjU+2cUcERU+0+2/cEsIkSerCj1iJ5rth5aG0FZ/yjluSoFaLyodSHSoVpNIWKfRYKVYzUm0W5KbZMNRhR3FeGo4pcGB4ll34d/GSJOH8icNw/sRh2O9sx/pdddjX0Ib9De1wtnvR1O6Fy+tHu9cPX0BGICDDL8sIBIIVmF8OjkYEZFmtzAAgGjtxorGd59D/PwNxi1Bf+q/yupmk4M/bLCbYrWaU5KVh+KYUoC54qdmvlpyJstpWNLZ70OTywWKSkJxkRunQjD4dU9pbxxZm4f9+exre/GYfftznxNaqZjS7fHB5/ervJjAw/79T4olbaB9TmIUfl50dr6cjnQxz2DH3uOF6N4OiLNr9d8/f7WhD8M1vfmYy8kOXy9RLpj0JVyXIqYNEXeHqcSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBRGX1uNsd3Iu7c+fOaDwcxdkBpxPtbjfsTidaNm/WuznUS0q/U/phb+nRf/k7RxTU2/4bldCuqKgAAMyePTsaD0d6KS8D1r2jdyuojyoqKnDsscf26ecAnfovf+eIAPS8/0pyFE4McDqd+PTTT1FQUACbzXbY++3cuROzZ8/G2rVrUVJS0t+nFR5fj0h8PSL19PVwu92oqKjA1KlT4XA4ev087L99w9cjEl+PSLHqv1GptB0OB84///we37+kpASlpaXReGpD4OsRia9HpJ68Hn2psBXsv/3D1yMSX49I0e6/XIhGREQkCIY2ERGRIBjaREREgohraOfm5mLp0qXIzc2N59MmLL4ekfh6REq01yPR2qM3vh6R+HpEitXrEZXV40RERBR7HB4nIiISBEObiIhIEAxtIiIiQTC0iYiIBMHQJiIiEkRcQtvtduP222/H0KFDYbfbceKJJ+LDDz+Mx1MnpJaWFixduhS/+MUvkJ2dDUmSsGrVKr2bpYuNGzfipptuQmlpKVJTU1FYWIh58+Zh+/btejdNF5s3b8ZFF12EUaNGISUlBTk5OZgyZQreeUe/i2qw/0Zi/w1j/40Uj/4bl9BesGABHn74YVx++eV49NFHYTabce655+Lzzz+Px9MnnNraWtxzzz3YsmULJkyYoHdzdPXAAw/gzTffxJlnnolHH30U1113HT777DMce+yx2LRpk97Ni7s9e/agubkZ8+fPx6OPPoq77roLADBr1iysXLlSlzax/0Zi/w1j/40Ul/4rx9iGDRtkAPKDDz6o3tbe3i4XFxfLJ598cqyfPiG5XC65srJSlmVZ3rhxowxAfv755/VtlE6++OIL2e12R9y2fft22WazyZdffrlOrUosPp9PnjBhgnzEEUfE/bnZfzti/w1j/+1etPtvzCvtN954A2azGdddd516W3JyMq6++mqsX79evZbvQGKz2ZCfn693MxLCKaecAqvVGnHb6NGjUVpaii1btujUqsRiNptRUFAAp9MZ9+dm/+2I/TeM/bd70e6/MQ/t7777DmPGjEFGRkbE7ZMmTQIAfP/997FuAglGlmVUV1cjJydH76boprW1FbW1tdi1axceeeQRvPfeezjzzDPj3g72X+ot9t/Y9t+oXE+7K5WVlRgyZEiH25XbDhw4EOsmkGBWr16N/fv345577tG7KbpZvHgxnn76aQCAyWTCnDlz8Pjjj8e9Hey/1Fvsv7HtvzEP7fb2dthstg63Jycnq98nUmzduhU33ngjTj75ZMyfP1/v5uhm0aJFmDt3Lg4cOIDXXnsNfr8fHo8n7u1g/6XeYP8NimX/jfnwuN1uh9vt7nC7y+VSv08EAFVVVTjvvPOQmZmpzqUOVGPHjsX06dPxq1/9CuvWrUNLSwtmzpwJOc7X92H/pZ5i/w2LZf+NeWgPGTIElZWVHW5Xbhs6dGism0ACaGxsxDnnnAOn04n333+fvxeHmDt3LjZu3Bj3/a/sv9QT7L9di2b/jXloT5w4Edu3b0dTU1PE7Rs2bFC/TwOby+XCzJkzsX37dqxbtw7jxo3Tu0kJRxmGbmxsjOvzsv9Sd9h/uxfN/hvz0J47dy78fn/ExnK3243nn38eJ554IgoKCmLdBEpgfr8fF198MdavX4/XX38dJ598st5N0lVNTU2H27xeL1588UXY7fa4/0Fk/6WusP9Gikf/jflCtBNPPBEXXXQRlixZgpqaGpSUlOCFF15AeXk5nnvuuVg/fcJ6/PHH4XQ61dW377zzDvbt2wcAuPnmm5GZmaln8+Jm8eLFePvttzFz5kzU19fj5Zdfjvj+FVdcoVPL9LFw4UI0NTVhypQpGDZsGKqqqrB69Wps3boVDz30ENLS0uLaHvbfzrH/BrH/RopL/43KES3daG9vl2+99VY5Pz9fttls8gknnCC///778XjqhFVUVCQD6PSjrKxM7+bFzdSpUw/7OsTp1zOh/O///q88ffp0efDgwbLFYpGzsrLk6dOny3/72990axP7b0fsv0Hsv5Hi0X8lWY7zclQiIiLqE16ak4iISBAMbSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0B4hp06Zh2rRpvfqZZcuWQZIk1NbWdnvfESNGYMGCBRG37dixAzNmzEBmZiYkScLatWt79fxEFMT+S4qYX+WLBq758+ejrKwMK1asgMPhwPHHH49XXnkFNTU1WLRokd7NI6IusP8mJoY2RcW2bdtgMoUHbtrb27F+/Xr8/ve/x0033aTe/sorr2DTpk3s9EQJhP1XHBweF0Bra6veTeiWzWZDUlKS+vXBgwcBAA6HQ6cWESUG9l+KJoZ2glHmoX7++WdcdtllyMrKwqmnngoAePnll3HcccfBbrcjOzsbl1xyCSoqKjo8xsqVK1FcXAy73Y5JkybhX//6V6fP9dhjj6G0tBQpKSnIyspSh78O5XQ6sWDBAjgcDmRmZuLKK69EW1tbxH20c2LLli1DUVERAOC2226DJEkYMWIEpk2bhnfffRd79uyBJEnq7URGwf5Lscbh8QR10UUXYfTo0bjvvvsgyzJWrFiBu+66C/PmzcM111yDgwcP4rHHHsOUKVPw3Xffqe+In3vuOSxcuBCnnHIKFi1ahN27d2PWrFnIzs5GQUGB+vjPPPMMfvOb32Du3Ln47W9/C5fLhR9//BEbNmzAZZddFtGWefPmYeTIkbj//vvx7bff4tlnn0VeXh4eeOCBTts+Z84cOBwO3HLLLbj00ktx7rnnIi0tDampqWhsbMS+ffvwyCOPAADS0tJi8wIS6Yj9l2JGpoSydOlSGYB86aWXqreVl5fLZrNZXrFiRcR9f/rpJ9lisai3ezweOS8vT544caLsdrvV+61cuVIGIE+dOlW97fzzz5dLS0t71Jarrroq4vYLLrhAHjRoUMRtRUVF8vz589Wvy8rKZADygw8+GHG/8847Ty4qKuryeYlExf5Lscbh8QR1/fXXq5+/9dZbCAQCmDdvHmpra9WP/Px8jB49Gv/85z8BAF9//TVqampw/fXXw2q1qj+/YMECZGZmRjy+w+HAvn37sHHjxl61BQBOO+001NXVoampqT//RCLDYv+lWOHweIIaOXKk+vmOHTsgyzJGjx7d6X2VBSR79uwBgA73S0pKwqhRoyJuu/322/HRRx9h0qRJKCkpwYwZM3DZZZdh8uTJHR6/sLAw4uusrCwAQENDAzIyMnr5LyMyPvZfihWGdoKy2+3q54FAAJIk4b333oPZbO5w377MKx155JHYtm0b1q1bh/fffx9vvvkmnnjiCdx9991Yvnx5xH07e04AkGW5189LNBCw/1KsMLQFUFxcDFmWMXLkSIwZM+aw91NWfO7YsQNnnHGGervX60VZWRkmTJgQcf/U1FRcfPHFuPjii+HxeDBnzhysWLECS5YsQXJyckz+LZIkxeRxiRIV+y9FE+e0BTBnzhyYzWYsX768w7tjWZZRV1cHADj++OORm5uLp556Ch6PR73PqlWr4HQ6I35O+RmF1WrFuHHjIMsyvF5vbP4hgLoClWigYP+laGKlLYDi4mLce++9WLJkCcrLyzF79mykp6ejrKwMa9aswXXXXYdbb70VSUlJuPfee7Fw4UKcccYZuPjii1FWVobnn3++w5zYjBkzkJ+fj8mTJ2Pw4MHYsmULHn/8cZx33nlIT0+P2b/luOOOw6uvvorf/e53OOGEE5CWloaZM2fG7PmI9Mb+S9HE0BbEHXfcgTFjxuCRRx5R56wKCgowY8YMzJo1S73fddddB7/fjwcffBC33XYbjjrqKLz99tu46667Ih5v4cKFWL16NR5++GG0tLRg+PDh+M1vfoM//OEPMf13/PrXv8b333+P559/Ho888giKiorY6cnw2H8pWiSZqxGIiIiEwDltIiIiQTC0iYiIBMHQJiIiEgRDm4iISBAMbSIiIkEwtImIiATB0CYiIhIEQ5uIiEgQDG0iIiJBMLSJiIgEwdAmIiISBEObiIhIEAxtIiIiQTC0iYiIBPH/ARImQkMFFr0ZAAAAAElFTkSuQmCC", 374 | "text/plain": [ 375 | "
" 376 | ] 377 | }, 378 | "metadata": {}, 379 | "output_type": "display_data" 380 | } 381 | ], 382 | "source": [ 383 | "fig, axes = plt.subplots(2, 2, figsize=(4, 4), dpi=120, constrained_layout=True)\n", 384 | "\n", 385 | "for i, ax in enumerate(axes.flatten()):\n", 386 | " \n", 387 | " ax.plot(grid, pdfs[i], label=\"Posterior\")\n", 388 | " ax.axvline(data[\"redshift\"][i], c=\"C3\", label=\"True redshift\")\n", 389 | " ax.set(xlabel=\"redshift\", yticks=[])\n", 390 | "\n", 391 | "axes[0,0].legend()\n", 392 | "plt.show()" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [] 401 | } 402 | ], 403 | "metadata": { 404 | "interpreter": { 405 | "hash": "12e089b4d8e7c489ece8bd483c1f38c5ce283ca32d0fbd08723b9602a5027f48" 406 | }, 407 | "kernelspec": { 408 | "display_name": "Python 3.9.1 64-bit ('pzflow': conda)", 409 | "language": "python", 410 | "name": "python3" 411 | }, 412 | "language_info": { 413 | "codemirror_mode": { 414 | "name": "ipython", 415 | "version": 3 416 | }, 417 | "file_extension": ".py", 418 | "mimetype": "text/x-python", 419 | "name": "python", 420 | "nbconvert_exporter": "python", 421 | "pygments_lexer": "ipython3", 422 | "version": "3.9.12" 423 | } 424 | }, 425 | "nbformat": 4, 426 | "nbformat_minor": 4 427 | } 428 | -------------------------------------------------------------------------------- /mkdocs-requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.6.1 2 | mkdocs-material==9.6.10 3 | mkdocstrings==0.29.1 4 | mkdocstrings-python==1.16.8 5 | mkdocs-gen-files==0.5.0 6 | mkdocs-section-index==0.3.9 7 | mkdocs-literate-nav==0.6.2 8 | mkdocs-jupyter==0.25.1 9 | mike==2.1.3 10 | jax 11 | jaxlib 12 | pandas 13 | dill 14 | tqdm 15 | flake8 16 | mypy 17 | black 18 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: PZFlow 2 | repo_url: https://github.com/jfcrenshaw/pzflow 3 | nav: 4 | - Home: index.md 5 | - Install: install.md 6 | - Tutorials: 7 | - tutorials/index.md 8 | - Introduction: tutorials/intro.ipynb 9 | - Conditional Flows: tutorials/conditional_demo.ipynb 10 | - Convolving Gaussian Errors: tutorials/gaussian_errors.ipynb 11 | - Flow Ensembles: tutorials/ensemble_demo.ipynb 12 | - Training Weights: tutorials/weighted.ipynb 13 | - Customizing the flow: tutorials/customizing_example.ipynb 14 | - Modeling Variables with Periodic Topology: tutorials/spherical_flow_example.ipynb 15 | - Marginalizing Variables: tutorials/marginalization.ipynb 16 | - Convolving Non-Gaussian Errors: tutorials/nongaussian_errors.ipynb 17 | - Common gotchas: gotchas.md 18 | - API: API/ 19 | 20 | theme: 21 | name: material 22 | palette: 23 | scheme: slate 24 | primary: pink 25 | accent: deep orange 26 | icon: 27 | logo: material/star-face 28 | features: 29 | - navigation.indexes 30 | plugins: 31 | - search 32 | - mkdocs-jupyter: 33 | theme: dark 34 | - gen-files: 35 | scripts: 36 | - docs/gen_ref_pages.py 37 | - section-index 38 | - literate-nav 39 | - mkdocstrings: 40 | handlers: 41 | python: 42 | options: 43 | docstring_style: numpy 44 | extra: 45 | version: 46 | provider: mike 47 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pzflow" 3 | version = "3.2.0" 4 | description = "Probabilistic modeling of tabular data with normalizing flows." 5 | authors = ["John Franklin Crenshaw "] 6 | license = "MIT" 7 | repository = "https://github.com/jfcrenshaw/pzflow" 8 | readme = "README.md" 9 | 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.10" 13 | jax = ">=0.5.3" 14 | jaxlib = ">=0.5.3" 15 | optax = ">=0.2.4" 16 | pandas = ">=2.2.3" 17 | dill = ">=0.3.9" 18 | tqdm = ">=4.67.1" 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | pre-commit = ">=4.2.0" 22 | black = ">=25.1.0" 23 | mypy = ">=1.15.0" 24 | isort = ">=6.0.1" 25 | flake8 = ">=7.2.0" 26 | flake8-bugbear = ">=24.12.12" 27 | flake8-builtins = ">=2.5.0" 28 | flake8-comprehensions = ">=3.16.0" 29 | flake8-docstrings = ">=1.7.0" 30 | flake8-isort = ">=6.1.2" 31 | flake8-markdown = ">=0.6.0" 32 | flake8-print = ">=5.0.0" 33 | flake8-pytest-style = ">=2.1.0" 34 | flake8-simplify = ">=0.21.0" 35 | flake8-tidy-imports = ">=4.11.0" 36 | pep8-naming = ">=0.14.1" 37 | pandas-vet = ">=2023.8.2" 38 | pytest = ">=8.3.5" 39 | pytest-cov = ">=6.0.0" 40 | pytest-xdist = ">=3.6.1" 41 | jupyter = ">=1.1.1" 42 | matplotlib = ">=3.10.1" 43 | toml = ">=0.10.2" 44 | 45 | [build-system] 46 | requires = ["poetry-core>=1.0.0"] 47 | build-backend = "poetry.core.masonry.api" 48 | 49 | [tool.black] 50 | line-length=79 51 | 52 | [tool.isort] 53 | multi_line_output = 3 54 | include_trailing_comma = true 55 | force_grid_wrap = 0 56 | use_parentheses = true 57 | line_length = 79 58 | 59 | [tool.pyright] 60 | reportGeneralTypeIssues = false 61 | -------------------------------------------------------------------------------- /pzflow/__init__.py: -------------------------------------------------------------------------------- 1 | """Import modules and set version.""" 2 | 3 | from pzflow.flow import Flow 4 | from pzflow.flowEnsemble import FlowEnsemble 5 | 6 | __version__ = "3.2.0" 7 | -------------------------------------------------------------------------------- /pzflow/bijectors.py: -------------------------------------------------------------------------------- 1 | """Define the bijectors used in the normalizing flows.""" 2 | from functools import update_wrapper 3 | from typing import Callable, Sequence, Tuple, Union 4 | 5 | import jax.numpy as jnp 6 | from jax import random 7 | from jax.nn import softmax, softplus 8 | 9 | from pzflow.utils import DenseReluNetwork, RationalQuadraticSpline 10 | 11 | # define a type alias for Jax Pytrees 12 | Pytree = Union[tuple, list] 13 | Bijector_Info = Tuple[str, tuple] 14 | 15 | 16 | class ForwardFunction: 17 | """Return the output and log_det of the forward bijection on the inputs. 18 | 19 | ForwardFunction of a Bijector, originally returned by the 20 | InitFunction of the Bijector. 21 | 22 | Parameters 23 | ---------- 24 | params : a Jax pytree 25 | A pytree of bijector parameters. 26 | This usually looks like a nested tuple or list of parameters. 27 | inputs : jnp.ndarray 28 | The data to be transformed by the bijection. 29 | 30 | Returns 31 | ------- 32 | outputs : jnp.ndarray 33 | Result of the forward bijection applied to the inputs. 34 | log_det : jnp.ndarray 35 | The log determinant of the Jacobian evaluated at the inputs. 36 | """ 37 | 38 | def __init__(self, func: Callable) -> None: 39 | self._func = func 40 | 41 | def __call__( 42 | self, params: Pytree, inputs: jnp.ndarray, **kwargs 43 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 44 | return self._func(params, inputs, **kwargs) 45 | 46 | 47 | class InverseFunction: 48 | """Return the output and log_det of the inverse bijection on the inputs. 49 | 50 | InverseFunction of a Bijector, originally returned by the 51 | InitFunction of the Bijector. 52 | 53 | Parameters 54 | ---------- 55 | params : a Jax pytree 56 | A pytree of bijector parameters. 57 | This usually looks like a nested tuple or list of parameters. 58 | inputs : jnp.ndarray 59 | The data to be transformed by the bijection. 60 | 61 | Returns 62 | ------- 63 | outputs : jnp.ndarray 64 | Result of the inverse bijection applied to the inputs. 65 | log_det : jnp.ndarray 66 | The log determinant of the Jacobian evaluated at the inputs. 67 | """ 68 | 69 | def __init__(self, func: Callable) -> None: 70 | self._func = func 71 | 72 | def __call__( 73 | self, params: Pytree, inputs: jnp.ndarray, **kwargs 74 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 75 | return self._func(params, inputs, **kwargs) 76 | 77 | 78 | class InitFunction: 79 | """Initialize the corresponding Bijector. 80 | 81 | InitFunction returned by the initialization of a Bijector. 82 | 83 | Parameters 84 | ---------- 85 | rng : jnp.ndarray 86 | A Random Number Key from jax.random.PRNGKey. 87 | input_dim : int 88 | The input dimension of the bijection. 89 | 90 | Returns 91 | ------- 92 | params : a Jax pytree 93 | A pytree of bijector parameters. 94 | This usually looks like a nested tuple or list of parameters. 95 | forward_fun : ForwardFunction 96 | The forward function of the Bijector. 97 | inverse_fun : InverseFunction 98 | The inverse function of the Bijector. 99 | """ 100 | 101 | def __init__(self, func: Callable) -> None: 102 | self._func = func 103 | 104 | def __call__( 105 | self, rng: jnp.ndarray, input_dim: int, **kwargs 106 | ) -> Tuple[Pytree, ForwardFunction, InverseFunction]: 107 | return self._func(rng, input_dim, **kwargs) 108 | 109 | 110 | class Bijector: 111 | """Wrapper class for bijector functions""" 112 | 113 | def __init__(self, func: Callable) -> None: 114 | self._func = func 115 | update_wrapper(self, func) 116 | 117 | def __call__(self, *args, **kwargs) -> Tuple[InitFunction, Bijector_Info]: 118 | return self._func(*args, **kwargs) 119 | 120 | 121 | @Bijector 122 | def Chain( 123 | *inputs: Sequence[Tuple[InitFunction, Bijector_Info]] 124 | ) -> Tuple[InitFunction, Bijector_Info]: 125 | """Bijector that chains multiple InitFunctions into a single InitFunction. 126 | 127 | Parameters 128 | ---------- 129 | inputs : (Bijector1(), Bijector2(), ...) 130 | A container of Bijector calls to be chained together. 131 | 132 | Returns 133 | ------- 134 | InitFunction 135 | The InitFunction of the total chained Bijector. 136 | Bijector_Info 137 | Tuple('Chain', Tuple(Bijector_Info for each bijection in the chain)) 138 | This allows the chain to be recreated later. 139 | """ 140 | 141 | init_funs = tuple(i[0] for i in inputs) 142 | bijector_info = ("Chain", tuple(i[1] for i in inputs)) 143 | 144 | @InitFunction 145 | def init_fun(rng, input_dim, **kwargs): 146 | all_params, forward_funs, inverse_funs = [], [], [] 147 | for init_f in init_funs: 148 | rng, layer_rng = random.split(rng) 149 | param, forward_f, inverse_f = init_f(layer_rng, input_dim) 150 | 151 | all_params.append(param) 152 | forward_funs.append(forward_f) 153 | inverse_funs.append(inverse_f) 154 | 155 | def bijector_chain(params, bijectors, inputs, **kwargs): 156 | log_dets = jnp.zeros(inputs.shape[0]) 157 | for bijector, param in zip(bijectors, params): 158 | inputs, log_det = bijector(param, inputs, **kwargs) 159 | log_dets += log_det 160 | return inputs, log_dets 161 | 162 | @ForwardFunction 163 | def forward_fun(params, inputs, **kwargs): 164 | return bijector_chain(params, forward_funs, inputs, **kwargs) 165 | 166 | @InverseFunction 167 | def inverse_fun(params, inputs, **kwargs): 168 | return bijector_chain( 169 | params[::-1], inverse_funs[::-1], inputs, **kwargs 170 | ) 171 | 172 | return all_params, forward_fun, inverse_fun 173 | 174 | return init_fun, bijector_info 175 | 176 | 177 | @Bijector 178 | def ColorTransform( 179 | ref_idx: int, mag_idx: int 180 | ) -> Tuple[InitFunction, Bijector_Info]: 181 | """Bijector that calculates photometric colors from magnitudes. 182 | 183 | Using ColorTransform restricts and impacts the order of columns in the 184 | corresponding normalizing flow. See the notes below for an example. 185 | 186 | Parameters 187 | ---------- 188 | ref_idx : int 189 | The index corresponding to the column of the reference band, which 190 | serves as a proxy for overall luminosity. 191 | mag_idx : arraylike of int 192 | The indices of the magnitude columns from which colors will be calculated. 193 | 194 | Returns 195 | ------- 196 | InitFunction 197 | The InitFunction of the ColorTransform Bijector. 198 | Bijector_Info 199 | Tuple of the Bijector name and the input parameters. 200 | This allows it to be recreated later. 201 | 202 | Notes 203 | ----- 204 | ColorTransform requires careful management of column order in the bijector. 205 | This is best explained with an example: 206 | 207 | Assume we have data 208 | [redshift, u, g, ellipticity, r, i, z, y, mass] 209 | Then 210 | ColorTransform(ref_idx=4, mag_idx=[1, 2, 4, 5, 6, 7]) 211 | will output 212 | [redshift, ellipticity, mass, r, u-g, g-r, r-i, i-z, z-y] 213 | 214 | Notice how the non-magnitude columns are aggregated at the front of the 215 | array, maintaining their relative order from the original array. 216 | These values are then followed by the reference magnitude, and the new colors. 217 | 218 | Also notice that the magnitudes indices in mag_idx are assumed to be 219 | adjacent colors. E.g. mag_idx=[1, 2, 5, 4, 6, 7] would have produced 220 | the colors [u-g, g-i, i-r, r-z, z-y]. You can chain multiple ColorTransforms 221 | back-to-back to create colors in a non-adjacent manner. 222 | """ 223 | 224 | # validate parameters 225 | if ref_idx <= 0: 226 | raise ValueError("ref_idx must be a positive integer.") 227 | if not isinstance(ref_idx, int): 228 | raise ValueError("ref_idx must be an integer.") 229 | if ref_idx not in mag_idx: 230 | raise ValueError("ref_idx must be in mag_idx.") 231 | 232 | bijector_info = ("ColorTransform", (ref_idx, mag_idx)) 233 | 234 | # convert mag_idx to an array 235 | mag_idx = jnp.array(mag_idx) 236 | 237 | @InitFunction 238 | def init_fun(rng, input_dim, **kwargs): 239 | # array of all the indices 240 | all_idx = jnp.arange(input_dim) 241 | # indices for columns to stick at the front 242 | front_idx = jnp.setdiff1d(all_idx, mag_idx) 243 | # the index corresponding to the first magnitude 244 | mag0_idx = len(front_idx) 245 | 246 | # the new column order 247 | new_idx = jnp.concatenate((front_idx, mag_idx)) 248 | # the new column for the reference magnitude 249 | new_ref = jnp.where(new_idx == ref_idx)[0][0] 250 | 251 | # define a convenience function for the forward_fun below 252 | # if the first magnitude is the reference mag, do nothing 253 | if ref_idx == mag_idx[0]: 254 | 255 | def mag0(outputs): 256 | return outputs 257 | 258 | # if the first magnitude is not the reference mag, 259 | # then we need to calculate the first magnitude (mag[0]) 260 | else: 261 | 262 | def mag0(outputs): 263 | return outputs.at[:, mag0_idx].set( 264 | outputs[:, mag0_idx] + outputs[:, new_ref], 265 | indices_are_sorted=True, 266 | unique_indices=True, 267 | ) 268 | 269 | @ForwardFunction 270 | def forward_fun(params, inputs, **kwargs): 271 | # re-order columns and calculate colors 272 | outputs = jnp.hstack( 273 | ( 274 | inputs[:, front_idx], # other values 275 | inputs[:, ref_idx, None], # ref mag 276 | -jnp.diff(inputs[:, mag_idx]), # colors 277 | ) 278 | ) 279 | # determinant of Jacobian is zero 280 | log_det = jnp.zeros(inputs.shape[0]) 281 | return outputs, log_det 282 | 283 | @InverseFunction 284 | def inverse_fun(params, inputs, **kwargs): 285 | # convert all colors to be in terms of the first magnitude, mag[0] 286 | outputs = jnp.hstack( 287 | ( 288 | inputs[:, 0:mag0_idx], # other values unchanged 289 | inputs[:, mag0_idx, None], # reference mag unchanged 290 | jnp.cumsum( 291 | inputs[:, mag0_idx + 1 :], axis=-1 292 | ), # all colors mag[i-1] - mag[i] --> mag[0] - mag[i] 293 | ) 294 | ) 295 | # calculate mag[0] 296 | outputs = mag0(outputs) 297 | # mag[i] = mag[0] - (mag[0] - mag[i]) 298 | outputs = outputs.at[:, mag0_idx + 1 :].set( 299 | outputs[:, mag0_idx, None] - outputs[:, mag0_idx + 1 :], 300 | indices_are_sorted=True, 301 | unique_indices=True, 302 | ) 303 | # return to original ordering 304 | outputs = outputs[:, jnp.argsort(new_idx)] 305 | # determinant of Jacobian is zero 306 | log_det = jnp.zeros(inputs.shape[0]) 307 | return outputs, log_det 308 | 309 | return (), forward_fun, inverse_fun 310 | 311 | return init_fun, bijector_info 312 | 313 | 314 | @Bijector 315 | def InvSoftplus( 316 | column_idx: int, sharpness: float = 1 317 | ) -> Tuple[InitFunction, Bijector_Info]: 318 | """Bijector that applies inverse softplus to the specified column(s). 319 | 320 | Applying the inverse softplus ensures that samples from that column will 321 | always be non-negative. This is because samples are the output of the 322 | inverse bijection -- so samples will have a softplus applied to them. 323 | 324 | Parameters 325 | ---------- 326 | column_idx : int 327 | An index or iterable of indices corresponding to the column(s) 328 | you wish to be transformed. 329 | sharpness : float; default=1 330 | The sharpness(es) of the softplus transformation. If more than one 331 | is provided, the list of sharpnesses must be of the same length as 332 | column_idx. 333 | 334 | Returns 335 | ------- 336 | InitFunction 337 | The InitFunction of the Softplus Bijector. 338 | Bijector_Info 339 | Tuple of the Bijector name and the input parameters. 340 | This allows it to be recreated later. 341 | """ 342 | 343 | idx = jnp.atleast_1d(column_idx) 344 | k = jnp.atleast_1d(sharpness) 345 | if len(idx) != len(k) and len(k) != 1: 346 | raise ValueError( 347 | "Please provide either a single sharpness or one for each column index." 348 | ) 349 | 350 | bijector_info = ("InvSoftplus", (column_idx, sharpness)) 351 | 352 | @InitFunction 353 | def init_fun(rng, input_dim, **kwargs): 354 | @ForwardFunction 355 | def forward_fun(params, inputs, **kwargs): 356 | outputs = inputs.at[:, idx].set( 357 | jnp.log(-1 + jnp.exp(k * inputs[:, idx])) / k, 358 | ) 359 | log_det = jnp.log(1 + jnp.exp(-k * outputs[:, idx])).sum(axis=1) 360 | return outputs, log_det 361 | 362 | @InverseFunction 363 | def inverse_fun(params, inputs, **kwargs): 364 | outputs = inputs.at[:, idx].set( 365 | jnp.log(1 + jnp.exp(k * inputs[:, idx])) / k, 366 | ) 367 | log_det = -jnp.log(1 + jnp.exp(-k * inputs[:, idx])).sum(axis=1) 368 | return outputs, log_det 369 | 370 | return (), forward_fun, inverse_fun 371 | 372 | return init_fun, bijector_info 373 | 374 | 375 | @Bijector 376 | def NeuralSplineCoupling( 377 | K: int = 16, 378 | B: float = 5, 379 | hidden_layers: int = 2, 380 | hidden_dim: int = 128, 381 | transformed_dim: int = None, 382 | n_conditions: int = 0, 383 | periodic: bool = False, 384 | ) -> Tuple[InitFunction, Bijector_Info]: 385 | """A coupling layer bijection with rational quadratic splines. 386 | 387 | This Bijector is a Coupling Layer [1,2], and as such only transforms 388 | the second half of input dimensions (or the last N dimensions, where 389 | N = transformed_dim). In order to transform all of the dimensions, 390 | you need multiple Couplings interspersed with Bijectors that change 391 | the order of inputs dimensions, e.g., Reverse, Shuffle, Roll, etc. 392 | 393 | NeuralSplineCoupling uses piecewise rational quadratic splines, 394 | as developed in [3]. 395 | 396 | If periodic=True, then this is a Circular Spline as described in [4]. 397 | 398 | Parameters 399 | ---------- 400 | K : int; default=16 401 | Number of bins in the spline (the number of knots is K+1). 402 | B : float; default=5 403 | Range of the splines. 404 | If periodic=False, outside of (-B,B), the transformation is just 405 | the identity. If periodic=True, the input is mapped into the 406 | appropriate location in the range (-B,B). 407 | hidden_layers : int; default=2 408 | The number of hidden layers in the neural network used to calculate 409 | the positions and derivatives of the spline knots. 410 | hidden_dim : int; default=128 411 | The width of the hidden layers in the neural network used to 412 | calculate the positions and derivatives of the spline knots. 413 | transformed_dim : int; optional 414 | The number of dimensions transformed by the splines. 415 | Default is ceiling(input_dim /2). 416 | n_conditions : int; default=0 417 | The number of variables to condition the bijection on. 418 | periodic : bool; default=False 419 | Whether to make this a periodic, Circular Spline [4]. 420 | 421 | Returns 422 | ------- 423 | InitFunction 424 | The InitFunction of the NeuralSplineCoupling Bijector. 425 | Bijector_Info 426 | Tuple of the Bijector name and the input parameters. 427 | This allows it to be recreated later. 428 | 429 | References 430 | ---------- 431 | [1] Laurent Dinh, David Krueger, Yoshua Bengio. NICE: Non-linear 432 | Independent Components Estimation. arXiv: 1605.08803, 2015. 433 | http://arxiv.org/abs/1605.08803 434 | [2] Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. 435 | Density Estimation Using Real NVP. arXiv: 1605.08803, 2017. 436 | http://arxiv.org/abs/1605.08803 437 | [3] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. 438 | Neural Spline Flows. arXiv:1906.04032, 2019. 439 | https://arxiv.org/abs/1906.04032 440 | [4] Rezende, Danilo Jimenez et al. 441 | Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020 442 | http://arxiv.org/abs/2002.02428 443 | """ 444 | 445 | if not isinstance(periodic, bool): 446 | raise ValueError("`periodic` must be True or False.") 447 | 448 | bijector_info = ( 449 | "NeuralSplineCoupling", 450 | ( 451 | K, 452 | B, 453 | hidden_layers, 454 | hidden_dim, 455 | transformed_dim, 456 | n_conditions, 457 | periodic, 458 | ), 459 | ) 460 | 461 | @InitFunction 462 | def init_fun(rng, input_dim, **kwargs): 463 | if transformed_dim is None: 464 | upper_dim = input_dim // 2 # variables that determine NN params 465 | lower_dim = ( 466 | input_dim - upper_dim 467 | ) # variables transformed by the NN 468 | else: 469 | upper_dim = input_dim - transformed_dim 470 | lower_dim = transformed_dim 471 | 472 | # create the neural network that will take in the upper dimensions and 473 | # will return the spline parameters to transform the lower dimensions 474 | network_init_fun, network_apply_fun = DenseReluNetwork( 475 | (3 * K - 1 + int(periodic)) * lower_dim, hidden_layers, hidden_dim 476 | ) 477 | _, network_params = network_init_fun(rng, (upper_dim + n_conditions,)) 478 | 479 | # calculate spline parameters as a function of the upper variables 480 | def spline_params(params, upper, conditions): 481 | key = jnp.hstack((upper, conditions))[ 482 | :, : upper_dim + n_conditions 483 | ] 484 | outputs = network_apply_fun(params, key) 485 | outputs = jnp.reshape( 486 | outputs, [-1, lower_dim, 3 * K - 1 + int(periodic)] 487 | ) 488 | W, H, D = jnp.split(outputs, [K, 2 * K], axis=2) 489 | W = 2 * B * softmax(W) 490 | H = 2 * B * softmax(H) 491 | D = softplus(D) 492 | return W, H, D 493 | 494 | @ForwardFunction 495 | def forward_fun(params, inputs, conditions, **kwargs): 496 | # lower dimensions are transformed as function of upper dimensions 497 | upper, lower = inputs[:, :upper_dim], inputs[:, upper_dim:] 498 | # widths, heights, derivatives = function(upper dimensions) 499 | W, H, D = spline_params(params, upper, conditions) 500 | # transform the lower dimensions with the Rational Quadratic Spline 501 | lower, log_det = RationalQuadraticSpline( 502 | lower, W, H, D, B, periodic, inverse=False 503 | ) 504 | outputs = jnp.hstack((upper, lower)) 505 | return outputs, log_det 506 | 507 | @InverseFunction 508 | def inverse_fun(params, inputs, conditions, **kwargs): 509 | # lower dimensions are transformed as function of upper dimensions 510 | upper, lower = inputs[:, :upper_dim], inputs[:, upper_dim:] 511 | # widths, heights, derivatives = function(upper dimensions) 512 | W, H, D = spline_params(params, upper, conditions) 513 | # transform the lower dimensions with the Rational Quadratic Spline 514 | lower, log_det = RationalQuadraticSpline( 515 | lower, W, H, D, B, periodic, inverse=True 516 | ) 517 | outputs = jnp.hstack((upper, lower)) 518 | return outputs, log_det 519 | 520 | return network_params, forward_fun, inverse_fun 521 | 522 | return init_fun, bijector_info 523 | 524 | 525 | @Bijector 526 | def Reverse() -> Tuple[InitFunction, Bijector_Info]: 527 | """Bijector that reverses the order of inputs. 528 | 529 | Returns 530 | ------- 531 | InitFunction 532 | The InitFunction of the the Reverse Bijector. 533 | Bijector_Info 534 | Tuple of the Bijector name and the input parameters. 535 | This allows it to be recreated later. 536 | """ 537 | 538 | bijector_info = ("Reverse", ()) 539 | 540 | @InitFunction 541 | def init_fun(rng, input_dim, **kwargs): 542 | @ForwardFunction 543 | def forward_fun(params, inputs, **kwargs): 544 | outputs = inputs[:, ::-1] 545 | log_det = jnp.zeros(inputs.shape[0]) 546 | return outputs, log_det 547 | 548 | @InverseFunction 549 | def inverse_fun(params, inputs, **kwargs): 550 | outputs = inputs[:, ::-1] 551 | log_det = jnp.zeros(inputs.shape[0]) 552 | return outputs, log_det 553 | 554 | return (), forward_fun, inverse_fun 555 | 556 | return init_fun, bijector_info 557 | 558 | 559 | @Bijector 560 | def Roll(shift: int = 1) -> Tuple[InitFunction, Bijector_Info]: 561 | """Bijector that rolls inputs along their last column using jnp.roll. 562 | 563 | Parameters 564 | ---------- 565 | shift : int; default=1 566 | The number of places to roll. 567 | 568 | Returns 569 | ------- 570 | InitFunction 571 | The InitFunction of the the Roll Bijector. 572 | Bijector_Info 573 | Tuple of the Bijector name and the input parameters. 574 | This allows it to be recreated later. 575 | """ 576 | 577 | if not isinstance(shift, int): 578 | raise ValueError("shift must be an integer.") 579 | 580 | bijector_info = ("Roll", (shift,)) 581 | 582 | @InitFunction 583 | def init_fun(rng, input_dim, **kwargs): 584 | @ForwardFunction 585 | def forward_fun(params, inputs, **kwargs): 586 | outputs = jnp.roll(inputs, shift=shift, axis=-1) 587 | log_det = jnp.zeros(inputs.shape[0]) 588 | return outputs, log_det 589 | 590 | @InverseFunction 591 | def inverse_fun(params, inputs, **kwargs): 592 | outputs = jnp.roll(inputs, shift=-shift, axis=-1) 593 | log_det = jnp.zeros(inputs.shape[0]) 594 | return outputs, log_det 595 | 596 | return (), forward_fun, inverse_fun 597 | 598 | return init_fun, bijector_info 599 | 600 | 601 | @Bijector 602 | def RollingSplineCoupling( 603 | nlayers: int, 604 | shift: int = 1, 605 | K: int = 16, 606 | B: float = 5, 607 | hidden_layers: int = 2, 608 | hidden_dim: int = 128, 609 | transformed_dim: int = None, 610 | n_conditions: int = 0, 611 | periodic: bool = False, 612 | ) -> Tuple[InitFunction, Bijector_Info]: 613 | """Bijector that alternates NeuralSplineCouplings and Roll bijections. 614 | 615 | Parameters 616 | ---------- 617 | nlayers : int 618 | The number of (NeuralSplineCoupling(), Roll()) couplets in the chain. 619 | shift : int 620 | How far the inputs are shifted on each Roll(). 621 | K : int; default=16 622 | Number of bins in the RollingSplineCoupling. 623 | B : float; default=5 624 | Range of the splines in the RollingSplineCoupling. 625 | If periodic=False, outside of (-B,B), the transformation is just 626 | the identity. If periodic=True, the input is mapped into the 627 | appropriate location in the range (-B,B). 628 | hidden_layers : int; default=2 629 | The number of hidden layers in the neural network used to calculate 630 | the bins and derivatives in the RollingSplineCoupling. 631 | hidden_dim : int; default=128 632 | The width of the hidden layers in the neural network used to 633 | calculate the bins and derivatives in the RollingSplineCoupling. 634 | transformed_dim : int; optional 635 | The number of dimensions transformed by the splines. 636 | Default is ceiling(input_dim /2). 637 | n_conditions : int; default=0 638 | The number of variables to condition the bijection on. 639 | periodic : bool; default=False 640 | Whether to make this a periodic, Circular Spline 641 | 642 | Returns 643 | ------- 644 | InitFunction 645 | The InitFunction of the RollingSplineCoupling Bijector. 646 | Bijector_Info 647 | Nested tuple of the Bijector name and input parameters. This allows 648 | it to be recreated later. 649 | 650 | """ 651 | return Chain( 652 | *( 653 | NeuralSplineCoupling( 654 | K=K, 655 | B=B, 656 | hidden_layers=hidden_layers, 657 | hidden_dim=hidden_dim, 658 | transformed_dim=transformed_dim, 659 | n_conditions=n_conditions, 660 | periodic=periodic, 661 | ), 662 | Roll(shift), 663 | ) 664 | * nlayers 665 | ) 666 | 667 | 668 | @Bijector 669 | def Scale(scale: float) -> Tuple[InitFunction, Bijector_Info]: 670 | """Bijector that multiplies inputs by a scalar. 671 | 672 | Parameters 673 | ---------- 674 | scale : float 675 | Factor by which to scale inputs. 676 | 677 | Returns 678 | ------- 679 | InitFunction 680 | The InitFunction of the the Scale Bijector. 681 | Bijector_Info 682 | Tuple of the Bijector name and the input parameters. 683 | This allows it to be recreated later. 684 | """ 685 | 686 | if isinstance(scale, jnp.ndarray): 687 | if scale.dtype != jnp.float32: 688 | raise ValueError("scale must be a float or array of floats.") 689 | elif not isinstance(scale, float): 690 | raise ValueError("scale must be a float or array of floats.") 691 | 692 | bijector_info = ("Scale", (scale,)) 693 | 694 | @InitFunction 695 | def init_fun(rng, input_dim, **kwargs): 696 | @ForwardFunction 697 | def forward_fun(params, inputs, **kwargs): 698 | outputs = scale * inputs 699 | log_det = jnp.log(scale ** inputs.shape[-1]) * jnp.ones( 700 | inputs.shape[0] 701 | ) 702 | return outputs, log_det 703 | 704 | @InverseFunction 705 | def inverse_fun(params, inputs, **kwargs): 706 | outputs = 1 / scale * inputs 707 | log_det = -jnp.log(scale ** inputs.shape[-1]) * jnp.ones( 708 | inputs.shape[0] 709 | ) 710 | return outputs, log_det 711 | 712 | return (), forward_fun, inverse_fun 713 | 714 | return init_fun, bijector_info 715 | 716 | 717 | @Bijector 718 | def ShiftBounds( 719 | min: float, max: float, B: float = 5 720 | ) -> Tuple[InitFunction, Bijector_Info]: 721 | """Bijector shifts the bounds of inputs so the lie in the range (-B, B). 722 | 723 | Parameters 724 | ---------- 725 | min : float 726 | The minimum of the input range. 727 | min : float 728 | The maximum of the input range. 729 | B : float; default=5 730 | The extent of the output bounds, which will be (-B, B). 731 | 732 | Returns 733 | ------- 734 | InitFunction 735 | The InitFunction of the ShiftBounds Bijector. 736 | Bijector_Info 737 | Tuple of the Bijector name and the input parameters. 738 | This allows it to be recreated later. 739 | """ 740 | 741 | min = jnp.atleast_1d(min) 742 | max = jnp.atleast_1d(max) 743 | if len(min) != len(max): 744 | raise ValueError( 745 | "Lengths of min and max do not match. " 746 | + "Please provide either a single min and max, " 747 | + "or a min and max for each dimension." 748 | ) 749 | if (min > max).any(): 750 | raise ValueError("All mins must be less than maxes.") 751 | 752 | bijector_info = ("ShiftBounds", (min, max, B)) 753 | 754 | mean = (max + min) / 2 755 | half_range = (max - min) / 2 756 | 757 | @InitFunction 758 | def init_fun(rng, input_dim, **kwargs): 759 | @ForwardFunction 760 | def forward_fun(params, inputs, **kwargs): 761 | outputs = B * (inputs - mean) / half_range 762 | log_det = jnp.log(jnp.prod(B / half_range)) * jnp.ones( 763 | inputs.shape[0] 764 | ) 765 | return outputs, log_det 766 | 767 | @InverseFunction 768 | def inverse_fun(params, inputs, **kwargs): 769 | outputs = inputs * half_range / B + mean 770 | log_det = jnp.log(jnp.prod(half_range / B)) * jnp.ones( 771 | inputs.shape[0] 772 | ) 773 | return outputs, log_det 774 | 775 | return (), forward_fun, inverse_fun 776 | 777 | return init_fun, bijector_info 778 | 779 | 780 | @Bijector 781 | def Shuffle() -> Tuple[InitFunction, Bijector_Info]: 782 | """Bijector that randomly permutes inputs. 783 | 784 | Returns 785 | ------- 786 | InitFunction 787 | The InitFunction of the Shuffle Bijector. 788 | Bijector_Info 789 | Tuple of the Bijector name and the input parameters. 790 | This allows it to be recreated later. 791 | """ 792 | 793 | bijector_info = ("Shuffle", ()) 794 | 795 | @InitFunction 796 | def init_fun(rng, input_dim, **kwargs): 797 | perm = random.permutation(rng, jnp.arange(input_dim)) 798 | inv_perm = jnp.argsort(perm) 799 | 800 | @ForwardFunction 801 | def forward_fun(params, inputs, **kwargs): 802 | outputs = inputs[:, perm] 803 | log_det = jnp.zeros(inputs.shape[0]) 804 | return outputs, log_det 805 | 806 | @InverseFunction 807 | def inverse_fun(params, inputs, **kwargs): 808 | outputs = inputs[:, inv_perm] 809 | log_det = jnp.zeros(inputs.shape[0]) 810 | return outputs, log_det 811 | 812 | return (), forward_fun, inverse_fun 813 | 814 | return init_fun, bijector_info 815 | 816 | 817 | @Bijector 818 | def StandardScaler( 819 | means: jnp.array, stds: jnp.array 820 | ) -> Tuple[InitFunction, Bijector_Info]: 821 | """Bijector that applies standard scaling to each input. 822 | 823 | Each input dimension i has an associated mean u_i and standard dev s_i. 824 | Each input is rescaled as (input[i] - u_i)/s_i, so that each input dimension 825 | has mean zero and unit variance. 826 | 827 | Parameters 828 | ---------- 829 | means : jnp.ndarray 830 | The mean of each column. 831 | stds : jnp.ndarray 832 | The standard deviation of each column. 833 | 834 | Returns 835 | ------- 836 | InitFunction 837 | The InitFunction of the StandardScaler Bijector. 838 | Bijector_Info 839 | Tuple of the Bijector name and the input parameters. 840 | This allows it to be recreated later. 841 | """ 842 | 843 | bijector_info = ("StandardScaler", (means, stds)) 844 | 845 | @InitFunction 846 | def init_fun(rng, input_dim, **kwargs): 847 | @ForwardFunction 848 | def forward_fun(params, inputs, **kwargs): 849 | outputs = (inputs - means) / stds 850 | log_det = jnp.log(1 / jnp.prod(stds)) * jnp.ones(inputs.shape[0]) 851 | return outputs, log_det 852 | 853 | @InverseFunction 854 | def inverse_fun(params, inputs, **kwargs): 855 | outputs = inputs * stds + means 856 | log_det = jnp.log(jnp.prod(stds)) * jnp.ones(inputs.shape[0]) 857 | return outputs, log_det 858 | 859 | return (), forward_fun, inverse_fun 860 | 861 | return init_fun, bijector_info 862 | 863 | 864 | @Bijector 865 | def UniformDequantizer(column_idx: int) -> Tuple[InitFunction, Bijector_Info]: 866 | """Bijector that dequantizes discrete variables with uniform noise. 867 | 868 | Dequantizers are necessary for modeling discrete values with a flow. 869 | Note that this isn't technically a bijector. 870 | 871 | Parameters 872 | ---------- 873 | column_idx : int 874 | An index or iterable of indices corresponding to the column(s) with 875 | discrete values. 876 | 877 | Returns 878 | ------- 879 | InitFunction 880 | The InitFunction of the UniformDequantizer Bijector. 881 | Bijector_Info 882 | Tuple of the Bijector name and the input parameters. 883 | This allows it to be recreated later. 884 | """ 885 | 886 | bijector_info = ("UniformDequantizer", (column_idx,)) 887 | column_idx = jnp.array(column_idx) 888 | 889 | @InitFunction 890 | def init_fun(rng, input_dim, **kwargs): 891 | @ForwardFunction 892 | def forward_fun(params, inputs, **kwargs): 893 | u = random.uniform( 894 | random.PRNGKey(0), shape=inputs[:, column_idx].shape 895 | ) 896 | outputs = inputs.astype(float) 897 | outputs.at[:, column_idx].set(outputs[:, column_idx] + u) 898 | log_det = jnp.zeros(inputs.shape[0]) 899 | return outputs, log_det 900 | 901 | @InverseFunction 902 | def inverse_fun(params, inputs, **kwargs): 903 | outputs = inputs.at[:, column_idx].set( 904 | jnp.floor(inputs[:, column_idx]) 905 | ) 906 | log_det = jnp.zeros(inputs.shape[0]) 907 | return outputs, log_det 908 | 909 | return (), forward_fun, inverse_fun 910 | 911 | return init_fun, bijector_info 912 | -------------------------------------------------------------------------------- /pzflow/distributions.py: -------------------------------------------------------------------------------- 1 | """Define the latent distributions used in the normalizing flows.""" 2 | import sys 3 | from abc import ABC, abstractmethod 4 | from typing import Union 5 | 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import random 9 | from jax.scipy.special import gammaln 10 | from jax.scipy.stats import beta, multivariate_normal 11 | 12 | from pzflow.bijectors import Pytree 13 | 14 | epsilon = sys.float_info.epsilon 15 | 16 | 17 | class LatentDist(ABC): 18 | """Base class for latent distributions.""" 19 | 20 | info = ("LatentDist", ()) 21 | 22 | @abstractmethod 23 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 24 | """Calculate log-probability of the inputs.""" 25 | 26 | @abstractmethod 27 | def sample( 28 | self, params: Pytree, nsamples: int, seed: int = None 29 | ) -> jnp.ndarray: 30 | """Sample from the distribution.""" 31 | 32 | 33 | def _mahalanobis_and_logdet(x: jnp.array, cov: jnp.array) -> tuple: 34 | # Calculate mahalanobis distance and log_det of cov. 35 | # Uses scipy method, explained here: 36 | # http://gregorygundersen.com/blog/2019/10/30/scipy-multivariate/ 37 | 38 | vals, vecs = jnp.linalg.eigh(cov) 39 | U = vecs * jnp.sqrt(1 / vals[..., None]) 40 | maha = jnp.square(U @ x[..., None]).reshape(x.shape[0], -1).sum(axis=1) 41 | log_det = jnp.log(vals).sum(axis=-1) 42 | return maha, log_det 43 | 44 | 45 | class CentBeta(LatentDist): 46 | """A centered Beta distribution. 47 | 48 | This distribution is just a regular Beta distribution, scaled and shifted 49 | to have support on the domain [-B, B] in each dimension. 50 | 51 | Alpha and beta parameters for each dimension are learned during training. 52 | """ 53 | 54 | def __init__(self, input_dim: int, B: float = 5) -> None: 55 | """ 56 | Parameters 57 | ---------- 58 | input_dim : int 59 | The dimension of the distribution. 60 | B : float; default=5 61 | The distribution has support (-B, B) along each dimension. 62 | """ 63 | self.input_dim = input_dim 64 | self.B = B 65 | 66 | # save dist info 67 | self._params = tuple([(0.0, 0.0) for i in range(input_dim)]) 68 | self.info = ("CentBeta", (input_dim, B)) 69 | 70 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 71 | """Calculates log probability density of inputs. 72 | 73 | Parameters 74 | ---------- 75 | params : a Jax pytree 76 | Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta) 77 | for the Nth dimension. 78 | inputs : jnp.ndarray 79 | Input data for which log probability density is calculated. 80 | 81 | Returns 82 | ------- 83 | jnp.ndarray 84 | Device array of shape (inputs.shape[0],). 85 | """ 86 | log_prob = jnp.hstack( 87 | [ 88 | beta.logpdf( 89 | inputs[:, i], 90 | a=jnp.exp(params[i][0]), 91 | b=jnp.exp(params[i][1]), 92 | loc=-self.B, 93 | scale=2 * self.B, 94 | ).reshape(-1, 1) 95 | for i in range(self.input_dim) 96 | ] 97 | ).sum(axis=1) 98 | 99 | return log_prob 100 | 101 | def sample( 102 | self, params: Pytree, nsamples: int, seed: int = None 103 | ) -> jnp.ndarray: 104 | """Returns samples from the distribution. 105 | 106 | Parameters 107 | ---------- 108 | params : a Jax pytree 109 | Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta) 110 | for the Nth dimension. 111 | nsamples : int 112 | The number of samples to be returned. 113 | seed : int; optional 114 | Sets the random seed for the samples. 115 | 116 | Returns 117 | ------- 118 | jnp.ndarray 119 | Device array of shape (nsamples, self.input_dim). 120 | """ 121 | seed = np.random.randint(1e18) if seed is None else seed 122 | seeds = random.split(random.PRNGKey(seed), self.input_dim) 123 | samples = jnp.hstack( 124 | [ 125 | random.beta( 126 | seeds[i], 127 | jnp.exp(params[i][0]), 128 | jnp.exp(params[i][1]), 129 | shape=(nsamples, 1), 130 | ) 131 | for i in range(self.input_dim) 132 | ] 133 | ) 134 | return 2 * self.B * (samples - 0.5) 135 | 136 | 137 | class CentBeta13(LatentDist): 138 | """A centered Beta distribution with alpha, beta = 13. 139 | 140 | This distribution is just a regular Beta distribution, scaled and shifted 141 | to have support on the domain [-B, B] in each dimension. 142 | 143 | Alpha, beta = 13 means that the distribution looks like a Gaussian 144 | distribution, but with hard cutoffs at +/- B. 145 | """ 146 | 147 | def __init__(self, input_dim: int, B: float = 5) -> None: 148 | """ 149 | Parameters 150 | ---------- 151 | input_dim : int 152 | The dimension of the distribution. 153 | B : float; default=5 154 | The distribution has support (-B, B) along each dimension. 155 | """ 156 | self.input_dim = input_dim 157 | self.B = B 158 | 159 | # save dist info 160 | self._params = tuple([(0.0, 0.0) for i in range(input_dim)]) 161 | self.info = ("CentBeta13", (input_dim, B)) 162 | self.a = 13 163 | self.b = 13 164 | 165 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 166 | """Calculates log probability density of inputs. 167 | 168 | Parameters 169 | ---------- 170 | params : a Jax pytree 171 | Empty pytree -- this distribution doesn't have learnable parameters. 172 | This parameter is present to ensure a consistent interface. 173 | inputs : jnp.ndarray 174 | Input data for which log probability density is calculated. 175 | 176 | Returns 177 | ------- 178 | jnp.ndarray 179 | Device array of shape (inputs.shape[0],). 180 | """ 181 | log_prob = jnp.hstack( 182 | [ 183 | beta.logpdf( 184 | inputs[:, i], 185 | a=self.a, 186 | b=self.b, 187 | loc=-self.B, 188 | scale=2 * self.B, 189 | ).reshape(-1, 1) 190 | for i in range(self.input_dim) 191 | ] 192 | ).sum(axis=1) 193 | 194 | return log_prob 195 | 196 | def sample( 197 | self, params: Pytree, nsamples: int, seed: int = None 198 | ) -> jnp.ndarray: 199 | """Returns samples from the distribution. 200 | 201 | Parameters 202 | ---------- 203 | params : a Jax pytree 204 | Empty pytree -- this distribution doesn't have learnable parameters. 205 | This parameter is present to ensure a consistent interface. 206 | nsamples : int 207 | The number of samples to be returned. 208 | seed : int; optional 209 | Sets the random seed for the samples. 210 | 211 | Returns 212 | ------- 213 | jnp.ndarray 214 | Device array of shape (nsamples, self.input_dim). 215 | """ 216 | seed = np.random.randint(1e18) if seed is None else seed 217 | seeds = random.split(random.PRNGKey(seed), self.input_dim) 218 | samples = jnp.hstack( 219 | [ 220 | random.beta( 221 | seeds[i], 222 | self.a, 223 | self.b, 224 | shape=(nsamples, 1), 225 | ) 226 | for i in range(self.input_dim) 227 | ] 228 | ) 229 | return 2 * self.B * (samples - 0.5) 230 | 231 | 232 | class Normal(LatentDist): 233 | """A multivariate Gaussian distribution with mean zero and unit variance. 234 | 235 | Note this distribution has infinite support, so it is not recommended that 236 | you use it with the spline coupling layers, which have compact support. 237 | If you do use the two together, you should set the support of the spline 238 | layers (using the spline parameter B) to be large enough that you rarely 239 | draw Gaussian samples outside the support of the splines. 240 | """ 241 | 242 | def __init__(self, input_dim: int) -> None: 243 | """ 244 | Parameters 245 | ---------- 246 | input_dim : int 247 | The dimension of the distribution. 248 | """ 249 | self.input_dim = input_dim 250 | 251 | # save dist info 252 | self._params = () 253 | self.info = ("Normal", (input_dim,)) 254 | 255 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 256 | """Calculates log probability density of inputs. 257 | 258 | Parameters 259 | ---------- 260 | params : a Jax pytree 261 | Empty pytree -- this distribution doesn't have learnable parameters. 262 | This parameter is present to ensure a consistent interface. 263 | inputs : jnp.ndarray 264 | Input data for which log probability density is calculated. 265 | 266 | Returns 267 | ------- 268 | jnp.ndarray 269 | Device array of shape (inputs.shape[0],). 270 | """ 271 | return multivariate_normal.logpdf( 272 | inputs, 273 | mean=jnp.zeros(self.input_dim), 274 | cov=jnp.identity(self.input_dim), 275 | ) 276 | 277 | def sample( 278 | self, params: Pytree, nsamples: int, seed: int = None 279 | ) -> jnp.ndarray: 280 | """Returns samples from the distribution. 281 | 282 | Parameters 283 | ---------- 284 | params : a Jax pytree 285 | Empty pytree -- this distribution doesn't have learnable parameters. 286 | This parameter is present to ensure a consistent interface. 287 | nsamples : int 288 | The number of samples to be returned. 289 | seed : int; optional 290 | Sets the random seed for the samples. 291 | 292 | Returns 293 | ------- 294 | jnp.ndarray 295 | Device array of shape (nsamples, self.input_dim). 296 | """ 297 | seed = np.random.randint(1e18) if seed is None else seed 298 | return random.multivariate_normal( 299 | key=random.PRNGKey(seed), 300 | mean=jnp.zeros(self.input_dim), 301 | cov=jnp.identity(self.input_dim), 302 | shape=(nsamples,), 303 | ) 304 | 305 | 306 | class Tdist(LatentDist): 307 | """A multivariate T distribution with mean zero and unit scale matrix. 308 | 309 | The number of degrees of freedom (i.e. the weight of the tails) is learned 310 | during training. 311 | 312 | Note this distribution has infinite support and potentially large tails, 313 | so it is not recommended to use this distribution with the spline coupling 314 | layers, which have compact support. 315 | """ 316 | 317 | def __init__(self, input_dim: int) -> None: 318 | """ 319 | Parameters 320 | ---------- 321 | input_dim : int 322 | The dimension of the distribution. 323 | """ 324 | self.input_dim = input_dim 325 | 326 | # save dist info 327 | self._params = jnp.log(30.0) 328 | self.info = ("Tdist", (input_dim,)) 329 | 330 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 331 | """Calculates log probability density of inputs. 332 | 333 | Uses method explained here: 334 | http://gregorygundersen.com/blog/2020/01/20/multivariate-t/ 335 | 336 | Parameters 337 | ---------- 338 | params : float 339 | The degrees of freedom (nu) of the t-distribution. 340 | inputs : jnp.ndarray 341 | Input data for which log probability density is calculated. 342 | 343 | Returns 344 | ------- 345 | jnp.ndarray 346 | Device array of shape (inputs.shape[0],). 347 | """ 348 | cov = jnp.identity(self.input_dim) 349 | nu = jnp.exp(params) 350 | maha, log_det = _mahalanobis_and_logdet(inputs, cov) 351 | t = 0.5 * (nu + self.input_dim) 352 | A = gammaln(t) 353 | B = gammaln(0.5 * nu) 354 | C = self.input_dim / 2.0 * jnp.log(nu * jnp.pi) 355 | D = 0.5 * log_det 356 | E = -t * jnp.log(1 + (1.0 / nu) * maha) 357 | 358 | return A - B - C - D + E 359 | 360 | def sample( 361 | self, params: Pytree, nsamples: int, seed: int = None 362 | ) -> jnp.ndarray: 363 | """Returns samples from the distribution. 364 | 365 | Parameters 366 | ---------- 367 | params : float 368 | The degrees of freedom (nu) of the t-distribution. 369 | nsamples : int 370 | The number of samples to be returned. 371 | seed : int; optional 372 | Sets the random seed for the samples. 373 | 374 | Returns 375 | ------- 376 | jnp.ndarray 377 | Device array of shape (nsamples, self.input_dim). 378 | """ 379 | mean = jnp.zeros(self.input_dim) 380 | nu = jnp.exp(params) 381 | 382 | seed = np.random.randint(1e18) if seed is None else seed 383 | rng = np.random.default_rng(int(seed)) 384 | x = jnp.array(rng.chisquare(nu, nsamples) / nu) 385 | z = random.multivariate_normal( 386 | key=random.PRNGKey(seed), 387 | mean=jnp.zeros(self.input_dim), 388 | cov=jnp.identity(self.input_dim), 389 | shape=(nsamples,), 390 | ) 391 | samples = mean + z / jnp.sqrt(x)[:, None] 392 | return samples 393 | 394 | 395 | class Uniform(LatentDist): 396 | """A multivariate uniform distribution with support [-B, B].""" 397 | 398 | def __init__(self, input_dim: int, B: float = 5) -> None: 399 | """ 400 | Parameters 401 | ---------- 402 | input_dim : int 403 | The dimension of the distribution. 404 | B : float; default=5 405 | The distribution has support (-B, B) along each dimension. 406 | """ 407 | self.input_dim = input_dim 408 | self.B = B 409 | 410 | # save dist info 411 | self._params = () 412 | self.info = ("Uniform", (input_dim, B)) 413 | 414 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 415 | """Calculates log probability density of inputs. 416 | 417 | Parameters 418 | ---------- 419 | params : Jax Pytree 420 | Empty pytree -- this distribution doesn't have learnable parameters. 421 | This parameter is present to ensure a consistent interface. 422 | inputs : jnp.ndarray 423 | Input data for which log probability density is calculated. 424 | 425 | Returns 426 | ------- 427 | jnp.ndarray 428 | Device array of shape (inputs.shape[0],). 429 | """ 430 | 431 | # which inputs are inside the support of the distribution 432 | mask = jnp.prod((inputs >= -self.B) & (inputs <= self.B), axis=-1) 433 | 434 | # calculate log_prob 435 | log_prob = jnp.where( 436 | mask, 437 | -self.input_dim * jnp.log(2 * self.B), 438 | -jnp.inf, 439 | ) 440 | 441 | return log_prob 442 | 443 | def sample( 444 | self, params: Pytree, nsamples: int, seed: int = None 445 | ) -> jnp.ndarray: 446 | """Returns samples from the distribution. 447 | 448 | Parameters 449 | ---------- 450 | params : a Jax pytree 451 | Empty pytree -- this distribution doesn't have learnable parameters. 452 | This parameter is present to ensure a consistent interface. 453 | nsamples : int 454 | The number of samples to be returned. 455 | seed : int; optional 456 | Sets the random seed for the samples. 457 | 458 | Returns 459 | ------- 460 | jnp.ndarray 461 | Device array of shape (nsamples, self.input_dim). 462 | """ 463 | seed = np.random.randint(1e18) if seed is None else seed 464 | samples = random.uniform( 465 | random.PRNGKey(seed), 466 | shape=(nsamples, self.input_dim), 467 | minval=-self.B, 468 | maxval=self.B, 469 | ) 470 | return jnp.array(samples) 471 | 472 | 473 | class Joint(LatentDist): 474 | """A joint distribution built from other distributions. 475 | 476 | Note that each of the other distributions already have support for 477 | multiple dimensions. This is only useful if you want to combine 478 | different distributions for different dimensions, e.g. if your first 479 | dimension has a Uniform latent space and the second dimension has a 480 | CentBeta latent space. 481 | """ 482 | 483 | def __init__(self, *inputs: Union[LatentDist, tuple]) -> None: 484 | """ 485 | Parameters 486 | ---------- 487 | inputs: LatentDist or tuple 488 | The latent distributions to join together. 489 | """ 490 | 491 | # if Joint info provided, use that for setup 492 | if inputs[0] == "Joint info": 493 | self.dists = [globals()[dist[0]](*dist[1]) for dist in inputs[1]] 494 | # otherwise, assume it's a list of distributions 495 | else: 496 | self.dists = inputs 497 | 498 | # save info 499 | self._params = [dist._params for dist in self.dists] 500 | self.input_dim = sum([dist.input_dim for dist in self.dists]) 501 | self.info = ( 502 | "Joint", 503 | ("Joint info", [dist.info for dist in self.dists]), 504 | ) 505 | 506 | # save the indices at which inputs will be split for log_prob 507 | # they must be concretely saved ahead-of-time so that jax trace 508 | # works properly when jitting 509 | self._splits = jnp.cumsum( 510 | jnp.array([dist.input_dim for dist in self.dists]) 511 | )[:-1] 512 | 513 | def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray: 514 | """Calculates log probability density of inputs. 515 | 516 | Parameters 517 | ---------- 518 | params : Jax Pytree 519 | Parameters for the distributions. 520 | inputs : jnp.ndarray 521 | Input data for which log probability density is calculated. 522 | 523 | Returns 524 | ------- 525 | jnp.ndarray 526 | Device array of shape (inputs.shape[0],). 527 | """ 528 | 529 | # split inputs for corresponding distribution 530 | inputs = jnp.split(inputs, self._splits, axis=1) 531 | 532 | # calculate log_prob with respect to each sub-distribution, 533 | # then sum all the log_probs for each input 534 | log_prob = jnp.hstack( 535 | [ 536 | self.dists[i].log_prob(params[i], inputs[i]).reshape(-1, 1) 537 | for i in range(len(self.dists)) 538 | ] 539 | ).sum(axis=1) 540 | 541 | return log_prob 542 | 543 | def sample( 544 | self, params: Pytree, nsamples: int, seed: int = None 545 | ) -> jnp.ndarray: 546 | """Returns samples from the distribution. 547 | 548 | Parameters 549 | ---------- 550 | params : a Jax pytree 551 | Parameters for the distributions. 552 | nsamples : int 553 | The number of samples to be returned. 554 | seed : int; optional 555 | Sets the random seed for the samples. 556 | 557 | Returns 558 | ------- 559 | jnp.ndarray 560 | Device array of shape (nsamples, self.input_dim). 561 | """ 562 | 563 | seed = np.random.randint(1e18) if seed is None else seed 564 | seeds = random.randint( 565 | random.PRNGKey(seed), (len(self.dists),), 0, int(1e9) 566 | ) 567 | samples = jnp.hstack( 568 | [ 569 | self.dists[i] 570 | .sample(params[i], nsamples, seeds[i]) 571 | .reshape(nsamples, -1) 572 | for i in range(len(self.dists)) 573 | ] 574 | ) 575 | 576 | return samples 577 | -------------------------------------------------------------------------------- /pzflow/example_files/checkerboard-data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/checkerboard-data.pkl -------------------------------------------------------------------------------- /pzflow/example_files/city-data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/city-data.pkl -------------------------------------------------------------------------------- /pzflow/example_files/example-flow.pzflow.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/example-flow.pzflow.pkl -------------------------------------------------------------------------------- /pzflow/example_files/galaxy-data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/galaxy-data.pkl -------------------------------------------------------------------------------- /pzflow/example_files/two-moons-data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfcrenshaw/pzflow/b3fb837cff6d758a28063c4cd82489d9607ab086/pzflow/example_files/two-moons-data.pkl -------------------------------------------------------------------------------- /pzflow/examples.py: -------------------------------------------------------------------------------- 1 | """Functions that return example data and a example flow trained on 2 | galaxy data. To see these examples in action, see the tutorial notebooks. 3 | """ 4 | 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from pzflow import Flow 10 | 11 | EXAMPLE_FILE_DIR = "example_files" 12 | 13 | 14 | def _load_example_data(name: str) -> pd.DataFrame: 15 | this_dir, _ = os.path.split(__file__) 16 | data_path = os.path.join(this_dir, f"{EXAMPLE_FILE_DIR}/{name}.pkl") 17 | data = pd.read_pickle(data_path) 18 | return data 19 | 20 | 21 | def get_twomoons_data() -> pd.DataFrame: 22 | """Return DataFrame with two moons example data. 23 | 24 | Two moons data originally from scikit-learn, 25 | i.e., `sklearn.datasets.make_moons`. 26 | """ 27 | return _load_example_data("two-moons-data") 28 | 29 | 30 | def get_galaxy_data() -> pd.DataFrame: 31 | """Return DataFrame with example galaxy data. 32 | 33 | 100,000 galaxies from the Buzzard simulation [1], with redshifts 34 | in the range (0,2.3) and photometry in the LSST ugrizy bands. 35 | 36 | References 37 | ---------- 38 | [1] Joseph DeRose et al. The Buzzard Flock: Dark Energy Survey 39 | Synthetic Sky Catalogs. arXiv:1901.02401, 2019. 40 | https://arxiv.org/abs/1901.02401 41 | """ 42 | return _load_example_data("galaxy-data") 43 | 44 | 45 | def get_checkerboard_data() -> pd.DataFrame: 46 | """Return DataFrame with discrete checkerboard data.""" 47 | return _load_example_data("checkerboard-data") 48 | 49 | 50 | def get_city_data() -> pd.DataFrame: 51 | """Return DataFrame with example city data. 52 | 53 | The countries, names, population, and coordinates of 47,966 cities. 54 | 55 | Subset of the Kaggle world cities database. 56 | https://www.kaggle.com/max-mind/world-cities-database 57 | This database was downloaded from MaxMind. The license follows: 58 | 59 | OPEN DATA LICENSE for MaxMind WorldCities and Postal Code Databases 60 | 61 | Copyright (c) 2008 MaxMind Inc. All Rights Reserved. 62 | 63 | The database uses toponymic information, based on the Geographic Names 64 | Data Base, containing official standard names approved by the United States 65 | Board on Geographic Names and maintained by the National 66 | Geospatial-Intelligence Agency. More information is available at the Maps 67 | and Geodata link at www.nga.mil. The National Geospatial-Intelligence Agency 68 | name, initials, and seal are protected by 10 United States Code Section 445. 69 | 70 | It also uses free population data from Stefan Helders www.world-gazetteer.com. 71 | Visit his website to download the free population data. Our database 72 | combines Stefan's population data with the list of all cities in the world. 73 | 74 | All advertising materials and documentation mentioning features or use of 75 | this database must display the following acknowledgment: 76 | "This product includes data created by MaxMind, available from 77 | http://www.maxmind.com/" 78 | 79 | Redistribution and use with or without modification, are permitted provided 80 | that the following conditions are met: 81 | 1. Redistributions must retain the above copyright notice, this list of 82 | conditions and the following disclaimer in the documentation and/or other 83 | materials provided with the distribution. 84 | 2. All advertising materials and documentation mentioning features or use of 85 | this database must display the following acknowledgement: 86 | "This product includes data created by MaxMind, available from 87 | http://www.maxmind.com/" 88 | 3. "MaxMind" may not be used to endorse or promote products derived from this 89 | database without specific prior written permission. 90 | 91 | THIS DATABASE IS PROVIDED BY MAXMIND.COM ``AS IS'' AND ANY 92 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 93 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 94 | DISCLAIMED. IN NO EVENT SHALL MAXMIND.COM BE LIABLE FOR ANY 95 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 96 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 97 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 98 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 99 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 100 | DATABASE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 101 | """ 102 | return _load_example_data("city-data") 103 | 104 | 105 | def get_example_flow() -> Flow: 106 | """Return a normalizing flow that was trained on galaxy data. 107 | 108 | This flow was trained in the `redshift_example.ipynb` Jupyter notebook, 109 | on the example data available in `pzflow.examples.galaxy_data`. 110 | For more info: `print(example_flow().info)`. 111 | """ 112 | this_dir, _ = os.path.split(__file__) 113 | flow_path = os.path.join( 114 | this_dir, f"{EXAMPLE_FILE_DIR}/example-flow.pzflow.pkl" 115 | ) 116 | flow = Flow(file=flow_path) 117 | return flow 118 | -------------------------------------------------------------------------------- /pzflow/flowEnsemble.py: -------------------------------------------------------------------------------- 1 | """Define FlowEnsemble object that holds an ensemble of normalizing flows.""" 2 | 3 | from typing import Any, Callable, Sequence, Tuple 4 | 5 | import dill as pickle 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import pandas as pd 9 | from jax import random 10 | from jax.scipy.integrate import trapezoid 11 | 12 | from pzflow import Flow, distributions 13 | from pzflow.bijectors import Bijector_Info, InitFunction 14 | 15 | 16 | class FlowEnsemble: 17 | """An ensemble of normalizing flows. 18 | 19 | Attributes 20 | ---------- 21 | data_columns : tuple 22 | List of DataFrame columns that the flows expect/produce. 23 | conditional_columns : tuple 24 | List of DataFrame columns on which the flows are conditioned. 25 | latent: distributions.LatentDist 26 | The latent distribution of the normalizing flows. 27 | Has it's own sample and log_prob methods. 28 | data_error_model : Callable 29 | The error model for the data variables. See the docstring of 30 | __init__ for more details. 31 | condition_error_model : Callable 32 | The error model for the conditional variables. See the docstring 33 | of __init__ for more details. 34 | info : Any 35 | Object containing any kind of info included with the ensemble. 36 | Often Reverse the data the flows are trained on. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | data_columns: Sequence[str] = None, 42 | bijector: Tuple[InitFunction, Bijector_Info] = None, 43 | latent: distributions.LatentDist = None, 44 | conditional_columns: Sequence[str] = None, 45 | data_error_model: Callable = None, 46 | condition_error_model: Callable = None, 47 | autoscale_conditions: bool = True, 48 | N: int = 1, 49 | info: Any = None, 50 | file: str = None, 51 | ) -> None: 52 | """Instantiate an ensemble of normalizing flows. 53 | 54 | Note that while all of the init parameters are technically optional, 55 | you must provide either data_columns and bijector OR file. 56 | In addition, if a file is provided, all other parameters must be None. 57 | 58 | Parameters 59 | ---------- 60 | data_columns : Sequence[str]; optional 61 | Tuple, list, or other container of column names. 62 | These are the columns the flows expect/produce in DataFrames. 63 | bijector : Bijector Call; optional 64 | A Bijector call that consists of the bijector InitFunction that 65 | initializes the bijector and the tuple of Bijector Info. 66 | Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc. 67 | If not provided, the bijector can be set later using 68 | flow.set_bijector, or by calling flow.train, in which case the 69 | default bijector will be used. The default bijector is 70 | ShiftBounds -> RollingSplineCoupling, where the range of shift 71 | bounds is learned from the training data, and the dimensions of 72 | RollingSplineCoupling is inferred. The default bijector assumes 73 | that the latent has support [-5, 5] for every dimension. 74 | latent : distributions.LatentDist; optional 75 | The latent distribution for the normalizing flow. Can be any of 76 | the distributions from pzflow.distributions. If not provided, 77 | a uniform distribution is used with input_dim = len(data_columns), 78 | and B=5. 79 | conditional_columns : Sequence[str]; optional 80 | Names of columns on which to condition the normalizing flows. 81 | data_error_model : Callable; optional 82 | A callable that defines the error model for data variables. 83 | data_error_model must take key, X, Xerr, nsamples as arguments: 84 | - key is a jax rng key, e.g. jax.random.PRNGKey(0) 85 | - X is 2D array of data variables, where the order of variables 86 | matches the order of the columns in data_columns 87 | - Xerr is the corresponding 2D array of errors 88 | - nsamples is number of samples to draw from error distribution 89 | data_error_model must return an array of samples with the shape 90 | (X.shape[0], nsamples, X.shape[1]). 91 | If data_error_model is not provided, Gaussian error model assumed. 92 | condition_error_model : Callable; optional 93 | A callable that defines the error model for conditional variables. 94 | condition_error_model must take key, X, Xerr, nsamples, where: 95 | - key is a jax rng key, e.g. jax.random.PRNGKey(0) 96 | - X is 2D array of conditional variables, where the order of 97 | variables matches order of columns in conditional_columns 98 | - Xerr is the corresponding 2D array of errors 99 | - nsamples is number of samples to draw from error distribution 100 | condition_error_model must return array of samples with shape 101 | (X.shape[0], nsamples, X.shape[1]). 102 | If condition_error_model is not provided, Gaussian error model 103 | assumed. 104 | autoscale_conditions : bool; default=True 105 | Sets whether or not conditions are automatically standard scaled 106 | when passed to a conditional flow. I recommend you leave as True. 107 | N : int; default=1 108 | The number of flows in the ensemble. 109 | info : Any; optional 110 | An object to attach to the info attribute. 111 | file : str; optional 112 | Path to file from which to load a pretrained flow ensemble. 113 | If a file is provided, all other parameters must be None. 114 | """ 115 | 116 | # validate parameters 117 | if data_columns is None and file is None: 118 | raise ValueError("You must provide data_columns OR file.") 119 | if file is not None and any( 120 | ( 121 | data_columns is not None, 122 | bijector is not None, 123 | conditional_columns is not None, 124 | latent is not None, 125 | data_error_model is not None, 126 | condition_error_model is not None, 127 | info is not None, 128 | ) 129 | ): 130 | raise ValueError( 131 | "If providing a file, please do not provide any other parameters." 132 | ) 133 | 134 | # if file is provided, load everything from the file 135 | if file is not None: 136 | # load the file 137 | with open(file, "rb") as handle: 138 | save_dict = pickle.load(handle) 139 | 140 | # make sure the saved file is for this class 141 | c = save_dict.pop("class") 142 | if c != self.__class__.__name__: 143 | raise TypeError( 144 | f"This save file isn't a {self.__class__.__name__}. It is a {c}." 145 | ) 146 | 147 | # load the ensemble from the dictionary 148 | self._ensemble = { 149 | name: Flow(_dictionary=flow_dict) 150 | for name, flow_dict in save_dict["ensemble"].items() 151 | } 152 | # load the metadata 153 | self.data_columns = save_dict["data_columns"] 154 | self.conditional_columns = save_dict["conditional_columns"] 155 | self.data_error_model = save_dict["data_error_model"] 156 | self.condition_error_model = save_dict["condition_error_model"] 157 | self.info = save_dict["info"] 158 | 159 | self._latent_info = save_dict["latent_info"] 160 | self.latent = getattr(distributions, self._latent_info[0])( 161 | *self._latent_info[1] 162 | ) 163 | 164 | # otherwise create a new ensemble from the provided parameters 165 | else: 166 | # save the ensemble of flows 167 | self._ensemble = { 168 | f"Flow {i}": Flow( 169 | data_columns=data_columns, 170 | bijector=bijector, 171 | conditional_columns=conditional_columns, 172 | latent=latent, 173 | data_error_model=data_error_model, 174 | condition_error_model=condition_error_model, 175 | autoscale_conditions=autoscale_conditions, 176 | seed=i, 177 | info=f"Flow {i}", 178 | ) 179 | for i in range(N) 180 | } 181 | # save the metadata 182 | self.data_columns = data_columns 183 | self.conditional_columns = conditional_columns 184 | self.latent = self._ensemble["Flow 0"].latent 185 | self.data_error_model = data_error_model 186 | self.condition_error_model = condition_error_model 187 | self.info = info 188 | 189 | def log_prob( 190 | self, 191 | inputs: pd.DataFrame, 192 | err_samples: int = None, 193 | seed: int = None, 194 | returnEnsemble: bool = False, 195 | ) -> jnp.ndarray: 196 | """Calculates log probability density of inputs. 197 | 198 | Parameters 199 | ---------- 200 | inputs : pd.DataFrame 201 | Input data for which log probability density is calculated. 202 | Every column in self.data_columns must be present. 203 | If self.conditional_columns is not None, those must be present 204 | as well. If other columns are present, they are ignored. 205 | err_samples : int; default=None 206 | Number of samples from the error distribution to average over for 207 | the log_prob calculation. If provided, Gaussian errors are assumed, 208 | and method will look for error columns in `inputs`. Error columns 209 | must end in `_err`. E.g. the error column for the variable `u` must 210 | be `u_err`. Zero error assumed for any missing error columns. 211 | seed : int; default=None 212 | Random seed for drawing the samples with Gaussian errors. 213 | returnEnsemble : bool; default=False 214 | If True, returns log_prob for each flow in the ensemble as an 215 | array of shape (inputs.shape[0], N flows in ensemble). 216 | If False, the prob is averaged over the flows in the ensemble, 217 | and the log of this average is returned as an array of shape 218 | (inputs.shape[0],) 219 | 220 | Returns 221 | ------- 222 | jnp.ndarray 223 | For shape, see returnEnsemble description above. 224 | """ 225 | 226 | # calculate log_prob for each flow in the ensemble 227 | ensemble = jnp.array( 228 | [ 229 | flow.log_prob(inputs, err_samples, seed) 230 | for flow in self._ensemble.values() 231 | ] 232 | ) 233 | 234 | # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble) 235 | ensemble = jnp.rollaxis(ensemble, axis=1) 236 | 237 | if returnEnsemble: 238 | # return the ensemble of log_probs 239 | return ensemble 240 | else: 241 | # return mean over ensemble 242 | # note we return log(mean prob) instead of just mean log_prob 243 | return jnp.log(jnp.exp(ensemble).mean(axis=1)) 244 | 245 | def posterior( 246 | self, 247 | inputs: pd.DataFrame, 248 | column: str, 249 | grid: jnp.ndarray, 250 | marg_rules: dict = None, 251 | normalize: bool = True, 252 | err_samples: int = None, 253 | seed: int = None, 254 | batch_size: int = None, 255 | returnEnsemble: bool = False, 256 | nan_to_zero: bool = True, 257 | ) -> jnp.ndarray: 258 | """Calculates posterior distributions for the provided column. 259 | 260 | Calculates the conditional posterior distribution, assuming the 261 | data values in the other columns of the DataFrame. 262 | 263 | Parameters 264 | ---------- 265 | inputs : pd.DataFrame 266 | Data on which the posterior distributions are conditioned. 267 | Must have columns matching self.data_columns, *except* 268 | for the column specified for the posterior (see below). 269 | column : str 270 | Name of the column for which the posterior distribution 271 | is calculated. Must be one of the columns in self.data_columns. 272 | However, whether or not this column is one of the columns in 273 | `inputs` is irrelevant. 274 | grid : jnp.ndarray 275 | Grid on which to calculate the posterior. 276 | marg_rules : dict; optional 277 | Dictionary with rules for marginalizing over missing variables. 278 | The dictionary must contain the key "flag", which gives the flag 279 | that indicates a missing value. E.g. if missing values are given 280 | the value 99, the dictionary should contain {"flag": 99}. 281 | The dictionary must also contain {"name": callable} for any 282 | variables that will need to be marginalized over, where name is 283 | the name of the variable, and callable is a callable that takes 284 | the row of variables nad returns a grid over which to marginalize 285 | the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}. 286 | Note: the callable for a given name must *always* return an array 287 | of the same length, regardless of the input row. 288 | normalize : boolean; default=True 289 | Whether to normalize the posterior so that it integrates to 1. 290 | err_samples : int; default=None 291 | Number of samples from the error distribution to average over for 292 | the posterior calculation. If provided, Gaussian errors are assumed, 293 | and method will look for error columns in `inputs`. Error columns 294 | must end in `_err`. E.g. the error column for the variable `u` must 295 | be `u_err`. Zero error assumed for any missing error columns. 296 | seed : int; default=None 297 | Random seed for drawing the samples with Gaussian errors. 298 | batch_size : int; default=None 299 | Size of batches in which to calculate posteriors. If None, all 300 | posteriors are calculated simultaneously. Simultaneous calculation 301 | is faster, but memory intensive for large data sets. 302 | returnEnsemble : bool; default=False 303 | If True, returns posterior for each flow in the ensemble as an 304 | array of shape (inputs.shape[0], N flows in ensemble, grid.size). 305 | If False, the posterior is averaged over the flows in the ensemble, 306 | and returned as an array of shape (inputs.shape[0], grid.size) 307 | nan_to_zero : bool; default=True 308 | Whether to convert NaN's to zero probability in the final pdfs. 309 | 310 | Returns 311 | ------- 312 | jnp.ndarray 313 | For shape, see returnEnsemble description above. 314 | """ 315 | 316 | # calculate posterior for each flow in the ensemble 317 | ensemble = jnp.array( 318 | [ 319 | flow.posterior( 320 | inputs=inputs, 321 | column=column, 322 | grid=grid, 323 | marg_rules=marg_rules, 324 | err_samples=err_samples, 325 | seed=seed, 326 | batch_size=batch_size, 327 | normalize=False, 328 | nan_to_zero=nan_to_zero, 329 | ) 330 | for flow in self._ensemble.values() 331 | ] 332 | ) 333 | 334 | # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble) 335 | ensemble = jnp.rollaxis(ensemble, axis=1) 336 | 337 | if returnEnsemble: 338 | # return the ensemble of posteriors 339 | if normalize: 340 | ensemble = ensemble.reshape(-1, grid.size) 341 | ensemble = ensemble / trapezoid(y=ensemble, x=grid).reshape( 342 | -1, 1 343 | ) 344 | ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size) 345 | return ensemble 346 | else: 347 | # return mean over ensemble 348 | pdfs = ensemble.mean(axis=1) 349 | if normalize: 350 | pdfs = pdfs / trapezoid(y=pdfs, x=grid).reshape(-1, 1) 351 | return pdfs 352 | 353 | def sample( 354 | self, 355 | nsamples: int = 1, 356 | conditions: pd.DataFrame = None, 357 | save_conditions: bool = True, 358 | seed: int = None, 359 | returnEnsemble: bool = False, 360 | ) -> pd.DataFrame: 361 | """Returns samples from the ensemble. 362 | 363 | Parameters 364 | ---------- 365 | nsamples : int; default=1 366 | The number of samples to be returned, either overall or per flow 367 | in the ensemble (see returnEnsemble below). 368 | conditions : pd.DataFrame; optional 369 | If this is a conditional flow, you must pass conditions for 370 | each sample. nsamples will be drawn for each row in conditions. 371 | save_conditions : bool; default=True 372 | If true, conditions will be saved in the DataFrame of samples 373 | that is returned. 374 | seed : int; optional 375 | Sets the random seed for the samples. 376 | returnEnsemble : bool; default=False 377 | If True, nsamples is drawn from each flow in the ensemble. 378 | If False, nsamples are drawn uniformly from the flows in the ensemble. 379 | 380 | Returns 381 | ------- 382 | pd.DataFrame 383 | Pandas DataFrame of samples. 384 | """ 385 | 386 | if returnEnsemble: 387 | # return nsamples for each flow in the ensemble 388 | return pd.concat( 389 | [ 390 | flow.sample(nsamples, conditions, save_conditions, seed) 391 | for flow in self._ensemble.values() 392 | ], 393 | keys=self._ensemble.keys(), 394 | ) 395 | else: 396 | # if this isn't a conditional flow, sampling is straightforward 397 | if conditions is None: 398 | # return nsamples drawn uniformly from the flows in the ensemble 399 | N = int(jnp.ceil(nsamples / len(self._ensemble))) 400 | samples = pd.concat( 401 | [ 402 | flow.sample(N, conditions, save_conditions, seed) 403 | for flow in self._ensemble.values() 404 | ] 405 | ) 406 | return samples.sample(nsamples, random_state=seed).reset_index( 407 | drop=True 408 | ) 409 | # if this is a conditional flow, it's a little more complicated... 410 | else: 411 | # if nsamples > 1, we duplicate the rows of the conditions 412 | if nsamples > 1: 413 | conditions = pd.concat([conditions] * nsamples) 414 | 415 | # now the main sampling algorithm 416 | seed = np.random.randint(1e18) if seed is None else seed 417 | # if we are drawing more samples than the number of flows in 418 | # the ensemble, then we will shuffle the conditions and randomly 419 | # assign them to one of the constituent flows 420 | if conditions.shape[0] > len(self._ensemble): 421 | # shuffle the conditions 422 | conditions_shuffled = conditions.sample( 423 | frac=1.0, random_state=int(seed / 1e9) 424 | ) 425 | # split conditions into ~equal sized chunks 426 | chunks = np.array_split( 427 | conditions_shuffled, len(self._ensemble) 428 | ) 429 | # shuffle the chunks 430 | chunks = [ 431 | chunks[i] 432 | for i in random.permutation( 433 | random.PRNGKey(seed), jnp.arange(len(chunks)) 434 | ) 435 | ] 436 | # sample from each flow, and return all the samples 437 | return pd.concat( 438 | [ 439 | flow.sample( 440 | 1, chunk, save_conditions, seed 441 | ).set_index(chunk.index) 442 | for flow, chunk in zip( 443 | self._ensemble.values(), chunks 444 | ) 445 | ] 446 | ).sort_index() 447 | # however, if there are more flows in the ensemble than samples 448 | # being drawn, then we will randomly select flows for each condition 449 | else: 450 | rng = np.random.default_rng(seed) 451 | # randomly select a flow to sample from for each condition 452 | flows = rng.choice( 453 | list(self._ensemble.values()), 454 | size=conditions.shape[0], 455 | replace=True, 456 | ) 457 | # sample from each flow and return all the samples together 458 | seeds = rng.integers(1e18, size=conditions.shape[0]) 459 | return pd.concat( 460 | [ 461 | flow.sample( 462 | 1, 463 | conditions[i : i + 1], 464 | save_conditions, 465 | new_seed, 466 | ) 467 | for i, (flow, new_seed) in enumerate( 468 | zip(flows, seeds) 469 | ) 470 | ], 471 | ).set_index(conditions.index) 472 | 473 | def save(self, file: str) -> None: 474 | """Saves the ensemble to a file. 475 | 476 | Pickles the ensemble and saves it to a file that can be passed as 477 | the `file` argument during flow instantiation. 478 | 479 | WARNING: Currently, this method only works for bijectors that are 480 | implemented in the `bijectors` module. If you want to save a flow 481 | with a custom bijector, you either need to add the bijector to that 482 | module, or handle the saving and loading on your end. 483 | 484 | Parameters 485 | ---------- 486 | file : str 487 | Path to where the ensemble will be saved. 488 | Extension `.pkl` will be appended if not already present. 489 | """ 490 | save_dict = { 491 | "data_columns": self.data_columns, 492 | "conditional_columns": self.conditional_columns, 493 | "latent_info": self.latent.info, 494 | "data_error_model": self.data_error_model, 495 | "condition_error_model": self.condition_error_model, 496 | "info": self.info, 497 | "class": self.__class__.__name__, 498 | "ensemble": { 499 | name: flow._save_dict() 500 | for name, flow in self._ensemble.items() 501 | }, 502 | } 503 | 504 | with open(file, "wb") as handle: 505 | pickle.dump(save_dict, handle, recurse=True) 506 | 507 | def train( 508 | self, 509 | inputs: pd.DataFrame, 510 | val_set: pd.DataFrame = None, 511 | train_weight: np.ndarray = None, 512 | val_weight: np.ndarray = None, 513 | epochs: int = 50, 514 | batch_size: int = 1024, 515 | optimizer: Callable = None, 516 | loss_fn: Callable = None, 517 | convolve_errs: bool = False, 518 | patience: int = None, 519 | best_params: bool = True, 520 | seed: int = 0, 521 | verbose: bool = False, 522 | progress_bar: bool = False, 523 | initial_loss: bool = True, 524 | ) -> dict: 525 | """Trains the normalizing flows on the provided inputs. 526 | 527 | Parameters 528 | ---------- 529 | inputs : pd.DataFrame 530 | Data on which to train the normalizing flows. 531 | Must have columns matching self.data_columns. 532 | val_set : pd.DataFrame; default=None 533 | Validation set, of same format as inputs. If provided, 534 | validation loss will be calculated at the end of each epoch. 535 | train_weight: np.ndarray; default=None 536 | Array of weights for each sample in the training set. 537 | val_weight: np.ndarray; default=None 538 | Array of weights for each sample in the validation set. 539 | epochs : int; default=50 540 | Number of epochs to train. 541 | batch_size : int; default=1024 542 | Batch size for training. 543 | optimizer : optax optimizer 544 | An optimizer from Optax. default = optax.adam(learning_rate=1e-3) 545 | see https://optax.readthedocs.io/en/latest/index.html for more. 546 | loss_fn : Callable; optional 547 | A function to calculate the loss: loss = loss_fn(params, x). 548 | If not provided, will be -mean(log_prob). 549 | convolve_errs : bool; default=False 550 | Whether to draw new data from the error distributions during 551 | each epoch of training. Method will look for error columns in 552 | `inputs`. Error columns must end in `_err`. E.g. the error column 553 | for the variable `u` must be `u_err`. Zero error assumed for 554 | any missing error columns. The error distribution is set during 555 | ensemble instantiation. 556 | patience : int; optional 557 | Factor that controls early stopping. Training will stop if the 558 | loss doesn't decrease for this number of epochs. 559 | best_params : bool; default=True 560 | Whether to use the params from the epoch with the lowest loss. 561 | Note if a validation set is provided, the epoch with the lowest 562 | validation loss is chosen. If False, the params from the final 563 | epoch are saved. 564 | seed : int; default=0 565 | A random seed to control the batching and the (optional) 566 | error sampling. 567 | verbose : bool; default=False 568 | If true, print the training loss every 5% of epochs. 569 | progress_bar : bool; default=False 570 | If true, display a tqdm progress bar during training. 571 | initial_loss : bool; default=True 572 | If true, start by calculating the initial loss. 573 | 574 | Returns 575 | ------- 576 | dict 577 | Dictionary of training losses from every epoch for each flow 578 | in the ensemble. 579 | """ 580 | 581 | # generate random seeds for each flow 582 | rng = np.random.default_rng(seed) 583 | seeds = rng.integers(1e9, size=len(self._ensemble)) 584 | 585 | loss_dict = dict() 586 | 587 | for i, (name, flow) in enumerate(self._ensemble.items()): 588 | if verbose: 589 | print(name) 590 | 591 | loss_dict[name] = flow.train( 592 | inputs=inputs, 593 | val_set=val_set, 594 | train_weight=train_weight, 595 | val_weight=val_weight, 596 | epochs=epochs, 597 | batch_size=batch_size, 598 | optimizer=optimizer, 599 | loss_fn=loss_fn, 600 | convolve_errs=convolve_errs, 601 | patience=patience, 602 | best_params=best_params, 603 | seed=seeds[i], 604 | verbose=verbose, 605 | progress_bar=progress_bar, 606 | initial_loss=initial_loss, 607 | ) 608 | 609 | return loss_dict 610 | -------------------------------------------------------------------------------- /pzflow/utils.py: -------------------------------------------------------------------------------- 1 | """Define utility functions for use in other modules.""" 2 | from typing import Callable, Tuple 3 | 4 | import jax.numpy as jnp 5 | from jax import random 6 | from jax.example_libraries.stax import Dense, LeakyRelu, serial 7 | 8 | from pzflow import bijectors 9 | 10 | 11 | def build_bijector_from_info(info: tuple) -> tuple: 12 | """Build a Bijector from a Bijector_Info object""" 13 | 14 | # recurse through chains 15 | if info[0] == "Chain": 16 | return bijectors.Chain(*(build_bijector_from_info(i) for i in info[1])) 17 | # build individual bijector from name and parameters 18 | else: 19 | return getattr(bijectors, info[0])(*info[1]) 20 | 21 | 22 | def DenseReluNetwork( 23 | out_dim: int, hidden_layers: int, hidden_dim: int 24 | ) -> Tuple[Callable, Callable]: 25 | """Create a dense neural network with Relu after hidden layers. 26 | 27 | Parameters 28 | ---------- 29 | out_dim : int 30 | The output dimension. 31 | hidden_layers : int 32 | The number of hidden layers 33 | hidden_dim : int 34 | The dimension of the hidden layers 35 | 36 | Returns 37 | ------- 38 | init_fun : function 39 | The function that initializes the network. Note that this is the 40 | init_function defined in the Jax stax module, which is different 41 | from the functions of my InitFunction class. 42 | forward_fun : function 43 | The function that passes the inputs through the neural network. 44 | """ 45 | init_fun, forward_fun = serial( 46 | *(Dense(hidden_dim), LeakyRelu) * hidden_layers, 47 | Dense(out_dim), 48 | ) 49 | return init_fun, forward_fun 50 | 51 | 52 | def gaussian_error_model( 53 | key, X: jnp.ndarray, Xerr: jnp.ndarray, nsamples: int 54 | ) -> jnp.ndarray: 55 | """ 56 | Default Gaussian error model were X are the means and Xerr are the stds. 57 | """ 58 | 59 | eps = random.normal(key, shape=(X.shape[0], nsamples, X.shape[1])) 60 | 61 | return X[:, None, :] + eps * Xerr[:, None, :] 62 | 63 | 64 | def RationalQuadraticSpline( 65 | inputs: jnp.ndarray, 66 | W: jnp.ndarray, 67 | H: jnp.ndarray, 68 | D: jnp.ndarray, 69 | B: float, 70 | periodic: bool = False, 71 | inverse: bool = False, 72 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 73 | """Apply rational quadratic spline to inputs and return outputs with log_det. 74 | 75 | Applies the piecewise rational quadratic spline developed in [1]. 76 | 77 | Parameters 78 | ---------- 79 | inputs : jnp.ndarray 80 | The inputs to be transformed. 81 | W : jnp.ndarray 82 | The widths of the spline bins. 83 | H : jnp.ndarray 84 | The heights of the spline bins. 85 | D : jnp.ndarray 86 | The derivatives of the inner spline knots. 87 | B : float 88 | Range of the splines. 89 | Outside of (-B,B), the transformation is just the identity. 90 | inverse : bool; default=False 91 | If True, perform the inverse transformation. 92 | Otherwise perform the forward transformation. 93 | periodic : bool; default=False 94 | Whether to make this a periodic, Circular Spline [2]. 95 | 96 | Returns 97 | ------- 98 | outputs : jnp.ndarray 99 | The result of applying the splines to the inputs. 100 | log_det : jnp.ndarray 101 | The log determinant of the Jacobian at the inputs. 102 | 103 | References 104 | ---------- 105 | [1] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. 106 | Neural Spline Flows. arXiv:1906.04032, 2019. 107 | https://arxiv.org/abs/1906.04032 108 | [2] Rezende, Danilo Jimenez et al. 109 | Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020 110 | http://arxiv.org/abs/2002.02428 111 | """ 112 | # knot x-positions 113 | xk = jnp.pad( 114 | -B + jnp.cumsum(W, axis=-1), 115 | [(0, 0)] * (len(W.shape) - 1) + [(1, 0)], 116 | mode="constant", 117 | constant_values=-B, 118 | ) 119 | # knot y-positions 120 | yk = jnp.pad( 121 | -B + jnp.cumsum(H, axis=-1), 122 | [(0, 0)] * (len(H.shape) - 1) + [(1, 0)], 123 | mode="constant", 124 | constant_values=-B, 125 | ) 126 | # knot derivatives 127 | if periodic: 128 | dk = jnp.pad(D, [(0, 0)] * (len(D.shape) - 1) + [(1, 0)], mode="wrap") 129 | else: 130 | dk = jnp.pad( 131 | D, 132 | [(0, 0)] * (len(D.shape) - 1) + [(1, 1)], 133 | mode="constant", 134 | constant_values=1, 135 | ) 136 | # knot slopes 137 | sk = H / W 138 | 139 | # if not periodic, out-of-bounds inputs will have identity applied 140 | # if periodic, we map the input into the appropriate region inside 141 | # the period. For now, we will pretend all inputs are periodic. 142 | # This makes sure that out-of-bounds inputs don't cause problems 143 | # with the spline, but for the non-periodic case, we will replace 144 | # these with their original values at the end 145 | out_of_bounds = (inputs <= -B) | (inputs >= B) 146 | masked_inputs = jnp.where(out_of_bounds, jnp.abs(inputs) - B, inputs) 147 | 148 | # find bin for each input 149 | if inverse: 150 | idx = jnp.sum(yk <= masked_inputs[..., None], axis=-1)[..., None] - 1 151 | else: 152 | idx = jnp.sum(xk <= masked_inputs[..., None], axis=-1)[..., None] - 1 153 | 154 | # get kx, ky, kyp1, kd, kdp1, kw, ks for the bin corresponding to each input 155 | input_xk = jnp.take_along_axis(xk, idx, -1)[..., 0] 156 | input_yk = jnp.take_along_axis(yk, idx, -1)[..., 0] 157 | input_dk = jnp.take_along_axis(dk, idx, -1)[..., 0] 158 | input_dkp1 = jnp.take_along_axis(dk, idx + 1, -1)[..., 0] 159 | input_wk = jnp.take_along_axis(W, idx, -1)[..., 0] 160 | input_hk = jnp.take_along_axis(H, idx, -1)[..., 0] 161 | input_sk = jnp.take_along_axis(sk, idx, -1)[..., 0] 162 | 163 | if inverse: 164 | # [1] Appendix A.3 165 | # quadratic formula coefficients 166 | a = (input_hk) * (input_sk - input_dk) + (masked_inputs - input_yk) * ( 167 | input_dkp1 + input_dk - 2 * input_sk 168 | ) 169 | b = (input_hk) * input_dk - (masked_inputs - input_yk) * ( 170 | input_dkp1 + input_dk - 2 * input_sk 171 | ) 172 | c = -input_sk * (masked_inputs - input_yk) 173 | 174 | relx = 2 * c / (-b - jnp.sqrt(b**2 - 4 * a * c)) 175 | outputs = relx * input_wk + input_xk 176 | # if not periodic, replace out-of-bounds values with original values 177 | if not periodic: 178 | outputs = jnp.where(out_of_bounds, inputs, outputs) 179 | 180 | # [1] Appendix A.2 181 | # calculate the log determinant 182 | dnum = ( 183 | input_dkp1 * relx**2 184 | + 2 * input_sk * relx * (1 - relx) 185 | + input_dk * (1 - relx) ** 2 186 | ) 187 | dden = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * ( 188 | 1 - relx 189 | ) 190 | log_det = 2 * jnp.log(input_sk) + jnp.log(dnum) - 2 * jnp.log(dden) 191 | # if not periodic, replace log_det for out-of-bounds values = 0 192 | if not periodic: 193 | log_det = jnp.where(out_of_bounds, 0, log_det) 194 | log_det = log_det.sum(axis=1) 195 | 196 | return outputs, -log_det 197 | 198 | else: 199 | # [1] Appendix A.1 200 | # calculate spline 201 | relx = (masked_inputs - input_xk) / input_wk 202 | num = input_hk * (input_sk * relx**2 + input_dk * relx * (1 - relx)) 203 | den = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * ( 204 | 1 - relx 205 | ) 206 | outputs = input_yk + num / den 207 | # if not periodic, replace out-of-bounds values with original values 208 | if not periodic: 209 | outputs = jnp.where(out_of_bounds, inputs, outputs) 210 | 211 | # [1] Appendix A.2 212 | # calculate the log determinant 213 | dnum = ( 214 | input_dkp1 * relx**2 215 | + 2 * input_sk * relx * (1 - relx) 216 | + input_dk * (1 - relx) ** 2 217 | ) 218 | dden = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * ( 219 | 1 - relx 220 | ) 221 | log_det = 2 * jnp.log(input_sk) + jnp.log(dnum) - 2 * jnp.log(dden) 222 | # if not periodic, replace log_det for out-of-bounds values = 0 223 | if not periodic: 224 | log_det = jnp.where(out_of_bounds, 0, log_det) 225 | log_det = log_det.sum(axis=1) 226 | 227 | return outputs, log_det 228 | 229 | 230 | def sub_diag_indices( 231 | inputs: jnp.ndarray, 232 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 233 | """Return indices for diagonal of 2D blocks in 3D array""" 234 | if inputs.ndim != 3: 235 | raise ValueError("Input must be a 3D array.") 236 | nblocks = inputs.shape[0] 237 | ndiag = min(inputs.shape[1], inputs.shape[2]) 238 | idx = ( 239 | jnp.repeat(jnp.arange(nblocks), ndiag), 240 | jnp.tile(jnp.arange(ndiag), nblocks), 241 | jnp.tile(jnp.arange(ndiag), nblocks), 242 | ) 243 | return idx 244 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | extend-ignore = E203 3 | docstring-convention = numpy 4 | ban-relative-imports = true 5 | 6 | [mypy] 7 | follow_imports = silent 8 | warn_redundant_casts = True 9 | warn_unused_ignores = True 10 | disallow_any_generics = False 11 | check_untyped_defs = True 12 | implicit_reexport = False 13 | disallow_untyped_defs = True 14 | ignore_missing_imports = True 15 | -------------------------------------------------------------------------------- /tests/test_bijectors.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | from jax import jit, random 4 | 5 | from pzflow.bijectors import * 6 | 7 | x = jnp.array( 8 | [ 9 | [0.2, 0.1, -0.3, 0.5, 0.1, -0.4, -0.3], 10 | [0.6, 0.5, 0.2, 0.2, -0.4, -0.1, 0.7], 11 | [0.9, 0.2, -0.3, 0.3, 0.4, -0.4, -0.1], 12 | ] 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "bijector,args,conditions", 18 | [ 19 | (ColorTransform, (3, [1, 3, 5]), jnp.zeros((3, 1))), 20 | (Reverse, (), jnp.zeros((3, 1))), 21 | (Roll, (2,), jnp.zeros((3, 1))), 22 | (Scale, (2.0,), jnp.zeros((3, 1))), 23 | (Shuffle, (), jnp.zeros((3, 1))), 24 | (InvSoftplus, (0,), jnp.zeros((3, 1))), 25 | (InvSoftplus, ([1, 3], [2.0, 12.0]), jnp.zeros((3, 1))), 26 | ( 27 | StandardScaler, 28 | (jnp.linspace(-1, 1, 7), jnp.linspace(1, 8, 7)), 29 | jnp.zeros((3, 1)), 30 | ), 31 | (Chain, (Reverse(), Scale(1 / 6), Roll(-1)), jnp.zeros((3, 1))), 32 | (NeuralSplineCoupling, (), jnp.zeros((3, 1))), 33 | ( 34 | NeuralSplineCoupling, 35 | (16, 3, 2, 128, 3), 36 | jnp.arange(9).reshape(3, 3), 37 | ), 38 | (RollingSplineCoupling, (2,), jnp.zeros((3, 1))), 39 | ( 40 | RollingSplineCoupling, 41 | (2, 1, 16, 3, 2, 128, None, 0, True), 42 | jnp.zeros((3, 1)), 43 | ), 44 | (ShiftBounds, (-0.5, 0.9, 5), jnp.zeros((3, 1))), 45 | ( 46 | ShiftBounds, 47 | (-1 * jnp.ones(7), 1.1 * jnp.ones(7), 3), 48 | jnp.zeros((3, 1)), 49 | ), 50 | ], 51 | ) 52 | class TestBijectors: 53 | def test_returns_correct_shape(self, bijector, args, conditions): 54 | init_fun, bijector_info = bijector(*args) 55 | params, forward_fun, inverse_fun = init_fun( 56 | random.PRNGKey(0), x.shape[-1] 57 | ) 58 | 59 | fwd_outputs, fwd_log_det = forward_fun( 60 | params, x, conditions=conditions 61 | ) 62 | assert fwd_outputs.shape == x.shape 63 | assert fwd_log_det.shape == x.shape[:1] 64 | 65 | inv_outputs, inv_log_det = inverse_fun( 66 | params, x, conditions=conditions 67 | ) 68 | assert inv_outputs.shape == x.shape 69 | assert inv_log_det.shape == x.shape[:1] 70 | 71 | def test_is_bijective(self, bijector, args, conditions): 72 | init_fun, bijector_info = bijector(*args) 73 | params, forward_fun, inverse_fun = init_fun( 74 | random.PRNGKey(0), x.shape[-1] 75 | ) 76 | 77 | fwd_outputs, fwd_log_det = forward_fun( 78 | params, x, conditions=conditions 79 | ) 80 | inv_outputs, inv_log_det = inverse_fun( 81 | params, fwd_outputs, conditions=conditions 82 | ) 83 | 84 | print(inv_outputs) 85 | assert jnp.allclose(inv_outputs, x, atol=1e-6) 86 | assert jnp.allclose(fwd_log_det, -inv_log_det, atol=1e-6) 87 | 88 | def test_is_jittable(self, bijector, args, conditions): 89 | init_fun, bijector_info = bijector(*args) 90 | params, forward_fun, inverse_fun = init_fun( 91 | random.PRNGKey(0), x.shape[-1] 92 | ) 93 | 94 | fwd_outputs_1, fwd_log_det_1 = forward_fun( 95 | params, x, conditions=conditions 96 | ) 97 | forward_fun = jit(forward_fun) 98 | fwd_outputs_2, fwd_log_det_2 = forward_fun( 99 | params, x, conditions=conditions 100 | ) 101 | 102 | inv_outputs_1, inv_log_det_1 = inverse_fun( 103 | params, x, conditions=conditions 104 | ) 105 | inverse_fun = jit(inverse_fun) 106 | inv_outputs_2, inv_log_det_2 = inverse_fun( 107 | params, x, conditions=conditions 108 | ) 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "bijector,args", 113 | [ 114 | (ColorTransform, (0, [1, 2, 3, 4])), 115 | (ColorTransform, (1.3, [1, 2, 3, 4])), 116 | (ColorTransform, (1, [2, 3, 4])), 117 | (Roll, (2.4,)), 118 | (Scale, (2,)), 119 | (Scale, (jnp.arange(7),)), 120 | (InvSoftplus, ([0, 1, 2], [1.0, 2.0])), 121 | ( 122 | RollingSplineCoupling, 123 | (2, 1, 16, 3, 2, 128, None, 0, "fake"), 124 | ), 125 | (ShiftBounds, (4, 2, 1)), 126 | (ShiftBounds, (jnp.array([0, 1]), 2, 1)), 127 | ], 128 | ) 129 | def test_bad_inputs(bijector, args): 130 | with pytest.raises(ValueError): 131 | bijector(*args) 132 | 133 | 134 | def test_uniform_dequantizer_returns_correct_shape(): 135 | init_fun, bijector_info = UniformDequantizer([1, 3, 4]) 136 | params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), x.shape[-1]) 137 | 138 | conditions = jnp.zeros((3, 1)) 139 | fwd_outputs, fwd_log_det = forward_fun(params, x, conditions=conditions) 140 | assert fwd_outputs.shape == x.shape 141 | assert fwd_log_det.shape == x.shape[:1] 142 | 143 | inv_outputs, inv_log_det = inverse_fun(params, x, conditions=conditions) 144 | assert inv_outputs.shape == x.shape 145 | assert inv_log_det.shape == x.shape[:1] 146 | -------------------------------------------------------------------------------- /tests/test_distributions.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | 4 | from pzflow.distributions import * 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "distribution,inputs,params", 9 | [ 10 | (CentBeta, (2, 3), ((0, 1), (2, 3))), 11 | (Normal, (2,), ()), 12 | (Tdist, (2,), jnp.log(30.0)), 13 | (Uniform, (2,), ()), 14 | (Joint, (Normal(1), Uniform(1, 4)), ((), ())), 15 | (Joint, (Normal(1), Tdist(1)), ((), jnp.log(30.0))), 16 | (Joint, (Joint(Normal(1), Uniform(1)).info[1]), ((), ())), 17 | (CentBeta13, (2, 4), ()), 18 | ], 19 | ) 20 | class TestDistributions: 21 | def test_returns_correct_shapes(self, distribution, inputs, params): 22 | dist = distribution(*inputs) 23 | 24 | nsamples = 8 25 | samples = dist.sample(params, nsamples) 26 | assert samples.shape == (nsamples, 2) 27 | 28 | log_prob = dist.log_prob(params, samples) 29 | assert log_prob.shape == (nsamples,) 30 | 31 | def test_control_sample_randomness(self, distribution, inputs, params): 32 | dist = distribution(*inputs) 33 | 34 | nsamples = 8 35 | s1 = dist.sample(params, nsamples) 36 | s2 = dist.sample(params, nsamples) 37 | assert ~jnp.all(jnp.isclose(s1, s2)) 38 | 39 | s1 = dist.sample(params, nsamples, seed=0) 40 | s2 = dist.sample(params, nsamples, seed=0) 41 | assert jnp.allclose(s1, s2) 42 | -------------------------------------------------------------------------------- /tests/test_ensemble.py: -------------------------------------------------------------------------------- 1 | import dill as pickle 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from jax import random 7 | from jax.scipy.integrate import trapezoid 8 | 9 | from pzflow import Flow, FlowEnsemble 10 | from pzflow.bijectors import Reverse, RollingSplineCoupling 11 | 12 | flowEns = FlowEnsemble(("x", "y"), RollingSplineCoupling(nlayers=2), N=2) 13 | flow0 = Flow(("x", "y"), RollingSplineCoupling(nlayers=2), seed=0) 14 | flow1 = Flow(("x", "y"), RollingSplineCoupling(nlayers=2), seed=1) 15 | 16 | xarray = np.arange(6).reshape(3, 2) / 10 17 | x = pd.DataFrame(xarray, columns=("x", "y")) 18 | 19 | 20 | def test_log_prob(): 21 | lpEns = flowEns.log_prob(x, returnEnsemble=True) 22 | assert lpEns.shape == (3, 2) 23 | 24 | lp0 = flow0.log_prob(x) 25 | lp1 = flow1.log_prob(x) 26 | assert jnp.allclose(lpEns[:, 0], lp0) 27 | assert jnp.allclose(lpEns[:, 1], lp1) 28 | 29 | lpEnsMean = flowEns.log_prob(x) 30 | assert lpEnsMean.shape == lp0.shape 31 | 32 | manualMean = jnp.log( 33 | jnp.mean(jnp.array([jnp.exp(lp0), jnp.exp(lp1)]), axis=0) 34 | ) 35 | assert jnp.allclose(lpEnsMean, manualMean) 36 | 37 | 38 | def test_posterior(): 39 | grid = jnp.linspace(-1, 1, 5) 40 | 41 | pEns = flowEns.posterior(x, "x", grid, returnEnsemble=True) 42 | assert pEns.shape == (3, 2, grid.size) 43 | 44 | p0 = flow0.posterior(x, "x", grid) 45 | p1 = flow1.posterior(x, "x", grid) 46 | assert jnp.allclose(pEns[:, 0, :], p0) 47 | assert jnp.allclose(pEns[:, 1, :], p1) 48 | 49 | pEnsMean = flowEns.posterior(x, "x", grid) 50 | assert pEnsMean.shape == p0.shape 51 | 52 | p0 = flow0.posterior(x, "x", grid, normalize=False) 53 | p1 = flow1.posterior(x, "x", grid, normalize=False) 54 | manualMean = (p0 + p1) / 2 55 | manualMean = manualMean / trapezoid(y=manualMean, x=grid).reshape(-1, 1) 56 | assert jnp.allclose(pEnsMean, manualMean) 57 | 58 | 59 | def test_sample(): 60 | # first test everything with returnEnsemble=False 61 | sEns = flowEns.sample(10, seed=0).values 62 | assert sEns.shape == (10, 2) 63 | 64 | s0 = flow0.sample(5, seed=0) 65 | s1 = flow1.sample(5, seed=0) 66 | sManual = jnp.vstack([s0.values, s1.values]) 67 | assert jnp.allclose( 68 | sEns[sEns[:, 0].argsort()], sManual[sManual[:, 0].argsort()] 69 | ) 70 | 71 | # now test everything with returnEnsemble=True 72 | sEns = flowEns.sample(10, seed=0, returnEnsemble=True).values 73 | assert sEns.shape == (20, 2) 74 | 75 | s0 = flow0.sample(10, seed=0) 76 | s1 = flow1.sample(10, seed=0) 77 | sManual = jnp.vstack([s0.values, s1.values]) 78 | assert jnp.allclose(sEns, sManual) 79 | 80 | 81 | def test_conditional_sample(): 82 | cEns = FlowEnsemble( 83 | ("x", "y"), 84 | RollingSplineCoupling(nlayers=2, n_conditions=2), 85 | conditional_columns=("a", "b"), 86 | N=2, 87 | ) 88 | 89 | # test with nsamples = 1, fewer samples than flows 90 | conditions = pd.DataFrame(np.arange(2).reshape(-1, 2), columns=("a", "b")) 91 | samples = cEns.sample( 92 | nsamples=1, conditions=conditions, save_conditions=False 93 | ) 94 | assert samples.shape == (1, 2) 95 | 96 | # test with nsamples = 1, more samples than flows 97 | conditions = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=("a", "b")) 98 | samples = cEns.sample( 99 | nsamples=1, conditions=conditions, save_conditions=False 100 | ) 101 | assert samples.shape == (5, 2) 102 | 103 | # test with nsamples = 2, more samples than flows 104 | conditions = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=("a", "b")) 105 | samples = cEns.sample( 106 | nsamples=2, conditions=conditions, save_conditions=False 107 | ) 108 | assert samples.shape == (10, 2) 109 | 110 | # test with returnEnsemble=True 111 | conditions = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=("a", "b")) 112 | samples = cEns.sample( 113 | nsamples=1, 114 | conditions=conditions, 115 | save_conditions=False, 116 | returnEnsemble=True, 117 | ) 118 | assert samples.shape == (10, 2) 119 | 120 | 121 | def test_train(): 122 | data = random.normal(random.PRNGKey(0), shape=(100, 2)) 123 | data = pd.DataFrame(np.array(data), columns=("x", "y")) 124 | 125 | loss_dict = flowEns.train(data, epochs=4, batch_size=50, verbose=True) 126 | 127 | rng = np.random.default_rng(0) 128 | seeds = rng.integers(1e9, size=2) 129 | losses0 = flow0.train(data, epochs=4, batch_size=50, seed=seeds[0]) 130 | losses1 = flow1.train(data, epochs=4, batch_size=50, seed=seeds[1]) 131 | 132 | assert jnp.allclose(jnp.array(loss_dict["Flow 0"]), jnp.array(losses0)) 133 | assert jnp.allclose(jnp.array(loss_dict["Flow 1"]), jnp.array(losses1)) 134 | 135 | 136 | def test_load_ensemble(tmp_path): 137 | flowEns = FlowEnsemble(("x", "y"), RollingSplineCoupling(nlayers=2), N=2) 138 | 139 | preSave = flowEns.sample(10, seed=0) 140 | 141 | file = tmp_path / "test-ensemble.pzflow.pkl" 142 | flowEns.save(str(file)) 143 | 144 | file = tmp_path / "test-ensemble.pzflow.pkl" 145 | flowEns = FlowEnsemble(file=str(file)) 146 | 147 | postSave = flowEns.sample(10, seed=0) 148 | 149 | assert jnp.allclose(preSave.values, postSave.values) 150 | 151 | with open(str(file), "rb") as handle: 152 | save_dict = pickle.load(handle) 153 | save_dict["class"] = "Flow" 154 | with open(str(file), "wb") as handle: 155 | pickle.dump(save_dict, handle, recurse=True) 156 | with pytest.raises(TypeError): 157 | FlowEnsemble(file=str(file)) 158 | 159 | 160 | @pytest.mark.parametrize( 161 | "data_columns,bijector,info,file", 162 | [ 163 | (None, None, None, None), 164 | (None, Reverse(), None, None), 165 | (("x", "y"), None, None, "file"), 166 | (None, Reverse(), None, "file"), 167 | (None, None, "fake", "file"), 168 | ], 169 | ) 170 | def test_bad_inputs(data_columns, bijector, info, file): 171 | with pytest.raises(ValueError): 172 | FlowEnsemble(data_columns, bijector=bijector, info=info, file=file) 173 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | 4 | from pzflow import Flow, examples 5 | 6 | 7 | def test_get_twomoons_data(): 8 | data = examples.get_twomoons_data() 9 | assert isinstance(data, pd.DataFrame) 10 | assert data.shape == (100_000, 2) 11 | 12 | 13 | def test_get_galaxy_data(): 14 | data = examples.get_galaxy_data() 15 | assert isinstance(data, pd.DataFrame) 16 | assert data.shape == (100_000, 7) 17 | 18 | 19 | def test_get_city_data(): 20 | data = examples.get_city_data() 21 | assert isinstance(data, pd.DataFrame) 22 | assert data.shape == (47_966, 5) 23 | 24 | 25 | def test_get_checkerboard_data(): 26 | data = examples.get_checkerboard_data() 27 | assert isinstance(data, pd.DataFrame) 28 | assert data.shape == (100_000, 2) 29 | 30 | 31 | def test_get_example_flow(): 32 | flow = examples.get_example_flow() 33 | assert isinstance(flow, Flow) 34 | assert isinstance(flow.info, str) 35 | 36 | samples = flow.sample(2) 37 | flow.log_prob(samples) 38 | 39 | grid = jnp.arange(0, 2.5, 0.5) 40 | flow.posterior(samples, column="redshift", grid=grid) 41 | -------------------------------------------------------------------------------- /tests/test_flow.py: -------------------------------------------------------------------------------- 1 | import dill as pickle 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from jax import random 7 | 8 | from pzflow import Flow 9 | from pzflow.bijectors import Reverse, RollingSplineCoupling 10 | from pzflow.distributions import * 11 | from pzflow.examples import get_twomoons_data 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "data_columns,bijector,info,file,_dictionary", 16 | [ 17 | (None, None, None, None, None), 18 | (None, Reverse(), None, None, None), 19 | (("x", "y"), None, None, "file", None), 20 | (None, Reverse(), None, "file", None), 21 | (None, None, "fake", "file", None), 22 | (("x", "y"), Reverse(), None, None, "dict"), 23 | (None, None, None, "file", "dict"), 24 | ], 25 | ) 26 | def test_bad_inputs(data_columns, bijector, info, file, _dictionary): 27 | with pytest.raises(ValueError): 28 | Flow( 29 | data_columns, 30 | bijector=bijector, 31 | info=info, 32 | file=file, 33 | _dictionary=_dictionary, 34 | ) 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "flow", 39 | [ 40 | Flow(("redshift", "y"), Reverse(), latent=Normal(2)), 41 | Flow(("redshift", "y"), Reverse(), latent=Tdist(2)), 42 | Flow(("redshift", "y"), Reverse(), latent=Uniform(2, 10)), 43 | Flow(("redshift", "y"), Reverse(), latent=CentBeta(2, 10)), 44 | ], 45 | ) 46 | def test_returns_correct_shape(flow): 47 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 48 | x = pd.DataFrame(xarray, columns=("redshift", "y")) 49 | 50 | conditions = flow._get_conditions(x) 51 | 52 | xfwd, xfwd_log_det = flow._forward( 53 | flow._params, xarray, conditions=conditions 54 | ) 55 | assert xfwd.shape == x.shape 56 | assert xfwd_log_det.shape == (x.shape[0],) 57 | 58 | xinv, xinv_log_det = flow._inverse( 59 | flow._params, xarray, conditions=conditions 60 | ) 61 | assert xinv.shape == x.shape 62 | assert xinv_log_det.shape == (x.shape[0],) 63 | 64 | nsamples = 4 65 | assert flow.sample(nsamples).shape == (nsamples, x.shape[1]) 66 | assert flow.log_prob(x).shape == (x.shape[0],) 67 | 68 | grid = jnp.arange(0, 2.1, 0.12) 69 | pdfs = flow.posterior(x, column="y", grid=grid) 70 | assert pdfs.shape == (x.shape[0], grid.size) 71 | pdfs = flow.posterior(x.iloc[:, 1:], column="redshift", grid=grid) 72 | assert pdfs.shape == (x.shape[0], grid.size) 73 | pdfs = flow.posterior( 74 | x.iloc[:, 1:], column="redshift", grid=grid, batch_size=2 75 | ) 76 | assert pdfs.shape == (x.shape[0], grid.size) 77 | 78 | assert len(flow.train(x, epochs=11, verbose=True)) == 12 79 | assert ( 80 | len(flow.train(x, epochs=11, verbose=True, convolve_errs=True)) == 12 81 | ) 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "flag", 86 | [ 87 | 99, 88 | np.nan, 89 | ], 90 | ) 91 | def test_posterior_with_marginalization(flag): 92 | flow = Flow(("a", "b", "c", "d"), Reverse()) 93 | 94 | # test posteriors with marginalization 95 | x = pd.DataFrame( 96 | np.arange(16).reshape(-1, 4), columns=("a", "b", "c", "d") 97 | ) 98 | grid = np.arange(0, 2.1, 0.12) 99 | 100 | marg_rules = { 101 | "flag": flag, 102 | "b": lambda row: np.linspace(0, 1, 2), 103 | "c": lambda row: np.linspace(1, 2, 3), 104 | } 105 | 106 | x["b"] = flag * np.ones(x.shape[0]) 107 | pdfs = flow.posterior(x, column="a", grid=grid, marg_rules=marg_rules) 108 | assert pdfs.shape == (x.shape[0], grid.size) 109 | 110 | x["c"] = flag * np.ones(x.shape[0]) 111 | pdfs = flow.posterior(x, column="a", grid=grid, marg_rules=marg_rules) 112 | assert pdfs.shape == (x.shape[0], grid.size) 113 | 114 | 115 | @pytest.mark.parametrize( 116 | "flow,x,x_with_err", 117 | [ 118 | ( 119 | Flow( 120 | ("redshift", "y"), RollingSplineCoupling(2), latent=Normal(2) 121 | ), 122 | pd.DataFrame( 123 | np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), 124 | columns=("redshift", "y"), 125 | ), 126 | pd.DataFrame( 127 | np.array( 128 | [ 129 | [1.0, 2.0, 0.1, 0.2], 130 | [3.0, 4.0, 0.2, 0.3], 131 | [5.0, 6.0, 0.1, 0.2], 132 | ] 133 | ), 134 | columns=("redshift", "y", "redshift_err", "y_err"), 135 | ), 136 | ), 137 | ( 138 | Flow( 139 | ("redshift", "y"), 140 | RollingSplineCoupling(2, n_conditions=2), 141 | latent=Normal(2), 142 | conditional_columns=("a", "b"), 143 | ), 144 | pd.DataFrame( 145 | np.array( 146 | [ 147 | [1.0, 2.0, 10, 20], 148 | [3.0, 4.0, 30, 40], 149 | [5.0, 6.0, 50, 60], 150 | ] 151 | ), 152 | columns=("redshift", "y", "a", "b"), 153 | ), 154 | pd.DataFrame( 155 | np.array( 156 | [ 157 | [1.0, 2.0, 10, 20, 0.1, 0.2, 1, 2], 158 | [3.0, 4.0, 30, 40, 0.2, 0.3, 3, 4], 159 | [5.0, 6.0, 50, 60, 0.1, 0.2, 5, 6], 160 | ] 161 | ), 162 | columns=( 163 | "redshift", 164 | "y", 165 | "a", 166 | "b", 167 | "redshift_err", 168 | "y_err", 169 | "a_err", 170 | "b_err", 171 | ), 172 | ), 173 | ), 174 | ( 175 | Flow( 176 | ("redshift", "y"), 177 | RollingSplineCoupling(2, n_conditions=1), 178 | latent=Normal(2), 179 | conditional_columns=("a",), 180 | ), 181 | pd.DataFrame( 182 | np.array([[1.0, 2.0, 10], [3.0, 4.0, 30], [5.0, 6.0, 50]]), 183 | columns=("redshift", "y", "a"), 184 | ), 185 | pd.DataFrame( 186 | np.array( 187 | [ 188 | [1.0, 2.0, 10, 0.1, 0.2, 1], 189 | [3.0, 4.0, 30, 0.2, 0.3, 3], 190 | [5.0, 6.0, 50, 0.1, 0.2, 5], 191 | ] 192 | ), 193 | columns=( 194 | "redshift", 195 | "y", 196 | "a", 197 | "redshift_err", 198 | "y_err", 199 | "a_err", 200 | ), 201 | ), 202 | ), 203 | ( 204 | Flow( 205 | ("y",), 206 | RollingSplineCoupling(1, n_conditions=2), 207 | latent=Normal(1), 208 | conditional_columns=("a", "b"), 209 | ), 210 | pd.DataFrame( 211 | np.array([[1.0, 10, 20], [3.0, 30, 40], [5.0, 50, 60]]), 212 | columns=("y", "a", "b"), 213 | ), 214 | pd.DataFrame( 215 | np.array( 216 | [ 217 | [1.0, 10, 20, 0.1, 1, 2], 218 | [3.0, 30, 40, 0.2, 3, 4], 219 | [5.0, 50, 60, 0.1, 5, 6], 220 | ] 221 | ), 222 | columns=( 223 | "y", 224 | "a", 225 | "b", 226 | "y_err", 227 | "a_err", 228 | "b_err", 229 | ), 230 | ), 231 | ), 232 | ], 233 | ) 234 | def test_error_convolution(flow, x, x_with_err): 235 | assert flow.log_prob(x, err_samples=10).shape == (x.shape[0],) 236 | assert jnp.allclose( 237 | flow.log_prob(x, err_samples=10, seed=0), 238 | flow.log_prob(x), 239 | ) 240 | assert ~jnp.allclose( 241 | flow.log_prob(x_with_err, err_samples=10, seed=0), 242 | flow.log_prob(x_with_err), 243 | ) 244 | assert jnp.allclose( 245 | flow.log_prob(x_with_err, err_samples=10, seed=0), 246 | flow.log_prob(x_with_err, err_samples=10, seed=0), 247 | ) 248 | assert ~jnp.allclose( 249 | flow.log_prob(x_with_err, err_samples=10, seed=0), 250 | flow.log_prob(x_with_err, err_samples=10, seed=1), 251 | ) 252 | assert ~jnp.allclose( 253 | flow.log_prob(x_with_err, err_samples=10), 254 | flow.log_prob(x_with_err, err_samples=10), 255 | ) 256 | 257 | grid = jnp.arange(0, 2.1, 0.12) 258 | pdfs = flow.posterior(x, column="y", grid=grid, err_samples=10) 259 | assert pdfs.shape == (x.shape[0], grid.size) 260 | assert jnp.allclose( 261 | flow.posterior(x, column="y", grid=grid, err_samples=10, seed=0), 262 | flow.posterior(x, column="y", grid=grid), 263 | rtol=1e-4, 264 | ) 265 | assert jnp.allclose( 266 | flow.posterior( 267 | x_with_err, column="y", grid=grid, err_samples=10, seed=0 268 | ), 269 | flow.posterior( 270 | x_with_err, column="y", grid=grid, err_samples=10, seed=0 271 | ), 272 | ) 273 | 274 | 275 | def test_posterior_batch(): 276 | columns = ("redshift", "y") 277 | flow = Flow(columns, Reverse()) 278 | 279 | xarray = np.array([[1, 2], [3, 4], [5, 6]]) 280 | x = pd.DataFrame(xarray, columns=columns) 281 | 282 | grid = jnp.arange(0, 2.1, 0.12) 283 | pdfs = flow.posterior(x.iloc[:, 1:], column="redshift", grid=grid) 284 | pdfs_batched = flow.posterior( 285 | x.iloc[:, 1:], column="redshift", grid=grid, batch_size=2 286 | ) 287 | assert jnp.allclose(pdfs, pdfs_batched) 288 | 289 | 290 | def test_flow_bijection(): 291 | columns = ("x", "y") 292 | flow = Flow(columns, Reverse()) 293 | 294 | x = jnp.array([[1, 2], [3, 4]]) 295 | xrev = jnp.array([[2, 1], [4, 3]]) 296 | 297 | assert jnp.allclose(flow._forward(flow._params, x)[0], xrev) 298 | assert jnp.allclose( 299 | flow._inverse(flow._params, flow._forward(flow._params, x)[0])[0], x 300 | ) 301 | assert jnp.allclose( 302 | flow._forward(flow._params, x)[1], flow._inverse(flow._params, x)[1] 303 | ) 304 | 305 | 306 | def test_load_flow(tmp_path): 307 | columns = ("x", "y") 308 | flow = Flow(columns, Reverse(), info=["random", 42]) 309 | 310 | file = tmp_path / "test-flow.pzflow.pkl" 311 | flow.save(str(file)) 312 | 313 | file = tmp_path / "test-flow.pzflow.pkl" 314 | flow = Flow(file=str(file)) 315 | 316 | x = jnp.array([[1, 2], [3, 4]]) 317 | xrev = jnp.array([[2, 1], [4, 3]]) 318 | 319 | assert jnp.allclose(flow._forward(flow._params, x)[0], xrev) 320 | assert jnp.allclose( 321 | flow._inverse(flow._params, flow._forward(flow._params, x)[0])[0], x 322 | ) 323 | assert jnp.allclose( 324 | flow._forward(flow._params, x)[1], flow._inverse(flow._params, x)[1] 325 | ) 326 | assert flow.info == ["random", 42] 327 | 328 | with open(str(file), "rb") as handle: 329 | save_dict = pickle.load(handle) 330 | save_dict["class"] = "FlowEnsemble" 331 | with open(str(file), "wb") as handle: 332 | pickle.dump(save_dict, handle, recurse=True) 333 | with pytest.raises(TypeError): 334 | Flow(file=str(file)) 335 | 336 | 337 | def test_control_sample_randomness(): 338 | columns = ("x", "y") 339 | flow = Flow(columns, Reverse()) 340 | 341 | assert np.all(~np.isclose(flow.sample(2), flow.sample(2))) 342 | assert np.allclose(flow.sample(2, seed=0), flow.sample(2, seed=0)) 343 | 344 | 345 | @pytest.mark.parametrize( 346 | "epochs,loss_fn,", 347 | [ 348 | (-1, None), 349 | (2.4, None), 350 | ("a", None), 351 | ], 352 | ) 353 | def test_train_bad_inputs(epochs, loss_fn): 354 | columns = ("redshift", "y") 355 | flow = Flow(columns, Reverse()) 356 | 357 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 358 | x = pd.DataFrame(xarray, columns=columns) 359 | 360 | with pytest.raises(ValueError): 361 | flow.train( 362 | x, 363 | epochs=epochs, 364 | loss_fn=loss_fn, 365 | ) 366 | 367 | 368 | def test_conditional_sample(): 369 | flow = Flow(("x", "y"), Reverse(), conditional_columns=("a", "b")) 370 | x = np.arange(12).reshape(-1, 4) 371 | x = pd.DataFrame(x, columns=("x", "y", "a", "b")) 372 | 373 | conditions = flow._get_conditions(x) 374 | assert conditions.shape == (x.shape[0], 2) 375 | 376 | with pytest.raises(ValueError): 377 | flow.sample(4) 378 | 379 | samples = flow.sample(4, conditions=x) 380 | assert samples.shape == (4 * x.shape[0], 4) 381 | 382 | samples = flow.sample(4, conditions=x, save_conditions=False) 383 | assert samples.shape == (4 * x.shape[0], 2) 384 | 385 | 386 | def test_train_no_errs_same(): 387 | columns = ("redshift", "y") 388 | flow = Flow(columns, Reverse()) 389 | 390 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 391 | x = pd.DataFrame(xarray, columns=columns) 392 | 393 | losses1 = flow.train(x, convolve_errs=True) 394 | losses2 = flow.train(x, convolve_errs=False) 395 | assert jnp.allclose(jnp.array(losses1), jnp.array(losses2)) 396 | 397 | 398 | def test_get_err_samples(): 399 | rng = random.PRNGKey(0) 400 | 401 | # check Gaussian data samples 402 | columns = ("x", "y") 403 | flow = Flow(columns, Reverse()) 404 | xarray = np.array([[1.0, 2.0, 0.1, 0.2], [3.0, 4.0, 0.3, 0.4]]) 405 | x = pd.DataFrame(xarray, columns=("x", "y", "x_err", "y_err")) 406 | samples = flow._get_err_samples(rng, x, 10) 407 | assert samples.shape == (20, 2) 408 | 409 | # test skip 410 | xarray = np.array([[1.0, 2.0, 0, 0]]) 411 | x = pd.DataFrame(xarray, columns=("x", "y", "x_err", "y_err")) 412 | samples = flow._get_err_samples(rng, x, 10, skip="y") 413 | assert jnp.allclose(samples, jnp.ones((10, 1))) 414 | samples = flow._get_err_samples(rng, x, 10, skip="x") 415 | assert jnp.allclose(samples, 2 * jnp.ones((10, 1))) 416 | 417 | # check Gaussian conditional samples 418 | flow = Flow(("x"), Reverse(), conditional_columns=("y")) 419 | samples = flow._get_err_samples(rng, x, 10, type="conditions") 420 | assert jnp.allclose(samples, 2 * jnp.ones((10, 1))) 421 | 422 | # check incorrect type 423 | with pytest.raises(ValueError): 424 | flow._get_err_samples(rng, x, 10, type="wrong") 425 | 426 | # check constant shift data samples 427 | columns = ("x", "y") 428 | shift_err_model = lambda key, X, Xerr, nsamples: jnp.repeat( 429 | X + Xerr, nsamples, axis=0 430 | ).reshape(X.shape[0], nsamples, X.shape[1]) 431 | flow = Flow(columns, Reverse(), data_error_model=shift_err_model) 432 | xarray = np.array([[1.0, 2.0, 0.1, 0.2], [3.0, 4.0, 0.3, 0.4]]) 433 | x = pd.DataFrame(xarray, columns=("x", "y", "x_err", "y_err")) 434 | samples = flow._get_err_samples(rng, x, 10) 435 | assert samples.shape == (20, 2) 436 | assert jnp.allclose( 437 | samples, 438 | shift_err_model(None, xarray[:, :2], xarray[:, 2:], 10).reshape(20, 2), 439 | ) 440 | 441 | # check constant shift conditional samples 442 | flow = Flow( 443 | ("x"), 444 | Reverse(), 445 | conditional_columns=("y"), 446 | condition_error_model=shift_err_model, 447 | ) 448 | samples = flow._get_err_samples(rng, x, 10, type="conditions") 449 | assert jnp.allclose( 450 | samples, jnp.repeat(jnp.array([[2.2], [4.4]]), 10, axis=0) 451 | ) 452 | 453 | 454 | def test_train_w_conditions(): 455 | xarray = np.array( 456 | [[1.0, 2.0, 0.1, 0.2], [3.0, 4.0, 0.3, 0.4], [5.0, 6.0, 0.5, 0.6]] 457 | ) 458 | x = pd.DataFrame(xarray, columns=("redshift", "y", "a", "b")) 459 | 460 | flow = Flow( 461 | ("redshift", "y"), 462 | Reverse(), 463 | latent=Normal(2), 464 | conditional_columns=("a", "b"), 465 | ) 466 | assert len(flow.train(x, epochs=11)) == 12 467 | 468 | print("------->>>>>") 469 | print(flow._condition_stds, "\n\n") 470 | print(xarray[:, 2:].std(axis=0)) 471 | assert jnp.allclose(flow._condition_means, xarray[:, 2:].mean(axis=0)) 472 | assert jnp.allclose(flow._condition_stds, xarray[:, 2:].std(axis=0)) 473 | 474 | 475 | def test_patience(): 476 | columns = ("redshift", "y") 477 | flow = Flow(columns, Reverse()) 478 | 479 | xarray = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 480 | x = pd.DataFrame(xarray, columns=columns) 481 | 482 | losses = flow.train(x, patience=2) 483 | print(losses) 484 | assert len(losses) == 4 485 | 486 | 487 | def test_latent_with_wrong_dimension(): 488 | cols = ["x", "y"] 489 | latent = Uniform(3) 490 | 491 | with pytest.raises(ValueError): 492 | Flow(data_columns=cols, latent=latent, bijector=Reverse()) 493 | 494 | 495 | def test_bijector_not_set(): 496 | flow = Flow(["x", "y"]) 497 | 498 | with pytest.raises(ValueError): 499 | flow.sample(1) 500 | 501 | with pytest.raises(ValueError): 502 | x = np.linspace(0, 1, 12) 503 | df = pd.DataFrame(x.reshape(-1, 2), columns=("x", "y")) 504 | flow.posterior(x, column="x", grid=x) 505 | 506 | 507 | def test_default_bijector(): 508 | flow = Flow(["x", "y"]) 509 | 510 | losses = flow.train(get_twomoons_data()) 511 | assert all(~np.isnan(losses)) 512 | 513 | 514 | def test_validation_train(): 515 | # load some training data 516 | data = get_twomoons_data()[:10] 517 | train_set = data[:8] 518 | val_set = data[8:] 519 | 520 | # train the default flow 521 | flow = Flow(train_set.columns, Reverse()) 522 | losses = flow.train( 523 | train_set, 524 | val_set, 525 | verbose=True, 526 | epochs=3, 527 | best_params=False, 528 | ) 529 | assert len(losses[0]) == 4 530 | assert len(losses[1]) == 4 531 | 532 | 533 | def test_nan_train_stop(): 534 | # create data with NaNs 535 | data = jnp.nan * jnp.ones((4, 2)) 536 | data = pd.DataFrame(data, columns=["x", "y"]) 537 | 538 | # train the flow 539 | flow = Flow(data.columns, Reverse()) 540 | losses = flow.train(data) 541 | assert len(losses) == 2 542 | 543 | def test_train_weights(): 544 | # load some training data 545 | data = get_twomoons_data()[:10] 546 | train_set = data[:8] 547 | val_set = data[8:] 548 | train_weight = np.linspace(1, 2, len(train_set)) 549 | val_weight = np.linspace(1, 2, len(val_set)) 550 | 551 | # train the default flow 552 | flow = Flow(train_set.columns, Reverse()) 553 | losses = flow.train( 554 | train_set, 555 | val_set, 556 | verbose=True, 557 | epochs=3, 558 | best_params=False, 559 | train_weight=train_weight, 560 | val_weight=val_weight, 561 | ) 562 | assert len(losses[0]) == 4 563 | assert len(losses[1]) == 4 564 | 565 | def test_no_initial_loss(): 566 | # load some training data 567 | data = get_twomoons_data()[:10] 568 | train_set = data[:8] 569 | val_set = data[8:] 570 | 571 | # train the default flow 572 | flow = Flow(train_set.columns, Reverse()) 573 | losses = flow.train( 574 | train_set, 575 | val_set, 576 | verbose=True, 577 | epochs=3, 578 | best_params=False, 579 | initial_loss=False, 580 | ) 581 | assert len(losses[0]) == 3 582 | assert len(losses[1]) == 3 -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | from jax import random 4 | 5 | from pzflow.bijectors import * 6 | from pzflow.utils import * 7 | 8 | 9 | def test_build_bijector_from_info(): 10 | x = jnp.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]) 11 | 12 | init_fun, info1 = Chain( 13 | Reverse(), 14 | Chain(ColorTransform(1, [1, 2, 3]), Roll(-1)), 15 | InvSoftplus(0, 1), 16 | Scale(-0.5), 17 | Chain(Roll(), Scale(-4.0)), 18 | ) 19 | 20 | params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), 4) 21 | xfwd1, log_det1 = forward_fun(params, x) 22 | 23 | init_fun, info2 = build_bijector_from_info(info1) 24 | assert info1 == info2 25 | 26 | params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), 4) 27 | xfwd2, log_det2 = forward_fun(params, x) 28 | assert jnp.allclose(xfwd1, xfwd2) 29 | assert jnp.allclose(log_det1, log_det2) 30 | 31 | invx, inv_log_det = inverse_fun(params, xfwd2) 32 | assert jnp.allclose(x, invx) 33 | assert jnp.allclose(log_det2, -inv_log_det) 34 | 35 | 36 | def test_sub_diag_indices_correct(): 37 | x = jnp.array([[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[2, 2], [2, 2]]]) 38 | y = jnp.array([[[1, 0], [0, 1]], [[2, 1], [1, 2]], [[3, 2], [2, 3]]]) 39 | idx = sub_diag_indices(x) 40 | x = x.at[idx].set(x[idx] + 1) 41 | 42 | assert jnp.allclose(x, y) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "x", 47 | [jnp.ones(2), jnp.ones((2, 2)), jnp.ones((2, 2, 2, 2))], 48 | ) 49 | def test_sub_diag_indices_bad_input(x): 50 | with pytest.raises(ValueError): 51 | idx = sub_diag_indices(x) 52 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import toml 4 | 5 | import pzflow 6 | 7 | 8 | def test_versions_are_in_sync(): 9 | """Checks if the pyproject.toml and pzflow.__init__.py __version__ are in sync.""" 10 | 11 | path = Path(__file__).resolve().parents[1] / "pyproject.toml" 12 | pyproject = toml.loads(open(str(path)).read()) 13 | pyproject_version = pyproject["tool"]["poetry"]["version"] 14 | 15 | package_init_version = pzflow.__version__ 16 | 17 | assert package_init_version == pyproject_version 18 | --------------------------------------------------------------------------------