├── .gitignore ├── jax_sgmc ├── util │ ├── stop_vmap.py │ ├── __init__.py │ ├── testing.py │ ├── uuid.py │ ├── tree_util.py │ └── list_map.py ├── data │ ├── __init__.py │ ├── hdf5_loader.py │ ├── tensorflow_loader.py │ └── numpy_loader.py ├── version.py ├── __init__.py ├── potential.py └── adaption.py ├── docs ├── api │ ├── jax_sgmc.alias.rst │ ├── index.rst │ ├── jax_sgmc.util.rst │ ├── jax_sgmc.potential.rst │ ├── jax_sgmc.solver.rst │ ├── jax_sgmc.integrator.rst │ ├── jax_sgmc.adaption.rst │ ├── jax_sgmc.io.rst │ ├── jax_sgmc.scheduler.rst │ └── jax_sgmc.data.rst ├── advanced │ ├── scheduler.rst │ └── adaption.rst ├── Makefile ├── make.bat ├── installation.md ├── index.rst ├── usage │ ├── io.rst │ ├── scheduler.rst │ ├── potential.rst │ └── data.rst └── conf.py ├── pytest.ini ├── .readthedocs.yml ├── LICENSE_SHORT ├── .github └── workflows │ ├── publish.yml │ └── ci.yml ├── tests ├── conftest.py ├── test_tree_operations.py ├── test_adaption.py ├── test_scheduler.py ├── test_io.py ├── test_alias.py └── test_potential.py ├── pyproject.toml ├── pylintrc ├── README.md ├── examples ├── quickstart.md └── sgld_rms.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /docs/build/ 2 | /docs/_autosummary/ 3 | -------------------------------------------------------------------------------- /jax_sgmc/util/stop_vmap.py: -------------------------------------------------------------------------------- 1 | """Stop vectorization and execute function sequentially. """ 2 | 3 | def stop_vmap(wrapped): 4 | return wrapped 5 | -------------------------------------------------------------------------------- /jax_sgmc/util/__init__.py: -------------------------------------------------------------------------------- 1 | from jax_sgmc.util.tree_util import ( 2 | tree_multiply, tree_scale, tree_add, 3 | Array, tree_matmul, tree_dot, Tensor, tensor_matmul) 4 | from jax_sgmc.util.list_map import (list_vmap, list_pmap, pytree_leaves_to_list, 5 | pytree_list_to_leaves) 6 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.alias.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.alias 2 | =============== 3 | 4 | .. automodule:: jax_sgmc.alias 5 | 6 | Solvers 7 | -------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | sgld 13 | re_sgld 14 | amagold 15 | sggmc 16 | sghmc 17 | obabo 18 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts=--disable-warnings --strict -v 3 | markers= 4 | tensorflow: marking test requiring tensorflow and tensorflow_datasets 5 | hdf5: marking test requiring hdf5/h5py 6 | pmap: marking test requiring more XLA-devices 7 | solver: testing convergence of a solver on simple problem -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-20.04" 5 | tools: 6 | python: "3.8" 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | python: 13 | install: 14 | - method: pip 15 | path: . 16 | extra_requirements: 17 | - docs 18 | -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc 2 | ================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: jax_sgmc 7 | 8 | jax_sgmc.adaption 9 | jax_sgmc.alias 10 | jax_sgmc.data 11 | jax_sgmc.integrator 12 | jax_sgmc.io 13 | jax_sgmc.potential 14 | jax_sgmc.scheduler 15 | jax_sgmc.solver 16 | jax_sgmc.util 17 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.util.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.util 2 | ============== 3 | 4 | .. automodule:: jax_sgmc.util 5 | 6 | 7 | .. autosummary:: 8 | :toctree: _autosummary 9 | 10 | Array 11 | tree_multiply 12 | tree_add 13 | tree_scale 14 | list_vmap 15 | list_pmap 16 | pytree_list_to_leaves 17 | pytree_leaves_to_list 18 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.potential.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.potential 2 | =================== 3 | 4 | .. automodule:: jax_sgmc.potential 5 | 6 | Stochastic Potential 7 | --------------------- 8 | 9 | .. autofunction:: jax_sgmc.potential.minibatch_potential 10 | 11 | .. autoclass:: jax_sgmc.potential.StochasticPotential 12 | 13 | .. automethod:: __call__ 14 | 15 | Full Potential 16 | --------------- 17 | 18 | .. autofunction:: jax_sgmc.potential.full_potential 19 | 20 | .. autoclass:: jax_sgmc.potential.FullPotential 21 | 22 | .. automethod:: __call__ 23 | -------------------------------------------------------------------------------- /jax_sgmc/util/testing.py: -------------------------------------------------------------------------------- 1 | """Testing utility.""" 2 | from functools import partial 3 | 4 | import numpy as onp 5 | from numpy import testing 6 | 7 | from jax import tree_util 8 | 9 | assert_equal = partial(tree_util.tree_map, testing.assert_array_equal) 10 | 11 | 12 | def assert_close(x, y, **kwargs): 13 | if "rtol" not in kwargs.keys(): 14 | kwargs["rtol"] = 1e-5 15 | def assert_fn(xi, yi): 16 | xi = onp.ravel(xi) 17 | yi = onp.ravel(yi) 18 | testing.assert_allclose(xi, yi, **kwargs) 19 | tree_util.tree_map(assert_fn, x, y) 20 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.solver.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.solver 2 | ================ 3 | 4 | .. automodule:: jax_sgmc.solver 5 | 6 | 7 | MCMC 8 | ----- 9 | 10 | Run multiple chains of a solver in parallel or vectorized and save the results. 11 | 12 | .. autoclass:: jax_sgmc.solver.mcmc 13 | 14 | 15 | Solvers 16 | ------- 17 | 18 | .. autofunction:: jax_sgmc.solver.sgmc 19 | .. autofunction:: jax_sgmc.solver.amagold 20 | .. autofunction:: jax_sgmc.solver.sggmc 21 | .. autofunction:: jax_sgmc.solver.parallel_tempering 22 | 23 | Solver States 24 | -------------- 25 | 26 | .. autoclass:: jax_sgmc.solver.AMAGOLDState 27 | .. autoclass:: jax_sgmc.solver.SGGMCState 28 | -------------------------------------------------------------------------------- /docs/advanced/scheduler.rst: -------------------------------------------------------------------------------- 1 | Extending Schedulers 2 | ===================== 3 | 4 | Integration with Base Scheduler 5 | -------------------------------- 6 | 7 | Global and Local Scheduler Arguments 8 | _____________________________________ 9 | 10 | - *Global only* arguments are provided by position to the scheduler 11 | - *Global and local* arguments are provided by keyword to the scheduler and the 12 | init function such that they can be overwritten. 13 | 14 | For example: 15 | 16 | :: 17 | 18 | def some_scheduler(global_arg, global_or_local=0.0): 19 | 20 | def init_fn(global_or_local = global_or_local): 21 | # Use local arg 22 | ... 23 | -------------------------------------------------------------------------------- /LICENSE_SHORT: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /jax_sgmc/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .core import * 16 | -------------------------------------------------------------------------------- /jax_sgmc/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "0.1.5" 16 | 17 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /jax_sgmc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from jax_sgmc.version import __version__ as __version__ 16 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.integrator.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.integrator 2 | ==================== 3 | 4 | Overview 5 | --------- 6 | 7 | .. automodule:: jax_sgmc.integrator 8 | 9 | Integrators 10 | ------------ 11 | 12 | .. autofunction:: jax_sgmc.integrator.obabo 13 | .. autofunction:: jax_sgmc.integrator.reversible_leapfrog 14 | .. autofunction:: jax_sgmc.integrator.friction_leapfrog 15 | .. autofunction:: jax_sgmc.integrator.langevin_diffusion 16 | 17 | Integrator States 18 | ------------------ 19 | 20 | .. autoclass:: jax_sgmc.integrator.ObaboState 21 | .. autoclass:: jax_sgmc.integrator.LeapfrogState 22 | .. autoclass:: jax_sgmc.integrator.LangevinState 23 | 24 | Utility 25 | ------- 26 | 27 | .. autofunction:: jax_sgmc.integrator.random_tree 28 | .. autofunction:: jax_sgmc.integrator.init_mass 29 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.adaption.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.adaption 2 | ================== 3 | 4 | .. automodule:: jax_sgmc.adaption 5 | 6 | Adaption Strategies 7 | --------------------- 8 | 9 | Mass Matrix 10 | _____________ 11 | 12 | .. autoclass:: jax_sgmc.adaption.MassMatrix 13 | 14 | .. autofunction:: jax_sgmc.adaption.mass_matrix 15 | 16 | Manifold 17 | _________ 18 | 19 | .. autoclass:: jax_sgmc.adaption.Manifold 20 | 21 | .. autofunction:: jax_sgmc.adaption.rms_prop 22 | 23 | Noise Model 24 | ____________ 25 | 26 | .. autoclass:: NoiseModel 27 | 28 | .. autofunction:: jax_sgmc.adaption.fisher_information 29 | 30 | Developer Information 31 | ---------------------- 32 | 33 | .. autoclass:: jax_sgmc.adaption.AdaptionState 34 | 35 | .. autofunction:: jax_sgmc.adaption.get_unravel_fn 36 | 37 | .. autofunction:: jax_sgmc.adaption.adaption 38 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.io.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.io 2 | =========== 3 | 4 | .. automodule:: jax_sgmc.io 5 | 6 | 7 | Data Collectors 8 | ---------------- 9 | 10 | Data Collector Interface 11 | _________________________ 12 | 13 | .. autoclass:: jax_sgmc.io.DataCollector 14 | :members: 15 | 16 | Collectors 17 | ___________ 18 | 19 | .. autoclass:: jax_sgmc.io.MemoryCollector 20 | :members: 21 | 22 | .. autoclass:: jax_sgmc.io.HDF5Collector 23 | :members: 24 | 25 | 26 | Saving Strategies 27 | ------------------ 28 | 29 | .. autofunction:: jax_sgmc.io.save 30 | 31 | .. autofunction:: jax_sgmc.io.no_save 32 | 33 | 34 | Pytree to Dict Transformation 35 | ------------------------------ 36 | 37 | .. autofunction:: jax_sgmc.io.pytree_to_dict 38 | 39 | .. autofunction:: jax_sgmc.io.dict_to_pytree 40 | 41 | .. autofunction:: jax_sgmc.io.pytree_dict_keys 42 | 43 | .. autofunction:: jax_sgmc.io.register_dictionize_rule 44 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # Upload the package to PyPI. 2 | # For more information, see https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | name: Upload to PyPI 5 | 6 | on: 7 | release: 8 | types: [published] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.x' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install build 23 | - name: Build package 24 | run: python -m build 25 | - name: Publish package 26 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 27 | with: 28 | user: __token__ 29 | password: ${{ secrets.PYPI_API_TOKEN }} 30 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import pytest 4 | 5 | import os 6 | from jax.lib import xla_bridge 7 | 8 | from jax import test_util 9 | 10 | # Setup multi-device environment 11 | @pytest.fixture(scope='session', autouse=True) 12 | def pmap_setup(): 13 | # Setup 14 | prev_xla_flags = os.getenv("XLA_FLAGS") 15 | flags_str = prev_xla_flags or "" 16 | # Don't override user-specified device count, or other XLA flags. 17 | if "xla_force_host_platform_device_count" not in flags_str: 18 | os.environ["XLA_FLAGS"] = (flags_str + 19 | " --xla_force_host_platform_device_count=12") 20 | # Clear any cached backends so new CPU backend will pick up the env var. 21 | xla_bridge.get_backend.cache_clear() 22 | 23 | # Run 24 | yield 25 | 26 | # Reset to previous configuration in case other test modules will be run. 27 | if prev_xla_flags is None: 28 | del os.environ["XLA_FLAGS"] 29 | else: 30 | os.environ["XLA_FLAGS"] = prev_xla_flags 31 | xla_bridge.get_backend.cache_clear() 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.scheduler.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.scheduler 2 | ======================= 3 | 4 | .. automodule:: jax_sgmc.scheduler 5 | 6 | Base Scheduler 7 | -------------- 8 | 9 | .. autofunction:: jax_sgmc.scheduler.init_scheduler 10 | 11 | .. autofunction:: jax_sgmc.scheduler.scheduler_state 12 | 13 | .. autofunction:: jax_sgmc.scheduler.schedule 14 | 15 | .. autofunction:: jax_sgmc.scheduler.static_information 16 | 17 | Specific Schedulers 18 | ------------------- 19 | 20 | .. autoclass:: jax_sgmc.scheduler.specific_scheduler 21 | 22 | Step-size 23 | __________ 24 | 25 | .. autosummary:: 26 | :toctree: _autosummary 27 | 28 | polynomial_step_size 29 | polynomial_step_size_first_last 30 | adaptive_step_size 31 | 32 | Temperature 33 | ____________ 34 | 35 | .. autosummary:: 36 | :toctree: _autosummary 37 | 38 | constant_temperature 39 | cyclic_temperature 40 | 41 | Burn In 42 | ________ 43 | 44 | .. autosummary:: 45 | :toctree: _autosummary 46 | 47 | cyclic_burn_in 48 | initial_burn_in 49 | 50 | Thinning 51 | __________ 52 | 53 | .. autosummary:: 54 | :toctree: _autosummary 55 | 56 | random_thinning 57 | 58 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Basic Setup 4 | 5 | **JaxSGMC** can be installed with pip: 6 | 7 | ```shell 8 | pip install jax-sgmc --upgrade 9 | ``` 10 | 11 | The above command installs **Jax for CPU**. 12 | 13 | To be able to run **JaxSGMC on the GPU**, a special version of Jax has to be 14 | installed. Further information can be found in 15 | [Jax Installation Instructions](https://github.com/google/jax#installation). 16 | 17 | (additional_requirements)= 18 | ## Additional Packages 19 | 20 | Some parts of **JaxSGMC** require additional packages: 21 | 22 | - Data Loading with tensorflow: 23 | ```shell 24 | pip install jax-sgmc[tensorflow] --upgrade 25 | ``` 26 | - Saving Samples in the HDF5-Format: 27 | ```shell 28 | pip install jax-sgmc[hdf5] --upgrade 29 | ``` 30 | 31 | 32 | ## Installation from Source 33 | 34 | For development purposes, **JaxSGMC** can be installed from source in 35 | editable mode: 36 | 37 | ```shell 38 | git clone git@github.com:tummfm/jax-sgmc.git 39 | pip install -e .[test,docs] 40 | ``` 41 | 42 | This command additionally installs the requirements to run the tests: 43 | 44 | ```shell 45 | pytest tests 46 | ``` 47 | 48 | And to build the documentation (e.g. in html): 49 | 50 | ```shell 51 | make -C docs html 52 | ``` 53 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Jax SGMC 2 | ======== 3 | 4 | JaxSGMC brings Stochastic Gradient Markov chain Monte Carlo (SGMCMC) samplers to JAX. Inspired by `optax `_, JaxSGMC is built on a modular concept to increase reusability and accelerate research of new SGMCMC solvers. Additionally, JaxSGMC aims to promote probabilistic machine learning by removing obstacles in switching from stochastic optimizers to SGMCMC samplers. 5 | 6 | 7 | To get started quickly using SGMCMC samplers, JaxSGMC provides some popular pre-built samplers in :doc:`./api/jax_sgmc.alias`: 8 | 9 | - `SGLD (rms-prop) `_ 10 | - `SGHMC `_ 11 | - `reSGLD `_ 12 | - `SGGMC `_ 13 | - `AMAGOLD `_ 14 | - `OBABO `_ 15 | 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Getting Started 20 | 21 | installation 22 | quickstart 23 | 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | :caption: Reference Documentation 28 | 29 | usage/data 30 | usage/potential 31 | usage/io 32 | usage/scheduler 33 | usage/sgld_rms 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | :caption: Advanced Topics 38 | 39 | advanced/adaption 40 | advanced/scheduler 41 | 42 | 43 | .. toctree:: 44 | :maxdepth: 2 45 | :caption: Examples 46 | 47 | examples/cifar 48 | 49 | 50 | .. toctree:: 51 | :maxdepth: 3 52 | :caption: API Documentation 53 | 54 | api/index 55 | 56 | 57 | Indices and tables 58 | ================== 59 | 60 | * :ref:`genindex` 61 | * :ref:`modindex` 62 | * :ref:`search` 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61"] # PEP 508 specifications. 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "jax-sgmc" 7 | authors = [ 8 | {name = "Paul Fuchs", email = "paul.fuchs@tum.de"}, 9 | {name = "Stephan Thaler", email = "stephan.thaler@tum.de"}, 10 | ] 11 | description = "Stochastic Gradient Monte Carlo in Jax" 12 | readme = "README.md" 13 | license = {"text" = "Apache-2.0"} 14 | classifiers = [ 15 | "Programming Language :: Python :: 3.8", 16 | "Programming Language :: Python :: 3.9", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Operating System :: MacOS", 19 | "Operating System :: POSIX :: Linux", 20 | "Topic :: Scientific/Engineering", 21 | "Intended Audience :: Science/Research", 22 | "Intended Audience :: Developers" 23 | ] 24 | requires-python = ">=3.7" 25 | dependencies = [ 26 | "numpy", 27 | "jax >= 0.1.73", 28 | "jaxlib >= 0.1.52", 29 | "dataclasses", 30 | ] 31 | dynamic = ["version"] 32 | 33 | [project.urls] 34 | "Documentation" = "https://jax-sgmc.readthedocs.io/en/latest/" 35 | "Source" = "https://github.com/tummfm/jax-sgmc" 36 | "Bug Tracker" = "https://github.com/tummfm/jax-sgmc/issues" 37 | 38 | [project.optional-dependencies] 39 | "tensorflow" = [ 40 | "tensorflow", 41 | "tensorflow_datasets", 42 | ] 43 | "test" = [ 44 | "pylint", 45 | "pytest", 46 | "pytest-mock", 47 | ] 48 | "docs" = [ 49 | "sphinx >= 3", 50 | "sphinx_rtd_theme", 51 | "sphinx-autodoc-typehints == 1.11.1", 52 | "myst-nb", 53 | "numpyro", 54 | "matplotlib", 55 | "h5py", 56 | "tensorflow", 57 | "tensorflow_datasets" 58 | ] 59 | 60 | [tool.setuptools.packages] 61 | find = {namespaces = false} 62 | 63 | [tool.setuptools.dynamic] 64 | version = {attr = "jax_sgmc.version.__version__"} 65 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code 6 | extension-pkg-whitelist=numpy 7 | 8 | 9 | [MESSAGES CONTROL] 10 | 11 | # Disable the message, report, category or checker with the given id(s). You 12 | # can either give multiple identifiers separated by comma (,) or put this 13 | # option multiple times (only on the command line, not in the configuration 14 | # file where it should appear only once).You can also use "--disable=all" to 15 | # disable everything first and then reenable specific checks. For example, if 16 | # you want to run only the similarities checker, you can use "--disable=all 17 | # --enable=similarities". If you want to run only the classes checker, but have 18 | # no Warning level messages displayed, use"--disable=all --enable=classes 19 | # --disable=W" 20 | disable=missing-docstring, 21 | too-many-locals, 22 | too-many-lines, 23 | invalid-name, 24 | redefined-outer-name, 25 | redefined-builtin, 26 | protected-name, 27 | no-else-return, 28 | fixme, 29 | protected-access, 30 | too-many-arguments, 31 | blacklisted-name, 32 | too-few-public-methods, 33 | unnecessary-lambda, 34 | anomalous-backslash-in-string, 35 | 36 | 37 | # Enable the message, report, category or checker with the given id(s). You can 38 | # either give multiple identifier separated by comma (,) or put this option 39 | # multiple time (only on the command line, not in the configuration file where 40 | # it should appear only once). See also the "--disable" option for examples. 41 | enable=c-extension-no-member 42 | 43 | 44 | [FORMAT] 45 | 46 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 47 | # tab). 48 | indent-string=" " -------------------------------------------------------------------------------- /jax_sgmc/util/uuid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import uuid 16 | 17 | import numpy as onp 18 | 19 | import jax.numpy as jnp 20 | from jax import tree_util 21 | 22 | from jax import Array 23 | 24 | @tree_util.register_pytree_node_class 25 | class JaxUUID: 26 | 27 | def __init__(self, ints: Array = None): 28 | if ints is None: 29 | uuid_int = uuid.uuid4().int 30 | ints = [(uuid_int >> bits) & 0xFFFFFFFF for bits in range(0, 128, 32)] 31 | ints = onp.array(ints, dtype=jnp.uint32) 32 | 33 | self._uuid_int = ints 34 | 35 | @property 36 | def as_uuid(self): 37 | # Rearrange 4 int32 to uuid 38 | shifted_ints = [int(int(sint) << bits) 39 | for sint, bits 40 | in zip(self._uuid_int, range(0, 128, 32))] 41 | int128 = sum(shifted_ints) 42 | # Ensure that the hex number has exactly 128 bits and is non-negative 43 | hex128 = hex(int128).replace('0x', '').replace('-','').zfill(32) 44 | return uuid.UUID(hex128) 45 | 46 | def __repr__(self): 47 | return str(self.as_uuid) 48 | 49 | @property 50 | def as_int32s(self): 51 | return self._uuid_int 52 | 53 | def tree_flatten(self): 54 | # Wrapping the ints in a tuple ensures that they remain a single array of 55 | # length 4. 56 | children = (self._uuid_int,) 57 | return (children, None) 58 | 59 | @classmethod 60 | def tree_unflatten(cls, _, children): 61 | ints, = children 62 | return cls(ints=ints) 63 | -------------------------------------------------------------------------------- /docs/api/jax_sgmc.data.rst: -------------------------------------------------------------------------------- 1 | jax_sgmc.data 2 | ============== 3 | 4 | jax_sgmc.data.core 5 | ------------------- 6 | 7 | .. automodule:: jax_sgmc.data.core 8 | 9 | Host Callback Wrappers 10 | ________________________ 11 | 12 | .. autofunction:: jax_sgmc.data.core.random_reference_data 13 | 14 | .. autofunction:: jax_sgmc.data.core.full_reference_data 15 | 16 | .. autoclass:: jax_sgmc.data.core.GetBatchFunction 17 | 18 | .. automethod:: __call__ 19 | 20 | .. autoclass:: jax_sgmc.data.core.FullDataMapFunction 21 | 22 | .. automethod:: __call__ 23 | 24 | .. autoclass:: jax_sgmc.data.core.MaskedMappedFunction 25 | 26 | .. automethod:: __call__ 27 | 28 | .. autoclass:: jax_sgmc.data.core.UnmaskedMappedFunction 29 | 30 | .. automethod:: __call__ 31 | 32 | .. autofunction:: jax_sgmc.data.core.full_data_mapper 33 | 34 | .. autoclass:: jax_sgmc.data.core.FullDataMapperFunction 35 | 36 | .. automethod:: __call__ 37 | 38 | States 39 | _______ 40 | 41 | .. autoclass:: jax_sgmc.data.core.MiniBatchInformation 42 | 43 | .. autoclass:: jax_sgmc.data.core.CacheState 44 | 45 | Base Classes 46 | _____________ 47 | 48 | .. autoclass:: jax_sgmc.data.core.DataLoader 49 | :members: 50 | 51 | .. autoclass:: jax_sgmc.data.core.DeviceDataLoader 52 | :members: 53 | 54 | .. autoclass:: jax_sgmc.data.core.HostDataLoader 55 | :members: 56 | 57 | Utility Functions 58 | __________________ 59 | 60 | .. autosummary:: 61 | :toctree: _autosummary 62 | 63 | tree_index 64 | tree_dtype_struct 65 | 66 | jax_sgmc.data.numpy_loader 67 | --------------------------- 68 | 69 | .. automodule:: jax_sgmc.data.numpy_loader 70 | 71 | .. autoclass:: jax_sgmc.data.numpy_loader.NumpyBase 72 | :members: 73 | 74 | .. autoclass:: jax_sgmc.data.numpy_loader.NumpyDataLoader 75 | :members: 76 | 77 | .. autoclass:: jax_sgmc.data.numpy_loader.DeviceNumpyDataLoader 78 | :members: 79 | 80 | jax_sgmc.data.tensorflow_loader 81 | -------------------------------- 82 | 83 | .. automodule:: jax_sgmc.data.tensorflow_loader 84 | 85 | .. autoclass:: jax_sgmc.data.tensorflow_loader.TensorflowDataLoader 86 | :members: 87 | 88 | jax_sgmc.data.hdf5_loader 89 | -------------------------- 90 | 91 | .. automodule:: jax_sgmc.data.hdf5_loader 92 | 93 | .. autoclass:: jax_sgmc.data.hdf5_loader.HDF5Loader 94 | -------------------------------------------------------------------------------- /docs/usage/io.rst: -------------------------------------------------------------------------------- 1 | Saving of Samples 2 | ================== 3 | 4 | Seting up Saving 5 | ----------------- 6 | 7 | **JaxSGMC** supports saving inside of jit-compiled 8 | functions. Additionally, checkpointing can also be provided (if this is a 9 | priority for you, please open a feature request on GitHub). 10 | Saving data consists of two parts: 11 | 12 | Data Collector 13 | _______________ 14 | 15 | The data collector serializes the data and writes it to the disk and or keeps 16 | it in memory. Every data collector following this interface works. 17 | 18 | Saving 19 | ________ 20 | 21 | The function 'save' initializes the interface between the data collector and the 22 | jit-compiled function. 23 | 24 | If the device memory is large, it is possible to use 25 | :func:`jax_sgmc.io.no_save`. This function has the same signature towards the 26 | jit-compiled function but keeps all collected samples in the device memory. 27 | 28 | 29 | Extending Saveable PyTree Types 30 | -------------------------------- 31 | 32 | By default, transformations are defined for some default types: 33 | 34 | - list 35 | - dict 36 | - (named)tuple 37 | 38 | Additionally, transformations for the following optional libraries are 39 | implemented: 40 | 41 | - haiku._src.data_structures.FlatMapping 42 | 43 | A new transformation rule is a function, which accepts a pytree node of 44 | a specific type an returns a iterable, which itself returns `(key, value)`- 45 | pairs. 46 | 47 | .. doctest:: 48 | 49 | >>> from jax_sgmc import io 50 | >>> from jax.tree_util import register_pytree_node 51 | >>> 52 | >>> class SomeClass: 53 | ... def __init__(self, value): 54 | ... self._value = value 55 | >>> 56 | >>> # Do not forget to register the class as jax pytree node 57 | >>> register_pytree_node(SomeClass, 58 | ... lambda sc: (sc._value, None), 59 | ... lambda _, data: SomeClass(value=data)) 60 | >>> 61 | >>> # Now define a rule to transform the class into a dict 62 | >>> @io.register_dictionize_rule(SomeClass) 63 | ... def some_class_to_dict(instance_of_some_class): 64 | ... return [("this_is_the_key", instance_of_some_class._value)] 65 | >>> 66 | >>> some_class = SomeClass({'a': 0.0, 'b': 0.5}) 67 | >>> some_class_as_dict = io.pytree_to_dict(some_class) 68 | >>> 69 | >>> print(some_class_as_dict) 70 | {'this_is_the_key': {'a': 0.0, 'b': 0.5}} 71 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 3.8 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.8 22 | - name: Install requirements 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install .[test,tensorflow] 26 | - name: Lint with pylint 27 | run: | 28 | pylint jax_sgmc --fail-under 9.0 29 | 30 | doctest: 31 | runs-on: ubuntu-latest 32 | 33 | steps: 34 | - uses: actions/checkout@v2 35 | - name: Set up Python 3.8 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: 3.8 39 | - name: Install requirements 40 | run: | 41 | python -m pip install --upgrade pip 42 | pip install .[docs] 43 | - name: Build documentation 44 | run: | 45 | make -C docs html 46 | make -C docs doctest 47 | 48 | test-build: 49 | 50 | runs-on: ubuntu-latest 51 | 52 | steps: 53 | - uses: actions/checkout@v3 54 | - name: Set up Python 55 | uses: actions/setup-python@v3 56 | with: 57 | python-version: '3.x' 58 | - name: Install dependencies 59 | run: | 60 | python -m pip install --upgrade pip 61 | pip install --upgrade build 62 | - name: Build package 63 | run: python -m build 64 | 65 | test: 66 | 67 | runs-on: ubuntu-latest 68 | 69 | steps: 70 | - uses: actions/checkout@v2 71 | - name: Set up Python 3.8 72 | uses: actions/setup-python@v2 73 | with: 74 | python-version: 3.8 75 | - name: Install requirements 76 | run: | 77 | python -m pip install --upgrade pip 78 | pip install .[test] 79 | - name: Test with pytest 80 | run: | 81 | pytest --tb=line -m "not tensorflow and not hdf5 and not pmap and not solver" 82 | 83 | test-with-tensorflow: 84 | 85 | runs-on: ubuntu-latest 86 | 87 | steps: 88 | - uses: actions/checkout@v2 89 | - name: Set up Python 3.8 90 | uses: actions/setup-python@v2 91 | with: 92 | python-version: 3.8 93 | - name: Install requirements 94 | run: | 95 | python -m pip install --upgrade pip 96 | pip install .[test,tensorflow] 97 | - name: Test with pytest 98 | run: | 99 | pytest --tb=line -m "tensorflow" 100 | 101 | test-alias: 102 | 103 | runs-on: ubuntu-latest 104 | 105 | steps: 106 | - uses: actions/checkout@v2 107 | - name: Set up Python 3.8 108 | uses: actions/setup-python@v2 109 | with: 110 | python-version: 3.8 111 | - name: Install requirements 112 | run: | 113 | python -m pip install --upgrade pip 114 | pip install .[test,tensorflow] 115 | - name: Test with pytest 116 | run: | 117 | pytest --tb=line -m "solver" 118 | -------------------------------------------------------------------------------- /docs/usage/scheduler.rst: -------------------------------------------------------------------------------- 1 | Setup Schedulers 2 | ================== 3 | 4 | A scheduler is a combination of specific schedulers, which control only a single 5 | parameter, for example the step size. 6 | Specific schedulers for different variables are combined into a basic 7 | scheduler via :func:`jax_sgmc.scheduler.init_scheduler`, which updates all 8 | specific schedulers and provides default values for parameters without a 9 | specific scheduler. 10 | 11 | 12 | Specific Schedulers 13 | -------------------- 14 | 15 | .. doctest:: 16 | 17 | >>> from jax_sgmc import scheduler 18 | >>> 19 | >>> step_size_schedule_unused = scheduler.polynomial_step_size( 20 | ... a=0.1, b=1.0, gamma=0.33) 21 | 22 | We already provided all required arguments. However, it is also possible to 23 | provide only the arguments, which should stay equal over all chains. 24 | For example we could provide different ``gamma``-values by specifying them 25 | during the initialization of the basic scheduler: 26 | 27 | >>> step_size_schedule_partial = scheduler.polynomial_step_size( 28 | ... a=0.1, b=1.0) 29 | 30 | Basic Scheduler 31 | --------------- 32 | 33 | It is not necessary to setup a scheduler for all parameters, because the basic 34 | scheduler provides default values. 35 | Therefore, we can initialize the basic scheduler only with the specific step 36 | size schedule we initialized above: 37 | 38 | >>> init_fn, next_fn, get_fn = scheduler.init_scheduler( 39 | ... step_size=step_size_schedule_partial, progress_bar=False) 40 | 41 | 42 | After we created the basic scheduler, we can initialize a schedule. 43 | Here we have to provide the missing values for the partially initialized 44 | schedulers. 45 | 46 | >>> sched_a, static_information = init_fn(10, step_size={'gamma': 0.1}) 47 | >>> sched_b, _ = init_fn(10, step_size={'gamma': 1.0}) 48 | 49 | Static information is returned in addition to the scheduler state, e.g. 50 | the total number of iterations or the expected number of collected samples. 51 | This information is necessary, e.g., for the ``io``-module to allocate 52 | sufficient memory for the samples to be saved. 53 | 54 | >>> print(static_information) 55 | static_information(samples_collected=10) 56 | 57 | In this example, we can see that the temperature parameter has been assigned to 58 | a default value of 1.0 and the different step size schedules are updated with 59 | different gamma parameters: 60 | 61 | >>> curr_sched_a = get_fn(sched_a) 62 | >>> curr_sched_b = get_fn(sched_b) 63 | 64 | >>> print(f"Scheduler a\n===========\n" 65 | ... f" Step-Size = {curr_sched_a.step_size : .2f}\n" 66 | ... f" Temperature = {curr_sched_a.temperature : .2f}") 67 | Scheduler a 68 | =========== 69 | Step-Size = 0.10 70 | Temperature = 1.00 71 | >>> print(f"Scheduler b\n===========\n" 72 | ... f" Step-Size = {curr_sched_b.step_size : .2f}\n" 73 | ... f" Temperature = {curr_sched_b.temperature : .2f}") 74 | Scheduler b 75 | =========== 76 | Step-Size = 0.10 77 | Temperature = 1.00 78 | 79 | >>> # Get the parameters at the next iteration 80 | >>> sched_a = next_fn(sched_a) 81 | >>> sched_b = next_fn(sched_b) 82 | >>> curr_sched_a = get_fn(sched_a) 83 | >>> curr_sched_b = get_fn(sched_b) 84 | 85 | >>> print(f"Scheduler a\n===========\n" 86 | ... f" Step-Size = {curr_sched_a.step_size : .2f}\n" 87 | ... f" Temperature = {curr_sched_a.temperature : .2f}") 88 | Scheduler a 89 | =========== 90 | Step-Size = 0.09 91 | Temperature = 1.00 92 | >>> print(f"Scheduler b\n===========\n" 93 | ... f" Step-Size = {curr_sched_b.step_size : .2f}\n" 94 | ... f" Temperature = {curr_sched_b.temperature : .2f}") 95 | Scheduler b 96 | =========== 97 | Step-Size = 0.05 98 | Temperature = 1.00 99 | -------------------------------------------------------------------------------- /docs/advanced/adaption.rst: -------------------------------------------------------------------------------- 1 | Extend Adapted Quantities 2 | ========================= 3 | 4 | Extension of Adaption Strategies 5 | _________________________________ 6 | 7 | Each adaption strategy is expected to return three functions 8 | 9 | :: 10 | 11 | @adaption(quantity=SomeQuantity) 12 | def some_adaption(minibatch_potential: Callable = None): 13 | ... 14 | return init_adaption, update_adaption, get_adaption 15 | 16 | The decorator :func:`adaption` wraps all three functions to flatten pytrees to 17 | 1D-arrays and unflatten the results of :func:`get_adaption`. 18 | 19 | The rule is that all arguments that are passed by position are expected 20 | to have the same shape as the sample pytree and are flattened to 1D-arrays. 21 | Arguments that should not be raveled have to be passed by keyword. 22 | 23 | 1. :func:`init_adaption` 24 | 25 | This function initializes the state of the adaption and the ravel- and unravel 26 | functions. Therefore, it must accept at least one positional argument with 27 | the shape of the sample pytree. 28 | 29 | :: 30 | 31 | ... 32 | def init_adaption(sample, momentum, parameter = 0.5): 33 | ... 34 | 35 | In the example above, the sample and the momentum are 1D-arrays with size 36 | equal to the latent variable count. Parameter is a scalar and will not be 37 | raveled. 38 | 39 | 2. :func:`update_adaption` 40 | 41 | This function updates the state of the adaption. It must accept at least one 42 | positional argument, the state, even if the adaption is stateless. 43 | 44 | :: 45 | 46 | ... 47 | # This is a stateless adaption 48 | def update_adaption(state, *args, **kwargs): 49 | del state, args, kwargs 50 | return None 51 | 52 | If the factory function of the adaption strategy is called with a potential 53 | function as keyword argument (`minibatch_potential = some_fun`), then 54 | :func:`update_adaption` is additionally called with the keyword arguments 55 | `flat_potential` and `mini_batch`. `flat_potential` is a wrapped version of 56 | the original potential function and can be called with the raveled sample. 57 | 58 | 3. :func:`get_adaption` 59 | 60 | This function calculates the desired quantity. Its argument-signature equals 61 | :func:`update_adaption`. It should return a 1D tuple of values in the right 62 | order, such that the quantity of the type ``NamedTuple`` can be created by 63 | providing positional arguments. For example, if the quantity has 64 | the fields `q = namedtuple('q', ['a', 'b', 'c'])`, the get function should 65 | look like 66 | 67 | :: 68 | 69 | ... 70 | def get_adaption(state, *args, **kwargs): 71 | ... 72 | return a, b, c 73 | 74 | The returned arrays can have dimension 1 or 2. 75 | 76 | 77 | Extension of Quantities 78 | _________________________ 79 | 80 | The introduction of quantities simplifies the implementation into an integrator 81 | or solver. 82 | 83 | For example, adapting a manifold :math:`G` for SGLD requires the calculation of 84 | :math:`G^{-1},\ G^{-\frac{1}{2}},\ \text{and}\ \Gamma`. If 85 | :func:`get_adaption` returns all three quantities in the order 86 | 87 | :: 88 | 89 | @adaption(quantity=Manifold) 90 | def some_adaption(): 91 | ... 92 | def get_adaption(state, ...): 93 | ... 94 | return g_inv, g_inv_sqrt, gamma 95 | 96 | the manifold should be defined as following, where the correct order of 97 | filed names is important: 98 | 99 | :: 100 | 101 | class Manifold(NamedTuple): 102 | g_inv: PyTree 103 | g_inv_sqrt: PyTree 104 | gamma: PyTree 105 | 106 | The new :func:`get_adaption` does only return a single value of type 107 | :class:`Manifold`. 108 | 109 | :: 110 | 111 | init_adaption, update_adaption, get_adaption = some_adaption() 112 | ... 113 | G = get_adaption(state, ...) 114 | -------------------------------------------------------------------------------- /tests/test_tree_operations.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax import random 7 | from jax import flatten_util 8 | from jax import tree_util 9 | 10 | from jax_sgmc import util 11 | from jax_sgmc import data 12 | from jax_sgmc.util import testing 13 | 14 | # Todo: Test vmap on custom host_callback 15 | 16 | import pytest 17 | 18 | 19 | 20 | @pytest.fixture 21 | def random_tree(): 22 | key = random.PRNGKey(0) 23 | split1, split2 = random.split(key) 24 | tree = {"a": random.normal(split1, shape=(2,)), 25 | "b": {"b1": 0.0, 26 | "b2": random.normal(split2, shape=(3, 4))}} 27 | 28 | flat_tree, unravel_fn = flatten_util.ravel_pytree(tree) 29 | ravel_fn = lambda t: flatten_util.ravel_pytree(t)[0] 30 | return (tree, flat_tree), (ravel_fn, unravel_fn) 31 | 32 | 33 | class TestTree: 34 | 35 | def test_tree_scale(self, random_tree): 36 | (tree, flat_tree), (ravel_fn, unravel_fn) = random_tree 37 | alpha = random.normal(random.PRNGKey(1)) 38 | 39 | true_result = unravel_fn(alpha * flat_tree) 40 | treemap_result = util.tree_scale(alpha, tree) 41 | 42 | testing.assert_equal(true_result, treemap_result) 43 | 44 | def test_tree_multiply(self, random_tree): 45 | (tree, flat_tree), (ravel_fn, unravel_fn) = random_tree 46 | rnd = random.normal(random.PRNGKey(1), shape=flat_tree.shape) 47 | 48 | true_result = unravel_fn(jnp.multiply(rnd, flat_tree)) 49 | treemap_result = util.tree_multiply(unravel_fn(rnd), tree) 50 | 51 | testing.assert_equal(true_result, treemap_result) 52 | 53 | def test_tree_add(self, random_tree): 54 | (tree, flat_tree), (ravel_fn, unravel_fn) = random_tree 55 | rnd = random.normal(random.PRNGKey(1), shape=flat_tree.shape) 56 | 57 | true_result = unravel_fn(jnp.add(rnd, flat_tree)) 58 | treemap_result = util.tree_add(unravel_fn(rnd), tree) 59 | 60 | testing.assert_equal(true_result, treemap_result) 61 | 62 | def test_tree_matmul(self, random_tree): 63 | (tree, flat_tree), (ravel_fn, unravel_fn) = random_tree 64 | rnd = random.normal(random.PRNGKey(1), shape=(flat_tree.size, flat_tree.size)) 65 | 66 | true_result = unravel_fn(jnp.matmul(rnd, flat_tree)) 67 | treemap_result = util.tree_matmul(rnd, tree) 68 | 69 | testing.assert_equal(true_result, treemap_result) 70 | 71 | def test_tree_dot(self, random_tree): 72 | (tree, flat_tree), (ravel_fn, unravel_fn) = random_tree 73 | 74 | true_result = jnp.dot(flat_tree, flat_tree) 75 | treemap_result = util.tree_dot(tree, tree) 76 | 77 | testing.assert_equal(true_result, treemap_result) 78 | 79 | class TestTreeMap(): 80 | 81 | def test_vmap(self, random_tree): 82 | (random_tree, _), _ = random_tree 83 | 84 | zero_tree = tree_util.tree_map(jnp.zeros_like, random_tree) 85 | one_tree = tree_util.tree_map(jnp.ones_like, random_tree) 86 | modified_tree = tree_util.tree_map(jnp.add, random_tree, one_tree) 87 | 88 | @jax.jit 89 | @util.list_vmap 90 | def test_substract(tree): 91 | return tree_util.tree_map(jnp.subtract, modified_tree, tree) 92 | 93 | one, zero = test_substract(random_tree, modified_tree) 94 | 95 | testing.assert_close(zero, zero_tree) 96 | testing.assert_close(one, one_tree) 97 | 98 | @pytest.mark.pmap 99 | def test_pmap(self, random_tree): 100 | (random_tree, _), _ = random_tree 101 | 102 | zero_tree = tree_util.tree_map(jnp.zeros_like, random_tree) 103 | one_tree = tree_util.tree_map(jnp.ones_like, random_tree) 104 | modified_tree = tree_util.tree_map(jnp.add, random_tree, one_tree) 105 | 106 | @jax.jit 107 | @util.list_pmap 108 | def test_substract(tree): 109 | return tree_util.tree_map(jnp.subtract, modified_tree, tree) 110 | 111 | one, zero = test_substract(random_tree, modified_tree) 112 | 113 | testing.assert_close(zero, zero_tree) 114 | testing.assert_close(one, one_tree) -------------------------------------------------------------------------------- /jax_sgmc/util/tree_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Defines types special to jax.edited.bak or this library. """ 16 | 17 | from typing import Any, NamedTuple 18 | from functools import partial 19 | 20 | from jax import tree_util 21 | from jax import flatten_util 22 | import jax.numpy as jnp 23 | 24 | Array = jnp.ndarray 25 | PyTree = Any 26 | 27 | class Tensor(NamedTuple): 28 | """Vector and matrix pytree-products. 29 | 30 | Attributes: 31 | ndim: Dimension of the pytree (1: vector, 2: matrix) 32 | tensor: Data of the pytree 33 | 34 | """ 35 | ndim: int 36 | tensor: PyTree 37 | 38 | def tensor_matmul(matrix: Tensor, vector: PyTree): 39 | """Matrix vector product with a tensor and a pytree. 40 | 41 | Distinguishes between full matrices and diagonal matrices. 42 | 43 | Args: 44 | matrix: Matrix in tensor format 45 | vector: PyTree, which is compatible to the tensor 46 | 47 | """ 48 | if matrix.ndim == 0: 49 | return tree_scale(matrix.tensor, vector) 50 | elif matrix.ndim == 1: 51 | return tree_multiply(matrix.tensor, vector) 52 | elif matrix.ndim == 2: 53 | return tree_matmul(matrix.tensor, vector) 54 | else: 55 | raise NotImplementedError(f"Cannot multiply matrix with dimension " 56 | f"{matrix.ndim}") 57 | 58 | def tree_multiply(tree_a: PyTree, tree_b: PyTree) -> PyTree: 59 | """Maps elementwise product over two vectors. 60 | 61 | Args: 62 | a: First pytree 63 | b: Second pytree, must have the same shape as a 64 | 65 | Returns: 66 | Returns a PyTree obtained by an element-wise product of all PyTree leaves. 67 | 68 | """ 69 | return tree_util.tree_map(jnp.multiply, tree_a, tree_b) 70 | 71 | 72 | def tree_scale(alpha: Array, tree: PyTree) -> PyTree: 73 | """Scalar-Pytree product via tree_map. 74 | 75 | Args: 76 | alpha: Scalar 77 | a: Arbitrary PyTree 78 | 79 | Returns: 80 | Returns a PyTree with all leaves scaled by alpha. 81 | 82 | """ 83 | @partial(partial, tree_util.tree_map) 84 | def tree_scale_imp(x: PyTree): 85 | return alpha * x 86 | return tree_scale_imp(tree) 87 | 88 | 89 | def tree_add(tree_a: PyTree, tree_b: PyTree) -> PyTree: 90 | """Maps elementwise sum over PyTrees. 91 | 92 | Arguments: 93 | a: First PyTree 94 | b: Second PyTree with the same shape as a 95 | 96 | Returns: 97 | Returns a PyTree obtained by leave-wise summation. 98 | """ 99 | @partial(partial, tree_util.tree_map) 100 | def tree_add_imp(leaf_a, leaf_b): 101 | return leaf_a + leaf_b 102 | return tree_add_imp(tree_a, tree_b) 103 | 104 | 105 | def tree_matmul(tree_mat: Array, tree_vec: PyTree): 106 | """Matrix tree product for LD on manifold. 107 | 108 | Arguments: 109 | tree_mat: Matrix to be multiplied with flattened tree 110 | tree_vec: Tree representing vector 111 | 112 | Returns: 113 | Returns the un-flattened product of the matrix and the flattened tree. 114 | """ 115 | # Todo: Redefine without need for flatten util 116 | vec_flat, unravel_fn = flatten_util.ravel_pytree(tree_vec) 117 | return unravel_fn(jnp.matmul(tree_mat, vec_flat)) 118 | 119 | def tree_dot(tree_a: PyTree, tree_b: PyTree): 120 | """Scalar product of two pytrees. 121 | 122 | Args: 123 | tree_a: First pytree 124 | tree_b: Second pytree with same tree stree structure and leaf shape as 125 | tree_a 126 | 127 | Returns: 128 | Returns a scalar, which is the sum of the element-wise product of all leaves. 129 | 130 | """ 131 | leaves_a = tree_util.tree_leaves(tree_a) 132 | leaves_b = tree_util.tree_leaves(tree_b) 133 | return sum((jnp.sum(jnp.multiply(a, b)) for a, b in zip(leaves_a, leaves_b))) 134 | -------------------------------------------------------------------------------- /tests/test_adaption.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax.numpy as jnp 16 | from jax import random 17 | from jax import flatten_util 18 | 19 | import pytest 20 | 21 | from jax_sgmc import adaption 22 | from jax_sgmc import util 23 | from jax_sgmc.util import testing 24 | 25 | class TestDecorator: 26 | """Test that the decorator transforms array adaption to tree adaption.""" 27 | 28 | @pytest.fixture 29 | def random_tree(self): 30 | key = random.PRNGKey(0) 31 | split1, split2 = random.split(key) 32 | tree = {"a": random.normal(split1, shape=(2,)), 33 | "b": {"b1": jnp.array(0.0), 34 | "b2": random.normal(split2, shape=(3, 4))}} 35 | 36 | return tree 37 | 38 | 39 | @pytest.mark.parametrize("test_arg, test_kwarg", [(0.0, 1.0), 40 | (1.0, 2.0), 41 | (1.0, 1.0)]) 42 | def test_args_kwargs(self, random_tree, test_arg, test_kwarg): 43 | 44 | @adaption.adaption(adaption.Manifold) 45 | def init_adaption(): 46 | def init(sample, arg, kwarg=1.0): 47 | return sample, arg, kwarg 48 | def update(state, sample, sample_grad, arg, kwarg=1.0): 49 | return state, sample, sample_grad, arg, kwarg 50 | def get(state, sample, sample_grad, arg, kwarg=1.0): 51 | return arg * sample, kwarg * sample_grad, sample-sample_grad 52 | return init, update, get 53 | 54 | init, update, get = init_adaption() 55 | init_state = init(random_tree, test_arg, kwarg=test_kwarg) 56 | update_state = update(init_state, random_tree, random_tree, arg=test_arg, kwarg=test_kwarg) 57 | manifold = get(init_state, random_tree, random_tree, arg=test_arg, kwarg=test_kwarg) 58 | 59 | # Assert that parameters are passed correctly 60 | 61 | testing.assert_equal(init_state.ravel_fn(random_tree), init_state.state[0]) 62 | assert init_state.state[1] == test_arg 63 | assert init_state.state[2] == test_kwarg 64 | 65 | testing.assert_equal(update_state.ravel_fn(random_tree), update_state.state[1]) 66 | testing.assert_equal(update_state.ravel_fn(random_tree), update_state.state[2]) 67 | assert update_state.state[3] == test_arg 68 | assert update_state.state[4] == test_kwarg 69 | 70 | assert manifold.g_inv.ndim == 1 71 | assert manifold.sqrt_g_inv.ndim == 1 72 | assert manifold.gamma.ndim == 1 73 | testing.assert_equal(manifold.g_inv.tensor, util.tree_scale(test_arg, random_tree)) 74 | testing.assert_equal(manifold.sqrt_g_inv.tensor, util.tree_scale(test_kwarg, random_tree)) 75 | testing.assert_equal(manifold.gamma.tensor, util.tree_scale(0.0, random_tree)) 76 | 77 | @pytest.mark.parametrize("diag", [True, False]) 78 | def test_diag_full(self, diag, random_tree): 79 | @adaption.adaption(adaption.Manifold) 80 | def init_adaption(): 81 | def init(*args): 82 | return None 83 | def update(*args): 84 | return None 85 | def get(state, sample, *args, diag=True): 86 | if diag: 87 | return jnp.ones_like(sample), jnp.ones_like(sample), jnp.ones_like(sample) 88 | else: 89 | return jnp.eye(sample.size), jnp.eye(sample.size), jnp.ones_like(sample) 90 | return init, update, get 91 | init, _, get = init_adaption() 92 | 93 | init_state = init(random_tree) 94 | manifold = get(init_state, random_tree, random_tree, None, diag=diag) 95 | 96 | # Assert correctness of diagonal / full manifold 97 | if diag: 98 | assert manifold.g_inv.ndim == 1 99 | testing.assert_equal(util.tree_multiply(manifold.g_inv.tensor, random_tree), random_tree) 100 | else: 101 | assert manifold.g_inv.ndim == 2 102 | testing.assert_equal(util.tree_matmul(manifold.g_inv.tensor, random_tree), random_tree) 103 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | 14 | import os 15 | import sys 16 | 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | from jax_sgmc.version import __version__ as jax_sgmc_version 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'jax-sgmc' 24 | copyright = '2021, Multiscale Modeling of Fluid Materials, TU Munich' 25 | author = 'Multiscale Modeling of Fluid Materials' 26 | 27 | # The full version, including alpha/beta/rc tags 28 | release = jax_sgmc_version 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.autosummary', 39 | 'sphinx.ext.coverage', 40 | 'sphinx.ext.doctest', 41 | 'sphinx.ext.intersphinx', 42 | 'sphinx.ext.mathjax', 43 | 'sphinx.ext.napoleon', 44 | 'sphinx.ext.viewcode', 45 | 'sphinx_autodoc_typehints', 46 | 'sphinx_rtd_theme', 47 | 'myst_nb' 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | source_suffix = ['.rst', '.ipynb'] 54 | 55 | # List of patterns, relative to source directory, that match files and 56 | # directories to ignore when looking for source files. 57 | # This pattern also affects html_static_path and html_extra_path. 58 | exclude_patterns = ['build'] 59 | 60 | autosummary_generate = True 61 | 62 | intersphinx_mapping = { 63 | 'python': ('https://docs.python.org/3/', None), 64 | 'numpy': ('https://numpy.org/doc/stable/', None), 65 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), 66 | 'jax': ('https://jax.readthedocs.io/en/latest/', None) 67 | } 68 | 69 | # -- MathJax ------------------------------------------------------------------ 70 | 71 | mathjax3_config = { 72 | "tex": { 73 | "inlineMath": [['$', '$'], ['\\(', '\\)']] 74 | }, 75 | "svg": { 76 | "fontCache": 'global' 77 | } 78 | } 79 | 80 | # -- MystNB ------------------------------------------------------------------ 81 | 82 | myst_heading_anchors = 3 83 | 84 | nb_execution_excludepatterns = [ 85 | # Require long computations 86 | 'examples/*', 87 | ] 88 | 89 | nb_render_priority = { 90 | "html": ( 91 | "application/vnd.jupyter.widget-view+json", 92 | "application/javascript", 93 | "text/html", 94 | "image/svg+xml", 95 | "image/png", 96 | "image/jpeg", 97 | "text/markdown", 98 | "text/latex", 99 | "text/plain", 100 | ), 101 | "doctest": ( 102 | "application/vnd.jupyter.widget-view+json", 103 | "application/javascript", 104 | "text/html", 105 | "image/svg+xml", 106 | "image/png", 107 | "image/jpeg", 108 | "text/markdown", 109 | "text/latex", 110 | "text/plain", 111 | ), 112 | "coverage": ( 113 | "application/vnd.jupyter.widget-view+json", 114 | "application/javascript", 115 | "text/html", 116 | "image/svg+xml", 117 | "image/png", 118 | "image/jpeg", 119 | "text/markdown", 120 | "text/latex", 121 | "text/plain", 122 | ) 123 | } 124 | 125 | # -- Options for HTML output ------------------------------------------------- 126 | 127 | # The theme to use for HTML and HTML Help pages. See the documentation for 128 | # a list of builtin themes. 129 | # 130 | html_theme = 'sphinx_rtd_theme' 131 | 132 | # Add any paths that contain custom static files (such as style sheets) here, 133 | # relative to this directory. They are copied after the builtin static files, 134 | # so a file named "default.css" will overwrite the builtin "default.css". 135 | html_static_path = ['_static'] 136 | -------------------------------------------------------------------------------- /jax_sgmc/util/list_map.py: -------------------------------------------------------------------------------- 1 | from functools import partial, wraps 2 | 3 | import jax 4 | from jax import tree_util 5 | import jax.numpy as jnp 6 | 7 | def pytree_list_to_leaves(pytrees): 8 | """Transform a list of pytrees to allow pmap/vmap. 9 | 10 | The trees must have the same tree structure and only differ in the value of 11 | their leaves. This means, that the trees might contain custom nodes, such as 12 | :class:`jax.tree_util.Partial`, but those tree nodes must equivalent. For 13 | example 14 | 15 | .. doctest:: 16 | 17 | >>> from jax.tree_util import Partial 18 | >>> 19 | >>> Partial(lambda x: x + 1) == Partial(lambda x: x + 1) 20 | False 21 | 22 | because they are defined on different functions, but are still equivalent as 23 | the functions perform the same computations. 24 | 25 | Example usage: 26 | 27 | .. doctest:: 28 | 29 | >>> import jax.numpy as jnp 30 | >>> import jax_sgmc.util.list_map as lm 31 | >>> 32 | >>> tree_a = {"a": 0.0, "b": jnp.zeros((2,))} 33 | >>> tree_b = {"a": 1.0, "b": jnp.ones((2,))} 34 | >>> 35 | >>> concat_tree = lm.pytree_list_to_leaves([tree_a, tree_b]) 36 | >>> print(concat_tree) 37 | {'a': Array([0., 1.], dtype=float32, weak_type=True), 'b': Array([[0., 0.], 38 | [1., 1.]], dtype=float32)} 39 | 40 | 41 | Args: 42 | pytrees: A list of trees with similar tree structure and equally shaped 43 | leaves 44 | 45 | Returns: 46 | Returns a tree with the same tree structure but corresponding leaves 47 | concatenated along the first dimension. 48 | 49 | """ 50 | 51 | # Transpose the pytrees, i. e. make a list (array) of leaves from a list of 52 | # pytrees. Only then vmap can be used to vectorize an operation over pytrees 53 | treedef = tree_util.tree_structure(pytrees[0]) 54 | superleaves = [jnp.stack(leaves, axis=0) 55 | for leaves in zip(*map(tree_util.tree_leaves, pytrees))] 56 | return tree_util.tree_unflatten(treedef, superleaves) 57 | 58 | 59 | def pytree_leaves_to_list(pytree): 60 | """Splits a pytree in a list of pytrees. 61 | 62 | Splits every leaf of the pytree along the first dimension, thus undoing the 63 | :func:`pytree_list_to_leaves` transformation. 64 | 65 | Example usage: 66 | 67 | .. doctest:: 68 | 69 | >>> import jax.numpy as jnp 70 | >>> import jax_sgmc.util.list_map as lm 71 | >>> 72 | >>> tree = {"a": jnp.array([0.0, 1.0]), "b": jnp.zeros((2, 2))} 73 | >>> 74 | >>> tree_list = lm.pytree_leaves_to_list(tree) 75 | >>> print(tree_list) 76 | [{'a': Array(0., dtype=float32), 'b': Array([0., 0.], dtype=float32)}, {'a': Array(1., dtype=float32), 'b': Array([0., 0.], dtype=float32)}] 77 | 78 | 79 | Args: 80 | pytree: A single pytree where each leaf has eqal `leaf.shape[0]`. 81 | 82 | Returns: 83 | Returns a list of pytrees with similar structure. 84 | 85 | """ 86 | leaves, treedef = tree_util.tree_flatten(pytree) 87 | num_trees = leaves[0].shape[0] 88 | pytrees = [tree_util.tree_unflatten(treedef, [leaf[idx] for leaf in leaves]) 89 | for idx in range(num_trees)] 90 | return pytrees 91 | 92 | 93 | def list_vmap(fun): 94 | """vmaps a function over similar pytrees. 95 | 96 | Example usage: 97 | 98 | .. doctest:: 99 | 100 | >>> from jax import tree_map 101 | >>> import jax.numpy as jnp 102 | >>> import jax_sgmc.util.list_map as lm 103 | >>> 104 | >>> tree_a = {"a": 0.0, "b": jnp.zeros((2,))} 105 | >>> tree_b = {"a": 1.0, "b": jnp.ones((2,))} 106 | >>> 107 | ... @lm.list_vmap 108 | ... def tree_add(pytree): 109 | ... return tree_map(jnp.subtract, pytree, tree_b) 110 | >>> 111 | >>> print(tree_add(tree_a, tree_b)) 112 | [{'a': Array(-1., dtype=float32, weak_type=True), 'b': Array([-1., -1.], dtype=float32)}, {'a': Array(0., dtype=float32, weak_type=True), 'b': Array([0., 0.], dtype=float32)}] 113 | 114 | Args: 115 | fun: Function accepting a single pytree as first argument. 116 | 117 | Returns: 118 | Returns a vmapped-function accepting multiple pytree args with similar tree- 119 | structure. 120 | 121 | """ 122 | vmap_fun = jax.vmap(fun, 0, 0) 123 | @wraps(fun) 124 | def vmapped(*pytrees): 125 | single_tree = pytree_list_to_leaves(pytrees) 126 | single_result = vmap_fun(single_tree) 127 | return pytree_leaves_to_list(single_result) 128 | return vmapped 129 | 130 | def list_pmap(fun): 131 | """pmaps a function over similar pytrees. 132 | 133 | Args: 134 | fun: Function accepting a single pytree as first argument. 135 | 136 | Returns: 137 | Returns a pmapped-function accepting multiple pytree args with similar tree- 138 | structure. 139 | 140 | """ 141 | pmap_fun = jax.pmap(fun, 0) 142 | @wraps(fun) 143 | def pmapped(*pytrees): 144 | single_tree = pytree_list_to_leaves(pytrees) 145 | single_result = pmap_fun(single_tree) 146 | return pytree_leaves_to_list(single_result) 147 | return pmapped 148 | -------------------------------------------------------------------------------- /jax_sgmc/data/hdf5_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Use samples saved with :class:`jax_sgmc.io.HDF5Collector` as reference data. """ 16 | 17 | import itertools 18 | from typing import Any 19 | 20 | import h5py 21 | import numpy as onp 22 | import jax.numpy as jnp 23 | import jax 24 | 25 | from jax_sgmc.data.numpy_loader import NumpyDataLoader 26 | from jax_sgmc.io import pytree_dict_keys 27 | 28 | PyTree = Any 29 | 30 | # Inherit from NumpyDataLoader because slicing of arrays is similar 31 | class HDF5Loader(NumpyDataLoader): 32 | """Load reference data from HDF5-files. 33 | 34 | This data loader can load reference data stored in HDF5 files. This makes it 35 | possible to use the :mod:`jax_sgmc.data` module to evaluate samples saved via 36 | the :class:`jax_sgmc.io.HDF5Collector`. 37 | 38 | Args: 39 | file: Path to the HDF5 file containing the reference data 40 | subdir: Path to the subset of the data set which should be loaded 41 | sample: PyTree to specify the original shape of the sub-pytree before it 42 | has been saved by the :class:`jax_sgmc.io.HDF5Collector` 43 | 44 | """ 45 | 46 | def __init__(self, file, subdir="/chain~0/variables/", sample=None): 47 | # The sample is necessary to return the observations in the correct format. 48 | super().__init__() 49 | 50 | if isinstance(file, h5py.File): 51 | self._dataset = file 52 | else: 53 | self._dataset = h5py.File(name=file, mode="r") 54 | self._reference_data = ["/".join(itertools.chain([subdir], key_tuple)) 55 | for key_tuple in pytree_dict_keys(sample)] 56 | self._pytree_structure = jax.tree_structure(sample) 57 | self._sample_format = jax.tree_map( 58 | lambda leaf: jax.ShapeDtypeStruct(shape=leaf.shape, dtype=leaf.dtype), 59 | sample) 60 | 61 | observations_counts = [len(self._dataset[leaf_name]) 62 | for leaf_name in self._reference_data] 63 | self._observation_count = observations_counts[0] 64 | 65 | def get_batches(self, chain_id: int) -> PyTree: 66 | """Draws a batch from a chain. 67 | 68 | Args: 69 | chain_id: ID of the chain, which holds the information about the form of 70 | the batch and the process of assembling. 71 | 72 | Returns: 73 | Returns a superbatch as registered by :func:`register_random_pipeline` or 74 | :func:`register_ordered_pipeline` with `cache_size` batches holding 75 | `mb_size` observations. 76 | 77 | """ 78 | # Data slicing is the same for all methods of random and ordered access, 79 | # only the indices for slicing differ. The method _get_indices find the 80 | # correct method for the chain. 81 | selections_idx, selections_mask = self._get_indices(chain_id) 82 | select_unique_idx = [onp.unique(batch_idx, return_inverse=True) 83 | for batch_idx in selections_idx] 84 | 85 | # Slice the data and transform into pytree 86 | selected_observations = [] 87 | for leaf_name in self._reference_data: 88 | unique_selections = [jnp.array(self._dataset[leaf_name][batch_idx]) 89 | for batch_idx, select_unique in select_unique_idx] 90 | selected_observations.append([unique[restore_idx] 91 | for unique, (_, restore_idx) 92 | in zip(unique_selections, select_unique_idx)]) 93 | 94 | selected_observations = [jnp.array(leaf) for leaf in selected_observations] 95 | selected_observations = jax.tree_unflatten(self._pytree_structure, 96 | selected_observations) 97 | 98 | return selected_observations, jnp.array(selections_mask, dtype=jnp.bool_) 99 | 100 | def save_state(self, chain_id: int) -> PyTree: 101 | raise NotImplementedError("Saving of the DataLoader state is not supported.") 102 | 103 | def load_state(self, chain_id: int, data) -> None: 104 | raise NotImplementedError("Loading of the DataLoader state is not supported.") 105 | 106 | @property 107 | def _format(self): 108 | """Returns shape and dtype of a single observation. """ 109 | return self._sample_format 110 | 111 | @property 112 | def static_information(self): 113 | """Returns information about total samples count and batch size. """ 114 | information = { 115 | "observation_count": self._observation_count 116 | } 117 | return information 118 | 119 | def close(self): 120 | self._dataset.close() 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modular Stochastic Gradient MCMC for Jax 2 | 3 | **[Paper](https://www.sciencedirect.com/science/article/pii/S2352711024000931) | [Introduction](#introduction) | 4 | [Implemented Solvers](#quickstart-with-solvers-from-aliaspy) | 5 | [Features](#features) | [Installation](#installation) | 6 | [Contributing](#contributing)** 7 | 8 | [![CI](https://github.com/tummfm/jax-sgmc/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/tummfm/jax-sgmc/actions/workflows/ci.yml) 9 | [![Documentation Status](https://readthedocs.org/projects/jax-sgmc/badge/?version=latest)](https://jax-sgmc.readthedocs.io/en/latest/?badge=latest) 10 | [![PyPI version](https://badge.fury.io/py/jax-sgmc.svg)](https://badge.fury.io/py/jax-sgmc) 11 | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 12 | 13 | ## Introduction 14 | 15 | **JaxSGMC** brings Stochastic Gradient Markov chain Monte Carlo (SGMCMC) 16 | samplers to JAX. Inspired by [optax](https://github.com/deepmind/optax), 17 | **JaxSGMC** is built on a modular concept to increase reusability and 18 | accelerate research of new SGMCMC solvers. Additionally, **JaxSGMC** aims to 19 | promote probabilistic machine learning by removing obstacles in switching 20 | from stochastic optimizers to SGMCMC samplers. 21 | 22 | ## Quickstart with solvers from ``alias.py`` 23 | 24 | To get started quickly using SGMCMC samplers, **JaxSGMC** provides some popular 25 | pre-built samplers in [alias.py](jax_sgmc/alias.py): 26 | 27 | - **SGLD (rms-prop)**: 28 | - **SGHMC**: 29 | - **reSGLD**: 30 | - **SGGMC**: 31 | - **AMAGOLD**: 32 | - **OBABO**: 33 | 34 | ## Features 35 | 36 | ### Modular SGMCMC solvers 37 | 38 | **JaxSGMC** aims to increase reusability of SGMCMC components via a toolbox of 39 | helper functions and a modular concept: 40 | 41 | ![](https://raw.githubusercontent.com/tummfm/jax-sgmc/main/jax-sgmc-structure.svg) 42 | 43 | In the simplest case of employing a pre-built sampler from 44 | [alias.py](jax_sgmc/alias.py), the user only needs to provide the computational 45 | model, consisting of functions for Prior and Likelihood. 46 | Schedulers allow to change sampler properies over the course of the training. 47 | Advanced users may build custom samplers from given components. 48 | 49 | ### Data Input / Output under ``jit`` 50 | 51 | **JaxSGMC** provides a toolbox to pass reference data to the computation 52 | and save collected samples from the Markov chain. 53 | 54 | By combining different data loader / collector classes and general wrappers it 55 | is possible to read data from and save samples to different data types via the 56 | mechanisms of JAX's Host-Callback module. 57 | It is therefore also possible to access datasets bigger than the device memory. 58 | 59 | Saving Data: 60 | - HDF5 61 | - Numpy ``.npz`` 62 | 63 | Loading Data: 64 | - HDF5 65 | - Numpy arrays 66 | - Tensorflow datasets 67 | 68 | ### Computing the stochastic potential 69 | 70 | Stochastic Gradient MCMC requires the evaluation of a potential function for a 71 | batch of data. 72 | **JaxSGMC** allows to compute this potential from likelihoods accepting only 73 | single observations and batches them automatically with sequential, parallel or 74 | vectorized execution. 75 | Moreover, **JaxSGMC** supports passing a model state between the evaluations of 76 | the likelihood function, which is saved corresponding to the samples, speeding 77 | up postprocessing. 78 | 79 | ## Installation 80 | 81 | ### Basic Setup 82 | 83 | **JaxSGMC** can be installed via pip: 84 | 85 | ```shell 86 | pip install jax-sgmc --upgrade 87 | ``` 88 | 89 | The above command installs **Jax for CPU**. To run **JaxSGMC on the GPU**, 90 | the GPU version of JAX has to be installed. 91 | Further information can be found here: 92 | [Jax Installation Instructions](https://github.com/google/jax#installation) 93 | 94 | ### Additional Packages 95 | 96 | Some parts of **JaxSGMC** require additional packages: 97 | 98 | - Data Loading with tensorflow: 99 | ```shell 100 | pip install jax-sgmc[tensorflow] --upgrade 101 | ``` 102 | - Saving Samples in the HDF5-Format: 103 | ```shell 104 | pip install jax-sgmc[hdf5] --upgrade 105 | ``` 106 | 107 | 108 | ### Installation from Source 109 | 110 | For development purposes, **JaxSGMC** can be installed from source in 111 | editable mode: 112 | 113 | ```shell 114 | git clone git@github.com:tummfm/jax-sgmc.git 115 | pip install -e .[test,docs] 116 | ``` 117 | 118 | This command additionally installs the requirements to run the tests: 119 | 120 | ```shell 121 | pytest tests 122 | ``` 123 | 124 | And to build the documentation (e.g. in html): 125 | 126 | ```shell 127 | make -C docs html 128 | ``` 129 | 130 | ## Contributing 131 | 132 | Contributions are always welcome! Please open a pull request to discuss the code 133 | additions. 134 | 135 | ## Citation 136 | 137 | If you use **JaxSGMC** in your own work, please consider citing 138 | 139 | ``` 140 | @article{jaxsgmc2024, 141 | title = {JaxSGMC: Modular stochastic gradient MCMC in JAX}, 142 | journal = {SoftwareX}, 143 | volume = {26}, 144 | pages = {101722}, 145 | year = {2024}, 146 | issn = {2352-7110}, 147 | doi = {https://doi.org/10.1016/j.softx.2024.101722}, 148 | url = {https://www.sciencedirect.com/science/article/pii/S2352711024000931}, 149 | author = {Stephan Thaler and Paul Fuchs and Ana Cukarska and Julija Zavadlav}, 150 | } 151 | ``` 152 | 153 | -------------------------------------------------------------------------------- /tests/test_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | 17 | import jax.numpy as jnp 18 | from jax import random 19 | 20 | import numpy as onp 21 | 22 | import pytest 23 | 24 | from jax_sgmc import scheduler 25 | from jax_sgmc.util import testing 26 | 27 | 28 | class TestScheduler: 29 | step_size = [None, 30 | (scheduler.polynomial_step_size, 31 | {"a": 1.0, "b": 1.0, "gamma": 0.55}), 32 | (scheduler.polynomial_step_size_first_last, 33 | {"first": 1.0, "last": 0.1, "gamma": 0.35})] 34 | 35 | burn_in = [None, 36 | (scheduler.initial_burn_in, 37 | {"n": 100})] 38 | 39 | temperature = [None, 40 | (scheduler.constant_temperature, 41 | {"tau": 0.5})] 42 | 43 | @pytest.mark.parametrize( 44 | "step_size, burn_in, temperature", 45 | itertools.product(step_size, burn_in, temperature)) 46 | def test_scheduler(self, step_size, burn_in, temperature): 47 | """Test that the scheduler can initialize all specific schedulers. Moreover, 48 | the capability to provide default values is tested. """ 49 | 50 | # Initialize all the specific schedulers 51 | if step_size is not None: 52 | fun, kwargs = step_size 53 | step_size = fun(**kwargs) 54 | if burn_in is not None: 55 | fun, kwargs = burn_in 56 | burn_in = fun(**kwargs) 57 | if temperature is not None: 58 | fun, kwargs = temperature 59 | temperature = fun(**kwargs) 60 | 61 | # Initialize the specific scheduler 62 | 63 | schedule = scheduler.init_scheduler(step_size=step_size, 64 | burn_in=burn_in, 65 | temperature=temperature) 66 | 67 | iterations = 100 68 | state, _ = schedule[0](iterations) 69 | 70 | for _ in range(iterations): 71 | 72 | sched = schedule[2](state) 73 | state = schedule[1](state) 74 | 75 | assert sched.step_size.shape == tuple() 76 | assert sched.temperature.shape == tuple() 77 | assert sched.burn_in.shape == tuple() 78 | assert sched.accept 79 | 80 | 81 | class TestStepSize(): 82 | 83 | @pytest.mark.parametrize("first, last, gamma, iterations", 84 | itertools.product([1.0, 0.05], 85 | [0.01, 0.0009], 86 | [0.33, 0.55], 87 | [100, 14723])) 88 | def test_first_last(self, first, last, gamma, iterations): 89 | """Test, that the first and last step size are computed right.""" 90 | 91 | first = jnp.array(first) 92 | last = jnp.array(last) 93 | gamma = jnp.array(gamma) 94 | 95 | schedule = scheduler.polynomial_step_size_first_last(first=first, 96 | last=last, 97 | gamma=gamma) 98 | 99 | state = schedule.init(iterations) 100 | 101 | testing.assert_close(schedule.get(state, 0), first) 102 | testing.assert_close(schedule.get(state, iterations-1), last) 103 | 104 | 105 | class TestBurnIn(): 106 | 107 | @pytest.mark.parametrize("n", [123, 243]) 108 | def test_initial_burn_in(self, n): 109 | """Test, that no off by one error exists.""" 110 | burn_in = scheduler.initial_burn_in(n=n) 111 | 112 | state, _ = burn_in.init(1000) 113 | 114 | # Check that samples are not accepted 115 | for idx in range(n): 116 | bi = burn_in.get(state, idx) 117 | state = burn_in.update(state, idx) 118 | assert bi == 0.0 119 | 120 | # Check that next sample is accepted 121 | assert burn_in.get(state, n + 1) == 1.0 122 | 123 | class TestThinning(): 124 | 125 | @pytest.fixture(params=[100, 1000]) 126 | def burn_in(self, request): 127 | iterations = request.param 128 | accepted = random.bernoulli(random.PRNGKey(0), p=0.3, shape=(100,)) 129 | init = lambda *args: (None, onp.sum(accepted)) 130 | update = lambda *args, **kwargs: None 131 | get = lambda _, iteration, **kwargs: accepted[iteration] 132 | nonzero, = onp.nonzero(accepted) 133 | return scheduler.specific_scheduler(init, update, get), nonzero, iterations 134 | 135 | def test_random_thinning(self, burn_in): 136 | """Given a burn in schedule, the sample can only be accepted if it is not 137 | subject to burn in.""" 138 | 139 | burn_in, non_zero, iterations = burn_in 140 | step_size = scheduler.polynomial_step_size(a=1.0, b=1.0, gamma=1.0) 141 | 142 | thinning = scheduler.random_thinning( 143 | step_size_schedule=step_size, 144 | burn_in_schedule=burn_in, 145 | selections=int(0.5 * non_zero.size)) 146 | state, _ = thinning.init(iterations) 147 | 148 | accepted = 0 149 | for idx in range(iterations): 150 | # If the state is accepted, it must also be not subject to burn in 151 | if thinning.get(state, idx): 152 | assert (idx in non_zero) 153 | accepted += 1 154 | assert accepted == int(0.5 * non_zero.size) -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | """Test io module. """ 2 | from functools import partial 3 | 4 | from jax import random, jit, pmap, vmap, lax, tree_map 5 | import jax.numpy as jnp 6 | 7 | import pytest 8 | 9 | from jax_sgmc import io, scheduler, util 10 | from jax_sgmc.util import testing 11 | 12 | class TestDictTransformation: 13 | pass 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | @pytest.mark.hdf5 22 | class TestHDF5Collector: 23 | pass 24 | 25 | 26 | 27 | 28 | 29 | 30 | class TestMemoryCollector: 31 | pass 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | class TestSaving: 41 | """Test saving.""" 42 | 43 | @pytest.fixture 44 | def test_function(self): 45 | 46 | shape = (4, 3) 47 | 48 | def _update(state, _): 49 | key, split1, split2 = random.split(state["key"], 3) 50 | noise = random.normal(split1, shape=shape) 51 | new_state = {"key": key, 52 | "results": {"noise": noise, 53 | "sum": state["results"]["sum"] + noise}} 54 | return new_state, new_state 55 | 56 | def _init(seed): 57 | init_state = {"key": random.PRNGKey(seed), 58 | "results": {"noise": jnp.zeros(shape), 59 | "sum": jnp.zeros(shape)}} 60 | return init_state 61 | 62 | def _run(init_state, keep): 63 | _, all_results = lax.scan(_update, init_state, keep) 64 | accepted_results = tree_map(lambda leaf: leaf[keep], all_results) 65 | return accepted_results 66 | 67 | return _init, _update, _run 68 | 69 | def test_no_save(self, test_function): 70 | """Test saving by running save against no_save and direct output on random 71 | operations. 72 | """ 73 | 74 | # Calculate the true results 75 | _init, _update, _run = test_function 76 | accepted_list = random.bernoulli(random.PRNGKey(11), shape=(15,)) 77 | count = jnp.sum(accepted_list) 78 | 79 | init_state = _init(0) 80 | reference_solution = _run(init_state, accepted_list) 81 | 82 | # Run the no_save solution 83 | 84 | init_save, save, postprocess_save = io.no_save() 85 | 86 | def no_save_run(init_sample): 87 | def update(state, keep): 88 | saving_state, simulation_state = state 89 | simulation_state, sample = _update(simulation_state, None) 90 | saving_state, _ = save(saving_state, keep, simulation_state) 91 | return (saving_state, simulation_state), None 92 | saving_state = init_save(init_sample, {}, io.scheduler.static_information(samples_collected=count)) 93 | (saving_state, _), _ = lax.scan(update, (saving_state, init_sample), accepted_list) 94 | return postprocess_save(saving_state, None)["samples"] 95 | 96 | no_save_results = no_save_run(init_state) 97 | 98 | # Check close 99 | 100 | testing.assert_equal(no_save_results, reference_solution) 101 | 102 | def test_save(self, test_function): 103 | """Test saving by running save against no_save and direct output on random 104 | operations. 105 | """ 106 | 107 | # Calculate the true results 108 | _init, _update, _run = test_function 109 | accepted_list = random.bernoulli(random.PRNGKey(11), shape=(15,)) 110 | count = jnp.sum(accepted_list) 111 | 112 | init_state = _init(0) 113 | reference_solution = _run(init_state, accepted_list) 114 | 115 | # Run the no_save solution 116 | 117 | data_collector = io.MemoryCollector() 118 | init_save, save, postprocess_save = io.save(data_collector) 119 | 120 | def save_run(init_sample): 121 | def update(state, keep): 122 | saving_state, simulation_state = state 123 | simulation_state, sample = _update(simulation_state, None) 124 | saving_state, _ = save(saving_state, keep, simulation_state) 125 | return (saving_state, simulation_state), None 126 | saving_state = init_save(init_sample, {}, io.scheduler.static_information(samples_collected=count)) 127 | (saving_state, _), _ = lax.scan(update, (saving_state, init_sample), accepted_list) 128 | return postprocess_save(saving_state, None)["samples"] 129 | 130 | save_results = save_run(init_state) 131 | 132 | # Check close 133 | 134 | testing.assert_equal(save_results, reference_solution) 135 | 136 | @pytest.mark.skip 137 | def test_save_vmap(self, test_function): 138 | """Test saving by running save against no_save and direct output on random 139 | operations. 140 | """ 141 | 142 | # Calculate the true results 143 | _init, _update, _run = test_function 144 | accepted_list = random.bernoulli(random.PRNGKey(11), shape=(15,)) 145 | count = jnp.sum(accepted_list) 146 | seeds = jnp.arange(2) 147 | 148 | 149 | init_states = list(map(_init, seeds)) 150 | reference_solution = list(map(lambda state: _run(state, accepted_list), init_states)) 151 | 152 | # Run the no_save solution 153 | 154 | data_collector = io.MemoryCollector() 155 | init_save, save, postprocess_save = io.save(data_collector) 156 | 157 | vmap_init_states = [((init_save(init_sample, {}, io.scheduler.static_information(samples_collected=count)), 158 | init_sample), accepted_list) for init_sample in init_states] 159 | 160 | # We use list_vmap, because we tun the chain over a list of pytrees. Also, 161 | # we currently have to vectorize over the acceptance list because of the 162 | # implementation of stop_vmap 163 | @util.list_vmap 164 | def save_run(init_state): 165 | init_args, accept = init_state 166 | saving_state, init_sample = init_args 167 | def update(state, keep): 168 | saving_state, simulation_state = state 169 | simulation_state, sample = _update(simulation_state, None) 170 | saving_state, _ = save(saving_state, keep, simulation_state) 171 | return (saving_state, simulation_state), None 172 | 173 | (saving_state, _), _ = lax.scan(update, (saving_state, init_sample), accept) 174 | return saving_state 175 | 176 | save_results = [postprocess_save(res, None)["samples"] for res in save_run(*vmap_init_states)] 177 | 178 | # Check close 179 | 180 | testing.assert_equal(save_results, reference_solution) 181 | -------------------------------------------------------------------------------- /jax_sgmc/data/tensorflow_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Load Tensorflow-Datasets in jit-compiled functions. 16 | 17 | The tensorflow dataloader supports tensorflow Datasets, e.g. from the 18 | `tensorflow_datasets` package. 19 | 20 | Note: 21 | This submodule requires that ``tensorflow`` and ``tensorflow_datasets`` are 22 | installed. Additional information can be found in the 23 | :ref:`installation instructions`. 24 | 25 | """ 26 | 27 | from typing import List, Any 28 | 29 | import jax 30 | import jax.numpy as jnp 31 | from jax import tree_util 32 | 33 | from tensorflow import data as tfd 34 | import tensorflow_datasets as tfds 35 | 36 | from jax_sgmc.data.core import HostDataLoader 37 | 38 | PyTree = Any 39 | TFDataSet = tfd.Dataset 40 | 41 | class TensorflowDataLoader(HostDataLoader): 42 | """Load data from a tensorflow dataset object. 43 | 44 | The tensorflow datasets package provides a high number of ready to go 45 | datasets, which can be provided directly to the Tensorflow Data Loader. 46 | 47 | :: 48 | 49 | import tensorflow_datasets as tdf 50 | import tensorflow_datasets as tfds 51 | from jax_sgmc import data 52 | from jax_sgmc.data.tensorflow_loader import TensorflowDataLoader 53 | 54 | pipeline = tfds.load("cifar10", split="train") 55 | data_loader = TensorflowDataLoader(pipeline, shuffle_cache=100, exclude_keys=['id']) 56 | 57 | Args: 58 | pipeline: A tensorflow data pipeline, which can be obtained from the 59 | tensorflow dataset package 60 | 61 | """ 62 | 63 | def __init__(self, 64 | pipeline: TFDataSet, 65 | mini_batch_size: int = None, 66 | shuffle_cache: int = 100, 67 | exclude_keys: List = None): 68 | super().__init__() 69 | # Tensorflow is in general not required to use the library 70 | assert mini_batch_size is None, "Depreceated" 71 | assert TFDataSet is not None, "Tensorflow must be installed to use this " \ 72 | "feature." 73 | assert tfds is not None, "Tensorflow datasets must be installed to use " \ 74 | "this feature." 75 | 76 | self._observation_count = jnp.int32(pipeline.cardinality().numpy()) 77 | # Basic pipeline, from which all other pipelines are constructed 78 | self._pipeline = pipeline 79 | self._exclude_keys = [] if exclude_keys is None else exclude_keys 80 | self._shuffle_cache = shuffle_cache 81 | 82 | self._pipelines: List[TFDataSet] = [] 83 | 84 | def register_random_pipeline(self, 85 | cache_size: int = 1, 86 | mb_size: int = None, 87 | **kwargs) -> int: 88 | """Register a new chain which draws samples randomly. 89 | 90 | Args: 91 | cache_size: The number of drawn batches. 92 | mb_size: The number of observations per batch. 93 | 94 | Returns: 95 | Returns the id of the new chain. 96 | 97 | """ 98 | 99 | # Assert that not kwargs are passed with the intention to control the 100 | # initial state of the tensorflow data loader 101 | assert kwargs == {}, "Tensorflow data loader does not accept additional "\ 102 | "kwargs" 103 | new_chain_id = len(self._pipelines) 104 | 105 | # Randomly draw a number of cache_size mini_batches, where each mini_batch 106 | # contains self.mini_batch_size elements. 107 | random_data = self._pipeline.repeat() 108 | random_data = random_data.shuffle(self._shuffle_cache) 109 | random_data = random_data.batch(mb_size) 110 | random_data = random_data.batch(cache_size) 111 | 112 | # The data must be transformed to numpy arrays, as most numpy arrays can 113 | # be transformed to the duck-typed jax array form 114 | # random_data = tfds.as_numpy(random_data) 115 | random_data = iter(random_data) 116 | 117 | self._pipelines.append(random_data) 118 | 119 | return new_chain_id 120 | 121 | def register_ordered_pipeline(self, 122 | cache_size: int = 1, 123 | mb_size: int = None, 124 | **kwargs 125 | ) -> int: 126 | """Register a chain which assembles batches in an ordered manner. 127 | 128 | Args: 129 | cache_size: The number of drawn batches. 130 | mb_size: The number of observations per batch. 131 | 132 | Returns: 133 | Returns the id of the new chain. 134 | 135 | """ 136 | raise NotImplementedError 137 | 138 | def get_batches(self, chain_id: int) -> PyTree: 139 | """Draws a batch from a chain. 140 | 141 | Args: 142 | chain_id: ID of the chain, which holds the information about the form of 143 | the batch and the process of assembling. 144 | 145 | Returns: 146 | Returns a batch of batches as registered by :func:`register_random_pipeline` or 147 | :func:`register_ordered_pipeline` with `cache_size` batches holding 148 | `mb_size` observations. 149 | 150 | """ 151 | 152 | # Not supported data types, such as strings, must be excluded before 153 | # transformation to jax types. 154 | numpy_batch = next(self._pipelines[chain_id]) 155 | if self._exclude_keys is not None: 156 | for key in self._exclude_keys: 157 | del numpy_batch[key] 158 | 159 | return tree_util.tree_map(jnp.array, numpy_batch), None 160 | 161 | @property 162 | def _format(self): 163 | data_spec = self._pipeline.element_spec 164 | if self._exclude_keys is not None: 165 | not_excluded_elements = {id: elem for id, elem in data_spec.items() 166 | if id not in self._exclude_keys} 167 | else: 168 | not_excluded_elements = data_spec 169 | 170 | def leaf_dtype_struct(leaf): 171 | shape = tuple(int(s) for s in leaf.shape if s is not None) 172 | dtype = leaf.dtype.as_numpy_dtype 173 | return jax.ShapeDtypeStruct( 174 | dtype=dtype, 175 | shape=shape) 176 | 177 | return tree_util.tree_map(leaf_dtype_struct, not_excluded_elements) 178 | 179 | @property 180 | def static_information(self): 181 | """Returns information about total samples count and batch size. """ 182 | information = { 183 | "observation_count" : self._observation_count 184 | } 185 | return information 186 | -------------------------------------------------------------------------------- /docs/usage/potential.rst: -------------------------------------------------------------------------------- 1 | .. _likelihood_to_potential: 2 | Compute Potential from Likelihood 3 | ================================== 4 | 5 | Stochastic Gradient MCMC evaluates the potential and the model 6 | for a subset of observations or all observations. Therefore, this module 7 | acts as an interface between the different likelihoods and the integrators. 8 | The likelihood can be implemented for only a single sample or a batch of data. 9 | 10 | Setup DataLoaders 11 | ------------------- 12 | 13 | For demonstration purposes, we setup a data loader to compute the potential for 14 | a random batch of data as well as for the full dataset. Note that the keyword 15 | arguments selected to initialize the data (here 'mean') have to be used to 16 | access the data of the observations in the likelihood. 17 | 18 | .. doctest:: 19 | 20 | >>> from functools import partial 21 | >>> import jax.numpy as jnp 22 | >>> import jax.scipy as jscp 23 | >>> from jax import random, vmap 24 | >>> from jax_sgmc import data, potential 25 | >>> from jax_sgmc.data.numpy_loader import NumpyDataLoader 26 | 27 | >>> mean = random.normal(random.PRNGKey(0), shape=(100, 5)) 28 | >>> data_loader = NumpyDataLoader(mean=mean) 29 | >>> 30 | >>> test_sample = {'mean': jnp.zeros(5), 'std': jnp.ones(1)} 31 | 32 | 33 | Stochastic Potential 34 | _____________________ 35 | 36 | The stochastic potential is an estimate of the true potential. It is 37 | calculated over a mini-batch and rescaled to the full dataset. 38 | To this end, we need to initialize functions that retreive mini-batches of the 39 | data. 40 | 41 | >>> batch_init, batch_get, _ = data.random_reference_data(data_loader, 42 | ... cached_batches_count=50, 43 | ... mb_size=5) 44 | >>> random_data_state = batch_init() 45 | 46 | 47 | Full Potential 48 | _______________ 49 | 50 | In combination with the :mod:`jax_sgmc.data` it is possible to calculate the 51 | true potential over the full dataset. 52 | If we specify a batch size of 3, then the likelihood will be sequentially 53 | calculated over batches with the size 3. 54 | 55 | 56 | >>> init_fun, fmap_fun, _ = data.full_reference_data(data_loader, 57 | ... cached_batches_count=50, 58 | ... mb_size=3) 59 | >>> data_state = init_fun() 60 | 61 | 62 | Unbatched Likelihood 63 | ---------------------- 64 | 65 | In the simplest case, the likelihood and model function only accept a single 66 | observation and parameter set. 67 | Therefore, this module maps the evaluation over the mini-batch or even all 68 | observations by making use of the Jax tools ``map``, ``vmap`` and ``pmap``. 69 | 70 | The likelihood can be written for a single observation. The 71 | :mod:`jax_sgmc.potential` module then evaluates the likelihood for a batch of 72 | reference data sequentially via ``map`` or in parallel via ``vmap`` or ``pmap``. 73 | The first input to the likelihood function is the sample, i.e. the model 74 | parameters. You can access all parameters of the dict via the keywords defined 75 | in the initial sample (e.g. 'test_sample' above). The second input is the 76 | observation from the dataset, where the data can be accessed with the same 77 | keyword arguments used during the initialization of the DataLoader. 78 | 79 | >>> def likelihood(sample, observation): 80 | ... likelihoods = jscp.stats.norm.logpdf(observation['mean'], 81 | ... loc=sample['mean'], 82 | ... scale=sample['std']) 83 | ... return jnp.sum(likelihoods) 84 | >>> prior = lambda unused_sample: 0.0 85 | 86 | 87 | Stochastic Potential 88 | ______________________ 89 | 90 | The stochastic potential is computed automatically from the prior and likelihood 91 | of a single observation. 92 | 93 | >>> 94 | >>> stochastic_potential_fn = potential.minibatch_potential(prior, 95 | ... likelihood, 96 | ... strategy='map') 97 | >>> new_random_data_state, random_batch = batch_get(random_data_state, information=True) 98 | >>> potential_eval, unused_state = stochastic_potential_fn(test_sample, random_batch) 99 | >>> 100 | >>> print(round(potential_eval)) 101 | 838 102 | 103 | For debugging purposes, it is recommended to check with a test sample and a test 104 | observation whether the potential is evaluated correctly. This simplifies the 105 | search for bugs without the overhead from the SG-MCMC sampler. 106 | 107 | Full Potential 108 | _______________ 109 | 110 | Here, the likelihood written for a single observation can be re-used. 111 | 112 | >>> potential_fn = potential.full_potential(prior, likelihood, strategy='vmap') 113 | >>> 114 | >>> potential_eval, (data_state, unused_state) = potential_fn( 115 | ... test_sample, data_state, fmap_fun) 116 | >>> 117 | >>> print(round(potential_eval)) 118 | 707 119 | 120 | 121 | 122 | Batched Likelihood 123 | ------------------ 124 | 125 | Some models already accept a batch of reference data. In this case, the 126 | potential function can be constructed by setting ``is_batched = True``. In this 127 | case, it is expected that the returned likelihoods are a vector with shape 128 | ``(N,)``, where N is the batch-size. 129 | 130 | 131 | >>> @partial(vmap, in_axes=(None, 0)) 132 | ... def batched_likelihood(sample, observation): 133 | ... likelihoods = jscp.stats.norm.logpdf(observation['mean'], 134 | ... loc=sample['mean'], 135 | ... scale=sample['std']) 136 | ... # Only valid samples contribute to the likelihood 137 | ... return jnp.sum(likelihoods) 138 | >>> 139 | 140 | 141 | Stochastic Potential 142 | _____________________ 143 | 144 | To compute the correct potential now, the function needs to know that the 145 | likelihood is batched by setting ``is_batched=True``. The strategy setting 146 | has no meaning anymore and can be kept on the default value. 147 | 148 | >>> stochastic_potential_fn = potential.minibatch_potential(prior, 149 | ... batched_likelihood, 150 | ... is_batched=True, 151 | ... strategy='map') 152 | >>> 153 | >>> new_random_data_state, random_batch = batch_get(random_data_state, information=True) 154 | >>> potential_eval, unused_state = stochastic_potential_fn(test_sample, random_batch) 155 | >>> 156 | >>> print(round(potential_eval)) 157 | 838 158 | >>> 159 | >>> _, (likelihoods, _) = stochastic_potential_fn(test_sample, 160 | ... random_batch, 161 | ... likelihoods=True) 162 | >>> 163 | >>> print(round(jnp.var(likelihoods))) 164 | 7 165 | 166 | Full Potential 167 | __________________ 168 | 169 | The batched likelihood can also be used to calculate the full potential. 170 | 171 | >>> prior = lambda unused_sample: 0.0 172 | >>> 173 | >>> potential_fn = potential.full_potential(prior, batched_likelihood, is_batched=True) 174 | >>> 175 | >>> potential_eval, (data_state, unused_state) = potential_fn( 176 | ... test_sample, data_state, fmap_fun) 177 | >>> 178 | >>> print(round(potential_eval)) 179 | 707 180 | 181 | Likelihoods with States 182 | ------------------------ 183 | 184 | By setting the argument ``has_state = True``, the likelihood accepts an 185 | additional state as first positional argument. This state should not influence 186 | the results of the computation. 187 | 188 | >>> def stateful_likelihood(state, sample, observation): 189 | ... n, mean = state 190 | ... n += 1 191 | ... new_mean = (n-1)/n * mean + 1/n * observation['mean'] 192 | ... 193 | ... likelihoods = jscp.stats.norm.logpdf((observation['mean'] - new_mean), 194 | ... loc=(sample['mean'] - new_mean), 195 | ... scale=sample['std']) 196 | ... return jnp.sum(likelihoods), (n, new_mean) 197 | 198 | .. note:: 199 | If the likelihood is not batched (``is_batched=False``), only the state 200 | corresponding to the computation with the first sample of the batch is 201 | returned. 202 | 203 | Stochastic Potential 204 | ____________________ 205 | 206 | >>> potential_fn = potential.minibatch_potential(prior, 207 | ... stateful_likelihood, 208 | ... has_state=True) 209 | >>> 210 | >>> potential_eval, new_state = potential_fn(test_sample, 211 | ... random_batch, 212 | ... state=(jnp.array(2), jnp.ones(5))) 213 | >>> 214 | >>> print(round(potential_eval)) 215 | 838 216 | >>> print(f"n: {new_state[0] : d}") 217 | n: 3 218 | 219 | Full Potential 220 | _______________ 221 | 222 | 223 | >>> full_potential_fn = potential.full_potential(prior, 224 | ... stateful_likelihood, 225 | ... has_state=True) 226 | >>> 227 | >>> potential_eval, (cache_state, new_state) = full_potential_fn( 228 | ... test_sample, data_state, fmap_fun, state=(jnp.array(2), jnp.ones(5))) 229 | >>> 230 | >>> print(f"n: {new_state[0] : d}") 231 | n: 36 232 | -------------------------------------------------------------------------------- /examples/quickstart.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | formats: examples///ipynb,examples///md:myst,docs///ipynb 4 | main_language: python 5 | text_representation: 6 | extension: .md 7 | format_name: myst 8 | format_version: 0.13 9 | jupytext_version: 1.14.4 10 | kernelspec: 11 | display_name: Python 3 (ipykernel) 12 | language: python 13 | name: python3 14 | --- 15 | 16 | ```{raw-cell} 17 | 18 | --- 19 | Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 20 | 21 | Licensed under the Apache License, Version 2.0 (the "License"); 22 | you may not use this file except in compliance with the License. 23 | You may obtain a copy of the License at 24 | 25 | http://www.apache.org/licenses/LICENSE-2.0 26 | 27 | Unless required by applicable law or agreed to in writing, software 28 | distributed under the License is distributed on an "AS IS" BASIS, 29 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | See the License for the specific language governing permissions and 31 | limitations under the License. 32 | --- 33 | ``` 34 | 35 | ```{code-cell} 36 | :tags: [hide-cell] 37 | 38 | import warnings 39 | warnings.filterwarnings('ignore') 40 | 41 | import time 42 | 43 | import numpy as onp 44 | 45 | import jax.numpy as jnp 46 | 47 | from jax import random 48 | 49 | import matplotlib.pyplot as plt 50 | 51 | from jax.scipy.stats import norm 52 | 53 | from numpyro import sample as npy_smpl 54 | import numpyro.infer as npy_inf 55 | import numpyro.distributions as npy_dist 56 | 57 | from scipy.stats import gaussian_kde 58 | 59 | 60 | from jax_sgmc import potential 61 | from jax_sgmc.data.numpy_loader import NumpyDataLoader 62 | from jax_sgmc import alias 63 | ``` 64 | 65 | # Quickstart 66 | 67 | ## Data Generation 68 | 69 | For demonstration purposes, we look at the simple problem 70 | 71 | ```{math} 72 | y^{(i)} \sim \mathcal{N}\left(\sum_{j=1}^d w_jx_j^{(i)}, \sigma^2\right) 73 | ``` 74 | 75 | where $d \ll N$ such that we have a large amount of reference data. 76 | 77 | The reference data are generated such that the weights are correlated: 78 | 79 | ```{math} 80 | u_1, u_2, u_3, u_4 \sim \mathcal{U}\left(-1, 1 \right) 81 | ``` 82 | 83 | and 84 | 85 | ```{math} 86 | \boldsymbol{x} = \left(\begin{array}{c} u_1 + u_2 \\ u_2 \\ 0.1u_3 -0.5u_4 \\ u_4 \end{array} \right). 87 | ``` 88 | 89 | The correct solution $w$ is drawn randomly from 90 | 91 | ```{math} 92 | w_j \sim \mathcal{U}\left(-1, 1\right) 93 | ``` 94 | 95 | and the standard deviation of the error is chosen to be 96 | 97 | ```{math} 98 | \sigma = 0.5. 99 | ``` 100 | 101 | ```{code-cell} 102 | N = 4 103 | samples = 1000 # Total samples 104 | 105 | key = random.PRNGKey(0) 106 | split1, split2, split3 = random.split(key, 3) 107 | 108 | # Correct solution 109 | sigma = 0.5 110 | w = random.uniform(split3, minval=-1, maxval=1, shape=(N, 1)) 111 | 112 | # Data generation 113 | noise = sigma * random.normal(split2, shape=(samples, 1)) 114 | x = random.uniform(split1, minval=-10, maxval=10, shape=(samples, N)) 115 | x = jnp.stack([x[:, 0] + x[:, 1], x[:, 1], 0.1 * x[:, 2] - 0.5 * x[:, 3], 116 | x[:, 3]]).transpose() 117 | y = jnp.matmul(x, w) + noise 118 | ``` 119 | 120 | ## Data Loader 121 | 122 | A feature of **JaxSGMC** is that it can store large datasets on the host and 123 | only send chunks of it to jit-compiled functions on the device (GPU) 124 | when they are required. 125 | 126 | Therefore, reference data must be stored in a ``DataLoader`` class, which 127 | also takes care of batching and shuffling. 128 | 129 | If the data fit into memory and are available as numpy arrays, then the 130 | ``NumpyDataLoader`` can be used. It expects a single array or multiple arrays where 131 | all observations are concatenated along the first dimension. The arrays are 132 | passed as keyword-arguments and the batches are returned as a flat dictionary 133 | with the corresponding keys. 134 | 135 | For our dataset, we stick to the names x and y such that we can later access the 136 | data via ``batch['x']`` and ``batch['y']``: 137 | 138 | ```{code-cell} 139 | data_loader = NumpyDataLoader(x=x, y=y) 140 | ``` 141 | 142 | Sometimes, a model needs the shape and type of the data to initialize its state. 143 | Therefore, each DataLoader has a method to get an all-zero observation and an 144 | all-zero batch of observations: 145 | 146 | ```{code-cell} 147 | # Print a single observation 148 | print("Single observation:") 149 | print(data_loader.initializer_batch()) 150 | 151 | # Print a batch of observations, e. g. to initialize the model 152 | print("Batch of two observations:") 153 | print(data_loader.initializer_batch(2)) 154 | ``` 155 | 156 | Note that the returned dictionary has the keys ``x`` and ``y``, just like the 157 | arrays have been passed to the ``NumpyDataLoader``. 158 | 159 | ## (Log-)Likelihood and (Log-)Prior 160 | 161 | The model is connected to the solver via the (log-)prior and (log-)likelihood 162 | function. The model for our problem is: 163 | 164 | ```{code-cell} 165 | def model(sample, observations): 166 | weights = sample["w"] 167 | predictors = observations["x"] 168 | return jnp.dot(predictors, weights) 169 | ``` 170 | 171 | **JaxSGMC** supports samples in the form of pytrees, so no flattening of e.g. 172 | Neural Net parameters is necessary. In our case we can separate the standard 173 | deviation, which is only part of the likelihood, from the weights by using a 174 | dictionary: 175 | 176 | ```{code-cell} 177 | def likelihood(sample, observations): 178 | sigma = jnp.exp(sample["log_sigma"]) 179 | y = observations["y"] 180 | y_pred = model(sample, observations) 181 | return norm.logpdf(y - y_pred, scale=sigma) 182 | 183 | def prior(sample): 184 | return 1 / jnp.exp(sample["log_sigma"]) 185 | 186 | ``` 187 | 188 | The (log-)prior and (log-)likelihood are not passed to the solver directly, but are 189 | first transformed into a (stochastic) potential. 190 | This allowed us to formulate the model and also the log-likelihood with only a single 191 | observation in mind and let **JaxSGMC** take care of evaluating it for a batch 192 | of observations. As the model is not computationally demanding, we let 193 | **JaxSGMC** vectorize the evaluation of the likelihood: 194 | 195 | ```{code-cell} 196 | potential_fn = potential.minibatch_potential(prior=prior, 197 | likelihood=likelihood, 198 | strategy="vmap") 199 | ``` 200 | 201 | For more complex models it is also possible to sequentially evaluate the 202 | likelihood via ``strategy="map"`` or to make use of multiple accelerators via ``strategy="pmap"``. 203 | 204 | Note that it is also possible to write the likelihood for a batch of 205 | observations and that **JaxSGMC** also supports stateful models (see 206 | {doc}`/usage/potential`). 207 | 208 | ## Solvers in alias.py 209 | 210 | Samples can be generated by using one of the ready to use solvers in 211 | **JaxSGMC**, which can be found in ``alias.py``. 212 | 213 | First, the solver must be built from the previously generated potential, the 214 | data loader and some static settings. 215 | 216 | ```{code-cell} 217 | :tags: [remove-output] 218 | 219 | sghmc = alias.sghmc(potential_fn, 220 | data_loader, 221 | cache_size=512, 222 | batch_size=2, 223 | burn_in=5000, 224 | accepted_samples=1000, 225 | integration_steps=5, 226 | adapt_noise_model=False, 227 | friction=0.9, 228 | first_step_size=0.01, 229 | last_step_size=0.00015, 230 | diagonal_noise=False) 231 | ``` 232 | 233 | Afterward, the solver can be applied to multiple initial samples, which are 234 | passed as positional arguments. Each inital sample is the starting point of a 235 | dedicated Markov chain, allowing straightforward multichain SG-MCMC sampling. 236 | Note that the initial sample has the same from as the sample we expect in our 237 | likelihood. 238 | The solver then returns a list of results, one for each initial sample, which 239 | contains solver-specific additional information beside the samples: 240 | 241 | ```{code-cell} 242 | start = time.time() 243 | iterations = 30000 244 | init_sample = {"w": jnp.zeros((N, 1)), "log_sigma": jnp.array(2.5)} 245 | 246 | # Run the solver 247 | results = sghmc(init_sample, iterations=iterations) 248 | 249 | # Access the obtained samples from the first Markov chain 250 | results = results[0]["samples"]["variables"] 251 | sghmc_time = time.time() - start 252 | ``` 253 | 254 | For a full list of ready to use solvers see {doc}`/api/jax_sgmc.alias`. 255 | Moreover, it is possible to construct custom solvers by the combination of 256 | different modules. 257 | 258 | ## Comparison with NumPyro 259 | 260 | In the following section, we plot the results of the solver and compare them with 261 | a solution returned by [NumPyro](https://github.com/pyro-ppl/numpyro). 262 | 263 | ### NumPyro Solution 264 | 265 | ```{code-cell} 266 | def numpyro_model(y_obs=None): 267 | sigma = npy_smpl("sigma", npy_dist.Uniform(low=0.0, high=10.0)) 268 | weights = npy_smpl("weights", 269 | npy_dist.Uniform(low=-10 * jnp.ones((N, 1)), 270 | high=10 * jnp.ones((N, 1)))) 271 | 272 | y_pred = jnp.matmul(x, weights) 273 | npy_smpl("likelihood", npy_dist.Normal(loc=y_pred, scale=sigma), obs=y_obs) 274 | 275 | # Collect 1000 samples 276 | 277 | kernel = npy_inf.HMC(numpyro_model) 278 | mcmc = npy_inf.MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=False) 279 | start = time.time() 280 | mcmc.run(random.PRNGKey(0), y_obs=y) 281 | mcmc.print_summary() 282 | hmc_time = time.time() - start 283 | 284 | w_npy = mcmc.get_samples()["weights"] 285 | ``` 286 | 287 | ### Comparison 288 | 289 | ```{code-cell} 290 | print(f"Runtime of SGHMC: {sghmc_time :.0f} seconds\nRuntume of HMC: {hmc_time: .0f} seconds") 291 | ``` 292 | 293 | ```{code-cell} 294 | :tags: [hide-input] 295 | 296 | w_npy = mcmc.get_samples()["weights"] 297 | 298 | fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(13, 6)) 299 | 300 | ax1.set_title("σ") 301 | 302 | ax1.plot(onp.exp(results["log_sigma"]), label="SGHMC σ") 303 | ax1.plot([sigma]*onp.shape(results["log_sigma"])[0], label="True σ") 304 | ax1.set_xlabel("Sample Number") 305 | ax1.set_ylim([0.4, 0.6]) 306 | ax1.legend() 307 | 308 | w_sghmc = results["w"] 309 | 310 | # Contours of NumPyro solution 311 | 312 | levels = onp.linspace(0.1, 1.0, 5) 313 | 314 | # w1 vs w2 315 | w12 = gaussian_kde(jnp.squeeze(w_npy[:, 0:2].transpose())) 316 | w1d = onp.linspace(0.00, 0.20, 100) 317 | w2d = onp.linspace(-0.70, -0.30, 100) 318 | W1d, W2d = onp.meshgrid(w1d, w2d) 319 | p12d = onp.vstack([W1d.ravel(), W2d.ravel()]) 320 | Z12d = onp.reshape(w12(p12d).T, W1d.shape) 321 | Z12d /= Z12d.max() 322 | 323 | ax2.set_xlabel("$w_1$") 324 | ax2.set_ylabel("$w_2$") 325 | ax2.set_title("$w_1$ vs $w_2$") 326 | ax2.set_xlim([0.07, 0.12]) 327 | ax2.set_ylim([-0.515, -0.455]) 328 | ax2.contour(W1d, W2d, Z12d, levels, colors='red', linewidths=0.5, label="Numpyro") 329 | ax2.plot(w_sghmc[:, 0], w_sghmc[:, 1], 'o', alpha=0.5, markersize=1, zorder=-1, label="SGHMC") 330 | ax2.legend() 331 | # w3 vs w4 332 | 333 | w34 = gaussian_kde(jnp.squeeze(w_npy[:, 2:4].transpose())) 334 | w3d = onp.linspace(-0.3, -0.05, 100) 335 | w4d = onp.linspace(-0.75, -0.575, 100) 336 | W3d, W4d = onp.meshgrid(w3d, w4d) 337 | p34d = onp.vstack([W3d.ravel(), W4d.ravel()]) 338 | Z34d = onp.reshape(w34(p34d).T, W3d.shape) 339 | Z34d /= Z34d.max() 340 | 341 | ax3.set_xlabel("$w_3$") 342 | ax3.set_ylabel("$w_4$") 343 | ax3.set_title("$w_3$ vs $w_4$") 344 | ax3.contour(W3d, W4d, Z34d, levels, colors='red', linewidths=0.5, label="Numpyro") 345 | ax3.plot(w_sghmc[:, 2], w_sghmc[:, 3], 'o', alpha=0.5, markersize=1, zorder=-1, label="SGHMC") 346 | ax3.legend() 347 | fig.tight_layout(pad=0.3) 348 | plt.show() 349 | ``` 350 | 351 | ```{code-cell} 352 | 353 | ``` 354 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/test_alias.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test the convergence of the solver on small toy problem. """ 16 | 17 | import jax 18 | from jax import random 19 | import jax.numpy as jnp 20 | from jax.scipy.stats import norm 21 | 22 | from scipy import stats as scpstats 23 | 24 | import pytest 25 | 26 | from jax_sgmc import data 27 | from jax_sgmc import potential 28 | from jax_sgmc import alias 29 | from jax_sgmc.data.numpy_loader import NumpyDataLoader 30 | 31 | 32 | @pytest.fixture 33 | def kolmogorov_smirnov_setup(): 34 | 35 | # Sample from a gaussian distribution 36 | samples = 100 37 | sigma = 0.5 38 | 39 | key = random.PRNGKey(11) 40 | x = sigma * random.normal(key, (samples, )) 41 | 42 | # No need for reference data 43 | data_loader = data.numpy_loader.DeviceNumpyDataLoader(x=jnp.zeros((2, 1))) 44 | batch_fn = data.random_reference_data(data_loader, 1, 1, verify_calls=True) 45 | 46 | def likelihood_fn(sample, _): 47 | return -0.5 * (sample / sigma) ** 2 48 | 49 | def prior_fn(sample): 50 | return jnp.asarray(0.0) 51 | 52 | potential_fn = potential.minibatch_potential( 53 | prior_fn, likelihood_fn, strategy="vmap") 54 | full_potential_fn = potential.full_potential( 55 | prior_fn, likelihood_fn, strategy="vmap" 56 | ) 57 | 58 | def check_fn(sampled): 59 | statistic = scpstats.kstest(sampled, x) 60 | assert statistic.pvalue > 0.05, ( 61 | f"Solver generated non-normal distributed samples with 95% confidence " 62 | f"(p_value is {statistic.pvalue})." 63 | ) 64 | 65 | return data_loader, batch_fn, potential_fn, full_potential_fn, x[0], check_fn 66 | 67 | 68 | @pytest.fixture 69 | def problem(): 70 | 71 | # Reference Data 72 | 73 | N = 4 74 | samples = 1000 # Total samples 75 | sigma = 0.5 76 | 77 | key = random.PRNGKey(0) 78 | split1, split2, split3 = random.split(key, 3) 79 | 80 | w = random.uniform(split3, minval=-1, maxval=1, shape=(N, 1)) 81 | noise = sigma * random.normal(split2, shape=(samples, 1)) 82 | x = random.uniform(split1, minval=-10, maxval=10, shape=(samples, N)) 83 | x = jnp.stack([x[:, 0] + x[:, 1], x[:, 1], 0.1 * x[:, 2] - 0.5 * x[:, 3], 84 | x[:, 3]]).transpose() 85 | y = jnp.matmul(x, w) + noise 86 | w_init = sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(10.0)} 87 | 88 | M = 10 89 | cs = 1000 90 | 91 | data_loader = NumpyDataLoader(x=x, y=y) 92 | batch_fn = data.random_reference_data(data_loader, 93 | cached_batches_count=cs, 94 | mb_size=M) 95 | 96 | def model(sample, observations): 97 | weights = sample["w"] 98 | predictors = observations["x"] 99 | return jnp.dot(predictors, weights) 100 | 101 | def likelihood(sample, observations): 102 | sigma = sample["sigma"] 103 | y = observations["y"] 104 | y_pred = model(sample, observations) 105 | return norm.logpdf(y - y_pred, scale=sigma) 106 | 107 | def prior(unused_sample): 108 | return 0.0 109 | 110 | # If the model is more complex, the strategy can be set to map for sequential 111 | # evaluation and pmap for parallel evaluation. 112 | potential_fn = potential.minibatch_potential(prior=prior, 113 | likelihood=likelihood, 114 | strategy="vmap") 115 | full_potential_fn = potential.full_potential(prior=prior, 116 | likelihood=likelihood, 117 | strategy="vmap") 118 | return data_loader, batch_fn, potential_fn, full_potential_fn, w, w_init 119 | 120 | 121 | class TestAliasKolmogorovSmirnov: 122 | 123 | @pytest.mark.solver 124 | def test_rms_prop(self, kolmogorov_smirnov_setup): 125 | data_loader, batch_fn, potential_fn, _, init_sample, assert_fn =\ 126 | kolmogorov_smirnov_setup 127 | 128 | solver = alias.sgld( 129 | potential_fn, 130 | data_loader, 131 | cache_size=1, 132 | batch_size=1, 133 | first_step_size=0.5, 134 | last_step_size=0.1, 135 | burn_in=100, 136 | accepted_samples=100, 137 | rms_prop=True 138 | ) 139 | 140 | sampled = solver(init_sample, iterations=500)[0]["samples"]["variables"] 141 | assert_fn(sampled) 142 | 143 | @pytest.mark.solver 144 | def test_re_sgld(self, kolmogorov_smirnov_setup): 145 | data_loader, batch_fn, potential_fn, _, init_sample, assert_fn =\ 146 | kolmogorov_smirnov_setup 147 | 148 | solver = alias.re_sgld( 149 | potential_fn, 150 | data_loader, 151 | cache_size=1, 152 | batch_size=1, 153 | first_step_size=0.5, 154 | last_step_size=0.1, 155 | burn_in=100, 156 | accepted_samples=100, 157 | temperature=1.5 158 | ) 159 | 160 | sampled = solver( 161 | (init_sample, init_sample), iterations=500 162 | )[0]["samples"]["variables"] 163 | assert_fn(sampled) 164 | 165 | @pytest.mark.solver 166 | def test_amagold(self, kolmogorov_smirnov_setup): 167 | data_loader, batch_fn, potential_fn, full_pot_fn, init_sample, assert_fn =\ 168 | kolmogorov_smirnov_setup 169 | 170 | solver = alias.amagold( 171 | potential_fn, 172 | full_pot_fn, 173 | data_loader, 174 | cache_size=1, 175 | batch_size=1, 176 | first_step_size=0.5, 177 | last_step_size=0.1, 178 | burn_in=100 179 | ) 180 | 181 | sampled = solver(init_sample, iterations=200)[0]["samples"]["variables"] 182 | assert_fn(sampled) 183 | 184 | @pytest.mark.solver 185 | def test_sggmc(self, kolmogorov_smirnov_setup): 186 | data_loader, batch_fn, potential_fn, full_pot_fn, init_sample, assert_fn =\ 187 | kolmogorov_smirnov_setup 188 | 189 | solver = alias.sggmc( 190 | potential_fn, 191 | full_pot_fn, 192 | data_loader, 193 | cache_size=1, 194 | batch_size=1, 195 | first_step_size=0.5, 196 | last_step_size=0.1, 197 | burn_in=100 198 | ) 199 | 200 | sampled = solver(init_sample, iterations=200)[0]["samples"]["variables"] 201 | assert_fn(sampled) 202 | 203 | @pytest.mark.solver 204 | def test_obabo(self, kolmogorov_smirnov_setup): 205 | data_loader, batch_fn, potential_fn, _, init_sample, assert_fn =\ 206 | kolmogorov_smirnov_setup 207 | 208 | solver = alias.obabo( 209 | potential_fn, 210 | data_loader, 211 | cache_size=1, 212 | batch_size=1, 213 | first_step_size=0.5, 214 | last_step_size=0.1, 215 | friction=10, 216 | burn_in=100, 217 | accepted_samples=100 218 | ) 219 | 220 | sampled = solver(init_sample, iterations=500)[0]["samples"]["variables"] 221 | assert_fn(sampled) 222 | 223 | @pytest.mark.solver 224 | def test_sghmc(self, kolmogorov_smirnov_setup): 225 | data_loader, batch_fn, potential_fn, _, init_sample, assert_fn =\ 226 | kolmogorov_smirnov_setup 227 | 228 | solver = alias.sghmc( 229 | potential_fn, 230 | data_loader, 231 | cache_size=1, 232 | batch_size=1, 233 | integration_steps=5, 234 | first_step_size=0.5, 235 | last_step_size=0.1, 236 | friction=0.995, 237 | burn_in=100, 238 | accepted_samples=100, 239 | adapt_noise_model=False, 240 | ) 241 | 242 | sampled = solver(init_sample, iterations=500)[0]["samples"]["variables"] 243 | 244 | assert_fn(sampled) 245 | 246 | 247 | class TestAliasLinearRegression: 248 | 249 | @pytest.mark.solver 250 | def test_rms_prop(self, problem): 251 | data_loader, batch_fn, potential_fn, _, w, w_init = problem 252 | 253 | solver = alias.sgld( 254 | potential_fn, 255 | data_loader, 256 | cache_size=512, 257 | batch_size=10, 258 | first_step_size=0.05, 259 | last_step_size=0.001, 260 | burn_in=20000, 261 | accepted_samples=4000, 262 | rms_prop=True 263 | ) 264 | 265 | results = solver(w_init, iterations=50000) 266 | 267 | # Check that the standard deviation is close 268 | assert jnp.all( 269 | jnp.abs(results[0]["samples"]["variables"]["sigma"] - 0.5) < 0.5) 270 | 271 | @pytest.mark.solver 272 | def test_re_sgld(self, problem): 273 | data_loader, batch_fn, potential_fn, _, w, w_init = problem 274 | 275 | solver = alias.re_sgld( 276 | potential_fn, 277 | data_loader, 278 | cache_size=512, 279 | batch_size=10, 280 | first_step_size=0.0001, 281 | last_step_size=0.000005, 282 | burn_in=20000, 283 | accepted_samples=4000, 284 | temperature=100.0 285 | ) 286 | 287 | results = solver((w_init, w_init), iterations=50000) 288 | 289 | # Check that the standard deviation is close 290 | assert jnp.all( 291 | jnp.abs(results[0]["samples"]["variables"]["sigma"] - 0.7) < 0.7) 292 | 293 | @pytest.mark.solver 294 | def test_amagold(self, problem): 295 | data_loader, batch_fn, potential_fn, full_potential_fn, w, w_init = problem 296 | 297 | solver = alias.amagold( 298 | potential_fn, 299 | full_potential_fn, 300 | data_loader, 301 | cache_size=512, 302 | batch_size=64, 303 | first_step_size=0.005, 304 | last_step_size=0.0005, 305 | burn_in=2000 306 | ) 307 | 308 | results = solver(w_init, iterations=50000) 309 | 310 | # Check that the standard deviation is close 311 | assert jnp.all( 312 | jnp.abs(results[0]["samples"]["variables"]["sigma"] - 0.5) < 0.5) 313 | 314 | @pytest.mark.solver 315 | def test_sggmc(self, problem): 316 | data_loader, batch_fn, potential_fn, full_potential_fn, w, w_init = problem 317 | 318 | solver = alias.sggmc( 319 | potential_fn, 320 | full_potential_fn, 321 | data_loader, 322 | cache_size=512, 323 | batch_size=64, 324 | first_step_size=0.005, 325 | last_step_size=0.0005, 326 | burn_in=2000 327 | ) 328 | 329 | results = solver(w_init, iterations=50000) 330 | 331 | # Check that the standard deviation is close 332 | assert jnp.all( 333 | jnp.abs(results[0]["samples"]["variables"]["sigma"] - 0.5) < 0.5) 334 | 335 | @pytest.mark.solver 336 | def test_obabo(self, problem): 337 | data_loader, batch_fn, potential_fn, _, w, w_init = problem 338 | 339 | solver = alias.obabo( 340 | potential_fn, 341 | data_loader, 342 | cache_size=512, 343 | batch_size=10, 344 | first_step_size=0.05, 345 | last_step_size=0.001, 346 | friction=1000, 347 | burn_in=2000 348 | ) 349 | 350 | results = solver(w_init, iterations=50000) 351 | 352 | # Check that the standard deviation is close 353 | assert jnp.all( 354 | jnp.abs(results[0]["samples"]["variables"]["sigma"] - 0.5) < 0.5) 355 | 356 | @pytest.mark.solver 357 | def test_sghmc(self, problem): 358 | data_loader, batch_fn, potential_fn, _, w, w_init = problem 359 | 360 | solver = alias.sghmc( 361 | potential_fn, 362 | data_loader, 363 | cache_size=512, 364 | batch_size=10, 365 | integration_steps=5, 366 | first_step_size=0.05, 367 | last_step_size=0.001, 368 | friction=0.5, 369 | burn_in=5000, 370 | adapt_noise_model=True, 371 | diagonal_noise=False, 372 | ) 373 | 374 | results = solver(w_init, iterations=10000) 375 | 376 | # Check that the standard deviation is close 377 | assert jnp.all( 378 | jnp.abs(results[0]["samples"]["variables"]["sigma"] - 0.5) < 0.5) 379 | -------------------------------------------------------------------------------- /jax_sgmc/potential.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility to evaluate stochastic or true potential. 16 | 17 | This module transforms the likelihood function for a single observation or a 18 | batch of observations to a function calculating the stochastic or full potential 19 | making use of ``map``, ``vmap`` and ``pmap``. 20 | 21 | """ 22 | 23 | from functools import partial 24 | 25 | from typing import Callable, Any, AnyStr, Optional, Tuple, Union, Protocol 26 | 27 | from jax import vmap, pmap, lax, tree_util, named_call 28 | 29 | import jax.numpy as jnp 30 | 31 | from jax_sgmc import util 32 | from jax_sgmc.data import CacheState, MiniBatch 33 | 34 | PyTree = Any 35 | Array = util.Array 36 | 37 | Likelihood = Union[ 38 | Callable[[PyTree, PyTree, MiniBatch], Tuple[Array, PyTree]], 39 | Callable[[PyTree, MiniBatch], Array]] 40 | Prior = Callable[[PyTree], Array] 41 | 42 | class StochasticPotential(Protocol): 43 | def __call__(self, 44 | sample: PyTree, 45 | reference_data: MiniBatch, 46 | state: PyTree = None, 47 | mask: Array = None, 48 | likelihoods: bool = False 49 | ) -> Union[Tuple[Array, PyTree], 50 | Tuple[Array, Tuple[Array, PyTree]]]: 51 | """Calculates the stochastic potential for a mini-batch of data. 52 | 53 | Args: 54 | sample: Model parameters 55 | reference_data: Batch of observations 56 | state: Special parameters of the model which should not change the 57 | result of a model evaluation. 58 | mask: Marking invalid (e.g. double) samples 59 | likelihoods: Return the likelihoods of all model evaluations separately 60 | 61 | Returns: 62 | Returns an approximation of the true potential based on a mini-batch of 63 | reference data. Moreover, the likelihood for every single observation 64 | can be returned. 65 | 66 | """ 67 | 68 | class FullPotential(Protocol): 69 | def __call__(self, 70 | sample: PyTree, 71 | data_state: CacheState, 72 | full_data_map_fn: Callable, 73 | state: PyTree = None 74 | ) -> Tuple[Array, Tuple[CacheState, PyTree]]: 75 | """Calculates the potential over the full dataset. 76 | 77 | Args: 78 | sample: Model parameters 79 | data_state: State of the ``full_data_map`` functional 80 | full_data_map_fn: Functional mapping a function over the complete 81 | dataset 82 | state: Special parameters of the model which should not change the 83 | result of a model evaluation. 84 | 85 | Returns: 86 | Returns the potential of the current sample using the full dataset. 87 | 88 | """ 89 | 90 | 91 | 92 | # Todo: Possibly support soft-vmap (numpyro) 93 | 94 | def minibatch_potential(prior: Prior, 95 | likelihood: Callable, 96 | strategy: AnyStr = "map", 97 | has_state: bool = False, 98 | is_batched: bool = False, 99 | temperature: float = 1.) -> StochasticPotential: 100 | """Initializes the potential function for a minibatch of data. 101 | 102 | Args: 103 | prior: Log-prior function which is evaluated for a single 104 | sample. 105 | likelihood: Log-likelihood function. If ``has_state = True``, then the 106 | first argument is the model state, otherwise the arguments are ``sample, 107 | reference_data``. 108 | strategy: Determines hwo to evaluate the model function with respect for 109 | sample: 110 | 111 | - ``'map'`` sequential evaluation 112 | - ``'vmap'`` parallel evaluation via vectorization 113 | - ``'pmap'`` parallel evaluation on multiple devices 114 | 115 | has_state: If an additional state is provided for the model evaluation 116 | is_batched: If likelihood expects a batch of observations instead of a 117 | single observation. If the likelihood is batched, choosing the strategy 118 | has no influence on the computation. 119 | temperature: Posterior temperature. T = 1 is the Bayesian posterior. 120 | 121 | Returns: 122 | Returns a function which evaluates the stochastic potential for a mini-batch 123 | of data. The first argument are the latent variables and the second is the 124 | mini-batch. 125 | """ 126 | 127 | # State is always passed to simplify usage in solvers 128 | def stateful_likelihood(state: PyTree, 129 | sample: PyTree, 130 | reference_data: PyTree): 131 | if has_state: 132 | lks, state = likelihood(state, sample, reference_data) 133 | else: 134 | lks = likelihood(sample, reference_data) 135 | state = None 136 | # Ensure that a scalar is returned to avoid broadcasting with mask 137 | return jnp.squeeze(lks), state 138 | 139 | # Define the strategies to evaluate the likelihoods sequantially, vectorized 140 | # or in parallel 141 | if is_batched: 142 | batched_likelihood = stateful_likelihood 143 | elif strategy == 'map': 144 | def batched_likelihood(state: PyTree, 145 | sample: PyTree, 146 | reference_data: PyTree): 147 | partial_likelihood = partial(stateful_likelihood, state, sample) 148 | return lax.map(partial_likelihood, reference_data) 149 | elif strategy == 'pmap': 150 | batched_likelihood = pmap(stateful_likelihood, 151 | in_axes=(None, None, 0)) 152 | elif strategy == 'vmap': 153 | batched_likelihood = vmap(stateful_likelihood, 154 | in_axes=(None, None, 0)) 155 | else: 156 | raise NotImplementedError(f"Strategy {strategy} is unknown") 157 | 158 | 159 | def batch_potential(state: PyTree, 160 | sample: PyTree, 161 | reference_data: MiniBatch, 162 | mask: Array): 163 | # Approximate the potential by taking the average and scaling it to the 164 | # full data set size 165 | batch_data, batch_information = reference_data 166 | N = batch_information.observation_count 167 | n = batch_information.batch_size 168 | 169 | batch_likelihoods, new_states = batched_likelihood( 170 | state, sample, batch_data) 171 | if is_batched: 172 | # Batched evaluation returns single state 173 | new_state = new_states 174 | elif state is not None: 175 | new_state = tree_util.tree_map( 176 | lambda ary, org: jnp.reshape(jnp.take(ary, 0, axis=0), org.shape), 177 | new_states, state) 178 | else: 179 | new_state = None 180 | 181 | # The mask is only necessary for the full potential evaluation 182 | if mask is None: 183 | stochastic_potential = - N * jnp.mean(batch_likelihoods, axis=0) 184 | else: 185 | stochastic_potential = - N / n * jnp.dot(batch_likelihoods, mask) 186 | return stochastic_potential, batch_likelihoods, new_state 187 | 188 | @partial(named_call, name='evaluate_stochastic_potential') 189 | def potential_function(sample: PyTree, 190 | reference_data: MiniBatch, 191 | state: PyTree = None, 192 | mask: Array = None, 193 | likelihoods: bool = False): 194 | # Never differentiate w. r. t. reference data 195 | reference_data = lax.stop_gradient(reference_data) 196 | 197 | # Evaluate the likelihood and model for each reference data sample 198 | # likelihood_value = batched_likelihood_and_model(sample, reference_data) 199 | # It is also possible to combine the prior and the likelihood into a single 200 | # callable. 201 | 202 | batch_likelihood, observation_likelihoods, new_state = batch_potential( 203 | state, sample, reference_data, mask) 204 | 205 | # The prior has to be evaluated only once, therefore the extra call 206 | prior_value = prior(sample) 207 | 208 | if likelihoods: 209 | return ( 210 | jnp.squeeze(batch_likelihood - prior_value) / temperature, 211 | (observation_likelihoods, new_state)) 212 | else: 213 | return (jnp.squeeze(batch_likelihood - prior_value) / temperature, 214 | new_state) 215 | 216 | return potential_function 217 | 218 | 219 | def full_potential(prior: Callable[[PyTree], Array], 220 | likelihood: Callable[[PyTree, PyTree], Array], 221 | strategy: AnyStr = "map", 222 | has_state: bool = False, 223 | is_batched: bool = False, 224 | temperature: float = 1., 225 | ) -> FullPotential: 226 | """Transforms a pdf to compute the full potential over all reference data. 227 | 228 | Args: 229 | prior: Log-prior function which is evaluated for a single 230 | sample. 231 | likelihood: Log-likelihood function. If ``has_state = True``, then the 232 | first argument is the model state, otherwise the arguments are ``sample, 233 | reference_data``. 234 | strategy: Determines how to evaluate the model function with respect for 235 | sample: 236 | 237 | - ``'map'`` sequential evaluation 238 | - ``'vmap'`` parallel evaluation via vectorization 239 | - ``'pmap'`` parallel evaluation on multiple devices 240 | 241 | has_state: If an additional state is provided for the model evaluation 242 | is_batched: If likelihood expects a batch of observations instead of a 243 | single observation. If the likelihood is batched, choosing the strategy 244 | has no influence on the computation. In this case, the last argument of 245 | the likelihood should be an optional mask. The mask is an arrays with ones 246 | for valid observations and zeros for non-valid observations. 247 | temperature: Posterior temperature. T = 1 is the Bayesian posterior. 248 | 249 | Returns: 250 | Returns a function which evaluates the potential over the full dataset via 251 | a dataset mapping from the :mod:`jax_sgmc.data` module. 252 | 253 | """ 254 | assert strategy != 'pmap', "Pmap is currently not supported" 255 | 256 | # Can use the potential evaluation strategy for a minibatch of data. The prior 257 | # must be evaluated independently. 258 | batch_potential = minibatch_potential(lambda _: jnp.array(0.0), 259 | likelihood, 260 | strategy=strategy, 261 | has_state=has_state, 262 | is_batched=is_batched) 263 | 264 | def batch_evaluation(sample, reference_data, mask, state): 265 | potential, state = batch_potential(sample, reference_data, state, mask) 266 | # We need to undo the scaling to get the real potential 267 | _, batch_information = reference_data 268 | N = batch_information.observation_count 269 | n = batch_information.batch_size 270 | unscaled_potential = potential * n / N 271 | return unscaled_potential, state 272 | 273 | @partial(named_call, name='evaluate_true_potential') 274 | def sum_batched_evaluations(sample: PyTree, 275 | data_state: CacheState, 276 | full_data_map_fn: Callable, 277 | state: PyTree = None): 278 | body_fn = partial(batch_evaluation, sample) 279 | 280 | if data_state is None: # quick fix to let it run with full_data_mapper 281 | results, new_state = full_data_map_fn( 282 | body_fn, state, masking=True, information=True) 283 | else: 284 | data_state, (results, new_state) = full_data_map_fn( 285 | body_fn, data_state, state, masking=True, information=True) 286 | 287 | # The prior needs just a single evaluation 288 | prior_eval = prior(sample) 289 | 290 | return (jnp.squeeze(jnp.sum(results) - prior_eval) / temperature, 291 | (data_state, new_state)) 292 | 293 | return sum_batched_evaluations 294 | -------------------------------------------------------------------------------- /docs/usage/data.rst: -------------------------------------------------------------------------------- 1 | Data Loading 2 | ============= 3 | 4 | Steps for Setting up a Data Chain 5 | --------------------------------- 6 | 7 | Access of data in **JaxSGMC** consists of two steps: 8 | 9 | - Setup Data Loader 10 | - Setup Callback Wrappers 11 | 12 | The DataLoader determines where the data is stored and how it is passed 13 | to the device (e. g. shuffled in epochs). 14 | 15 | The Callback Wrappers requests new batches from the DataLoader and pass them 16 | to the device via Jax's Host-Callback module. Therefore, only a subset of the 17 | data is stored in the device memory. 18 | 19 | The combination of a DataLoader and Callback Wrappers determines how the data is 20 | passed to the computation. Therefore, this guide presents different methods of 21 | data access with ``NumpyDataLoader`` and ``TensorflowDataLoader``. 22 | 23 | .. note:: 24 | When using multiple DataLoaders sequentially in a single script, the 25 | release function should be called after the Callback Wrapper has been used in 26 | the last computation. After this, the reference to the DataLoader has been 27 | discarded and the DataLoader can be deleted. 28 | 29 | Important Notes 30 | ---------------- 31 | 32 | Getting shape and dtype of the data 33 | ____________________________________ 34 | 35 | Some models needs to now the shape and dtype of the reference data. Therefore, 36 | an all-zero batch can be drawn from every DataLoader. 37 | 38 | :: 39 | 40 | print(data_loader.initializer_batch(3)) 41 | {'x_r': Array([0, 0, 0], dtype=int32), 'y_r': Array([[0., 0.], 42 | [0., 0.], 43 | [0., 0.]], dtype=float32)} 44 | 45 | If no batch size is specified, a single observation is returned (all leaves' 46 | shapes are reduced by the first axis). 47 | 48 | :: 49 | 50 | print(data_loader.initializer_batch()) 51 | {'x_r': Array(0, dtype=int32), 'y_r': Array([0., 0.], dtype=float32)} 52 | 53 | 54 | Numpy Data Loader 55 | ------------------ 56 | 57 | .. doctest:: 58 | 59 | >>> import numpy as onp 60 | >>> from jax_sgmc import data 61 | >>> from jax_sgmc.data.numpy_loader import NumpyDataLoader 62 | 63 | First, we set up the dataset. This is very simple, as each array can be assigned 64 | as a keyword argument to the dataloader. The keywords of the single arrays form 65 | the keys of the pytree-dict, bundling all observations. Note that you can access 66 | the data supplied to the likelihood via the same keywords. 67 | 68 | >>> # The arrays must have the same length along the first dimension, 69 | >>> # corresponding to the total observation count 70 | >>> x = onp.arange(10) 71 | >>> y = onp.zeros((10, 2)) 72 | >>> 73 | >>> data_loader = NumpyDataLoader(x_r=x, y_r=y) 74 | 75 | The host callback wrappers cache some data in the device memory to reduce the 76 | number of calls to the host. The cache size equals the number of batches stored 77 | on the device. A larger cache size is more efficient in computation time, but 78 | has an increased device memory consumption. 79 | 80 | >>> rd_init, rd_batch, _ = data.random_reference_data(data_loader, 100, 2) 81 | 82 | The ``NumpyDataLoader`` accepts keyword arguments in 83 | the init function to determine the starting points of the chains. 84 | 85 | >>> rd_state = rd_init(seed=0) 86 | >>> new_state, (rd_batch, info) = rd_batch(rd_state, information=True) 87 | >>> print(rd_batch) 88 | {'x_r': Array([8, 9], dtype=int32), 'y_r': Array([[0., 0.], 89 | [0., 0.]], dtype=float32)} 90 | >>> # If necessary, information about the total sample count can be passed 91 | >>> print(info) 92 | MiniBatchInformation(observation_count=10, mask=Array([ True, True], dtype=bool), batch_size=2) 93 | 94 | 95 | Random Data Access 96 | ___________________ 97 | 98 | The ``NumpyDataLoader`` provides three different methods to randomly select 99 | observations: 100 | 101 | - Independent draw (default): Draw from all samples with replacement. 102 | - Shuffling: Draw from all samples without replacement and immediately reshuffle 103 | if all samples have been drawn. 104 | - Shuffling in epochs: Draw from all samples without replacement and return mask 105 | to mark invalid samples at the end of the epoch. 106 | 107 | This is illustrated for a small toy-dataset, for which the observation count is 108 | not a multiple of the batch size: 109 | 110 | .. doctest:: 111 | 112 | >>> import numpy as onp 113 | >>> from jax_sgmc import data 114 | >>> from jax_sgmc.data.numpy_loader import NumpyDataLoader 115 | 116 | >>> x = onp.arange(10) 117 | >>> data_loader = NumpyDataLoader(x=x) 118 | >>> init_fn, batch_fn, _ = data.random_reference_data(data_loader, 2, 3) 119 | 120 | The preferred method has to be passed when initializing the different chains: 121 | 122 | >>> random_chain = init_fn() 123 | >>> shuffle_chain = init_fn(shuffle=True) 124 | >>> epoch_chain = init_fn(shuffle=True, in_epochs=True) 125 | 126 | In the fourth draw, the epoch chain should return a mask with invalid samples: 127 | 128 | >>> def eval_fn(chain): 129 | ... for _ in range(4): 130 | ... chain, batch = batch_fn(chain, information=True) 131 | ... print(batch) 132 | >>> 133 | >>> eval_fn(random_chain) 134 | ({'x': Array([4, 6, 6], dtype=int32)}, MiniBatchInformation(observation_count=10, mask=Array([ True, True, True], dtype=bool), batch_size=3)) 135 | >>> eval_fn(shuffle_chain) 136 | ({'x': Array([0, 4, 7], dtype=int32)}, MiniBatchInformation(observation_count=10, mask=Array([ True, True, True], dtype=bool), batch_size=3)) 137 | >>> eval_fn(epoch_chain) 138 | ({'x': Array([5, 0, 0], dtype=int32)}, MiniBatchInformation(observation_count=10, mask=Array([ True, False, False], dtype=bool), batch_size=3)) 139 | 140 | 141 | Mapping over Full Dataset 142 | __________________________ 143 | 144 | It is also possible to map a function over the complete dataset provided by a 145 | DataLoader. In each iteration, the function is mapped over a batch of data to 146 | speed up the calculation, but limit the memory consumption. 147 | 148 | In this toy example, the dataset consists of the potential bases 149 | :math:`\mathcal{D} = \left\{i \mid i = 0, \ldots, 10 \right\}`. In a scan loop, 150 | the sum of the potentials with given exponents is calculated: 151 | 152 | .. math:: 153 | 154 | f_e = \sum_{i=0}^{9}d_i^e \mid d_i \in \mathcal{D}, k = 0,\ldots, 2. 155 | 156 | .. doctest:: 157 | 158 | >>> from functools import partial 159 | >>> import jax.numpy as jnp 160 | >>> from jax.lax import scan 161 | >>> from jax_sgmc import data 162 | >>> from jax_sgmc.data.numpy_loader import NumpyDataLoader 163 | 164 | First, the data loader must be set up. The batch size is not required to 165 | divide the total observation count. This is realized by filling up the 166 | last batch with some values, which are sorted out either automatically or 167 | directly by the user with the provided mask. 168 | 169 | >>> base = jnp.arange(10) 170 | >>> 171 | >>> data_loader = NumpyDataLoader(base=base) 172 | 173 | The mask is an boolean array with ``True`` if the value is valid and ``False`` 174 | if it is just a filler. 175 | If set to ``masking=False`` (default), no positional argument mask is expected 176 | in the function signature. 177 | 178 | >>> def sum_potentials(exp, data, mask, unused_state): 179 | ... # Mask out the invalid samples (filler values, already mapped over) 180 | ... sum = jnp.sum(mask * jnp.power(data['base'], exp)) 181 | ... return sum, unused_state 182 | >>> 183 | >>> init_fun, map_fun, _ = data.full_reference_data(data_loader, 184 | ... cached_batches_count=3, 185 | ... mb_size=4) 186 | 187 | The results per batch must be post-processed. If ``masking=False``, a result for 188 | each observation is returned. Therefore, using the masking option improves the 189 | memory consumption. 190 | 191 | >>> # The exponential value is fixed during the mapping, therefore add it via 192 | >>> # functools.partial to the mapped function. 193 | >>> map_results = map_fun(partial(sum_potentials, 2), 194 | ... init_fun(), 195 | ... None, 196 | ... masking=True) 197 | >>> 198 | >>> data_state, (batch_sums, unused_state) = map_results 199 | >>> 200 | >>> # As we used the masking, a single result for each batch is returned. 201 | >>> # Now we need to postprocess those results, in this case by summing, to 202 | >>> # get the true result. 203 | >>> summed_result = jnp.sum(batch_sums) 204 | >>> print(f"Result: {summed_result : d}") 205 | Result: 285 206 | 207 | The full data map can be used in ``jit``-compiled functions, e.g. in a scan loop, 208 | such that it is possible to compute the results for multiple exponents in a 209 | ``lax.scan``-loop. 210 | 211 | >>> # Calculate for multiple exponents: 212 | >>> def body_fun(data_state, exp): 213 | ... map_results = map_fun(partial(sum_potentials, exp), data_state, None, masking=True) 214 | ... # Currently, we only summed over each mini-batch but not the whole 215 | ... # dataset. 216 | ... data_state, (batch_sums, unused_state) = map_results 217 | ... return data_state, (jnp.sum(batch_sums), unused_state) 218 | >>> 219 | >>> init_data_state = init_fun() 220 | >>> _, (result, _) = scan(body_fun, init_data_state, jnp.arange(3)) 221 | >>> print(result) 222 | [ 10 45 285] 223 | 224 | It is also possible to store the ``CacheStates`` in the host memory, such that 225 | it is not necessary to carry the ``data_state`` through all function calls. 226 | The :func:`jax_sgmc.data.core.full_data_mapper` function does this, such that 227 | its usage is a little bit simpler: 228 | 229 | >>> mapper_fn, release_fn = data.full_data_mapper(data_loader, 230 | ... cached_batches_count=3, 231 | ... mb_size=4) 232 | >>> 233 | >>> results, _ = mapper_fn(partial(sum_potentials, 2), None, masking=True) 234 | >>> 235 | >>> print(f"Result with exp = 2: {jnp.sum(results) : d}") 236 | Result with exp = 2: 285 237 | >>> 238 | >>> # Delete the reference to the Data Loader (optional) 239 | >>> release_fn() 240 | 241 | 242 | Tensorflow Data Loader 243 | ----------------------- 244 | 245 | Random Access 246 | _______________________ 247 | 248 | The tensorflow DataLoader is a great choice for many standard datasets 249 | available on tensorflow_datasets. 250 | 251 | .. doctest:: 252 | 253 | >>> import tensorflow_datasets as tfds 254 | >>> from jax import tree_util 255 | >>> from jax_sgmc import data 256 | >>> from jax_sgmc.data.tensorflow_loader import TensorflowDataLoader 257 | >>> 258 | >>> import contextlib 259 | >>> import io 260 | >>> 261 | >>> # Helper function to look at the data provided 262 | >>> def show_data(data): 263 | ... for key, item in data.items(): 264 | ... print(f"{key} with shape {item.shape} and dtype {item.dtype}") 265 | 266 | The pipeline returned by tfds load can be directly passed to the DataLoader. 267 | However, not all tensorflow data types can be transformed to Jax data types, for 268 | example the feature 'id', which is a string. Those keys can be simply excluded 269 | via the keyword argument `exclude_keys`. 270 | 271 | >>> # The data pipeline can be used directly 272 | >>> with contextlib.redirect_stdout(io.StringIO()): 273 | ... pipeline, info = tfds.load("cifar10", split="train", with_info=True) 274 | >>> print(info.features) 275 | FeaturesDict({ 276 | 'id': Text(shape=(), dtype=string), 277 | 'image': Image(shape=(32, 32, 3), dtype=uint8), 278 | 'label': ClassLabel(shape=(), dtype=int64, num_classes=10), 279 | }) 280 | >>> 281 | >>> data_loader = TensorflowDataLoader(pipeline, shuffle_cache=10, exclude_keys=['id']) 282 | >>> 283 | >>> # If the model needs data for initialization, an all zero batch can be 284 | >>> # drawn with the correct shapes and dtypes 285 | >>> show_data(data_loader.initializer_batch(mb_size=1000)) 286 | image with shape (1000, 32, 32, 3) and dtype uint8 287 | label with shape (1000,) and dtype int32 288 | 289 | The host callback wrappers cache some data in the device memory to reduce the 290 | number of calls to the host. The cache size equals the number of batches stored 291 | on the device. 292 | 293 | >>> data_init, data_batch, _ = data.random_reference_data(data_loader, 100, 1000) 294 | >>> 295 | >>> init_state = data_init() 296 | >>> new_state, batch = data_batch(init_state) 297 | >>> show_data(batch) 298 | image with shape (1000, 32, 32, 3) and dtype uint8 299 | label with shape (1000,) and dtype int32 300 | 301 | Combining ``pmap`` and ``jit`` 302 | ______________________________ 303 | 304 | .. warning:: 305 | Jit-compiling a function including pmap requires adjustments of the Callback 306 | Wrapper functions, which can lead to memory leaks if not done correctly. 307 | 308 | Additionally, combining ``jit`` and ``pmap`` can lead to inefficient data 309 | movement. See ``_. 310 | 311 | When jitting a function f that includes a pmap function g, also the parts of f 312 | outside of g are run on all involved devices. This causes all devices to request 313 | the same cache state (verified by a token) from a single chain. 314 | 315 | For example, if g is pmapped to 5 devices, f is also going to run on 5 devices 316 | and hence 5 times the same cache state is requested from a chain. 317 | 318 | JaxSGMC resolved this issue by caching all requested states on the host for a 319 | specified number of requests. 320 | 321 | In the above example, the Callback Wrapper used in f should be called like: 322 | 323 | :: 324 | 325 | ... = full_data_map(to_map_fn, data_state, carry, device_count=5) 326 | 327 | 328 | It is important to note that providing a device count larger than the actual 329 | number of calling devices causes a memory leak, as all requested cache states 330 | will remain on the host until the program has finished. 331 | -------------------------------------------------------------------------------- /jax_sgmc/adaption.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Adapt quantities to the local or global geometry. 16 | """ 17 | 18 | import functools 19 | 20 | from typing import Any, Callable, Tuple, NamedTuple 21 | 22 | from collections import namedtuple 23 | 24 | from jax import tree_util, flatten_util, jit, named_call, lax, vmap, grad 25 | import jax.numpy as jnp 26 | 27 | from jax_sgmc.util import Array, Tensor 28 | from jax_sgmc.data import MiniBatch 29 | 30 | PartialFn = Any 31 | PyTree = Any 32 | 33 | AdaptionStrategy = Callable 34 | 35 | class AdaptionState(NamedTuple): 36 | """Extended adaption state returned by adaption decorator. 37 | 38 | This tuple stores functions to ravel and unravel the parameter and gradient 39 | pytree in addition to the adaption state. 40 | 41 | Attributes: 42 | state: State of the adaption strategy 43 | ravel_fn: Jax-partial function to transform pytree to 1D array 44 | unravel_fn: Jax-partial function to undo ravelling of pytree 45 | flat_potential: Potential function on the flattened pytree 46 | """ 47 | state: PyTree 48 | ravel_fn: Callable 49 | unravel_fn: Callable 50 | flat_potential: Callable 51 | 52 | 53 | class Manifold(NamedTuple): 54 | """Adapted manifold. 55 | 56 | Attributes: 57 | g_inv: Inverse manifold. 58 | sqrt_g_inv: Square root of inverse manifold. 59 | gamma: Diffusion to correct for positional dependence of manifold. 60 | 61 | """ 62 | g_inv: Tensor 63 | sqrt_g_inv: Tensor 64 | gamma: Tensor 65 | 66 | 67 | class MassMatrix(NamedTuple): 68 | """Mass matrix for HMC. 69 | 70 | Attributes: 71 | inv: Inverse of the mass matrix 72 | sqrt: Square root of the mass matrix 73 | 74 | """ 75 | inv: PyTree 76 | sqrt: PyTree 77 | 78 | 79 | class NoiseModel(NamedTuple): 80 | """Approximation of the gradient noise. 81 | 82 | Attributes: 83 | cb_diff_sqrt: Square root of the difference between the friction term and 84 | the noise model 85 | b_sqrt: Square root of the noise model 86 | 87 | """ 88 | cb_diff_sqrt: PyTree 89 | b_sqrt: PyTree 90 | 91 | 92 | def get_unravel_fn(tree: PyTree): 93 | """Calculates the unravel function. 94 | 95 | Args: 96 | tree: Parameter pytree 97 | 98 | Returns: 99 | Returns a jax Partial object such that the function can be passed as 100 | valid argument. 101 | 102 | """ 103 | _, unravel_fn = flatten_util.ravel_pytree(tree) 104 | return tree_util.Partial(jit(unravel_fn)) 105 | 106 | 107 | def adaption(quantity: namedtuple = tuple): 108 | """Decorator to make adaption strategies operate on 1D arrays. 109 | 110 | Positional arguments are flattened while keyword arguments are passed 111 | unchanged. 112 | 113 | Args: 114 | quantity: Namedtuple to specify which fields are returned by 115 | :func:``get_adaption``. 116 | 117 | """ 118 | return functools.partial(_adaption, quantity=quantity) 119 | 120 | def _adaption(adaption_fn: Callable, quantity: namedtuple = tuple): 121 | @functools.wraps(adaption_fn) 122 | def pytree_adaption(*args, **kwargs) -> AdaptionStrategy: 123 | init, update, get = adaption_fn(*args, **kwargs) 124 | # Name call for debugging 125 | named_update = named_call(update, name='update_adaption_state') 126 | named_get = named_call(get, name='get_adapted_manifold') 127 | @functools.wraps(init) 128 | def new_init(x0: PyTree, 129 | *init_args, 130 | **init_kwargs) -> AdaptionState: 131 | # Calculate the flattened state and the ravel and unravel fun 132 | ravel_fn = tree_util.Partial( 133 | jit(lambda tree: flatten_util.ravel_pytree(tree)[0])) 134 | unravel_fn = get_unravel_fn(x0) 135 | 136 | x0_flat = ravel_fn(x0) 137 | init_flat = map(ravel_fn, init_args) 138 | state = init(x0_flat, *init_flat, **init_kwargs) 139 | 140 | # Wrap the potential in a flatten function if potential is provided as 141 | # kwarg 142 | minibatch_potential = kwargs.get("minibatch_potential") 143 | if minibatch_potential is not None: 144 | @tree_util.Partial 145 | def flat_potential(sample, 146 | mini_batch, 147 | model_state: PyTree = None, 148 | **kwargs): 149 | sample_tree = unravel_fn(sample) 150 | if model_state is not None: 151 | return minibatch_potential( 152 | model_state, sample_tree, mini_batch, **kwargs) 153 | else: 154 | return minibatch_potential( 155 | sample_tree, mini_batch, **kwargs) 156 | else: 157 | flat_potential = None 158 | 159 | return AdaptionState(state=state, 160 | ravel_fn=ravel_fn, 161 | unravel_fn=unravel_fn, 162 | flat_potential=flat_potential) 163 | 164 | @functools.wraps(update) 165 | def new_update(state: AdaptionState, 166 | *update_args, 167 | mini_batch: PyTree = None, 168 | **update_kwargs): 169 | # Flat the params and the gradient 170 | flat_args = map(state.ravel_fn, update_args) 171 | 172 | # Update with flattened arguments 173 | if state.flat_potential is None: 174 | new_state = named_update( 175 | state.state, 176 | *flat_args, 177 | **update_kwargs) 178 | else: 179 | assert mini_batch, "Adaption strategy requires mini-batch" 180 | new_state = named_update( 181 | state.state, 182 | *flat_args, 183 | mini_batch, 184 | state.flat_potential, 185 | **update_kwargs) 186 | 187 | updated_state = AdaptionState( 188 | state=new_state, 189 | ravel_fn=state.ravel_fn, 190 | unravel_fn=state.unravel_fn, 191 | flat_potential=state.flat_potential) 192 | return updated_state 193 | 194 | @functools.wraps(get) 195 | def new_get(state: AdaptionState, 196 | *get_args, 197 | mini_batch: PyTree = None, 198 | **get_kwargs 199 | ) -> quantity: 200 | # Flat the params and the gradient 201 | flat_args = map(state.ravel_fn, get_args) 202 | 203 | # Get with flattened arguments 204 | if state.flat_potential is None: 205 | adapted_quantities = named_get( 206 | state.state,*flat_args, **get_kwargs) 207 | else: 208 | adapted_quantities = named_get( 209 | state.state, *flat_args, 210 | mini_batch=mini_batch, flat_potential=state.flat_potential, 211 | **get_kwargs) 212 | 213 | def unravel_quantities(): 214 | for q in adapted_quantities: 215 | if q.ndim == 1: 216 | yield Tensor(ndim=1, tensor=state.unravel_fn(q)) 217 | else: 218 | yield Tensor(ndim=2, tensor=q) 219 | 220 | return quantity(*unravel_quantities()) 221 | return new_init, new_update, new_get 222 | return pytree_adaption 223 | 224 | 225 | @adaption(quantity=Manifold) 226 | def rms_prop() -> AdaptionStrategy: 227 | """RMSprop adaption. 228 | 229 | Adapt a diagonal matrix to the local curvature requiring only the stochastic 230 | gradient. 231 | 232 | Returns: 233 | Returns RMS-prop adaption strategy. 234 | 235 | [1] https://arxiv.org/abs/1512.07666 236 | """ 237 | 238 | def init(sample: Array, 239 | alpha: Array = 0.9, 240 | lmbd: Array = 1e-5): 241 | """Initializes RMSprop algorithm. 242 | 243 | Args: 244 | sample: Initial sample to derive the sample size 245 | alpha: Adaption speed 246 | lmbd: Stabilization constant 247 | 248 | Returns: 249 | Returns the initial adaption state 250 | """ 251 | v = jnp.ones_like(sample) 252 | return (v, alpha, lmbd) 253 | 254 | def update(state: Tuple[Array, Array, Array], 255 | sample: Array, 256 | sample_grad: Array, 257 | *args: Any, 258 | **kwargs: Any): 259 | """Updates the RMS-prop adaption. 260 | 261 | Args: 262 | state: Adaption state 263 | sample_grad: Stochastic gradient 264 | 265 | Returns: 266 | Returns adapted RMSprop state. 267 | """ 268 | del sample, args, kwargs 269 | 270 | v, alpha, lmbd = state 271 | new_v = alpha * v + (1 - alpha) * jnp.square(sample_grad) 272 | return new_v, alpha, lmbd 273 | 274 | def get(state: Tuple[Array, Array, Array], 275 | sample: Array, 276 | sample_grad: Array, 277 | *args: Any, 278 | **kwargs: Any): 279 | """Calculates the current manifold of the RMS-prop adaption. 280 | 281 | Args: 282 | state: Current RMSprop adaption state 283 | 284 | Returns: 285 | Returns a manifold tuple with ``ndim == 1``. 286 | """ 287 | del sample, sample_grad, args, kwargs 288 | 289 | v, _, lmbd = state 290 | g = jnp.power(lmbd + jnp.sqrt(v), -1.0) 291 | return g, jnp.sqrt(g), jnp.zeros_like(g) 292 | 293 | return init, update, get 294 | 295 | 296 | @adaption(quantity=MassMatrix) 297 | def mass_matrix(diagonal=True, burn_in=1000): 298 | """Adapt the mass matrix for HMC. 299 | 300 | Args: 301 | diagonal: Restrict the adapted matrix to be diagonal 302 | burn_in: Number of steps in which the matrix should be updated 303 | 304 | Returns: 305 | Returns an adaption strategy for the mass matrix. 306 | 307 | """ 308 | 309 | def _update_matrix(args): 310 | iteration, ssq, _, _ = args 311 | if diagonal: 312 | inv = ssq / iteration 313 | sqrt = jnp.sqrt(iteration / ssq) 314 | else: 315 | inv = ssq / iteration 316 | eigw, eigv = jnp.linalg.eigh(ssq / iteration) 317 | # Todo: More effective computation 318 | sqrt = jnp.matmul(jnp.transpose(eigv), jnp.matmul(jnp.diag(jnp.sqrt(eigw)), eigv)) 319 | return inv, sqrt 320 | 321 | def init(sample: Array, init_cov: Array): 322 | iteration = 0 323 | mean = jnp.zeros_like(sample) 324 | if diagonal: 325 | ssq = jnp.zeros_like(sample) 326 | else: 327 | ssq = jnp.zeros((sample.size, sample.size)) 328 | 329 | if init_cov is None: 330 | init_cov = jnp.ones_like(sample) 331 | 332 | if diagonal: 333 | m_inv = init_cov 334 | m_sqrt = 1 / jnp.sqrt(init_cov) 335 | else: 336 | m_inv = jnp.diag(init_cov) 337 | m_sqrt = jnp.diag(1 / jnp.sqrt(init_cov)) 338 | 339 | return iteration, mean, ssq, m_inv, m_sqrt 340 | 341 | def update(state: Tuple[Array, Array, Array, Array, Array], 342 | sample: Array, 343 | *args: Any, 344 | **kwargs: Any): 345 | del args, kwargs 346 | iteration, mean, ssq, m_inv, m_sqrt = state 347 | 348 | iteration += 1 349 | new_mean = (iteration - 1) / iteration * mean + 1 / iteration * sample 350 | if diagonal: 351 | ssq += jnp.multiply(sample - mean, sample - new_mean) 352 | else: 353 | ssq += jnp.outer(sample - mean, sample - new_mean) 354 | 355 | # Only update once 356 | new_m_inv, new_m_sqrt = lax.cond( 357 | iteration == burn_in, 358 | _update_matrix, 359 | lambda arg: (arg[2], arg[3]), 360 | (iteration, ssq, m_inv, m_sqrt)) 361 | 362 | return iteration, new_mean, ssq, new_m_inv, new_m_sqrt 363 | 364 | def get(state: Tuple[Array, Array, Array, Array, Array]): 365 | _, _, _, m_inv, m_sqrt = state 366 | 367 | return m_inv, m_sqrt 368 | 369 | return init, update, get 370 | 371 | @adaption(quantity=NoiseModel) 372 | def fisher_information(minibatch_potential: Callable = None, 373 | diagonal = True 374 | ) -> AdaptionStrategy: 375 | """Adapts empirical fisher information. 376 | 377 | Use the empirical fisher information as a noise model for SGHMC. The empirical 378 | fisher information is approximated according to [1]. 379 | 380 | Returns: 381 | Returns noise model approximation strategy. 382 | 383 | [1] https://arxiv.org/abs/1206.6380 384 | """ 385 | assert minibatch_potential, "Fisher information requires potential function." 386 | 387 | def init(*args): 388 | del args 389 | 390 | def update(*args, 391 | **kwargs): 392 | del args, kwargs 393 | 394 | def get(state, 395 | sample: Array, 396 | sample_grad: Array, 397 | friction: Array, 398 | *args: Any, 399 | mini_batch: MiniBatch, 400 | flat_potential, 401 | step_size: Any = jnp.array(1.0), 402 | model_state: PyTree = None, 403 | **kwargs: Any): 404 | del state, args, kwargs 405 | 406 | _, mb_information = mini_batch 407 | N = mb_information.observation_count 408 | n = mb_information.batch_size 409 | 410 | # Unscale the gradient to get mean 411 | sample_grad /= N 412 | 413 | def potential_at_obs(smp, obs_idx): 414 | _, (likelihoods, _) = flat_potential( 415 | smp, mini_batch, likelihoods=True, state=model_state) 416 | return likelihoods[obs_idx] 417 | 418 | grad_diff_idx = grad(potential_at_obs, argnums=0) 419 | 420 | @functools.partial(vmap, out_axes=0) 421 | def sqd(idx): 422 | if diagonal: 423 | return jnp.square(grad_diff_idx(sample, idx) - sample_grad) 424 | else: 425 | diff = grad_diff_idx(sample, idx) - sample_grad 426 | return jnp.outer(diff, diff) 427 | 428 | ssq = jnp.sum(sqd(jnp.arange(mb_information.batch_size)), axis=0) 429 | 430 | if diagonal: 431 | v = 1 / (n - 1) * ssq 432 | b = 0.5 * step_size * v 433 | 434 | # Correct for negative eigenvalues 435 | correction = friction - b 436 | smallest_positive = jnp.min(jnp.where(correction <= 0, jnp.inf, correction)) 437 | positive_correction = jnp.where(correction <= 0, smallest_positive, correction) 438 | 439 | # Apply the corrections to b 440 | b_corrected = friction - positive_correction 441 | 442 | noise_scale = jnp.sqrt(positive_correction) 443 | scale = jnp.sqrt(b_corrected) 444 | else: 445 | v = 1 / (n - 1) * ssq 446 | b = 0.5 * step_size * v 447 | 448 | # Using svd also works for positive semi-definite matrices 449 | u_cb, s_cb, _ = jnp.linalg.svd(jnp.diag(friction) - b) 450 | u, s, _ = jnp.linalg.svd(b) 451 | 452 | noise_scale = jnp.dot(u_cb, jnp.sqrt(s_cb)) 453 | scale = jnp.dot(u, jnp.sqrt(s)) 454 | 455 | return noise_scale, scale 456 | 457 | return init, update, get 458 | -------------------------------------------------------------------------------- /examples/sgld_rms.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | formats: examples///ipynb,examples///md:myst,docs//usage//ipynb 4 | main_language: python 5 | text_representation: 6 | extension: .md 7 | format_name: myst 8 | format_version: 0.13 9 | jupytext_version: 1.14.4 10 | kernelspec: 11 | display_name: Python 3 12 | name: python3 13 | --- 14 | 15 | ```{raw-cell} 16 | 17 | --- 18 | Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 19 | 20 | Licensed under the Apache License, Version 2.0 (the "License"); 21 | you may not use this file except in compliance with the License. 22 | You may obtain a copy of the License at 23 | 24 | http://www.apache.org/licenses/LICENSE-2.0 25 | 26 | Unless required by applicable law or agreed to in writing, software 27 | distributed under the License is distributed on an "AS IS" BASIS, 28 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 29 | See the License for the specific language governing permissions and 30 | limitations under the License. 31 | --- 32 | ``` 33 | 34 | ```{code-cell} ipython3 35 | :tags: [hide-cell] 36 | 37 | import warnings 38 | warnings.filterwarnings('ignore') 39 | 40 | import h5py 41 | import matplotlib.pyplot as plt 42 | import numpy as onp 43 | import jax.numpy as jnp 44 | from jax import random 45 | from jax.scipy.stats import norm 46 | 47 | from jax_sgmc import data, potential, scheduler, adaption, integrator, solver, io 48 | from jax_sgmc.data.numpy_loader import NumpyDataLoader 49 | from jax_sgmc.data.hdf5_loader import HDF5Loader 50 | ``` 51 | 52 | # Setup Custom Solver 53 | 54 | This example shows how to customize a solver by combining the individual 55 | modules of **JaxSGMC**. 56 | It covers all necessary steps to build a *SGLD* solver with *RMSprop* adaption 57 | and applies it to the same problem as described in {doc}`/quickstart/`. 58 | 59 | ## Overview 60 | 61 | Schematically, a solver in **JaxSGMC** has the structure: 62 | 63 | ![Structure of JaxSGMC](https://raw.githubusercontent.com/tummfm/jax-sgmc/main/jax-sgmc-structure.svg) 64 | 65 | The SGLD solver with RMSprop adaption will make use of all modules. 66 | It is set up in these steps: 67 | 68 | - **[Load Reference Data](#load-reference-data)** 69 | - **[Transform Log-likelihood to Potential](#transform-log-likelihood-to-potential)** 70 | - **[RMSprop Adaption](#rmsprop-adaption)** 71 | - **[Integrator and Solver](#integrator-and-solver)** 72 | - **[Scheduler](#scheduler)** 73 | - **[Save Samples](#save-samples-in-numpy-arrays)** 74 | - **[Run Solver](#run-solver)** 75 | 76 | ## Load Reference Data 77 | 78 | The reference data is passed to the solver via two components, the Data Loader 79 | and the Host Callback Wrapper. 80 | 81 | The Data Loader assembles the batches requested by the host callback wrappers. 82 | It loads the data from a source (HDF-File, numpy-array, tensorflow dataset) 83 | and selects the observations in each batch after a specific method 84 | (ordered access, shuffling, ...). 85 | 86 | The Host Callback Wrapper requests new batches from the Data Loader and loads 87 | them into jit-compiled programs via Jax's Host Callback module. 88 | To balance the memory usage and the delay due to loading the data, each device 89 | call returns multiple batches. 90 | 91 | ```{code-cell} ipython3 92 | :tags: [hide-cell] 93 | 94 | N = 4 95 | samples = 1000 # Total samples 96 | 97 | key = random.PRNGKey(0) 98 | split1, split2, split3 = random.split(key, 3) 99 | 100 | # Correct solution 101 | sigma = 0.5 102 | w = random.uniform(split3, minval=-1, maxval=1, shape=(N, 1)) 103 | 104 | # Data generation 105 | noise = sigma * random.normal(split2, shape=(samples, 1)) 106 | x = random.uniform(split1, minval=-10, maxval=10, shape=(samples, N)) 107 | x = jnp.stack([x[:, 0] + x[:, 1], x[:, 1], 0.1 * x[:, 2] - 0.5 * x[:, 3], 108 | x[:, 3]]).transpose() 109 | y = jnp.matmul(x, w) + noise 110 | ``` 111 | 112 | The NumpyDataLoader assembles batches randomly by drawing from the the complete 113 | dataset with and without replacement (shuffling). 114 | It also provides the possibility to start the batching from a defined state, 115 | controlled via the seed. 116 | 117 | These settings can be passed differently for every chain and are thus not passed 118 | during the initialization. 119 | Instead, they have to be passed during the 120 | [initialization of the chains](#integrator-and-solver). 121 | 122 | In this example, the batches are shuffled, i.e. every sample is used at least 123 | once before an already drawn sample is used again and the chains start at a 124 | defined state. 125 | 126 | ```{code-cell} ipython3 127 | # The construction of the data loader can be different. For the numpy data 128 | # loader, the numpy arrays can be passed as keyword arguments and are later 129 | # returned as a dictionary with corresponding keys. 130 | data_loader = NumpyDataLoader(x=x, y=y) 131 | 132 | # The cache size corresponds to the number of batches per cache. The state 133 | # initialized via the init function is necessary to identify which data chain 134 | # request new batches of data. 135 | data_fn = data.random_reference_data(data_loader, 136 | mb_size=N, 137 | cached_batches_count=100) 138 | 139 | data_loader_kwargs = { 140 | "seed": 0, 141 | "shuffle": True, 142 | "in_epochs": False 143 | } 144 | ``` 145 | 146 | ## Transform Log-likelihood to Potential 147 | 148 | The model is connected to the solver via the (log-)prior and (log-)likelihood 149 | function. The model for our problem is: 150 | 151 | ```{code-cell} ipython3 152 | def model(sample, observations): 153 | weights = sample["w"] 154 | predictors = observations["x"] 155 | return jnp.dot(predictors, weights) 156 | ``` 157 | 158 | **JaxSGMC** supports samples in the form of pytrees, so no flattering of e.g. 159 | neural network parameters is necessary. In our case we can separate the standard 160 | deviation, which is only part of the likelihood, from the weights by using a 161 | dictionary: 162 | 163 | ```{code-cell} ipython3 164 | sample = {"log_sigma": jnp.array(1.0), "w": jnp.zeros((N, 1))} 165 | 166 | def likelihood(sample, observations): 167 | sigma = jnp.exp(sample["log_sigma"]) 168 | y = observations["y"] 169 | y_pred = model(sample, observations) 170 | return norm.logpdf(y - y_pred, scale=sigma) 171 | 172 | def prior(sample): 173 | return 1 / jnp.exp(sample["log_sigma"]) 174 | 175 | ``` 176 | 177 | The prior and likelihood are not passed to the solver directly, but 178 | first transformed into a (stochastic) potential. 179 | This allows us to formulate the model and so the likelihood with only a single 180 | observation in mind and let **JaxSGMC** take care of evaluating it for a batch 181 | of observations. As the model is not computationally demanding, we let 182 | **JaxSGMC** vectorize the evaluation of the likelihood: 183 | 184 | ```{code-cell} ipython3 185 | potential_fn = potential.minibatch_potential(prior=prior, 186 | likelihood=likelihood, 187 | strategy="vmap") 188 | ``` 189 | 190 | ## RMSprop Adaption 191 | 192 | The adaption module simplifies the implementation of an adaption strategy 193 | by raveling / unraveling the latent variables pytree. 194 | 195 | The RMSprop adaption is characterized by two parameters, which can be set 196 | dynamically for each chain. 197 | As for the data loader arguments, non-default RMSprop parameters must be passed 198 | during the [initialization of the chains](#integrator-and-solver). 199 | 200 | ```{code-cell} ipython3 201 | rms_prop_adaption = adaption.rms_prop() 202 | 203 | adaption_kwargs = { 204 | "lmbd": 1e-6, 205 | "alpha": 0.99 206 | } 207 | ``` 208 | 209 | ## Integrator and Solver 210 | 211 | The integrator proposes new samples based on a specific process which are then 212 | processed by the solver. 213 | For example, the solver might reject a proposal by a Metropolis Hastings 214 | acceptance step (AMAGOLD, SGGMC) or swap it with another proposal by a parallel 215 | tempering chain swap (reSGLD). 216 | 217 | In this case, a Langevin Diffusion process proposes a new sample, which is 218 | accepted unconditionally by the solver. 219 | 220 | After this step we defined our process. 221 | Therefore, we can now initialize the starting states of each chain with the 222 | dynamic settings for the data loader and adaption. 223 | 224 | ```{code-cell} ipython3 225 | langevin_diffusion = integrator.langevin_diffusion(potential_fn=potential_fn, 226 | batch_fn=data_fn, 227 | adaption=rms_prop_adaption) 228 | 229 | # Returns a triplet of init_fn, update_fn and get_fn 230 | rms_prop_solver = solver.sgmc(langevin_diffusion) 231 | 232 | # Initialize the solver by providing initial values for the latent variables. 233 | # We provide extra arguments for the data loader and the adaption method. 234 | init_sample = {"log_sigma": jnp.array(0.0), "w": jnp.zeros(N)} 235 | init_state = rms_prop_solver[0](init_sample, 236 | adaption_kwargs=adaption_kwargs, 237 | batch_kwargs=data_loader_kwargs) 238 | ``` 239 | 240 | ## Scheduler 241 | 242 | Next, we set up a schedule which updates process parameters such as the 243 | temperature and the step size independently of the solver state. 244 | It is moreover necessary to determine which samples should be saved or discarded. 245 | 246 | SGLD only depends on the step size, which is chosen to follow a polynomial 247 | schedule. 248 | However, as only a few and independent samples should be saved, we also set up a 249 | burn in schedule, which rejects the first 2000 samples and a thinning schedule, 250 | which randomly selects 1000 samples not subject to burn in. 251 | 252 | ```{code-cell} ipython3 253 | step_size_schedule = scheduler.polynomial_step_size_first_last(first=0.05, 254 | last=0.001, 255 | gamma=0.33) 256 | burn_in_schedule = scheduler.initial_burn_in(2000) 257 | thinning_schedule = scheduler.random_thinning(step_size_schedule=step_size_schedule, 258 | burn_in_schedule=burn_in_schedule, 259 | selections=1000) 260 | 261 | # Bundles all specific schedules 262 | schedule = scheduler.init_scheduler(step_size=step_size_schedule, 263 | burn_in=burn_in_schedule, 264 | thinning=thinning_schedule) 265 | ``` 266 | 267 | ## Save samples in numpy Arrays 268 | 269 | By default, **JaxSGMC** save accepted samples in the device memory. 270 | However, for some models the required memory rapidly exceeds the available 271 | memory. Therefore, **JaxSGMC** supports saving the samples on the host in a 272 | similar manner as it loads reference data from the host. 273 | 274 | Hence, also the saving step consists of setting up a Data Collector, which takes 275 | care of saving the data in different formats and a general Host Callback Wrapper 276 | which transfers the data out of jit-compiled computations. 277 | 278 | In this example, the data is simply passed to (real) numpy arrays in the host 279 | memory. 280 | 281 | ```{code-cell} ipython3 282 | data_collector = io.MemoryCollector() 283 | save_fn = io.save(data_collector=data_collector) 284 | ``` 285 | 286 | ### Save samples in hdf5 287 | 288 | ```{code-cell} ipython3 289 | import h5py 290 | 291 | data_collector = io.MemoryCollector() 292 | save_fn = io.save(data_collector=data_collector) 293 | ``` 294 | 295 | ## Run Solver 296 | 297 | Finally, all parts of the solver are set up and can be combined to a runnable 298 | process. 299 | The mcmc function updates the scheduler and integrator in the correct order and 300 | passes the results to the saving module. 301 | 302 | The mcmc function can be called with multiple ``init_states`` as 303 | positional arguments to run multiple chains and returns a list of results, one 304 | for each chain. 305 | 306 | ```{code-cell} ipython3 307 | mcmc = solver.mcmc(solver=rms_prop_solver, 308 | scheduler=schedule, 309 | saving=save_fn) 310 | 311 | # Take the result of the first chain 312 | results = mcmc(init_state, iterations=10000)[0] 313 | 314 | 315 | print(f"Collected {results['sample_count']} samples") 316 | ``` 317 | 318 | ## Plot Results 319 | 320 | ```{code-cell} ipython3 321 | :tags: [hide-input] 322 | 323 | plt.figure() 324 | plt.title("Sigma") 325 | 326 | plt.plot(onp.exp(results["samples"]["variables"]["log_sigma"]), label="RMSprop") 327 | 328 | w_rms = results["samples"]["variables"]["w"] 329 | 330 | # w1 vs w2 331 | w1d = onp.linspace(0.00, 0.20, 100) 332 | w2d = onp.linspace(-0.70, -0.30, 100) 333 | W1d, W2d = onp.meshgrid(w1d, w2d) 334 | p12d = onp.vstack([W1d.ravel(), W2d.ravel()]) 335 | 336 | plt.figure() 337 | plt.title("w_1 vs w_2 (rms)") 338 | 339 | plt.xlim([0.07, 0.12]) 340 | plt.ylim([-0.525, -0.450]) 341 | plt.plot(w_rms[:, 0], w_rms[:, 1], 'o', alpha=0.5, markersize=0.5, zorder=-1) 342 | 343 | # w3 vs w4 344 | w3d = onp.linspace(-0.3, -0.05, 100) 345 | w4d = onp.linspace(-0.75, -0.575, 100) 346 | W3d, W4d = onp.meshgrid(w3d, w4d) 347 | p34d = onp.vstack([W3d.ravel(), W4d.ravel()]) 348 | 349 | plt.figure() 350 | plt.title("w_3 vs w_4 (rms)") 351 | plt.plot(w_rms[:, 2], w_rms[:, 3], 'o', alpha=0.5, markersize=0.5, zorder=-1) 352 | ``` 353 | 354 | ## Large Models: Save Data to HDF5 355 | 356 | ```{code-cell} ipython3 357 | # Open a HDF5 file to store data in 358 | with h5py.File("sgld_rms.hdf5", "w") as file: 359 | 360 | data_collector = io.HDF5Collector(file) 361 | save_fn = io.save(data_collector=data_collector) 362 | 363 | mcmc = solver.mcmc(solver=rms_prop_solver, 364 | scheduler=schedule, 365 | saving=save_fn) 366 | 367 | # The solver has to be reinitialized, as the data loader has to be reinitialized 368 | init_state = rms_prop_solver[0](init_sample, 369 | adaption_kwargs=adaption_kwargs, 370 | batch_kwargs=data_loader_kwargs) 371 | results = mcmc(init_state, iterations=10000)[0] 372 | 373 | print(f"Collected {results['sample_count']} samples") 374 | ``` 375 | 376 | ```{code-cell} ipython3 377 | # Sum up and count all values 378 | def map_fn(batch, mask, count): 379 | return jnp.sum(batch["w"].T * mask, axis=1), count + jnp.sum(mask) 380 | 381 | # Load only the samples from the file 382 | with h5py.File("sgld_rms.hdf5", "r") as file: 383 | postprocess_loader = HDF5Loader(file, subdir="/chain~0/variables", sample=init_sample) 384 | 385 | full_data_mapper, _ = data.full_data_mapper(postprocess_loader, 128, 128) 386 | w_sums, count = full_data_mapper(map_fn, 0, masking=True) 387 | 388 | # Sum up the sums from the individual batches 389 | w_means = jnp.sum(w_sums, axis=0) / count 390 | 391 | print(f"Collected {count} samples with means:") 392 | for idx, (w, w_old) in enumerate(zip(w_means, onp.mean(w_rms, axis=0))): 393 | print(f" w_{idx}: new = {w}, old = {w_old})") 394 | ``` 395 | 396 | ```{code-cell} ipython3 397 | :tags: [hide-cell] 398 | 399 | import os 400 | os.remove("sgld_rms.hdf5") 401 | ``` 402 | -------------------------------------------------------------------------------- /tests/test_potential.py: -------------------------------------------------------------------------------- 1 | """Test the evaluation of the potential.""" 2 | 3 | from functools import partial 4 | import itertools 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from jax import random, jit, lax 10 | 11 | import pytest 12 | 13 | # from jax_sgmc.data import mini_batch 14 | from jax_sgmc.potential import minibatch_potential, full_potential 15 | from jax_sgmc.data import MiniBatchInformation, full_reference_data 16 | from jax_sgmc.data.numpy_loader import NumpyDataLoader 17 | from jax_sgmc.util import testing 18 | 19 | # Todo: Test the potential evaluation function on arbitrary pytrees. 20 | 21 | class TestPotential(): 22 | # Helper functions 23 | 24 | @pytest.fixture 25 | def linear_potential(self): 26 | 27 | def likelihood(sample, reference_data): 28 | return jnp.sum(sample * reference_data) 29 | 30 | def prior(sample): 31 | return jnp.sum(sample) 32 | 33 | return prior, likelihood 34 | 35 | @pytest.fixture 36 | def potential(self): 37 | """Define likelihood with pytree sample and pytree reference data.""" 38 | 39 | def likelihood(sample, reference_data): 40 | scale = sample["scale"] 41 | bases = sample["base"] 42 | powers = reference_data["power"] 43 | ref_scale = reference_data["scale"] 44 | return scale * ref_scale * jnp.sum(jnp.power(bases, powers)) 45 | 46 | def prior(sample): 47 | return jnp.exp(-sample["scale"]) 48 | 49 | return prior, likelihood 50 | 51 | @pytest.fixture 52 | def stateful_potential(self): 53 | def likelihood(state, sample, reference_data): 54 | new_state = sample 55 | scale = sample["scale"] * state["scale"] 56 | bases = sample["base"] + state["base"] 57 | powers = reference_data["power"] 58 | ref_scale = reference_data["scale"] 59 | return scale * ref_scale * jnp.sum(jnp.power(bases, powers)), new_state 60 | 61 | def prior(sample): 62 | return jnp.exp(-sample["scale"]) 63 | return prior, likelihood 64 | 65 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 66 | def test_linear_potential(self, linear_potential, obs, dim): 67 | prior, likelihood = linear_potential 68 | # Setup potential 69 | 70 | scan_pot = minibatch_potential(prior, likelihood, strategy="map") 71 | vmap_pot = minibatch_potential(prior, likelihood, strategy="vmap") 72 | # pmap_pot = minibatch_potential(prior, likelihood, strategy="pmap") 73 | 74 | # Setup reference data 75 | key = random.PRNGKey(0) 76 | 77 | split1, split2 = random.split(key, 2) 78 | observations = jnp.tile(jnp.arange(4), (dim, 1)) 79 | reference_data = observations, MiniBatchInformation(observation_count=obs, 80 | batch_size=dim, 81 | mask=jnp.ones(dim)) 82 | sample = jnp.ones(4) 83 | 84 | true_result = -jnp.sum(jnp.arange(4)) * obs -4 85 | 86 | scan_result, _ = scan_pot(sample, reference_data) 87 | vmap_result, _ = vmap_pot(sample, reference_data) 88 | # pmap_result, _ = pmap_pot(sample, reference_data) 89 | 90 | testing.assert_close(scan_result, true_result) 91 | testing.assert_close(vmap_result, true_result) 92 | # test_util.check_close(pmap_result, true_result) 93 | 94 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 95 | def test_stochastic_potential_zero(self, potential, obs, dim): 96 | _, likelihood = potential 97 | prior = lambda _: 0.0 98 | # Setup potential 99 | 100 | scan_pot = minibatch_potential(prior, likelihood, strategy="map") 101 | vmap_pot = minibatch_potential(prior, likelihood, strategy="vmap") 102 | # pmap_pot = minibatch_potential(prior, likelihood, strategy="pmap") 103 | 104 | # Setup reference data 105 | key = random.PRNGKey(0) 106 | 107 | split1, split2 = random.split(key, 2) 108 | observations = {"scale": random.exponential(split1, shape=(obs,)), 109 | "power": random.exponential(split2, shape=(obs, dim))} 110 | reference_data = observations, MiniBatchInformation(observation_count=obs, 111 | batch_size=obs, 112 | mask=jnp.ones(dim)) 113 | sample = {"scale": 0.5, "base": jnp.zeros(dim)} 114 | 115 | zero_array = jnp.array(-0.0) 116 | scan_result, _ = scan_pot(sample, reference_data) 117 | vmap_result, _ = vmap_pot(sample, reference_data) 118 | # pmap_result, _ = pmap_pot(sample, reference_data) 119 | 120 | testing.assert_equal(scan_result, zero_array) 121 | testing.assert_equal(vmap_result, zero_array) 122 | # test_util.check_close(pmap_result, zero_array) 123 | 124 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 125 | def test_stochastic_potential_jit(self, potential, obs, dim): 126 | _, likelihood = potential 127 | prior = lambda _: 0.0 128 | # Setup potential 129 | 130 | scan_pot = jit(minibatch_potential(prior, likelihood, strategy="map")) 131 | vmap_pot = jit(minibatch_potential(prior, likelihood, strategy="vmap")) 132 | # pmap_pot = jit(minibatch_potential(prior, likelihood, strategy="pmap")) 133 | 134 | # Setup reference data 135 | key = random.PRNGKey(0) 136 | 137 | split1, split2 = random.split(key, 2) 138 | observations = {"scale": random.exponential(split1, shape=(obs,)), 139 | "power": random.exponential(split2, shape=(obs, dim))} 140 | reference_data = observations, MiniBatchInformation( 141 | observation_count=obs, 142 | batch_size=obs, 143 | mask=jnp.ones(dim)) 144 | sample = {"scale": 0.5, "base": jnp.zeros(dim)} 145 | 146 | zero_array = jnp.array(-0.0) 147 | scan_result, _ = scan_pot(sample, reference_data) 148 | vmap_result, _ = vmap_pot(sample, reference_data) 149 | # pmap_result, _ = pmap_pot(sample, reference_data) 150 | 151 | testing.assert_equal(scan_result, zero_array) 152 | testing.assert_equal(vmap_result, zero_array) 153 | # test_util.check_close(pmap_result, zero_array) 154 | 155 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 156 | def test_stochastic_potential_equal(self, potential, obs, dim): 157 | prior, likelihood = potential 158 | # Setup potential 159 | 160 | scan_pot = jit(minibatch_potential(prior, likelihood, strategy="map")) 161 | vmap_pot = jit(minibatch_potential(prior, likelihood, strategy="vmap")) 162 | # pmap_pot = jit(minibatch_potential(prior, likelihood, strategy="pmap")) 163 | 164 | # Setup reference data 165 | key = random.PRNGKey(0) 166 | 167 | split1, split2, split3 = random.split(key, 3) 168 | observations = {"scale": random.exponential(split1, shape=(obs,)), 169 | "power": random.exponential(split2, shape=(obs, dim))} 170 | reference_data = observations, MiniBatchInformation( 171 | observation_count=obs, 172 | batch_size=obs, 173 | mask=jnp.ones(dim)) 174 | sample = {"scale": 0.5, "base": random.uniform(split3, (dim, ))} 175 | 176 | scan_result, _ = scan_pot(sample, reference_data) 177 | vmap_result, _ = vmap_pot(sample, reference_data) 178 | # pmap_result, _ = pmap_pot(sample, reference_data) 179 | 180 | testing.assert_close(scan_result, vmap_result) 181 | # test_util.check_close(scan_result, pmap_result) 182 | 183 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 184 | def test_stochastic_potential_gradient_equal(self, potential, obs, dim): 185 | prior, likelihood = potential 186 | # Setup potential 187 | 188 | scan_grad = jit( 189 | jax.grad(minibatch_potential(prior, likelihood, strategy="map"), 190 | has_aux=True, 191 | argnums=0)) 192 | vmap_grad = jit( 193 | jax.grad(minibatch_potential(prior, likelihood, strategy="vmap"), 194 | has_aux=True, 195 | argnums=0)) 196 | # pmap_grad = jit( 197 | # jax.grad(minibatch_potential(prior, likelihood, strategy="pmap"), 198 | # has_aux=True, 199 | # argnums=0)) 200 | 201 | # Setup reference data 202 | key = random.PRNGKey(0) 203 | 204 | split1, split2, split3 = random.split(key, 3) 205 | observations = {"scale": random.exponential(split1, shape=(obs,)), 206 | "power": random.exponential(split2, shape=(obs, dim))} 207 | reference_data = observations, MiniBatchInformation( 208 | observation_count=obs, 209 | batch_size=obs, 210 | mask=jnp.ones(dim)) 211 | sample = {"scale": 0.5, "base": random.uniform(split3, (dim,))} 212 | 213 | scan_result, _ = scan_grad(sample, reference_data) 214 | vmap_result, _ = vmap_grad(sample, reference_data) 215 | # pmap_result, _ = pmap_grad(sample, reference_data) 216 | 217 | testing.assert_close(scan_result, vmap_result) 218 | # test_util.check_close(scan_result, pmap_result) 219 | 220 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 221 | def test_stochastic_potential_gradient_shape(self, potential, obs, dim): 222 | _, likelihood = potential 223 | prior = lambda _: 0.0 224 | # Setup potential 225 | 226 | scan_grad = jit( 227 | jax.grad(minibatch_potential(prior, likelihood, strategy="map"), 228 | has_aux=True, 229 | argnums=0)) 230 | vmap_grad = jit( 231 | jax.grad(minibatch_potential(prior, likelihood, strategy="vmap"), 232 | has_aux=True, 233 | argnums=0)) 234 | # pmap_grad = jit( 235 | # jax.grad(minibatch_potential(prior, likelihood, strategy="pmap"), 236 | # has_aux=True, 237 | # argnums=0)) 238 | 239 | # Setup reference data 240 | key = random.PRNGKey(0) 241 | 242 | # Set scale to zero to get zero gradient 243 | split1, split2 = random.split(key, 2) 244 | observations = {"scale": jnp.zeros(obs), 245 | "power": random.exponential(split1, shape=(obs, dim))} 246 | reference_data = observations, MiniBatchInformation( 247 | observation_count=obs, 248 | batch_size=obs, 249 | mask=jnp.ones(dim)) 250 | sample = {"scale": 0.5, "base": random.uniform(split2, (dim,))} 251 | 252 | zero_gradient = jax.tree_map(jnp.zeros_like, sample) 253 | scan_result, _ = scan_grad(sample, reference_data) 254 | vmap_result, _ = vmap_grad(sample, reference_data) 255 | # pmap_result, _ = pmap_grad(sample, reference_data) 256 | 257 | print(scan_result) 258 | print(vmap_result) 259 | # print(pmap_result) 260 | 261 | testing.assert_equal(scan_result, zero_gradient) 262 | testing.assert_equal(vmap_result, zero_gradient) 263 | # test_util.check_close(pmap_result, zero_gradient) 264 | 265 | @pytest.mark.parametrize("obs, dim", itertools.product([7, 11], [3, 5])) 266 | def test_stateful_stochastic_potential_zero(self, stateful_potential, obs, dim): 267 | _, likelihood = stateful_potential 268 | prior = lambda _: 0.0 269 | # Setup potential 270 | 271 | scan_pot = minibatch_potential(prior, likelihood, strategy="map", has_state=True) 272 | vmap_pot = minibatch_potential(prior, likelihood, strategy="vmap", has_state=True) 273 | # pmap_pot = minibatch_potential(prior, likelihood, strategy="pmap", has_state=True) 274 | 275 | # Setup reference data 276 | key = random.PRNGKey(0) 277 | 278 | split1, split2 = random.split(key, 2) 279 | observations = {"scale": random.exponential(split1, shape=(obs,)), 280 | "power": random.exponential(split2, shape=(obs, dim))} 281 | reference_data = observations, MiniBatchInformation(observation_count=obs, 282 | batch_size=obs, 283 | mask=jnp.ones(dim)) 284 | sample = {"scale": jnp.array([0.5]), "base": jnp.ones(dim)} 285 | init_state = {"scale": jnp.array([0.0]), "base": jnp.zeros(dim)} 286 | 287 | _, new_state_map = scan_pot(sample, reference_data, state=init_state) 288 | _, new_state_vmap = vmap_pot(sample, reference_data, state=init_state) 289 | # _, new_state_pmap = pmap_pot(sample, reference_data, state=init_state) 290 | 291 | print(init_state) 292 | print(new_state_map) 293 | 294 | testing.assert_close(new_state_map, sample) 295 | testing.assert_close(new_state_vmap, sample) 296 | # test_util.check_close(new_state_pmap, sample) 297 | 298 | @pytest.mark.parametrize("obs, dim, mbsize", itertools.product([7, 11], [3, 5], [2, 3])) 299 | def test_full_potential(self, potential, obs, dim, mbsize): 300 | prior, likelihood = potential 301 | # Setup potential 302 | 303 | scan_pot = minibatch_potential(prior, likelihood, strategy="map") 304 | 305 | # Setup reference data 306 | key = random.PRNGKey(0) 307 | 308 | split1, split2 = random.split(key, 2) 309 | observations = {"scale": random.exponential(split1, shape=(obs,)), 310 | "power": random.exponential(split2, shape=(obs, dim))} 311 | reference_data = observations, MiniBatchInformation(observation_count=obs, 312 | batch_size=obs, 313 | mask=jnp.ones(mbsize)) 314 | sample = {"scale": jnp.array([0.5]), "base": jnp.ones(dim)} 315 | init_state = {"scale": jnp.array([0.0]), "base": jnp.zeros(dim)} 316 | 317 | reference_sol, _ = scan_pot(sample, reference_data, state=init_state) 318 | 319 | # Initialize dataloader for full potential evaluation 320 | data_loader = NumpyDataLoader(**observations) 321 | full_data_map = full_reference_data(data_loader, cached_batches_count=2, mb_size=mbsize) 322 | 323 | map_data_state = full_data_map[0]() 324 | vmap_data_state = full_data_map[0]() 325 | 326 | map_pot = full_potential(prior, likelihood, strategy="map") 327 | vmap_pot = full_potential(prior, likelihood, strategy="vmap") 328 | 329 | map_sol, _ = map_pot(sample, map_data_state, full_data_map[1]) 330 | vmap_sol, _ = vmap_pot(sample, vmap_data_state, full_data_map[1]) 331 | 332 | testing.assert_close(reference_sol, map_sol) 333 | testing.assert_close(reference_sol, vmap_sol) 334 | 335 | @pytest.mark.parametrize("obs, dim, mbsize", itertools.product([7, 11], [3, 5], [2, 3])) 336 | def test_variance(self, potential, obs, dim, mbsize): 337 | prior, likelihood = potential 338 | # Setup potential 339 | 340 | scan_pot = minibatch_potential(prior, likelihood, strategy="map") 341 | 342 | # Setup reference data 343 | key = random.PRNGKey(0) 344 | 345 | split1, split2 = random.split(key, 2) 346 | observations = {"scale": random.exponential(split1, shape=(obs,)), 347 | "power": random.exponential(split2, shape=(obs, dim))} 348 | reference_data = observations, MiniBatchInformation(observation_count=obs, 349 | batch_size=obs, 350 | mask=jnp.ones(mbsize)) 351 | 352 | sample = {"scale": jnp.array([0.5]), "base": jnp.ones(dim)} 353 | pot_results = lambda obs: scan_pot( 354 | sample, 355 | (jax.tree_map(partial(jnp.expand_dims, axis=0), obs), 356 | MiniBatchInformation( 357 | observation_count=1, 358 | batch_size=1, 359 | mask=jnp.ones(1)))) 360 | 361 | likelihoods, _ = lax.map(pot_results, observations) 362 | true_variance = jnp.var(likelihoods) 363 | 364 | _, (lkls, _) = scan_pot(sample, reference_data, likelihoods=True) 365 | variance = jnp.var(lkls) 366 | 367 | testing.assert_close(variance, true_variance) -------------------------------------------------------------------------------- /jax_sgmc/data/numpy_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Load numpy arrays in jit-compiled functions. 16 | 17 | The numpy data loader is easy to use if the whole dataset fits into RAM and is 18 | already present as numpy-arrays. 19 | 20 | """ 21 | from copy import deepcopy 22 | 23 | import math 24 | import itertools 25 | from typing import Tuple, Any, Dict, List 26 | 27 | import numpy as onp 28 | import jax.numpy as jnp 29 | import jax 30 | from jax import random 31 | 32 | from jax_sgmc.data.core import DeviceDataLoader, HostDataLoader, DataLoader 33 | from jax_sgmc.data.core import MiniBatchInformation 34 | from jax_sgmc.data.core import tree_index 35 | from jax_sgmc.util import Array 36 | 37 | PyTree = Any 38 | 39 | class NumpyBase(DataLoader): 40 | 41 | def __init__(self, on_device: bool = True, copy=True, **reference_data): 42 | super().__init__() 43 | 44 | observation_counts = [] 45 | self._reference_data = {} 46 | for name, array in reference_data.items(): 47 | observation_counts.append(len(array)) 48 | # Transform to jax arrays if on device 49 | if on_device: 50 | self._reference_data[name] = jnp.array(array, copy=copy) 51 | else: 52 | self._reference_data[name] = onp.array(array, copy=copy) 53 | 54 | if len(observation_counts) != 0: 55 | # Check same number of observations 56 | if onp.any(onp.array(observation_counts) != observation_counts[0]): 57 | raise ValueError("All reference_data arrays must have the same length " 58 | "in the first dimension.") 59 | 60 | self._observation_count = observation_counts[0] 61 | else: 62 | self._observation_count = 0 63 | 64 | @property 65 | def reference_data(self): 66 | """Returns the reference data as a dictionary.""" 67 | return self._reference_data 68 | 69 | @property 70 | def _format(self): 71 | """Returns shape and dtype of a single observation. """ 72 | mb_format = {} 73 | for name, array in self._reference_data.items(): 74 | # Get the format and dtype of the data 75 | mb_format[name] = jax.ShapeDtypeStruct( 76 | dtype=self._reference_data[name].dtype, 77 | shape=tuple(int(s) for s in array.shape[1:])) 78 | return mb_format 79 | 80 | @property 81 | def static_information(self): 82 | """Returns information about total samples count and batch size. """ 83 | information = { 84 | "observation_count": self._observation_count 85 | } 86 | return information 87 | 88 | 89 | class DeviceNumpyDataLoader(NumpyBase, DeviceDataLoader): 90 | """Load complete dataset into memory from multiple numpy arrays. 91 | 92 | This data loader supports checkpointing, starting chains from a well-defined 93 | state and true random access. 94 | 95 | The pipeline can be constructed directly from numpy arrays: 96 | 97 | .. doctest:: 98 | 99 | >>> import numpy as onp 100 | >>> from jax_sgmc.data.numpy_loader import DeviceNumpyDataLoader 101 | >>> 102 | >>> x, y = onp.arange(10), onp.zeros((10, 4, 3)) 103 | >>> 104 | >>> data_loader = DeviceNumpyDataLoader(name_for_x=x, name_for_y=y) 105 | >>> 106 | >>> zero_batch = data_loader.initializer_batch(4) 107 | >>> for key, value in zero_batch.items(): 108 | ... print(f"{key}: shape={value.shape}, dtype={value.dtype}") 109 | name_for_x: shape=(4,), dtype=int32 110 | name_for_y: shape=(4, 4, 3), dtype=float32 111 | 112 | Args: 113 | reference_data: Each kwarg-pair is an entry in the returned data-dict. 114 | copy: Whether to copy the reference data (default True) or only create a 115 | reference. 116 | 117 | """ 118 | 119 | def __init__(self, copy=True, **reference_data): 120 | super().__init__(on_device=True, copy=copy, **reference_data) 121 | 122 | def init_random_data(self, *args, **kwargs) -> PyTree: 123 | del args 124 | key = kwargs.get("key", random.PRNGKey(0)) 125 | return key 126 | 127 | # Todo: Provide shuffling and in_epoch_shuffling too 128 | def get_random_data(self, 129 | state, 130 | batch_size 131 | ) ->Tuple[PyTree, Tuple[PyTree, MiniBatchInformation]]: 132 | key, split = random.split(state) 133 | selection_indices = random.randint( 134 | split, shape=(batch_size,), minval=0, maxval=self._observation_count) 135 | 136 | selected_observations = tree_index(self._reference_data, selection_indices) 137 | info = MiniBatchInformation(observation_count=self._observation_count, 138 | batch_size=batch_size, 139 | mask=jnp.ones(batch_size, dtype=jnp.bool_)) 140 | 141 | return key, (selected_observations, info) 142 | 143 | def get_full_data(self) -> Dict: 144 | return self._reference_data 145 | 146 | 147 | class NumpyDataLoader(NumpyBase, HostDataLoader): 148 | """Load complete dataset into memory from multiple numpy arrays. 149 | 150 | This data loader supports checkpointing, starting chains from a well-defined 151 | state and true random access. 152 | 153 | The pipeline can be constructed directly from numpy arrays: 154 | 155 | .. doctest:: 156 | 157 | >>> import numpy as onp 158 | >>> from jax_sgmc.data.numpy_loader import NumpyDataLoader 159 | >>> 160 | >>> x, y = onp.arange(10), onp.zeros((10, 4, 3)) 161 | >>> 162 | >>> data_loader = NumpyDataLoader(name_for_x=x, name_for_y=y) 163 | >>> 164 | >>> zero_batch = data_loader.initializer_batch(4) 165 | >>> for key, value in zero_batch.items(): 166 | ... print(f"{key}: shape={value.shape}, dtype={value.dtype}") 167 | name_for_x: shape=(4,), dtype=int32 168 | name_for_y: shape=(4, 4, 3), dtype=float32 169 | 170 | Args: 171 | reference_data: Each kwarg-pair is an entry in the returned data-dict. 172 | copy: Whether to copy the reference data (default True) or only create a 173 | reference. 174 | 175 | """ 176 | 177 | def __init__(self, 178 | copy=True, 179 | **reference_data): 180 | super().__init__( 181 | on_device=False, 182 | copy=copy, 183 | **reference_data) 184 | self._chains: List = [] 185 | 186 | def save_state(self, chain_id: int) -> PyTree: 187 | """Returns all necessary information to restore the dataloader state. 188 | 189 | Args: 190 | chain_id: Each chain can be checkpointed independently. 191 | 192 | Returns: 193 | Returns necessary information to restore the state of the chain via 194 | :func:`load_state`. 195 | 196 | """ 197 | # Get the state of all random data generators. All other information will be 198 | # set by initializing the generator on the same way as before 199 | 200 | chain_data = self._chains[chain_id] 201 | if chain_data['type'] == 'random': 202 | data = {key: deepcopy(value) 203 | for key, value in chain_data.items() if key != 'rng'} 204 | return {'random': (chain_data['rng'].bit_generator.state, data)} 205 | elif chain_data['type'] == 'ordered': 206 | return {'ordered': chain_data['idx_offset']} 207 | else: 208 | raise ValueError(f"Chain type {chain_data['type']} is unknown.") 209 | 210 | def load_state(self, chain_id: int, data) -> None: 211 | """Restores dataloader state from previously computed checkpoint. 212 | 213 | Args: 214 | chain_id: The chain to restore the state. 215 | data: Data from :func:`save_state` to restore state of the chain. 216 | 217 | """ 218 | # Restore the state by setting the random number generators to the 219 | # checkpointed state 220 | type, value = data.popitem() 221 | if type == 'random': 222 | rng_state, chain_data = value 223 | self._chains[chain_id]['rng'].bit_generator.state = rng_state 224 | for key, value in chain_data.items(): 225 | self._chains[chain_id][key] = value 226 | elif type == 'ordered': 227 | self._chains[chain_id]['idx_offset'] = value 228 | else: 229 | raise ValueError(f"Chain type {type} is unknown.") 230 | 231 | def register_random_pipeline(self, 232 | cache_size: int = 1, 233 | mb_size: int = None, 234 | in_epochs: bool = False, 235 | shuffle: bool = False, 236 | **kwargs: Any) -> int: 237 | """Register a new chain which draws samples randomly. 238 | 239 | Args: 240 | cache_size: The number of drawn batches. 241 | mb_size: The number of observations per batch. 242 | shuffle: Shuffle dataset instead of drawing randomly from the 243 | observations. 244 | in_epochs: Samples returned twice per epoch are marked via mask = 0 (only 245 | if ``shuffle = True``. 246 | seed: Set the random seed to start the chain at a well-defined state. 247 | 248 | Returns: 249 | Returns the id of the new chain. 250 | 251 | """ 252 | # The random state of each chain can be defined unambiguously via the 253 | # PRNGKey 254 | if mb_size > self._observation_count: 255 | raise ValueError(f"The batch size cannot be bigger than the observation " 256 | f"count. Provided {mb_size} and " 257 | f"{self._observation_count}") 258 | if not shuffle and in_epochs: 259 | raise ValueError("in_epochs = True can only be used for shuffle = True.") 260 | 261 | chain_id = len(self._chains) 262 | 263 | seed = kwargs.get("seed", chain_id) 264 | rng = onp.random.default_rng( 265 | onp.random.SeedSequence(seed).spawn(1)[0]) 266 | 267 | # The indices list must have at least the length equal to the number of 268 | # observations but should also be a multiple of the mb_size to simplify 269 | # getting new indices. 270 | new_chain = {'type': 'random', 271 | 'rng': rng, 272 | 'idx_offset': None, 273 | 'in_epochs': in_epochs, 274 | 'shuffle': shuffle, 275 | 'remaining_samples': 0, 276 | 'draws': math.ceil(self._observation_count / mb_size), 277 | 'random_indices': None, 278 | 'mb_size': mb_size, 279 | 'cache_size': cache_size} 280 | 281 | self._chains.append(new_chain) 282 | return chain_id 283 | 284 | def register_ordered_pipeline(self, 285 | cache_size: int = 1, 286 | mb_size: int = None, 287 | **kwargs 288 | ) -> int: 289 | """Register a chain which assembles batches in an ordered manner. 290 | 291 | Args: 292 | cache_size: The number of drawn batches. 293 | mb_size: The number of observations per batch. 294 | seed: Set the random seed to start the chain at a well-defined state. 295 | 296 | Returns: 297 | Returns the id of the new chain. 298 | 299 | """ 300 | assert mb_size <= self._observation_count, \ 301 | (f"The batch size cannot be bigger than the observation count. Provided " 302 | f"{mb_size} and {self._observation_count}") 303 | chain_id = len(self._chains) 304 | 305 | new_chain = {'type': 'ordered', 306 | 'rng': None, 307 | 'idx_offset': 0, 308 | 'mb_size': mb_size, 309 | 'cache_size': cache_size} 310 | 311 | self._chains.append(new_chain) 312 | return chain_id 313 | 314 | def get_batches(self, chain_id: int) -> PyTree: 315 | """Draws a batch from a chain. 316 | 317 | Args: 318 | chain_id: ID of the chain, which holds the information about the form of 319 | the batch and the process of assembling. 320 | 321 | Returns: 322 | Returns a batch of batches as registered by 323 | :func:`register_random_pipeline` or :func:`register_ordered_pipeline` with 324 | `cache_size` batches holding `mb_size` observations. 325 | 326 | """ 327 | # Data slicing is the same for all methods of random and ordered access, 328 | # only the indices for slicing differ. The method _get_indices find the 329 | # correct method for the chain. 330 | selections_idx, selections_mask = self._get_indices(chain_id) 331 | 332 | # Slice the data and transform into device array. 333 | selected_observations: Dict[str, Array] = {} 334 | for key, data in self._reference_data.items(): 335 | if data.ndim == 1: 336 | selection = jnp.array(data[selections_idx,]) 337 | else: 338 | selection = jnp.array(data[selections_idx,::]) 339 | selected_observations[key] = selection 340 | return selected_observations, jnp.array(selections_mask, dtype=jnp.bool_) 341 | 342 | def _get_indices(self, chain_id: int): 343 | chain = self._chains[chain_id] 344 | if chain['type'] == 'ordered': 345 | index_fn = self._ordered_indices 346 | elif chain['in_epochs']: 347 | index_fn = self._shuffle_in_epochs 348 | elif chain['shuffle']: 349 | index_fn = self._shuffle_indices 350 | else: 351 | index_fn = self._draw_indices 352 | indices, masks = list(zip(*map( 353 | lambda _: index_fn(chain), 354 | itertools.repeat(chain_id, self._chains[chain_id]['cache_size'])))) 355 | return onp.array(indices), onp.array(masks, dtype=onp.bool_) 356 | 357 | def _ordered_indices(self, chain): 358 | idcs = onp.arange(chain['mb_size']) + chain['idx_offset'] 359 | # For consistency also return a mask to mark 360 | # the samples returned double. 361 | mask = onp.arange(chain['mb_size']) + chain['idx_offset'] < self._observation_count 362 | 363 | # Start again at the first sample if all samples have been returned 364 | if chain['idx_offset'] + chain['mb_size'] >= self._observation_count: 365 | chain['idx_offset'] = 0 366 | else: 367 | chain['idx_offset'] += chain['mb_size'] 368 | # Simply return the first samples again if less samples remain than 369 | # necessary to fill the cache. 370 | return onp.mod(idcs, self._observation_count), mask 371 | 372 | def _random_indices(self, chain_id: int) -> Tuple[List, Any]: 373 | """Returns indices and mask to access random data. """ 374 | chain = self._chains[chain_id] 375 | if chain['in_epochs']: 376 | return self._shuffle_in_epochs(chain) 377 | elif chain['shuffle']: 378 | return self._shuffle_indices(chain) 379 | else: 380 | return self._draw_indices(chain) 381 | 382 | def _draw_indices(self, chain): 383 | # Randomly choose batches 384 | selections = chain['rng'].choice( 385 | onp.arange(0, self._observation_count), 386 | size=chain['mb_size'], 387 | replace=True) 388 | mask = onp.ones(chain['mb_size'], dtype=onp.bool_) 389 | return selections, mask 390 | 391 | def _shuffle_indices(self, chain): 392 | floor_draws = math.floor(self._observation_count / chain['mb_size']) 393 | # The partial valid cache must not be changed when updating the indices 394 | ceil_draws = floor_draws + 2 395 | 396 | if chain['remaining_samples'] < chain['mb_size']: 397 | # The indices have to be refreshed. Shuffling is equivalent to drawing 398 | # without replacement. 399 | new_indices = chain['rng'].choice( 400 | onp.arange(0, self._observation_count), 401 | size=self._observation_count, 402 | replace=False) 403 | 404 | if chain['random_indices'] is None: 405 | # Special options for first run 406 | chain['draws'] = 0 407 | chain['random_indices'] = onp.zeros( 408 | ceil_draws * chain['mb_size'], dtype=onp.int_) 409 | 410 | # Update only invalid samples (do not overwrite still valid samples) 411 | update_idxs = onp.mod( 412 | onp.arange(self._observation_count) 413 | + chain['draws'] * chain['mb_size'] 414 | + chain['remaining_samples'], 415 | ceil_draws * chain['mb_size']) 416 | chain['random_indices'][update_idxs] = new_indices 417 | chain['remaining_samples'] += self._observation_count 418 | 419 | # All samples are valid 420 | mask = onp.ones(chain['mb_size'], dtype=onp.bool_) 421 | 422 | # Take the new indices 423 | selections_idxs = onp.mod( 424 | onp.arange(chain['mb_size']) + chain['draws'] * chain['mb_size'], 425 | chain['mb_size'] * ceil_draws) 426 | selections = onp.copy(chain['random_indices'][selections_idxs]) 427 | chain['draws'] = (chain['draws'] + 1) % ceil_draws 428 | chain['remaining_samples'] -= chain['mb_size'] 429 | 430 | return selections, mask 431 | 432 | def _shuffle_in_epochs(self, chain): 433 | ceil_draws = math.ceil(self._observation_count / chain['mb_size']) 434 | 435 | if chain['draws'] == ceil_draws: 436 | # The indices have to be refreshed. Shuffling is equivalent to drawing 437 | # without replacement. 438 | new_indices = chain['rng'].choice( 439 | onp.arange(0, self._observation_count), 440 | size=self._observation_count, 441 | replace=False) 442 | 443 | if chain['random_indices'] is None: 444 | # Special options for first run 445 | chain['draws'] = 0 446 | chain['random_indices'] = onp.zeros( 447 | ceil_draws * chain['mb_size'], dtype=onp.int_) 448 | 449 | chain['random_indices'][0:self._observation_count] = new_indices 450 | chain['draws'] = 0 451 | 452 | start_idx = chain['mb_size'] * chain['draws'] 453 | end_idx = chain['mb_size'] * (chain['draws'] + 1) 454 | 455 | mask = onp.arange(start_idx, end_idx) < self._observation_count 456 | 457 | selections = onp.copy(chain['random_indices'][start_idx:end_idx]) 458 | chain['draws'] += 1 459 | 460 | return selections, mask 461 | --------------------------------------------------------------------------------