├── src
└── flowMC
│ ├── __init__.py
│ ├── resource
│ ├── __init__.py
│ ├── kernel
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── Gaussian_random_walk.py
│ │ ├── MALA.py
│ │ ├── HMC.py
│ │ └── NF_proposal.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── nf_model
│ │ │ ├── __init__.py
│ │ │ ├── realNVP.py
│ │ │ └── base.py
│ │ ├── flowmatching
│ │ │ ├── __init__.py
│ │ │ └── base.py
│ │ └── common.py
│ ├── optimizer.py
│ ├── base.py
│ ├── buffers.py
│ ├── states.py
│ └── logPDF.py
│ ├── strategy
│ ├── __init__.py
│ ├── importance_sampling.py
│ ├── sequential_monte_carlo.py
│ ├── base.py
│ ├── lambda_function.py
│ ├── update_state.py
│ ├── train_model.py
│ ├── optimization.py
│ └── take_steps.py
│ ├── resource_strategy_bundle
│ ├── __init__.py
│ ├── base.py
│ └── RQSpline_MALA.py
│ └── Sampler.py
├── ruff.toml
├── docs
├── contribution.md
├── dual_moon.png
├── logo_0810.png
├── tutorials
│ └── dual_moon.png
├── stylesheets
│ └── extra.css
├── requirements.txt
├── gen_ref_pages.py
├── communityExamples.md
├── index.md
├── FAQ.md
├── configuration.md
└── quickstart.md
├── .coverage
├── .gitattributes
├── readthedocs.yml
├── .devcontainer
└── devcontainer.json
├── .gitignore
├── .github
├── ISSUE_TEMPLATE
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ ├── pre-commit.yml
│ ├── workflowsjoss.yml
│ ├── run_tests.yml
│ └── python-publish.yml
├── .pre-commit-config.yaml
├── LICENSE
├── test
├── integration
│ ├── test_quickstart.py
│ ├── test_normalizingFlow.py
│ ├── test_MALA.py
│ ├── test_HMC.py
│ └── test_RWMCMC.py
└── unit
│ ├── test_bundle.py
│ ├── test_nf.py
│ ├── test_resources.py
│ └── test_flowmatching.py
├── pyproject.toml
├── CONTRIBUTING.md
├── .all-contributorsrc
├── CODE_OF_CONDUCT.md
├── mkdocs.yml
├── joss
├── paper.bib
└── paper.md
└── README.md
/src/flowMC/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | lint.ignore = ["F722"]
--------------------------------------------------------------------------------
/src/flowMC/resource/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/contribution.md:
--------------------------------------------------------------------------------
1 | ../CONTRIBUTING.md
--------------------------------------------------------------------------------
/src/flowMC/resource/kernel/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flowMC/resource/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/importance_sampling.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flowMC/resource/model/nf_model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flowMC/resource_strategy_bundle/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/flowMC/resource/model/flowmatching/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.coverage:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kazewong/flowMC/HEAD/.coverage
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | docs/** linguist-vendored
2 | example/** linguist-vendored
--------------------------------------------------------------------------------
/docs/dual_moon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kazewong/flowMC/HEAD/docs/dual_moon.png
--------------------------------------------------------------------------------
/docs/logo_0810.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kazewong/flowMC/HEAD/docs/logo_0810.png
--------------------------------------------------------------------------------
/docs/tutorials/dual_moon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kazewong/flowMC/HEAD/docs/tutorials/dual_moon.png
--------------------------------------------------------------------------------
/docs/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | :root {
2 | --md-primary-fg-color: #200f80;
3 | --md-primary-fg-color--light: #6cb8bb;
4 | --md-primary-fg-color--dark: #6cb8bb;
5 | }
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | python:
4 | install:
5 | - requirements: docs/requirements.txt
6 | - method: pip
7 | path: .
8 |
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.11"
13 |
14 | mkdocs:
15 | configuration: mkdocs.yml
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "image": "mcr.microsoft.com/devcontainers/universal:2",
3 | "features": {
4 | "ghcr.io/va-h/devcontainers-features/uv:1": {},
5 | "ghcr.io/devcontainers/features/python:1": {},
6 | "ghcr.io/devcontainers-extra/features/ruff:1": {}
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | flowMC/nfmodel/__pycache__/maf.cpython-37.pyc
3 | log_dir/
4 | *.pkl
5 | flowMC/nfmodel/realNVP_white.py
6 | example/kepler/kepler_white.py
7 | example/kepler/kepler_analyse.py
8 | dist
9 | *egg-info*
10 | docs/build/*
11 | *_autosummary
12 | README.md
13 | *settings.json
14 | *build*
15 | test_docs/test_quickstart.ipynb
16 | node_modules
17 | package.json
18 | yarn.lock
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | mkdocs==1.6.1 # Main documentation generator.
2 | mkdocs-material==9.5.47 # Theme
3 | pymdown-extensions==10.12 # Markdown extensions e.g. to handle LaTeX.
4 | mkdocstrings[python]==0.27.0 # Autogenerate documentation from docstrings.
5 | mkdocs-jupyter==0.25.1 # Turn Jupyter Lab notebooks into webpages.
6 | mkdocs-gen-files==0.5.0
7 | mkdocs-literate-nav==0.6.1
8 |
--------------------------------------------------------------------------------
/src/flowMC/resource_strategy_bundle/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from flowMC.resource.base import Resource
4 | from flowMC.strategy.base import Strategy
5 |
6 |
7 | class ResourceStrategyBundle(ABC):
8 | """Resource-Strategy Bundle is aim to be the highest level of abstraction in the
9 | flowMC library.
10 |
11 | It is a collection of resources and strategies that are used to perform a specific
12 | task.
13 | """
14 |
15 | resources: dict[str, Resource]
16 | strategies: dict[str, Strategy]
17 | strategy_order: list[str]
18 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/sequential_monte_carlo.py:
--------------------------------------------------------------------------------
1 | from flowMC.resource.base import Resource
2 | from jaxtyping import Array, Float, PRNGKeyArray
3 |
4 |
5 | class SequentialMonteCarlo(Resource):
6 | def __init__(self):
7 | raise NotImplementedError
8 |
9 | def __call__(
10 | self,
11 | rng_key: PRNGKeyArray,
12 | resources: dict[str, Resource],
13 | initial_position: Float[Array, "n_chains n_dim"],
14 | data: dict,
15 | ) -> tuple[
16 | PRNGKeyArray,
17 | dict[str, Resource],
18 | Float[Array, "n_chains n_dim"],
19 | ]:
20 | raise NotImplementedError
21 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: 25.1.0
4 | hooks:
5 | - id: black
6 | - repo: https://github.com/charliermarsh/ruff-pre-commit
7 | rev: 'v0.9.10'
8 | hooks:
9 | - id: ruff
10 | args: ["--fix"]
11 | - repo: https://github.com/RobertCraigie/pyright-python
12 | rev: v1.1.396
13 | hooks:
14 | - id: pyright
15 | additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, typing_extensions, equinox, optax, tqdm, diffrax]
16 | - repo: https://github.com/nbQA-dev/nbQA
17 | rev: 1.9.1
18 | hooks:
19 | - id: nbqa-black
20 | additional_dependencies: [ipython==8.12, black]
21 | - id: nbqa-ruff-format
22 | additional_dependencies: [ipython==8.12, ruff]
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 |
10 | pre-commit:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: ["3.11", "3.12"]
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v3
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | python -m pip install pytest
26 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
27 | python -m pip install .
28 | - uses: pre-commit/action@v3.0.1
--------------------------------------------------------------------------------
/.github/workflows/workflowsjoss.yml:
--------------------------------------------------------------------------------
1 | name: Update-Joss
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - 40-joss-paper
8 |
9 | jobs:
10 | paper:
11 | runs-on: ubuntu-latest
12 | name: Paper Draft
13 | steps:
14 | - name: Checkout
15 | uses: actions/checkout@v2
16 | - name: Build draft PDF
17 | uses: openjournals/openjournals-draft-action@master
18 | with:
19 | journal: joss
20 | # This should be the path to the paper within your repo.
21 | paper-path: joss/paper.md
22 | - name: Upload
23 | uses: actions/upload-artifact@v1
24 | with:
25 | name: paper
26 | # This is the output path where Pandoc will write the compiled
27 | # PDF. Note, this should be the same directory as the input
28 | # paper.md
29 | path: joss/paper.pdf
30 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from jaxtyping import Array, Float, PRNGKeyArray
4 |
5 | from flowMC.resource.base import Resource
6 |
7 |
8 | class Strategy(ABC):
9 | """Base class for strategies, which are basically wrapper blocks that modify the
10 | state of the sampler.
11 |
12 | This is an abstract template that should not be directly used.
13 | """
14 |
15 | @abstractmethod
16 | def __init__(self):
17 | raise NotImplementedError
18 |
19 | @abstractmethod
20 | def __call__(
21 | self,
22 | rng_key: PRNGKeyArray,
23 | resources: dict[str, Resource],
24 | initial_position: Float[Array, "n_chains n_dim"],
25 | data: dict,
26 | ) -> tuple[
27 | PRNGKeyArray,
28 | dict[str, Resource],
29 | Float[Array, "n_chains n_dim"],
30 | ]:
31 | raise NotImplementedError
32 |
--------------------------------------------------------------------------------
/src/flowMC/resource/kernel/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import equinox as eqx
3 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
4 | from flowMC.resource.base import Resource
5 | from flowMC.resource.logPDF import LogPDF
6 | from typing import Callable
7 |
8 |
9 | class ProposalBase(eqx.Module, Resource):
10 | @abstractmethod
11 | def __init__(
12 | self,
13 | ):
14 | """Initialize the sampler class."""
15 |
16 | @abstractmethod
17 | def kernel(
18 | self,
19 | rng_key: PRNGKeyArray,
20 | position: Float[Array, "nstep n_dim"],
21 | log_prob: Float[Array, "nstep 1"],
22 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
23 | data: PyTree,
24 | ) -> tuple[
25 | Float[Array, "nstep n_dim"], Float[Array, "nstep 1"], Int[Array, "n_step 1"]
26 | ]:
27 | """Kernel for one step in the proposal cycle."""
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/docs/gen_ref_pages.py:
--------------------------------------------------------------------------------
1 | """Generate the code reference pages."""
2 |
3 | from pathlib import Path
4 |
5 | import mkdocs_gen_files
6 |
7 | nav = mkdocs_gen_files.Nav()
8 |
9 |
10 | for path in sorted(Path("src").rglob("*.py")): #
11 | module_path = path.relative_to("src").with_suffix("") #
12 | doc_path = path.relative_to("src").with_suffix(".md") #
13 | full_doc_path = Path("api", doc_path) #
14 |
15 | parts = list(module_path.parts)
16 |
17 | if parts[-1] == "__init__": #
18 | continue
19 | elif parts[-1] == "__main__":
20 | continue
21 |
22 | nav[parts] = doc_path.as_posix()
23 |
24 | print(full_doc_path)
25 |
26 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: #
27 | identifier = ".".join(parts) #
28 | print("::: " + identifier, file=fd) #
29 |
30 | mkdocs_gen_files.set_edit_path(full_doc_path, path) #
31 |
32 | with mkdocs_gen_files.open("api/summary.md", "w") as nav_file: #
33 | nav_file.writelines(nav.build_literate_nav())
34 |
--------------------------------------------------------------------------------
/src/flowMC/resource/optimizer.py:
--------------------------------------------------------------------------------
1 | from flowMC.resource.base import Resource
2 | import optax
3 | import equinox as eqx
4 |
5 |
6 | class Optimizer(Resource):
7 | optim: optax.GradientTransformation
8 | optim_state: optax.OptState
9 |
10 | def __repr__(self):
11 | return "Optimizer"
12 |
13 | def __init__(
14 | self,
15 | model: eqx.Module,
16 | learning_rate: float = 1e-3,
17 | momentum: float = 0.9,
18 | ):
19 | self.optim = optax.chain(
20 | optax.clip_by_global_norm(1.0),
21 | optax.adamw(learning_rate=learning_rate, b1=momentum),
22 | )
23 | self.optim_state = self.optim.init(eqx.filter(model, eqx.is_array))
24 |
25 | def __call__(self, params, grads):
26 | raise NotImplementedError
27 |
28 | def print_parameters(self):
29 | raise NotImplementedError
30 |
31 | def save_resource(self, path: str):
32 | raise NotImplementedError
33 |
34 | def load_resource(self, path: str):
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/src/flowMC/resource/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from typing import Self
4 |
5 |
6 | class Resource(ABC):
7 | """Base class for resources. Resources are objects such as local sampler and neural
8 | networks.
9 |
10 | This is an abstract template that should not be directly used.
11 | """
12 |
13 | @abstractmethod
14 | def __init__(self):
15 | raise NotImplementedError
16 |
17 | @abstractmethod
18 | def print_parameters(self):
19 | """Function to print the tunable parameters of the resource."""
20 | raise NotImplementedError
21 |
22 | @abstractmethod
23 | def save_resource(self, path: str):
24 | """Function to save the resource.
25 |
26 | Args:
27 | path (str): Path to save the resource.
28 | """
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def load_resource(self, path: str) -> Self:
33 | """Function to load the resource.
34 |
35 | Args:
36 | path (str): Path to load the resource.
37 | """
38 | raise NotImplementedError
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Kaze Wong & contributor
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Run Tests
5 |
6 | on: [push, pull_request]
7 |
8 | jobs:
9 | build:
10 |
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: ["3.11", "3.12"]
16 |
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 |
21 | - name: Install uv
22 | uses: astral-sh/setup-uv@v3
23 |
24 | - name: Set up Python ${{ matrix.python-version }}
25 | run: uv python install ${{ matrix.python-version }}
26 |
27 | - name: Install the project
28 | run: uv sync --all-extras --dev
29 |
30 | - name: Run tests and coverage
31 | run: |
32 | uv run coverage run --source=src -m pytest test
33 | uv run coveralls --service=github-actions
34 | env:
35 | COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }}
36 |
37 | - name: Run build
38 | run: uv build
--------------------------------------------------------------------------------
/docs/communityExamples.md:
--------------------------------------------------------------------------------
1 | # Community examples
2 |
3 | The core design philosophy of flowMC is to stay lean and simple, so we decided to leave most of the use case-specific optimization to the users. Hence, we keep most of the flowMC internals away from most of the users,
4 | so the only interfaces with flowMC are really defining your likelihood and tuning the sampler parameters exposed on the top level.
5 | That said, it would be useful to have references to see how to use/tune flowMC for different use cases, therefore in this page we host a number of community examples that are contributed by the users.
6 | If you find flowMC useful, please consider contributing your example to this page. This will help other users (and perhaps your future students) to get started quickly.
7 |
8 | ## Examples
9 |
10 | - [jim - A JAX-based gravitational-wave inference toolkit](https://github.com/kazewong/jim)
11 | - [Bayeux - Stitching together models and samplers](https://github.com/jax-ml/bayeux)
12 | - [Colab example](https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing)
13 | - [Markovian Flow Matching: Accelerating MCMC with Continuous Normalizing Flows](https://arxiv.org/pdf/2405.14392)
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 | workflow_dispatch:
15 |
16 |
17 | permissions:
18 | contents: read
19 |
20 | jobs:
21 | deploy:
22 |
23 | runs-on: ubuntu-latest
24 |
25 | steps:
26 | - uses: actions/checkout@v3
27 | - name: Set up Python
28 | uses: actions/setup-python@v3
29 | with:
30 | python-version: '3.x'
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install build
35 | - name: Build package
36 | run: python -m build
37 | - name: Publish package
38 | uses: pypa/gh-action-pypi-publish@release/v1
39 | with:
40 | user: __token__
41 | password: ${{ secrets.PYPI_API_TOKEN }}
42 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/lambda_function.py:
--------------------------------------------------------------------------------
1 | from flowMC.strategy.base import Strategy
2 | from flowMC.resource.base import Resource
3 | from jaxtyping import Array, Float, PRNGKeyArray
4 | from typing import Callable
5 |
6 |
7 | class Lambda(Strategy):
8 | """A strategy that applies a function to the resources.
9 |
10 | This should be used for simple functions or calling methods from
11 | a class.
12 | If you find yourself writing a Lambda strategy that is more than a few lines
13 | long, consider writing a custom strategy instead.
14 |
15 | """
16 |
17 | def __init__(self, lambda_function: Callable):
18 | """Initialize the lambda strategy.
19 |
20 | Args:
21 | lambda: A callable that takes a resource and applies the lambda function.
22 | """
23 | self.lambda_function = lambda_function
24 |
25 | def __call__(
26 | self,
27 | rng_key: PRNGKeyArray,
28 | resources: dict[str, Resource],
29 | initial_position: Float[Array, "n_chains n_dim"],
30 | data: dict,
31 | ) -> tuple[
32 | PRNGKeyArray,
33 | dict[str, Resource],
34 | Float[Array, "n_chains n_dim"],
35 | ]:
36 | self.lambda_function(rng_key, resources, initial_position, data)
37 | return rng_key, resources, initial_position
38 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | flowMC
2 | ======
3 |
4 | **Normalizing-flow enhanced sampling package for probabilistic inference**
5 |
6 |
7 | 
8 |
9 | [](https://flowMC.readthedocs.io/en/latest/)
10 | [](https//github.com/kazewong/flowMC/blob/Packaging/LICENSE)
11 |
12 |
13 |
14 | flowMC is a Jax-based python package for normalizing-flow enhanced Markov chain Monte Carlo (MCMC) sampling.
15 | The code is open source under MIT license, and it is under active development.
16 |
17 | - Just-in-time compilation is supported.
18 | - Native support for GPU acceleration.
19 | - Suit for problems with multi-modality and complex geometry.
20 |
21 | Four steps to use flowMC's guide
22 | ================================
23 |
24 | 1. You can find installation info, a basic example, and some design and guiding principles of `flowMC` in the [quickstart](quickstart.md).
25 | 2. In tutorials, we have a set of pedagogical notebooks that will give you a better understanding of the package infrastructure and a number of common use cases.
26 | 3. We list some community examples in [community_example](communityExamples.md), so users can see whether there is a similar use case they can adopt their code quickly.
27 | 4. Finally, we have a list of frequently asked questions in [FAQ](FAQ.md).
28 |
29 |
--------------------------------------------------------------------------------
/test/integration/test_quickstart.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from flowMC.Sampler import Sampler
4 | from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle
5 |
6 |
7 | def log_posterior(x, data: dict):
8 | return -0.5 * jnp.sum((x - data["data"]) ** 2)
9 |
10 |
11 | n_dims = 2
12 | n_local_steps = 10
13 | n_global_steps = 10
14 | n_training_loops = 5
15 | n_production_loops = 5
16 | n_epochs = 10
17 | n_chains = 10
18 | rq_spline_hidden_units = [64, 64]
19 | rq_spline_n_bins = 8
20 | rq_spline_n_layers = 3
21 | data = {"data": jnp.arange(n_dims).astype(jnp.float32)}
22 |
23 | rng_key = jax.random.PRNGKey(42)
24 | rng_key, subkey = jax.random.split(rng_key)
25 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1
26 |
27 | rng_key, subkey = jax.random.split(rng_key)
28 | bundle = RQSpline_MALA_Bundle(
29 | subkey,
30 | n_chains,
31 | n_dims,
32 | log_posterior,
33 | n_local_steps,
34 | n_global_steps,
35 | n_training_loops,
36 | n_production_loops,
37 | n_epochs,
38 | rq_spline_hidden_units=rq_spline_hidden_units,
39 | rq_spline_n_bins=rq_spline_n_bins,
40 | rq_spline_n_layers=rq_spline_n_layers,
41 | verbose=True,
42 | )
43 |
44 | nf_sampler = Sampler(
45 | n_dims,
46 | n_chains,
47 | rng_key,
48 | resource_strategy_bundles=bundle,
49 | )
50 |
51 | nf_sampler.sample(initial_position, data)
52 |
--------------------------------------------------------------------------------
/test/unit/test_bundle.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from flowMC.resource_strategy_bundle.RQSpline_MALA import RQSpline_MALA_Bundle
3 | from flowMC.resource_strategy_bundle.RQSpline_MALA_PT import RQSpline_MALA_PT_Bundle
4 |
5 |
6 | def logpdf(x, _):
7 | return -0.5 * (x**2).sum()
8 |
9 |
10 | def test_rqspline_mala_bundle_initialization():
11 | rng_key = jax.random.PRNGKey(0)
12 | n_chains = 2
13 | n_dims = 3
14 | n_local_steps = 10
15 | n_global_steps = 5
16 | n_training_loops = 2
17 | n_production_loops = 1
18 | n_epochs = 3
19 |
20 | bundle = RQSpline_MALA_Bundle(
21 | rng_key=rng_key,
22 | n_chains=n_chains,
23 | n_dims=n_dims,
24 | logpdf=logpdf,
25 | n_local_steps=n_local_steps,
26 | n_global_steps=n_global_steps,
27 | n_training_loops=n_training_loops,
28 | n_production_loops=n_production_loops,
29 | n_epochs=n_epochs,
30 | )
31 |
32 | assert repr(bundle) == "RQSpline_MALA Bundle"
33 |
34 |
35 | def test_rqspline_mala_pt_bundle_initialization():
36 | rng_key = jax.random.PRNGKey(0)
37 | n_chains = 2
38 | n_dims = 3
39 | n_local_steps = 10
40 | n_global_steps = 5
41 | n_training_loops = 2
42 | n_production_loops = 1
43 | n_epochs = 3
44 |
45 | bundle = RQSpline_MALA_PT_Bundle(
46 | rng_key=rng_key,
47 | n_chains=n_chains,
48 | n_dims=n_dims,
49 | logpdf=logpdf,
50 | n_local_steps=n_local_steps,
51 | n_global_steps=n_global_steps,
52 | n_training_loops=n_training_loops,
53 | n_production_loops=n_production_loops,
54 | n_epochs=n_epochs,
55 | )
56 |
57 | assert repr(bundle) == "RQSpline MALA PT Bundle"
58 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/update_state.py:
--------------------------------------------------------------------------------
1 | from flowMC.resource.states import State
2 | from flowMC.strategy.base import Strategy
3 | from flowMC.resource.base import Resource
4 | from jaxtyping import Array, Float, PRNGKeyArray
5 |
6 |
7 | class UpdateState(Strategy):
8 | """Update a state resource in place.
9 |
10 | This strategy is meant to be used to update the state not too frequently.
11 | If you are looking for an option that iterates over some parameters,
12 | say the paramters of a neural network, you should write a custom strategy
13 | that does that.
14 | """
15 |
16 | def __init__(
17 | self, state_name: str, keys: list[str], values: list[int | bool | str]
18 | ):
19 | """Initialize the update state strategy.
20 |
21 | Args:
22 | state_name (str): The name of the state resource to update.
23 | keys (list[str]): The keys to update in the state resource.
24 | values (list[int | bool | str]): The values to update in the state resource.
25 | """
26 | self.state_name = state_name
27 | self.keys = keys
28 | self.values = values
29 |
30 | def __call__(
31 | self,
32 | rng_key: PRNGKeyArray,
33 | resources: dict[str, Resource],
34 | initial_position: Float[Array, "n_chains n_dim"],
35 | data: dict,
36 | ) -> tuple[
37 | PRNGKeyArray,
38 | dict[str, Resource],
39 | Float[Array, "n_chains n_dim"],
40 | ]:
41 | """Update the state resource in place."""
42 | assert isinstance(
43 | state := resources[self.state_name], State
44 | ), f"Resource {self.state_name} is not a State resource."
45 |
46 | state.update(self.keys, self.values)
47 | return rng_key, resources, initial_position
48 |
--------------------------------------------------------------------------------
/test/integration/test_normalizingFlow.py:
--------------------------------------------------------------------------------
1 | import equinox as eqx # Equinox utilities
2 | import jax
3 | import jax.numpy as jnp # JAX NumPy
4 | import optax # Optimizers
5 |
6 | from flowMC.resource.model.nf_model.realNVP import RealNVP
7 | from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline
8 |
9 |
10 | def test_realNVP():
11 | key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3)
12 | data = jax.random.normal(key1, (100, 2))
13 |
14 | num_epochs = 5
15 | batch_size = 100
16 | learning_rate = 0.001
17 | momentum = 0.9
18 |
19 | model = RealNVP(2, 4, 32, rng)
20 | optim = optax.adam(learning_rate, momentum)
21 | state = optim.init(eqx.filter(model, eqx.is_array))
22 |
23 | rng, best_model, state, loss_values = model.train(
24 | init_rng, data, optim, state, num_epochs, batch_size, verbose=True
25 | )
26 | rng_key_nf = jax.random.PRNGKey(124098)
27 | model.sample(rng_key_nf, 10000)
28 |
29 |
30 | def test_rqSpline():
31 | n_dim = 2
32 | num_epochs = 5
33 | batch_size = 100
34 | learning_rate = 0.001
35 | momentum = 0.9
36 |
37 | key1, rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3)
38 | data = jax.random.normal(key1, (batch_size, n_dim))
39 |
40 | n_layers = 4
41 | hidden_dim = 32
42 | num_bins = 4
43 |
44 | model = MaskedCouplingRQSpline(
45 | n_dim,
46 | n_layers,
47 | [hidden_dim, hidden_dim],
48 | num_bins,
49 | rng,
50 | data_mean=jnp.mean(data, axis=0),
51 | data_cov=jnp.cov(data.T),
52 | )
53 | optim = optax.adam(learning_rate, momentum)
54 | state = optim.init(eqx.filter(model, eqx.is_array))
55 |
56 | rng, best_model, state, loss_values = model.train(
57 | init_rng, data, optim, state, num_epochs, batch_size, verbose=True
58 | )
59 | rng_key_nf = jax.random.PRNGKey(124098)
60 | model.sample(rng_key_nf, 10000)
61 |
--------------------------------------------------------------------------------
/src/flowMC/resource/buffers.py:
--------------------------------------------------------------------------------
1 | from flowMC.resource.base import Resource
2 | from typing import TypeVar
3 | import numpy as np
4 | from jaxtyping import Array, Float
5 | import jax.numpy as jnp
6 | import jax
7 |
8 | TBuffer = TypeVar("TBuffer", bound="Buffer")
9 |
10 |
11 | class Buffer(Resource):
12 | name: str
13 | data: Float[Array, " ..."]
14 | cursor: int = 0
15 | cursor_dim: int = 0
16 |
17 | def __repr__(self):
18 | return "Buffer " + self.name + " with shape " + str(self.data.shape)
19 |
20 | @property
21 | def shape(self):
22 | return self.data.shape
23 |
24 | def __init__(self, name: str, shape: tuple[int, ...], cursor_dim: int = 0):
25 | self.cursor_dim = cursor_dim
26 | self.name = name
27 | self.data = jnp.zeros(shape) - jnp.inf
28 |
29 | def __call__(self):
30 | return self.data
31 |
32 | def update_buffer(self, updates: Array, start: int = 0):
33 | """Update the buffer with new data.
34 |
35 | This will modify the buffer in place.
36 | The cursor is expected to propagate the buffer in the cursor_dim
37 | with length equal to the length of the updates in its first dimension.
38 | """
39 | self.data = jax.lax.dynamic_update_slice_in_dim(
40 | self.data, updates, start, self.cursor_dim
41 | )
42 |
43 | def print_parameters(self):
44 | print(
45 | f"Buffer: {self.name} with shape {self.data.shape} and cursor"
46 | f" {self.cursor} at dimension {self.cursor_dim}"
47 | )
48 |
49 | def get_distribution(self, n_bins: int = 100):
50 | return np.histogram(self.data.flatten(), bins=n_bins)
51 |
52 | def save_resource(self, path: str):
53 | np.savez(
54 | path + self.name,
55 | name=self.name,
56 | data=self.data,
57 | )
58 |
59 | def load_resource(self: TBuffer, path: str) -> TBuffer:
60 | data = np.load(path)
61 | buffer: Float[Array, " ..."] = data["data"]
62 | result = Buffer(data["name"], buffer.shape)
63 | result.data = buffer
64 | return result # type: ignore
65 |
--------------------------------------------------------------------------------
/test/unit/test_nf.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 | from flowMC.resource.model.nf_model.realNVP import AffineCoupling, RealNVP
5 | from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline
6 |
7 |
8 | def test_affine_coupling_forward_and_inverse():
9 | n_features = 2
10 | n_hidden = 4
11 | x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
12 | mask = jnp.where(jnp.arange(n_features) % 2 == 0, 1.0, 0.0)
13 | key = jax.random.PRNGKey(0)
14 | dt = 0.5
15 | layer = AffineCoupling(n_features, n_hidden, mask, key, dt)
16 |
17 | y_forward, log_det_forward = jax.vmap(layer.forward)(x)
18 | x_recon, log_det_inverse = jax.vmap(layer.inverse)(y_forward)
19 |
20 | assert jnp.allclose(x, jnp.round(x_recon, decimals=5))
21 | assert jnp.allclose(log_det_forward, -log_det_inverse)
22 |
23 |
24 | def test_realnvp():
25 | n_features = 3
26 | n_hidden = 4
27 | n_layers = 2
28 | x = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
29 |
30 | rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2)
31 | model = RealNVP(n_features, n_layers, n_hidden, rng_key)
32 |
33 | y, log_det = jax.vmap(model)(x)
34 |
35 | assert y.shape == x.shape
36 | assert log_det.shape == (2,)
37 |
38 | y_inv, log_det_inv = jax.vmap(model.inverse)(y)
39 |
40 | assert y_inv.shape == x.shape
41 | assert log_det_inv.shape == (2,)
42 | assert jnp.allclose(x, y_inv)
43 | assert jnp.allclose(log_det, -log_det_inv)
44 |
45 | rng_key = jax.random.PRNGKey(0)
46 | samples = model.sample(rng_key, 2)
47 |
48 | assert samples.shape == (2, 3)
49 |
50 | log_prob = jax.vmap(model.log_prob)(samples)
51 |
52 | assert log_prob.shape == (2,)
53 |
54 |
55 | def test_rqspline():
56 | n_features = 3
57 | hidden_layes = [16, 16]
58 | n_layers = 2
59 | n_bins = 8
60 |
61 | rng_key, rng_subkey = jax.random.split(jax.random.PRNGKey(0), 2)
62 | model = MaskedCouplingRQSpline(
63 | n_features, n_layers, hidden_layes, n_bins, jax.random.PRNGKey(10)
64 | )
65 |
66 | rng_key = jax.random.PRNGKey(0)
67 | samples = model.sample(rng_key, 2)
68 |
69 | assert samples.shape == (2, 3)
70 |
71 | log_prob = jax.vmap(model.log_prob)(samples)
72 |
73 | assert log_prob.shape == (2,)
74 |
--------------------------------------------------------------------------------
/src/flowMC/resource/states.py:
--------------------------------------------------------------------------------
1 | from flowMC.resource.base import Resource
2 | from typing import TypeVar
3 | import numpy as np
4 |
5 | TState = TypeVar("TState", bound="State")
6 |
7 |
8 | class State(Resource):
9 | """A Resource class that holds the state of the system.
10 | This is essentially a wrapper around a dictionary such that it can be
11 | handled by flowMC.
12 |
13 | We restrict the type of the state to be simple types including integers, booleans and strings.
14 | The main reason for this is State is expected to be used to indiciate stage of individual
15 | strategies instead of storing parameters to resources.
16 | I.e. State should hold whehter the sampler is in training phase or production phase.
17 | But not mass matrix of a kernel per se.
18 | """
19 |
20 | name: str
21 | data: dict[str, int | bool | str]
22 |
23 | def __repr__(self):
24 | return "State " + self.name + " with shape " + str(len(self.data))
25 |
26 | def __init__(self, data: dict[str, int | bool | str], name: str = "State"):
27 | """Initialize the state.
28 |
29 | Args:
30 | data (dict): The data to initialize the state with.
31 | name (str): The name of the state.
32 | """
33 |
34 | self.name = name
35 | self.data = data
36 |
37 | def update(self, key: list[str], value: list[int | bool | str]):
38 | """Update the state with new data.
39 |
40 | This will modify the state in place.
41 |
42 | Args:
43 | key (str): The key to update.
44 | value (int | bool | str): The value to update.
45 | """
46 | for k, v in zip(key, value):
47 | self.data[k] = v
48 | print(f"Updated state {k} to {v}")
49 |
50 | def print_parameters(self):
51 | print(f"State: {self.name} with shape {len(self.data)} and data {self.data}")
52 |
53 | def save_resource(self, path: str):
54 | np.savez(
55 | path + self.name,
56 | name=self.name,
57 | data=self.data, # type: ignore
58 | )
59 |
60 | def load_resource(self: TState, path: str) -> TState:
61 | data = np.load(path)
62 | result = State(data["name"], data["data"])
63 | return result # type: ignore
64 |
--------------------------------------------------------------------------------
/test/integration/test_MALA.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.scipy.special import logsumexp
4 | from jaxtyping import Array, Float
5 |
6 | from flowMC.resource.kernel.MALA import MALA
7 | from flowMC.strategy.take_steps import TakeSerialSteps
8 | from flowMC.resource.buffers import Buffer
9 | from flowMC.resource.states import State
10 | from flowMC.resource.logPDF import LogPDF
11 | from flowMC.Sampler import Sampler
12 |
13 |
14 | def dual_moon_pe(x: Float[Array, "n_dims"], data: dict):
15 | """
16 | Term 2 and 3 separate the distribution
17 | and smear it along the first and second dimension
18 | """
19 | print("compile count")
20 | term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2
21 | term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
22 | term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
23 | return -(term1 - logsumexp(term2) - logsumexp(term3))
24 |
25 |
26 | n_dims = 5
27 | n_chains = 15
28 | n_local_steps = 30
29 | step_size = 0.01
30 |
31 | data = {"data": jnp.arange(5)}
32 |
33 | rng_key = jax.random.PRNGKey(42)
34 | rng_key, subkey = jax.random.split(rng_key)
35 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1
36 |
37 | # Defining resources
38 |
39 | MALA_Sampler = MALA(step_size=step_size)
40 | positions = Buffer("positions", (n_chains, n_local_steps, n_dims), 1)
41 | log_prob = Buffer("log_prob", (n_chains, n_local_steps), 1)
42 | acceptance = Buffer("acceptance", (n_chains, n_local_steps), 1)
43 | sampler_state = State(
44 | {
45 | "positions": "positions",
46 | "log_prob": "log_prob",
47 | "acceptance": "acceptance",
48 | },
49 | name="sampler_state",
50 | )
51 |
52 | resource = {
53 | "positions": positions,
54 | "log_prob": log_prob,
55 | "acceptance": acceptance,
56 | "MALA": MALA_Sampler,
57 | "logpdf": LogPDF(dual_moon_pe, n_dims=n_dims),
58 | "sampler_state": sampler_state,
59 | }
60 |
61 | # Defining strategy
62 |
63 | strategy = TakeSerialSteps(
64 | "logpdf",
65 | kernel_name="MALA",
66 | state_name="sampler_state",
67 | buffer_names=["positions", "log_prob", "acceptance"],
68 | n_steps=n_local_steps,
69 | )
70 |
71 | nf_sampler = Sampler(
72 | n_dim=n_dims,
73 | n_chains=n_chains,
74 | rng_key=rng_key,
75 | resources=resource,
76 | strategies={"take_steps": strategy},
77 | strategy_order=["take_steps"],
78 | )
79 |
80 | nf_sampler.sample(initial_position, data)
81 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "flowMC"
3 | version = "0.4.5"
4 | description = "Normalizing flow exhanced sampler in jax"
5 | authors = [
6 | { name = "Kaze Wong", email = "kazewong.physics@gmail.com"},
7 | { name = "Marylou Gabrié"},
8 | { name = "Dan Foreman-Mackey"}
9 | ]
10 |
11 | classifiers = [
12 | "Programming Language :: Python :: 3.11",
13 | "Programming Language :: Python :: 3.12",
14 | "Programming Language :: Python :: 3.13",
15 | "License :: OSI Approved :: MIT License",
16 | "Intended Audience :: Science/Research",
17 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
18 | "Topic :: Scientific/Engineering :: Mathematics",
19 | "Topic :: Scientific/Engineering :: Information Analysis",
20 | ]
21 |
22 | readme = "README.md"
23 | requires-python = ">=3.11"
24 | keywords = ["sampling", "inference", "machine learning", "normalizing", "autodiff", "jax"]
25 | dependencies = [
26 | "chex>=0.1.87",
27 | "diffrax>=0.7.0",
28 | "equinox>=0.11.9",
29 | "jax[cpu]>=0.5.0",
30 | "jaxtyping>=0.2.36",
31 | "optax>=0.2.4",
32 | "scikit-learn>=1.6.0",
33 | "tqdm>=4.67.1",
34 | ]
35 | license = { file = "LICENSE" }
36 |
37 | [project.urls]
38 | Documentation = "https://github.com/kazewong/flowMC"
39 |
40 |
41 | [project.optional-dependencies]
42 | docs = [
43 | "mkdocs-gen-files==0.5.0",
44 | "mkdocs-jupyter==0.25.1",
45 | "mkdocs-literate-nav==0.6.1",
46 | "mkdocs-material==9.5.47",
47 | "mkdocs==1.6.1",
48 | "mkdocstrings[python]==0.27.0",
49 | "pymdown-extensions==10.12",
50 | ]
51 | visualize = [
52 | "arviz>=0.21.0",
53 | "corner>=2.2.3",
54 | "matplotlib>=3.9.3",
55 | ]
56 | cuda = [
57 | "jax[cuda12]>=0.5.0",
58 | ]
59 |
60 | [dependency-groups]
61 | dev = [
62 | "flowMC",
63 | "ipykernel>=6.29.5",
64 | "coveralls>=4.0.1",
65 | "pre-commit>=4.0.1",
66 | "pyright>=1.1.389",
67 | "pytest>=8.3.3",
68 | "ruff>=0.8.0",
69 | "ipython>=8.30.0",
70 | ]
71 |
72 | [tool.uv.sources]
73 | flowMC = { workspace = true }
74 |
75 |
76 | [build-system]
77 | requires = ["hatchling"]
78 | build-backend = "hatchling.build"
79 |
80 |
81 | [tool.pyright]
82 | include = [
83 | "src",
84 | "tests",
85 | ]
86 | exclude = [
87 | "docs"
88 | ]
89 |
90 | [tool.coverage.report]
91 | exclude_also = [
92 | 'def __repr__',
93 | "raise AssertionError",
94 | "raise NotImplementedError",
95 | "@(abc\\. )?abstractmethod",
96 | "def tree_flatten",
97 | "def tree_unflatten",
98 | ]
99 |
--------------------------------------------------------------------------------
/docs/FAQ.md:
--------------------------------------------------------------------------------
1 | FAQ
2 | ===
3 |
4 | **My local sampler is not accepting**
5 |
6 | This usually means you are setting the step size of the local samplerto be too big.
7 | Try reducing the step size in your local sampler.
8 |
9 | Alternative, this could also mean your sampler is proposing in region where the likelihood is ill-defined (i.e. NaN either in likelihood or its derivative if you are using a gradient-based local sampler).
10 | It is worth making sure your likelihood is well-defined within your range of prior.
11 |
12 | **In order for my local sampler to accept, I have to choose a very small step size, which makes my chain very correlated.**
13 |
14 | This usually indicate some of your parameters are much better measured than others.
15 | Since taking a small step in those directions will already change your likelihood value by a lot, the exploration power of the local sampler in other parameters are limited by those which are well measured.
16 | Currently, we support different step size for different parameters, which you can tune to see whether that improves the situation or not.
17 | If you know the scale of each parameter ahead of time, reparameterizing them to maintain roughly equal scale across parameters also helps.
18 |
19 | **My global sample's loss is exploding/not decreasing**
20 |
21 | This usually means your learning rate used for training the normalizing flow is too large.
22 | Try reducing the learning rate by a factor of ten.
23 |
24 | Another reason for a flat loss is your local sampler is not accepting at all.
25 | This is a bit rarer since this means your data used to train the normalizing flow is just your prior, which the normalizing flow should still be able to learn.
26 |
27 | **The sampler is stuck a bit until it starts sampling**
28 |
29 | If you use the option ``Jit`` in constructing the local sampler, the code will compile your code to speed up the execution.
30 | The sampler is not really stuck, but it is compiling the code. Depending on how you code up your likelihood function, the compilation can take a while.
31 | If you don't want to wait, you can set ``Jit=False``, which would increase the sampling time.
32 |
33 | **The compilation is slow**
34 |
35 | If you have a likelihood with many lines, Jax will take a long time to compile the code.
36 | Jax is known to be slow in compilation, especially if your computational graph uses some sort of loop that call a function many times.
37 | While we cannot fundamentally get rid of the problem, [using a jax.lax.scan](https://docs.kidger.site/equinox/tricks/#low-overhead-training-loops) is usually how we deal with it.
38 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 |
2 | ### Expectations
3 |
4 | flowMC is developed and maintained in my spare time and, while I try to be
5 | responsive, I don't always get to every issue immediately. If it has been more
6 | than a week or two, feel free to ping me (@kazewong) to try to get my attention. This is
7 | subject to changes as the community grows.
8 |
9 | ### Did you find a bug?
10 |
11 | **Ensure the bug was not already reported** by searching on GitHub under
12 | [Issues](https://github.com/kazewong/flowMC/issues). If you're unable to find an
13 | open issue addressing the problem, [open a new
14 | one](https://github.com/kazewong/flowMC/issues/new). Be sure to include a **title
15 | and clear description**, as much relevant information as possible, and the
16 | simplest possible **code sample** demonstrating the expected behavior that is
17 | not occurring. Also label the issue with the bug label.
18 |
19 | ### Did you write a patch that fixes a bug?
20 |
21 | Open a new GitHub pull request with the patch. Ensure the PR description clearly
22 | describes the problem and solution. Include the relevant issue number if
23 | applicable.
24 |
25 | ### Do you intend to add a new feature or change an existing feature?
26 |
27 | Please follow the following principle when you are thinking about adding a new
28 | feature or changing an existing feature:
29 |
30 | 1. The new feature should be able to take advantage of `jax.jit` whenever possible.
31 | 2. Light weight and modular implementation is preferred.
32 | 3. The core package only does sampling. If you have a concrete example that
33 | involves a complete analysis such as plotting and models, see the next
34 | contribution guide.
35 |
36 | Suggestions for new features are welcome on [flowMC support
37 | group](https://groups.google.com/u/1/g/flowmc). Note that features related to the
38 | core algorithm are unlikely to be accepted since that may include a lot of
39 | breaking changes.
40 |
41 | ### Do you intend to introduce an example or tutorial?
42 |
43 | Open a new GitHub pull request with the example or tutorial. The example should
44 | be self-contained and keep import from other packages to minimal. Leave the
45 | case-specific analysis detail out. For more extensive tutorial, we encourage the
46 | community to link the minimal example hosted on the flowMC documentation to
47 | documentation from other packages.
48 |
49 | ### Do you have question about the code?
50 |
51 | Do not open an issue. Instead, find whether there are already existing threads
52 | on the [flowMC support group](https://groups.google.com/u/1/g/flowmc). If not,
53 | please open a new conversation there.
54 |
--------------------------------------------------------------------------------
/test/integration/test_HMC.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.scipy.special import logsumexp
4 | from jaxtyping import Array, Float
5 |
6 | from flowMC.resource.kernel.HMC import HMC
7 | from flowMC.strategy.take_steps import TakeSerialSteps
8 | from flowMC.resource.buffers import Buffer
9 | from flowMC.resource.states import State
10 | from flowMC.resource.logPDF import LogPDF
11 | from flowMC.Sampler import Sampler
12 |
13 |
14 | def dual_moon_pe(x: Float[Array, " n_dims"], data: dict):
15 | """
16 | Term 2 and 3 separate the distribution
17 | and smear it along the first and second dimension
18 | """
19 | print("compile count")
20 | term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2
21 | term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
22 | term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
23 | return -(term1 - logsumexp(term2) - logsumexp(term3))
24 |
25 |
26 | # Setup parameters
27 | n_dims = 5
28 | n_chains = 15
29 | n_local_steps = 30
30 | step_size = 0.1
31 | n_leapfrog = 10
32 |
33 | data = {"data": jnp.arange(5)}
34 |
35 | # Initialize positions
36 | rng_key = jax.random.PRNGKey(42)
37 | rng_key, subkey = jax.random.split(rng_key)
38 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1
39 |
40 | # Define resources
41 | HMC_sampler = HMC(
42 | step_size=step_size,
43 | n_leapfrog=n_leapfrog,
44 | condition_matrix=jnp.eye(n_dims),
45 | )
46 |
47 | positions = Buffer("positions", (n_chains, n_local_steps, n_dims), 1)
48 | log_prob = Buffer("log_prob", (n_chains, n_local_steps), 1)
49 | acceptance = Buffer("acceptance", (n_chains, n_local_steps), 1)
50 |
51 | sampler_state = State(
52 | {
53 | "positions": "positions",
54 | "log_prob": "log_prob",
55 | "acceptance": "acceptance",
56 | },
57 | name="sampler_state",
58 | )
59 |
60 | resource = {
61 | "positions": positions,
62 | "log_prob": log_prob,
63 | "acceptance": acceptance,
64 | "HMC": HMC_sampler,
65 | "logpdf": LogPDF(dual_moon_pe, n_dims=n_dims),
66 | "sampler_state": sampler_state,
67 | }
68 |
69 | # Define strategy
70 | strategy = TakeSerialSteps(
71 | "logpdf",
72 | kernel_name="HMC",
73 | state_name="sampler_state",
74 | buffer_names=["positions", "log_prob", "acceptance"],
75 | n_steps=n_local_steps,
76 | )
77 |
78 | # Initialize and run sampler
79 | nf_sampler = Sampler(
80 | n_dim=n_dims,
81 | n_chains=n_chains,
82 | rng_key=rng_key,
83 | resources=resource,
84 | strategies={"take_steps": strategy},
85 | strategy_order=["take_steps"],
86 | )
87 |
88 | nf_sampler.sample(initial_position, data)
89 |
--------------------------------------------------------------------------------
/src/flowMC/resource/kernel/Gaussian_random_walk.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
4 | from typing import Callable
5 |
6 | from flowMC.resource.kernel.base import ProposalBase
7 | from flowMC.resource.logPDF import LogPDF
8 |
9 |
10 | class GaussianRandomWalk(ProposalBase):
11 | """Gaussian random walk sampler class."""
12 |
13 | step_size: Float
14 |
15 | def __repr__(self):
16 | return "Gaussian Random Walk with step size " + str(self.step_size)
17 |
18 | def __init__(
19 | self,
20 | step_size: Float,
21 | ):
22 | super().__init__()
23 | self.step_size = step_size
24 |
25 | def kernel(
26 | self,
27 | rng_key: PRNGKeyArray,
28 | position: Float[Array, " n_dim"],
29 | log_prob: Float[Array, "1"],
30 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
31 | data: PyTree,
32 | ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]:
33 | """Random walk gaussian kernel. This is a kernel that only evolve a single
34 | chain.
35 |
36 | Args:
37 | rng_key (PRNGKeyArray): Jax PRNGKey
38 | position (Float[Array, "n_dim"]): current position of the chain
39 | log_prob (Float[Array, "1"]): current log-probability of the chain
40 | data (PyTree): data to be passed to the logpdf function
41 |
42 | Returns:
43 | position (Float[Array, "n_dim"]): new position of the chain
44 | log_prob (Float[Array, "1"]): new log-probability of the chain
45 | do_accept (Int[Array, "1"]): whether the new position is accepted
46 | """
47 |
48 | key1, key2 = jax.random.split(rng_key)
49 | move_proposal: Float[Array, " n_dim"] = (
50 | jax.random.normal(key1, shape=position.shape) * self.step_size
51 | )
52 |
53 | proposal = position + move_proposal
54 | proposal_log_prob: Float[Array, " n_dim"] = logpdf(proposal, data)
55 |
56 | log_uniform = jnp.log(jax.random.uniform(key2))
57 | do_accept = log_uniform < proposal_log_prob - log_prob
58 |
59 | position = jnp.where(do_accept, proposal, position)
60 | log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
61 | return position, log_prob, do_accept
62 |
63 | def print_parameters(self):
64 | print("Gaussian Random Walk parameters:")
65 | print(f"step_size: {self.step_size}")
66 |
67 | def save_resource(self, path):
68 | raise NotImplementedError
69 |
70 | def load_resource(self, path):
71 | raise NotImplementedError
72 |
--------------------------------------------------------------------------------
/test/integration/test_RWMCMC.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.scipy.special import logsumexp
4 | from jaxtyping import Array, Float
5 |
6 | from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk
7 | from flowMC.strategy.take_steps import TakeSerialSteps
8 | from flowMC.resource.buffers import Buffer
9 | from flowMC.resource.states import State
10 | from flowMC.resource.logPDF import LogPDF
11 | from flowMC.Sampler import Sampler
12 |
13 |
14 | def dual_moon_pe(x: Float[Array, " n_dims"], data: dict):
15 | """
16 | Term 2 and 3 separate the distribution
17 | and smear it along the first and second dimension
18 | """
19 | print("compile count")
20 | term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2
21 | term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
22 | term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
23 | return -(term1 - logsumexp(term2) - logsumexp(term3))
24 |
25 |
26 | # Test parameters
27 | n_dims = 5
28 | n_chains = 2
29 | n_local_steps = 3
30 | n_global_steps = 3
31 | step_size = 0.1
32 |
33 | data = {"data": jnp.arange(5)}
34 |
35 | # Initialize random key and position
36 | rng_key = jax.random.PRNGKey(43)
37 | rng_key, subkey = jax.random.split(rng_key)
38 | initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims)) * 1
39 |
40 | # Define resources
41 | RWMCMC_sampler = GaussianRandomWalk(step_size=step_size)
42 | positions = Buffer("positions", (n_chains, n_local_steps, n_dims), 1)
43 | log_prob = Buffer("log_prob", (n_chains, n_local_steps), 1)
44 | acceptance = Buffer("acceptance", (n_chains, n_local_steps), 1)
45 | sampler_state = State(
46 | {
47 | "positions": "positions",
48 | "log_prob": "log_prob",
49 | "acceptance": "acceptance",
50 | },
51 | name="sampler_state",
52 | )
53 |
54 | # Initialize normalizing flow model
55 | rng_key, subkey = jax.random.split(rng_key)
56 |
57 | resource = {
58 | "logpdf": LogPDF(dual_moon_pe, n_dims=n_dims),
59 | "positions": positions,
60 | "log_prob": log_prob,
61 | "acceptance": acceptance,
62 | "RWMCMC": RWMCMC_sampler,
63 | "sampler_state": sampler_state,
64 | }
65 |
66 | # Define strategy
67 | strategy = TakeSerialSteps(
68 | "logpdf",
69 | kernel_name="RWMCMC",
70 | state_name="sampler_state",
71 | buffer_names=["positions", "log_prob", "acceptance"],
72 | n_steps=n_local_steps,
73 | )
74 |
75 | print("Initializing sampler class")
76 |
77 | # Initialize and run sampler
78 | nf_sampler = Sampler(
79 | n_dim=n_dims,
80 | n_chains=n_chains,
81 | rng_key=rng_key,
82 | resources=resource,
83 | strategies={"take_steps": strategy},
84 | strategy_order=["take_steps"],
85 | )
86 |
87 | nf_sampler.sample(initial_position, data)
88 |
--------------------------------------------------------------------------------
/.all-contributorsrc:
--------------------------------------------------------------------------------
1 | {
2 | "projectName": "flowMC",
3 | "projectOwner": "kazewong",
4 | "files": [
5 | "README.md"
6 | ],
7 | "commitType": "docs",
8 | "commitConvention": "angular",
9 | "contributorsPerLine": 7,
10 | "contributors": [
11 | {
12 | "login": "HajimeKawahara",
13 | "name": "Hajime Kawahara",
14 | "avatar_url": "https://avatars.githubusercontent.com/u/15956904?v=4",
15 | "profile": "http://secondearths.sakura.ne.jp/en/index.html",
16 | "contributions": [
17 | "bug"
18 | ]
19 | },
20 | {
21 | "login": "daniel-dodd",
22 | "name": "Daniel Dodd",
23 | "avatar_url": "https://avatars.githubusercontent.com/u/68821880?v=4",
24 | "profile": "https://github.com/daniel-dodd",
25 | "contributions": [
26 | "doc",
27 | "review",
28 | "test",
29 | "bug"
30 | ]
31 | },
32 | {
33 | "login": "matt-graham",
34 | "name": "Matt Graham",
35 | "avatar_url": "https://avatars.githubusercontent.com/u/6746980?v=4",
36 | "profile": "http://matt-graham.github.io",
37 | "contributions": [
38 | "bug",
39 | "test",
40 | "review",
41 | "doc"
42 | ]
43 | },
44 | {
45 | "login": "kazewong",
46 | "name": "Kaze Wong",
47 | "avatar_url": "https://avatars.githubusercontent.com/u/8803931?v=4",
48 | "profile": "https://www.kaze-wong.com/",
49 | "contributions": [
50 | "bug",
51 | "blog",
52 | "code",
53 | "content",
54 | "doc",
55 | "example",
56 | "infra",
57 | "maintenance",
58 | "research",
59 | "review",
60 | "test",
61 | "tutorial"
62 | ]
63 | },
64 | {
65 | "login": "marylou-gabrie",
66 | "name": "Marylou Gabrié",
67 | "avatar_url": "https://avatars.githubusercontent.com/u/11092071?v=4",
68 | "profile": "https://marylou-gabrie.github.io/",
69 | "contributions": [
70 | "bug",
71 | "code",
72 | "content",
73 | "doc",
74 | "example",
75 | "maintenance",
76 | "research",
77 | "test",
78 | "tutorial"
79 | ]
80 | },
81 | {
82 | "login": "Qazalbash",
83 | "name": "Meesum Qazalbash",
84 | "avatar_url": "https://avatars.githubusercontent.com/u/62182585?v=4",
85 | "profile": "https://github.com/Qazalbash",
86 | "contributions": [
87 | "code",
88 | "maintenance"
89 | ]
90 | },
91 | {
92 | "login": "thomasckng",
93 | "name": "Thomas Ng",
94 | "avatar_url": "https://avatars.githubusercontent.com/u/97585527?v=4",
95 | "profile": "https://github.com/thomasckng",
96 | "contributions": [
97 | "code",
98 | "maintenance"
99 | ]
100 | },
101 | {
102 | "login": "tedwards2412",
103 | "name": "Thomas Edwards",
104 | "avatar_url": "https://avatars.githubusercontent.com/u/6105841?v=4",
105 | "profile": "https://github.com/tedwards2412",
106 | "contributions": [
107 | "bug",
108 | "code"
109 | ]
110 | }
111 | ],
112 | "repoType": "github"
113 | }
114 |
--------------------------------------------------------------------------------
/test/unit/test_resources.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from flowMC.resource.buffers import Buffer
4 | from flowMC.resource.logPDF import LogPDF, Variable, TemperedPDF
5 | from flowMC.resource.kernel.MALA import MALA
6 | from flowMC.resource.states import State
7 | from flowMC.strategy.take_steps import TakeSerialSteps
8 |
9 |
10 | class TestLogPDF:
11 |
12 | n_dims = 5
13 |
14 | def posterior(self, x, data):
15 | return -jnp.sum(jnp.square(x - data["data"]))
16 |
17 | def test_value_and_grad(self):
18 | logpdf = LogPDF(
19 | self.posterior, [Variable("x_" + str(i), True) for i in range(self.n_dims)]
20 | )
21 | inputs = jnp.arange(self.n_dims).astype(jnp.float32)
22 | data = {"data": jnp.ones(self.n_dims)}
23 | values, grads = jax.value_and_grad(logpdf)(inputs, data)
24 | assert values == self.posterior(inputs, data)
25 |
26 | def test_resource(self):
27 | mala = MALA(1.0)
28 | logpdf = LogPDF(
29 | self.posterior, [Variable("x_" + str(i), True) for i in range(self.n_dims)]
30 | )
31 | rng_key = jax.random.PRNGKey(0)
32 | initial_position = jnp.zeros(self.n_dims)
33 | data = {"data": jnp.ones(self.n_dims)}
34 | sampler_state = State(
35 | {
36 | "target_positions": "test_positions",
37 | "target_log_probs": "test_log_probs",
38 | "target_acceptances": "test_acceptances",
39 | },
40 | name="sampler_state",
41 | )
42 | resources = {
43 | "test_positions": Buffer("test_positions", (self.n_dims, 1), 1),
44 | "test_log_probs": Buffer("test_log_probs", (self.n_dims, 1), 1),
45 | "test_acceptances": Buffer("test_acceptances", (self.n_dims, 1), 1),
46 | "MALA": mala,
47 | "logpdf": logpdf,
48 | "sampler_state": sampler_state,
49 | }
50 | stepper = TakeSerialSteps(
51 | "logpdf",
52 | "MALA",
53 | "sampler_state",
54 | ["target_positions", "target_log_probs", "target_acceptances"],
55 | 1,
56 | )
57 | key, resources, positions = stepper(rng_key, resources, initial_position, data)
58 |
59 | def test_tempered_pdf(self):
60 | n_temps = 5
61 | logpdf = TemperedPDF(
62 | self.posterior,
63 | lambda x, data: jnp.zeros(1),
64 | n_dims=self.n_dims,
65 | n_temps=n_temps,
66 | max_temp=100,
67 | )
68 | inputs = jnp.ones((n_temps, self.n_dims)).astype(jnp.float32)
69 | data = {"data": jnp.ones(self.n_dims)}
70 | temperatures = jnp.arange(n_temps) + 1.0
71 | values = jax.vmap(logpdf.tempered_log_pdf, in_axes=(0, 0, None))(
72 | temperatures, inputs, data
73 | )
74 | assert (
75 | values[:, 0] == jax.vmap(self.posterior, in_axes=(0, None))(inputs, data)
76 | ).all()
77 | assert values.shape == (5, 1)
78 |
79 |
80 | class TestBuffer:
81 | def test_buffer(self):
82 | buffer = Buffer("test", (10, 10), cursor_dim=1)
83 | assert buffer.name == "test"
84 | assert buffer.data.shape == (10, 10)
85 | assert buffer.cursor == 0
86 | assert buffer.cursor_dim == 1
87 |
88 | def test_update_buffer(self):
89 | buffer = Buffer("test", (10, 10), cursor_dim=0)
90 | buffer.update_buffer(jnp.ones((10, 10)))
91 | assert (buffer.data == jnp.ones((10, 10))).all()
92 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | kazewong.physics@gmail.com.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 |
70 | ## Attribution
71 |
72 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
73 | version 2.0, available at
74 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
75 |
76 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
77 | enforcement ladder](https://github.com/mozilla/diversity).
78 |
79 | [homepage]: https://www.contributor-covenant.org
80 |
81 | For answers to common questions about this code of conduct, see the FAQ at
82 | https://www.contributor-covenant.org/faq. Translations are available at
83 | https://www.contributor-covenant.org/translations.
84 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: flowMC
2 | site_description: Normalizing-flow enhanced sampling package for probabilistic inference in Jax
3 | site_author: Kaze Wong
4 | repo_url: https://github.com/kazewong/flowMC
5 | repo_name: kazewong/kazewong
6 |
7 | theme:
8 | name: material
9 | features:
10 | - navigation # Sections are included in the navigation on the left.
11 | - toc # Table of contents is integrated on the left; does not appear separately on the right.
12 | - header.autohide # header disappears as you scroll
13 | palette:
14 | # Light mode / dark mode
15 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as
16 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.
17 | - scheme: default
18 | primary: cyan
19 | accent: deep purple
20 | toggle:
21 | icon: material/brightness-5
22 | name: Dark mode
23 | - scheme: slate
24 | primary: custom
25 | accent: purple
26 | toggle:
27 | icon: material/brightness-2
28 | name: Light mode
29 |
30 | twitter_name: "@physicskaze"
31 | twitter_url: "https://twitter.com/physicskaze"
32 |
33 | markdown_extensions:
34 | - pymdownx.arithmatex: # Render LaTeX via MathJax
35 | generic: true
36 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
37 | - pymdownx.details
38 | - pymdownx.snippets: # Include one Markdown file into another
39 | base_path: docs
40 | - admonition
41 | - toc:
42 | permalink: "¤" # Adds a clickable permalink to each section heading
43 | toc_depth: 4
44 |
45 | plugins:
46 | - search # default search plugin; needs manually re-enabling when using any other plugins
47 | - autorefs # Cross-links to headings
48 | - mkdocs-jupyter: # Jupyter notebook support
49 | # show_input: False
50 | - gen-files:
51 | scripts:
52 | - docs/gen_ref_pages.py
53 | - literate-nav:
54 | nav_file: SUMMARY.md
55 | - mkdocstrings:
56 | handlers:
57 | python:
58 | setup_commands:
59 | - import pytkdocs_tweaks
60 | - pytkdocs_tweaks.main()
61 | - import jaxtyping
62 | - jaxtyping.set_array_name_format("array")
63 |
64 | optional:
65 | docstring_style: google
66 | inherited_members: true # Allow looking up inherited methods
67 | show_root_heading: true # actually display anything at all...
68 | show_root_full_path: true # display "diffrax.asdf" not just "asdf"
69 | show_if_no_docstring: true
70 | show_signature_annotations: true
71 | show_source: false # don't include source code
72 | members_order: source # order methods according to their order of definition in the source code, not alphabetical order
73 | heading_level: 4
74 |
75 |
76 | extra_css:
77 | - stylesheets/extra.css
78 |
79 | nav:
80 | - Home: index.md
81 | - Quickstart: quickstart.md
82 | - Tutorial:
83 | - Step-by-step guide: tutorials/dualmoon.ipynb
84 | - Custom Strategy: tutorials/custom_strategy.ipynb
85 | - Training Normalizing Flow: tutorials/train_normalizing_flow.ipynb
86 | - Loading Resource: tutorials/loading_resources.ipynb
87 | - Community Examples: communityExamples.md
88 | - FAQ: FAQ.md
89 | - Contribution: contribution.md
90 | - API: api/flowMC/*
91 |
--------------------------------------------------------------------------------
/src/flowMC/resource/kernel/MALA.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.scipy.stats import multivariate_normal
4 | from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, PyTree
5 | from typing import Callable
6 |
7 | from flowMC.resource.logPDF import LogPDF
8 | from flowMC.resource.kernel.base import ProposalBase
9 |
10 |
11 | class MALA(ProposalBase):
12 | """Metropolis-adjusted Langevin algorithm sampler class."""
13 |
14 | step_size: Float
15 |
16 | def __repr__(self):
17 | return "MALA with step size " + str(self.step_size)
18 |
19 | def __init__(
20 | self,
21 | step_size: Float,
22 | ):
23 | super().__init__()
24 | self.step_size = step_size
25 |
26 | def kernel(
27 | self,
28 | rng_key: PRNGKeyArray,
29 | position: Float[Array, " n_dim"],
30 | log_prob: Float[Array, "1"],
31 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
32 | data: PyTree,
33 | ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]:
34 | """Metropolis-adjusted Langevin algorithm kernel. This is a kernel that only
35 | evolve a single chain.
36 |
37 | Args:
38 | rng_key (PRNGKeyArray): Jax PRNGKey
39 | position (Float[Array, " n_dim"]): current position of the chain
40 | log_prob (Float[Array, "1"]): current log-probability of the chain
41 | data (PyTree): data to be passed to the logpdf function
42 |
43 | Returns:
44 | position (Float[Array, " n_dim"]): new position of the chain
45 | log_prob (Float[Array, "1"]): new log-probability of the chain
46 | do_accept (Int[Array, "1"]): whether the new position is accepted
47 | """
48 |
49 | def body(
50 | carry: tuple[Float[Array, " n_dim"], float, dict],
51 | this_key: PRNGKeyArray,
52 | ) -> tuple[
53 | tuple[Float[Array, " n_dim"], float, dict],
54 | tuple[Float[Array, " n_dim"], Float[Array, "1"], Float[Array, " n_dim"]],
55 | ]:
56 | print("Compiling MALA body")
57 | this_position, dt, data = carry
58 | dt2 = dt * dt
59 | this_log_prob, this_d_log = jax.value_and_grad(logpdf)(this_position, data)
60 | proposal = this_position + jnp.dot(dt2, this_d_log) / 2
61 | proposal += jnp.dot(
62 | dt, jax.random.normal(this_key, shape=this_position.shape)
63 | )
64 | return (proposal, dt, data), (proposal, this_log_prob, this_d_log)
65 |
66 | key1, key2 = jax.random.split(rng_key)
67 |
68 | dt: Float = self.step_size
69 | dt2 = dt * dt
70 |
71 | _, (proposal, logprob, d_logprob) = jax.lax.scan(
72 | body, (position, dt, data), jnp.array([key1, key1])
73 | )
74 |
75 | ratio = logprob[1] - logprob[0]
76 | ratio -= multivariate_normal.logpdf(
77 | proposal[0], position + jnp.dot(dt2, d_logprob[0]) / 2, dt2
78 | )
79 | ratio += multivariate_normal.logpdf(
80 | position, proposal[0] + jnp.dot(dt2, d_logprob[1]) / 2, dt2
81 | )
82 |
83 | log_uniform = jnp.log(jax.random.uniform(key2))
84 | do_accept: Bool[Array, " n_dim"] = log_uniform < ratio
85 |
86 | position = jnp.where(do_accept, proposal[0], position)
87 | log_prob = jnp.where(do_accept, logprob[1], logprob[0])
88 |
89 | return position, log_prob, do_accept
90 |
91 | def print_parameters(self):
92 | print("MALA parameters:")
93 | print(f"step_size: {self.step_size}")
94 |
95 | def save_resource(self, path):
96 | raise NotImplementedError
97 |
98 | def load_resource(self, path):
99 | raise NotImplementedError
100 |
--------------------------------------------------------------------------------
/src/flowMC/resource/logPDF.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable, Optional
3 | from flowMC.resource.base import Resource
4 | from jaxtyping import Array, Float, PyTree
5 | import jax
6 |
7 |
8 | @dataclass
9 | class Variable:
10 | """A dataclass that holds the information of a variable in the log-pdf function.
11 |
12 | This main purpose of this class is to let the users name their variables,
13 | and specify whether they are continuous or not.
14 | """
15 |
16 | name: str
17 | continuous: bool
18 |
19 |
20 | @jax.tree_util.register_pytree_node_class
21 | class LogPDF(Resource):
22 | """A resource class that holds the log-pdf function.
23 | The main purpose of this class is to wrap the log-pdf function into the unified Resource interface.
24 |
25 | Args:
26 | log_pdf (Callable[[Float[Array, "n_dim"], PyTree], Float[Array, "1"]): The log-pdf function
27 | variables (list[Variable]): The list of variables in the log-pdf function
28 | """
29 |
30 | log_pdf: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]]
31 | variables: list[Variable]
32 |
33 | @property
34 | def n_dims(self):
35 | return len(self.variables)
36 |
37 | def __repr__(self):
38 | return "LogPDF with " + str(self.n_dims) + " dimensions"
39 |
40 | def __init__(
41 | self,
42 | log_pdf: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
43 | variables: Optional[list[Variable]] = None,
44 | n_dims: Optional[int] = None,
45 | ):
46 | """
47 | Args:
48 | log_pdf (Callable[[Float[Array, "n_dim"], PyTree], Float[Array, "1"]): The log-pdf function
49 | variables (list[Variable], optional): The list of variables in the log-pdf function. Defaults to None. n_dims must be provided if this argument is None.
50 | n_dims (int, optional): The number of dimensions of the log-pdf function. Defaults to None. If variables is provided, this argument is ignored.
51 | """
52 | self.log_pdf = log_pdf
53 | if variables is None and n_dims is not None:
54 | self.variables = [Variable("x_" + str(i), True) for i in range(n_dims)]
55 | elif variables is not None:
56 | self.variables = variables
57 | else:
58 | raise ValueError("Either variables or n_dims must be provided")
59 |
60 | def __call__(self, x: Float[Array, " n_dim"], data: PyTree) -> Float[Array, "1"]:
61 | return self.log_pdf(x, data)
62 |
63 | def print_parameters(self):
64 | print("LogPDF with variables:")
65 | for var in self.variables:
66 | print(var.name, var.continuous)
67 |
68 | def save_resource(self, path):
69 | raise NotImplementedError
70 |
71 | def load_resource(self, path):
72 | raise NotImplementedError
73 |
74 | def tree_flatten(self):
75 | children = ()
76 | aux_data = (self.log_pdf, self.variables)
77 | return (children, aux_data)
78 |
79 | @classmethod
80 | def tree_unflatten(cls, aux_data, children):
81 | return cls(aux_data[0], aux_data[1])
82 |
83 |
84 | @jax.tree_util.register_pytree_node_class
85 | class TemperedPDF(LogPDF):
86 |
87 | log_prior: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]]
88 |
89 | def __init__(
90 | self,
91 | log_likelihood: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
92 | log_prior: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
93 | variables=None,
94 | n_dims=None,
95 | n_temps=5,
96 | max_temp=100,
97 | ):
98 | super().__init__(log_likelihood, variables, n_dims)
99 | self.log_prior = log_prior
100 |
101 | def __call__(self, x, data):
102 | return super().__call__(x, data)
103 |
104 | def tempered_log_pdf(self, temperatures, x, data):
105 | base_pdf = super().__call__(x, data)
106 | return (1.0 / temperatures) * base_pdf + self.log_prior(x, data)
107 |
108 | def tree_flatten(self): # type: ignore
109 | children = ()
110 | aux_data = (self.log_pdf, self.log_prior, self.variables)
111 | return (children, aux_data)
112 |
113 | @classmethod
114 | def tree_unflatten(cls, aux_data, children):
115 | return cls(*aux_data, *children)
116 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/train_model.py:
--------------------------------------------------------------------------------
1 | from flowMC.strategy.base import Strategy
2 | from flowMC.resource.base import Resource
3 | from flowMC.resource.buffers import Buffer
4 | from flowMC.resource.model.nf_model.base import NFModel
5 | from flowMC.resource.optimizer import Optimizer
6 | from jaxtyping import Array, Float, PRNGKeyArray
7 | import jax
8 | import jax.numpy as jnp
9 |
10 |
11 | class TrainModel(Strategy):
12 | model_resource: str
13 | data_resource: str
14 | optimizer_resource: str
15 | n_epochs: int
16 | batch_size: int
17 | n_max_examples: int
18 | verbose: bool
19 | thinning: int
20 |
21 | def __repr__(self):
22 | return "Train " + self.model_resource
23 |
24 | def __init__(
25 | self,
26 | model_resource: str,
27 | data_resource: str,
28 | optimizer_resource: str,
29 | loss_buffer_name: str = "",
30 | n_epochs: int = 100,
31 | batch_size: int = 64,
32 | n_max_examples: int = 10000,
33 | history_window: int = 100,
34 | verbose: bool = False,
35 | ):
36 | self.model_resource = model_resource
37 | self.data_resource = data_resource
38 | self.optimizer_resource = optimizer_resource
39 | self.loss_buffer_name = loss_buffer_name
40 |
41 | self.n_epochs = n_epochs
42 | self.batch_size = batch_size
43 | self.n_max_examples = n_max_examples
44 | self.verbose = verbose
45 | self.history_window = history_window
46 |
47 | def __call__(
48 | self,
49 | rng_key: PRNGKeyArray,
50 | resources: dict[str, Resource],
51 | initial_position: Float[Array, "n_chains n_dim"],
52 | data: dict,
53 | ) -> tuple[
54 | PRNGKeyArray,
55 | dict[str, Resource],
56 | Float[Array, "n_chains n_dim"],
57 | ]:
58 | model = resources[self.model_resource]
59 | assert isinstance(model, NFModel), "Target resource must be a NFModel"
60 | data_resource = resources[self.data_resource]
61 | assert isinstance(data_resource, Buffer), "Data resource must be a buffer"
62 | optimizer = resources[self.optimizer_resource]
63 | assert isinstance(
64 | optimizer, Optimizer
65 | ), "Optimizer resource must be an optimizer"
66 | n_chains = data_resource.data.shape[0]
67 | n_dims = data_resource.data.shape[-1]
68 | training_data = data_resource.data[
69 | jnp.isfinite(data_resource.data).all(axis=-1)
70 | ].reshape(n_chains, -1, n_dims)
71 | training_data = training_data[:, -self.history_window :].reshape(-1, n_dims)
72 | rng_key, subkey = jax.random.split(rng_key)
73 | training_data = training_data[
74 | jax.random.choice(
75 | subkey,
76 | jnp.arange(training_data.shape[0]),
77 | shape=(self.n_max_examples,),
78 | replace=True,
79 | )
80 | ]
81 | rng_key, subkey = jax.random.split(rng_key)
82 |
83 | if self.verbose:
84 | print("Training model")
85 | print(f"Training data shape: {training_data.shape}")
86 | print(f"n_epochs: {self.n_epochs}")
87 | print(f"batch_size: {self.batch_size}")
88 |
89 | (rng_key, model, optim_state, loss_values) = model.train(
90 | rng=subkey,
91 | data=training_data,
92 | optim=optimizer.optim,
93 | state=optimizer.optim_state,
94 | num_epochs=self.n_epochs,
95 | batch_size=self.batch_size,
96 | verbose=self.verbose,
97 | )
98 |
99 | if self.loss_buffer_name != "":
100 | loss_buffer = resources[self.loss_buffer_name]
101 | assert isinstance(
102 | loss_buffer, Buffer
103 | ), "Loss buffer resource must be a buffer"
104 | loss_buffer.update_buffer(loss_values, start=loss_buffer.cursor)
105 | loss_buffer.cursor += len(loss_values)
106 | resources[self.loss_buffer_name] = loss_buffer
107 |
108 | optimizer.optim_state = optim_state
109 | resources[self.model_resource] = model
110 | resources[self.optimizer_resource] = optimizer
111 | # print(f"Training loss: {loss_values}")
112 | return rng_key, resources, initial_position
113 |
--------------------------------------------------------------------------------
/src/flowMC/Sampler.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jaxtyping import Array, Float, PRNGKeyArray
3 | from typing import Optional
4 |
5 | from flowMC.strategy.base import Strategy
6 | from flowMC.resource.base import Resource
7 | from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle
8 |
9 |
10 | class Sampler:
11 | """Top level API that the users primarily interact with.
12 |
13 | Args:
14 | n_dim (int): Dimension of the parameter space.
15 | n_chains (int): Number of chains to sample.
16 | rng_key (PRNGKeyArray): Jax PRNGKey.
17 | logpdf (Callable[[Float[Array, "n_dim"], dict], Float):
18 | Log probability function.
19 | resources (dict[str, Resource]): Resources to be used by the sampler.
20 | strategies (dict[str, Strategy]): Strategies to be used by the sampler.
21 | verbose (bool): Whether to print out progress. Defaults to False.
22 | logging (bool): Whether to log the progress. Defaults to True.
23 | outdir (str): Directory to save the logs. Defaults to "./outdir/".
24 | """
25 |
26 | # Essential parameters
27 | n_dim: int
28 | n_chains: int
29 | rng_key: PRNGKeyArray
30 | resources: dict[str, Resource]
31 | strategies: dict[str, Strategy]
32 | strategy_order: Optional[list[str]]
33 |
34 | # Logging hyperparameters
35 | verbose: bool = False
36 | logging: bool = True
37 | outdir: str = "./outdir/"
38 |
39 | def __init__(
40 | self,
41 | n_dim: int,
42 | n_chains: int,
43 | rng_key: PRNGKeyArray,
44 | resources: None | dict[str, Resource] = None,
45 | strategies: None | dict[str, Strategy] = None,
46 | strategy_order: None | list[str] = None,
47 | resource_strategy_bundles: None | ResourceStrategyBundle = None,
48 | **kwargs,
49 | ):
50 | # Copying input into the model
51 |
52 | self.n_dim = n_dim
53 | self.n_chains = n_chains
54 | self.rng_key = rng_key
55 |
56 | if resources is not None and strategies is not None:
57 | print(
58 | "Resources and strategies provided. Ignoring resource strategy bundles."
59 | )
60 | self.resources = resources
61 | self.strategies = strategies
62 | self.strategy_order = strategy_order
63 |
64 | else:
65 | print(
66 | "Resources or strategies not provided. Using resource strategy bundles."
67 | )
68 | if resource_strategy_bundles is None:
69 | raise ValueError(
70 | "Resource strategy bundles not provided."
71 | "Please provide either resources and strategies or resource strategy bundles."
72 | )
73 | self.resources = resource_strategy_bundles.resources
74 | self.strategies = resource_strategy_bundles.strategies
75 | self.strategy_order = resource_strategy_bundles.strategy_order
76 |
77 | # Set and override any given hyperparameters
78 | class_keys = list(self.__class__.__dict__.keys())
79 | for key, value in kwargs.items():
80 | if key in class_keys:
81 | if not key.startswith("__"):
82 | setattr(self, key, value)
83 |
84 | def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict):
85 | """Sample from the posterior using the local sampler.
86 |
87 | Args:
88 | initial_position (Device Array): Initial position.
89 | data (dict): Data to be used by the likelihood functions
90 | """
91 |
92 | initial_position = jnp.atleast_2d(initial_position) # type: ignore
93 | rng_key = self.rng_key
94 | last_step = initial_position
95 | assert isinstance(self.strategy_order, list)
96 | for strategy in self.strategy_order:
97 | if strategy not in self.strategies:
98 | raise ValueError(
99 | f"Invalid strategy name '{strategy}' provided. "
100 | f"Available strategies are: {list(self.strategies.keys())}."
101 | )
102 | (
103 | rng_key,
104 | self.resources,
105 | last_step,
106 | ) = self.strategies[
107 | strategy
108 | ](rng_key, self.resources, last_step, data)
109 |
110 | # TODO: Implement quick access and summary functions that operates on buffer
111 |
112 | def serialize(self):
113 | """Serialize the sampler object."""
114 | raise NotImplementedError
115 |
116 | def deserialize(self):
117 | """Deserialize the sampler object."""
118 | raise NotImplementedError
119 |
--------------------------------------------------------------------------------
/src/flowMC/resource/kernel/HMC.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
6 |
7 | from flowMC.resource.kernel.base import ProposalBase
8 | from flowMC.resource.logPDF import LogPDF
9 |
10 |
11 | class HMC(ProposalBase):
12 | """Hamiltonian Monte Carlo sampler class builiding the hmc_sampler method from
13 | target logpdf.
14 |
15 | Args:
16 | logpdf: target logpdf function
17 | jit: whether to jit the sampler
18 | params: dictionary of parameters for the sampler
19 | """
20 |
21 | condition_matrix: Float[Array, " n_dim n_dim"]
22 | step_size: Float
23 | leapfrog_coefs: Float[Array, " n_leapfrog n_dim"]
24 |
25 | @property
26 | def n_leapfrog(self) -> Int:
27 | return self.leapfrog_coefs.shape[0] - 2
28 |
29 | def __repr__(self):
30 | return (
31 | "HMC with step size "
32 | + str(self.step_size)
33 | + " and "
34 | + str(self.n_leapfrog)
35 | + " leapfrog steps"
36 | )
37 |
38 | def __init__(
39 | self,
40 | condition_matrix: Float[Array, " n_dim n_dim"] | Float = 1,
41 | step_size: Float = 0.1,
42 | n_leapfrog: Int = 10,
43 | ):
44 | self.condition_matrix = condition_matrix
45 | self.step_size = step_size
46 |
47 | coefs = jnp.ones((n_leapfrog + 2, 2))
48 | coefs = coefs.at[0].set(jnp.array([0, 0.5]))
49 | coefs = coefs.at[-1].set(jnp.array([1, 0.5]))
50 | self.leapfrog_coefs = coefs
51 |
52 | def get_initial_hamiltonian(
53 | self,
54 | potential: Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
55 | kinetic: Callable[
56 | [Float[Array, " n_dim"], Float[Array, " n_dim"]], Float[Array, "1"]
57 | ],
58 | rng_key: PRNGKeyArray,
59 | position: Float[Array, " n_dim"],
60 | data: PyTree,
61 | ):
62 | """Compute the value of the Hamiltonian from positions with initial momentum
63 | draw at random from the standard normal distribution."""
64 |
65 | momentum = (
66 | jax.random.normal(rng_key, shape=position.shape)
67 | * self.condition_matrix**-0.5
68 | )
69 | return potential(position, data) + kinetic(momentum, self.condition_matrix)
70 |
71 | def leapfrog_kernel(self, kinetic, potential, carry, extras):
72 | position, momentum, data, metric, index = carry
73 | position = position + self.step_size * self.leapfrog_coefs[index][0] * jax.grad(
74 | kinetic
75 | )(momentum, metric)
76 | momentum = momentum - self.step_size * self.leapfrog_coefs[index][1] * jax.grad(
77 | potential
78 | )(position, data)
79 | index = index + 1
80 | return (position, momentum, data, metric, index), extras
81 |
82 | def leapfrog_step(
83 | self,
84 | leapfrog_kernel: Callable,
85 | position: Float[Array, " n_dim"],
86 | momentum: Float[Array, " n_dim"],
87 | data: PyTree,
88 | metric: Float[Array, " n_dim n_dim"],
89 | ) -> tuple[Float[Array, " n_dim"], Float[Array, " n_dim"]]:
90 | print("Compiling leapfrog step")
91 | (position, momentum, data, metric, index), _ = jax.lax.scan(
92 | leapfrog_kernel,
93 | (position, momentum, data, metric, 0),
94 | jnp.arange(self.n_leapfrog + 2),
95 | )
96 | return position, momentum
97 |
98 | def kernel(
99 | self,
100 | rng_key: PRNGKeyArray,
101 | position: Float[Array, " n_dim"],
102 | log_prob: Float[Array, "1"],
103 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
104 | data: PyTree,
105 | ) -> tuple[Float[Array, " n_dim"], Float[Array, "1"], Int[Array, "1"]]:
106 | """Note that since the potential function is the negative log likelihood,
107 | hamiltonian is going down, but the likelihood value should go up.
108 |
109 | Args:
110 | rng_key (n_chains, 2): random key
111 | position (n_chains, n_dim): current position
112 | PE (n_chains, ): Potential energy of the current position
113 | """
114 |
115 | def potential(x: Float[Array, " n_dim"], data: PyTree) -> Float[Array, "1"]:
116 | return -logpdf(x, data)
117 |
118 | def kinetic(
119 | p: Float[Array, " n_dim"], metric: Float[Array, " n_dim"]
120 | ) -> Float[Array, "1"]:
121 | return 0.5 * (p**2 * metric).sum()
122 |
123 | leapfrog_kernel = jax.tree_util.Partial(
124 | self.leapfrog_kernel, kinetic, potential
125 | )
126 | leapfrog_step = jax.tree_util.Partial(self.leapfrog_step, leapfrog_kernel)
127 |
128 | key1, key2 = jax.random.split(rng_key)
129 |
130 | momentum: Float[Array, " n_dim"] = (
131 | jax.random.normal(key1, shape=position.shape) * self.condition_matrix**-0.5
132 | )
133 | momentum = jnp.dot(
134 | jax.random.normal(key1, shape=position.shape),
135 | jnp.linalg.cholesky(jnp.linalg.inv(self.condition_matrix)).T,
136 | )
137 | H = -log_prob + kinetic(momentum, self.condition_matrix)
138 | proposed_position, proposed_momentum = leapfrog_step(
139 | position, momentum, data, self.condition_matrix
140 | )
141 | proposed_PE = potential(proposed_position, data)
142 | proposed_ham = proposed_PE + kinetic(proposed_momentum, self.condition_matrix)
143 | log_acc = H - proposed_ham
144 | log_uniform = jnp.log(jax.random.uniform(key2))
145 |
146 | do_accept = log_uniform < log_acc
147 |
148 | position = jnp.where(do_accept, proposed_position, position) # type: ignore
149 | log_prob = jnp.where(do_accept, -proposed_PE, log_prob) # type: ignore
150 |
151 | return position, log_prob, do_accept
152 |
153 | def print_parameters(self):
154 | print("HMC parameters:")
155 | print(f"step_size: {self.step_size}")
156 | print(f"n_leapfrog: {self.n_leapfrog}")
157 | print(f"condition_matrix: {self.condition_matrix}")
158 |
159 | def save_resource(self, path):
160 | raise NotImplementedError
161 |
162 | def load_resource(self, path):
163 | raise NotImplementedError
164 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/optimization.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import optax
6 | from jaxtyping import Array, Float, PRNGKeyArray
7 |
8 | from flowMC.strategy.base import Strategy
9 | from flowMC.resource.base import Resource
10 |
11 |
12 | class AdamOptimization(Strategy):
13 | """Optimize a set of chains using Adam optimization. Note that if the posterior can
14 | go to infinity, this optimization scheme is likely to return NaNs.
15 |
16 | Args:
17 | n_steps: int = 100
18 | Number of optimization steps.
19 | learning_rate: float = 1e-2
20 | Learning rate for the optimization.
21 | noise_level: float = 10
22 | Variance of the noise added to the gradients.
23 | bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
24 | Bounds for the optimization. The optimization will be projected to these bounds.
25 | If bounds has shape (1, 2), it will be broadcast to all dimensions. For n_dim > 1,
26 | passing a (1, 2) array applies the same bound to every dimension. To specify different
27 | bounds per dimension, provide an array of shape (n_dim, 2).
28 | """
29 |
30 | logpdf: Callable[[Float[Array, " n_dim"], dict], Float]
31 | n_steps: int = 100
32 | learning_rate: float = 1e-2
33 | noise_level: float = 10
34 | bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
35 |
36 | def __repr__(self):
37 | return "AdamOptimization"
38 |
39 | def __init__(
40 | self,
41 | logpdf: Callable[[Float[Array, " n_dim"], dict], Float],
42 | n_steps: int = 100,
43 | learning_rate: float = 1e-2,
44 | noise_level: float = 10,
45 | bounds: Float[Array, " n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
46 | ):
47 | self.logpdf = logpdf
48 | self.n_steps = n_steps
49 | self.learning_rate = learning_rate
50 | self.noise_level = noise_level
51 | self.bounds = bounds
52 |
53 | # Validate bounds shape
54 | if bounds.ndim != 2 or bounds.shape[1] != 2:
55 | raise ValueError(
56 | f"bounds must have shape (n_dim, 2) or (1, 2), got {bounds.shape}"
57 | )
58 | # If bounds is (1, 2), it will be broadcast to all dimensions. If not, check compatibility.
59 | # Try to infer n_dim from logpdf signature or initial_position, but here we can't, so warn in runtime.
60 |
61 | self.solver = optax.chain(
62 | optax.adam(learning_rate=self.learning_rate),
63 | )
64 |
65 | def __call__(
66 | self,
67 | rng_key: PRNGKeyArray,
68 | resources: dict[str, Resource],
69 | initial_position: Float[Array, " n_chain n_dim"],
70 | data: dict,
71 | ) -> tuple[
72 | PRNGKeyArray,
73 | dict[str, Resource],
74 | Float[Array, " n_chain n_dim"],
75 | ]:
76 | def loss_fn(params: Float[Array, " n_dim"], data: dict) -> Float:
77 | return -self.logpdf(params, data)
78 |
79 | rng_key, optimized_positions, _ = self.optimize(
80 | rng_key, loss_fn, initial_position, data
81 | )
82 |
83 | return rng_key, resources, optimized_positions
84 |
85 | def optimize(
86 | self,
87 | rng_key: PRNGKeyArray,
88 | objective: Callable,
89 | initial_position: Float[Array, " n_chain n_dim"],
90 | data: dict,
91 | ):
92 | # Validate bounds shape against n_dim
93 | n_dim = initial_position.shape[-1]
94 | if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim):
95 | raise ValueError(
96 | f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. "
97 | "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds."
98 | )
99 |
100 | """Optimization kernel. This can be used independently of the __call__ method.
101 |
102 | Args:
103 | rng_key: PRNGKeyArray
104 | Random key for the optimization.
105 | objective: Callable
106 | Objective function to optimize.
107 | initial_position: Float[Array, " n_chain n_dim"]
108 | Initial positions for the optimization.
109 | data: dict
110 | Data to pass to the objective function.
111 |
112 | Returns:
113 | rng_key: PRNGKeyArray
114 | Updated random key.
115 | optimized_positions: Float[Array, " n_chain n_dim"]
116 | Optimized positions.
117 | final_log_prob: Float[Array, " n_chain"]
118 | Final log-probabilities of the optimized positions.
119 | """
120 | grad_fn = jax.jit(jax.grad(objective))
121 |
122 | def _kernel(carry, _step):
123 | key, params, opt_state = carry
124 |
125 | key, subkey = jax.random.split(key)
126 | grad = grad_fn(params, data) * (
127 | 1 + jax.random.normal(subkey) * self.noise_level
128 | )
129 | updates, opt_state = self.solver.update(grad, opt_state, params)
130 | params = optax.apply_updates(params, updates)
131 | params = optax.projections.projection_box(
132 | params, self.bounds[:, 0], self.bounds[:, 1]
133 | )
134 | return (key, params, opt_state), None
135 |
136 | def _single_optimize(
137 | key: PRNGKeyArray,
138 | initial_position: Float[Array, " n_dim"],
139 | ) -> Float[Array, " n_dim"]:
140 | opt_state = self.solver.init(initial_position)
141 |
142 | (key, params, opt_state), _ = jax.lax.scan(
143 | _kernel,
144 | (key, initial_position, opt_state),
145 | jnp.arange(self.n_steps),
146 | )
147 |
148 | return params # type: ignore
149 |
150 | print("Using Adam optimization")
151 | rng_key, subkey = jax.random.split(rng_key)
152 | keys = jax.random.split(subkey, initial_position.shape[0])
153 | optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
154 | keys, initial_position
155 | )
156 |
157 | final_log_prob = jax.vmap(self.logpdf, in_axes=(0, None))(
158 | optimized_positions, data
159 | )
160 |
161 | if jnp.isinf(final_log_prob).any() or jnp.isnan(final_log_prob).any():
162 | print("Warning: Optimization accessed infinite or NaN log-probabilities.")
163 |
164 | return rng_key, optimized_positions, final_log_prob
165 |
--------------------------------------------------------------------------------
/joss/paper.bib:
--------------------------------------------------------------------------------
1 | @article{Gabrie2021,
2 | author = {Gabri{\'{e}}, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric},
3 | doi = {10.1073/pnas.2109420119},
4 | eprint = {2105.12603},
5 | issn = {0027-8424},
6 | journal = {Proceedings of the National Academy of Sciences},
7 | mendeley-groups = {publication list,project-capstone-EBM-NF,project-ag-clusters,paper-mixedkernels,proposal-louis-assistant,project-elena,proposal-hiparis-whitebook},
8 | month = {mar},
9 | number = {10},
10 | title = {{Adaptive Monte Carlo augmented with normalizing flows}},
11 | url = {https://pnas.org/doi/full/10.1073/pnas.2109420119},
12 | volume = {119},
13 | year = {2022}
14 | }
15 | @inproceedings{Hoffman2019,
16 | archivePrefix = {arXiv},
17 | arxivId = {1903.03704},
18 | author = {Hoffman, Matthew D and Sountsov, Pavel and Dillon, Joshua V. and Langmore, Ian and Tran, Dustin and Vasudevan, Srinivas},
19 | booktitle = {1st Symposium on Advances in Approximate Bayesian Inference, 2018 1–5},
20 | eprint = {1903.03704},
21 | title = {{NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport}},
22 | url = {http://arxiv.org/abs/1903.03704},
23 | year = {2019}
24 | }
25 | @article{Albergo2019,
26 | author = {Albergo, M. S. and Kanwar, G. and Shanahan, P. E.},
27 | doi = {10.1103/PhysRevD.100.034515},
28 | eprint = {1904.12072},
29 | issn = {2470-0010},
30 | journal = {Physical Review D},
31 | keywords = {doi:10.1103/PhysRevD.100.034515 url:https://doi.or},
32 | month = {aug},
33 | number = {3},
34 | pages = {034515},
35 | publisher = {American Physical Society},
36 | title = {{Flow-based generative models for Markov chain Monte Carlo in lattice field theory}},
37 | url = {https://link.aps.org/doi/10.1103/PhysRevD.100.034515},
38 | volume = {100},
39 | year = {2019}
40 | }
41 | @inproceedings{Gabrie2021a,
42 | author = {Gabri{\'{e}}, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric},
43 | booktitle = {Invertible Neural Networks, NormalizingFlows, and Explicit Likelihood Models (ICML Workshop).},
44 | title = {{Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods}},
45 | url = {https://arxiv.org/abs/2107.08001},
46 | year = {2021}
47 | }
48 | @article{Kobyzev2021,
49 | archivePrefix = {arXiv},
50 | arxivId = {1908.09257},
51 | author = {Kobyzev, Ivan and Prince, Simon J.D. and Brubaker, Marcus A.},
52 | doi = {10.1109/TPAMI.2020.2992934},
53 | eprint = {1908.09257},
54 | issn = {19393539},
55 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
56 | keywords = {Generative models,density estimation,invertible neural networks,normalizing flows,variational inference},
57 | number = {11},
58 | pages = {3964--3979},
59 | pmid = {32396070},
60 | publisher = {IEEE},
61 | title = {{Normalizing Flows: An Introduction and Review of Current Methods}},
62 | volume = {43},
63 | year = {2021}
64 | }
65 | @article{Papamakarios2019,
66 | archivePrefix = {arXiv},
67 | arxivId = {1912.02762},
68 | author = {Papamakarios, George and Nalisnick, Eric and Rezende, Danilo Jimenez and Mohamed, Shakir and Lakshminarayanan, Balaji},
69 | eprint = {1912.02762},
70 | journal = {Journal of Machine Learning Research},
71 | number = {57},
72 | pages = {1--64},
73 | title = {{Normalizing Flows for Probabilistic Modeling and Inference}},
74 | url = {https://jmlr.org/papers/v22/19-1028.html},
75 | volume = {22},
76 | year = {2021}
77 | }
78 | @article{Nicoli2020,
79 | archivePrefix = {arXiv},
80 | arxivId = {1910.13496},
81 | author = {Nicoli, Kim A. and Nakajima, Shinichi and Strodthoff, Nils and Samek, Wojciech and M{\"{u}}ller, Klaus Robert and Kessel, Pan},
82 | doi = {10.1103/PhysRevE.101.023304},
83 | eprint = {1910.13496},
84 | issn = {24700053},
85 | journal = {Physical Review E},
86 | mendeley-groups = {project-potential-based-learning,paper-mixedkernels,paper-neutravsflex},
87 | number = {2},
88 | pmid = {32168605},
89 | title = {{Asymptotically unbiased estimation of physical observables with neural samplers}},
90 | volume = {101},
91 | year = {2020}
92 | }
93 | @article{McNaughton2020,
94 | archivePrefix = {arXiv},
95 | arxivId = {2002.04292},
96 | author = {McNaughton, B. and Milo{\v{s}}evi{\'{c}}, M. V. and Perali, A. and Pilati, S.},
97 | doi = {10.1103/PhysRevE.101.053312},
98 | eprint = {2002.04292},
99 | issn = {24700053},
100 | journal = {Physical Review E},
101 | mendeley-groups = {project-potential-based-learning,paper-mixedkernels,proposal-hiparis-whitebook,paper-neutravsflex},
102 | number = {Mc},
103 | pages = {1--13},
104 | pmid = {32575304},
105 | title = {{Boosting Monte Carlo simulations of spin glasses using autoregressive neural networks}},
106 | url = {http://arxiv.org/abs/2002.04292},
107 | volume = {101},
108 | year = {2020}
109 | }
110 | @article{Naesseth2020,
111 | archivePrefix = {arXiv},
112 | arxivId = {2003.10374},
113 | author = {Naesseth, Christian A. and Lindsten, Fredrik and Blei, David},
114 | eprint = {2003.10374},
115 | file = {:Users/marylou/Dropbox/PhD/Literature/2003.10374.pdf:pdf},
116 | issn = {10495258},
117 | journal = {Advances in Neural Information Processing Systems},
118 | mendeley-groups = {paper-aistat-ergofloancao,paper-mixedkernels},
119 | number = {MCMC},
120 | title = {{Markovian score climbing: Variational inference with KL(p||q)}},
121 | volume = {2020-Decem},
122 | year = {2020}
123 | }
124 | @article{Andrieu2008,
125 | author = {Andrieu, Christophe and Thoms, Johannes},
126 | doi = {10.1007/s11222-008-9110-y},
127 | file = {:Users/marylou/Library/Application Support/Mendeley Desktop/Downloaded/Andrieu, Thoms - 2008 - A tutorial on adaptive MCMC.pdf:pdf},
128 | issn = {09603174},
129 | journal = {Statistics and Computing},
130 | keywords = {Adaptive MCMC,Controlled Markov chain,MCMC,Stochastic approximation},
131 | number = {4},
132 | pages = {343--373},
133 | title = {{A tutorial on adaptive MCMC}},
134 | volume = {18},
135 | year = {2008}
136 | }
137 | @article{bingham2019pyro,
138 | author = {Eli Bingham and
139 | Jonathan P. Chen and
140 | Martin Jankowiak and
141 | Fritz Obermeyer and
142 | Neeraj Pradhan and
143 | Theofanis Karaletsos and
144 | Rohit Singh and
145 | Paul A. Szerlip and
146 | Paul Horsfall and
147 | Noah D. Goodman},
148 | title = {Pyro: Deep Universal Probabilistic Programming},
149 | journal = {J. Mach. Learn. Res.},
150 | volume = {20},
151 | pages = {28:1--28:6},
152 | year = {2019},
153 | url = {http://jmlr.org/papers/v20/18-403.html}
154 | }
155 | @article{Karamanis2022,
156 | archivePrefix = {arXiv},
157 | arxivId = {2207.05652},
158 | author = {Karamanis, Minas and Beutler, Florian and Peacock, John A. and Nabergoj, David and Seljak, Uros},
159 | eprint = {2207.05652},
160 | file = {:Users/marylou/Dropbox/PhD/Literature/2207.05652.pdf:pdf},
161 | journal = {arXiv preprint},
162 | keywords = {cosmology,data analysis,large-scale structure of universe,methods,statistical},
163 | title = {{Accelerating astronomical and cosmological inference with Preconditioned Monte Carlo}},
164 | url = {http://arxiv.org/abs/2207.05652},
165 | volume = {2207.05652},
166 | year = {2022}
167 | }
168 |
--------------------------------------------------------------------------------
/src/flowMC/resource/kernel/NF_proposal.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from jax import random
6 | from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
7 | from typing import Callable
8 | import equinox as eqx
9 |
10 | from flowMC.resource.model.nf_model.base import NFModel
11 | from flowMC.resource.kernel.base import ProposalBase
12 | from flowMC.resource.logPDF import LogPDF
13 |
14 |
15 | class NFProposal(ProposalBase):
16 | model: NFModel
17 | n_batch_size: int
18 |
19 | def __repr__(self):
20 | return "NF proposal with " + self.model.__repr__()
21 |
22 | def __init__(self, model: NFModel, n_NFproposal_batch_size: int = 100):
23 | super().__init__()
24 | self.model = model
25 | self.n_batch_size = n_NFproposal_batch_size
26 |
27 | def kernel(
28 | self,
29 | rng_key: PRNGKeyArray,
30 | position: Float[Array, " n_dim"],
31 | log_prob: Float[Array, "1"],
32 | logpdf: LogPDF | Callable[[Float[Array, " n_dim"], PyTree], Float[Array, "1"]],
33 | data: PyTree,
34 | ) -> tuple[
35 | Float[Array, "n_step n_dim"], Float[Array, "n_step 1"], Int[Array, "n_step 1"]
36 | ]:
37 |
38 | print("Compiling NF proposal kernel")
39 | n_steps = data["n_steps"]
40 |
41 | rng_key, subkey = random.split(rng_key)
42 |
43 | # nf_current is size (1, n_dim)
44 | log_prob_nf_current = eqx.filter_jit(self.model.log_prob)(position)
45 |
46 | # All these are size (n_steps, n_dim)
47 | proposed_position, log_prob_nf_proposed = eqx.filter_jit(self.sample_flow)(
48 | subkey, n_steps
49 | )
50 | if n_steps > self.n_batch_size:
51 | n_batch = ceil(proposed_position.shape[0] / self.n_batch_size)
52 | batched_proposed_position = proposed_position[
53 | : (n_batch - 1) * self.n_batch_size
54 | ].reshape(n_batch - 1, self.n_batch_size, self.model.n_features)
55 |
56 | def scan_sample(
57 | carry,
58 | aux,
59 | ):
60 | proposed_position = aux
61 | return carry, jax.vmap(logpdf, in_axes=(0, None))(
62 | proposed_position, data
63 | )
64 |
65 | _, log_prob_proposed = jax.lax.scan(
66 | scan_sample,
67 | (),
68 | batched_proposed_position,
69 | )
70 | log_prob_proposed = log_prob_proposed.reshape(-1)
71 | log_prob_proposed = jnp.concatenate(
72 | (
73 | log_prob_proposed,
74 | jax.vmap(logpdf, in_axes=(0, None))(
75 | jax.lax.dynamic_slice_in_dim(
76 | proposed_position,
77 | (n_batch - 1) * self.n_batch_size,
78 | n_steps - (n_batch - 1) * self.n_batch_size,
79 | ),
80 | data,
81 | ),
82 | ),
83 | axis=0,
84 | )
85 |
86 | else:
87 | log_prob_proposed = jax.vmap(logpdf, in_axes=(0, None))(
88 | proposed_position, data
89 | )
90 |
91 | def body(carry, data):
92 | (
93 | rng_key,
94 | position_initial,
95 | log_prob_initial,
96 | log_prob_nf_initial,
97 | ) = carry
98 | (position_proposal, log_prob_proposal, log_prob_nf_proposal) = data
99 | rng_key, subkey = random.split(rng_key)
100 | ratio = (log_prob_proposal - log_prob_initial) - (
101 | log_prob_nf_proposal - log_prob_nf_initial
102 | )
103 | uniform_random = jnp.log(jax.random.uniform(subkey))
104 | do_accept = uniform_random < ratio
105 | position_current = jnp.where(do_accept, position_proposal, position_initial)
106 | log_prob_current = jnp.where(do_accept, log_prob_proposal, log_prob_initial)
107 | log_prob_nf_current = jnp.where(
108 | do_accept, log_prob_nf_proposal, log_prob_nf_initial
109 | )
110 |
111 | return (rng_key, position_current, log_prob_current, log_prob_nf_current), (
112 | position_current,
113 | log_prob_current,
114 | do_accept,
115 | )
116 |
117 | _, (positions, log_prob, do_accept) = jax.lax.scan(
118 | body,
119 | (
120 | rng_key,
121 | position,
122 | log_prob,
123 | log_prob_nf_current,
124 | ),
125 | (proposed_position, log_prob_proposed, log_prob_nf_proposed),
126 | )
127 |
128 | return positions, log_prob, do_accept
129 |
130 | def sample_flow(
131 | self,
132 | rng_key: PRNGKeyArray,
133 | n_steps: int,
134 | ):
135 | if n_steps > self.n_batch_size:
136 | rng_key = rng_key
137 | n_batch = ceil(n_steps / self.n_batch_size)
138 | n_sample = ceil(n_steps / n_batch)
139 | (dynamic, static) = eqx.partition(self.model, eqx.is_array)
140 |
141 | def scan_sample(
142 | carry: tuple[PRNGKeyArray, NFModel],
143 | data,
144 | ):
145 | print("Compiling sample_flow")
146 | rng_key, model = carry
147 | rng_key, subkey = random.split(rng_key)
148 | combined = eqx.combine(model, static)
149 | proposal_position = combined.sample(subkey, n_samples=n_sample)
150 | proposed_log_prob = eqx.filter_vmap(combined.log_prob)(
151 | proposal_position
152 | )
153 | return (rng_key, model), (proposal_position, proposed_log_prob)
154 |
155 | _, (proposal_position, proposed_log_prob) = jax.lax.scan(
156 | scan_sample,
157 | (rng_key, dynamic),
158 | length=n_batch,
159 | )
160 | proposal_position = proposal_position.reshape(-1, self.model.n_features)[
161 | :n_steps
162 | ]
163 | proposed_log_prob = proposed_log_prob.reshape(-1)[:n_steps]
164 |
165 | else:
166 | proposal_position = self.model.sample(rng_key, n_steps)
167 | proposed_log_prob = eqx.filter_vmap(self.model.log_prob)(proposal_position)
168 |
169 | proposal_position = proposal_position.reshape(n_steps, self.model.n_features)
170 | proposed_log_prob = proposed_log_prob.reshape(n_steps)
171 |
172 | return proposal_position, proposed_log_prob
173 |
174 | def print_parameters(self):
175 | # TODO: Implement this
176 | raise NotImplementedError
177 |
178 | def save_resource(self, path):
179 | # TODO: Implement this
180 | raise NotImplementedError
181 |
182 | def load_resource(self, path):
183 | # TODO: Implement this
184 | raise NotImplementedError
185 |
--------------------------------------------------------------------------------
/test/unit/test_flowmatching.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pytest
4 |
5 | from flowMC.resource.model.flowmatching.base import (
6 | FlowMatchingModel,
7 | Solver,
8 | Path,
9 | CondOTScheduler,
10 | )
11 | from flowMC.resource.model.common import MLP
12 | from diffrax import Dopri5
13 | import equinox as eqx
14 | import optax
15 |
16 |
17 | def get_simple_mlp(n_input, n_hidden, n_output, key):
18 | """Simple 2-layer MLP for testing."""
19 | shape = (
20 | [n_input]
21 | + ([n_hidden] if isinstance(n_hidden, int) else list(n_hidden))
22 | + [n_output]
23 | )
24 | return MLP(shape=shape, key=key, activation=jax.nn.swish)
25 |
26 |
27 | ##############################
28 | # Solver Tests
29 | ##############################
30 |
31 |
32 | class TestSolver:
33 | @pytest.fixture
34 | def solver(self):
35 | key = jax.random.PRNGKey(0)
36 | n_dim = 3
37 | n_hidden = 4
38 | mlp = get_simple_mlp(
39 | n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key
40 | )
41 | return Solver(model=mlp, method=Dopri5()), key, n_dim
42 |
43 | def test_sample_shape_and_finiteness(self, solver):
44 | solver, key, n_dim = solver
45 | n_samples = 7
46 | samples = solver.sample(key, n_samples)
47 | assert samples.shape == (n_samples, n_dim)
48 | assert jnp.isfinite(samples).all()
49 |
50 | def test_log_prob_shape_and_finiteness(self, solver):
51 | solver, key, n_dim = solver
52 | x1 = jax.random.normal(key, (n_dim,))
53 | logp = solver.log_prob(x1)
54 | logp_arr = jnp.asarray(logp)
55 | assert logp_arr.size == 1
56 | assert jnp.isfinite(logp_arr).all()
57 |
58 | @pytest.mark.parametrize("dt", [1e-2, 1e-1, 0.5])
59 | def test_sample_various_dt(self, solver, dt):
60 | solver, key, n_dim = solver
61 | samples = solver.sample(key, 3, dt=dt)
62 | assert samples.shape == (3, n_dim)
63 | assert jnp.isfinite(samples).all()
64 |
65 |
66 | ##############################
67 | # Path & Scheduler Tests
68 | ##############################
69 |
70 |
71 | class TestPathAndScheduler:
72 | def test_path_sample_shapes_and_values(self):
73 | n_dim = 2
74 | scheduler = CondOTScheduler()
75 | path = Path(scheduler=scheduler)
76 | x0 = jnp.ones((5, n_dim))
77 | x1 = jnp.zeros((5, n_dim))
78 | for t_val in [0.0, 0.5, 1.0]:
79 | t = jnp.full((5, 1), t_val)
80 | x_t, dx_t = path.sample(x0, x1, t)
81 | assert x_t.shape == (5, n_dim)
82 | assert dx_t.shape == (5, n_dim)
83 |
84 | @pytest.mark.parametrize("t", [0.0, 1.0, 0.5, -0.1, 1.1])
85 | def test_condotscheduler_call_output(self, t):
86 | sched = CondOTScheduler()
87 | out = sched(jnp.array(t))
88 | assert isinstance(out, tuple)
89 | assert len(out) == 4
90 | assert all(isinstance(float(x), float) for x in out)
91 |
92 |
93 | ##############################
94 | # FlowMatchingModel Tests
95 | ##############################
96 |
97 |
98 | class TestFlowMatchingModel:
99 | @pytest.fixture
100 | def model(self):
101 | key = jax.random.PRNGKey(42)
102 | n_dim = 2
103 | n_hidden = 8
104 | mlp = get_simple_mlp(
105 | n_input=n_dim + 1, n_hidden=n_hidden, n_output=n_dim, key=key
106 | )
107 | solver = Solver(model=mlp, method=Dopri5())
108 | scheduler = CondOTScheduler()
109 | path = Path(scheduler=scheduler)
110 | model = FlowMatchingModel(
111 | solver=solver,
112 | path=path,
113 | data_mean=jnp.zeros(n_dim),
114 | data_cov=jnp.eye(n_dim),
115 | )
116 | return model, key, n_dim
117 |
118 | def test_sample_and_log_prob(self, model):
119 | model, key, n_dim = model
120 | n_samples = 4
121 | samples = model.sample(key, n_samples)
122 | assert samples.shape == (n_samples, n_dim)
123 | assert jnp.isfinite(samples).all()
124 | logp = eqx.filter_vmap(model.log_prob)(samples)
125 | assert logp.shape == (n_samples, 1)
126 | assert jnp.isfinite(logp).all()
127 |
128 | @pytest.mark.parametrize("n_samples", [1, 5, 10])
129 | def test_sample_various_shapes(self, model, n_samples):
130 | model, key, n_dim = model
131 | samples = model.sample(key, n_samples)
132 | assert samples.shape == (n_samples, n_dim)
133 | assert jnp.isfinite(samples).all()
134 | logp = eqx.filter_vmap(model.log_prob)(samples)
135 | assert logp.shape[0] == n_samples
136 | assert jnp.isfinite(logp).all()
137 |
138 | def test_log_prob_edge_cases(self, model):
139 | model, key, n_dim = model
140 | for arr in [jnp.zeros(n_dim), 1e6 * jnp.ones(n_dim), -1e6 * jnp.ones(n_dim)]:
141 | logp = model.log_prob(arr)
142 | logp_arr = jnp.asarray(logp)
143 | assert logp_arr.size == 1
144 | assert (
145 | jnp.isfinite(logp_arr).all() or jnp.isnan(logp_arr).all()
146 | ) # may be nan for extreme values
147 |
148 | def test_save_and_load(self, tmp_path, model):
149 | model, key, n_dim = model
150 | save_path = str(tmp_path / "test_model")
151 | model.save_model(save_path)
152 | loaded = model.load_model(save_path)
153 | x = jax.random.normal(key, (2, n_dim))
154 | assert jnp.allclose(
155 | eqx.filter_vmap(model.log_prob)(x), eqx.filter_vmap(loaded.log_prob)(x)
156 | )
157 |
158 | def test_properties(self, model):
159 | model, key, n_dim = model
160 | mean = jnp.arange(n_dim)
161 | cov = jnp.eye(n_dim) * 2
162 | model2 = FlowMatchingModel(
163 | solver=model.solver, path=model.path, data_mean=mean, data_cov=cov
164 | )
165 | assert model2.n_features == n_dim
166 | assert jnp.allclose(model2.data_mean, mean)
167 | assert jnp.allclose(model2.data_cov, cov)
168 |
169 | def test_print_parameters_notimplemented(self, model):
170 | model, key, n_dim = model
171 | with pytest.raises(NotImplementedError):
172 | model.print_parameters()
173 |
174 | def test_train_step_and_epoch(self, model):
175 | model, key, n_dim = model
176 | n_batch = 5
177 | x0 = jax.random.normal(key, (n_batch, n_dim))
178 | x1 = jax.random.normal(key, (n_batch, n_dim))
179 | t = jax.random.uniform(key, (n_batch, 1))
180 | optim = optax.adam(learning_rate=1e-3)
181 | state = optim.init(eqx.filter(model, eqx.is_array))
182 | std = jnp.sqrt(jnp.diag(model.data_cov))
183 | x1_whitened = (x1 - model.data_mean) / std
184 | x_t, dx_t = model.path.sample(x0, x1_whitened, t)
185 | loss, model2, state2 = model.train_step(x_t, t, dx_t, optim, state)
186 | assert jnp.isfinite(loss)
187 | assert isinstance(model2, FlowMatchingModel)
188 | data = (x0, x1, t)
189 | loss_epoch, model3, state3 = model.train_epoch(
190 | key, optim, state, data, batch_size=n_batch
191 | )
192 | assert jnp.isfinite(loss_epoch)
193 | assert isinstance(model3, FlowMatchingModel)
194 |
--------------------------------------------------------------------------------
/docs/configuration.md:
--------------------------------------------------------------------------------
1 | Configuration Guide
2 | ===================
3 |
4 | This page contains information about the most important hyperparameters which affect the behavior of the sampler.
5 |
6 |
7 | | Essential | Optional | Advanced |
8 | | --------------------------------- | ----------------------------------------- | ----------------------------------- |
9 | | [`n_dim`](#n_dim) | [`use_global`](#use_global) | [`keep_quantile`](#keep_quantile) |
10 | | [`rng_key_set`](#rng_key_set) | [`n_chains`](#n_chains) | [`momenutum`](#momenutum) |
11 | | [`local_sampler`](#local_sampler) | [`n_loop_training`](#n_loop_training) | [`nf_variable`](#nf_variable) |
12 | | [`data`](#data) | [`n_loop_production`](#n_loop_production) | [`local_autotune`](#local_autotune) |
13 | | [`nf_model`](#nf_model) | [`n_local_steps`](#n_local_steps) | [`train_thinning`](#train_thinning) |
14 | | | [`n_global_steps`](#n_global_steps) | |
15 | | | [`n_epochs`](#n_epochs) | |
16 | | | [`learning_rate`](#learning_rate) | |
17 | | | [`max_samples`](#max_samples) | |
18 | | | [`batch_size`](#batch_size) | |
19 | | | [`verbose`](#verbose) | |
20 |
21 |
22 |
23 |
24 | Essential arguments
25 | -------------------
26 |
27 | ## [n_dim](#n_dim)
28 |
29 | The dimension of the problem, the sampler would bug if `n_dim` does not match the input dimension of your likelihood function
30 |
31 | ## [rng_key_set](#rng_key_set)
32 |
33 | The set of Jax generated PRNG_keys.
34 |
35 | ## [data](#data)
36 |
37 | The data you want to sample from. This is used to precompile the kernels used during the sampling.
38 | Note that you keep the shape of the data consistent between runs, otherwise it would trigger recompilation.
39 | If your likelihood does not take any data arguments, simply put it as None should work.
40 |
41 | ## [local_sampler](#local_sampler)
42 | Specific local sampler you want to use.
43 |
44 | ## [nf_model](#nf_model)
45 | Specific normalizing flow model you want to use.
46 |
47 | Optional arguments that you might want to tune
48 | ----------------------------------------------
49 |
50 | ## [use_global](#use_global)
51 | Whether to use global sampler or not. Default is ``True``.
52 | Turning off global sampler will also disable to training phase.
53 | This is useful when you want to test whether the local sampler is behaving normally or perform an ablation study on the benefits of the global sampler.
54 | In production quality runs, you probably always want to use the global sampler since it improves convergence significantly.
55 |
56 | ## [n_chains](#n_chains)
57 | Number of parallel chains to run. Default is ``20``.
58 | Within your memory bandwidth and without oversubscribing your computational device, you should use as many chains as possible.
59 | The method benefits tremendously from parallelization.
60 |
61 | ## [n_loop_training](#n_loop_training)
62 | Number of local-global sample loop to run during training phase. Default is ``3``.
63 |
64 | ## [n_loop_production](#n_loop_production)
65 | Number of local-global sample loop to run during production phase. Default is ``3``.
66 | This is similar to ``n_loop_training``, the only difference is during the production loop, the normalizing flow model will not be updated anymore. This saves computation time once the flow is sufficiently trained to power global moves. As the MCMC stops being adaptive, detailed balance is also restored in the Metropolis-Hastings steps and the traditional diagnostic of MCMC convergence can be applied safely.
67 |
68 |
69 | ## [n_local_steps](#n_local_steps)
70 | Number of local steps to run during the local sampling phase. Default is ``50``.
71 |
72 | ## [n_global_steps](#n_global_steps)
73 | Number of global steps to run during the global sampling phase. Default is ``50``.
74 |
75 | ## [n_epochs](#n_epochs)
76 | Number of epochs to run during the training phase. Default is ``30``.
77 | The higher this number, the better the flow performs, at the cost of increasing the training time.
78 |
79 | ## [learning_rate](#learning_rate)
80 | Learning rate for the Adam optimizer. Default is ``1e-2``.
81 |
82 | ## [max_samples](#max_samples)
83 | Maximum number of samples used to training the normalizing flow model. Default is ``10000``.
84 |
85 | If the total number of obtained samples along the chains is more than ``max_samples`` when getting to a training phase, only a subsample of the most recent steps of the chains of size ``max_samples`` will be used for training.
86 | .. The chains dimension has priority over step dimension, meaning the sampler will try to take at least one sample from each chain before going to previous steps to retrieve more samples.
87 | One usually choose this number base on the memory capacity of the device.
88 | If the number is larger than the memory bandwidth of your device, each training loop will take longer to finish.
89 | On the other hand, the training time will not be affect if the entire dataset can fit on your device.
90 | If this number is small only the most recent samples are used in the training.
91 | This may cause the normalizing flow model to forget about some features of the global landscape that were not visited recently by the chains. For instance, it can lead to mode collapse.
92 |
93 | ## [batch_size](#batch_size)
94 | Batch size for training the normalizing flow model. Default is ``10000``.
95 | Using large batch size speeds up the training since the training time is determined by the number of batched backward passes.
96 | Unlike typical deep learning use case, since our training dataset is continuously evolving, we do not really have to worry about overfitting.
97 | Therefore, using larger batch size is usually better.
98 | The rule of thumb here is: within memory and computational bandwith, choose the largest number that would not increase the training time.
99 |
100 | ## [keep_quantile](#keep_quantile)
101 | Dictionary with keys ``params`` and ``variables`` allowing to use the model trained during a previous run of the NFSampler. These variables can be retrieved from the ``NFSampler.state`` after a run. An exemple is provided in :ref:`tutorials`.
102 |
103 | ## [verbose](#verbose)
104 | Whether to print out extra info during the inference. Default is ``False``.
105 |
106 |
107 |
108 | Only-if-you-know-what-you-are-doing arguments
109 | ---------------------------------------------
110 |
111 |
112 | ## [keep_quantile](#keep_quantile)
113 |
114 | ## [momentum](#momentum)
115 |
116 | ## [nf_variable](#nf_variable)
117 |
118 | ## [local_autotune](#local_autotune)
119 |
120 | ## [train_thinning](#train_thinning)
121 |
122 | Thinning factors for data used to train the normalizing flow.
123 | Given we only keep ``max_samples`` Samples, only the newest ``max_samples/n_chains`` in each chain are used for training the normalizing flow.
124 | This thinning factor keep every ``train_thinning`` samples in each chain.
125 | The larger the number, the less correlated each samples in each chain are.
126 | When ``max_samples*train_thining/n_chains > n_local_steps``, samples generated from different global training loops are used in training the new normalizing flow.
127 | This reduces the possibility of mode collapse since the algorithms had access to samples generated before the mode collapse if it would have happened.
128 |
129 | This API is still experimental and might be combined with other hyperparameters into one big tuning parameters later.
130 |
--------------------------------------------------------------------------------
/src/flowMC/strategy/take_steps.py:
--------------------------------------------------------------------------------
1 | from flowMC.resource.base import Resource
2 | from flowMC.resource.kernel.base import ProposalBase
3 | from flowMC.resource.buffers import Buffer
4 | from flowMC.resource.states import State
5 | from flowMC.resource.logPDF import LogPDF
6 | from flowMC.strategy.base import Strategy
7 | from jaxtyping import Array, Float, PRNGKeyArray
8 | import jax
9 | import jax.numpy as jnp
10 | import equinox as eqx
11 | from abc import abstractmethod
12 |
13 |
14 | class TakeSteps(Strategy):
15 | logpdf_name: str
16 | kernel_name: str
17 | state_name: str
18 | buffer_names: list[str]
19 | n_steps: int
20 | current_position: int
21 | thinning: int
22 | chain_batch_size: int # If vmap over a large number of chains is memory bounded, this splits the computation
23 | verbose: bool
24 |
25 | def __init__(
26 | self,
27 | logpdf_name: str,
28 | kernel_name: str,
29 | state_name: str,
30 | buffer_names: list[str],
31 | n_steps: int,
32 | thinning: int = 1,
33 | chain_batch_size: int = 0,
34 | verbose: bool = False,
35 | ):
36 | self.logpdf_name = logpdf_name
37 | self.kernel_name = kernel_name
38 | self.state_name = state_name
39 | self.buffer_names = buffer_names
40 | self.n_steps = n_steps
41 | self.current_position = 0
42 | self.thinning = thinning
43 | self.chain_batch_size = chain_batch_size
44 | self.verbose = verbose
45 |
46 | @abstractmethod
47 | def sample(
48 | self,
49 | kernel: ProposalBase,
50 | rng_key: PRNGKeyArray,
51 | initial_position: Float[Array, " n_dim"],
52 | logpdf: LogPDF,
53 | data: dict,
54 | ):
55 | raise NotImplementedError
56 |
57 | def set_current_position(self, current_position: int):
58 | self.current_position = current_position
59 |
60 | def __call__(
61 | self,
62 | rng_key: PRNGKeyArray,
63 | resources: dict[str, Resource],
64 | initial_position: Float[Array, "n_chains n_dim"],
65 | data: dict,
66 | ) -> tuple[
67 | PRNGKeyArray,
68 | dict[str, Resource],
69 | Float[Array, "n_chains n_dim"],
70 | ]:
71 | rng_key, subkey = jax.random.split(rng_key)
72 | subkey = jax.random.split(subkey, initial_position.shape[0])
73 |
74 | assert isinstance(
75 | state_resource := resources[self.state_name], State
76 | ), "State resource must be a State"
77 |
78 | assert isinstance(
79 | position_buffer_name := state_resource.data[self.buffer_names[0]], str
80 | ), "Position buffer resource name must be a string"
81 |
82 | assert isinstance(
83 | log_prob_buffer_name := state_resource.data[self.buffer_names[1]], str
84 | ), "Log probability buffer resource name must be a string"
85 |
86 | assert isinstance(
87 | acceptance_buffer_name := state_resource.data[self.buffer_names[2]], str
88 | ), "Acceptance buffer resource name must be a string"
89 |
90 | assert isinstance(
91 | position_buffer := resources[position_buffer_name], Buffer
92 | ), "Position buffer resource must be a Buffer"
93 | assert isinstance(
94 | log_prob_buffer := resources[log_prob_buffer_name], Buffer
95 | ), "Log probability buffer resource must be a Buffer"
96 | assert isinstance(
97 | acceptance_buffer := resources[acceptance_buffer_name], Buffer
98 | ), "Acceptance buffer resource must be a Buffer"
99 |
100 | kernel = resources[self.kernel_name]
101 | logpdf = resources[self.logpdf_name]
102 |
103 | # Filter jit will bypass the compilation of
104 | # the function if not clearing the cache
105 | n_chains = initial_position.shape[0]
106 | if self.chain_batch_size > 1 and n_chains > self.chain_batch_size:
107 | positions_list = []
108 | log_probs_list = []
109 | do_accepts_list = []
110 | for i in range(0, n_chains, self.chain_batch_size):
111 | batch_slice = slice(i, min(i + self.chain_batch_size, n_chains))
112 | subkey_batch = subkey[batch_slice]
113 | initial_position_batch = initial_position[batch_slice]
114 | positions_batch, log_probs_batch, do_accepts_batch = eqx.filter_jit(
115 | eqx.filter_vmap(
116 | jax.tree_util.Partial(self.sample, kernel),
117 | in_axes=(0, 0, None, None),
118 | )
119 | )(subkey_batch, initial_position_batch, logpdf, data)
120 | positions_list.append(positions_batch)
121 | log_probs_list.append(log_probs_batch)
122 | do_accepts_list.append(do_accepts_batch)
123 | positions = jnp.concatenate(positions_list, axis=0)
124 | log_probs = jnp.concatenate(log_probs_list, axis=0)
125 | do_accepts = jnp.concatenate(do_accepts_list, axis=0)
126 | else:
127 | positions, log_probs, do_accepts = eqx.filter_jit(
128 | eqx.filter_vmap(
129 | jax.tree_util.Partial(self.sample, kernel),
130 | in_axes=(0, 0, None, None),
131 | )
132 | )(subkey, initial_position, logpdf, data)
133 |
134 | positions = positions[:, :: self.thinning]
135 | log_probs = log_probs[:, :: self.thinning]
136 | do_accepts = do_accepts[:, :: self.thinning].astype(
137 | acceptance_buffer.data.dtype
138 | )
139 |
140 | position_buffer.update_buffer(positions, self.current_position)
141 | log_prob_buffer.update_buffer(log_probs, self.current_position)
142 | acceptance_buffer.update_buffer(do_accepts, self.current_position)
143 | self.current_position += self.n_steps // self.thinning
144 | return rng_key, resources, positions[:, -1]
145 |
146 |
147 | class TakeSerialSteps(TakeSteps):
148 | """TakeSerialSteps is a strategy that takes a number of steps in a serial manner,
149 | i.e. one after the other.
150 |
151 | This uses jax.lax.scan to iterate over the steps and apply the kernel to the current
152 | position. This is intended to be used for most local kernels that are dependent on
153 | the previous step.
154 | """
155 |
156 | def body(self, kernel: ProposalBase, carry, aux):
157 | key, position, log_prob, logpdf, data = carry
158 | key, subkey = jax.random.split(key)
159 | position, log_prob, do_accept = kernel.kernel(
160 | subkey, position, log_prob, logpdf, data
161 | )
162 | return (key, position, log_prob, logpdf, data), (position, log_prob, do_accept)
163 |
164 | def sample(
165 | self,
166 | kernel: ProposalBase,
167 | rng_key: PRNGKeyArray,
168 | initial_position: Float[Array, " n_dim"],
169 | logpdf: LogPDF,
170 | data: dict,
171 | ):
172 | (
173 | (last_key, last_position, last_log_prob, logpdf, data),
174 | (positions, log_probs, do_accepts),
175 | ) = jax.lax.scan(
176 | jax.tree_util.Partial(self.body, kernel),
177 | (rng_key, initial_position, logpdf(initial_position, data), logpdf, data),
178 | length=self.n_steps,
179 | )
180 | return positions, log_probs, do_accepts
181 |
182 |
183 | class TakeGroupSteps(TakeSteps):
184 | """TakeGroupSteps is a strategy that takes a number of steps in a group manner, i.e.
185 | all steps are taken at once.
186 |
187 | This is intended to be used for kernels such as normalizing flow, which proposal
188 | steps are independent of each other, and benefit from being computed in parallel.
189 | """
190 |
191 | def sample(
192 | self,
193 | kernel: ProposalBase,
194 | rng_key: PRNGKeyArray,
195 | initial_position: Float[Array, " n_dim"],
196 | logpdf: LogPDF,
197 | data: dict,
198 | ):
199 | (positions, log_probs, do_accepts) = kernel.kernel(
200 | rng_key,
201 | initial_position,
202 | logpdf(initial_position, data),
203 | logpdf,
204 | {**data, "n_steps": self.n_steps},
205 | )
206 | return positions, log_probs, do_accepts
207 |
--------------------------------------------------------------------------------
/src/flowMC/resource/model/nf_model/realNVP.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import equinox as eqx
4 | import jax
5 | import jax.numpy as jnp
6 | from jaxtyping import Array, Float, PRNGKeyArray
7 |
8 | from flowMC.resource.model.nf_model.base import NFModel
9 | from flowMC.resource.model.common import (
10 | Distribution,
11 | MLP,
12 | Gaussian,
13 | MaskedCouplingLayer,
14 | MLPAffine,
15 | )
16 |
17 |
18 | class AffineCoupling(eqx.Module):
19 | """
20 | Affine coupling layer.
21 | (Defined in the RealNVP paper https://arxiv.org/abs/1605.08803)
22 | We use tanh as the default activation function.
23 |
24 | Args:
25 | n_features: (int) The number of features in the input.
26 | n_hidden: (int) The number of hidden units in the MLP.
27 | mask: (ndarray) Alternating mask for the affine coupling layer.
28 | dt: (Float) Scaling factor for the affine coupling layer.
29 | """
30 |
31 | _mask: Array
32 | scale_MLP: MLP
33 | translate_MLP: MLP
34 | dt: Float = 1
35 |
36 | def __init__(
37 | self,
38 | n_features: int,
39 | n_hidden: int,
40 | mask: Array,
41 | key: PRNGKeyArray,
42 | dt: Float = 1,
43 | scale: Float = 1e-4,
44 | ):
45 | self._mask = mask
46 | self.dt = dt
47 | key, scale_subkey, translate_subkey = jax.random.split(key, 3)
48 | features = [n_features, n_hidden, n_features]
49 | self.scale_MLP = MLP(features, key=scale_subkey, scale=scale)
50 | self.translate_MLP = MLP(features, key=translate_subkey, scale=scale)
51 |
52 | @property
53 | def mask(self):
54 | return jax.lax.stop_gradient(self._mask)
55 |
56 | @property
57 | def n_features(self):
58 | return self.scale_MLP.n_input
59 |
60 | def __call__(self, x: Array):
61 | return self.forward(x)
62 |
63 | def forward(self, x: Array) -> Tuple[Array, Array]:
64 | """From latent space to data space.
65 |
66 | Args:
67 | x: (Array) Latent space.
68 |
69 | Returns:
70 | outputs: (Array) Data space.
71 | log_det: (Array) Log determinant of the Jacobian.
72 | """
73 | s = self.mask * self.scale_MLP(x * (1 - self.mask))
74 | s = jnp.tanh(s) * self.dt
75 | t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt
76 |
77 | # Compute log determinant of the Jacobian
78 | log_det = s.sum()
79 |
80 | # Apply the transformation
81 | outputs = (x + t) * jnp.exp(s)
82 | return outputs, log_det
83 |
84 | def inverse(self, x: Array) -> Tuple[Array, Array]:
85 | """From data space to latent space.
86 |
87 | Args:
88 | x: (Array) Data space.
89 |
90 | Returns:
91 | outputs: (Array) Latent space.
92 | log_det: (Array) Log determinant of the Jacobian.
93 | """
94 | s = self.mask * self.scale_MLP(x * (1 - self.mask))
95 | s = jnp.tanh(s) * self.dt
96 | t = self.mask * self.translate_MLP(x * (1 - self.mask)) * self.dt
97 | log_det = -s.sum()
98 | outputs = x * jnp.exp(-s) - t
99 | return outputs, log_det
100 |
101 |
102 | class RealNVP(NFModel):
103 | """
104 | RealNVP mode defined in the paper https://arxiv.org/abs/1605.08803.
105 | MLP is needed to make sure the scaling between layers are more or less the same.
106 |
107 | Args:
108 | n_layers: (int) The number of affine coupling layers.
109 | n_features: (int) The number of features in the input.
110 | n_hidden: (int) The number of hidden units in the MLP.
111 | dt: (Float) Scaling factor for the affine coupling layer.
112 |
113 | Properties:
114 | data_mean: (ndarray) Mean of Gaussian base distribution
115 | data_cov: (ndarray) Covariance of Gaussian base distribution
116 | """
117 |
118 | base_dist: Distribution
119 | affine_coupling: List[MaskedCouplingLayer]
120 | _n_features: int
121 | _data_mean: Float[Array, " n_dim"]
122 | _data_cov: Float[Array, " n_dim n_dim"]
123 |
124 | @property
125 | def n_features(self) -> int:
126 | return self._n_features
127 |
128 | @property
129 | def data_mean(self):
130 | return jax.lax.stop_gradient(self._data_mean)
131 |
132 | @property
133 | def data_cov(self):
134 | return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov))
135 |
136 | def __init__(
137 | self, n_features: int, n_layers: int, n_hidden: int, key: PRNGKeyArray, **kwargs
138 | ):
139 | if kwargs.get("base_dist") is not None:
140 | self.base_dist = kwargs.get("base_dist") # type: ignore
141 | else:
142 | self.base_dist = Gaussian(
143 | jnp.zeros(n_features), jnp.eye(n_features), learnable=False
144 | )
145 |
146 | if kwargs.get("data_mean") is not None:
147 | data_mean = kwargs.get("data_mean")
148 | assert isinstance(data_mean, Array)
149 | self._data_mean = data_mean
150 | else:
151 | self._data_mean = jnp.zeros(n_features)
152 |
153 | if kwargs.get("data_cov") is not None:
154 | data_cov = kwargs.get("data_cov")
155 | assert isinstance(data_cov, Array)
156 | self._data_cov = data_cov
157 | else:
158 | self._data_cov = jnp.eye(n_features)
159 |
160 | self._n_features = n_features
161 |
162 | def make_layer(i: int, key: PRNGKeyArray):
163 | key, scale_subkey, shift_subkey = jax.random.split(key, 3)
164 | mask = jnp.ones(n_features)
165 | mask = mask.at[: int(n_features / 2)].set(0)
166 | mask = jax.lax.cond(i % 2 == 0, lambda x: 1 - x, lambda x: x, mask)
167 | scale_MLP = MLP([n_features, n_hidden, n_features], key=scale_subkey)
168 | shift_MLP = MLP([n_features, n_hidden, n_features], key=shift_subkey)
169 | return MaskedCouplingLayer(MLPAffine(scale_MLP, shift_MLP), mask)
170 |
171 | keys = jax.random.split(key, n_layers)
172 | self.affine_coupling = eqx.filter_vmap(make_layer)(jnp.arange(n_layers), keys)
173 |
174 | def forward(
175 | self,
176 | x: Float[Array, " n_dim"],
177 | key: Optional[PRNGKeyArray] = None,
178 | condition: Optional[Float[Array, " n_condition"]] = None,
179 | ) -> tuple[Float[Array, " n_dim"], Float]:
180 | log_det = 0.0
181 | dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array)
182 |
183 | def f(carry, data):
184 | x, log_det = carry
185 | layers = eqx.combine(data, statics)
186 | x, log_det_i = layers(x, condition)
187 | return (x, log_det + log_det_i), None
188 |
189 | (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics)
190 | return x, log_det
191 |
192 | def inverse(
193 | self,
194 | x: Float[Array, " n_dim"],
195 | condition: Optional[Float[Array, " n_condition"]] = None,
196 | ) -> tuple[Float[Array, " n_dim"], Float]:
197 | """From latent space to data space."""
198 | log_det = 0.0
199 | dynamics, statics = eqx.partition(self.affine_coupling, eqx.is_array)
200 |
201 | def f(carry, data):
202 | x, log_det = carry
203 | layers = eqx.combine(data, statics)
204 | x, log_det_i = layers.inverse(x, condition)
205 | return (x, log_det + log_det_i), None
206 |
207 | (x, log_det), _ = jax.lax.scan(f, (x, log_det), dynamics, reverse=True)
208 | return x, log_det
209 |
210 | def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array:
211 | samples = self.base_dist.sample(rng_key, n_samples)
212 | samples = jax.vmap(self.inverse)(samples)[0]
213 | samples = samples * jnp.sqrt(jnp.diag(self.data_cov)) + self.data_mean
214 | return samples
215 |
216 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
217 | # TODO: Check whether taking away vmap hurts accuracy.
218 | x = (x - self.data_mean) / jnp.sqrt(jnp.diag(self.data_cov))
219 | y, log_det = self.__call__(x)
220 | log_det = log_det + jax.scipy.stats.multivariate_normal.logpdf(
221 | y, jnp.zeros(self.n_features), jnp.eye(self.n_features)
222 | )
223 | return log_det
224 |
225 | def print_parameters(self):
226 | print("RealNVP parameters:")
227 | print(f"Data mean: {self.data_mean}")
228 | print(f"Data covariance: {self.data_cov}")
229 |
--------------------------------------------------------------------------------
/joss/paper.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: 'FlowMC'
3 | tags:
4 | - Python
5 | - Bayesian Inference
6 | - Machine Learning
7 | - Jax
8 | -
9 | authors:
10 | - name: Kaze W. K. Wong
11 | orcid: 0000-0001-8432-7788
12 | # equal-contrib: true
13 | affiliation: 1
14 | - name: Marylou Gabrié
15 | orcid: 0000-0002-5989-1018
16 | # equal-contrib: true
17 | affiliation: "2, 3"
18 | - name: Dan Foreman-Mackey
19 | orcid: 0000-0002-9328-5652
20 | affiliation: 1
21 | affiliations:
22 | - name: Center for Computational Astrophysics, Flatiron Institute, New York, NY 10010, US
23 | index: 1
24 | - name: École Polytechnique, Palaiseau 91120, France
25 | index: 2
26 | - name: Center for Computational Mathematics, Flatiron Institute, New York, NY 10010, US
27 | index: 3
28 | date: 30 September 2022
29 | bibliography: paper.bib
30 | ---
31 |
32 | # Summary
33 |
34 | Across scientific fields, the modelling of increasingly complex physical processes requires more flexible models. Yet the estimation of models'parameters becomes more challenging as the dimension of the parameter space grows. A common strategy to explore parameter space is to sample through a Markov Chain Monte Carlo (MCMC). Yet even MCMC methods can struggle to fully represent the parameter space when only relying on local updates.
35 |
36 | `FlowMC` is a Python library for accelerated Markov Chain Monte Carlo (MCMC) leveraging deep generative modelling built on top of machine learning libraries `Jax` and `Flax`. At its core, `FlowMC` uses a local sampler and a learnable global sampler in tandem to efficiently sample posterior distributions with non-trivial geometry, such as multimodal distributions and distributions with local correlations. While multiple chains of the local sampler generate samples over the region of interest in the target parameter space, the package uses these samples to train a normalizing flow model, then use it to propose global jumps across the parameter space.
37 |
38 | The key features of `FlowMC` are summarized in the following list:
39 |
40 | ## Key features
41 |
42 | - Since `FlowMC` is built on top of `Jax`, it supports gradient-based sampler such as MALA and Hamiltonian Monte Carlo (HMC) through automatic differentiation.
43 | - `FlowMC` uses state-of-the-art normalizing flow models such as rational quadratic spline (RQS) for the global sampler, which is very efficient in capturing local features with relatively short training time.
44 | - Use of accelerators such as GPUs and TPUs are natively supported. The code also supports the use of multiple accelerators with SIMD parallelism.
45 | - By default, Just-in-time (JIT) compilations are used to further speed up the sampling process.
46 | - We provide a simple black box interface for the users who want to use `FlowMC` by its default parameters, yet provide at the same time an extensive guide explaining trade-offs while tuning the sampler parameters.
47 |
48 | The tight integration of all the above features makes `FlowMC` a highly performant yet simple-to-use package for statistical inference.
49 |
50 | # Statement of need
51 |
52 | Bayesian inference requires to compute expectations with respect to the posterior distribution on the parameters $\theta$ after collecting the observations $\mathcal{D}$. This posterior is given by
53 |
54 | $$
55 | p(\theta|\mathcal{D}) = \frac{\ell(\mathcal{D}|\theta) p_0(\theta)}{Z(\mathcal{D})}
56 | $$
57 |
58 | where $\ell(\mathcal{D}|\theta)$ is the likelihood induced by the model, $p_0(\theta)$ the prior on the parameters and $Z(\mathcal{D})$ the model evidence.
59 | As soon as the dimension of $\theta$ exceeds 3 or 4, it is necessary to resort to a robust sampling strategy such as a MCMC. Drastic gains in computational efficiency can be obtained by a careful selection of the MCMC transition kernel which can be assisted by machine learning methods and libraries.
60 |
61 | ***Gradient-based sampler***
62 | In a high dimensional space, sampling methods which leverage gradient information of the target distribution are shown to be efficient by proposing new samples likely to be accepted.
63 | `FlowMC` supports gradient-based samplers such as MALA and HMC through automatic differentiation with `Jax`.
64 | The computational cost of obtaining a gradient in this way is often of the same order as evaluating the target function itself, making gradient-based samplers compare usually favorably to random walks with respect to the efficiency/accuracy trade-off.
65 |
66 | ***Learned transition kernels with normalizing flow***
67 | Posterior distribution of many real-world problems have non-trivial geometry such as multi-modality and local correlation, which could drastically slow down the convergence of the sampler only based on gradient information.
68 | To address this problem, we combine a gradient-based sampler with a normalizing flow, which is a class of generative model `[@Papamakarios2019; @Kobyzev2021]`, that is trained to mimic the posterior distribution and used as a proposal a Metropolis-Hastings step. Variant of this idea have been explored in the past few years (e.g.`[@Albergo2019; @Hoffman2019; @Gabrie2021]` and references there in).
69 | Despite the growing interest for these methods few accessible implementations for non-experts already exist and none of them propose GPU and TPU. Namely, a version of the NeuTra sampler `[@Hoffman2019]` available in Pyro `[@bingham2019pyro]` and the PocoMC package `[@Karamanis2022]` are both CPU bounded.
70 |
71 | `FlowMC` implements the proposition of `[@Gabrie2021a]`.
72 | As individual chains are exploring their local neighborhood through gradient-based MCMC steps, multiple chains can be combined and fed to the normalizing flow so it can learn the global landscape of the posterior distribution. In turn, the chains can be propagated with a Metropolis-Hastings kernel using the normalizing flow to propose globally in the parameter space. The cycle of local sampling, normalizing flow tuning and global sampling is repeated until convergence of the chains.
73 | The entire algorithm belongs to the class of adaptive MCMCs `[@Andrieu2008]` collecting information from the chains previous steps to simultaneously improve the transition kernel.
74 | Usual MCMC diagnostics can be applied to asses the robustness of the inference results without worrying about the validation of the normalizing flow model, which is a common problem in deep learning.
75 | If further sampling from the posterior is necessary, the flow trained during a previous can be reused without further training.
76 | The mathematical detail of the method are explained in `[@Gabrie2021a]`.
77 |
78 | ***Use of Accelerator***
79 | Modern accelerators such as GPU and TPU are designed to execute dense computation in parallel.
80 | Due to the sequential nature of MCMC, a common approach in leveraging accelerators is to run multiple chains in parallel, then combine their results to obtain the posterior distribution.
81 | However, large portion of the computation comes from the burn-in phase, and simply by parallelizing over many chains do not help speed up the burn-in.
82 | To fully leverage the benefit from having many chains, ensemble methods such as (Cite) are often implemented.
83 | This comes with its own set of challenges, and implementing such class of methods on accelerators require careful consideration.
84 |
86 | Since `FlowMC` is built on top of `Jax`, it supports the use of accelerators by default.
87 | Users can write codes in the same way as they would do on a CPU, and the library will automatically detect the available accelerators and use them in run time.
88 | Furthermore, the library leverage Just-In-Time compilations to further improve the performance of the sampler.
89 |
90 | ***Simplicity and extensibility***
91 | Since we anticipate most of the users would like to spend most of their time building model instead of optimize the performance of the sampler,
92 | we provide a black-box interface with a few tuning parameters for users who intend to use `FlowMC` without too much customization on the sampler side.
93 | The only inputs we require from the users are the log-likelihood function, the log-prior function, and initial position of the chains.
94 | On top of the black-box interface, the package offers automatic tuning for the local samplers, in order to reduce the number of hyperparameters the users have to manage.
95 |
96 | While we provide a high-level API for most of the users, the code is also designed to be extensible. In particular, custom local and global sampling kernels can be integrated in the `sampler` module.
97 |
98 |
99 | # Acknowledgements
100 | M.G. acknowledges support from Hi!Paris.
101 | # References
102 |
--------------------------------------------------------------------------------
/src/flowMC/resource/model/nf_model/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Optional
3 |
4 | import equinox as eqx
5 | import jax
6 | import jax.numpy as jnp
7 | import optax
8 | from jaxtyping import Array, Float, PRNGKeyArray
9 | from tqdm import tqdm, trange
10 | from typing_extensions import Self
11 | from flowMC.resource.base import Resource
12 |
13 |
14 | class NFModel(eqx.Module, Resource):
15 | """Base class for normalizing flow models.
16 |
17 | This is an abstract template that should not be directly used.
18 | """
19 |
20 | _n_features: int
21 | _data_mean: Float[Array, " n_dim"]
22 | _data_cov: Float[Array, " n_dim n_dim"]
23 |
24 | @property
25 | def n_features(self):
26 | return self._n_features
27 |
28 | @property
29 | def data_mean(self):
30 | return jax.lax.stop_gradient(self._data_mean)
31 |
32 | @property
33 | def data_cov(self):
34 | return jax.lax.stop_gradient(jnp.atleast_2d(self._data_cov))
35 |
36 | @abstractmethod
37 | def __init__(self):
38 | raise NotImplementedError
39 |
40 | def __call__(
41 | self, x: Float[Array, " n_dim"]
42 | ) -> tuple[Float[Array, " n_dim"], Float]:
43 | """Forward pass of the model.
44 |
45 | Args:
46 | x (Float[Array, "n_dim"]): Input data.
47 |
48 | Returns:
49 | tuple[Float[Array, "n_dim"], Float]:
50 | Output data and log determinant of the Jacobian.
51 | """
52 | return self.forward(x)
53 |
54 | @abstractmethod
55 | def log_prob(self, x: Float[Array, " n_dim"]) -> Float:
56 | raise NotImplementedError
57 |
58 | @abstractmethod
59 | def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> Array:
60 | raise NotImplementedError
61 |
62 | @abstractmethod
63 | def forward(
64 | self, x: Float[Array, " n_dim"], key: Optional[PRNGKeyArray] = None
65 | ) -> tuple[Float[Array, " n_dim"], Float]:
66 | """Forward pass of the model.
67 |
68 | Args:
69 | x (Float[Array, "n_dim"]): Input data.
70 |
71 | Returns:
72 | tuple[Float[Array, "n_dim"], Float]:
73 | Output data and log determinant of the Jacobian.
74 | """
75 | raise NotImplementedError
76 |
77 | @abstractmethod
78 | def inverse(
79 | self, x: Float[Array, " n_dim"]
80 | ) -> tuple[Float[Array, " n_dim"], Float]:
81 | """Inverse pass of the model.
82 |
83 | Args:
84 | x (Float[Array, "n_dim"]): Input data.
85 |
86 | Returns:
87 | tuple[Float[Array, "n_dim"], Float]:
88 | Output data and log determinant of the Jacobian.
89 | """
90 | raise NotImplementedError
91 |
92 | def save_model(self, path: str):
93 | eqx.tree_serialise_leaves(path + ".eqx", self)
94 |
95 | def load_model(self, path: str) -> Self:
96 | return eqx.tree_deserialise_leaves(path + ".eqx", self)
97 |
98 | @eqx.filter_value_and_grad
99 | def loss_fn(self, x: Float[Array, "n_batch n_dim"]) -> Float:
100 | return -jnp.mean(jax.vmap(self.log_prob)(x))
101 |
102 | @eqx.filter_jit
103 | def train_step(
104 | model: Self,
105 | x: Float[Array, "n_batch n_dim"],
106 | optim: optax.GradientTransformation,
107 | state: optax.OptState,
108 | ) -> tuple[Float[Array, " 1"], Self, optax.OptState]:
109 | """Train for a single step.
110 |
111 | Args:
112 | model (eqx.Model): NF model to train.
113 | x (Array): Training data.
114 | opt_state (optax.OptState): Optimizer state.
115 |
116 | Returns:
117 | loss (Array): Loss value.
118 | model (eqx.Model): Updated model.
119 | opt_state (optax.OptState): Updated optimizer state.
120 | """
121 | print("Compiling training step")
122 | loss, grads = model.loss_fn(x)
123 | updates, state = optim.update(grads, state, model) # type: ignore
124 | model = eqx.apply_updates(model, updates)
125 | return loss, model, state
126 |
127 | def train_epoch(
128 | self: Self,
129 | rng: PRNGKeyArray,
130 | optim: optax.GradientTransformation,
131 | state: optax.OptState,
132 | data: Float[Array, "n_example n_dim"],
133 | batch_size: Float,
134 | ) -> tuple[Float, Self, optax.OptState]:
135 | """Train for a single epoch."""
136 | value = 1e9
137 | model = self
138 | train_ds_size = len(data)
139 | steps_per_epoch = train_ds_size // batch_size
140 | if steps_per_epoch > 0:
141 | perms = jax.random.permutation(rng, train_ds_size)
142 |
143 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
144 | perms = perms.reshape((steps_per_epoch, batch_size))
145 | for perm in perms:
146 | batch = data[perm, ...]
147 | value, model, state = model.train_step(batch, optim, state)
148 | else:
149 | value, model, state = model.train_step(data, optim, state)
150 |
151 | return value, model, state
152 |
153 | def train(
154 | self: Self,
155 | rng: PRNGKeyArray,
156 | data: Array,
157 | optim: optax.GradientTransformation,
158 | state: optax.OptState,
159 | num_epochs: int,
160 | batch_size: int,
161 | verbose: bool = True,
162 | ) -> tuple[PRNGKeyArray, Self, optax.OptState, Array]:
163 | """Train a normalizing flow model.
164 |
165 | Args:
166 | rng (PRNGKeyArray): JAX PRNGKey.
167 | model (eqx.Module): NF model to train.
168 | data (Array): Training data.
169 | num_epochs (int): Number of epochs to train for.
170 | batch_size (int): Batch size.
171 | verbose (bool): Whether to print progress.
172 |
173 | Returns:
174 | rng (PRNGKeyArray): Updated JAX PRNGKey.
175 | model (eqx.Model): Updated NF model.
176 | loss_values (Array): Loss values.
177 | """
178 | loss_values = jnp.zeros(num_epochs)
179 | if verbose:
180 | pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10))
181 | else:
182 | pbar = range(num_epochs)
183 |
184 | best_model = model = self
185 | best_state = state
186 | best_loss = 1e9
187 | model = eqx.tree_at(lambda m: m._data_mean, model, jnp.mean(data, axis=0))
188 | model = eqx.tree_at(lambda m: m._data_cov, model, jnp.cov(data.T))
189 | for epoch in pbar:
190 | # Use a separate PRNG key to permute image data during shuffling
191 | rng, input_rng = jax.random.split(rng)
192 | # Run an optimization step over a training batch
193 | value, model, state = model.train_epoch(
194 | input_rng, optim, state, data, batch_size
195 | )
196 | loss_values = loss_values.at[epoch].set(value)
197 | if loss_values[epoch] < best_loss:
198 | best_model = model
199 | best_state = state
200 | best_loss = loss_values[epoch]
201 | if verbose:
202 | assert isinstance(pbar, tqdm)
203 | if num_epochs > 10:
204 | if epoch % int(num_epochs / 10) == 0:
205 | pbar.set_description(f"Training NF, current loss: {value:.3f}")
206 | else:
207 | if epoch == num_epochs:
208 | pbar.set_description(f"Training NF, current loss: {value:.3f}")
209 |
210 | return rng, best_model, best_state, loss_values
211 |
212 | def to_precision(self, precision: str = "float32"):
213 | """Convert all parameters to a given precision.
214 |
215 | !!! warning
216 | This function is **experimental** and may change in the future.
217 |
218 | Args:
219 | precision (str): Precision to convert to.
220 |
221 | Returns:
222 | eqx.Module: Model with parameters converted to the given precision.
223 | """
224 |
225 | precisions_dict = {
226 | "float16": jnp.float16,
227 | "bfloat16": jnp.bfloat16,
228 | "float32": jnp.float32,
229 | "float64": jnp.float64,
230 | }
231 | try:
232 | precision_format = precisions_dict[precision.lower()]
233 | except KeyError:
234 | raise ValueError(
235 | f"Precision {precision} not supported.\
236 | Choose from {precisions_dict.keys()}"
237 | )
238 | dynamic_model, static_model = eqx.partition(self, eqx.is_array)
239 | dynamic_model = jax.tree.map(
240 | lambda x: x.astype(precision_format), dynamic_model
241 | )
242 | return eqx.combine(dynamic_model, static_model)
243 |
244 | save_resource = save_model
245 | load_resource = load_model
246 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # flowMC
2 |
3 | **Normalizing-flow enhanced sampling package for probabilistic inference**
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | > [!WARNING]
14 | > As I have new priority in my life now, this code base is in very low maintanence mode. This means I am not actively developing the codebase, and only infrequently responds to the issue board. Small PRs that I can finish reviewing within say 15 minutes will still be looked at, but anything that changes more than 5-10 long files is unlikely to retain my attention. In the meantime, a group of friends of mine forked the repo here https://github.com/GW-JAX-Team/flowMC. I am not officially associated with them and have no idea on what their plan with the fork, but feel free to look over there to find if they have something you need.
15 |
16 | > [!WARNING]
17 | > Note that `flowMC` has not reached v1.0.0, meaning the API could subject to changes. In general, the higher level the API, the less likely it is going to change. However, intermediate level API such as the resource strategy interface could subject to major revision for performance concerns.
18 |
19 | 
20 |
21 | flowMC is a Jax-based python package for normalizing-flow enhanced Markov chain Monte Carlo (MCMC) sampling.
22 | The code is open source under MIT license, and it is under active development.
23 |
24 | - Just-in-time compilation is supported.
25 | - Native support for GPU acceleration.
26 | - Suit for problems with multi-modality.
27 | - Minimal tuning.
28 |
29 | # Installation
30 |
31 | The simplest way to install the package is to do it through pip
32 |
33 | ```
34 | pip install flowMC
35 | ```
36 |
37 | This will install the latest stable release and its dependencies.
38 | flowMC is based on [Jax](https://github.com/google/jax) and [Equinox](https://github.com/patrick-kidger/equinox).
39 | By default, installing flowMC will automatically install Jax and Equinox available on [PyPI](https://pypi.org).
40 | By default this install the CPU version of Jax. If you have a GPU and want to use it, you can install the GPU version of Jax by running:
41 |
42 | ```
43 | pip install flowMC[cuda]
44 | ```
45 |
46 | If you want to install the latest version of flowMC, you can clone this repo and install it locally:
47 |
48 | ```
49 | git clone https://github.com/kazewong/flowMC.git
50 | cd flowMC
51 | pip install -e .
52 | ```
53 |
54 | There are a couple more extras that you can install with flowMC, including:
55 | - `flowMC[docs]`: Install the documentation dependencies.
56 | - `flowMC[codeqa]`: Install the code quality dependencies.
57 | - `flowMC[visualize]`: Install the visualization dependencies.
58 |
59 | On top of `pip` installation, we highly encourage you to use [uv](https://docs.astral.sh/uv/) to manage your python environment. Once you clone the repo, you can run `uv sync` to create a virtual environment with all the dependencies installed.
60 | # Attribution
61 |
62 | If you used `flowMC` in your research, we would really appericiate it if you could at least cite the following papers:
63 |
64 | ```
65 | @article{Wong:2022xvh,
66 | author = "Wong, Kaze W. k. and Gabri\'e, Marylou and Foreman-Mackey, Daniel",
67 | title = "{flowMC: Normalizing flow enhanced sampling package for probabilistic inference in JAX}",
68 | eprint = "2211.06397",
69 | archivePrefix = "arXiv",
70 | primaryClass = "astro-ph.IM",
71 | doi = "10.21105/joss.05021",
72 | journal = "J. Open Source Softw.",
73 | volume = "8",
74 | number = "83",
75 | pages = "5021",
76 | year = "2023"
77 | }
78 |
79 | @article{Gabrie:2021tlu,
80 | author = "Gabri\'e, Marylou and Rotskoff, Grant M. and Vanden-Eijnden, Eric",
81 | title = "{Adaptive Monte Carlo augmented with normalizing flows}",
82 | eprint = "2105.12603",
83 | archivePrefix = "arXiv",
84 | primaryClass = "physics.data-an",
85 | doi = "10.1073/pnas.2109420119",
86 | journal = "Proc. Nat. Acad. Sci.",
87 | volume = "119",
88 | number = "10",
89 | pages = "e2109420119",
90 | year = "2022"
91 | }
92 | ```
93 |
94 | This will help `flowMC` getting more recognition, and the main benefit *for you* is this means the `flowMC` community will grow and it will be continuously improved. If you believe in the magic of open-source software, please support us by attributing our software in your work.
95 |
96 |
97 | `flowMC` is a Jax implementation of methods described in:
98 | > *Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods* Gabrié M., Rotskoff G. M., Vanden-Eijnden E. - ICML INNF+ workshop 2021 - [pdf](https://openreview.net/pdf?id=mvtooHbjOwx)
99 |
100 | > *Adaptive Monte Carlo augmented with normalizing flows.*
101 | Gabrié M., Rotskoff G. M., Vanden-Eijnden E. - PNAS 2022 - [doi](https://www.pnas.org/doi/10.1073/pnas.2109420119), [arxiv](https://arxiv.org/abs/2105.12603)
102 |
103 | ## Contributors
104 |
105 |
106 |
107 |
108 |
Hajime Kawahara 🐛 |
112 | Daniel Dodd 📖 👀 ⚠️ 🐛 |
113 | Matt Graham 🐛 ⚠️ 👀 📖 |
114 | Kaze Wong 🐛 📝 💻 🖋 📖 💡 🚇 🚧 🔬 👀 ⚠️ ✅ |
115 | Marylou Gabrié 🐛 💻 🖋 📖 💡 🚧 🔬 ⚠️ ✅ |
116 | Meesum Qazalbash 💻 🚧 |
117 | Thomas Ng 💻 🚧 |
118 |
Thomas Edwards 🐛 💻 |
121 |