├── src └── flowMC │ ├── __init__.py │ ├── resource │ ├── __init__.py │ ├── kernel │ │ ├── __init__.py │ │ ├── base.py │ │ ├── Gaussian_random_walk.py │ │ ├── MALA.py │ │ ├── HMC.py │ │ └── NF_proposal.py │ ├── model │ │ ├── __init__.py │ │ ├── nf_model │ │ │ ├── __init__.py │ │ │ ├── realNVP.py │ │ │ └── base.py │ │ ├── flowmatching │ │ │ ├── __init__.py │ │ │ └── base.py │ │ └── common.py │ ├── optimizer.py │ ├── base.py │ ├── buffers.py │ ├── states.py │ └── logPDF.py │ ├── strategy │ ├── __init__.py │ ├── importance_sampling.py │ ├── sequential_monte_carlo.py │ ├── base.py │ ├── lambda_function.py │ ├── update_state.py │ ├── train_model.py │ ├── optimization.py │ └── take_steps.py │ ├── resource_strategy_bundle │ ├── __init__.py │ ├── base.py │ └── RQSpline_MALA.py │ └── Sampler.py ├── ruff.toml ├── docs ├── contribution.md ├── dual_moon.png ├── logo_0810.png ├── tutorials │ └── dual_moon.png ├── stylesheets │ └── extra.css ├── requirements.txt ├── gen_ref_pages.py ├── communityExamples.md ├── index.md ├── FAQ.md ├── configuration.md └── quickstart.md ├── .coverage ├── .gitattributes ├── readthedocs.yml ├── .devcontainer └── devcontainer.json ├── .gitignore ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── pre-commit.yml │ ├── workflowsjoss.yml │ ├── run_tests.yml │ └── python-publish.yml ├── .pre-commit-config.yaml ├── LICENSE ├── test ├── integration │ ├── test_quickstart.py │ ├── test_normalizingFlow.py │ ├── test_MALA.py │ ├── test_HMC.py │ └── test_RWMCMC.py └── unit │ ├── test_bundle.py │ ├── test_nf.py │ ├── test_resources.py │ └── test_flowmatching.py ├── pyproject.toml ├── CONTRIBUTING.md ├── .all-contributorsrc ├── CODE_OF_CONDUCT.md ├── mkdocs.yml ├── joss ├── paper.bib └── paper.md └── README.md /src/flowMC/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | lint.ignore = ["F722"] -------------------------------------------------------------------------------- /src/flowMC/resource/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flowMC/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/contribution.md: -------------------------------------------------------------------------------- 1 | ../CONTRIBUTING.md -------------------------------------------------------------------------------- /src/flowMC/resource/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flowMC/strategy/importance_sampling.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/nf_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flowMC/resource_strategy_bundle/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/flowmatching/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.coverage: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kazewong/flowMC/HEAD/.coverage -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | docs/** linguist-vendored 2 | example/** linguist-vendored -------------------------------------------------------------------------------- /docs/dual_moon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kazewong/flowMC/HEAD/docs/dual_moon.png -------------------------------------------------------------------------------- /docs/logo_0810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kazewong/flowMC/HEAD/docs/logo_0810.png -------------------------------------------------------------------------------- /docs/tutorials/dual_moon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kazewong/flowMC/HEAD/docs/tutorials/dual_moon.png -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --md-primary-fg-color: #200f80; 3 | --md-primary-fg-color--light: #6cb8bb; 4 | --md-primary-fg-color--dark: #6cb8bb; 5 | } -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | install: 5 | - requirements: docs/requirements.txt 6 | - method: pip 7 | path: . 8 | 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | mkdocs: 15 | configuration: mkdocs.yml -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/universal:2", 3 | "features": { 4 | "ghcr.io/va-h/devcontainers-features/uv:1": {}, 5 | "ghcr.io/devcontainers/features/python:1": {}, 6 | "ghcr.io/devcontainers-extra/features/ruff:1": {} 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | flowMC/nfmodel/__pycache__/maf.cpython-37.pyc 3 | log_dir/ 4 | *.pkl 5 | flowMC/nfmodel/realNVP_white.py 6 | example/kepler/kepler_white.py 7 | example/kepler/kepler_analyse.py 8 | dist 9 | *egg-info* 10 | docs/build/* 11 | *_autosummary 12 | README.md 13 | *settings.json 14 | *build* 15 | test_docs/test_quickstart.ipynb 16 | node_modules 17 | package.json 18 | yarn.lock -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.6.1 # Main documentation generator. 2 | mkdocs-material==9.5.47 # Theme 3 | pymdown-extensions==10.12 # Markdown extensions e.g. to handle LaTeX. 4 | mkdocstrings[python]==0.27.0 # Autogenerate documentation from docstrings. 5 | mkdocs-jupyter==0.25.1 # Turn Jupyter Lab notebooks into webpages. 6 | mkdocs-gen-files==0.5.0 7 | mkdocs-literate-nav==0.6.1 8 | -------------------------------------------------------------------------------- /src/flowMC/resource_strategy_bundle/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from flowMC.resource.base import Resource 4 | from flowMC.strategy.base import Strategy 5 | 6 | 7 | class ResourceStrategyBundle(ABC): 8 | """Resource-Strategy Bundle is aim to be the highest level of abstraction in the 9 | flowMC library. 10 | 11 | It is a collection of resources and strategies that are used to perform a specific 12 | task. 13 | """ 14 | 15 | resources: dict[str, Resource] 16 | strategies: dict[str, Strategy] 17 | strategy_order: list[str] 18 | -------------------------------------------------------------------------------- /src/flowMC/strategy/sequential_monte_carlo.py: -------------------------------------------------------------------------------- 1 | from flowMC.resource.base import Resource 2 | from jaxtyping import Array, Float, PRNGKeyArray 3 | 4 | 5 | class SequentialMonteCarlo(Resource): 6 | def __init__(self): 7 | raise NotImplementedError 8 | 9 | def __call__( 10 | self, 11 | rng_key: PRNGKeyArray, 12 | resources: dict[str, Resource], 13 | initial_position: Float[Array, "n_chains n_dim"], 14 | data: dict, 15 | ) -> tuple[ 16 | PRNGKeyArray, 17 | dict[str, Resource], 18 | Float[Array, "n_chains n_dim"], 19 | ]: 20 | raise NotImplementedError 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 25.1.0 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/charliermarsh/ruff-pre-commit 7 | rev: 'v0.9.10' 8 | hooks: 9 | - id: ruff 10 | args: ["--fix"] 11 | - repo: https://github.com/RobertCraigie/pyright-python 12 | rev: v1.1.396 13 | hooks: 14 | - id: pyright 15 | additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, typing_extensions, equinox, optax, tqdm, diffrax] 16 | - repo: https://github.com/nbQA-dev/nbQA 17 | rev: 1.9.1 18 | hooks: 19 | - id: nbqa-black 20 | additional_dependencies: [ipython==8.12, black] 21 | - id: nbqa-ruff-format 22 | additional_dependencies: [ipython==8.12, ruff] -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | 10 | pre-commit: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.11", "3.12"] 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | python -m pip install pytest 26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 27 | python -m pip install . 28 | - uses: pre-commit/action@v3.0.1 -------------------------------------------------------------------------------- /.github/workflows/workflowsjoss.yml: -------------------------------------------------------------------------------- 1 | name: Update-Joss 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - 40-joss-paper 8 | 9 | jobs: 10 | paper: 11 | runs-on: ubuntu-latest 12 | name: Paper Draft 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v2 16 | - name: Build draft PDF 17 | uses: openjournals/openjournals-draft-action@master 18 | with: 19 | journal: joss 20 | # This should be the path to the paper within your repo. 21 | paper-path: joss/paper.md 22 | - name: Upload 23 | uses: actions/upload-artifact@v1 24 | with: 25 | name: paper 26 | # This is the output path where Pandoc will write the compiled 27 | # PDF. Note, this should be the same directory as the input 28 | # paper.md 29 | path: joss/paper.pdf 30 | -------------------------------------------------------------------------------- /src/flowMC/strategy/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from jaxtyping import Array, Float, PRNGKeyArray 4 | 5 | from flowMC.resource.base import Resource 6 | 7 | 8 | class Strategy(ABC): 9 | """Base class for strategies, which are basically wrapper blocks that modify the 10 | state of the sampler. 11 | 12 | This is an abstract template that should not be directly used. 13 | """ 14 | 15 | @abstractmethod 16 | def __init__(self): 17 | raise NotImplementedError 18 | 19 | @abstractmethod 20 | def __call__( 21 | self, 22 | rng_key: PRNGKeyArray, 23 | resources: dict[str, Resource], 24 | initial_position: Float[Array, "n_chains n_dim"], 25 | data: dict, 26 | ) -> tuple[ 27 | PRNGKeyArray, 28 | dict[str, Resource], 29 | Float[Array, "n_chains n_dim"], 30 | ]: 31 | raise NotImplementedError 32 | -------------------------------------------------------------------------------- /src/flowMC/resource/kernel/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import equinox as eqx 3 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree 4 | from flowMC.resource.base import Resource 5 | from flowMC.resource.logPDF import LogPDF 6 | from typing import Callable 7 | 8 | 9 | class ProposalBase(eqx.Module, Resource): 10 | @abstractmethod 11 | def __init__( 12 | self, 13 | ): 14 | """Initialize the sampler class.""" 15 | 16 | @abstractmethod 17 | def kernel( 18 | self, 19 | rng_key: PRNGKeyArray, 20 | position: Float[Array, "nstep n_dim"], 21 | log_prob: Float[Array, "nstep 1"], 22 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 23 | data: PyTree, 24 | ) -> tuple[ 25 | Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"] 26 | ]: 27 | """Kernel for one step in the proposal cycle.""" 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /docs/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.Nav() 8 | 9 | 10 | for path in sorted(Path("src").rglob("*.py")): # 11 | module_path = path.relative_to("src").with_suffix("") # 12 | doc_path = path.relative_to("src").with_suffix(".md") # 13 | full_doc_path = Path("api", doc_path) # 14 | 15 | parts = list(module_path.parts) 16 | 17 | if parts[-1] == "__init__": # 18 | continue 19 | elif parts[-1] == "__main__": 20 | continue 21 | 22 | nav[parts] = doc_path.as_posix() 23 | 24 | print(full_doc_path) 25 | 26 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: # 27 | identifier = ".".join(parts) # 28 | print("::: " + identifier, file=fd) # 29 | 30 | mkdocs_gen_files.set_edit_path(full_doc_path, path) # 31 | 32 | with mkdocs_gen_files.open("api/summary.md", "w") as nav_file: # 33 | nav_file.writelines(nav.build_literate_nav()) 34 | -------------------------------------------------------------------------------- /src/flowMC/resource/optimizer.py: -------------------------------------------------------------------------------- 1 | from flowMC.resource.base import Resource 2 | import optax 3 | import equinox as eqx 4 | 5 | 6 | class Optimizer(Resource): 7 | optim: optax.GradientTransformation 8 | optim_state: optax.OptState 9 | 10 | def __repr__(self): 11 | return "Optimizer" 12 | 13 | def __init__( 14 | self, 15 | model: eqx.Module, 16 | learning_rate: float = 1e-3, 17 | momentum: float = 0.9, 18 | ): 19 | self.optim = optax.chain( 20 | optax.clip_by_global_norm(1.0), 21 | optax.adamw(learning_rate=learning_rate, b1=momentum), 22 | ) 23 | self.optim_state = self.optim.init(eqx.filter(model, eqx.is_array)) 24 | 25 | def __call__(self, params, grads): 26 | raise NotImplementedError 27 | 28 | def print_parameters(self): 29 | raise NotImplementedError 30 | 31 | def save_resource(self, path: str): 32 | raise NotImplementedError 33 | 34 | def load_resource(self, path: str): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /src/flowMC/resource/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from typing import Self 4 | 5 | 6 | class Resource(ABC): 7 | """Base class for resources. Resources are objects such as local sampler and neural 8 | networks. 9 | 10 | This is an abstract template that should not be directly used. 11 | """ 12 | 13 | @abstractmethod 14 | def __init__(self): 15 | raise NotImplementedError 16 | 17 | @abstractmethod 18 | def print_parameters(self): 19 | """Function to print the tunable parameters of the resource.""" 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def save_resource(self, path: str): 24 | """Function to save the resource. 25 | 26 | Args: 27 | path (str): Path to save the resource. 28 | """ 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def load_resource(self, path: str) -> Self: 33 | """Function to load the resource. 34 | 35 | Args: 36 | path (str): Path to load the resource. 37 | """ 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kaze Wong & contributor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Run Tests 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.11", "3.12"] 16 | 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | 21 | - name: Install uv 22 | uses: astral-sh/setup-uv@v3 23 | 24 | - name: Set up Python ${{ matrix.python-version }} 25 | run: uv python install ${{ matrix.python-version }} 26 | 27 | - name: Install the project 28 | run: uv sync --all-extras --dev 29 | 30 | - name: Run tests and coverage 31 | run: | 32 | uv run coverage run --source=src -m pytest test 33 | uv run coveralls --service=github-actions 34 | env: 35 | COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} 36 | 37 | - name: Run build 38 | run: uv build -------------------------------------------------------------------------------- /docs/communityExamples.md: -------------------------------------------------------------------------------- 1 | # Community examples 2 | 3 | The core design philosophy of flowMC is to stay lean and simple, so we decided to leave most of the use case-specific optimization to the users. Hence, we keep most of the flowMC internals away from most of the users, 4 | so the only interfaces with flowMC are really defining your likelihood and tuning the sampler parameters exposed on the top level. 5 | That said, it would be useful to have references to see how to use/tune flowMC for different use cases, therefore in this page we host a number of community examples that are contributed by the users. 6 | If you find flowMC useful, please consider contributing your example to this page. This will help other users (and perhaps your future students) to get started quickly. 7 | 8 | ## Examples 9 | 10 | - [jim - A JAX-based gravitational-wave inference toolkit](https://github.com/kazewong/jim) 11 | - [Bayeux - Stitching together models and samplers](https://github.com/jax-ml/bayeux) 12 | - [Colab example](https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing) 13 | - [Markovian Flow Matching: Accelerating MCMC with Continuous Normalizing Flows](https://arxiv.org/pdf/2405.14392) -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | workflow_dispatch: 15 | 16 | 17 | permissions: 18 | contents: read 19 | 20 | jobs: 21 | deploy: 22 | 23 | runs-on: ubuntu-latest 24 | 25 | steps: 26 | - uses: actions/checkout@v3 27 | - name: Set up Python 28 | uses: actions/setup-python@v3 29 | with: 30 | python-version: '3.x' 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install build 35 | - name: Build package 36 | run: python -m build 37 | - name: Publish package 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | with: 40 | user: __token__ 41 | password: ${{ secrets.PYPI_API_TOKEN }} 42 | -------------------------------------------------------------------------------- /src/flowMC/strategy/lambda_function.py: -------------------------------------------------------------------------------- 1 | from flowMC.strategy.base import Strategy 2 | from flowMC.resource.base import Resource 3 | from jaxtyping import Array, Float, PRNGKeyArray 4 | from typing import Callable 5 | 6 | 7 | class Lambda(Strategy): 8 | """A strategy that applies a function to the resources. 9 | 10 | This should be used for simple functions or calling methods from 11 | a class. 12 | If you find yourself writing a Lambda strategy that is more than a few lines 13 | long, consider writing a custom strategy instead. 14 | 15 | """ 16 | 17 | def __init__(self, lambda_function: Callable): 18 | """Initialize the lambda strategy. 19 | 20 | Args: 21 | lambda: A callable that takes a resource and applies the lambda function. 22 | """ 23 | self.lambda_function = lambda_function 24 | 25 | def __call__( 26 | self, 27 | rng_key: PRNGKeyArray, 28 | resources: dict[str, Resource], 29 | initial_position: Float[Array, "n_chains n_dim"], 30 | data: dict, 31 | ) -> tuple[ 32 | PRNGKeyArray, 33 | dict[str, Resource], 34 | Float[Array, "n_chains n_dim"], 35 | ]: 36 | self.lambda_function(rng_key, resources, initial_position, data) 37 | return rng_key, resources, initial_position 38 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | flowMC 2 | ====== 3 | 4 | **Normalizing-flow enhanced sampling package for probabilistic inference** 5 | 6 | 7 | ![](logo_0810.png) 8 | 9 | [![](https://badgen.net/badge/Read/the doc/blue)](https://flowMC.readthedocs.io/en/latest/) 10 | [![](https://badgen.net/badge/License/MIT/blue)](https//github.com/kazewong/flowMC/blob/Packaging/LICENSE) 11 | 12 | 13 | 14 | flowMC is a Jax-based python package for normalizing-flow enhanced Markov chain Monte Carlo (MCMC) sampling. 15 | The code is open source under MIT license, and it is under active development. 16 | 17 | - Just-in-time compilation is supported. 18 | - Native support for GPU acceleration. 19 | - Suit for problems with multi-modality and complex geometry. 20 | 21 | Four steps to use flowMC's guide 22 | ================================ 23 | 24 | 1. You can find installation info, a basic example, and some design and guiding principles of `flowMC` in the [quickstart](quickstart.md). 25 | 2. In tutorials, we have a set of pedagogical notebooks that will give you a better understanding of the package infrastructure and a number of common use cases. 26 | 3. We list some community examples in [community_example](communityExamples.md), so users can see whether there is a similar use case they can adopt their code quickly. 27 | 4. Finally, we have a list of frequently asked questions in [FAQ](FAQ.md). 28 | 29 | -------------------------------------------------------------------------------- /test/integration/test_quickstart.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flowMC.Sampler import Sampler 4 | from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle 5 | 6 | 7 | def log_posterior(x, data: dict): 8 | return -0.5 * jnp.sum((x - data["data"]) ** 2) 9 | 10 | 11 | n_dims = 2 12 | n_local_steps = 10 13 | n_global_steps = 10 14 | n_training_loops = 5 15 | n_production_loops = 5 16 | n_epochs = 10 17 | n_chains = 10 18 | rq_spline_hidden_units = [64, 64] 19 | rq_spline_n_bins = 8 20 | rq_spline_n_layers = 3 21 | data = {"data": jnp.arange(n_dims).astype(jnp.float32)} 22 | 23 | rng_key = jax.random.PRNGKey(42) 24 | rng_key, subkey = jax.random.split(rng_key) 25 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1 26 | 27 | rng_key, subkey = jax.random.split(rng_key) 28 | bundle = RQSpline_MALA_Bundle( 29 | subkey, 30 | n_chains, 31 | n_dims, 32 | log_posterior, 33 | n_local_steps, 34 | n_global_steps, 35 | n_training_loops, 36 | n_production_loops, 37 | n_epochs, 38 | rq_spline_hidden_units=rq_spline_hidden_units, 39 | rq_spline_n_bins=rq_spline_n_bins, 40 | rq_spline_n_layers=rq_spline_n_layers, 41 | verbose=True, 42 | ) 43 | 44 | nf_sampler = Sampler( 45 | n_dims, 46 | n_chains, 47 | rng_key, 48 | resource_strategy_bundles=bundle, 49 | ) 50 | 51 | nf_sampler.sample(initial_position, data) 52 | -------------------------------------------------------------------------------- /test/unit/test_bundle.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle 3 | from flowMC.resource_strategy_bundle.RQSpline_MALA_PT import RQSpline_MALA_PT_Bundle 4 | 5 | 6 | def logpdf(x, _): 7 | return -0.5 * (x**2).sum() 8 | 9 | 10 | def test_rqspline_mala_bundle_initialization(): 11 | rng_key = jax.random.PRNGKey(0) 12 | n_chains = 2 13 | n_dims = 3 14 | n_local_steps = 10 15 | n_global_steps = 5 16 | n_training_loops = 2 17 | n_production_loops = 1 18 | n_epochs = 3 19 | 20 | bundle = RQSpline_MALA_Bundle( 21 | rng_key=rng_key, 22 | n_chains=n_chains, 23 | n_dims=n_dims, 24 | logpdf=logpdf, 25 | n_local_steps=n_local_steps, 26 | n_global_steps=n_global_steps, 27 | n_training_loops=n_training_loops, 28 | n_production_loops=n_production_loops, 29 | n_epochs=n_epochs, 30 | ) 31 | 32 | assert repr(bundle) == "RQSpline_MALA Bundle" 33 | 34 | 35 | def test_rqspline_mala_pt_bundle_initialization(): 36 | rng_key = jax.random.PRNGKey(0) 37 | n_chains = 2 38 | n_dims = 3 39 | n_local_steps = 10 40 | n_global_steps = 5 41 | n_training_loops = 2 42 | n_production_loops = 1 43 | n_epochs = 3 44 | 45 | bundle = RQSpline_MALA_PT_Bundle( 46 | rng_key=rng_key, 47 | n_chains=n_chains, 48 | n_dims=n_dims, 49 | logpdf=logpdf, 50 | n_local_steps=n_local_steps, 51 | n_global_steps=n_global_steps, 52 | n_training_loops=n_training_loops, 53 | n_production_loops=n_production_loops, 54 | n_epochs=n_epochs, 55 | ) 56 | 57 | assert repr(bundle) == "RQSpline MALA PT Bundle" 58 | -------------------------------------------------------------------------------- /src/flowMC/strategy/update_state.py: -------------------------------------------------------------------------------- 1 | from flowMC.resource.states import State 2 | from flowMC.strategy.base import Strategy 3 | from flowMC.resource.base import Resource 4 | from jaxtyping import Array, Float, PRNGKeyArray 5 | 6 | 7 | class UpdateState(Strategy): 8 | """Update a state resource in place. 9 | 10 | This strategy is meant to be used to update the state not too frequently. 11 | If you are looking for an option that iterates over some parameters, 12 | say the paramters of a neural network, you should write a custom strategy 13 | that does that. 14 | """ 15 | 16 | def __init__( 17 | self, state_name: str, keys: list[str], values: list[int | bool | str] 18 | ): 19 | """Initialize the update state strategy. 20 | 21 | Args: 22 | state_name (str): The name of the state resource to update. 23 | keys (list[str]): The keys to update in the state resource. 24 | values (list[int | bool | str]): The values to update in the state resource. 25 | """ 26 | self.state_name = state_name 27 | self.keys = keys 28 | self.values = values 29 | 30 | def __call__( 31 | self, 32 | rng_key: PRNGKeyArray, 33 | resources: dict[str, Resource], 34 | initial_position: Float[Array, "n_chains n_dim"], 35 | data: dict, 36 | ) -> tuple[ 37 | PRNGKeyArray, 38 | dict[str, Resource], 39 | Float[Array, "n_chains n_dim"], 40 | ]: 41 | """Update the state resource in place.""" 42 | assert isinstance( 43 | state := resources[self.state_name], State 44 | ), f"Resource {self.state_name} is not a State resource." 45 | 46 | state.update(self.keys, self.values) 47 | return rng_key, resources, initial_position 48 | -------------------------------------------------------------------------------- /test/integration/test_normalizingFlow.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx # Equinox utilities 2 | import jax 3 | import jax.numpy as jnp # JAX NumPy 4 | import optax # Optimizers 5 | 6 | from flowMC.resource.model.nf_model.realNVP import RealNVP 7 | from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline 8 | 9 | 10 | def test_realNVP(): 11 | key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) 12 | data = jax.random.normal(key1, (100, 2)) 13 | 14 | num_epochs = 5 15 | batch_size = 100 16 | learning_rate = 0.001 17 | momentum = 0.9 18 | 19 | model = RealNVP(2, 4, 32, rng) 20 | optim = optax.adam(learning_rate, momentum) 21 | state = optim.init(eqx.filter(model, eqx.is_array)) 22 | 23 | rng, best_model, state, loss_values = model.train( 24 | init_rng, data, optim, state, num_epochs, batch_size, verbose=True 25 | ) 26 | rng_key_nf = jax.random.PRNGKey(124098) 27 | model.sample(rng_key_nf, 10000) 28 | 29 | 30 | def test_rqSpline(): 31 | n_dim = 2 32 | num_epochs = 5 33 | batch_size = 100 34 | learning_rate = 0.001 35 | momentum = 0.9 36 | 37 | key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3) 38 | data = jax.random.normal(key1, (batch_size, n_dim)) 39 | 40 | n_layers = 4 41 | hidden_dim = 32 42 | num_bins = 4 43 | 44 | model = MaskedCouplingRQSpline( 45 | n_dim, 46 | n_layers, 47 | [hidden_dim, hidden_dim], 48 | num_bins, 49 | rng, 50 | data_mean=jnp.mean(data, axis=0), 51 | data_cov=jnp.cov(data.T), 52 | ) 53 | optim = optax.adam(learning_rate, momentum) 54 | state = optim.init(eqx.filter(model, eqx.is_array)) 55 | 56 | rng, best_model, state, loss_values = model.train( 57 | init_rng, data, optim, state, num_epochs, batch_size, verbose=True 58 | ) 59 | rng_key_nf = jax.random.PRNGKey(124098) 60 | model.sample(rng_key_nf, 10000) 61 | -------------------------------------------------------------------------------- /src/flowMC/resource/buffers.py: -------------------------------------------------------------------------------- 1 | from flowMC.resource.base import Resource 2 | from typing import TypeVar 3 | import numpy as np 4 | from jaxtyping import Array, Float 5 | import jax.numpy as jnp 6 | import jax 7 | 8 | TBuffer = TypeVar("TBuffer", bound="Buffer") 9 | 10 | 11 | class Buffer(Resource): 12 | name: str 13 | data: Float[Array, " ..."] 14 | cursor: int = 0 15 | cursor_dim: int = 0 16 | 17 | def __repr__(self): 18 | return "Buffer " + self.name + " with shape " + str(self.data.shape) 19 | 20 | @property 21 | def shape(self): 22 | return self.data.shape 23 | 24 | def __init__(self, name: str, shape: tuple[int, ...], cursor_dim: int = 0): 25 | self.cursor_dim = cursor_dim 26 | self.name = name 27 | self.data = jnp.zeros(shape) - jnp.inf 28 | 29 | def __call__(self): 30 | return self.data 31 | 32 | def update_buffer(self, updates: Array, start: int = 0): 33 | """Update the buffer with new data. 34 | 35 | This will modify the buffer in place. 36 | The cursor is expected to propagate the buffer in the cursor_dim 37 | with length equal to the length of the updates in its first dimension. 38 | """ 39 | self.data = jax.lax.dynamic_update_slice_in_dim( 40 | self.data, updates, start, self.cursor_dim 41 | ) 42 | 43 | def print_parameters(self): 44 | print( 45 | f"Buffer: {self.name} with shape {self.data.shape} and cursor" 46 | f" {self.cursor} at dimension {self.cursor_dim}" 47 | ) 48 | 49 | def get_distribution(self, n_bins: int = 100): 50 | return np.histogram(self.data.flatten(), bins=n_bins) 51 | 52 | def save_resource(self, path: str): 53 | np.savez( 54 | path + self.name, 55 | name=self.name, 56 | data=self.data, 57 | ) 58 | 59 | def load_resource(self: TBuffer, path: str) -> TBuffer: 60 | data = np.load(path) 61 | buffer: Float[Array, " ..."] = data["data"] 62 | result = Buffer(data["name"], buffer.shape) 63 | result.data = buffer 64 | return result # type: ignore 65 | -------------------------------------------------------------------------------- /test/unit/test_nf.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from flowMC.resource.model.nf_model.realNVP import AffineCoupling, RealNVP 5 | from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline 6 | 7 | 8 | def test_affine_coupling_forward_and_inverse(): 9 | n_features = 2 10 | n_hidden = 4 11 | x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) 12 | mask = jnp.where(jnp.arange(n_features) % 2 == 0, 1.0, 0.0) 13 | key = jax.random.PRNGKey(0) 14 | dt = 0.5 15 | layer = AffineCoupling(n_features, n_hidden, mask, key, dt) 16 | 17 | y_forward, log_det_forward = jax.vmap(layer.forward)(x) 18 | x_recon, log_det_inverse = jax.vmap(layer.inverse)(y_forward) 19 | 20 | assert jnp.allclose(x, jnp.round(x_recon, decimals=5)) 21 | assert jnp.allclose(log_det_forward, -log_det_inverse) 22 | 23 | 24 | def test_realnvp(): 25 | n_features = 3 26 | n_hidden = 4 27 | n_layers = 2 28 | x = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 29 | 30 | rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2) 31 | model = RealNVP(n_features, n_layers, n_hidden, rng_key) 32 | 33 | y, log_det = jax.vmap(model)(x) 34 | 35 | assert y.shape == x.shape 36 | assert log_det.shape == (2,) 37 | 38 | y_inv, log_det_inv = jax.vmap(model.inverse)(y) 39 | 40 | assert y_inv.shape == x.shape 41 | assert log_det_inv.shape == (2,) 42 | assert jnp.allclose(x, y_inv) 43 | assert jnp.allclose(log_det, -log_det_inv) 44 | 45 | rng_key = jax.random.PRNGKey(0) 46 | samples = model.sample(rng_key, 2) 47 | 48 | assert samples.shape == (2, 3) 49 | 50 | log_prob = jax.vmap(model.log_prob)(samples) 51 | 52 | assert log_prob.shape == (2,) 53 | 54 | 55 | def test_rqspline(): 56 | n_features = 3 57 | hidden_layes = [16, 16] 58 | n_layers = 2 59 | n_bins = 8 60 | 61 | rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2) 62 | model = MaskedCouplingRQSpline( 63 | n_features, n_layers, hidden_layes, n_bins, jax.random.PRNGKey(10) 64 | ) 65 | 66 | rng_key = jax.random.PRNGKey(0) 67 | samples = model.sample(rng_key, 2) 68 | 69 | assert samples.shape == (2, 3) 70 | 71 | log_prob = jax.vmap(model.log_prob)(samples) 72 | 73 | assert log_prob.shape == (2,) 74 | -------------------------------------------------------------------------------- /src/flowMC/resource/states.py: -------------------------------------------------------------------------------- 1 | from flowMC.resource.base import Resource 2 | from typing import TypeVar 3 | import numpy as np 4 | 5 | TState = TypeVar("TState", bound="State") 6 | 7 | 8 | class State(Resource): 9 | """A Resource class that holds the state of the system. 10 | This is essentially a wrapper around a dictionary such that it can be 11 | handled by flowMC. 12 | 13 | We restrict the type of the state to be simple types including integers, booleans and strings. 14 | The main reason for this is State is expected to be used to indiciate stage of individual 15 | strategies instead of storing parameters to resources. 16 | I.e. State should hold whehter the sampler is in training phase or production phase. 17 | But not mass matrix of a kernel per se. 18 | """ 19 | 20 | name: str 21 | data: dict[str, int | bool | str] 22 | 23 | def __repr__(self): 24 | return "State " + self.name + " with shape " + str(len(self.data)) 25 | 26 | def __init__(self, data: dict[str, int | bool | str], name: str = "State"): 27 | """Initialize the state. 28 | 29 | Args: 30 | data (dict): The data to initialize the state with. 31 | name (str): The name of the state. 32 | """ 33 | 34 | self.name = name 35 | self.data = data 36 | 37 | def update(self, key: list[str], value: list[int | bool | str]): 38 | """Update the state with new data. 39 | 40 | This will modify the state in place. 41 | 42 | Args: 43 | key (str): The key to update. 44 | value (int | bool | str): The value to update. 45 | """ 46 | for k, v in zip(key, value): 47 | self.data[k] = v 48 | print(f"Updated state {k} to {v}") 49 | 50 | def print_parameters(self): 51 | print(f"State: {self.name} with shape {len(self.data)} and data {self.data}") 52 | 53 | def save_resource(self, path: str): 54 | np.savez( 55 | path + self.name, 56 | name=self.name, 57 | data=self.data, # type: ignore 58 | ) 59 | 60 | def load_resource(self: TState, path: str) -> TState: 61 | data = np.load(path) 62 | result = State(data["name"], data["data"]) 63 | return result # type: ignore 64 | -------------------------------------------------------------------------------- /test/integration/test_MALA.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.special import logsumexp 4 | from jaxtyping import Array, Float 5 | 6 | from flowMC.resource.kernel.MALA import MALA 7 | from flowMC.strategy.take_steps import TakeSerialSteps 8 | from flowMC.resource.buffers import Buffer 9 | from flowMC.resource.states import State 10 | from flowMC.resource.logPDF import LogPDF 11 | from flowMC.Sampler import Sampler 12 | 13 | 14 | def dual_moon_pe(x: Float[Array, "n_dims"], data: dict): 15 | """ 16 | Term 2 and 3 separate the distribution 17 | and smear it along the first and second dimension 18 | """ 19 | print("compile count") 20 | term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 21 | term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 22 | term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 23 | return -(term1 - logsumexp(term2) - logsumexp(term3)) 24 | 25 | 26 | n_dims = 5 27 | n_chains = 15 28 | n_local_steps = 30 29 | step_size = 0.01 30 | 31 | data = {"data": jnp.arange(5)} 32 | 33 | rng_key = jax.random.PRNGKey(42) 34 | rng_key, subkey = jax.random.split(rng_key) 35 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1 36 | 37 | # Defining resources 38 | 39 | MALA_Sampler = MALA(step_size=step_size) 40 | positions = Buffer("positions", (n_chains, n_local_steps, n_dims), 1) 41 | log_prob = Buffer("log_prob", (n_chains, n_local_steps), 1) 42 | acceptance = Buffer("acceptance", (n_chains, n_local_steps), 1) 43 | sampler_state = State( 44 | { 45 | "positions": "positions", 46 | "log_prob": "log_prob", 47 | "acceptance": "acceptance", 48 | }, 49 | name="sampler_state", 50 | ) 51 | 52 | resource = { 53 | "positions": positions, 54 | "log_prob": log_prob, 55 | "acceptance": acceptance, 56 | "MALA": MALA_Sampler, 57 | "logpdf": LogPDF(dual_moon_pe, n_dims=n_dims), 58 | "sampler_state": sampler_state, 59 | } 60 | 61 | # Defining strategy 62 | 63 | strategy = TakeSerialSteps( 64 | "logpdf", 65 | kernel_name="MALA", 66 | state_name="sampler_state", 67 | buffer_names=["positions", "log_prob", "acceptance"], 68 | n_steps=n_local_steps, 69 | ) 70 | 71 | nf_sampler = Sampler( 72 | n_dim=n_dims, 73 | n_chains=n_chains, 74 | rng_key=rng_key, 75 | resources=resource, 76 | strategies={"take_steps": strategy}, 77 | strategy_order=["take_steps"], 78 | ) 79 | 80 | nf_sampler.sample(initial_position, data) 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flowMC" 3 | version = "0.4.5" 4 | description = "Normalizing flow exhanced sampler in jax" 5 | authors = [ 6 | { name = "Kaze Wong", email = "kazewong.physics@gmail.com"}, 7 | { name = "Marylou Gabrié"}, 8 | { name = "Dan Foreman-Mackey"} 9 | ] 10 | 11 | classifiers = [ 12 | "Programming Language :: Python :: 3.11", 13 | "Programming Language :: Python :: 3.12", 14 | "Programming Language :: Python :: 3.13", 15 | "License :: OSI Approved :: MIT License", 16 | "Intended Audience :: Science/Research", 17 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 18 | "Topic :: Scientific/Engineering :: Mathematics", 19 | "Topic :: Scientific/Engineering :: Information Analysis", 20 | ] 21 | 22 | readme = "README.md" 23 | requires-python = ">=3.11" 24 | keywords = ["sampling", "inference", "machine learning", "normalizing", "autodiff", "jax"] 25 | dependencies = [ 26 | "chex>=0.1.87", 27 | "diffrax>=0.7.0", 28 | "equinox>=0.11.9", 29 | "jax[cpu]>=0.5.0", 30 | "jaxtyping>=0.2.36", 31 | "optax>=0.2.4", 32 | "scikit-learn>=1.6.0", 33 | "tqdm>=4.67.1", 34 | ] 35 | license = { file = "LICENSE" } 36 | 37 | [project.urls] 38 | Documentation = "https://github.com/kazewong/flowMC" 39 | 40 | 41 | [project.optional-dependencies] 42 | docs = [ 43 | "mkdocs-gen-files==0.5.0", 44 | "mkdocs-jupyter==0.25.1", 45 | "mkdocs-literate-nav==0.6.1", 46 | "mkdocs-material==9.5.47", 47 | "mkdocs==1.6.1", 48 | "mkdocstrings[python]==0.27.0", 49 | "pymdown-extensions==10.12", 50 | ] 51 | visualize = [ 52 | "arviz>=0.21.0", 53 | "corner>=2.2.3", 54 | "matplotlib>=3.9.3", 55 | ] 56 | cuda = [ 57 | "jax[cuda12]>=0.5.0", 58 | ] 59 | 60 | [dependency-groups] 61 | dev = [ 62 | "flowMC", 63 | "ipykernel>=6.29.5", 64 | "coveralls>=4.0.1", 65 | "pre-commit>=4.0.1", 66 | "pyright>=1.1.389", 67 | "pytest>=8.3.3", 68 | "ruff>=0.8.0", 69 | "ipython>=8.30.0", 70 | ] 71 | 72 | [tool.uv.sources] 73 | flowMC = { workspace = true } 74 | 75 | 76 | [build-system] 77 | requires = ["hatchling"] 78 | build-backend = "hatchling.build" 79 | 80 | 81 | [tool.pyright] 82 | include = [ 83 | "src", 84 | "tests", 85 | ] 86 | exclude = [ 87 | "docs" 88 | ] 89 | 90 | [tool.coverage.report] 91 | exclude_also = [ 92 | 'def __repr__', 93 | "raise AssertionError", 94 | "raise NotImplementedError", 95 | "@(abc\\. )?abstractmethod", 96 | "def tree_flatten", 97 | "def tree_unflatten", 98 | ] 99 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | FAQ 2 | === 3 | 4 | **My local sampler is not accepting** 5 | 6 | This usually means you are setting the step size of the local samplerto be too big. 7 | Try reducing the step size in your local sampler. 8 | 9 | Alternative, this could also mean your sampler is proposing in region where the likelihood is ill-defined (i.e. NaN either in likelihood or its derivative if you are using a gradient-based local sampler). 10 | It is worth making sure your likelihood is well-defined within your range of prior. 11 | 12 | **In order for my local sampler to accept, I have to choose a very small step size, which makes my chain very correlated.** 13 | 14 | This usually indicate some of your parameters are much better measured than others. 15 | Since taking a small step in those directions will already change your likelihood value by a lot, the exploration power of the local sampler in other parameters are limited by those which are well measured. 16 | Currently, we support different step size for different parameters, which you can tune to see whether that improves the situation or not. 17 | If you know the scale of each parameter ahead of time, reparameterizing them to maintain roughly equal scale across parameters also helps. 18 | 19 | **My global sample's loss is exploding/not decreasing** 20 | 21 | This usually means your learning rate used for training the normalizing flow is too large. 22 | Try reducing the learning rate by a factor of ten. 23 | 24 | Another reason for a flat loss is your local sampler is not accepting at all. 25 | This is a bit rarer since this means your data used to train the normalizing flow is just your prior, which the normalizing flow should still be able to learn. 26 | 27 | **The sampler is stuck a bit until it starts sampling** 28 | 29 | If you use the option ``Jit`` in constructing the local sampler, the code will compile your code to speed up the execution. 30 | The sampler is not really stuck, but it is compiling the code. Depending on how you code up your likelihood function, the compilation can take a while. 31 | If you don't want to wait, you can set ``Jit=False``, which would increase the sampling time. 32 | 33 | **The compilation is slow** 34 | 35 | If you have a likelihood with many lines, Jax will take a long time to compile the code. 36 | Jax is known to be slow in compilation, especially if your computational graph uses some sort of loop that call a function many times. 37 | While we cannot fundamentally get rid of the problem, [using a jax.lax.scan](https://docs.kidger.site/equinox/tricks/#low-overhead-training-loops) is usually how we deal with it. 38 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | ### Expectations 3 | 4 | flowMC is developed and maintained in my spare time and, while I try to be 5 | responsive, I don't always get to every issue immediately. If it has been more 6 | than a week or two, feel free to ping me (@kazewong) to try to get my attention. This is 7 | subject to changes as the community grows. 8 | 9 | ### Did you find a bug? 10 | 11 | **Ensure the bug was not already reported** by searching on GitHub under 12 | [Issues](https://github.com/kazewong/flowMC/issues). If you're unable to find an 13 | open issue addressing the problem, [open a new 14 | one](https://github.com/kazewong/flowMC/issues/new). Be sure to include a **title 15 | and clear description**, as much relevant information as possible, and the 16 | simplest possible **code sample** demonstrating the expected behavior that is 17 | not occurring. Also label the issue with the bug label. 18 | 19 | ### Did you write a patch that fixes a bug? 20 | 21 | Open a new GitHub pull request with the patch. Ensure the PR description clearly 22 | describes the problem and solution. Include the relevant issue number if 23 | applicable. 24 | 25 | ### Do you intend to add a new feature or change an existing feature? 26 | 27 | Please follow the following principle when you are thinking about adding a new 28 | feature or changing an existing feature: 29 | 30 | 1. The new feature should be able to take advantage of `jax.jit` whenever possible. 31 | 2. Light weight and modular implementation is preferred. 32 | 3. The core package only does sampling. If you have a concrete example that 33 | involves a complete analysis such as plotting and models, see the next 34 | contribution guide. 35 | 36 | Suggestions for new features are welcome on [flowMC support 37 | group](https://groups.google.com/u/1/g/flowmc). Note that features related to the 38 | core algorithm are unlikely to be accepted since that may include a lot of 39 | breaking changes. 40 | 41 | ### Do you intend to introduce an example or tutorial? 42 | 43 | Open a new GitHub pull request with the example or tutorial. The example should 44 | be self-contained and keep import from other packages to minimal. Leave the 45 | case-specific analysis detail out. For more extensive tutorial, we encourage the 46 | community to link the minimal example hosted on the flowMC documentation to 47 | documentation from other packages. 48 | 49 | ### Do you have question about the code? 50 | 51 | Do not open an issue. Instead, find whether there are already existing threads 52 | on the [flowMC support group](https://groups.google.com/u/1/g/flowmc). If not, 53 | please open a new conversation there. 54 | -------------------------------------------------------------------------------- /test/integration/test_HMC.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.special import logsumexp 4 | from jaxtyping import Array, Float 5 | 6 | from flowMC.resource.kernel.HMC import HMC 7 | from flowMC.strategy.take_steps import TakeSerialSteps 8 | from flowMC.resource.buffers import Buffer 9 | from flowMC.resource.states import State 10 | from flowMC.resource.logPDF import LogPDF 11 | from flowMC.Sampler import Sampler 12 | 13 | 14 | def dual_moon_pe(x: Float[Array, " n_dims"], data: dict): 15 | """ 16 | Term 2 and 3 separate the distribution 17 | and smear it along the first and second dimension 18 | """ 19 | print("compile count") 20 | term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 21 | term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 22 | term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 23 | return -(term1 - logsumexp(term2) - logsumexp(term3)) 24 | 25 | 26 | # Setup parameters 27 | n_dims = 5 28 | n_chains = 15 29 | n_local_steps = 30 30 | step_size = 0.1 31 | n_leapfrog = 10 32 | 33 | data = {"data": jnp.arange(5)} 34 | 35 | # Initialize positions 36 | rng_key = jax.random.PRNGKey(42) 37 | rng_key, subkey = jax.random.split(rng_key) 38 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1 39 | 40 | # Define resources 41 | HMC_sampler = HMC( 42 | step_size=step_size, 43 | n_leapfrog=n_leapfrog, 44 | condition_matrix=jnp.eye(n_dims), 45 | ) 46 | 47 | positions = Buffer("positions", (n_chains, n_local_steps, n_dims), 1) 48 | log_prob = Buffer("log_prob", (n_chains, n_local_steps), 1) 49 | acceptance = Buffer("acceptance", (n_chains, n_local_steps), 1) 50 | 51 | sampler_state = State( 52 | { 53 | "positions": "positions", 54 | "log_prob": "log_prob", 55 | "acceptance": "acceptance", 56 | }, 57 | name="sampler_state", 58 | ) 59 | 60 | resource = { 61 | "positions": positions, 62 | "log_prob": log_prob, 63 | "acceptance": acceptance, 64 | "HMC": HMC_sampler, 65 | "logpdf": LogPDF(dual_moon_pe, n_dims=n_dims), 66 | "sampler_state": sampler_state, 67 | } 68 | 69 | # Define strategy 70 | strategy = TakeSerialSteps( 71 | "logpdf", 72 | kernel_name="HMC", 73 | state_name="sampler_state", 74 | buffer_names=["positions", "log_prob", "acceptance"], 75 | n_steps=n_local_steps, 76 | ) 77 | 78 | # Initialize and run sampler 79 | nf_sampler = Sampler( 80 | n_dim=n_dims, 81 | n_chains=n_chains, 82 | rng_key=rng_key, 83 | resources=resource, 84 | strategies={"take_steps": strategy}, 85 | strategy_order=["take_steps"], 86 | ) 87 | 88 | nf_sampler.sample(initial_position, data) 89 | -------------------------------------------------------------------------------- /src/flowMC/resource/kernel/Gaussian_random_walk.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree 4 | from typing import Callable 5 | 6 | from flowMC.resource.kernel.base import ProposalBase 7 | from flowMC.resource.logPDF import LogPDF 8 | 9 | 10 | class GaussianRandomWalk(ProposalBase): 11 | """Gaussian random walk sampler class.""" 12 | 13 | step_size: Float 14 | 15 | def __repr__(self): 16 | return "Gaussian Random Walk with step size " + str(self.step_size) 17 | 18 | def __init__( 19 | self, 20 | step_size: Float, 21 | ): 22 | super().__init__() 23 | self.step_size = step_size 24 | 25 | def kernel( 26 | self, 27 | rng_key: PRNGKeyArray, 28 | position: Float[Array, " n_dim"], 29 | log_prob: Float[Array, "1"], 30 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 31 | data: PyTree, 32 | ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: 33 | """Random walk gaussian kernel. This is a kernel that only evolve a single 34 | chain. 35 | 36 | Args: 37 | rng_key (PRNGKeyArray): Jax PRNGKey 38 | position (Float[Array, "n_dim"]): current position of the chain 39 | log_prob (Float[Array, "1"]): current log-probability of the chain 40 | data (PyTree): data to be passed to the logpdf function 41 | 42 | Returns: 43 | position (Float[Array, "n_dim"]): new position of the chain 44 | log_prob (Float[Array, "1"]): new log-probability of the chain 45 | do_accept (Int[Array, "1"]): whether the new position is accepted 46 | """ 47 | 48 | key1, key2 = jax.random.split(rng_key) 49 | move_proposal: Float[Array, " n_dim"] = ( 50 | jax.random.normal(key1, shape=position.shape) * self.step_size 51 | ) 52 | 53 | proposal = position + move_proposal 54 | proposal_log_prob: Float[Array, " n_dim"] = logpdf(proposal, data) 55 | 56 | log_uniform = jnp.log(jax.random.uniform(key2)) 57 | do_accept = log_uniform < proposal_log_prob - log_prob 58 | 59 | position = jnp.where(do_accept, proposal, position) 60 | log_prob = jnp.where(do_accept, proposal_log_prob, log_prob) 61 | return position, log_prob, do_accept 62 | 63 | def print_parameters(self): 64 | print("Gaussian Random Walk parameters:") 65 | print(f"step_size: {self.step_size}") 66 | 67 | def save_resource(self, path): 68 | raise NotImplementedError 69 | 70 | def load_resource(self, path): 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /test/integration/test_RWMCMC.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.special import logsumexp 4 | from jaxtyping import Array, Float 5 | 6 | from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk 7 | from flowMC.strategy.take_steps import TakeSerialSteps 8 | from flowMC.resource.buffers import Buffer 9 | from flowMC.resource.states import State 10 | from flowMC.resource.logPDF import LogPDF 11 | from flowMC.Sampler import Sampler 12 | 13 | 14 | def dual_moon_pe(x: Float[Array, " n_dims"], data: dict): 15 | """ 16 | Term 2 and 3 separate the distribution 17 | and smear it along the first and second dimension 18 | """ 19 | print("compile count") 20 | term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2 21 | term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2 22 | term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2 23 | return -(term1 - logsumexp(term2) - logsumexp(term3)) 24 | 25 | 26 | # Test parameters 27 | n_dims = 5 28 | n_chains = 2 29 | n_local_steps = 3 30 | n_global_steps = 3 31 | step_size = 0.1 32 | 33 | data = {"data": jnp.arange(5)} 34 | 35 | # Initialize random key and position 36 | rng_key = jax.random.PRNGKey(43) 37 | rng_key, subkey = jax.random.split(rng_key) 38 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1 39 | 40 | # Define resources 41 | RWMCMC_sampler = GaussianRandomWalk(step_size=step_size) 42 | positions = Buffer("positions", (n_chains, n_local_steps, n_dims), 1) 43 | log_prob = Buffer("log_prob", (n_chains, n_local_steps), 1) 44 | acceptance = Buffer("acceptance", (n_chains, n_local_steps), 1) 45 | sampler_state = State( 46 | { 47 | "positions": "positions", 48 | "log_prob": "log_prob", 49 | "acceptance": "acceptance", 50 | }, 51 | name="sampler_state", 52 | ) 53 | 54 | # Initialize normalizing flow model 55 | rng_key, subkey = jax.random.split(rng_key) 56 | 57 | resource = { 58 | "logpdf": LogPDF(dual_moon_pe, n_dims=n_dims), 59 | "positions": positions, 60 | "log_prob": log_prob, 61 | "acceptance": acceptance, 62 | "RWMCMC": RWMCMC_sampler, 63 | "sampler_state": sampler_state, 64 | } 65 | 66 | # Define strategy 67 | strategy = TakeSerialSteps( 68 | "logpdf", 69 | kernel_name="RWMCMC", 70 | state_name="sampler_state", 71 | buffer_names=["positions", "log_prob", "acceptance"], 72 | n_steps=n_local_steps, 73 | ) 74 | 75 | print("Initializing sampler class") 76 | 77 | # Initialize and run sampler 78 | nf_sampler = Sampler( 79 | n_dim=n_dims, 80 | n_chains=n_chains, 81 | rng_key=rng_key, 82 | resources=resource, 83 | strategies={"take_steps": strategy}, 84 | strategy_order=["take_steps"], 85 | ) 86 | 87 | nf_sampler.sample(initial_position, data) 88 | -------------------------------------------------------------------------------- /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "projectName": "flowMC", 3 | "projectOwner": "kazewong", 4 | "files": [ 5 | "README.md" 6 | ], 7 | "commitType": "docs", 8 | "commitConvention": "angular", 9 | "contributorsPerLine": 7, 10 | "contributors": [ 11 | { 12 | "login": "HajimeKawahara", 13 | "name": "Hajime Kawahara", 14 | "avatar_url": "https://avatars.githubusercontent.com/u/15956904?v=4", 15 | "profile": "http://secondearths.sakura.ne.jp/en/index.html", 16 | "contributions": [ 17 | "bug" 18 | ] 19 | }, 20 | { 21 | "login": "daniel-dodd", 22 | "name": "Daniel Dodd", 23 | "avatar_url": "https://avatars.githubusercontent.com/u/68821880?v=4", 24 | "profile": "https://github.com/daniel-dodd", 25 | "contributions": [ 26 | "doc", 27 | "review", 28 | "test", 29 | "bug" 30 | ] 31 | }, 32 | { 33 | "login": "matt-graham", 34 | "name": "Matt Graham", 35 | "avatar_url": "https://avatars.githubusercontent.com/u/6746980?v=4", 36 | "profile": "http://matt-graham.github.io", 37 | "contributions": [ 38 | "bug", 39 | "test", 40 | "review", 41 | "doc" 42 | ] 43 | }, 44 | { 45 | "login": "kazewong", 46 | "name": "Kaze Wong", 47 | "avatar_url": "https://avatars.githubusercontent.com/u/8803931?v=4", 48 | "profile": "https://www.kaze-wong.com/", 49 | "contributions": [ 50 | "bug", 51 | "blog", 52 | "code", 53 | "content", 54 | "doc", 55 | "example", 56 | "infra", 57 | "maintenance", 58 | "research", 59 | "review", 60 | "test", 61 | "tutorial" 62 | ] 63 | }, 64 | { 65 | "login": "marylou-gabrie", 66 | "name": "Marylou Gabrié", 67 | "avatar_url": "https://avatars.githubusercontent.com/u/11092071?v=4", 68 | "profile": "https://marylou-gabrie.github.io/", 69 | "contributions": [ 70 | "bug", 71 | "code", 72 | "content", 73 | "doc", 74 | "example", 75 | "maintenance", 76 | "research", 77 | "test", 78 | "tutorial" 79 | ] 80 | }, 81 | { 82 | "login": "Qazalbash", 83 | "name": "Meesum Qazalbash", 84 | "avatar_url": "https://avatars.githubusercontent.com/u/62182585?v=4", 85 | "profile": "https://github.com/Qazalbash", 86 | "contributions": [ 87 | "code", 88 | "maintenance" 89 | ] 90 | }, 91 | { 92 | "login": "thomasckng", 93 | "name": "Thomas Ng", 94 | "avatar_url": "https://avatars.githubusercontent.com/u/97585527?v=4", 95 | "profile": "https://github.com/thomasckng", 96 | "contributions": [ 97 | "code", 98 | "maintenance" 99 | ] 100 | }, 101 | { 102 | "login": "tedwards2412", 103 | "name": "Thomas Edwards", 104 | "avatar_url": "https://avatars.githubusercontent.com/u/6105841?v=4", 105 | "profile": "https://github.com/tedwards2412", 106 | "contributions": [ 107 | "bug", 108 | "code" 109 | ] 110 | } 111 | ], 112 | "repoType": "github" 113 | } 114 | -------------------------------------------------------------------------------- /test/unit/test_resources.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flowMC.resource.buffers import Buffer 4 | from flowMC.resource.logPDF import LogPDF, Variable, TemperedPDF 5 | from flowMC.resource.kernel.MALA import MALA 6 | from flowMC.resource.states import State 7 | from flowMC.strategy.take_steps import TakeSerialSteps 8 | 9 | 10 | class TestLogPDF: 11 | 12 | n_dims = 5 13 | 14 | def posterior(self, x, data): 15 | return -jnp.sum(jnp.square(x - data["data"])) 16 | 17 | def test_value_and_grad(self): 18 | logpdf = LogPDF( 19 | self.posterior, [Variable("x_" + str(i), True) for i in range(self.n_dims)] 20 | ) 21 | inputs = jnp.arange(self.n_dims).astype(jnp.float32) 22 | data = {"data": jnp.ones(self.n_dims)} 23 | values, grads = jax.value_and_grad(logpdf)(inputs, data) 24 | assert values == self.posterior(inputs, data) 25 | 26 | def test_resource(self): 27 | mala = MALA(1.0) 28 | logpdf = LogPDF( 29 | self.posterior, [Variable("x_" + str(i), True) for i in range(self.n_dims)] 30 | ) 31 | rng_key = jax.random.PRNGKey(0) 32 | initial_position = jnp.zeros(self.n_dims) 33 | data = {"data": jnp.ones(self.n_dims)} 34 | sampler_state = State( 35 | { 36 | "target_positions": "test_positions", 37 | "target_log_probs": "test_log_probs", 38 | "target_acceptances": "test_acceptances", 39 | }, 40 | name="sampler_state", 41 | ) 42 | resources = { 43 | "test_positions": Buffer("test_positions", (self.n_dims, 1), 1), 44 | "test_log_probs": Buffer("test_log_probs", (self.n_dims, 1), 1), 45 | "test_acceptances": Buffer("test_acceptances", (self.n_dims, 1), 1), 46 | "MALA": mala, 47 | "logpdf": logpdf, 48 | "sampler_state": sampler_state, 49 | } 50 | stepper = TakeSerialSteps( 51 | "logpdf", 52 | "MALA", 53 | "sampler_state", 54 | ["target_positions", "target_log_probs", "target_acceptances"], 55 | 1, 56 | ) 57 | key, resources, positions = stepper(rng_key, resources, initial_position, data) 58 | 59 | def test_tempered_pdf(self): 60 | n_temps = 5 61 | logpdf = TemperedPDF( 62 | self.posterior, 63 | lambda x, data: jnp.zeros(1), 64 | n_dims=self.n_dims, 65 | n_temps=n_temps, 66 | max_temp=100, 67 | ) 68 | inputs = jnp.ones((n_temps, self.n_dims)).astype(jnp.float32) 69 | data = {"data": jnp.ones(self.n_dims)} 70 | temperatures = jnp.arange(n_temps) + 1.0 71 | values = jax.vmap(logpdf.tempered_log_pdf, in_axes=(0, 0, None))( 72 | temperatures, inputs, data 73 | ) 74 | assert ( 75 | values[:, 0] == jax.vmap(self.posterior, in_axes=(0, None))(inputs, data) 76 | ).all() 77 | assert values.shape == (5, 1) 78 | 79 | 80 | class TestBuffer: 81 | def test_buffer(self): 82 | buffer = Buffer("test", (10, 10), cursor_dim=1) 83 | assert buffer.name == "test" 84 | assert buffer.data.shape == (10, 10) 85 | assert buffer.cursor == 0 86 | assert buffer.cursor_dim == 1 87 | 88 | def test_update_buffer(self): 89 | buffer = Buffer("test", (10, 10), cursor_dim=0) 90 | buffer.update_buffer(jnp.ones((10, 10))) 91 | assert (buffer.data == jnp.ones((10, 10))).all() 92 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | kazewong.physics@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | 70 | ## Attribution 71 | 72 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 73 | version 2.0, available at 74 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 75 | 76 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 77 | enforcement ladder](https://github.com/mozilla/diversity). 78 | 79 | [homepage]: https://www.contributor-covenant.org 80 | 81 | For answers to common questions about this code of conduct, see the FAQ at 82 | https://www.contributor-covenant.org/faq. Translations are available at 83 | https://www.contributor-covenant.org/translations. 84 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: flowMC 2 | site_description: Normalizing-flow enhanced sampling package for probabilistic inference in Jax 3 | site_author: Kaze Wong 4 | repo_url: https://github.com/kazewong/flowMC 5 | repo_name: kazewong/kazewong 6 | 7 | theme: 8 | name: material 9 | features: 10 | - navigation # Sections are included in the navigation on the left. 11 | - toc # Table of contents is integrated on the left; does not appear separately on the right. 12 | - header.autohide # header disappears as you scroll 13 | palette: 14 | # Light mode / dark mode 15 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as 16 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. 17 | - scheme: default 18 | primary: cyan 19 | accent: deep purple 20 | toggle: 21 | icon: material/brightness-5 22 | name: Dark mode 23 | - scheme: slate 24 | primary: custom 25 | accent: purple 26 | toggle: 27 | icon: material/brightness-2 28 | name: Light mode 29 | 30 | twitter_name: "@physicskaze" 31 | twitter_url: "https://twitter.com/physicskaze" 32 | 33 | markdown_extensions: 34 | - pymdownx.arithmatex: # Render LaTeX via MathJax 35 | generic: true 36 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 37 | - pymdownx.details 38 | - pymdownx.snippets: # Include one Markdown file into another 39 | base_path: docs 40 | - admonition 41 | - toc: 42 | permalink: "¤" # Adds a clickable permalink to each section heading 43 | toc_depth: 4 44 | 45 | plugins: 46 | - search # default search plugin; needs manually re-enabling when using any other plugins 47 | - autorefs # Cross-links to headings 48 | - mkdocs-jupyter: # Jupyter notebook support 49 | # show_input: False 50 | - gen-files: 51 | scripts: 52 | - docs/gen_ref_pages.py 53 | - literate-nav: 54 | nav_file: SUMMARY.md 55 | - mkdocstrings: 56 | handlers: 57 | python: 58 | setup_commands: 59 | - import pytkdocs_tweaks 60 | - pytkdocs_tweaks.main() 61 | - import jaxtyping 62 | - jaxtyping.set_array_name_format("array") 63 | 64 | optional: 65 | docstring_style: google 66 | inherited_members: true # Allow looking up inherited methods 67 | show_root_heading: true # actually display anything at all... 68 | show_root_full_path: true # display "diffrax.asdf" not just "asdf" 69 | show_if_no_docstring: true 70 | show_signature_annotations: true 71 | show_source: false # don't include source code 72 | members_order: source # order methods according to their order of definition in the source code, not alphabetical order 73 | heading_level: 4 74 | 75 | 76 | extra_css: 77 | - stylesheets/extra.css 78 | 79 | nav: 80 | - Home: index.md 81 | - Quickstart: quickstart.md 82 | - Tutorial: 83 | - Step-by-step guide: tutorials/dualmoon.ipynb 84 | - Custom Strategy: tutorials/custom_strategy.ipynb 85 | - Training Normalizing Flow: tutorials/train_normalizing_flow.ipynb 86 | - Loading Resource: tutorials/loading_resources.ipynb 87 | - Community Examples: communityExamples.md 88 | - FAQ: FAQ.md 89 | - Contribution: contribution.md 90 | - API: api/flowMC/* 91 | -------------------------------------------------------------------------------- /src/flowMC/resource/kernel/MALA.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.stats import multivariate_normal 4 | from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, PyTree 5 | from typing import Callable 6 | 7 | from flowMC.resource.logPDF import LogPDF 8 | from flowMC.resource.kernel.base import ProposalBase 9 | 10 | 11 | class MALA(ProposalBase): 12 | """Metropolis-adjusted Langevin algorithm sampler class.""" 13 | 14 | step_size: Float 15 | 16 | def __repr__(self): 17 | return "MALA with step size " + str(self.step_size) 18 | 19 | def __init__( 20 | self, 21 | step_size: Float, 22 | ): 23 | super().__init__() 24 | self.step_size = step_size 25 | 26 | def kernel( 27 | self, 28 | rng_key: PRNGKeyArray, 29 | position: Float[Array, " n_dim"], 30 | log_prob: Float[Array, "1"], 31 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 32 | data: PyTree, 33 | ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: 34 | """Metropolis-adjusted Langevin algorithm kernel. This is a kernel that only 35 | evolve a single chain. 36 | 37 | Args: 38 | rng_key (PRNGKeyArray): Jax PRNGKey 39 | position (Float[Array, " n_dim"]): current position of the chain 40 | log_prob (Float[Array, "1"]): current log-probability of the chain 41 | data (PyTree): data to be passed to the logpdf function 42 | 43 | Returns: 44 | position (Float[Array, " n_dim"]): new position of the chain 45 | log_prob (Float[Array, "1"]): new log-probability of the chain 46 | do_accept (Int[Array, "1"]): whether the new position is accepted 47 | """ 48 | 49 | def body( 50 | carry: tuple[Float[Array, " n_dim"], float, dict], 51 | this_key: PRNGKeyArray, 52 | ) -> tuple[ 53 | tuple[Float[Array, " n_dim"], float, dict], 54 | tuple[Float[Array, " n_dim"], Float[Array, "1"], Float[Array, " n_dim"]], 55 | ]: 56 | print("Compiling MALA body") 57 | this_position, dt, data = carry 58 | dt2 = dt * dt 59 | this_log_prob, this_d_log = jax.value_and_grad(logpdf)(this_position, data) 60 | proposal = this_position + jnp.dot(dt2, this_d_log) / 2 61 | proposal += jnp.dot( 62 | dt, jax.random.normal(this_key, shape=this_position.shape) 63 | ) 64 | return (proposal, dt, data), (proposal, this_log_prob, this_d_log) 65 | 66 | key1, key2 = jax.random.split(rng_key) 67 | 68 | dt: Float = self.step_size 69 | dt2 = dt * dt 70 | 71 | _, (proposal, logprob, d_logprob) = jax.lax.scan( 72 | body, (position, dt, data), jnp.array([key1, key1]) 73 | ) 74 | 75 | ratio = logprob[1] - logprob[0] 76 | ratio -= multivariate_normal.logpdf( 77 | proposal[0], position + jnp.dot(dt2, d_logprob[0]) / 2, dt2 78 | ) 79 | ratio += multivariate_normal.logpdf( 80 | position, proposal[0] + jnp.dot(dt2, d_logprob[1]) / 2, dt2 81 | ) 82 | 83 | log_uniform = jnp.log(jax.random.uniform(key2)) 84 | do_accept: Bool[Array, " n_dim"] = log_uniform < ratio 85 | 86 | position = jnp.where(do_accept, proposal[0], position) 87 | log_prob = jnp.where(do_accept, logprob[1], logprob[0]) 88 | 89 | return position, log_prob, do_accept 90 | 91 | def print_parameters(self): 92 | print("MALA parameters:") 93 | print(f"step_size: {self.step_size}") 94 | 95 | def save_resource(self, path): 96 | raise NotImplementedError 97 | 98 | def load_resource(self, path): 99 | raise NotImplementedError 100 | -------------------------------------------------------------------------------- /src/flowMC/resource/logPDF.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Optional 3 | from flowMC.resource.base import Resource 4 | from jaxtyping import Array, Float, PyTree 5 | import jax 6 | 7 | 8 | @dataclass 9 | class Variable: 10 | """A dataclass that holds the information of a variable in the log-pdf function. 11 | 12 | This main purpose of this class is to let the users name their variables, 13 | and specify whether they are continuous or not. 14 | """ 15 | 16 | name: str 17 | continuous: bool 18 | 19 | 20 | @jax.tree_util.register_pytree_node_class 21 | class LogPDF(Resource): 22 | """A resource class that holds the log-pdf function. 23 | The main purpose of this class is to wrap the log-pdf function into the unified Resource interface. 24 | 25 | Args: 26 | log_pdf (Callable[[Float[Array, "n_dim"], PyTree], Float[Array, "1"]): The log-pdf function 27 | variables (list[Variable]): The list of variables in the log-pdf function 28 | """ 29 | 30 | log_pdf: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]] 31 | variables: list[Variable] 32 | 33 | @property 34 | def n_dims(self): 35 | return len(self.variables) 36 | 37 | def __repr__(self): 38 | return "LogPDF with " + str(self.n_dims) + " dimensions" 39 | 40 | def __init__( 41 | self, 42 | log_pdf: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 43 | variables: Optional[list[Variable]] = None, 44 | n_dims: Optional[int] = None, 45 | ): 46 | """ 47 | Args: 48 | log_pdf (Callable[[Float[Array, "n_dim"], PyTree], Float[Array, "1"]): The log-pdf function 49 | variables (list[Variable], optional): The list of variables in the log-pdf function. Defaults to None. n_dims must be provided if this argument is None. 50 | n_dims (int, optional): The number of dimensions of the log-pdf function. Defaults to None. If variables is provided, this argument is ignored. 51 | """ 52 | self.log_pdf = log_pdf 53 | if variables is None and n_dims is not None: 54 | self.variables = [Variable("x_" + str(i), True) for i in range(n_dims)] 55 | elif variables is not None: 56 | self.variables = variables 57 | else: 58 | raise ValueError("Either variables or n_dims must be provided") 59 | 60 | def __call__(self, x: Float[Array, " n_dim"], data: PyTree) -> Float[Array, "1"]: 61 | return self.log_pdf(x, data) 62 | 63 | def print_parameters(self): 64 | print("LogPDF with variables:") 65 | for var in self.variables: 66 | print(var.name, var.continuous) 67 | 68 | def save_resource(self, path): 69 | raise NotImplementedError 70 | 71 | def load_resource(self, path): 72 | raise NotImplementedError 73 | 74 | def tree_flatten(self): 75 | children = () 76 | aux_data = (self.log_pdf, self.variables) 77 | return (children, aux_data) 78 | 79 | @classmethod 80 | def tree_unflatten(cls, aux_data, children): 81 | return cls(aux_data[0], aux_data[1]) 82 | 83 | 84 | @jax.tree_util.register_pytree_node_class 85 | class TemperedPDF(LogPDF): 86 | 87 | log_prior: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]] 88 | 89 | def __init__( 90 | self, 91 | log_likelihood: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 92 | log_prior: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 93 | variables=None, 94 | n_dims=None, 95 | n_temps=5, 96 | max_temp=100, 97 | ): 98 | super().__init__(log_likelihood, variables, n_dims) 99 | self.log_prior = log_prior 100 | 101 | def __call__(self, x, data): 102 | return super().__call__(x, data) 103 | 104 | def tempered_log_pdf(self, temperatures, x, data): 105 | base_pdf = super().__call__(x, data) 106 | return (1.0 / temperatures) * base_pdf + self.log_prior(x, data) 107 | 108 | def tree_flatten(self): # type: ignore 109 | children = () 110 | aux_data = (self.log_pdf, self.log_prior, self.variables) 111 | return (children, aux_data) 112 | 113 | @classmethod 114 | def tree_unflatten(cls, aux_data, children): 115 | return cls(*aux_data, *children) 116 | -------------------------------------------------------------------------------- /src/flowMC/strategy/train_model.py: -------------------------------------------------------------------------------- 1 | from flowMC.strategy.base import Strategy 2 | from flowMC.resource.base import Resource 3 | from flowMC.resource.buffers import Buffer 4 | from flowMC.resource.model.nf_model.base import NFModel 5 | from flowMC.resource.optimizer import Optimizer 6 | from jaxtyping import Array, Float, PRNGKeyArray 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | 11 | class TrainModel(Strategy): 12 | model_resource: str 13 | data_resource: str 14 | optimizer_resource: str 15 | n_epochs: int 16 | batch_size: int 17 | n_max_examples: int 18 | verbose: bool 19 | thinning: int 20 | 21 | def __repr__(self): 22 | return "Train " + self.model_resource 23 | 24 | def __init__( 25 | self, 26 | model_resource: str, 27 | data_resource: str, 28 | optimizer_resource: str, 29 | loss_buffer_name: str = "", 30 | n_epochs: int = 100, 31 | batch_size: int = 64, 32 | n_max_examples: int = 10000, 33 | history_window: int = 100, 34 | verbose: bool = False, 35 | ): 36 | self.model_resource = model_resource 37 | self.data_resource = data_resource 38 | self.optimizer_resource = optimizer_resource 39 | self.loss_buffer_name = loss_buffer_name 40 | 41 | self.n_epochs = n_epochs 42 | self.batch_size = batch_size 43 | self.n_max_examples = n_max_examples 44 | self.verbose = verbose 45 | self.history_window = history_window 46 | 47 | def __call__( 48 | self, 49 | rng_key: PRNGKeyArray, 50 | resources: dict[str, Resource], 51 | initial_position: Float[Array, "n_chains n_dim"], 52 | data: dict, 53 | ) -> tuple[ 54 | PRNGKeyArray, 55 | dict[str, Resource], 56 | Float[Array, "n_chains n_dim"], 57 | ]: 58 | model = resources[self.model_resource] 59 | assert isinstance(model, NFModel), "Target resource must be a NFModel" 60 | data_resource = resources[self.data_resource] 61 | assert isinstance(data_resource, Buffer), "Data resource must be a buffer" 62 | optimizer = resources[self.optimizer_resource] 63 | assert isinstance( 64 | optimizer, Optimizer 65 | ), "Optimizer resource must be an optimizer" 66 | n_chains = data_resource.data.shape[0] 67 | n_dims = data_resource.data.shape[-1] 68 | training_data = data_resource.data[ 69 | jnp.isfinite(data_resource.data).all(axis=-1) 70 | ].reshape(n_chains, -1, n_dims) 71 | training_data = training_data[:, -self.history_window :].reshape(-1, n_dims) 72 | rng_key, subkey = jax.random.split(rng_key) 73 | training_data = training_data[ 74 | jax.random.choice( 75 | subkey, 76 | jnp.arange(training_data.shape[0]), 77 | shape=(self.n_max_examples,), 78 | replace=True, 79 | ) 80 | ] 81 | rng_key, subkey = jax.random.split(rng_key) 82 | 83 | if self.verbose: 84 | print("Training model") 85 | print(f"Training data shape: {training_data.shape}") 86 | print(f"n_epochs: {self.n_epochs}") 87 | print(f"batch_size: {self.batch_size}") 88 | 89 | (rng_key, model, optim_state, loss_values) = model.train( 90 | rng=subkey, 91 | data=training_data, 92 | optim=optimizer.optim, 93 | state=optimizer.optim_state, 94 | num_epochs=self.n_epochs, 95 | batch_size=self.batch_size, 96 | verbose=self.verbose, 97 | ) 98 | 99 | if self.loss_buffer_name != "": 100 | loss_buffer = resources[self.loss_buffer_name] 101 | assert isinstance( 102 | loss_buffer, Buffer 103 | ), "Loss buffer resource must be a buffer" 104 | loss_buffer.update_buffer(loss_values, start=loss_buffer.cursor) 105 | loss_buffer.cursor += len(loss_values) 106 | resources[self.loss_buffer_name] = loss_buffer 107 | 108 | optimizer.optim_state = optim_state 109 | resources[self.model_resource] = model 110 | resources[self.optimizer_resource] = optimizer 111 | # print(f"Training loss: {loss_values}") 112 | return rng_key, resources, initial_position 113 | -------------------------------------------------------------------------------- /src/flowMC/Sampler.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jaxtyping import Array, Float, PRNGKeyArray 3 | from typing import Optional 4 | 5 | from flowMC.strategy.base import Strategy 6 | from flowMC.resource.base import Resource 7 | from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle 8 | 9 | 10 | class Sampler: 11 | """Top level API that the users primarily interact with. 12 | 13 | Args: 14 | n_dim (int): Dimension of the parameter space. 15 | n_chains (int): Number of chains to sample. 16 | rng_key (PRNGKeyArray): Jax PRNGKey. 17 | logpdf (Callable[[Float[Array, "n_dim"], dict], Float): 18 | Log probability function. 19 | resources (dict[str, Resource]): Resources to be used by the sampler. 20 | strategies (dict[str, Strategy]): Strategies to be used by the sampler. 21 | verbose (bool): Whether to print out progress. Defaults to False. 22 | logging (bool): Whether to log the progress. Defaults to True. 23 | outdir (str): Directory to save the logs. Defaults to "./outdir/". 24 | """ 25 | 26 | # Essential parameters 27 | n_dim: int 28 | n_chains: int 29 | rng_key: PRNGKeyArray 30 | resources: dict[str, Resource] 31 | strategies: dict[str, Strategy] 32 | strategy_order: Optional[list[str]] 33 | 34 | # Logging hyperparameters 35 | verbose: bool = False 36 | logging: bool = True 37 | outdir: str = "./outdir/" 38 | 39 | def __init__( 40 | self, 41 | n_dim: int, 42 | n_chains: int, 43 | rng_key: PRNGKeyArray, 44 | resources: None | dict[str, Resource] = None, 45 | strategies: None | dict[str, Strategy] = None, 46 | strategy_order: None | list[str] = None, 47 | resource_strategy_bundles: None | ResourceStrategyBundle = None, 48 | **kwargs, 49 | ): 50 | # Copying input into the model 51 | 52 | self.n_dim = n_dim 53 | self.n_chains = n_chains 54 | self.rng_key = rng_key 55 | 56 | if resources is not None and strategies is not None: 57 | print( 58 | "Resources and strategies provided. Ignoring resource strategy bundles." 59 | ) 60 | self.resources = resources 61 | self.strategies = strategies 62 | self.strategy_order = strategy_order 63 | 64 | else: 65 | print( 66 | "Resources or strategies not provided. Using resource strategy bundles." 67 | ) 68 | if resource_strategy_bundles is None: 69 | raise ValueError( 70 | "Resource strategy bundles not provided." 71 | "Please provide either resources and strategies or resource strategy bundles." 72 | ) 73 | self.resources = resource_strategy_bundles.resources 74 | self.strategies = resource_strategy_bundles.strategies 75 | self.strategy_order = resource_strategy_bundles.strategy_order 76 | 77 | # Set and override any given hyperparameters 78 | class_keys = list(self.__class__.__dict__.keys()) 79 | for key, value in kwargs.items(): 80 | if key in class_keys: 81 | if not key.startswith("__"): 82 | setattr(self, key, value) 83 | 84 | def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict): 85 | """Sample from the posterior using the local sampler. 86 | 87 | Args: 88 | initial_position (Device Array): Initial position. 89 | data (dict): Data to be used by the likelihood functions 90 | """ 91 | 92 | initial_position = jnp.atleast_2d(initial_position) # type: ignore 93 | rng_key = self.rng_key 94 | last_step = initial_position 95 | assert isinstance(self.strategy_order, list) 96 | for strategy in self.strategy_order: 97 | if strategy not in self.strategies: 98 | raise ValueError( 99 | f"Invalid strategy name '{strategy}' provided. " 100 | f"Available strategies are: {list(self.strategies.keys())}." 101 | ) 102 | ( 103 | rng_key, 104 | self.resources, 105 | last_step, 106 | ) = self.strategies[ 107 | strategy 108 | ](rng_key, self.resources, last_step, data) 109 | 110 | # TODO: Implement quick access and summary functions that operates on buffer 111 | 112 | def serialize(self): 113 | """Serialize the sampler object.""" 114 | raise NotImplementedError 115 | 116 | def deserialize(self): 117 | """Deserialize the sampler object.""" 118 | raise NotImplementedError 119 | -------------------------------------------------------------------------------- /src/flowMC/resource/kernel/HMC.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree 6 | 7 | from flowMC.resource.kernel.base import ProposalBase 8 | from flowMC.resource.logPDF import LogPDF 9 | 10 | 11 | class HMC(ProposalBase): 12 | """Hamiltonian Monte Carlo sampler class builiding the hmc_sampler method from 13 | target logpdf. 14 | 15 | Args: 16 | logpdf: target logpdf function 17 | jit: whether to jit the sampler 18 | params: dictionary of parameters for the sampler 19 | """ 20 | 21 | condition_matrix: Float[Array, " n_dim n_dim"] 22 | step_size: Float 23 | leapfrog_coefs: Float[Array, " n_leapfrog n_dim"] 24 | 25 | @property 26 | def n_leapfrog(self) -> Int: 27 | return self.leapfrog_coefs.shape[0] - 2 28 | 29 | def __repr__(self): 30 | return ( 31 | "HMC with step size " 32 | + str(self.step_size) 33 | + " and " 34 | + str(self.n_leapfrog) 35 | + " leapfrog steps" 36 | ) 37 | 38 | def __init__( 39 | self, 40 | condition_matrix: Float[Array, " n_dim n_dim"] | Float = 1, 41 | step_size: Float = 0.1, 42 | n_leapfrog: Int = 10, 43 | ): 44 | self.condition_matrix = condition_matrix 45 | self.step_size = step_size 46 | 47 | coefs = jnp.ones((n_leapfrog + 2, 2)) 48 | coefs = coefs.at[0].set(jnp.array([0, 0.5])) 49 | coefs = coefs.at[-1].set(jnp.array([1, 0.5])) 50 | self.leapfrog_coefs = coefs 51 | 52 | def get_initial_hamiltonian( 53 | self, 54 | potential: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 55 | kinetic: Callable[ 56 | [Float[Array, " n_dim"], Float[Array, " n_dim"]], Float[Array, "1"] 57 | ], 58 | rng_key: PRNGKeyArray, 59 | position: Float[Array, " n_dim"], 60 | data: PyTree, 61 | ): 62 | """Compute the value of the Hamiltonian from positions with initial momentum 63 | draw at random from the standard normal distribution.""" 64 | 65 | momentum = ( 66 | jax.random.normal(rng_key, shape=position.shape) 67 | * self.condition_matrix**-0.5 68 | ) 69 | return potential(position, data) + kinetic(momentum, self.condition_matrix) 70 | 71 | def leapfrog_kernel(self, kinetic, potential, carry, extras): 72 | position, momentum, data, metric, index = carry 73 | position = position + self.step_size * self.leapfrog_coefs[index][0] * jax.grad( 74 | kinetic 75 | )(momentum, metric) 76 | momentum = momentum - self.step_size * self.leapfrog_coefs[index][1] * jax.grad( 77 | potential 78 | )(position, data) 79 | index = index + 1 80 | return (position, momentum, data, metric, index), extras 81 | 82 | def leapfrog_step( 83 | self, 84 | leapfrog_kernel: Callable, 85 | position: Float[Array, " n_dim"], 86 | momentum: Float[Array, " n_dim"], 87 | data: PyTree, 88 | metric: Float[Array, " n_dim n_dim"], 89 | ) -> tuple[Float[Array, " n_dim"], Float[Array, " n_dim"]]: 90 | print("Compiling leapfrog step") 91 | (position, momentum, data, metric, index), _ = jax.lax.scan( 92 | leapfrog_kernel, 93 | (position, momentum, data, metric, 0), 94 | jnp.arange(self.n_leapfrog + 2), 95 | ) 96 | return position, momentum 97 | 98 | def kernel( 99 | self, 100 | rng_key: PRNGKeyArray, 101 | position: Float[Array, " n_dim"], 102 | log_prob: Float[Array, "1"], 103 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 104 | data: PyTree, 105 | ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]: 106 | """Note that since the potential function is the negative log likelihood, 107 | hamiltonian is going down, but the likelihood value should go up. 108 | 109 | Args: 110 | rng_key (n_chains, 2): random key 111 | position (n_chains, n_dim): current position 112 | PE (n_chains, ): Potential energy of the current position 113 | """ 114 | 115 | def potential(x: Float[Array, " n_dim"], data: PyTree) -> Float[Array, "1"]: 116 | return -logpdf(x, data) 117 | 118 | def kinetic( 119 | p: Float[Array, " n_dim"], metric: Float[Array, " n_dim"] 120 | ) -> Float[Array, "1"]: 121 | return 0.5 * (p**2 * metric).sum() 122 | 123 | leapfrog_kernel = jax.tree_util.Partial( 124 | self.leapfrog_kernel, kinetic, potential 125 | ) 126 | leapfrog_step = jax.tree_util.Partial(self.leapfrog_step, leapfrog_kernel) 127 | 128 | key1, key2 = jax.random.split(rng_key) 129 | 130 | momentum: Float[Array, " n_dim"] = ( 131 | jax.random.normal(key1, shape=position.shape) * self.condition_matrix**-0.5 132 | ) 133 | momentum = jnp.dot( 134 | jax.random.normal(key1, shape=position.shape), 135 | jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)).T, 136 | ) 137 | H = -log_prob + kinetic(momentum, self.condition_matrix) 138 | proposed_position, proposed_momentum = leapfrog_step( 139 | position, momentum, data, self.condition_matrix 140 | ) 141 | proposed_PE = potential(proposed_position, data) 142 | proposed_ham = proposed_PE + kinetic(proposed_momentum, self.condition_matrix) 143 | log_acc = H - proposed_ham 144 | log_uniform = jnp.log(jax.random.uniform(key2)) 145 | 146 | do_accept = log_uniform < log_acc 147 | 148 | position = jnp.where(do_accept, proposed_position, position) # type: ignore 149 | log_prob = jnp.where(do_accept, -proposed_PE, log_prob) # type: ignore 150 | 151 | return position, log_prob, do_accept 152 | 153 | def print_parameters(self): 154 | print("HMC parameters:") 155 | print(f"step_size: {self.step_size}") 156 | print(f"n_leapfrog: {self.n_leapfrog}") 157 | print(f"condition_matrix: {self.condition_matrix}") 158 | 159 | def save_resource(self, path): 160 | raise NotImplementedError 161 | 162 | def load_resource(self, path): 163 | raise NotImplementedError 164 | -------------------------------------------------------------------------------- /src/flowMC/strategy/optimization.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import optax 6 | from jaxtyping import Array, Float, PRNGKeyArray 7 | 8 | from flowMC.strategy.base import Strategy 9 | from flowMC.resource.base import Resource 10 | 11 | 12 | class AdamOptimization(Strategy): 13 | """Optimize a set of chains using Adam optimization. Note that if the posterior can 14 | go to infinity, this optimization scheme is likely to return NaNs. 15 | 16 | Args: 17 | n_steps: int = 100 18 | Number of optimization steps. 19 | learning_rate: float = 1e-2 20 | Learning rate for the optimization. 21 | noise_level: float = 10 22 | Variance of the noise added to the gradients. 23 | bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) 24 | Bounds for the optimization. The optimization will be projected to these bounds. 25 | If bounds has shape (1, 2), it will be broadcast to all dimensions. For n_dim > 1, 26 | passing a (1, 2) array applies the same bound to every dimension. To specify different 27 | bounds per dimension, provide an array of shape (n_dim, 2). 28 | """ 29 | 30 | logpdf: Callable[[Float[Array, " n_dim"], dict], Float] 31 | n_steps: int = 100 32 | learning_rate: float = 1e-2 33 | noise_level: float = 10 34 | bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) 35 | 36 | def __repr__(self): 37 | return "AdamOptimization" 38 | 39 | def __init__( 40 | self, 41 | logpdf: Callable[[Float[Array, " n_dim"], dict], Float], 42 | n_steps: int = 100, 43 | learning_rate: float = 1e-2, 44 | noise_level: float = 10, 45 | bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), 46 | ): 47 | self.logpdf = logpdf 48 | self.n_steps = n_steps 49 | self.learning_rate = learning_rate 50 | self.noise_level = noise_level 51 | self.bounds = bounds 52 | 53 | # Validate bounds shape 54 | if bounds.ndim != 2 or bounds.shape[1] != 2: 55 | raise ValueError( 56 | f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}" 57 | ) 58 | # If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility. 59 | # Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime. 60 | 61 | self.solver = optax.chain( 62 | optax.adam(learning_rate=self.learning_rate), 63 | ) 64 | 65 | def __call__( 66 | self, 67 | rng_key: PRNGKeyArray, 68 | resources: dict[str, Resource], 69 | initial_position: Float[Array, " n_chain n_dim"], 70 | data: dict, 71 | ) -> tuple[ 72 | PRNGKeyArray, 73 | dict[str, Resource], 74 | Float[Array, " n_chain n_dim"], 75 | ]: 76 | def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float: 77 | return -self.logpdf(params, data) 78 | 79 | rng_key, optimized_positions, _ = self.optimize( 80 | rng_key, loss_fn, initial_position, data 81 | ) 82 | 83 | return rng_key, resources, optimized_positions 84 | 85 | def optimize( 86 | self, 87 | rng_key: PRNGKeyArray, 88 | objective: Callable, 89 | initial_position: Float[Array, " n_chain n_dim"], 90 | data: dict, 91 | ): 92 | # Validate bounds shape against n_dim 93 | n_dim = initial_position.shape[-1] 94 | if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim): 95 | raise ValueError( 96 | f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. " 97 | "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds." 98 | ) 99 | 100 | """Optimization kernel. This can be used independently of the __call__ method. 101 | 102 | Args: 103 | rng_key: PRNGKeyArray 104 | Random key for the optimization. 105 | objective: Callable 106 | Objective function to optimize. 107 | initial_position: Float[Array, " n_chain n_dim"] 108 | Initial positions for the optimization. 109 | data: dict 110 | Data to pass to the objective function. 111 | 112 | Returns: 113 | rng_key: PRNGKeyArray 114 | Updated random key. 115 | optimized_positions: Float[Array, " n_chain n_dim"] 116 | Optimized positions. 117 | final_log_prob: Float[Array, " n_chain"] 118 | Final log-probabilities of the optimized positions. 119 | """ 120 | grad_fn = jax.jit(jax.grad(objective)) 121 | 122 | def _kernel(carry, _step): 123 | key, params, opt_state = carry 124 | 125 | key, subkey = jax.random.split(key) 126 | grad = grad_fn(params, data) * ( 127 | 1 + jax.random.normal(subkey) * self.noise_level 128 | ) 129 | updates, opt_state = self.solver.update(grad, opt_state, params) 130 | params = optax.apply_updates(params, updates) 131 | params = optax.projections.projection_box( 132 | params, self.bounds[:, 0], self.bounds[:, 1] 133 | ) 134 | return (key, params, opt_state), None 135 | 136 | def _single_optimize( 137 | key: PRNGKeyArray, 138 | initial_position: Float[Array, " n_dim"], 139 | ) -> Float[Array, " n_dim"]: 140 | opt_state = self.solver.init(initial_position) 141 | 142 | (key, params, opt_state), _ = jax.lax.scan( 143 | _kernel, 144 | (key, initial_position, opt_state), 145 | jnp.arange(self.n_steps), 146 | ) 147 | 148 | return params # type: ignore 149 | 150 | print("Using Adam optimization") 151 | rng_key, subkey = jax.random.split(rng_key) 152 | keys = jax.random.split(subkey, initial_position.shape[0]) 153 | optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))( 154 | keys, initial_position 155 | ) 156 | 157 | final_log_prob = jax.vmap(self.logpdf, in_axes=(0, None))( 158 | optimized_positions, data 159 | ) 160 | 161 | if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any(): 162 | print("Warning: Optimization accessed infinite or NaN log-probabilities.") 163 | 164 | return rng_key, optimized_positions, final_log_prob 165 | -------------------------------------------------------------------------------- /joss/paper.bib: -------------------------------------------------------------------------------- 1 | @article{Gabrie2021, 2 | author = {Gabri{\'{e}}, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric}, 3 | doi = {10.1073/pnas.2109420119}, 4 | eprint = {2105.12603}, 5 | issn = {0027-8424}, 6 | journal = {Proceedings of the National Academy of Sciences}, 7 | mendeley-groups = {publication list,project-capstone-EBM-NF,project-ag-clusters,paper-mixedkernels,proposal-louis-assistant,project-elena,proposal-hiparis-whitebook}, 8 | month = {mar}, 9 | number = {10}, 10 | title = {{Adaptive Monte Carlo augmented with normalizing flows}}, 11 | url = {https://pnas.org/doi/full/10.1073/pnas.2109420119}, 12 | volume = {119}, 13 | year = {2022} 14 | } 15 | @inproceedings{Hoffman2019, 16 | archivePrefix = {arXiv}, 17 | arxivId = {1903.03704}, 18 | author = {Hoffman, Matthew D and Sountsov, Pavel and Dillon, Joshua V. and Langmore, Ian and Tran, Dustin and Vasudevan, Srinivas}, 19 | booktitle = {1st Symposium on Advances in Approximate Bayesian Inference, 2018 1–5}, 20 | eprint = {1903.03704}, 21 | title = {{NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport}}, 22 | url = {http://arxiv.org/abs/1903.03704}, 23 | year = {2019} 24 | } 25 | @article{Albergo2019, 26 | author = {Albergo, M. S. and Kanwar, G. and Shanahan, P. E.}, 27 | doi = {10.1103/PhysRevD.100.034515}, 28 | eprint = {1904.12072}, 29 | issn = {2470-0010}, 30 | journal = {Physical Review D}, 31 | keywords = {doi:10.1103/PhysRevD.100.034515 url:https://doi.or}, 32 | month = {aug}, 33 | number = {3}, 34 | pages = {034515}, 35 | publisher = {American Physical Society}, 36 | title = {{Flow-based generative models for Markov chain Monte Carlo in lattice field theory}}, 37 | url = {https://link.aps.org/doi/10.1103/PhysRevD.100.034515}, 38 | volume = {100}, 39 | year = {2019} 40 | } 41 | @inproceedings{Gabrie2021a, 42 | author = {Gabri{\'{e}}, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric}, 43 | booktitle = {Invertible Neural Networks, NormalizingFlows, and Explicit Likelihood Models (ICML Workshop).}, 44 | title = {{Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods}}, 45 | url = {https://arxiv.org/abs/2107.08001}, 46 | year = {2021} 47 | } 48 | @article{Kobyzev2021, 49 | archivePrefix = {arXiv}, 50 | arxivId = {1908.09257}, 51 | author = {Kobyzev, Ivan and Prince, Simon J.D. and Brubaker, Marcus A.}, 52 | doi = {10.1109/TPAMI.2020.2992934}, 53 | eprint = {1908.09257}, 54 | issn = {19393539}, 55 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 56 | keywords = {Generative models,density estimation,invertible neural networks,normalizing flows,variational inference}, 57 | number = {11}, 58 | pages = {3964--3979}, 59 | pmid = {32396070}, 60 | publisher = {IEEE}, 61 | title = {{Normalizing Flows: An Introduction and Review of Current Methods}}, 62 | volume = {43}, 63 | year = {2021} 64 | } 65 | @article{Papamakarios2019, 66 | archivePrefix = {arXiv}, 67 | arxivId = {1912.02762}, 68 | author = {Papamakarios, George and Nalisnick, Eric and Rezende, Danilo Jimenez and Mohamed, Shakir and Lakshminarayanan, Balaji}, 69 | eprint = {1912.02762}, 70 | journal = {Journal of Machine Learning Research}, 71 | number = {57}, 72 | pages = {1--64}, 73 | title = {{Normalizing Flows for Probabilistic Modeling and Inference}}, 74 | url = {https://jmlr.org/papers/v22/19-1028.html}, 75 | volume = {22}, 76 | year = {2021} 77 | } 78 | @article{Nicoli2020, 79 | archivePrefix = {arXiv}, 80 | arxivId = {1910.13496}, 81 | author = {Nicoli, Kim A. and Nakajima, Shinichi and Strodthoff, Nils and Samek, Wojciech and M{\"{u}}ller, Klaus Robert and Kessel, Pan}, 82 | doi = {10.1103/PhysRevE.101.023304}, 83 | eprint = {1910.13496}, 84 | issn = {24700053}, 85 | journal = {Physical Review E}, 86 | mendeley-groups = {project-potential-based-learning,paper-mixedkernels,paper-neutravsflex}, 87 | number = {2}, 88 | pmid = {32168605}, 89 | title = {{Asymptotically unbiased estimation of physical observables with neural samplers}}, 90 | volume = {101}, 91 | year = {2020} 92 | } 93 | @article{McNaughton2020, 94 | archivePrefix = {arXiv}, 95 | arxivId = {2002.04292}, 96 | author = {McNaughton, B. and Milo{\v{s}}evi{\'{c}}, M. V. and Perali, A. and Pilati, S.}, 97 | doi = {10.1103/PhysRevE.101.053312}, 98 | eprint = {2002.04292}, 99 | issn = {24700053}, 100 | journal = {Physical Review E}, 101 | mendeley-groups = {project-potential-based-learning,paper-mixedkernels,proposal-hiparis-whitebook,paper-neutravsflex}, 102 | number = {Mc}, 103 | pages = {1--13}, 104 | pmid = {32575304}, 105 | title = {{Boosting Monte Carlo simulations of spin glasses using autoregressive neural networks}}, 106 | url = {http://arxiv.org/abs/2002.04292}, 107 | volume = {101}, 108 | year = {2020} 109 | } 110 | @article{Naesseth2020, 111 | archivePrefix = {arXiv}, 112 | arxivId = {2003.10374}, 113 | author = {Naesseth, Christian A. and Lindsten, Fredrik and Blei, David}, 114 | eprint = {2003.10374}, 115 | file = {:Users/marylou/Dropbox/PhD/Literature/2003.10374.pdf:pdf}, 116 | issn = {10495258}, 117 | journal = {Advances in Neural Information Processing Systems}, 118 | mendeley-groups = {paper-aistat-ergofloancao,paper-mixedkernels}, 119 | number = {MCMC}, 120 | title = {{Markovian score climbing: Variational inference with KL(p||q)}}, 121 | volume = {2020-Decem}, 122 | year = {2020} 123 | } 124 | @article{Andrieu2008, 125 | author = {Andrieu, Christophe and Thoms, Johannes}, 126 | doi = {10.1007/s11222-008-9110-y}, 127 | file = {:Users/marylou/Library/Application Support/Mendeley Desktop/Downloaded/Andrieu, Thoms - 2008 - A tutorial on adaptive MCMC.pdf:pdf}, 128 | issn = {09603174}, 129 | journal = {Statistics and Computing}, 130 | keywords = {Adaptive MCMC,Controlled Markov chain,MCMC,Stochastic approximation}, 131 | number = {4}, 132 | pages = {343--373}, 133 | title = {{A tutorial on adaptive MCMC}}, 134 | volume = {18}, 135 | year = {2008} 136 | } 137 | @article{bingham2019pyro, 138 | author = {Eli Bingham and 139 | Jonathan P. Chen and 140 | Martin Jankowiak and 141 | Fritz Obermeyer and 142 | Neeraj Pradhan and 143 | Theofanis Karaletsos and 144 | Rohit Singh and 145 | Paul A. Szerlip and 146 | Paul Horsfall and 147 | Noah D. Goodman}, 148 | title = {Pyro: Deep Universal Probabilistic Programming}, 149 | journal = {J. Mach. Learn. Res.}, 150 | volume = {20}, 151 | pages = {28:1--28:6}, 152 | year = {2019}, 153 | url = {http://jmlr.org/papers/v20/18-403.html} 154 | } 155 | @article{Karamanis2022, 156 | archivePrefix = {arXiv}, 157 | arxivId = {2207.05652}, 158 | author = {Karamanis, Minas and Beutler, Florian and Peacock, John A. and Nabergoj, David and Seljak, Uros}, 159 | eprint = {2207.05652}, 160 | file = {:Users/marylou/Dropbox/PhD/Literature/2207.05652.pdf:pdf}, 161 | journal = {arXiv preprint}, 162 | keywords = {cosmology,data analysis,large-scale structure of universe,methods,statistical}, 163 | title = {{Accelerating astronomical and cosmological inference with Preconditioned Monte Carlo}}, 164 | url = {http://arxiv.org/abs/2207.05652}, 165 | volume = {2207.05652}, 166 | year = {2022} 167 | } 168 | -------------------------------------------------------------------------------- /src/flowMC/resource/kernel/NF_proposal.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import random 6 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree 7 | from typing import Callable 8 | import equinox as eqx 9 | 10 | from flowMC.resource.model.nf_model.base import NFModel 11 | from flowMC.resource.kernel.base import ProposalBase 12 | from flowMC.resource.logPDF import LogPDF 13 | 14 | 15 | class NFProposal(ProposalBase): 16 | model: NFModel 17 | n_batch_size: int 18 | 19 | def __repr__(self): 20 | return "NF proposal with " + self.model.__repr__() 21 | 22 | def __init__(self, model: NFModel, n_NFproposal_batch_size: int = 100): 23 | super().__init__() 24 | self.model = model 25 | self.n_batch_size = n_NFproposal_batch_size 26 | 27 | def kernel( 28 | self, 29 | rng_key: PRNGKeyArray, 30 | position: Float[Array, " n_dim"], 31 | log_prob: Float[Array, "1"], 32 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]], 33 | data: PyTree, 34 | ) -> tuple[ 35 | Float[Array, "n_step n_dim"], Float[Array, "n_step 1"], Int[Array, "n_step 1"] 36 | ]: 37 | 38 | print("Compiling NF proposal kernel") 39 | n_steps = data["n_steps"] 40 | 41 | rng_key, subkey = random.split(rng_key) 42 | 43 | # nf_current is size (1, n_dim) 44 | log_prob_nf_current = eqx.filter_jit(self.model.log_prob)(position) 45 | 46 | # All these are size (n_steps, n_dim) 47 | proposed_position, log_prob_nf_proposed = eqx.filter_jit(self.sample_flow)( 48 | subkey, n_steps 49 | ) 50 | if n_steps > self.n_batch_size: 51 | n_batch = ceil(proposed_position.shape[0] / self.n_batch_size) 52 | batched_proposed_position = proposed_position[ 53 | : (n_batch - 1) * self.n_batch_size 54 | ].reshape(n_batch - 1, self.n_batch_size, self.model.n_features) 55 | 56 | def scan_sample( 57 | carry, 58 | aux, 59 | ): 60 | proposed_position = aux 61 | return carry, jax.vmap(logpdf, in_axes=(0, None))( 62 | proposed_position, data 63 | ) 64 | 65 | _, log_prob_proposed = jax.lax.scan( 66 | scan_sample, 67 | (), 68 | batched_proposed_position, 69 | ) 70 | log_prob_proposed = log_prob_proposed.reshape(-1) 71 | log_prob_proposed = jnp.concatenate( 72 | ( 73 | log_prob_proposed, 74 | jax.vmap(logpdf, in_axes=(0, None))( 75 | jax.lax.dynamic_slice_in_dim( 76 | proposed_position, 77 | (n_batch - 1) * self.n_batch_size, 78 | n_steps - (n_batch - 1) * self.n_batch_size, 79 | ), 80 | data, 81 | ), 82 | ), 83 | axis=0, 84 | ) 85 | 86 | else: 87 | log_prob_proposed = jax.vmap(logpdf, in_axes=(0, None))( 88 | proposed_position, data 89 | ) 90 | 91 | def body(carry, data): 92 | ( 93 | rng_key, 94 | position_initial, 95 | log_prob_initial, 96 | log_prob_nf_initial, 97 | ) = carry 98 | (position_proposal, log_prob_proposal, log_prob_nf_proposal) = data 99 | rng_key, subkey = random.split(rng_key) 100 | ratio = (log_prob_proposal - log_prob_initial) - ( 101 | log_prob_nf_proposal - log_prob_nf_initial 102 | ) 103 | uniform_random = jnp.log(jax.random.uniform(subkey)) 104 | do_accept = uniform_random < ratio 105 | position_current = jnp.where(do_accept, position_proposal, position_initial) 106 | log_prob_current = jnp.where(do_accept, log_prob_proposal, log_prob_initial) 107 | log_prob_nf_current = jnp.where( 108 | do_accept, log_prob_nf_proposal, log_prob_nf_initial 109 | ) 110 | 111 | return (rng_key, position_current, log_prob_current, log_prob_nf_current), ( 112 | position_current, 113 | log_prob_current, 114 | do_accept, 115 | ) 116 | 117 | _, (positions, log_prob, do_accept) = jax.lax.scan( 118 | body, 119 | ( 120 | rng_key, 121 | position, 122 | log_prob, 123 | log_prob_nf_current, 124 | ), 125 | (proposed_position, log_prob_proposed, log_prob_nf_proposed), 126 | ) 127 | 128 | return positions, log_prob, do_accept 129 | 130 | def sample_flow( 131 | self, 132 | rng_key: PRNGKeyArray, 133 | n_steps: int, 134 | ): 135 | if n_steps > self.n_batch_size: 136 | rng_key = rng_key 137 | n_batch = ceil(n_steps / self.n_batch_size) 138 | n_sample = ceil(n_steps / n_batch) 139 | (dynamic, static) = eqx.partition(self.model, eqx.is_array) 140 | 141 | def scan_sample( 142 | carry: tuple[PRNGKeyArray, NFModel], 143 | data, 144 | ): 145 | print("Compiling sample_flow") 146 | rng_key, model = carry 147 | rng_key, subkey = random.split(rng_key) 148 | combined = eqx.combine(model, static) 149 | proposal_position = combined.sample(subkey, n_samples=n_sample) 150 | proposed_log_prob = eqx.filter_vmap(combined.log_prob)( 151 | proposal_position 152 | ) 153 | return (rng_key, model), (proposal_position, proposed_log_prob) 154 | 155 | _, (proposal_position, proposed_log_prob) = jax.lax.scan( 156 | scan_sample, 157 | (rng_key, dynamic), 158 | length=n_batch, 159 | ) 160 | proposal_position = proposal_position.reshape(-1, self.model.n_features)[ 161 | :n_steps 162 | ] 163 | proposed_log_prob = proposed_log_prob.reshape(-1)[:n_steps] 164 | 165 | else: 166 | proposal_position = self.model.sample(rng_key, n_steps) 167 | proposed_log_prob = eqx.filter_vmap(self.model.log_prob)(proposal_position) 168 | 169 | proposal_position = proposal_position.reshape(n_steps, self.model.n_features) 170 | proposed_log_prob = proposed_log_prob.reshape(n_steps) 171 | 172 | return proposal_position, proposed_log_prob 173 | 174 | def print_parameters(self): 175 | # TODO: Implement this 176 | raise NotImplementedError 177 | 178 | def save_resource(self, path): 179 | # TODO: Implement this 180 | raise NotImplementedError 181 | 182 | def load_resource(self, path): 183 | # TODO: Implement this 184 | raise NotImplementedError 185 | -------------------------------------------------------------------------------- /test/unit/test_flowmatching.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pytest 4 | 5 | from flowMC.resource.model.flowmatching.base import ( 6 | FlowMatchingModel, 7 | Solver, 8 | Path, 9 | CondOTScheduler, 10 | ) 11 | from flowMC.resource.model.common import MLP 12 | from diffrax import Dopri5 13 | import equinox as eqx 14 | import optax 15 | 16 | 17 | def get_simple_mlp(n_input, n_hidden, n_output, key): 18 | """Simple 2-layer MLP for testing.""" 19 | shape = ( 20 | [n_input] 21 | + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden)) 22 | + [n_output] 23 | ) 24 | return MLP(shape=shape, key=key, activation=jax.nn.swish) 25 | 26 | 27 | ############################## 28 | # Solver Tests 29 | ############################## 30 | 31 | 32 | class TestSolver: 33 | @pytest.fixture 34 | def solver(self): 35 | key = jax.random.PRNGKey(0) 36 | n_dim = 3 37 | n_hidden = 4 38 | mlp = get_simple_mlp( 39 | n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key 40 | ) 41 | return Solver(model=mlp, method=Dopri5()), key, n_dim 42 | 43 | def test_sample_shape_and_finiteness(self, solver): 44 | solver, key, n_dim = solver 45 | n_samples = 7 46 | samples = solver.sample(key, n_samples) 47 | assert samples.shape == (n_samples, n_dim) 48 | assert jnp.isfinite(samples).all() 49 | 50 | def test_log_prob_shape_and_finiteness(self, solver): 51 | solver, key, n_dim = solver 52 | x1 = jax.random.normal(key, (n_dim,)) 53 | logp = solver.log_prob(x1) 54 | logp_arr = jnp.asarray(logp) 55 | assert logp_arr.size == 1 56 | assert jnp.isfinite(logp_arr).all() 57 | 58 | @pytest.mark.parametrize("dt", [1e-2, 1e-1, 0.5]) 59 | def test_sample_various_dt(self, solver, dt): 60 | solver, key, n_dim = solver 61 | samples = solver.sample(key, 3, dt=dt) 62 | assert samples.shape == (3, n_dim) 63 | assert jnp.isfinite(samples).all() 64 | 65 | 66 | ############################## 67 | # Path & Scheduler Tests 68 | ############################## 69 | 70 | 71 | class TestPathAndScheduler: 72 | def test_path_sample_shapes_and_values(self): 73 | n_dim = 2 74 | scheduler = CondOTScheduler() 75 | path = Path(scheduler=scheduler) 76 | x0 = jnp.ones((5, n_dim)) 77 | x1 = jnp.zeros((5, n_dim)) 78 | for t_val in [0.0, 0.5, 1.0]: 79 | t = jnp.full((5, 1), t_val) 80 | x_t, dx_t = path.sample(x0, x1, t) 81 | assert x_t.shape == (5, n_dim) 82 | assert dx_t.shape == (5, n_dim) 83 | 84 | @pytest.mark.parametrize("t", [0.0, 1.0, 0.5, -0.1, 1.1]) 85 | def test_condotscheduler_call_output(self, t): 86 | sched = CondOTScheduler() 87 | out = sched(jnp.array(t)) 88 | assert isinstance(out, tuple) 89 | assert len(out) == 4 90 | assert all(isinstance(float(x), float) for x in out) 91 | 92 | 93 | ############################## 94 | # FlowMatchingModel Tests 95 | ############################## 96 | 97 | 98 | class TestFlowMatchingModel: 99 | @pytest.fixture 100 | def model(self): 101 | key = jax.random.PRNGKey(42) 102 | n_dim = 2 103 | n_hidden = 8 104 | mlp = get_simple_mlp( 105 | n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key 106 | ) 107 | solver = Solver(model=mlp, method=Dopri5()) 108 | scheduler = CondOTScheduler() 109 | path = Path(scheduler=scheduler) 110 | model = FlowMatchingModel( 111 | solver=solver, 112 | path=path, 113 | data_mean=jnp.zeros(n_dim), 114 | data_cov=jnp.eye(n_dim), 115 | ) 116 | return model, key, n_dim 117 | 118 | def test_sample_and_log_prob(self, model): 119 | model, key, n_dim = model 120 | n_samples = 4 121 | samples = model.sample(key, n_samples) 122 | assert samples.shape == (n_samples, n_dim) 123 | assert jnp.isfinite(samples).all() 124 | logp = eqx.filter_vmap(model.log_prob)(samples) 125 | assert logp.shape == (n_samples, 1) 126 | assert jnp.isfinite(logp).all() 127 | 128 | @pytest.mark.parametrize("n_samples", [1, 5, 10]) 129 | def test_sample_various_shapes(self, model, n_samples): 130 | model, key, n_dim = model 131 | samples = model.sample(key, n_samples) 132 | assert samples.shape == (n_samples, n_dim) 133 | assert jnp.isfinite(samples).all() 134 | logp = eqx.filter_vmap(model.log_prob)(samples) 135 | assert logp.shape[0] == n_samples 136 | assert jnp.isfinite(logp).all() 137 | 138 | def test_log_prob_edge_cases(self, model): 139 | model, key, n_dim = model 140 | for arr in [jnp.zeros(n_dim), 1e6 * jnp.ones(n_dim), -1e6 * jnp.ones(n_dim)]: 141 | logp = model.log_prob(arr) 142 | logp_arr = jnp.asarray(logp) 143 | assert logp_arr.size == 1 144 | assert ( 145 | jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all() 146 | ) # may be nan for extreme values 147 | 148 | def test_save_and_load(self, tmp_path, model): 149 | model, key, n_dim = model 150 | save_path = str(tmp_path / "test_model") 151 | model.save_model(save_path) 152 | loaded = model.load_model(save_path) 153 | x = jax.random.normal(key, (2, n_dim)) 154 | assert jnp.allclose( 155 | eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x) 156 | ) 157 | 158 | def test_properties(self, model): 159 | model, key, n_dim = model 160 | mean = jnp.arange(n_dim) 161 | cov = jnp.eye(n_dim) * 2 162 | model2 = FlowMatchingModel( 163 | solver=model.solver, path=model.path, data_mean=mean, data_cov=cov 164 | ) 165 | assert model2.n_features == n_dim 166 | assert jnp.allclose(model2.data_mean, mean) 167 | assert jnp.allclose(model2.data_cov, cov) 168 | 169 | def test_print_parameters_notimplemented(self, model): 170 | model, key, n_dim = model 171 | with pytest.raises(NotImplementedError): 172 | model.print_parameters() 173 | 174 | def test_train_step_and_epoch(self, model): 175 | model, key, n_dim = model 176 | n_batch = 5 177 | x0 = jax.random.normal(key, (n_batch, n_dim)) 178 | x1 = jax.random.normal(key, (n_batch, n_dim)) 179 | t = jax.random.uniform(key, (n_batch, 1)) 180 | optim = optax.adam(learning_rate=1e-3) 181 | state = optim.init(eqx.filter(model, eqx.is_array)) 182 | std = jnp.sqrt(jnp.diag(model.data_cov)) 183 | x1_whitened = (x1 - model.data_mean) / std 184 | x_t, dx_t = model.path.sample(x0, x1_whitened, t) 185 | loss, model2, state2 = model.train_step(x_t, t, dx_t, optim, state) 186 | assert jnp.isfinite(loss) 187 | assert isinstance(model2, FlowMatchingModel) 188 | data = (x0, x1, t) 189 | loss_epoch, model3, state3 = model.train_epoch( 190 | key, optim, state, data, batch_size=n_batch 191 | ) 192 | assert jnp.isfinite(loss_epoch) 193 | assert isinstance(model3, FlowMatchingModel) 194 | -------------------------------------------------------------------------------- /docs/configuration.md: -------------------------------------------------------------------------------- 1 | Configuration Guide 2 | =================== 3 | 4 | This page contains information about the most important hyperparameters which affect the behavior of the sampler. 5 | 6 | 7 | | Essential | Optional | Advanced | 8 | | --------------------------------- | ----------------------------------------- | ----------------------------------- | 9 | | [`n_dim`](#n_dim) | [`use_global`](#use_global) | [`keep_quantile`](#keep_quantile) | 10 | | [`rng_key_set`](#rng_key_set) | [`n_chains`](#n_chains) | [`momenutum`](#momenutum) | 11 | | [`local_sampler`](#local_sampler) | [`n_loop_training`](#n_loop_training) | [`nf_variable`](#nf_variable) | 12 | | [`data`](#data) | [`n_loop_production`](#n_loop_production) | [`local_autotune`](#local_autotune) | 13 | | [`nf_model`](#nf_model) | [`n_local_steps`](#n_local_steps) | [`train_thinning`](#train_thinning) | 14 | | | [`n_global_steps`](#n_global_steps) | | 15 | | | [`n_epochs`](#n_epochs) | | 16 | | | [`learning_rate`](#learning_rate) | | 17 | | | [`max_samples`](#max_samples) | | 18 | | | [`batch_size`](#batch_size) | | 19 | | | [`verbose`](#verbose) | | 20 | 21 | 22 | 23 | 24 | Essential arguments 25 | ------------------- 26 | 27 | ## [n_dim](#n_dim) 28 | 29 | The dimension of the problem, the sampler would bug if `n_dim` does not match the input dimension of your likelihood function 30 | 31 | ## [rng_key_set](#rng_key_set) 32 | 33 | The set of Jax generated PRNG_keys. 34 | 35 | ## [data](#data) 36 | 37 | The data you want to sample from. This is used to precompile the kernels used during the sampling. 38 | Note that you keep the shape of the data consistent between runs, otherwise it would trigger recompilation. 39 | If your likelihood does not take any data arguments, simply put it as None should work. 40 | 41 | ## [local_sampler](#local_sampler) 42 | Specific local sampler you want to use. 43 | 44 | ## [nf_model](#nf_model) 45 | Specific normalizing flow model you want to use. 46 | 47 | Optional arguments that you might want to tune 48 | ---------------------------------------------- 49 | 50 | ## [use_global](#use_global) 51 | Whether to use global sampler or not. Default is ``True``. 52 | Turning off global sampler will also disable to training phase. 53 | This is useful when you want to test whether the local sampler is behaving normally or perform an ablation study on the benefits of the global sampler. 54 | In production quality runs, you probably always want to use the global sampler since it improves convergence significantly. 55 | 56 | ## [n_chains](#n_chains) 57 | Number of parallel chains to run. Default is ``20``. 58 | Within your memory bandwidth and without oversubscribing your computational device, you should use as many chains as possible. 59 | The method benefits tremendously from parallelization. 60 | 61 | ## [n_loop_training](#n_loop_training) 62 | Number of local-global sample loop to run during training phase. Default is ``3``. 63 | 64 | ## [n_loop_production](#n_loop_production) 65 | Number of local-global sample loop to run during production phase. Default is ``3``. 66 | This is similar to ``n_loop_training``, the only difference is during the production loop, the normalizing flow model will not be updated anymore. This saves computation time once the flow is sufficiently trained to power global moves. As the MCMC stops being adaptive, detailed balance is also restored in the Metropolis-Hastings steps and the traditional diagnostic of MCMC convergence can be applied safely. 67 | 68 | 69 | ## [n_local_steps](#n_local_steps) 70 | Number of local steps to run during the local sampling phase. Default is ``50``. 71 | 72 | ## [n_global_steps](#n_global_steps) 73 | Number of global steps to run during the global sampling phase. Default is ``50``. 74 | 75 | ## [n_epochs](#n_epochs) 76 | Number of epochs to run during the training phase. Default is ``30``. 77 | The higher this number, the better the flow performs, at the cost of increasing the training time. 78 | 79 | ## [learning_rate](#learning_rate) 80 | Learning rate for the Adam optimizer. Default is ``1e-2``. 81 | 82 | ## [max_samples](#max_samples) 83 | Maximum number of samples used to training the normalizing flow model. Default is ``10000``. 84 | 85 | If the total number of obtained samples along the chains is more than ``max_samples`` when getting to a training phase, only a subsample of the most recent steps of the chains of size ``max_samples`` will be used for training. 86 | .. The chains dimension has priority over step dimension, meaning the sampler will try to take at least one sample from each chain before going to previous steps to retrieve more samples. 87 | One usually choose this number base on the memory capacity of the device. 88 | If the number is larger than the memory bandwidth of your device, each training loop will take longer to finish. 89 | On the other hand, the training time will not be affect if the entire dataset can fit on your device. 90 | If this number is small only the most recent samples are used in the training. 91 | This may cause the normalizing flow model to forget about some features of the global landscape that were not visited recently by the chains. For instance, it can lead to mode collapse. 92 | 93 | ## [batch_size](#batch_size) 94 | Batch size for training the normalizing flow model. Default is ``10000``. 95 | Using large batch size speeds up the training since the training time is determined by the number of batched backward passes. 96 | Unlike typical deep learning use case, since our training dataset is continuously evolving, we do not really have to worry about overfitting. 97 | Therefore, using larger batch size is usually better. 98 | The rule of thumb here is: within memory and computational bandwith, choose the largest number that would not increase the training time. 99 | 100 | ## [keep_quantile](#keep_quantile) 101 | Dictionary with keys ``params`` and ``variables`` allowing to use the model trained during a previous run of the NFSampler. These variables can be retrieved from the ``NFSampler.state`` after a run. An exemple is provided in :ref:`tutorials`. 102 | 103 | ## [verbose](#verbose) 104 | Whether to print out extra info during the inference. Default is ``False``. 105 | 106 | 107 | 108 | Only-if-you-know-what-you-are-doing arguments 109 | --------------------------------------------- 110 | 111 | 112 | ## [keep_quantile](#keep_quantile) 113 | 114 | ## [momentum](#momentum) 115 | 116 | ## [nf_variable](#nf_variable) 117 | 118 | ## [local_autotune](#local_autotune) 119 | 120 | ## [train_thinning](#train_thinning) 121 | 122 | Thinning factors for data used to train the normalizing flow. 123 | Given we only keep ``max_samples`` Samples, only the newest ``max_samples/n_chains`` in each chain are used for training the normalizing flow. 124 | This thinning factor keep every ``train_thinning`` samples in each chain. 125 | The larger the number, the less correlated each samples in each chain are. 126 | When ``max_samples*train_thining/n_chains > n_local_steps``, samples generated from different global training loops are used in training the new normalizing flow. 127 | This reduces the possibility of mode collapse since the algorithms had access to samples generated before the mode collapse if it would have happened. 128 | 129 | This API is still experimental and might be combined with other hyperparameters into one big tuning parameters later. 130 | -------------------------------------------------------------------------------- /src/flowMC/strategy/take_steps.py: -------------------------------------------------------------------------------- 1 | from flowMC.resource.base import Resource 2 | from flowMC.resource.kernel.base import ProposalBase 3 | from flowMC.resource.buffers import Buffer 4 | from flowMC.resource.states import State 5 | from flowMC.resource.logPDF import LogPDF 6 | from flowMC.strategy.base import Strategy 7 | from jaxtyping import Array, Float, PRNGKeyArray 8 | import jax 9 | import jax.numpy as jnp 10 | import equinox as eqx 11 | from abc import abstractmethod 12 | 13 | 14 | class TakeSteps(Strategy): 15 | logpdf_name: str 16 | kernel_name: str 17 | state_name: str 18 | buffer_names: list[str] 19 | n_steps: int 20 | current_position: int 21 | thinning: int 22 | chain_batch_size: int # If vmap over a large number of chains is memory bounded, this splits the computation 23 | verbose: bool 24 | 25 | def __init__( 26 | self, 27 | logpdf_name: str, 28 | kernel_name: str, 29 | state_name: str, 30 | buffer_names: list[str], 31 | n_steps: int, 32 | thinning: int = 1, 33 | chain_batch_size: int = 0, 34 | verbose: bool = False, 35 | ): 36 | self.logpdf_name = logpdf_name 37 | self.kernel_name = kernel_name 38 | self.state_name = state_name 39 | self.buffer_names = buffer_names 40 | self.n_steps = n_steps 41 | self.current_position = 0 42 | self.thinning = thinning 43 | self.chain_batch_size = chain_batch_size 44 | self.verbose = verbose 45 | 46 | @abstractmethod 47 | def sample( 48 | self, 49 | kernel: ProposalBase, 50 | rng_key: PRNGKeyArray, 51 | initial_position: Float[Array, " n_dim"], 52 | logpdf: LogPDF, 53 | data: dict, 54 | ): 55 | raise NotImplementedError 56 | 57 | def set_current_position(self, current_position: int): 58 | self.current_position = current_position 59 | 60 | def __call__( 61 | self, 62 | rng_key: PRNGKeyArray, 63 | resources: dict[str, Resource], 64 | initial_position: Float[Array, "n_chains n_dim"], 65 | data: dict, 66 | ) -> tuple[ 67 | PRNGKeyArray, 68 | dict[str, Resource], 69 | Float[Array, "n_chains n_dim"], 70 | ]: 71 | rng_key, subkey = jax.random.split(rng_key) 72 | subkey = jax.random.split(subkey, initial_position.shape[0]) 73 | 74 | assert isinstance( 75 | state_resource := resources[self.state_name], State 76 | ), "State resource must be a State" 77 | 78 | assert isinstance( 79 | position_buffer_name := state_resource.data[self.buffer_names[0]], str 80 | ), "Position buffer resource name must be a string" 81 | 82 | assert isinstance( 83 | log_prob_buffer_name := state_resource.data[self.buffer_names[1]], str 84 | ), "Log probability buffer resource name must be a string" 85 | 86 | assert isinstance( 87 | acceptance_buffer_name := state_resource.data[self.buffer_names[2]], str 88 | ), "Acceptance buffer resource name must be a string" 89 | 90 | assert isinstance( 91 | position_buffer := resources[position_buffer_name], Buffer 92 | ), "Position buffer resource must be a Buffer" 93 | assert isinstance( 94 | log_prob_buffer := resources[log_prob_buffer_name], Buffer 95 | ), "Log probability buffer resource must be a Buffer" 96 | assert isinstance( 97 | acceptance_buffer := resources[acceptance_buffer_name], Buffer 98 | ), "Acceptance buffer resource must be a Buffer" 99 | 100 | kernel = resources[self.kernel_name] 101 | logpdf = resources[self.logpdf_name] 102 | 103 | # Filter jit will bypass the compilation of 104 | # the function if not clearing the cache 105 | n_chains = initial_position.shape[0] 106 | if self.chain_batch_size > 1 and n_chains > self.chain_batch_size: 107 | positions_list = [] 108 | log_probs_list = [] 109 | do_accepts_list = [] 110 | for i in range(0, n_chains, self.chain_batch_size): 111 | batch_slice = slice(i, min(i + self.chain_batch_size, n_chains)) 112 | subkey_batch = subkey[batch_slice] 113 | initial_position_batch = initial_position[batch_slice] 114 | positions_batch, log_probs_batch, do_accepts_batch = eqx.filter_jit( 115 | eqx.filter_vmap( 116 | jax.tree_util.Partial(self.sample, kernel), 117 | in_axes=(0, 0, None, None), 118 | ) 119 | )(subkey_batch, initial_position_batch, logpdf, data) 120 | positions_list.append(positions_batch) 121 | log_probs_list.append(log_probs_batch) 122 | do_accepts_list.append(do_accepts_batch) 123 | positions = jnp.concatenate(positions_list, axis=0) 124 | log_probs = jnp.concatenate(log_probs_list, axis=0) 125 | do_accepts = jnp.concatenate(do_accepts_list, axis=0) 126 | else: 127 | positions, log_probs, do_accepts = eqx.filter_jit( 128 | eqx.filter_vmap( 129 | jax.tree_util.Partial(self.sample, kernel), 130 | in_axes=(0, 0, None, None), 131 | ) 132 | )(subkey, initial_position, logpdf, data) 133 | 134 | positions = positions[:, :: self.thinning] 135 | log_probs = log_probs[:, :: self.thinning] 136 | do_accepts = do_accepts[:, :: self.thinning].astype( 137 | acceptance_buffer.data.dtype 138 | ) 139 | 140 | position_buffer.update_buffer(positions, self.current_position) 141 | log_prob_buffer.update_buffer(log_probs, self.current_position) 142 | acceptance_buffer.update_buffer(do_accepts, self.current_position) 143 | self.current_position += self.n_steps // self.thinning 144 | return rng_key, resources, positions[:, -1] 145 | 146 | 147 | class TakeSerialSteps(TakeSteps): 148 | """TakeSerialSteps is a strategy that takes a number of steps in a serial manner, 149 | i.e. one after the other. 150 | 151 | This uses jax.lax.scan to iterate over the steps and apply the kernel to the current 152 | position. This is intended to be used for most local kernels that are dependent on 153 | the previous step. 154 | """ 155 | 156 | def body(self, kernel: ProposalBase, carry, aux): 157 | key, position, log_prob, logpdf, data = carry 158 | key, subkey = jax.random.split(key) 159 | position, log_prob, do_accept = kernel.kernel( 160 | subkey, position, log_prob, logpdf, data 161 | ) 162 | return (key, position, log_prob, logpdf, data), (position, log_prob, do_accept) 163 | 164 | def sample( 165 | self, 166 | kernel: ProposalBase, 167 | rng_key: PRNGKeyArray, 168 | initial_position: Float[Array, " n_dim"], 169 | logpdf: LogPDF, 170 | data: dict, 171 | ): 172 | ( 173 | (last_key, last_position, last_log_prob, logpdf, data), 174 | (positions, log_probs, do_accepts), 175 | ) = jax.lax.scan( 176 | jax.tree_util.Partial(self.body, kernel), 177 | (rng_key, initial_position, logpdf(initial_position, data), logpdf, data), 178 | length=self.n_steps, 179 | ) 180 | return positions, log_probs, do_accepts 181 | 182 | 183 | class TakeGroupSteps(TakeSteps): 184 | """TakeGroupSteps is a strategy that takes a number of steps in a group manner, i.e. 185 | all steps are taken at once. 186 | 187 | This is intended to be used for kernels such as normalizing flow, which proposal 188 | steps are independent of each other, and benefit from being computed in parallel. 189 | """ 190 | 191 | def sample( 192 | self, 193 | kernel: ProposalBase, 194 | rng_key: PRNGKeyArray, 195 | initial_position: Float[Array, " n_dim"], 196 | logpdf: LogPDF, 197 | data: dict, 198 | ): 199 | (positions, log_probs, do_accepts) = kernel.kernel( 200 | rng_key, 201 | initial_position, 202 | logpdf(initial_position, data), 203 | logpdf, 204 | {**data, "n_steps": self.n_steps}, 205 | ) 206 | return positions, log_probs, do_accepts 207 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/nf_model/realNVP.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import equinox as eqx 4 | import jax 5 | import jax.numpy as jnp 6 | from jaxtyping import Array, Float, PRNGKeyArray 7 | 8 | from flowMC.resource.model.nf_model.base import NFModel 9 | from flowMC.resource.model.common import ( 10 | Distribution, 11 | MLP, 12 | Gaussian, 13 | MaskedCouplingLayer, 14 | MLPAffine, 15 | ) 16 | 17 | 18 | class AffineCoupling(eqx.Module): 19 | """ 20 | Affine coupling layer. 21 | (Defined in the RealNVP paper https://arxiv.org/abs/1605.08803) 22 | We use tanh as the default activation function. 23 | 24 | Args: 25 | n_features: (int) The number of features in the input. 26 | n_hidden: (int) The number of hidden units in the MLP. 27 | mask: (ndarray) Alternating mask for the affine coupling layer. 28 | dt: (Float) Scaling factor for the affine coupling layer. 29 | """ 30 | 31 | _mask: Array 32 | scale_MLP: MLP 33 | translate_MLP: MLP 34 | dt: Float = 1 35 | 36 | def __init__( 37 | self, 38 | n_features: int, 39 | n_hidden: int, 40 | mask: Array, 41 | key: PRNGKeyArray, 42 | dt: Float = 1, 43 | scale: Float = 1e-4, 44 | ): 45 | self._mask = mask 46 | self.dt = dt 47 | key, scale_subkey, translate_subkey = jax.random.split(key, 3) 48 | features = [n_features, n_hidden, n_features] 49 | self.scale_MLP = MLP(features, key=scale_subkey, scale=scale) 50 | self.translate_MLP = MLP(features, key=translate_subkey, scale=scale) 51 | 52 | @property 53 | def mask(self): 54 | return jax.lax.stop_gradient(self._mask) 55 | 56 | @property 57 | def n_features(self): 58 | return self.scale_MLP.n_input 59 | 60 | def __call__(self, x: Array): 61 | return self.forward(x) 62 | 63 | def forward(self, x: Array) -> Tuple[Array, Array]: 64 | """From latent space to data space. 65 | 66 | Args: 67 | x: (Array) Latent space. 68 | 69 | Returns: 70 | outputs: (Array) Data space. 71 | log_det: (Array) Log determinant of the Jacobian. 72 | """ 73 | s = self.mask * self.scale_MLP(x * (1 - self.mask)) 74 | s = jnp.tanh(s) * self.dt 75 | t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt 76 | 77 | # Compute log determinant of the Jacobian 78 | log_det = s.sum() 79 | 80 | # Apply the transformation 81 | outputs = (x + t) * jnp.exp(s) 82 | return outputs, log_det 83 | 84 | def inverse(self, x: Array) -> Tuple[Array, Array]: 85 | """From data space to latent space. 86 | 87 | Args: 88 | x: (Array) Data space. 89 | 90 | Returns: 91 | outputs: (Array) Latent space. 92 | log_det: (Array) Log determinant of the Jacobian. 93 | """ 94 | s = self.mask * self.scale_MLP(x * (1 - self.mask)) 95 | s = jnp.tanh(s) * self.dt 96 | t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt 97 | log_det = -s.sum() 98 | outputs = x * jnp.exp(-s) - t 99 | return outputs, log_det 100 | 101 | 102 | class RealNVP(NFModel): 103 | """ 104 | RealNVP mode defined in the paper https://arxiv.org/abs/1605.08803. 105 | MLP is needed to make sure the scaling between layers are more or less the same. 106 | 107 | Args: 108 | n_layers: (int) The number of affine coupling layers. 109 | n_features: (int) The number of features in the input. 110 | n_hidden: (int) The number of hidden units in the MLP. 111 | dt: (Float) Scaling factor for the affine coupling layer. 112 | 113 | Properties: 114 | data_mean: (ndarray) Mean of Gaussian base distribution 115 | data_cov: (ndarray) Covariance of Gaussian base distribution 116 | """ 117 | 118 | base_dist: Distribution 119 | affine_coupling: List[MaskedCouplingLayer] 120 | _n_features: int 121 | _data_mean: Float[Array, " n_dim"] 122 | _data_cov: Float[Array, " n_dim n_dim"] 123 | 124 | @property 125 | def n_features(self) -> int: 126 | return self._n_features 127 | 128 | @property 129 | def data_mean(self): 130 | return jax.lax.stop_gradient(self._data_mean) 131 | 132 | @property 133 | def data_cov(self): 134 | return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov)) 135 | 136 | def __init__( 137 | self, n_features: int, n_layers: int, n_hidden: int, key: PRNGKeyArray, **kwargs 138 | ): 139 | if kwargs.get("base_dist") is not None: 140 | self.base_dist = kwargs.get("base_dist") # type: ignore 141 | else: 142 | self.base_dist = Gaussian( 143 | jnp.zeros(n_features), jnp.eye(n_features), learnable=False 144 | ) 145 | 146 | if kwargs.get("data_mean") is not None: 147 | data_mean = kwargs.get("data_mean") 148 | assert isinstance(data_mean, Array) 149 | self._data_mean = data_mean 150 | else: 151 | self._data_mean = jnp.zeros(n_features) 152 | 153 | if kwargs.get("data_cov") is not None: 154 | data_cov = kwargs.get("data_cov") 155 | assert isinstance(data_cov, Array) 156 | self._data_cov = data_cov 157 | else: 158 | self._data_cov = jnp.eye(n_features) 159 | 160 | self._n_features = n_features 161 | 162 | def make_layer(i: int, key: PRNGKeyArray): 163 | key, scale_subkey, shift_subkey = jax.random.split(key, 3) 164 | mask = jnp.ones(n_features) 165 | mask = mask.at[: int(n_features / 2)].set(0) 166 | mask = jax.lax.cond(i % 2 == 0, lambda x: 1 - x, lambda x: x, mask) 167 | scale_MLP = MLP([n_features, n_hidden, n_features], key=scale_subkey) 168 | shift_MLP = MLP([n_features, n_hidden, n_features], key=shift_subkey) 169 | return MaskedCouplingLayer(MLPAffine(scale_MLP, shift_MLP), mask) 170 | 171 | keys = jax.random.split(key, n_layers) 172 | self.affine_coupling = eqx.filter_vmap(make_layer)(jnp.arange(n_layers), keys) 173 | 174 | def forward( 175 | self, 176 | x: Float[Array, " n_dim"], 177 | key: Optional[PRNGKeyArray] = None, 178 | condition: Optional[Float[Array, " n_condition"]] = None, 179 | ) -> tuple[Float[Array, " n_dim"], Float]: 180 | log_det = 0.0 181 | dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array) 182 | 183 | def f(carry, data): 184 | x, log_det = carry 185 | layers = eqx.combine(data, statics) 186 | x, log_det_i = layers(x, condition) 187 | return (x, log_det + log_det_i), None 188 | 189 | (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics) 190 | return x, log_det 191 | 192 | def inverse( 193 | self, 194 | x: Float[Array, " n_dim"], 195 | condition: Optional[Float[Array, " n_condition"]] = None, 196 | ) -> tuple[Float[Array, " n_dim"], Float]: 197 | """From latent space to data space.""" 198 | log_det = 0.0 199 | dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array) 200 | 201 | def f(carry, data): 202 | x, log_det = carry 203 | layers = eqx.combine(data, statics) 204 | x, log_det_i = layers.inverse(x, condition) 205 | return (x, log_det + log_det_i), None 206 | 207 | (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics, reverse=True) 208 | return x, log_det 209 | 210 | def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: 211 | samples = self.base_dist.sample(rng_key, n_samples) 212 | samples = jax.vmap(self.inverse)(samples)[0] 213 | samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean 214 | return samples 215 | 216 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float: 217 | # TODO: Check whether taking away vmap hurts accuracy. 218 | x = (x - self.data_mean) / jnp.sqrt(jnp.diag(self.data_cov)) 219 | y, log_det = self.__call__(x) 220 | log_det = log_det + jax.scipy.stats.multivariate_normal.logpdf( 221 | y, jnp.zeros(self.n_features), jnp.eye(self.n_features) 222 | ) 223 | return log_det 224 | 225 | def print_parameters(self): 226 | print("RealNVP parameters:") 227 | print(f"Data mean: {self.data_mean}") 228 | print(f"Data covariance: {self.data_cov}") 229 | -------------------------------------------------------------------------------- /joss/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'FlowMC' 3 | tags: 4 | - Python 5 | - Bayesian Inference 6 | - Machine Learning 7 | - Jax 8 | - 9 | authors: 10 | - name: Kaze W. K. Wong 11 | orcid: 0000-0001-8432-7788 12 | # equal-contrib: true 13 | affiliation: 1 14 | - name: Marylou Gabrié 15 | orcid: 0000-0002-5989-1018 16 | # equal-contrib: true 17 | affiliation: "2, 3" 18 | - name: Dan Foreman-Mackey 19 | orcid: 0000-0002-9328-5652 20 | affiliation: 1 21 | affiliations: 22 | - name: Center for Computational Astrophysics, Flatiron Institute, New York, NY 10010, US 23 | index: 1 24 | - name: École Polytechnique, Palaiseau 91120, France 25 | index: 2 26 | - name: Center for Computational Mathematics, Flatiron Institute, New York, NY 10010, US 27 | index: 3 28 | date: 30 September 2022 29 | bibliography: paper.bib 30 | --- 31 | 32 | # Summary 33 | 34 | Across scientific fields, the modelling of increasingly complex physical processes requires more flexible models. Yet the estimation of models'parameters becomes more challenging as the dimension of the parameter space grows. A common strategy to explore parameter space is to sample through a Markov Chain Monte Carlo (MCMC). Yet even MCMC methods can struggle to fully represent the parameter space when only relying on local updates. 35 | 36 | `FlowMC` is a Python library for accelerated Markov Chain Monte Carlo (MCMC) leveraging deep generative modelling built on top of machine learning libraries `Jax` and `Flax`. At its core, `FlowMC` uses a local sampler and a learnable global sampler in tandem to efficiently sample posterior distributions with non-trivial geometry, such as multimodal distributions and distributions with local correlations. While multiple chains of the local sampler generate samples over the region of interest in the target parameter space, the package uses these samples to train a normalizing flow model, then use it to propose global jumps across the parameter space. 37 | 38 | The key features of `FlowMC` are summarized in the following list: 39 | 40 | ## Key features 41 | 42 | - Since `FlowMC` is built on top of `Jax`, it supports gradient-based sampler such as MALA and Hamiltonian Monte Carlo (HMC) through automatic differentiation. 43 | - `FlowMC` uses state-of-the-art normalizing flow models such as rational quadratic spline (RQS) for the global sampler, which is very efficient in capturing local features with relatively short training time. 44 | - Use of accelerators such as GPUs and TPUs are natively supported. The code also supports the use of multiple accelerators with SIMD parallelism. 45 | - By default, Just-in-time (JIT) compilations are used to further speed up the sampling process. 46 | - We provide a simple black box interface for the users who want to use `FlowMC` by its default parameters, yet provide at the same time an extensive guide explaining trade-offs while tuning the sampler parameters. 47 | 48 | The tight integration of all the above features makes `FlowMC` a highly performant yet simple-to-use package for statistical inference. 49 | 50 | # Statement of need 51 | 52 | Bayesian inference requires to compute expectations with respect to the posterior distribution on the parameters $\theta$ after collecting the observations $\mathcal{D}$. This posterior is given by 53 | 54 | $$ 55 | p(\theta|\mathcal{D}) = \frac{\ell(\mathcal{D}|\theta) p_0(\theta)}{Z(\mathcal{D})} 56 | $$ 57 | 58 | where $\ell(\mathcal{D}|\theta)$ is the likelihood induced by the model, $p_0(\theta)$ the prior on the parameters and $Z(\mathcal{D})$ the model evidence. 59 | As soon as the dimension of $\theta$ exceeds 3 or 4, it is necessary to resort to a robust sampling strategy such as a MCMC. Drastic gains in computational efficiency can be obtained by a careful selection of the MCMC transition kernel which can be assisted by machine learning methods and libraries. 60 | 61 | ***Gradient-based sampler*** 62 | In a high dimensional space, sampling methods which leverage gradient information of the target distribution are shown to be efficient by proposing new samples likely to be accepted. 63 | `FlowMC` supports gradient-based samplers such as MALA and HMC through automatic differentiation with `Jax`. 64 | The computational cost of obtaining a gradient in this way is often of the same order as evaluating the target function itself, making gradient-based samplers compare usually favorably to random walks with respect to the efficiency/accuracy trade-off. 65 | 66 | ***Learned transition kernels with normalizing flow*** 67 | Posterior distribution of many real-world problems have non-trivial geometry such as multi-modality and local correlation, which could drastically slow down the convergence of the sampler only based on gradient information. 68 | To address this problem, we combine a gradient-based sampler with a normalizing flow, which is a class of generative model `[@Papamakarios2019; @Kobyzev2021]`, that is trained to mimic the posterior distribution and used as a proposal a Metropolis-Hastings step. Variant of this idea have been explored in the past few years (e.g.`[@Albergo2019; @Hoffman2019; @Gabrie2021]` and references there in). 69 | Despite the growing interest for these methods few accessible implementations for non-experts already exist and none of them propose GPU and TPU. Namely, a version of the NeuTra sampler `[@Hoffman2019]` available in Pyro `[@bingham2019pyro]` and the PocoMC package `[@Karamanis2022]` are both CPU bounded. 70 | 71 | `FlowMC` implements the proposition of `[@Gabrie2021a]`. 72 | As individual chains are exploring their local neighborhood through gradient-based MCMC steps, multiple chains can be combined and fed to the normalizing flow so it can learn the global landscape of the posterior distribution. In turn, the chains can be propagated with a Metropolis-Hastings kernel using the normalizing flow to propose globally in the parameter space. The cycle of local sampling, normalizing flow tuning and global sampling is repeated until convergence of the chains. 73 | The entire algorithm belongs to the class of adaptive MCMCs `[@Andrieu2008]` collecting information from the chains previous steps to simultaneously improve the transition kernel. 74 | Usual MCMC diagnostics can be applied to asses the robustness of the inference results without worrying about the validation of the normalizing flow model, which is a common problem in deep learning. 75 | If further sampling from the posterior is necessary, the flow trained during a previous can be reused without further training. 76 | The mathematical detail of the method are explained in `[@Gabrie2021a]`. 77 | 78 | ***Use of Accelerator*** 79 | Modern accelerators such as GPU and TPU are designed to execute dense computation in parallel. 80 | Due to the sequential nature of MCMC, a common approach in leveraging accelerators is to run multiple chains in parallel, then combine their results to obtain the posterior distribution. 81 | However, large portion of the computation comes from the burn-in phase, and simply by parallelizing over many chains do not help speed up the burn-in. 82 | To fully leverage the benefit from having many chains, ensemble methods such as (Cite) are often implemented. 83 | This comes with its own set of challenges, and implementing such class of methods on accelerators require careful consideration. 84 | 86 | Since `FlowMC` is built on top of `Jax`, it supports the use of accelerators by default. 87 | Users can write codes in the same way as they would do on a CPU, and the library will automatically detect the available accelerators and use them in run time. 88 | Furthermore, the library leverage Just-In-Time compilations to further improve the performance of the sampler. 89 | 90 | ***Simplicity and extensibility*** 91 | Since we anticipate most of the users would like to spend most of their time building model instead of optimize the performance of the sampler, 92 | we provide a black-box interface with a few tuning parameters for users who intend to use `FlowMC` without too much customization on the sampler side. 93 | The only inputs we require from the users are the log-likelihood function, the log-prior function, and initial position of the chains. 94 | On top of the black-box interface, the package offers automatic tuning for the local samplers, in order to reduce the number of hyperparameters the users have to manage. 95 | 96 | While we provide a high-level API for most of the users, the code is also designed to be extensible. In particular, custom local and global sampling kernels can be integrated in the `sampler` module. 97 | 98 | 99 | # Acknowledgements 100 | M.G. acknowledges support from Hi!Paris. 101 | # References 102 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/nf_model/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Optional 3 | 4 | import equinox as eqx 5 | import jax 6 | import jax.numpy as jnp 7 | import optax 8 | from jaxtyping import Array, Float, PRNGKeyArray 9 | from tqdm import tqdm, trange 10 | from typing_extensions import Self 11 | from flowMC.resource.base import Resource 12 | 13 | 14 | class NFModel(eqx.Module, Resource): 15 | """Base class for normalizing flow models. 16 | 17 | This is an abstract template that should not be directly used. 18 | """ 19 | 20 | _n_features: int 21 | _data_mean: Float[Array, " n_dim"] 22 | _data_cov: Float[Array, " n_dim n_dim"] 23 | 24 | @property 25 | def n_features(self): 26 | return self._n_features 27 | 28 | @property 29 | def data_mean(self): 30 | return jax.lax.stop_gradient(self._data_mean) 31 | 32 | @property 33 | def data_cov(self): 34 | return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov)) 35 | 36 | @abstractmethod 37 | def __init__(self): 38 | raise NotImplementedError 39 | 40 | def __call__( 41 | self, x: Float[Array, " n_dim"] 42 | ) -> tuple[Float[Array, " n_dim"], Float]: 43 | """Forward pass of the model. 44 | 45 | Args: 46 | x (Float[Array, "n_dim"]): Input data. 47 | 48 | Returns: 49 | tuple[Float[Array, "n_dim"], Float]: 50 | Output data and log determinant of the Jacobian. 51 | """ 52 | return self.forward(x) 53 | 54 | @abstractmethod 55 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float: 56 | raise NotImplementedError 57 | 58 | @abstractmethod 59 | def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array: 60 | raise NotImplementedError 61 | 62 | @abstractmethod 63 | def forward( 64 | self, x: Float[Array, " n_dim"], key: Optional[PRNGKeyArray] = None 65 | ) -> tuple[Float[Array, " n_dim"], Float]: 66 | """Forward pass of the model. 67 | 68 | Args: 69 | x (Float[Array, "n_dim"]): Input data. 70 | 71 | Returns: 72 | tuple[Float[Array, "n_dim"], Float]: 73 | Output data and log determinant of the Jacobian. 74 | """ 75 | raise NotImplementedError 76 | 77 | @abstractmethod 78 | def inverse( 79 | self, x: Float[Array, " n_dim"] 80 | ) -> tuple[Float[Array, " n_dim"], Float]: 81 | """Inverse pass of the model. 82 | 83 | Args: 84 | x (Float[Array, "n_dim"]): Input data. 85 | 86 | Returns: 87 | tuple[Float[Array, "n_dim"], Float]: 88 | Output data and log determinant of the Jacobian. 89 | """ 90 | raise NotImplementedError 91 | 92 | def save_model(self, path: str): 93 | eqx.tree_serialise_leaves(path + ".eqx", self) 94 | 95 | def load_model(self, path: str) -> Self: 96 | return eqx.tree_deserialise_leaves(path + ".eqx", self) 97 | 98 | @eqx.filter_value_and_grad 99 | def loss_fn(self, x: Float[Array, "n_batch n_dim"]) -> Float: 100 | return -jnp.mean(jax.vmap(self.log_prob)(x)) 101 | 102 | @eqx.filter_jit 103 | def train_step( 104 | model: Self, 105 | x: Float[Array, "n_batch n_dim"], 106 | optim: optax.GradientTransformation, 107 | state: optax.OptState, 108 | ) -> tuple[Float[Array, " 1"], Self, optax.OptState]: 109 | """Train for a single step. 110 | 111 | Args: 112 | model (eqx.Model): NF model to train. 113 | x (Array): Training data. 114 | opt_state (optax.OptState): Optimizer state. 115 | 116 | Returns: 117 | loss (Array): Loss value. 118 | model (eqx.Model): Updated model. 119 | opt_state (optax.OptState): Updated optimizer state. 120 | """ 121 | print("Compiling training step") 122 | loss, grads = model.loss_fn(x) 123 | updates, state = optim.update(grads, state, model) # type: ignore 124 | model = eqx.apply_updates(model, updates) 125 | return loss, model, state 126 | 127 | def train_epoch( 128 | self: Self, 129 | rng: PRNGKeyArray, 130 | optim: optax.GradientTransformation, 131 | state: optax.OptState, 132 | data: Float[Array, "n_example n_dim"], 133 | batch_size: Float, 134 | ) -> tuple[Float, Self, optax.OptState]: 135 | """Train for a single epoch.""" 136 | value = 1e9 137 | model = self 138 | train_ds_size = len(data) 139 | steps_per_epoch = train_ds_size // batch_size 140 | if steps_per_epoch > 0: 141 | perms = jax.random.permutation(rng, train_ds_size) 142 | 143 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 144 | perms = perms.reshape((steps_per_epoch, batch_size)) 145 | for perm in perms: 146 | batch = data[perm, ...] 147 | value, model, state = model.train_step(batch, optim, state) 148 | else: 149 | value, model, state = model.train_step(data, optim, state) 150 | 151 | return value, model, state 152 | 153 | def train( 154 | self: Self, 155 | rng: PRNGKeyArray, 156 | data: Array, 157 | optim: optax.GradientTransformation, 158 | state: optax.OptState, 159 | num_epochs: int, 160 | batch_size: int, 161 | verbose: bool = True, 162 | ) -> tuple[PRNGKeyArray, Self, optax.OptState, Array]: 163 | """Train a normalizing flow model. 164 | 165 | Args: 166 | rng (PRNGKeyArray): JAX PRNGKey. 167 | model (eqx.Module): NF model to train. 168 | data (Array): Training data. 169 | num_epochs (int): Number of epochs to train for. 170 | batch_size (int): Batch size. 171 | verbose (bool): Whether to print progress. 172 | 173 | Returns: 174 | rng (PRNGKeyArray): Updated JAX PRNGKey. 175 | model (eqx.Model): Updated NF model. 176 | loss_values (Array): Loss values. 177 | """ 178 | loss_values = jnp.zeros(num_epochs) 179 | if verbose: 180 | pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10)) 181 | else: 182 | pbar = range(num_epochs) 183 | 184 | best_model = model = self 185 | best_state = state 186 | best_loss = 1e9 187 | model = eqx.tree_at(lambda m: m._data_mean, model, jnp.mean(data, axis=0)) 188 | model = eqx.tree_at(lambda m: m._data_cov, model, jnp.cov(data.T)) 189 | for epoch in pbar: 190 | # Use a separate PRNG key to permute image data during shuffling 191 | rng, input_rng = jax.random.split(rng) 192 | # Run an optimization step over a training batch 193 | value, model, state = model.train_epoch( 194 | input_rng, optim, state, data, batch_size 195 | ) 196 | loss_values = loss_values.at[epoch].set(value) 197 | if loss_values[epoch] < best_loss: 198 | best_model = model 199 | best_state = state 200 | best_loss = loss_values[epoch] 201 | if verbose: 202 | assert isinstance(pbar, tqdm) 203 | if num_epochs > 10: 204 | if epoch % int(num_epochs / 10) == 0: 205 | pbar.set_description(f"Training NF, current loss: {value:.3f}") 206 | else: 207 | if epoch == num_epochs: 208 | pbar.set_description(f"Training NF, current loss: {value:.3f}") 209 | 210 | return rng, best_model, best_state, loss_values 211 | 212 | def to_precision(self, precision: str = "float32"): 213 | """Convert all parameters to a given precision. 214 | 215 | !!! warning 216 | This function is **experimental** and may change in the future. 217 | 218 | Args: 219 | precision (str): Precision to convert to. 220 | 221 | Returns: 222 | eqx.Module: Model with parameters converted to the given precision. 223 | """ 224 | 225 | precisions_dict = { 226 | "float16": jnp.float16, 227 | "bfloat16": jnp.bfloat16, 228 | "float32": jnp.float32, 229 | "float64": jnp.float64, 230 | } 231 | try: 232 | precision_format = precisions_dict[precision.lower()] 233 | except KeyError: 234 | raise ValueError( 235 | f"Precision {precision} not supported.\ 236 | Choose from {precisions_dict.keys()}" 237 | ) 238 | dynamic_model, static_model = eqx.partition(self, eqx.is_array) 239 | dynamic_model = jax.tree.map( 240 | lambda x: x.astype(precision_format), dynamic_model 241 | ) 242 | return eqx.combine(dynamic_model, static_model) 243 | 244 | save_resource = save_model 245 | load_resource = load_model 246 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # flowMC 2 | 3 | **Normalizing-flow enhanced sampling package for probabilistic inference** 4 | 5 | 6 | doc 7 | 8 | 9 | doc 10 | 11 | Coverage Status 12 | 13 | > [!WARNING] 14 | > As I have new priority in my life now, this code base is in very low maintanence mode. This means I am not actively developing the codebase, and only infrequently responds to the issue board. Small PRs that I can finish reviewing within say 15 minutes will still be looked at, but anything that changes more than 5-10 long files is unlikely to retain my attention. In the meantime, a group of friends of mine forked the repo here https://github.com/GW-JAX-Team/flowMC. I am not officially associated with them and have no idea on what their plan with the fork, but feel free to look over there to find if they have something you need. 15 | 16 | > [!WARNING] 17 | > Note that `flowMC` has not reached v1.0.0, meaning the API could subject to changes. In general, the higher level the API, the less likely it is going to change. However, intermediate level API such as the resource strategy interface could subject to major revision for performance concerns. 18 | 19 | ![flowMC_logo](./docs/logo_0810.png) 20 | 21 | flowMC is a Jax-based python package for normalizing-flow enhanced Markov chain Monte Carlo (MCMC) sampling. 22 | The code is open source under MIT license, and it is under active development. 23 | 24 | - Just-in-time compilation is supported. 25 | - Native support for GPU acceleration. 26 | - Suit for problems with multi-modality. 27 | - Minimal tuning. 28 | 29 | # Installation 30 | 31 | The simplest way to install the package is to do it through pip 32 | 33 | ``` 34 | pip install flowMC 35 | ``` 36 | 37 | This will install the latest stable release and its dependencies. 38 | flowMC is based on [Jax](https://github.com/google/jax) and [Equinox](https://github.com/patrick-kidger/equinox). 39 | By default, installing flowMC will automatically install Jax and Equinox available on [PyPI](https://pypi.org). 40 | By default this install the CPU version of Jax. If you have a GPU and want to use it, you can install the GPU version of Jax by running: 41 | 42 | ``` 43 | pip install flowMC[cuda] 44 | ``` 45 | 46 | If you want to install the latest version of flowMC, you can clone this repo and install it locally: 47 | 48 | ``` 49 | git clone https://github.com/kazewong/flowMC.git 50 | cd flowMC 51 | pip install -e . 52 | ``` 53 | 54 | There are a couple more extras that you can install with flowMC, including: 55 | - `flowMC[docs]`: Install the documentation dependencies. 56 | - `flowMC[codeqa]`: Install the code quality dependencies. 57 | - `flowMC[visualize]`: Install the visualization dependencies. 58 | 59 | On top of `pip` installation, we highly encourage you to use [uv](https://docs.astral.sh/uv/) to manage your python environment. Once you clone the repo, you can run `uv sync` to create a virtual environment with all the dependencies installed. 60 | # Attribution 61 | 62 | If you used `flowMC` in your research, we would really appericiate it if you could at least cite the following papers: 63 | 64 | ``` 65 | @article{Wong:2022xvh, 66 | author = "Wong, Kaze W. k. and Gabri\'e, Marylou and Foreman-Mackey, Daniel", 67 | title = "{flowMC: Normalizing flow enhanced sampling package for probabilistic inference in JAX}", 68 | eprint = "2211.06397", 69 | archivePrefix = "arXiv", 70 | primaryClass = "astro-ph.IM", 71 | doi = "10.21105/joss.05021", 72 | journal = "J. Open Source Softw.", 73 | volume = "8", 74 | number = "83", 75 | pages = "5021", 76 | year = "2023" 77 | } 78 | 79 | @article{Gabrie:2021tlu, 80 | author = "Gabri\'e, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric", 81 | title = "{Adaptive Monte Carlo augmented with normalizing flows}", 82 | eprint = "2105.12603", 83 | archivePrefix = "arXiv", 84 | primaryClass = "physics.data-an", 85 | doi = "10.1073/pnas.2109420119", 86 | journal = "Proc. Nat. Acad. Sci.", 87 | volume = "119", 88 | number = "10", 89 | pages = "e2109420119", 90 | year = "2022" 91 | } 92 | ``` 93 | 94 | This will help `flowMC` getting more recognition, and the main benefit *for you* is this means the `flowMC` community will grow and it will be continuously improved. If you believe in the magic of open-source software, please support us by attributing our software in your work. 95 | 96 | 97 | `flowMC` is a Jax implementation of methods described in: 98 | > *Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods* Gabrié M., Rotskoff G. M., Vanden-Eijnden E. - ICML INNF+ workshop 2021 - [pdf](https://openreview.net/pdf?id=mvtooHbjOwx) 99 | 100 | > *Adaptive Monte Carlo augmented with normalizing flows.* 101 | Gabrié M., Rotskoff G. M., Vanden-Eijnden E. - PNAS 2022 - [doi](https://www.pnas.org/doi/10.1073/pnas.2109420119), [arxiv](https://arxiv.org/abs/2105.12603) 102 | 103 | ## Contributors 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 |
Hajime Kawahara
Hajime Kawahara

🐛
Daniel Dodd
Daniel Dodd

📖 👀 ⚠️ 🐛
Matt Graham
Matt Graham

🐛 ⚠️ 👀 📖
Kaze Wong
Kaze Wong

🐛 📝 💻 🖋 📖 💡 🚇 🚧 🔬 👀 ⚠️
Marylou Gabrié
Marylou Gabrié

🐛 💻 🖋 📖 💡 🚧 🔬 ⚠️
Meesum Qazalbash
Meesum Qazalbash

💻 🚧
Thomas Ng
Thomas Ng

💻 🚧
Thomas Edwards
Thomas Edwards

🐛 💻
124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /src/flowMC/resource_strategy_bundle/RQSpline_MALA.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jax 4 | from jaxtyping import Array, Float, PRNGKeyArray 5 | import equinox as eqx 6 | 7 | from flowMC.resource.base import Resource 8 | from flowMC.resource.buffers import Buffer 9 | from flowMC.resource.states import State 10 | from flowMC.resource.logPDF import LogPDF 11 | from flowMC.resource.kernel.MALA import MALA 12 | from flowMC.resource.kernel.NF_proposal import NFProposal 13 | from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline 14 | from flowMC.resource.optimizer import Optimizer 15 | from flowMC.strategy.lambda_function import Lambda 16 | from flowMC.strategy.take_steps import TakeSerialSteps, TakeGroupSteps 17 | from flowMC.strategy.train_model import TrainModel 18 | from flowMC.strategy.update_state import UpdateState 19 | from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle 20 | 21 | 22 | class RQSpline_MALA_Bundle(ResourceStrategyBundle): 23 | """A bundle that uses a Rational Quadratic Spline as a normalizing flow model and 24 | the Metropolis Adjusted Langevin Algorithm as a local sampler. 25 | 26 | This is the base algorithm described in https://www.pnas.org/doi/full/10.1073/pnas.2109420119 27 | 28 | """ 29 | 30 | def __repr__(self): 31 | return "RQSpline_MALA Bundle" 32 | 33 | def __init__( 34 | self, 35 | rng_key: PRNGKeyArray, 36 | n_chains: int, 37 | n_dims: int, 38 | logpdf: Callable[[Float[Array, " n_dim"], dict], Float], 39 | n_local_steps: int, 40 | n_global_steps: int, 41 | n_training_loops: int, 42 | n_production_loops: int, 43 | n_epochs: int, 44 | mala_step_size: float = 1e-1, 45 | chain_batch_size: int = 0, 46 | rq_spline_hidden_units: list[int] = [32, 32], 47 | rq_spline_n_bins: int = 8, 48 | rq_spline_n_layers: int = 4, 49 | learning_rate: float = 1e-3, 50 | batch_size: int = 10000, 51 | n_max_examples: int = 10000, 52 | local_thinning: int = 1, 53 | global_thinning: int = 1, 54 | n_NFproposal_batch_size: int = 10000, 55 | verbose: bool = False, 56 | ): 57 | n_training_steps = ( 58 | n_local_steps // local_thinning * n_training_loops 59 | + n_global_steps // global_thinning * n_training_loops 60 | ) 61 | n_production_steps = ( 62 | n_local_steps // local_thinning * n_production_loops 63 | + n_global_steps // global_thinning * n_production_loops 64 | ) 65 | n_total_epochs = n_training_loops * n_epochs 66 | 67 | positions_training = Buffer( 68 | "positions_training", (n_chains, n_training_steps, n_dims), 1 69 | ) 70 | log_prob_training = Buffer("log_prob_training", (n_chains, n_training_steps), 1) 71 | local_accs_training = Buffer( 72 | "local_accs_training", (n_chains, n_training_steps), 1 73 | ) 74 | global_accs_training = Buffer( 75 | "global_accs_training", (n_chains, n_training_steps), 1 76 | ) 77 | loss_buffer = Buffer("loss_buffer", (n_total_epochs,), 0) 78 | 79 | position_production = Buffer( 80 | "positions_production", (n_chains, n_production_steps, n_dims), 1 81 | ) 82 | log_prob_production = Buffer( 83 | "log_prob_production", (n_chains, n_production_steps), 1 84 | ) 85 | local_accs_production = Buffer( 86 | "local_accs_production", (n_chains, n_production_steps), 1 87 | ) 88 | global_accs_production = Buffer( 89 | "global_accs_production", (n_chains, n_production_steps), 1 90 | ) 91 | 92 | local_sampler = MALA(step_size=mala_step_size) 93 | rng_key, subkey = jax.random.split(rng_key) 94 | model = MaskedCouplingRQSpline( 95 | n_dims, rq_spline_n_layers, rq_spline_hidden_units, rq_spline_n_bins, subkey 96 | ) 97 | global_sampler = NFProposal( 98 | model, n_NFproposal_batch_size=n_NFproposal_batch_size 99 | ) 100 | optimizer = Optimizer(model=model, learning_rate=learning_rate) 101 | logpdf = LogPDF(logpdf, n_dims=n_dims) 102 | 103 | sampler_state = State( 104 | { 105 | "target_positions": "positions_training", 106 | "target_log_prob": "log_prob_training", 107 | "target_local_accs": "local_accs_training", 108 | "target_global_accs": "global_accs_training", 109 | "training": True, 110 | }, 111 | name="sampler_state", 112 | ) 113 | 114 | self.resources = { 115 | "logpdf": logpdf, 116 | "positions_training": positions_training, 117 | "log_prob_training": log_prob_training, 118 | "local_accs_training": local_accs_training, 119 | "global_accs_training": global_accs_training, 120 | "loss_buffer": loss_buffer, 121 | "positions_production": position_production, 122 | "log_prob_production": log_prob_production, 123 | "local_accs_production": local_accs_production, 124 | "global_accs_production": global_accs_production, 125 | "local_sampler": local_sampler, 126 | "global_sampler": global_sampler, 127 | "model": model, 128 | "optimizer": optimizer, 129 | "sampler_state": sampler_state, 130 | } 131 | 132 | local_stepper = TakeSerialSteps( 133 | "logpdf", 134 | "local_sampler", 135 | "sampler_state", 136 | ["target_positions", "target_log_prob", "target_local_accs"], 137 | n_local_steps, 138 | thinning=local_thinning, 139 | chain_batch_size=chain_batch_size, 140 | verbose=verbose, 141 | ) 142 | 143 | global_stepper = TakeGroupSteps( 144 | "logpdf", 145 | "global_sampler", 146 | "sampler_state", 147 | ["target_positions", "target_log_prob", "target_global_accs"], 148 | n_global_steps, 149 | thinning=global_thinning, 150 | chain_batch_size=chain_batch_size, 151 | verbose=verbose, 152 | ) 153 | 154 | model_trainer = TrainModel( 155 | "model", 156 | "positions_training", 157 | "optimizer", 158 | loss_buffer_name="loss_buffer", 159 | n_epochs=n_epochs, 160 | batch_size=batch_size, 161 | n_max_examples=n_max_examples, 162 | verbose=verbose, 163 | ) 164 | 165 | update_state = UpdateState( 166 | "sampler_state", 167 | [ 168 | "target_positions", 169 | "target_log_prob", 170 | "target_local_accs", 171 | "target_global_accs", 172 | "training", 173 | ], 174 | [ 175 | "positions_production", 176 | "log_prob_production", 177 | "local_accs_production", 178 | "global_accs_production", 179 | False, 180 | ], 181 | ) 182 | 183 | def reset_steppers( 184 | rng_key: PRNGKeyArray, 185 | resources: dict[str, Resource], 186 | initial_position: Float[Array, "n_chains n_dim"], 187 | data: dict, 188 | ) -> tuple[ 189 | PRNGKeyArray, 190 | dict[str, Resource], 191 | Float[Array, "n_chains n_dim"], 192 | ]: 193 | """Reset the steppers to the initial position.""" 194 | local_stepper.set_current_position(0) 195 | global_stepper.set_current_position(0) 196 | return rng_key, resources, initial_position 197 | 198 | reset_steppers_lambda = Lambda( 199 | lambda rng_key, resources, initial_position, data: reset_steppers( 200 | rng_key, resources, initial_position, data 201 | ) 202 | ) 203 | 204 | update_global_step = Lambda( 205 | lambda rng_key, resources, initial_position, data: global_stepper.set_current_position( 206 | local_stepper.current_position 207 | ) 208 | ) 209 | update_local_step = Lambda( 210 | lambda rng_key, resources, initial_position, data: local_stepper.set_current_position( 211 | global_stepper.current_position 212 | ) 213 | ) 214 | 215 | def update_model( 216 | rng_key: PRNGKeyArray, 217 | resources: dict[str, Resource], 218 | initial_position: Float[Array, "n_chains n_dim"], 219 | data: dict, 220 | ) -> tuple[ 221 | PRNGKeyArray, 222 | dict[str, Resource], 223 | Float[Array, "n_chains n_dim"], 224 | ]: 225 | """Update the model.""" 226 | model = resources["model"] 227 | resources["global_sampler"] = eqx.tree_at( 228 | lambda x: x.model, 229 | resources["global_sampler"], 230 | model, 231 | ) 232 | return rng_key, resources, initial_position 233 | 234 | update_model_lambda = Lambda( 235 | lambda rng_key, resources, initial_position, data: update_model( 236 | rng_key, resources, initial_position, data 237 | ) 238 | ) 239 | 240 | self.strategies = { 241 | "local_stepper": local_stepper, 242 | "global_stepper": global_stepper, 243 | "model_trainer": model_trainer, 244 | "update_state": update_state, 245 | "update_global_step": update_global_step, 246 | "update_local_step": update_local_step, 247 | "reset_steppers": reset_steppers_lambda, 248 | "update_model": update_model_lambda, 249 | } 250 | 251 | training_phase = [ 252 | "local_stepper", 253 | "update_global_step", 254 | "model_trainer", 255 | "update_model", 256 | "global_stepper", 257 | "update_local_step", 258 | ] 259 | production_phase = [ 260 | "local_stepper", 261 | "update_global_step", 262 | "global_stepper", 263 | "update_local_step", 264 | ] 265 | strategy_order = [] 266 | for _ in range(n_training_loops): 267 | strategy_order.extend(training_phase) 268 | 269 | strategy_order.append("reset_steppers") 270 | strategy_order.append("update_state") 271 | for _ in range(n_production_loops): 272 | strategy_order.extend(production_phase) 273 | 274 | self.strategy_order = strategy_order 275 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/common.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Tuple, Optional 2 | 3 | import equinox as eqx 4 | import jax 5 | import jax.numpy as jnp 6 | from jaxtyping import Array, Float, PRNGKeyArray 7 | from abc import abstractmethod 8 | 9 | 10 | class Bijection(eqx.Module): 11 | """Base class for bijective transformations. 12 | 13 | This is an abstract template that should not be directly used. 14 | """ 15 | 16 | @abstractmethod 17 | def __init__(self): 18 | raise NotImplementedError 19 | 20 | def __call__( 21 | self, 22 | x: Float[Array, " n_dim"], 23 | condition: Float[Array, " n_condition"], 24 | ) -> tuple[Float[Array, " n_dim"], Float]: 25 | return self.forward(x, condition) 26 | 27 | @abstractmethod 28 | def forward( 29 | self, 30 | x: Float[Array, " n_dim"], 31 | condition: Float[Array, " n_condition"], 32 | ) -> tuple[Float[Array, " n_dim"], Float]: 33 | raise NotImplementedError 34 | 35 | @abstractmethod 36 | def inverse( 37 | self, 38 | x: Float[Array, " n_dim"], 39 | condition: Float[Array, " n_condition"], 40 | ) -> tuple[Float[Array, " n_dim"], Float]: 41 | raise NotImplementedError 42 | 43 | 44 | class Distribution(eqx.Module): 45 | """Base class for probability distributions. 46 | 47 | This is an abstract template that should not be directly used. 48 | """ 49 | 50 | @abstractmethod 51 | def __init__(self): 52 | raise NotImplementedError 53 | 54 | def __call__(self, x: Array, key: Optional[PRNGKeyArray] = None) -> Array: 55 | return self.log_prob(x) 56 | 57 | @abstractmethod 58 | def log_prob(self, x: Array) -> Array: 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def sample( 63 | self, rng_key: PRNGKeyArray, n_samples: int 64 | ) -> Float[Array, " n_samples n_features"]: 65 | raise NotImplementedError 66 | 67 | 68 | class MLP(eqx.Module): 69 | r"""Multilayer perceptron. 70 | 71 | Args: 72 | shape (List[int]): Shape of the MLP. The first element is the input dimension, 73 | the last element is the output dimension. 74 | key (PRNGKeyArray): Random key. 75 | 76 | Attributes: 77 | layers (List): List of layers. 78 | activation (Callable): Activation function. 79 | use_bias (bool): Whether to use bias. 80 | """ 81 | 82 | layers: List 83 | 84 | def __init__( 85 | self, 86 | shape: List[int], 87 | key: PRNGKeyArray, 88 | scale: Float = 1e-4, 89 | activation: Callable = jax.nn.relu, 90 | use_bias: bool = True, 91 | ): 92 | self.layers = [] 93 | for i in range(len(shape) - 2): 94 | key, subkey1, subkey2 = jax.random.split(key, 3) 95 | layer = eqx.nn.Linear( 96 | shape[i], shape[i + 1], key=subkey1, use_bias=use_bias 97 | ) 98 | weight = jax.random.normal(subkey2, (shape[i + 1], shape[i])) * jnp.sqrt( 99 | scale / shape[i] 100 | ) 101 | layer = eqx.tree_at(lambda layer: layer.weight, layer, weight) 102 | self.layers.append(layer) 103 | self.layers.append(activation) 104 | key, subkey = jax.random.split(key) 105 | self.layers.append( 106 | eqx.nn.Linear(shape[-2], shape[-1], key=subkey, use_bias=use_bias) 107 | ) 108 | 109 | def __call__(self, x: Float[Array, " n_in"]) -> Float[Array, " n_out"]: 110 | for layer in self.layers: 111 | x = layer(x) 112 | return x 113 | 114 | @property 115 | def n_input(self) -> int: 116 | return self.layers[0].in_features 117 | 118 | @property 119 | def n_output(self) -> int: 120 | return self.layers[-1].out_features 121 | 122 | @property 123 | def dtype(self) -> jnp.dtype: 124 | return self.layers[0].weight.dtype 125 | 126 | 127 | class MaskedCouplingLayer(Bijection): 128 | r"""Masked coupling layer. 129 | 130 | f(x) = (1-m)*b(x;c(m*x;z)) + m*x 131 | where b is the inner bijector, m is the mask, and c is the conditioner. 132 | 133 | Args: 134 | bijector (Bijection): inner bijector in the masked coupling layer. 135 | mask (Array): Mask. 0 for the input variables that are transformed, 136 | 1 for the input variables that are not transformed. 137 | """ 138 | 139 | _mask: Float[Array, " n_dim"] 140 | bijector: Bijection 141 | 142 | @property 143 | def mask(self) -> Float[Array, " n_dim"]: 144 | return jax.lax.stop_gradient(self._mask) 145 | 146 | def __init__(self, bijector: Bijection, mask: Float[Array, " n_dim"]): 147 | self.bijector = bijector 148 | self._mask = mask 149 | 150 | def forward( 151 | self, 152 | x: Float[Array, " n_dim"], 153 | condition: Float[Array, " n_condition"], 154 | ) -> tuple[Float[Array, " n_dim"], Float]: 155 | y, log_det = self.bijector(x, x * self.mask) # type: ignore 156 | y = (1 - self.mask) * y + self.mask * x 157 | log_det = ((1 - self.mask) * log_det).sum() 158 | return y, log_det 159 | 160 | def inverse( 161 | self, 162 | x: Float[Array, " n_dim"], 163 | condition: Float[Array, " n_condition"], 164 | ) -> tuple[Float[Array, " n_dim"], Float]: 165 | y, log_det = self.bijector.inverse(x, x * self.mask) # type: ignore 166 | y = (1 - self.mask) * y + self.mask * x 167 | log_det = ((1 - self.mask) * log_det).sum() 168 | return y, log_det 169 | 170 | 171 | class MLPAffine(Bijection): 172 | scale_MLP: MLP 173 | shift_MLP: MLP 174 | dt: Float = 1 175 | 176 | def __init__(self, scale_MLP: MLP, shift_MLP: MLP, dt: Float = 1): 177 | self.scale_MLP = scale_MLP 178 | self.shift_MLP = shift_MLP 179 | self.dt = dt 180 | 181 | def __call__( 182 | self, x: Float[Array, " n_dim"], condition_x: Float[Array, " n_cond"] 183 | ) -> Tuple[Float[Array, " n_dim"], Float]: 184 | return self.forward(x, condition_x) 185 | 186 | def forward( 187 | self, 188 | x: Float[Array, " n_dim"], 189 | condition: Float[Array, " n_condition"], 190 | ) -> tuple[Float[Array, " n_dim"], Float]: 191 | # Note that this note output log_det as an array instead of a number. 192 | # This is because we need to sum over the log_det in the masked coupling layer. 193 | scale = jnp.tanh(self.scale_MLP(condition)) * self.dt 194 | shift = self.shift_MLP(condition) * self.dt 195 | log_det = scale 196 | y = (x + shift) * jnp.exp(scale) 197 | return y, log_det 198 | 199 | def inverse( 200 | self, 201 | x: Float[Array, " n_dim"], 202 | condition: Float[Array, " n_condition"], 203 | ) -> tuple[Float[Array, " n_dim"], Float]: 204 | scale = jnp.tanh(self.scale_MLP(condition)) * self.dt 205 | shift = self.shift_MLP(condition) * self.dt 206 | log_det = -scale 207 | y = x * jnp.exp(-scale) - shift 208 | return y, log_det 209 | 210 | 211 | class ScalarAffine(Bijection): 212 | scale: Array 213 | shift: Array 214 | 215 | def __init__(self, scale: Float, shift: Float): 216 | self.scale = jnp.array(scale) 217 | self.shift = jnp.array(shift) 218 | 219 | def __call__( 220 | self, x: Float[Array, " n_dim"], condition_x: Float[Array, " n_cond"] 221 | ) -> Tuple[Float[Array, " n_dim"], Float]: 222 | return self.forward(x, condition_x) 223 | 224 | def forward( 225 | self, 226 | x: Float[Array, " n_dim"], 227 | condition: Float[Array, " n_condition"], 228 | ) -> tuple[Float[Array, " n_dim"], Float]: 229 | y = (x + self.shift) * jnp.exp(self.scale) 230 | log_det = self.scale 231 | return y, log_det 232 | 233 | def inverse( 234 | self, 235 | x: Float[Array, " n_dim"], 236 | condition: Float[Array, " n_condition"], 237 | ) -> tuple[Float[Array, " n_dim"], Float]: 238 | y = x * jnp.exp(-self.scale) - self.shift 239 | log_det = -self.scale 240 | return y, log_det 241 | 242 | 243 | class Gaussian(Distribution): 244 | r"""Multivariate Gaussian distribution. 245 | 246 | Args: 247 | mean (Array): Mean. 248 | cov (Array): Covariance matrix. 249 | learnable (bool): 250 | Whether the mean and covariance matrix are learnable parameters. 251 | 252 | Attributes: 253 | mean (Array): Mean. 254 | cov (Array): Covariance matrix. 255 | """ 256 | 257 | _mean: Float[Array, " n_dim"] 258 | _cov: Float[Array, "n_dim n_dim"] 259 | learnable: bool = False 260 | 261 | @property 262 | def mean(self) -> Float[Array, " n_dim"]: 263 | if self.learnable: 264 | return self._mean 265 | else: 266 | return jax.lax.stop_gradient(self._mean) 267 | 268 | @property 269 | def cov(self) -> Float[Array, "n_dim n_dim"]: 270 | if self.learnable: 271 | return self._cov 272 | else: 273 | return jax.lax.stop_gradient(self._cov) 274 | 275 | def __init__( 276 | self, 277 | mean: Float[Array, " n_dim"], 278 | cov: Float[Array, "n_dim n_dim"], 279 | learnable: bool = False, 280 | ): 281 | self._mean = mean 282 | self._cov = cov 283 | self.learnable = learnable 284 | 285 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float: 286 | return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov) 287 | 288 | def sample( 289 | self, rng_key: PRNGKeyArray, n_samples: int 290 | ) -> Float[Array, " n_samples n_features"]: 291 | return jax.random.multivariate_normal( 292 | rng_key, self.mean, self.cov, (n_samples,) 293 | ) 294 | 295 | 296 | class Composable(Distribution): 297 | distributions: list[Distribution] 298 | partitions: dict[str, tuple[int, int]] 299 | 300 | def __init__(self, distributions: list[Distribution], partitions: dict): 301 | self.distributions = distributions 302 | self.partitions = partitions 303 | 304 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float: 305 | log_prob = 0 306 | for dist, (_, ranges) in zip(self.distributions, self.partitions.items()): 307 | log_prob += dist.log_prob(x[ranges[0] : ranges[1]]) 308 | return log_prob 309 | 310 | def sample( 311 | self, rng_key: PRNGKeyArray, n_samples: int 312 | ) -> Float[Array, " n_samples n_features"]: 313 | samples = {} 314 | for dist, (key, _) in zip(self.distributions, self.partitions.items()): 315 | rng_key, sub_key = jax.random.split(rng_key) 316 | samples[key] = dist.sample(sub_key, n_samples=n_samples) 317 | return samples # type: ignore 318 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | 2 | # Quick Start 3 | 4 | 5 | ## Installation 6 | 7 | The recommended way to install flowMC is using pip 8 | 9 | ``` 10 | pip install flowMC 11 | ``` 12 | 13 | This will install the latest stable release and its dependencies. 14 | flowMC is based on [JAX](https://github.com/google/jax) and [Equinox](https://github.com/patrick-kidger/equinox). 15 | By default, installing flowMC will automatically install JAX and Equinox available on [PyPI](https://pypi.org/). 16 | JAX does not install GPU support by default. 17 | If you want to use GPU with JAX, you need to install JAX with GPU support according to [their document](https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier). 18 | At the time of writing this documentation page, this is the command to install JAX with GPU support: 19 | 20 | ``` 21 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 22 | ``` 23 | 24 | If you want to install the latest version of flowMC, you can clone this repo and install it locally: 25 | 26 | ``` 27 | git clone https://github.com/kazewong/flowMC.git 28 | cd flowMC 29 | pip install -e . 30 | ``` 31 | 32 | If you have [uv](https://docs.astral.sh/uv/) installed, you can also install the latest version of flowMC by running: 33 | 34 | ``` 35 | uv sync 36 | ``` 37 | 38 | once you have cloned the repo. 39 | 40 | 41 | ## Basic Usage 42 | 43 | 44 | To sample an N dimensional Gaussian, you would do something like: 45 | 46 | ``` python 47 | import jax 48 | import jax.numpy as jnp 49 | from flowMC.Sampler import Sampler 50 | from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle 51 | 52 | # Defining the log posterior 53 | 54 | def log_posterior(x, data: dict): 55 | return -0.5 * jnp.sum((x - data["data"]) ** 2) 56 | 57 | # Declaring hyperparameters 58 | 59 | n_dims = 2 60 | n_local_steps = 10 61 | n_global_steps = 10 62 | n_training_loops = 3 63 | n_production_loops = 3 64 | n_epochs = 10 65 | n_chains = 10 66 | rq_spline_hidden_units = [64, 64] 67 | rq_spline_n_bins = 8 68 | rq_spline_n_layers = 3 69 | data = {"data": jnp.arange(n_dims).astype(jnp.float32)} 70 | 71 | # Initializing the strategy bundle 72 | rng_key = jax.random.PRNGKey(42) 73 | rng_key, subkey = jax.random.split(rng_key) 74 | bundle = RQSpline_MALA_Bundle( 75 | subkey, 76 | n_chains, 77 | n_dims, 78 | log_posterior, 79 | n_local_steps, 80 | n_global_steps, 81 | n_training_loops, 82 | n_production_loops, 83 | n_epochs, 84 | rq_spline_hidden_units=rq_spline_hidden_units, 85 | rq_spline_n_bins=rq_spline_n_bins, 86 | rq_spline_n_layers=rq_spline_n_layers, 87 | ) 88 | 89 | # Run the sampler 90 | 91 | rng_key, subkey = jax.random.split(rng_key) 92 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1 93 | nf_sampler = Sampler( 94 | n_dims, 95 | n_chains, 96 | rng_key, 97 | resource_strategy_bundles=bundle, 98 | ) 99 | 100 | nf_sampler.sample(initial_position, data) 101 | ``` 102 | 103 | In the ideal case, the only three things you will have to do are: 104 | 105 | 1. Write down the log-probability density function you want to sample in the form of `log_p(x)`, where `x` is the vector of variables of interest, 106 | 2. Choose your sampling strategy and hyperparameters, 107 | 3. Give the sampler the initial position of your chains and start sampling. 108 | 109 | Given the scripts above, you can start playing with the sampler and see how it behaves. Below is a more detail description of `flowMC` and some guiding principles when using `flowMC`. 110 | 111 | ## Anatomy of flowMC 112 | 113 | Prior to version 0.4.0, `flowMC` was a package that was designed to execute the algorithm detailed in [this paper](https://arxiv.org/pdf/2105.12603). Since then the community has tried applying `flowMC` to different problems. While there were some successes, there are also limiting factors in terms of performance. One of the biggest issues `flowMC` faced is the fact that the global-local sampling algorithm were baked into the top level `Sampler` API, which means `flowMC` can only use the exact algorithm described in the paper. What if the users want to use a different model? Or run some optimization steps during the sampling stage? Or apply annealing? These are either impossible or not very intuitive in `flowMC` prior to version 0.4.0. 114 | 115 | Seeing this limitation, we redesigned the middle level API of `flowMC` while keeping the top level API as similar as possible. This guide aims to describe the different components of `flowMC` and how they interact with each other, and give users who want to extend `flowMC` to optimize for their specific problems a starting point on what could be useful to change. This also acts as a rule of thumb for users who want to use `flowMC` as a black box and interact with internal components through hyperparameters only. 116 | 117 | ### Target distribution 118 | 119 | The target distribution should be defined as a log-probability density function, which follows the following function signature: 120 | 121 | ```python 122 | def target_log_prob_fn(x: Float[Array, "n_dims"], data: dict[str, Any]) -> Float: 123 | ... 124 | return log_prob 125 | ``` 126 | 127 | The `target_log_prob_fn` should take in a `Float[Array, "n_dims"]` array `x` and a dictionary `data` that contains any additional data that the target distribution depends on. The function should return a scalar `Float` that is the log-probability density of the target distribution at `x`. 128 | 129 | To ensure the target distribution is well-defined and performant, you should also check whether the function is behaving as expected when `jax.jit` and `jax.grad` are applied to it. 130 | 131 | ### Sampler 132 | 133 | On the top level, the `Sampler` class is a thin wrapper on top of the resource-strategy pair (defined below) that provides a couple of extra functionality. The `Sampler` class manages the resources and strategies, as well as run-related parameters such as where would the resources be stored if the user decides to serialize the resources. 134 | 135 | ```python 136 | nf_sampler = Sampler( 137 | n_dims, 138 | n_chains, 139 | rng_key, 140 | # you can either supply the resources and strategies directly, 141 | # which is prioritized over the resource-strategy bundles 142 | resources=resources, 143 | strategies=strategies, 144 | # or you can supply the resource-strategy bundles 145 | resource_strategy_bundles=bundle, 146 | ) 147 | ``` 148 | 149 | The main loop of `Sampler` is pretty straight forward after initialization: Given the available resources, iterate through the list of strategies, which each takes the resources, perform some actions (such as taking local steps or training a normalizing flow), and return the updated resources. In the current implementation, the `Sampler` simply goes through the list of strategy, but in the future we are planning to more flexible main loop such as automatic stopping based on some criteria. 150 | 151 | ### Resource and Strategy 152 | 153 | At the core of the new `flowMC` API are the resource and strategy interfaces. Broadly speaking resources are similar to a data class, and strategies are similar to functions. 154 | **Resources** store some attribute and can be manipulated, but should not have too many methods associated with it. For example, a buffer that stores the sampling results is a resource, a MALA kernel is a resource, and a normalizing flow model is a resource. **Strategies** are functions that take in resources and return updated resources. For example, taking a local step requires two kinds of resources: a proposal distribution and the buffer where the samples are stored. Examples of strategies are taking a local step, training a normalizing flow, and running an optimization step. 155 | 156 | If you are initializing the sources and strategies directly, you can do something like: 157 | 158 | ```python 159 | resources = { 160 | "buffer": Buffer(name, n_chains, n_steps, n_dims), 161 | "proposal": MALA(step_size), 162 | "model": NormalizingFlow(model_parameters), 163 | } 164 | 165 | strategies = { 166 | "Strategy 1": Strategy1(), 167 | "Strategy 2": Strategy2(), 168 | } 169 | 170 | strategy_order = ["Strategy 1", "Strategy 2", "Strategy 1", ...] 171 | ``` 172 | 173 | The reason for this separation is to allow users to compose different strategies together. For example, the user may want to update the parameters of a proposal kernel like MALA with the local information from a normalizing flow model. Instead of hard coding this functionality to associate with either the MALA kernel or the normalizing flow model, the current API allows the user to define a strategy that takes in both the MALA kernel and the normalizing flow model, and update the MALA kernel with the information from the normalizing flow model. This separate the concern of intermixing different components of the algorithm and make experimenting with new strategies more manageable. 174 | 175 | Since this API is designed for users who are willing to look into the guts of `flowMC` and experiment with different strategies, the main question to ask is whether a new data structure/functionality should be a resource or a strategy. While there is no hard rules for such implementation other than conforming to the individual base classes, a good rule of thumb is to ask whether the new data structure/functionality is something that should be updated by other strategies. If the answer is yes, then it should be a resource. If the answer is no, then it should be a strategy. 176 | 177 | One extra criteria that decides whether an implementation should be a resource or a strategy is whether the implementation is compatible with `jax`'s transformation. Resource should be compatible with `jit`, and strategy is not required to be compatible with `jit`. An example to illustrate the difference is a training loop contains for-looping over a number of epochs and logging the metadata, which is usually not necessary to be jitted, so this should be a strategy. A neural network and its main functions needs to run efficiently on GPU no matter in sampling or training, so it should be a resource. 178 | 179 | You can find the hyper-parameters of a resource, a strategy, or a resource-strategy bundles in the API docs. 180 | 181 | ## Guiding principles 182 | 183 | ### Write the likelihood function in JAX 184 | 185 | If your likelihood is fully defined in [JAX](https://github.com/google/jax), there are a couple benefits that compound with each other: 186 | 187 | 1. JAX allows you to access the gradient of the likelihood function with respect to the parameters of the model through automatic differentiation. Having access to the gradient allows the use of gradient-based local sampler such as Metropolis-adjusted Langevin algorithm (MALA) and Hamiltonian Monte Carlo (HMC). These algorithms allow the sampler to handle high dimensional problems, and is often more efficient than the gradient-free local sampler such as Metropolis-Hastings. 188 | 2. JAX uses [XLA](https://www.tensorflow.org/xla) to compile your code not only into machine code but also in a way that is more optimized for accelerators such as GPUs and TPUs. Having multiple MCMC chains helps speed up the training of the normalizing flow. Accelerators such as GPUs and TPUs provide parallel computing solutions that are more scalable compared to CPUs. 189 | 190 | Since version 0.4.0, we made the design choice of removing support for likelihood functions incompatible with `jax` transformations. The reason is that `flowMC` is designed to leverage GPU acceleration and machine learning methods to solve complex problems. If a developer decides to use `flowMC` to try to solve their problem, it is also a good time to consider rewriting their legacy code base in `jax`, which on its own could provide a significant speedup. Instead of letting people off the hook by allowing non-jax compatible likelihood functions, we decided to enforce the use of `jax` to encourage users to take advantage of its benefits. 191 | 192 | ### Parallelize whenever you can 193 | 194 | One should center their choice of resource and strategy around leveraging parallelization. This is reflected by the fact that `n_chains` is a required parameter for the `Sampler` class. The reason for this is `flowMC` is designed to solve problems with complex geometry using adaptive sampling method such as training a normalizing flow alongside with a local proposal, which benefit tremendously from having multiple chains running in parallel. 195 | -------------------------------------------------------------------------------- /src/flowMC/resource/model/flowmatching/base.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | from jaxtyping import PRNGKeyArray, Float, Array, PyTree 3 | import optax 4 | from flowMC.resource.base import Resource 5 | from flowMC.resource.model.common import MLP 6 | from typing_extensions import Self 7 | from typing import Optional 8 | import jax.numpy as jnp 9 | import jax 10 | from jax.scipy.stats.multivariate_normal import logpdf 11 | from diffrax import diffeqsolve, ODETerm, Dopri5, AbstractSolver 12 | from tqdm import trange, tqdm 13 | 14 | 15 | class Solver(eqx.Module): 16 | 17 | model: MLP # Shape should be [input_dim + t_dim, hiddens, output_dim] 18 | method: AbstractSolver 19 | 20 | def __init__(self, model: MLP, method: AbstractSolver = Dopri5()): 21 | self.model = model 22 | self.method = method 23 | 24 | def sample( 25 | self, rng_key: PRNGKeyArray, n_samples: int, dt: Float = 1e-1 26 | ) -> Float[Array, "n_samples n_dims"]: 27 | """Sample points from the solver. 28 | This sovles the ODE forward, i.e. from the prior to the posterior. 29 | """ 30 | 31 | def model_wrapper( 32 | t: Float, x: Float[Array, " n_dims"], args: PyTree 33 | ) -> Float[Array, " n_dims"]: 34 | """Wrapper for the model to be used in the ODE solver.""" 35 | t = jnp.expand_dims(t, axis=-1) 36 | x = jnp.concatenate([x, t], axis=-1) 37 | return self.model(x) 38 | 39 | def solve_ode( 40 | y0: Float[Array, " n_dims"], dt: Float = 1e-1 41 | ) -> Float[Array, " n_dims"]: 42 | """Solve the ODE with initial condition y0.""" 43 | term = ODETerm(model_wrapper) 44 | sol = diffeqsolve( 45 | term, 46 | self.method, 47 | t0=0.0, 48 | t1=1.0, 49 | dt0=dt, 50 | y0=y0, 51 | ) 52 | return sol.ys[-1] # type: ignore 53 | 54 | x0 = jax.random.normal(rng_key, (n_samples, self.model.n_input - 1)) 55 | sols = eqx.filter_vmap(solve_ode, in_axes=(0, None))(x0, dt) 56 | return sols 57 | 58 | def log_prob(self, x1: Float[Array, " n_dims"], dt: Float = 1e-1) -> Float: 59 | """Compute the log probability of the initial condition x1. 60 | This solves the ODE backward, i.e. from the posterior to the prior. 61 | """ 62 | 63 | def model_wrapper( 64 | t: Float, x: Float[Array, " n_dims"], args: PyTree 65 | ) -> list[Float[Array, " ..."]]: 66 | """Wrapper for the model to be used in the ODE solver. 67 | 68 | The output shape should be [n_dims, 1]. 69 | """ 70 | t = jnp.expand_dims(t, axis=-1) 71 | x = jnp.concatenate([x[0], t], axis=-1) 72 | y = self.model(x) 73 | div = jax.jacrev(self.model, argnums=0)(x)[:, :-1] 74 | return [y, jnp.trace(div)] 75 | 76 | def solve_ode(y0: Float[Array, " n_dims"], dt: Float = 1e-1) -> PyTree: 77 | """Solve the ODE with initial condition y0.""" 78 | term = ODETerm(model_wrapper) 79 | y_init = jax.tree.map(jnp.asarray, [y0, 0.0]) 80 | sol = diffeqsolve( 81 | term, 82 | self.method, 83 | t0=1.0, 84 | t1=0.0, 85 | dt0=-dt, 86 | y0=y_init, 87 | ) 88 | return sol.ys 89 | 90 | x0, log_p = solve_ode(x1, dt) 91 | return ( 92 | logpdf( 93 | x1, 94 | mean=self.model.n_output * jnp.zeros(self.model.n_output), 95 | cov=jnp.eye(self.model.n_output), 96 | ) 97 | + log_p 98 | ) 99 | 100 | 101 | class Scheduler: 102 | 103 | def __call__(self, t: Float) -> tuple[Float, Float, Float, Float]: 104 | """Return the parameters of the scheduler at time t.""" 105 | raise NotImplementedError 106 | 107 | 108 | class CondOTScheduler(Scheduler): 109 | """Conditional Optimal Transport Scheduler.""" 110 | 111 | def __call__(self, t: Float) -> tuple[Float, Float, Float, Float]: 112 | """Return the parameters of the scheduler at time t.""" 113 | # Implement the logic to compute alpha_t, d_alpha_t, sigma_t, d_sigma_t 114 | return t, 1.0, 1.0 - t, -1.0 115 | 116 | 117 | class Path: 118 | 119 | scheduler: Scheduler 120 | 121 | def __init__(self, scheduler: Scheduler): 122 | self.scheduler = scheduler 123 | 124 | def sample(self, x0: Float, x1: Float, t: Float) -> Float: 125 | """Sample a point along the path between x0 and x1 at time t.""" 126 | alpha_t, d_alpha_t, sigma_t, d_sigma_t = self.scheduler(t) 127 | x_t = sigma_t * x0 + alpha_t * x1 128 | dx_t = d_sigma_t * x0 + d_alpha_t * x1 129 | return x_t, dx_t 130 | 131 | 132 | class FlowMatchingModel(eqx.Module, Resource): 133 | 134 | solver: Solver 135 | path: Path 136 | _data_mean: Float[Array, " n_dim"] 137 | _data_cov: Float[Array, " n_dim n_dim"] 138 | 139 | @property 140 | def n_features(self): 141 | return self.solver.model.n_input - 1 142 | 143 | @property 144 | def data_mean(self): 145 | return jax.lax.stop_gradient(self._data_mean) 146 | 147 | @property 148 | def data_cov(self): 149 | return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov)) 150 | 151 | def __init__( 152 | self, 153 | solver: Solver, 154 | path: Path, 155 | data_mean: Optional[Float[Array, " n_dim"]] = None, 156 | data_cov: Optional[Float[Array, " n_dim n_dim"]] = None, 157 | ): 158 | self.solver = solver 159 | self.path = path 160 | n_features = self.n_features 161 | if data_mean is not None: 162 | self._data_mean = data_mean 163 | else: 164 | self._data_mean = jnp.zeros(n_features) 165 | 166 | if data_cov is not None: 167 | self._data_cov = data_cov 168 | else: 169 | self._data_cov = jnp.eye(n_features) 170 | 171 | def sample( 172 | self, rng_key: PRNGKeyArray, num_samples: int, dt: Float = 1e-1 173 | ) -> Float[Array, " n_dim"]: 174 | rng_key, subkey = jax.random.split(rng_key) 175 | samples = self.solver.sample(subkey, num_samples, dt=dt) 176 | std = jnp.sqrt(jnp.diag(self.data_cov)) 177 | samples = samples * std + self.data_mean 178 | return samples 179 | 180 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float: 181 | std = jnp.sqrt(jnp.diag(self.data_cov)) 182 | x_whitened = (x - self.data_mean) / std 183 | log_det = -jnp.sum(jnp.log(std)) 184 | return self.solver.log_prob(x_whitened) + log_det 185 | 186 | def save_model(self, path: str): 187 | eqx.tree_serialise_leaves(path + ".eqx", self) 188 | 189 | def load_model(self, path: str) -> Self: 190 | return eqx.tree_deserialise_leaves(path + ".eqx", self) 191 | 192 | @eqx.filter_value_and_grad 193 | def loss_fn( 194 | self, 195 | x: Float[Array, "n_batch n_dim"], 196 | t: Float[Array, "n_batch 1"], 197 | dx_t: Float[Array, "n_batch n_dim"], 198 | ) -> Float: 199 | x = jnp.concatenate([x, t], axis=-1) 200 | return jnp.mean( 201 | (eqx.filter_vmap(self.solver.model, in_axes=(0))(x) - dx_t) ** 2 202 | ) 203 | 204 | @eqx.filter_jit 205 | def train_step( 206 | model: Self, 207 | x_t: Float[Array, "n_batch n_dim"], 208 | t: Float[Array, "n_batch 1"], 209 | dx_t: Float[Array, "n_batch n_dim"], 210 | optim: optax.GradientTransformation, 211 | state: optax.OptState, 212 | ) -> tuple[Float[Array, " 1"], Self, optax.OptState]: 213 | print("Compiling training step") 214 | loss, grads = model.loss_fn(x_t, t, dx_t) 215 | updates, state = optim.update(grads, state, model) # type: ignore 216 | model = eqx.apply_updates(model, updates) 217 | return loss, model, state 218 | 219 | def train_epoch( 220 | self: Self, 221 | rng: PRNGKeyArray, 222 | optim: optax.GradientTransformation, 223 | state: optax.OptState, 224 | data: tuple[ 225 | Float[Array, "n_example n_dim"], 226 | Float[Array, "n_example n_dim"], 227 | Float[Array, "n_example 1"], 228 | ], 229 | batch_size: Float, 230 | ) -> tuple[Float, Self, optax.OptState]: 231 | """Train for a single epoch.""" 232 | value = 1e9 233 | model = self 234 | train_ds_size = len(data[0]) 235 | steps_per_epoch = train_ds_size // batch_size 236 | std = jnp.sqrt(jnp.diag(self.data_cov)) 237 | if steps_per_epoch > 0: 238 | perms = jax.random.permutation(rng, train_ds_size) 239 | 240 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 241 | perms = perms.reshape((steps_per_epoch, batch_size)) 242 | for perm in perms: 243 | batch_x0, batch_x1, batch_t = ( 244 | data[0][perm, ...], 245 | data[1][perm, ...], 246 | data[2][perm, ...], 247 | ) 248 | batch_x1 = (batch_x1 - self.data_mean) / std 249 | batch_x_t, batch_dx_t = self.path.sample(batch_x0, batch_x1, batch_t) 250 | value, model, state = model.train_step( 251 | batch_x_t, batch_t, batch_dx_t, optim, state 252 | ) 253 | else: 254 | batch_x1 = (data[1] - self.data_mean) / std 255 | x_t, dx_t = self.path.sample(data[0], batch_x1, data[2]) 256 | value, model, state = model.train_step(x_t, data[2], dx_t, optim, state) 257 | return value, model, state 258 | 259 | def train( 260 | self: Self, 261 | rng: PRNGKeyArray, 262 | data: tuple[ 263 | Float[Array, "n_example n_dim"], 264 | Float[Array, "n_example n_dim"], 265 | Float[Array, "n_example 1"], 266 | ], 267 | optim: optax.GradientTransformation, 268 | state: optax.OptState, 269 | num_epochs: int, 270 | batch_size: int, 271 | verbose: bool = True, 272 | ) -> tuple[PRNGKeyArray, Self, optax.OptState, Array]: 273 | """Train a normalizing flow model. 274 | 275 | Args: 276 | rng (PRNGKeyArray): JAX PRNGKey. 277 | model (eqx.Module): NF model to train. 278 | data (Array): Training data. 279 | num_epochs (int): Number of epochs to train for. 280 | batch_size (int): Batch size. 281 | verbose (bool): Whether to print progress. 282 | 283 | Returns: 284 | rng (PRNGKeyArray): Updated JAX PRNGKey. 285 | model (eqx.Model): Updated NF model. 286 | loss_values (Array): Loss values. 287 | """ 288 | loss_values = jnp.zeros(num_epochs) 289 | if verbose: 290 | pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10)) 291 | else: 292 | pbar = range(num_epochs) 293 | 294 | best_model = model = self 295 | best_state = state 296 | best_loss = 1e9 297 | model = eqx.tree_at(lambda m: m._data_mean, model, jnp.mean(data[1], axis=0)) 298 | model = eqx.tree_at(lambda m: m._data_cov, model, jnp.cov(data[1].T)) 299 | for epoch in pbar: 300 | # Use a separate PRNG key to permute image data during shuffling 301 | rng, input_rng = jax.random.split(rng) 302 | # Run an optimization step over a training batch 303 | value, model, state = model.train_epoch( 304 | input_rng, optim, state, data, batch_size 305 | ) 306 | loss_values = loss_values.at[epoch].set(value) 307 | if loss_values[epoch] < best_loss: 308 | best_model = model 309 | best_state = state 310 | best_loss = loss_values[epoch] 311 | if verbose: 312 | assert isinstance(pbar, tqdm) 313 | if num_epochs > 10: 314 | if epoch % int(num_epochs / 10) == 0: 315 | pbar.set_description(f"Training NF, current loss: {value:.3f}") 316 | else: 317 | if epoch == num_epochs: 318 | pbar.set_description(f"Training NF, current loss: {value:.3f}") 319 | 320 | return rng, best_model, best_state, loss_values 321 | 322 | save_resource = save_model 323 | load_resource = load_model 324 | 325 | def print_parameters(self): 326 | raise NotImplementedError( 327 | "print_parameters is not implemented for FlowMatchingModel" 328 | ) 329 | --------------------------------------------------------------------------------