├── .gitignore ├── LICENSE ├── README.md ├── models ├── __init__.py ├── autoregressive.py ├── bijectors.py ├── maf.py └── nsf.py └── notebooks ├── example.ipynb └── sbi.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Siddharth Mishra-Sharma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional normalizing flows in Jax 2 | 3 | Implementation of some common normalizing flow models allowing for a conditioning context using [Jax](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Distrax](https://github.com/deepmind/distrax). The following are currently implemented: 4 | - Masked/Inverse Autoregressive Flows (MAF/IAF; [Papamakarios et al, 2017](https://arxiv.org/abs/1705.07057) and [Kingma et al, 2016](https://arxiv.org/abs/1606.04934)) 5 | - Neural Spline Flows (NSF; [Durkan et al, 2019](https://arxiv.org/abs/1906.04032)) 6 | 7 | ## Examples 8 | - See [notebooks/example.ipynb](notebooks/example.ipynb) for a simple usage example. 9 | - See [notebooks/sbi.ipynb](notebooks/sbi.ipynb) for an example application for neural simulation-based inference (conditional posterior estimation). 10 | 11 | ## Basic usage 12 | 13 | ```python 14 | import jax 15 | from models.maf import MaskedAutoregressiveFlow 16 | from models.nsf import NeuralSplineFlow 17 | 18 | n_dim = 2 # Feature dim 19 | n_context = 1 # Context dim 20 | 21 | ## Define flow model 22 | # model = MaskedAutoregressiveFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=12, activation="tanh", use_random_permutations=False) 23 | model = NeuralSplineFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=8, activation="gelu", n_bins=4) 24 | 25 | ## Initialize model and params 26 | key = jax.random.PRNGKey(42) 27 | x_test = jax.random.uniform(key=key, shape=(64, n_dim)) 28 | context = jax.random.uniform(key=key, shape=(64, n_context)) 29 | params = model.init(key, x_test, context) 30 | 31 | ## Log-prob and sampling 32 | log_prob = model.apply(params, x_test, jnp.ones((x_test.shape[0], n_context))) 33 | samples = model.apply(params, n_samples, key, jnp.ones((n_samples, n_context)), method=model.sample) 34 | ``` -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smsharma/jax-conditional-flows/f132113ce8d88e9cd75ed121876fc25e6af34caa/models/__init__.py -------------------------------------------------------------------------------- /models/autoregressive.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from flax.linen.module import compact 5 | import flax.linen as nn 6 | from flax.linen.dtypes import promote_dtype 7 | import distrax 8 | from tensorflow_probability.substrates import jax as tfp 9 | 10 | from typing import Any, List 11 | import dataclasses 12 | 13 | Array = Any 14 | tfb = tfp.bijectors 15 | 16 | 17 | class MaskedDense(nn.Dense): 18 | """A linear transformation applied over the last dimension of the input. 19 | 20 | Attributes: 21 | mask: mask to apply to the weights. 22 | """ 23 | 24 | mask: Array = None 25 | 26 | @compact 27 | def __call__(self, inputs: Array) -> Array: 28 | """Applies a linear transformation to the inputs along the last dimension. 29 | 30 | Args: 31 | inputs: The nd-array to be transformed. 32 | 33 | Returns: 34 | The transformed input. 35 | """ 36 | 37 | kernel = self.param("kernel", self.kernel_init, (jnp.shape(inputs)[-1], self.features), self.param_dtype) 38 | if self.use_bias: 39 | bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype) 40 | else: 41 | bias = None 42 | inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) 43 | 44 | kernel = self.mask * kernel 45 | 46 | y = jax.lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) 47 | if bias is not None: 48 | y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) 49 | return y 50 | 51 | 52 | class MADE(nn.Module): 53 | n_params: Any = 2 54 | n_context: Any = 0 55 | hidden_dims: List[int] = dataclasses.field(default_factory=lambda: [32, 32]) 56 | activation: str = "tanh" 57 | 58 | @compact 59 | def __call__(self, y: Array, context=None): 60 | if context is not None: 61 | # Stack with context on the left so that the parameters are autoregressively conditioned on it with left-to-right ordering 62 | y = jnp.hstack([context, y]) 63 | 64 | broadcast_dims = y.shape[:-1] 65 | 66 | masks = tfb.masked_autoregressive._make_dense_autoregressive_masks(params=2, event_size=self.n_params + self.n_context, hidden_units=self.hidden_dims, input_order="left-to-right") # 2 parameters for scele and shift factors 67 | 68 | for mask in masks[:-1]: 69 | y = MaskedDense(features=mask.shape[-1], mask=mask)(y) 70 | y = getattr(jax.nn, self.activation)(y) 71 | y = MaskedDense(features=masks[-1].shape[-1], mask=masks[-1])(y) 72 | 73 | # Unravel the inputs and parameters 74 | params = y.reshape(broadcast_dims + (self.n_params + self.n_context, 2)) 75 | 76 | # Only take the values corresponding to the parameters of interest for scale and shift; ignore context outputs 77 | params = params[..., self.n_context :, :] 78 | 79 | return params 80 | 81 | 82 | class MAF(distrax.Bijector): 83 | def __init__(self, bijector_fn, unroll_loop=False): 84 | super().__init__(event_ndims_in=1) 85 | 86 | self.autoregressive_fn = bijector_fn 87 | self.unroll_loop = unroll_loop 88 | 89 | def forward_and_log_det(self, x, context): 90 | event_ndims = x.shape[-1] 91 | 92 | if self.unroll_loop: 93 | y = jnp.zeros_like(x) 94 | log_det = None 95 | 96 | for _ in range(event_ndims): 97 | params = self.autoregressive_fn(y, context) 98 | shift, log_scale = params[..., 0], params[..., 1] 99 | y, log_det = distrax.ScalarAffine(shift=shift, log_scale=log_scale).forward_and_log_det(x) 100 | 101 | # TODO: Rewrite with Flax primitives rather than jax.lax; these cannot be mixed 102 | else: 103 | 104 | def update_fn(i, y_and_log_det): 105 | y, log_det = y_and_log_det 106 | params = self.autoregressive_fn(y) 107 | shift, log_scale = params[..., 0], params[..., 1] 108 | y, log_det = distrax.ScalarAffine(shift=shift, log_scale=log_scale).forward_and_log_det(x) 109 | return y, log_det 110 | 111 | y, log_det = jax.lax.fori_loop(0, event_ndims, update_fn, (jnp.zeros_like(x), jnp.zeros_like(x))) 112 | 113 | return y, log_det.sum(-1) 114 | 115 | def inverse_and_log_det(self, y, context): 116 | params = self.autoregressive_fn(y, context) 117 | shift, log_scale = params[..., 0], params[..., 1] 118 | x, log_det = distrax.ScalarAffine(shift=shift, log_scale=log_scale).inverse_and_log_det(y) 119 | 120 | return x, log_det.sum(-1) 121 | -------------------------------------------------------------------------------- /models/bijectors.py: -------------------------------------------------------------------------------- 1 | # Bijectors with base from Distrax, additionally allowing for a conditioning context 2 | 3 | from typing import Any, List, Tuple, Optional 4 | 5 | import math 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | import distrax 11 | from distrax._src.bijectors.chain import Chain 12 | from distrax._src.bijectors.inverse import Inverse 13 | from distrax._src.distributions.transformed import Transformed 14 | from distrax._src.bijectors.masked_coupling import MaskedCoupling 15 | from distrax._src.utils import math 16 | 17 | Array = Any 18 | PRNGKey = Array 19 | 20 | 21 | class TransformedConditional(Transformed): 22 | def __init__(self, distribution, flow): 23 | super().__init__(distribution, flow) 24 | 25 | def sample(self, seed: PRNGKey, sample_shape: List[int], context: Optional[Array] = None) -> Array: 26 | x = self.distribution.sample(seed=seed, sample_shape=sample_shape) 27 | y, _ = self.bijector.forward_and_log_det(x, context) 28 | return y 29 | 30 | def log_prob(self, x: Array, context: Optional[Array] = None) -> Array: 31 | x, ildj_y = self.bijector.inverse_and_log_det(x, context) 32 | lp_x = self.distribution.log_prob(x) 33 | lp_y = lp_x + ildj_y 34 | return lp_y 35 | 36 | def sample_and_log_prob(self, seed: PRNGKey, sample_shape: List[int], context: Optional[Array] = None) -> Tuple[Array, Array]: 37 | x, lp_x = self.distribution.sample_and_log_prob(seed=seed, sample_shape=sample_shape) 38 | y, fldj = jax.vmap(self.bijector.forward_and_log_det)(x, context) 39 | lp_y = jax.vmap(jnp.subtract)(lp_x, fldj) 40 | return y, lp_y 41 | 42 | 43 | class InverseConditional(Inverse): 44 | def __init__(self, *args, **kwargs): 45 | super().__init__(*args, **kwargs) 46 | 47 | def forward(self, x: Array, context: Optional[Array] = None) -> Array: 48 | return self._bijector.inverse(x, context) 49 | 50 | def inverse(self, y: Array, context: Optional[Array] = None) -> Array: 51 | return self._bijector.forward(y, context) 52 | 53 | def forward_and_log_det(self, x: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 54 | return self._bijector.inverse_and_log_det(x, context) 55 | 56 | def inverse_and_log_det(self, y: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 57 | return self._bijector.forward_and_log_det(y, context) 58 | 59 | 60 | class ChainConditional(Chain): 61 | def __init__(self, *args): 62 | super().__init__(*args) 63 | 64 | def forward(self, x: Array, context: Optional[Array] = None) -> Array: 65 | for bijector in reversed(self._bijectors): 66 | x = bijector.forward(x, context) 67 | return x 68 | 69 | def inverse(self, y: Array, context: Optional[Array] = None) -> Array: 70 | for bijector in self._bijectors: 71 | y = bijector.inverse(y, context) 72 | return y 73 | 74 | def forward_and_log_det(self, x: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 75 | x, log_det = self._bijectors[-1].forward_and_log_det(x, context) 76 | for bijector in reversed(self._bijectors[:-1]): 77 | x, ld = bijector.forward_and_log_det(x, context) 78 | log_det += ld 79 | return x, log_det 80 | 81 | def inverse_and_log_det(self, y: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 82 | y, log_det = self._bijectors[0].inverse_and_log_det(y, context) 83 | for bijector in self._bijectors[1:]: 84 | y, ld = bijector.inverse_and_log_det(y, context) 85 | log_det += ld 86 | return y, log_det 87 | 88 | 89 | class Permute(distrax.Bijector): 90 | def __init__(self, permutation: Array, axis: int = -1): 91 | 92 | super().__init__(event_ndims_in=1) 93 | 94 | self.permutation = jnp.array(permutation) 95 | self.axis = axis 96 | 97 | def permute_along_axis(self, x: Array, permutation: Array, axis: int = -1) -> Array: 98 | x = jnp.moveaxis(x, axis, 0) 99 | x = x[permutation, ...] 100 | x = jnp.moveaxis(x, 0, axis) 101 | return x 102 | 103 | def forward_and_log_det(self, x: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 104 | y = self.permute_along_axis(x, self.permutation, axis=self.axis) 105 | return y, jnp.zeros(x.shape[: -self.event_ndims_in]) 106 | 107 | def inverse_and_log_det(self, y: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 108 | inv_permutation = jnp.zeros_like(self.permutation) 109 | inv_permutation = inv_permutation.at[self.permutation].set(jnp.arange(len(self.permutation))) 110 | x = self.permute_along_axis(y, inv_permutation) 111 | return x, jnp.zeros(y.shape[: -self.event_ndims_in]) 112 | 113 | 114 | class MaskedCouplingConditional(MaskedCoupling): 115 | def __init__(self, *args, **kwargs): 116 | super().__init__(*args, **kwargs) 117 | 118 | def forward_and_log_det(self, x: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 119 | self._check_forward_input_shape(x) 120 | masked_x = jnp.where(self._event_mask, x, 0.0) 121 | params = self._conditioner(masked_x, context) 122 | y0, log_d = self._inner_bijector(params).forward_and_log_det(x) 123 | y = jnp.where(self._event_mask, x, y0) 124 | logdet = math.sum_last(jnp.where(self._mask, 0.0, log_d), self._event_ndims - self._inner_event_ndims) 125 | return y, logdet 126 | 127 | def inverse_and_log_det(self, y: Array, context: Optional[Array] = None) -> Tuple[Array, Array]: 128 | self._check_inverse_input_shape(y) 129 | masked_y = jnp.where(self._event_mask, y, 0.0) 130 | params = self._conditioner(masked_y, context) 131 | x0, log_d = self._inner_bijector(params).inverse_and_log_det(y) 132 | x = jnp.where(self._event_mask, y, x0) 133 | logdet = math.sum_last(jnp.where(self._mask, 0.0, log_d), self._event_ndims - self._inner_event_ndims) 134 | return x, logdet 135 | -------------------------------------------------------------------------------- /models/maf.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | import dataclasses 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import flax.linen as nn 7 | import distrax 8 | 9 | from models.bijectors import InverseConditional, ChainConditional, TransformedConditional, Permute 10 | from models.autoregressive import MAF, MADE 11 | 12 | Array = Any 13 | 14 | 15 | class MaskedAutoregressiveFlow(nn.Module): 16 | # Note: Does not currently allow for general event shapes 17 | 18 | n_dim: int 19 | n_context: int = 0 20 | n_transforms: int = 4 21 | hidden_dims: List[int] = dataclasses.field(default_factory=lambda: [128, 128]) 22 | activation: str = "gelu" 23 | unroll_loop: bool = True 24 | use_random_permutations: bool = True 25 | rng_key: Array = jax.random.PRNGKey(42) 26 | inverse: bool = False 27 | 28 | def setup(self): 29 | self.made = [MADE(n_params=self.n_dim, n_context=self.n_context, activation=self.activation, hidden_dims=self.hidden_dims, name="made_{}".format(i)) for i in range(self.n_transforms)] 30 | 31 | bijectors = [] 32 | key = self.rng_key 33 | for i in range(self.n_transforms): 34 | # Permutation 35 | if self.use_random_permutations: 36 | permutation = jax.random.choice(key, jnp.arange(self.n_dim), shape=(self.n_dim,), replace=False) 37 | key, _ = jax.random.split(key) 38 | else: 39 | permutation = list(reversed(range(self.n_dim))) 40 | bijectors.append(Permute(permutation)) 41 | 42 | bijector_af = MAF(bijector_fn=self.made[i], unroll_loop=self.unroll_loop) 43 | if self.inverse: 44 | bijector_af = InverseConditional(bijector_af) # Flip forward and reverse directions for IAF 45 | bijectors.append(bijector_af) 46 | 47 | self.bijector = InverseConditional(ChainConditional(bijectors)) # Forward direction goes from target to base distribution 48 | self.base_dist = distrax.MultivariateNormalDiag(jnp.zeros(self.n_dim), jnp.ones(self.n_dim)) 49 | 50 | self.flow = TransformedConditional(self.base_dist, self.bijector) 51 | 52 | def __call__(self, x: Array, context: Array = None) -> Array: 53 | return self.flow.log_prob(x, context=context) 54 | 55 | def sample(self, num_samples: int, rng: Array, context: Array = None) -> Array: 56 | return self.flow.sample(seed=rng, sample_shape=(num_samples,), context=context) 57 | -------------------------------------------------------------------------------- /models/nsf.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | import dataclasses 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import flax.linen as nn 8 | from flax.linen.module import compact 9 | import distrax 10 | 11 | from models.bijectors import InverseConditional, ChainConditional, TransformedConditional, MaskedCouplingConditional 12 | 13 | Array = Any 14 | 15 | 16 | class Conditioner(nn.Module): 17 | event_shape: List[int] 18 | context_shape: List[int] 19 | hidden_dims: List[int] 20 | num_bijector_params: int 21 | activation: str = "tanh" 22 | 23 | @compact 24 | def __call__(self, x: Array, context=None): 25 | # Infer batch dims 26 | batch_shape = x.shape[: -len(self.event_shape)] 27 | batch_shape_context = context.shape[: -len(self.context_shape)] 28 | assert batch_shape == batch_shape_context 29 | 30 | # Flatten event dims 31 | x = x.reshape(*batch_shape, -1) 32 | context = context.reshape(*batch_shape, -1) 33 | 34 | x = jnp.hstack([context, x]) 35 | 36 | for hidden_dim in self.hidden_dims: 37 | x = nn.Dense(hidden_dim)(x) 38 | x = getattr(jax.nn, self.activation)(x) 39 | x = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params, kernel_init=jax.nn.initializers.zeros, bias_init=jax.nn.initializers.zeros)(x) 40 | 41 | x = x.reshape(*batch_shape, *(tuple(self.event_shape) + (self.num_bijector_params,))) 42 | 43 | return x 44 | 45 | 46 | class NeuralSplineFlow(nn.Module): 47 | """Based on the implementation in the Distrax repo, https://github.com/deepmind/distrax/blob/master/examples/flow.py""" 48 | 49 | n_dim: int 50 | n_context: int = 0 51 | n_transforms: int = 4 52 | hidden_dims: List[int] = dataclasses.field(default_factory=lambda: [128, 128]) 53 | activation: str = "gelu" 54 | n_bins: int = 8 55 | range_min: float = -1.0 56 | range_max: float = 1.0 57 | event_shape: Optional[List[int]] = None 58 | context_shape: Optional[List[int]] = None 59 | 60 | def setup(self): 61 | def bijector_fn(params: Array): 62 | return distrax.RationalQuadraticSpline(params, range_min=self.range_min, range_max=self.range_max) 63 | 64 | # If event shapes are not provided, assume single event and context dimensions 65 | event_shape = (self.n_dim,) if self.event_shape is None else self.event_shape 66 | context_shape = (self.n_context,) if self.context_shape is None else self.context_shape 67 | 68 | # Alternating binary mask 69 | mask = jnp.arange(0, np.prod(event_shape)) % 2 70 | mask = jnp.reshape(mask, event_shape) 71 | mask = mask.astype(bool) 72 | 73 | # Number of parameters for the rational-quadratic spline: 74 | # - `num_bins` bin widths 75 | # - `num_bins` bin heights 76 | # - `num_bins + 1` knot slopes 77 | # for a total of `3 * num_bins + 1` parameters 78 | num_bijector_params = 3 * self.n_bins + 1 79 | 80 | self.conditioner = [Conditioner(event_shape=event_shape, context_shape=context_shape, hidden_dims=self.hidden_dims, num_bijector_params=num_bijector_params, activation=self.activation, name="conditioner_{}".format(i)) for i in range(self.n_transforms)] 81 | 82 | bijectors = [] 83 | for i in range(self.n_transforms): 84 | bijectors.append(MaskedCouplingConditional(mask=mask, bijector=bijector_fn, conditioner=self.conditioner[i])) 85 | mask = jnp.logical_not(mask) # Flip the mask after each layer 86 | 87 | self.bijector = InverseConditional(ChainConditional(bijectors)) 88 | self.base_dist = distrax.MultivariateNormalDiag(jnp.zeros(event_shape), jnp.ones(event_shape)) 89 | 90 | self.flow = TransformedConditional(self.base_dist, self.bijector) 91 | 92 | def __call__(self, x: Array, context: Array = None) -> Array: 93 | return self.flow.log_prob(x, context=context) 94 | 95 | def sample(self, num_samples: int, rng: Array, context: Array = None) -> Array: 96 | return self.flow.sample(seed=rng, sample_shape=(num_samples,), context=context) 97 | -------------------------------------------------------------------------------- /notebooks/sbi.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 129, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from scipy.stats import chi2\n", 12 | "from tqdm import tqdm, trange" 13 | ] 14 | }, 15 | { 16 | "attachments": {}, 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "## Power-law bump simulator" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 130, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "def bump_forward_model(y, amp_s, mu_s, std_s, amp_b, exp_b):\n", 30 | " \"\"\" Forward model for a Gaussian bump (amp_s, mu_s, std_s) on top of a power-law background (amp_b, exp_b).\n", 31 | " \"\"\"\n", 32 | " x_b = amp_b * (y ** exp_b) # Power-law background\n", 33 | " x_s = amp_s * np.exp(-((y - mu_s) ** 2) / (2 * std_s ** 2)) # Gaussian signal\n", 34 | "\n", 35 | " x = x_b + x_s # Total mean signal\n", 36 | "\n", 37 | " return x" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 131, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "" 49 | ] 50 | }, 51 | "execution_count": 131, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | }, 55 | { 56 | "data": { 57 | "image/png": "", 58 | "text/plain": [ 59 | "
" 60 | ] 61 | }, 62 | "metadata": {}, 63 | "output_type": "display_data" 64 | } 65 | ], 66 | "source": [ 67 | "def poisson_interval(k, alpha=0.32): \n", 68 | " \"\"\" Uses chi2 to get the poisson interval.\n", 69 | " \"\"\"\n", 70 | " a = alpha\n", 71 | " low, high = (chi2.ppf(a/2, 2*k) / 2, chi2.ppf(1-a/2, 2*k + 2) / 2)\n", 72 | " if k == 0: \n", 73 | " low = 0.0\n", 74 | " return k - low, high - k\n", 75 | "\n", 76 | "y = np.linspace(0.1, 1, 50) # Dependent variable\n", 77 | "\n", 78 | "# Mean expected counts\n", 79 | "x_mu = bump_forward_model(y, \n", 80 | " amp_s=50, mu_s=0.8, std_s=0.05, # Signal params\n", 81 | " amp_b=50, exp_b=-0.5) # Background params\n", 82 | "\n", 83 | "# Realized counts\n", 84 | "x = np.random.poisson(x_mu)\n", 85 | "x_err = np.array([poisson_interval(k) for k in x.T]).T\n", 86 | "\n", 87 | "# Plot\n", 88 | "plt.plot(y, x_mu, color='k', ls='--', label=\"Mean expected counts\")\n", 89 | "plt.errorbar(y, x, yerr=x_err, fmt='o', color='k', label=\"Realized counts\")\n", 90 | "\n", 91 | "plt.xlabel(\"$y$\")\n", 92 | "plt.ylabel(\"Counts\")\n", 93 | "\n", 94 | "plt.legend()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 132, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/plain": [ 105 | "array([171, 159, 140, 118, 123, 114, 105, 124, 101, 86, 101, 86, 85,\n", 106 | " 81, 110, 101, 75, 80, 82, 68, 70, 63, 79, 71, 57, 75,\n", 107 | " 60, 78, 73, 64, 68, 63, 75, 68, 81, 71, 98, 101, 109,\n", 108 | " 100, 106, 78, 64, 54, 49, 62, 53, 50, 54, 51])" 109 | ] 110 | }, 111 | "execution_count": 132, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "def bump_simulator(thetas, y):\n", 118 | " \"\"\" Simulate samples from the bump forward model given theta = (amp_s, mu_s) and abscissa points y.\n", 119 | " \"\"\"\n", 120 | " amp_s, mu_s = thetas\n", 121 | " std_s, amp_b, exp_b = 0.05, 50, -0.5\n", 122 | " x_mu = bump_forward_model(y, amp_s, mu_s, std_s, amp_b, exp_b)\n", 123 | " x = np.random.poisson(x_mu)\n", 124 | " return x\n", 125 | "\n", 126 | "# Test it out\n", 127 | "bump_simulator([50, 0.8], y)" 128 | ] 129 | }, 130 | { 131 | "attachments": {}, 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "## Training data" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 133, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "100%|██████████| 50000/50000 [00:00<00:00, 62598.60it/s]\n" 148 | ] 149 | } 150 | ], 151 | "source": [ 152 | "n_train = 50_000\n", 153 | "\n", 154 | "# Simulate training data\n", 155 | "theta_samples = np.random.uniform(low=[0, 0], high=[200, 1], size=(n_train, 2)) # Parameter proposal\n", 156 | "x_samples = np.array([bump_simulator(theta, y) for theta in tqdm(theta_samples)])" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 134, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# Normalize the data\n", 166 | "x_mean = x_samples.mean()\n", 167 | "x_std = x_samples.std()\n", 168 | "x_samples = (x_samples - x_mean) / x_std\n", 169 | "\n", 170 | "theta_mean = theta_samples.mean(axis=0)\n", 171 | "theta_std = theta_samples.std(axis=0)\n", 172 | "theta_samples = (theta_samples - theta_mean) / theta_std" 173 | ] 174 | }, 175 | { 176 | "attachments": {}, 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "## Neural posterior estimator model" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 188, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "import sys\n", 190 | "sys.path.append(\"../\")\n", 191 | "\n", 192 | "import jax\n", 193 | "import optax\n", 194 | "import flax.linen as nn\n", 195 | "\n", 196 | "from models.maf import MaskedAutoregressiveFlow\n", 197 | "from models.nsf import NeuralSplineFlow" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 236, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "class MLP(nn.Module):\n", 207 | " \"\"\" A simple MLP in Flax. This will be the feature extractor.\n", 208 | " \"\"\"\n", 209 | " hidden_dim: int = 32\n", 210 | " out_dim: int = 2\n", 211 | " n_layers: int = 2\n", 212 | "\n", 213 | " @nn.compact\n", 214 | " def __call__(self, x):\n", 215 | " for _ in range(self.n_layers):\n", 216 | " x = nn.Dense(features=self.hidden_dim)(x)\n", 217 | " x = nn.gelu(x)\n", 218 | " x = nn.Dense(features=self.out_dim)(x)\n", 219 | " return x" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 238, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "class NeuralPosteriorEstimator(nn.Module):\n", 229 | " \"\"\" A neural posterior estimator.\n", 230 | " \"\"\"\n", 231 | " d_param: int = 2 # Param dim\n", 232 | " d_hidden: int = 64 # Hidden dim of MLP and flow\n", 233 | " d_summaries: int = 16 # Number of summaries from MLP\n", 234 | " n_layers: int = 4 # Number of layers in MLP\n", 235 | " n_transforms: int = 6 # Number of flow transforms\n", 236 | "\n", 237 | " def setup(self):\n", 238 | " self.featurizer = MLP(hidden_dim=self.d_hidden, out_dim=self.d_summaries, n_layers=self.n_layers)\n", 239 | " self.flow = MaskedAutoregressiveFlow(n_dim=self.d_param, n_context=self.d_summaries, hidden_dims=2 * [self.d_hidden], n_transforms=self.n_transforms, activation=\"tanh\", use_random_permutations=False)\n", 240 | "\n", 241 | " @nn.compact\n", 242 | " def __call__(self, x, theta):\n", 243 | "\n", 244 | " # Pass data through MLP to get summaries\n", 245 | " context =self.featurizer(x)\n", 246 | "\n", 247 | " # Use summaries as context for flow\n", 248 | " log_prob = self.flow(theta, context)\n", 249 | "\n", 250 | " return log_prob" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 239, 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "Array([-1.7746974, -1.9767084, -3.3885179, -3.3411727, -2.5347002,\n", 262 | " -2.8596387, -2.7574067, -2.4476223, -2.9018464, -1.8135326,\n", 263 | " -2.7627273, -2.8596678, -2.3531556, -3.523026 , -2.882647 ,\n", 264 | " -1.8737776], dtype=float32)" 265 | ] 266 | }, 267 | "execution_count": 239, 268 | "metadata": {}, 269 | "output_type": "execute_result" 270 | } 271 | ], 272 | "source": [ 273 | "npe = NeuralPosteriorEstimator()\n", 274 | "\n", 275 | "key = jax.random.PRNGKey(0)\n", 276 | "log_prob, params = npe.init_with_output(rngs=key, x=x_samples[:16], theta=theta_samples[:16])\n", 277 | "\n", 278 | "log_prob" 279 | ] 280 | }, 281 | { 282 | "attachments": {}, 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "## Training" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 240, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "@jax.jit\n", 296 | "def loss_fn(params, x, theta):\n", 297 | " log_prob = npe.apply(params, x, theta)\n", 298 | " return -log_prob.mean()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 242, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "opt = optax.sgd(learning_rate=1e-4, momentum=0.99, nesterov=True)\n", 308 | "opt_state = opt.init(params)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 243, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "name": "stderr", 318 | "output_type": "stream", 319 | "text": [ 320 | "100%|██████████| 3000/3000 [00:51<00:00, 58.13it/s, val=0.6358305] \n" 321 | ] 322 | } 323 | ], 324 | "source": [ 325 | "n_steps = 3000\n", 326 | "n_batch = 128\n", 327 | "\n", 328 | "with trange(n_steps) as steps:\n", 329 | " for step in steps:\n", 330 | "\n", 331 | " # Draw a random batches from x\n", 332 | " key, subkey = jax.random.split(key)\n", 333 | " idx = jax.random.choice(key, x.shape[0], shape=(n_batch,))\n", 334 | "\n", 335 | " x_batch, theta_batch = x_samples[idx], theta_samples[idx]\n", 336 | " \n", 337 | " loss, grads = jax.value_and_grad(loss_fn)(params, x_batch, theta_batch)\n", 338 | " updates, opt_state = opt.update(grads, opt_state, params)\n", 339 | "\n", 340 | " params = optax.apply_updates(params, updates)\n", 341 | "\n", 342 | " steps.set_postfix(val=loss)" 343 | ] 344 | }, 345 | { 346 | "attachments": {}, 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "## Test" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 244, 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "name": "stderr", 360 | "output_type": "stream", 361 | "text": [ 362 | "100%|██████████| 1000/1000 [00:00<00:00, 57510.58it/s]\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "n_test = 1000\n", 368 | "\n", 369 | "# Simulate training data\n", 370 | "theta_samples_test = np.random.uniform(low=[0, 0], high=[200, 1], size=(n_test, 2)) # Parameter proposal\n", 371 | "x_samples_test = np.array([bump_simulator(theta, y) for theta in tqdm(theta_samples_test)])\n", 372 | "\n", 373 | "# Normalize the data\n", 374 | "x_samples_test = (x_samples_test - x_mean) / x_std\n", 375 | "theta_samples_test = (theta_samples_test - theta_mean) / theta_std" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 247, 381 | "metadata": {}, 382 | "outputs": [ 383 | { 384 | "data": { 385 | "image/png": "", 386 | "text/plain": [ 387 | "
" 388 | ] 389 | }, 390 | "metadata": {}, 391 | "output_type": "display_data" 392 | } 393 | ], 394 | "source": [ 395 | "from einops import repeat\n", 396 | "\n", 397 | "def sample_posterior(params, x, key, n_samples=100):\n", 398 | " \"\"\" Draw samples from trained flow conditioned on `x`.\n", 399 | " \"\"\"\n", 400 | " def sample(npe):\n", 401 | " enc = npe.featurizer(x)\n", 402 | " enc = repeat(enc, 'n_summary -> n_samples n_summary', n_samples=n_samples) # Repeat enc to match n_samples\n", 403 | " samples = npe.flow.sample(n_samples, key, enc)\n", 404 | " return samples\n", 405 | " return nn.apply(sample, npe)(params)\n", 406 | "\n", 407 | "\n", 408 | "idx = 3\n", 409 | "samples_post = sample_posterior(params, x_samples_test[idx], key, n_samples=100_000)\n", 410 | "\n", 411 | "import corner\n", 412 | "\n", 413 | "corner.corner(np.array(samples_post) * theta_std + theta_mean, truths=theta_samples_test[idx] * theta_std + theta_mean, labels=[\"amp_s\", \"mu_s\"], show_titles=True, title_kwargs={\"fontsize\": 12});" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [] 422 | } 423 | ], 424 | "metadata": { 425 | "kernelspec": { 426 | "display_name": "torch-mps", 427 | "language": "python", 428 | "name": "python3" 429 | }, 430 | "language_info": { 431 | "codemirror_mode": { 432 | "name": "ipython", 433 | "version": 3 434 | }, 435 | "file_extension": ".py", 436 | "mimetype": "text/x-python", 437 | "name": "python", 438 | "nbconvert_exporter": "python", 439 | "pygments_lexer": "ipython3", 440 | "version": "3.9.13" 441 | }, 442 | "orig_nbformat": 4 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 2 446 | } 447 | --------------------------------------------------------------------------------