├── .gitmodules ├── src └── glasflow │ ├── nets │ ├── __init__.py │ └── mlp.py │ ├── transforms │ ├── __init__.py │ ├── utils.py │ └── coupling.py │ ├── distributions │ ├── __init__.py │ ├── uniform.py │ └── resampled.py │ ├── flows │ ├── __init__.py │ ├── realnvp.py │ ├── nsf.py │ ├── base.py │ ├── coupling.py │ └── autoregressive.py │ ├── utils.py │ └── __init__.py ├── .pre-commit-config.yaml ├── tests ├── test_utils.py ├── test_integration.py ├── conftest.py ├── test_transforms │ ├── test_utils.py │ └── test_coupling_transforms.py ├── test_import.py ├── test_flows │ ├── test_autoregressive.py │ ├── test_realnvp.py │ ├── test_nsf.py │ └── test_coupling.py ├── test_nets │ └── test_mlp.py └── test_distributions │ ├── test_resampled.py │ └── test_uniform.py ├── .github └── workflows │ ├── lint.yml │ ├── publish-to-pypi.yml │ ├── tests_nflows_fork.yml │ ├── tests.yml │ └── integration-tests.yml ├── pyproject.toml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── CHANGELOG.md └── examples ├── moons_nvp_example.ipynb └── conditional_example.ipynb /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/nflows"] 2 | path = submodules/nflows 3 | url = https://github.com/uofgravity/nflows.git 4 | -------------------------------------------------------------------------------- /src/glasflow/nets/__init__.py: -------------------------------------------------------------------------------- 1 | """Basic neural networks.""" 2 | 3 | from .mlp import MLP 4 | 5 | __all__ = [ 6 | "MLP", 7 | ] 8 | -------------------------------------------------------------------------------- /src/glasflow/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .coupling import AffineCouplingTransform 3 | 4 | __all__ = ["AffineCouplingTransform"] 5 | -------------------------------------------------------------------------------- /src/glasflow/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | """Distributions for use with normalising flows""" 2 | 3 | from .resampled import ResampledGaussian 4 | from .uniform import MultivariateUniform 5 | 6 | __all__ = [ 7 | "MultivariateUniform", 8 | "ResampledGaussian", 9 | ] 10 | -------------------------------------------------------------------------------- /src/glasflow/flows/__init__.py: -------------------------------------------------------------------------------- 1 | from .realnvp import RealNVP 2 | from .nsf import CouplingNSF 3 | from .autoregressive import ( 4 | MaskedAffineAutoregressiveFlow, 5 | MaskedPiecewiseLinearAutoregressiveFlow, 6 | MaskedPiecewiseQuadraticAutoregressiveFlow, 7 | MaskedPiecewiseCubicAutoregressiveAutoregressiveFlow, 8 | MaskedPiecewiseRationalQuadraticAutoregressiveFlow, 9 | ) 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/pycqa/flake8 8 | rev: 6.1.0 9 | hooks: 10 | - id: flake8 11 | additional_dependencies: 12 | - Flake8-pyproject 13 | - repo: https://github.com/kynan/nbstripout 14 | rev: 0.6.0 15 | hooks: 16 | - id: nbstripout 17 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Test for the glasflow utilities""" 2 | 3 | from glasflow.utils import get_torch_size 4 | import pytest 5 | import torch 6 | 7 | 8 | def test_get_torch_size_int(): 9 | """Assert an int is converted to a tuple""" 10 | assert get_torch_size(2) == torch.Size((2,)) 11 | 12 | 13 | @pytest.mark.parametrize("shape", [(2, 2), [2, 2], torch.tensor([2, 2])]) 14 | def test_get_torch_size_iterables(shape): 15 | """Assert iterables still work""" 16 | assert get_torch_size(shape) == torch.Size((2, 2)) 17 | -------------------------------------------------------------------------------- /src/glasflow/utils.py: -------------------------------------------------------------------------------- 1 | """General utilities""" 2 | 3 | from collections.abc import Iterable 4 | from typing import Union 5 | 6 | import torch 7 | 8 | 9 | def get_torch_size(shape: Union[int, Iterable]) -> torch.Size: 10 | """Get a torch size from a more flexible input shape. 11 | 12 | Parameters 13 | ---------- 14 | shape 15 | The shape to convert to an instance of `torch.Size`. 16 | 17 | Returns 18 | ------- 19 | The torch size 20 | """ 21 | if isinstance(shape, int): 22 | shape = (shape,) 23 | elif isinstance(shape, Iterable): 24 | shape = tuple(shape) 25 | return torch.Size(shape) 26 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Overall integration tests. 3 | """ 4 | 5 | import pytest 6 | import torch 7 | 8 | 9 | @pytest.mark.slow_integration_test 10 | def test_flow_training(FlowClass): 11 | """General integration test for all flows""" 12 | n_inputs = 2 13 | flow = FlowClass(n_inputs=2, n_transforms=2) 14 | 15 | # Draw [0, 1) since some flows only support the unit interval 16 | x = torch.rand(10, n_inputs) 17 | 18 | opt = torch.optim.Adam(flow.parameters()) 19 | 20 | for param in flow.parameters(): 21 | param.grad = None 22 | loss = -flow.log_prob(x).mean() 23 | loss.backward() 24 | opt.step() 25 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | General configuration for tests. 3 | """ 4 | 5 | import glasflow 6 | import glasflow.flows 7 | import pytest 8 | 9 | 10 | @pytest.fixture( 11 | params=[ 12 | glasflow.flows.RealNVP, 13 | glasflow.flows.CouplingNSF, 14 | glasflow.flows.MaskedAffineAutoregressiveFlow, 15 | glasflow.flows.MaskedPiecewiseCubicAutoregressiveAutoregressiveFlow, 16 | glasflow.flows.MaskedPiecewiseLinearAutoregressiveFlow, 17 | glasflow.flows.MaskedPiecewiseQuadraticAutoregressiveFlow, 18 | glasflow.flows.MaskedPiecewiseRationalQuadraticAutoregressiveFlow, 19 | ] 20 | ) 21 | def FlowClass(request): 22 | return request.param 23 | 24 | 25 | def pytest_sessionstart(): 26 | """Log which nflows backend is being using""" 27 | print(f"glasflow config: USE_NFLOWS={glasflow.USE_NFLOWS}") 28 | -------------------------------------------------------------------------------- /tests/test_transforms/test_utils.py: -------------------------------------------------------------------------------- 1 | from glasflow.transforms.utils import SCALE_ACTIVATIONS, get_scale_activation 2 | import torch 3 | import pytest 4 | 5 | 6 | @pytest.mark.parametrize("name", list(SCALE_ACTIVATIONS.keys())) 7 | def test_get_scale_activation_str(name): 8 | assert get_scale_activation(name) is SCALE_ACTIVATIONS[name] 9 | 10 | 11 | def test_get_scale_activation_fn(): 12 | def fn(x): 13 | return torch.sigmoid(x) 14 | 15 | assert get_scale_activation(fn) is fn 16 | 17 | 18 | def test_get_scale_activation_log(): 19 | fn = get_scale_activation("log10") 20 | inputs = torch.tensor([-torch.inf, 0.0, torch.inf]) 21 | expected = torch.exp(torch.tensor([-10.0, 0.0, 10.0])) 22 | assert torch.equal(fn(inputs), expected) 23 | 24 | 25 | def test_get_scale_activation_invalid(): 26 | with pytest.raises(ValueError, match=r"Unknown activation: .*"): 27 | get_scale_activation("invalid") 28 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ main, release* ] 6 | pull_request: 7 | branches: [ main, release* ] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | black: 15 | name: Black 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: psf/black@stable 20 | with: 21 | jupyter: true 22 | options: "--check --diff" 23 | version: "24.2" 24 | flake8: 25 | name: Flake8 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Set up Python 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: '3.x' 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | python -m pip install flake8 Flake8-pyproject 37 | - name: Lint with flake8 38 | run: | 39 | python -m flake8 . 40 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build-n-publish: 9 | name: Build and publish Python 🐍 distributions 📦 to PyPI 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout repository and submodules 13 | uses: actions/checkout@v4 14 | with: 15 | submodules: recursive 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.x" 20 | - name: Install pypa/build 21 | run: >- 22 | python -m 23 | pip install 24 | build 25 | --user 26 | - name: Build a binary wheel and a source tarball 27 | run: >- 28 | python -m 29 | build 30 | --sdist 31 | --wheel 32 | --outdir dist/ 33 | . 34 | - name: Publish distribution 📦 to PyPI 35 | uses: pypa/gh-action-pypi-publish@master 36 | with: 37 | password: ${{ secrets.PYPI_API_TOKEN }} 38 | -------------------------------------------------------------------------------- /.github/workflows/tests_nflows_fork.yml: -------------------------------------------------------------------------------- 1 | name: Check nflows fork 2 | 3 | on: 4 | push: 5 | branches: [ main, release*] 6 | schedule: 7 | # Run tests at 7:00 UTC everyday 8 | - cron: '0 7 * * *' 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | test-nflows: 16 | name: Test (${{ matrix.os }}) 17 | 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: [macOS, Ubuntu, Windows] 22 | runs-on: ${{ matrix.os }}-latest 23 | steps: 24 | - name: Checkout repository and submodules 25 | uses: actions/checkout@v4 26 | with: 27 | submodules: recursive 28 | - name: Set up Python 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: '3.10' 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -e .[nflows-test] 36 | - name: Downgrade numpy 37 | if: runner.os == 'Windows' 38 | run: | 39 | python -m pip install "numpy<2.0" 40 | - name: Test with pytest 41 | run: | 42 | python -m pytest submodules/nflows 43 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Test importing glasflow""" 3 | from importlib import reload 4 | import os 5 | 6 | import glasflow 7 | import pytest 8 | 9 | 10 | @pytest.fixture(autouse=True) 11 | def reload_glasflow(): 12 | """Make sure nessai is reloaded after these tests""" 13 | original_version = glasflow.__version__ 14 | use_nflows_default = os.environ.get("GLASFLOW_USE_NFLOWS", None) 15 | yield 16 | if use_nflows_default is None: 17 | os.environ.pop("GLASFLOW_USE_NFLOWS") 18 | else: 19 | os.environ["GLASFLOW_USE_NFLOWS"] = use_nflows_default 20 | assert os.environ.get("GLASFLOW_USE_NFLOWS") == use_nflows_default 21 | reload(glasflow) 22 | assert glasflow.__version__ == original_version 23 | 24 | 25 | @pytest.mark.requires("nflows") 26 | @pytest.mark.integration_test 27 | def test_glasflow_use_external_nflows(caplog): 28 | """Assert in the import works with the external version of nflows""" 29 | os.environ["GLASFLOW_USE_NFLOWS"] = "True" 30 | reload(glasflow) 31 | assert "using an externally installed version" in str(caplog.text) 32 | 33 | 34 | @pytest.mark.integration_test 35 | def test_glasflow_use_internal_nflows(caplog): 36 | """Assert the import works with the internal version of nflows""" 37 | caplog.set_level("INFO") 38 | os.environ["GLASFLOW_USE_NFLOWS"] = "False" 39 | reload(glasflow) 40 | assert "using its own internal version" in str(caplog.text) 41 | -------------------------------------------------------------------------------- /tests/test_flows/test_autoregressive.py: -------------------------------------------------------------------------------- 1 | from glasflow.flows.autoregressive import ( 2 | MaskedAffineAutoregressiveFlow, 3 | MaskedPiecewiseCubicAutoregressiveAutoregressiveFlow, 4 | MaskedPiecewiseLinearAutoregressiveFlow, 5 | MaskedPiecewiseQuadraticAutoregressiveFlow, 6 | MaskedPiecewiseRationalQuadraticAutoregressiveFlow, 7 | ) 8 | import pytest 9 | import torch 10 | 11 | 12 | @pytest.fixture( 13 | params=[ 14 | MaskedAffineAutoregressiveFlow, 15 | MaskedPiecewiseCubicAutoregressiveAutoregressiveFlow, 16 | MaskedPiecewiseLinearAutoregressiveFlow, 17 | MaskedPiecewiseQuadraticAutoregressiveFlow, 18 | MaskedPiecewiseRationalQuadraticAutoregressiveFlow, 19 | ] 20 | ) 21 | def FlowClass(request): 22 | return request.param 23 | 24 | 25 | @pytest.mark.parametrize("use_random_permutations", [False, True]) 26 | @pytest.mark.parametrize("use_random_masks", [False, True]) 27 | def test_random_mask_and_perms( 28 | FlowClass, use_random_permutations, use_random_masks 29 | ): 30 | n = 10 31 | dims = 2 32 | flow = FlowClass( 33 | n_inputs=dims, 34 | n_transforms=4, 35 | use_random_masks=use_random_masks, 36 | use_random_permutations=use_random_permutations, 37 | use_residual_blocks=False if use_random_masks else True, 38 | ) 39 | x = torch.rand(n, dims) 40 | z, logj = flow.forward(x) 41 | assert z.shape == (n, dims) 42 | assert len(logj) == n 43 | -------------------------------------------------------------------------------- /tests/test_nets/test_mlp.py: -------------------------------------------------------------------------------- 1 | """Tests for the MLP submodule""" 2 | 3 | from unittest.mock import create_autospec 4 | from glasflow.nets.mlp import MLP 5 | import pytest 6 | import torch 7 | 8 | 9 | @pytest.fixture 10 | def mlp(): 11 | return create_autospec(MLP) 12 | 13 | 14 | def test_mlp_init_hidden_layers(mlp): 15 | """Make sure the correct hidden layers are added""" 16 | MLP.__init__(mlp, 2, 1, [20, 10]) 17 | 18 | assert mlp._input_layer.in_features == 2 19 | assert mlp._input_layer.out_features == 20 20 | # One layer will be the input layer 21 | assert len(mlp._hidden_layers) == 1 22 | assert mlp._output_layer.in_features == 10 23 | assert mlp._output_layer.out_features == 1 24 | 25 | 26 | @pytest.mark.parametrize("activate_output", [False, True, torch.sigmoid]) 27 | def test_mlp_init_activate_output(mlp, activate_output): 28 | """Assert the different possible inputs are valid""" 29 | MLP.__init__( 30 | mlp, 31 | 1, 32 | 1, 33 | [ 34 | 1, 35 | ], 36 | activate_output=activate_output, 37 | ) 38 | assert mlp._activate_output is bool(activate_output) 39 | 40 | 41 | @pytest.mark.integration_test 42 | def test_update_weights(): 43 | """Check that the weights can be updated.""" 44 | net = MLP(2, 1, [4]) 45 | 46 | x = torch.randn(10, 2) 47 | y = torch.randn(10, 1) 48 | 49 | net.zero_grad() 50 | y_pred = net(x) 51 | loss = torch.mean((y - y_pred) ** 2.0) 52 | loss.backward() 53 | -------------------------------------------------------------------------------- /src/glasflow/transforms/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for the transforms submodule""" 2 | 3 | from typing import Callable, Union 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | SCALE_ACTIVATIONS = dict( 9 | nflows=lambda x: torch.sigmoid(x + 2) + 1e-3, 10 | nflows_general=lambda x: (F.softplus(x) + 1e-3).clamp(0, 3), 11 | wide=lambda x: (F.softplus(x) + 1e-3).clamp(0, 3), 12 | log1=lambda x: torch.exp(2 * (torch.sigmoid(x) - 0.5)), 13 | log2=lambda x: torch.exp(4 * (torch.sigmoid(x) - 0.5)), 14 | log3=lambda x: torch.exp(6 * (torch.sigmoid(x) - 0.5)), 15 | ) 16 | 17 | 18 | def get_scale_activation(activation: Union[str, Callable]) -> Callable: 19 | """Get the scale activation function. 20 | 21 | If `activation` is not a string, then it is returned. If `activation` is 22 | a string of the form `logN` where `N` is a number, then the activation is 23 | `exp(2N * (sigmoid(x) - 0.5))`. In the current implementation this 24 | corresponds to estimating the log-scale. 25 | """ 26 | if not isinstance(activation, str): 27 | return activation 28 | elif activation in SCALE_ACTIVATIONS: 29 | return SCALE_ACTIVATIONS.get(activation) 30 | elif activation[:3] == "log": 31 | scale = float(activation[3:]) 32 | return lambda x: torch.exp(2 * scale * (torch.sigmoid(x) - 0.5)) 33 | else: 34 | raise ValueError( 35 | f"Unknown activation: {activation}. " 36 | f"Choose from: {SCALE_ACTIVATIONS} or logN, where N is a number." 37 | ) 38 | -------------------------------------------------------------------------------- /tests/test_transforms/test_coupling_transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from glasflow.transforms.coupling import AffineCouplingTransform 3 | import pytest 4 | from unittest.mock import create_autospec, patch 5 | 6 | 7 | @pytest.fixture 8 | def affine_transform(): 9 | return create_autospec(AffineCouplingTransform) 10 | 11 | 12 | @patch( 13 | "glasflow.transforms.coupling.get_scale_activation", 14 | return_value="act_fn", 15 | ) 16 | @patch("glasflow.nflows.transforms.coupling.AffineCouplingTransform.__init__") 17 | def test_affine_coupling_init(mock_init, mock_get, affine_transform): 18 | mask = [1, -1] 19 | create_fn = object() 20 | 21 | AffineCouplingTransform.__init__( 22 | affine_transform, 23 | mask, 24 | create_fn, 25 | unconditional_transform=None, 26 | scale_activation="test", 27 | ) 28 | 29 | mock_get.assert_called_once_with("test") 30 | mock_init.assert_called_once_with( 31 | mask, 32 | create_fn, 33 | unconditional_transform=None, 34 | scale_activation="act_fn", 35 | ) 36 | 37 | 38 | def test_affine_coupling_invalid(affine_transform): 39 | """Assert an error is raised that is not `scale_activation`""" 40 | mask = [1, -1] 41 | create_fn = object() 42 | 43 | with pytest.raises( 44 | TypeError, 45 | match=r"unexpected keyword argument 'invalid_test'", 46 | ): 47 | AffineCouplingTransform.__init__( 48 | affine_transform, 49 | mask, 50 | create_fn, 51 | invalid_test=None, 52 | ) 53 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Unit tests 2 | 3 | on: 4 | push: 5 | branches: [ main, release* ] 6 | pull_request: 7 | branches: [ main, release* ] 8 | schedule: 9 | # Run tests at 7:00 UTC everyday 10 | - cron: '0 7 * * *' 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | unittests: 18 | name: Python ${{ matrix.python-version }} (nflows=${{ matrix.use-nflows }}) (${{ matrix.os }}) 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: [macOS, Ubuntu, Windows] 24 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 25 | use-nflows: [True, False] 26 | runs-on: ${{ matrix.os }}-latest 27 | 28 | steps: 29 | - name: Checkout repository and submodules 30 | uses: actions/checkout@v4 31 | with: 32 | submodules: recursive 33 | - name: Set up Python ${{ matrix.python-version }} 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -e .[dev] 41 | - name: Downgrade numpy 42 | if: runner.os == 'Windows' 43 | run: | 44 | python -m pip install "numpy<2.0" 45 | - name: Optionally install nflows 46 | if: ${{ matrix.use-nflows }} 47 | run: | 48 | pip install nflows 49 | - name: Set nflows version 50 | run: | 51 | echo "GLASFLOW_USE_NFLOWS=${{ matrix.use-nflows }}" >> $GITHUB_ENV 52 | - name: Print environment variables 53 | run: | 54 | env 55 | - name: Test with pytest 56 | run: | 57 | python -m pytest --without-integration --without-slow-integration 58 | -------------------------------------------------------------------------------- /.github/workflows/integration-tests.yml: -------------------------------------------------------------------------------- 1 | name: Integration tests 2 | 3 | on: 4 | push: 5 | branches: [ main, release* ] 6 | pull_request: 7 | branches: [ main, release* ] 8 | schedule: 9 | # Run tests at 7:00 UTC everyday 10 | - cron: '0 7 * * *' 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | integration-tests: 18 | name: Python ${{ matrix.python-version }} (nflows=${{ matrix.use-nflows }}) (${{ matrix.os }}) 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: [macOS, Ubuntu, Windows] 24 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 25 | use-nflows: [True, False] 26 | runs-on: ${{ matrix.os }}-latest 27 | 28 | steps: 29 | - name: Checkout repository and submodules 30 | uses: actions/checkout@v4 31 | with: 32 | submodules: recursive 33 | - name: Set up Python ${{ matrix.python-version }} 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -e .[dev] 41 | - name: Downgrade numpy 42 | if: runner.os == 'Windows' 43 | run: | 44 | python -m pip install "numpy<2.0" 45 | - name: Optionally install nflows 46 | if: ${{ matrix.use-nflows }} 47 | run: | 48 | pip install nflows 49 | - name: Set nflows version 50 | run: | 51 | echo "GLASFLOW_USE_NFLOWS=${{ matrix.use-nflows }}" >> $GITHUB_ENV 52 | - name: Print environment variables 53 | run: | 54 | env 55 | - name: Run integration tests (fast) 56 | run: | 57 | python -m pytest -m "integration_test" 58 | - name: Run integration tests (slow) 59 | run: | 60 | python -m pytest -m "slow_integration_test" 61 | -------------------------------------------------------------------------------- /tests/test_flows/test_realnvp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from glasflow import USE_NFLOWS 4 | from glasflow.flows import RealNVP 5 | from glasflow.transforms.utils import SCALE_ACTIVATIONS 6 | 7 | import pytest 8 | 9 | 10 | @pytest.mark.parametrize("volume_preserving", [False, True]) 11 | def test_coupling_flow_init(volume_preserving): 12 | """Test the initialise method""" 13 | RealNVP(2, 2, volume_preserving=volume_preserving) 14 | 15 | 16 | def test_real_nvp_defaults(): 17 | flow = RealNVP(2, 2) 18 | x = torch.randn(10, 2) 19 | flow.forward(x) 20 | 21 | 22 | @pytest.mark.integration_test 23 | @pytest.mark.flaky(reruns=5) 24 | def test_realnvp_wide_scale(): 25 | """Test RealNVP with the 'wide' scaling method""" 26 | flow = RealNVP(2, 2, scaling_method="wide") 27 | x = torch.randn(10, 2) 28 | z, log_j = flow.forward(x) 29 | x_out, log_j_out = flow.inverse(z) 30 | 31 | assert torch.allclose(x, x_out) 32 | assert torch.allclose(log_j, -log_j_out) 33 | 34 | 35 | @pytest.mark.integration_test 36 | @pytest.mark.flaky(reruns=5) 37 | @pytest.mark.parametrize("scale_activation", list(SCALE_ACTIVATIONS.keys())) 38 | def test_realnvp_scale_activation(scale_activation): 39 | flow = RealNVP(2, 2, scale_activation=scale_activation) 40 | x = torch.randn(10, 2) 41 | z, log_j = flow.forward(x) 42 | x_out, log_j_out = flow.inverse(z) 43 | assert torch.allclose(x, x_out) 44 | assert torch.allclose(log_j, -log_j_out) 45 | 46 | 47 | @pytest.mark.skipif( 48 | USE_NFLOWS is False, reason="Test only applies when using nflows" 49 | ) 50 | def test_affine_coupling_warning_nflows(caplog): 51 | """Assert a warning is printed if using `scale_activation` with nflows""" 52 | RealNVP(2, 2, scale_activation="log1") 53 | assert "Trying without `scale_activation`" in caplog.text 54 | assert ( 55 | "Using affine coupling transform without `scale_activation`" 56 | in caplog.text 57 | ) 58 | -------------------------------------------------------------------------------- /tests/test_distributions/test_resampled.py: -------------------------------------------------------------------------------- 1 | """Tests for the resampled distributions""" 2 | 3 | from glasflow.distributions.resampled import ResampledGaussian 4 | from glasflow.nflows.distributions import StandardNormal 5 | from glasflow.nets.mlp import MLP 6 | import pytest 7 | import torch 8 | 9 | 10 | @pytest.mark.parametrize("trainable", [False, True]) 11 | def test_resampled_gaussian_update_weights(trainable): 12 | """Assert the weights can be updated. 13 | 14 | Test with both fixed and trainable mean and variance. 15 | """ 16 | dims = 2 17 | acc_fn = MLP( 18 | dims, 19 | 1, 20 | [ 21 | 10, 22 | ], 23 | activate_output=torch.sigmoid, 24 | ) 25 | dist = ResampledGaussian(dims, acc_fn, trainable=trainable) 26 | 27 | x = torch.randn(10, 2) 28 | 29 | dist.zero_grad() 30 | # Loss is normal flow loss 31 | loss = -torch.mean(dist.log_prob(x)) 32 | loss.backward() 33 | dist.estimate_normalisation_constant() 34 | 35 | 36 | def test_log_prob_gaussian(): 37 | """Assert the gaussian log-probability is correct""" 38 | shape = (2,) 39 | dist = ResampledGaussian(shape, lambda x: 1) 40 | ref_dist = StandardNormal(shape) 41 | x = ref_dist.sample(10) 42 | out = dist._log_prob_gaussian(x) 43 | expected = ref_dist.log_prob(x) 44 | torch.equal(out, expected) 45 | 46 | 47 | def test_sample(): 48 | """Assert samples are drawn with the correct shape""" 49 | dims = 2 50 | n = 10 51 | dist = ResampledGaussian( 52 | dims, lambda x: torch.ones(len(x), 1), trainable=False 53 | ) 54 | out = dist._sample(n) 55 | assert out.shape == (n, dims) 56 | 57 | 58 | def test_log_prob(): 59 | """Test the log-prob method""" 60 | dims = 2 61 | n = 10 62 | x = torch.randn(n, 2) 63 | dist = ResampledGaussian( 64 | dims, lambda x: torch.ones(len(x), 1), trainable=False 65 | ) 66 | out = dist.log_prob(x) 67 | assert len(out) == n 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64.0.3", "wheel", "setuptools_scm[toml]"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "glasflow" 7 | description = "Normalising flows using nflows" 8 | authors = [ 9 | {name = "Michael J. Williams", email = "michaeljw1@googlemail.com"}, 10 | {name = "Federico Stachurski", email = "f.stachurski.1@research.gla.ac.uk"}, 11 | {name = "Jordan McGinn"}, 12 | {name = "John Veitch", email = "john.veitch@glasgow.ac.uk"}, 13 | ] 14 | readme = "README.md" 15 | requires-python = ">=3.8" 16 | license = {text = "MIT"} 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ] 22 | keywords = [ 23 | "normalising flows", 24 | "normalizing flows", 25 | "machine learning", 26 | ] 27 | dependencies = [ 28 | "numpy", 29 | "torch>=1.11.0", 30 | ] 31 | 32 | dynamic = [ 33 | "version", 34 | ] 35 | [project.urls] 36 | "Homepage" = "https://github.com/uofgravity/glasflow" 37 | 38 | [project.optional-dependencies] 39 | nflows = ["nflows"] 40 | nflows-test = [ 41 | "pytest", 42 | "pytest-rerunfailures", 43 | "torchtestcase", 44 | "UMNN", 45 | ] 46 | dev = [ 47 | "black[jupyter]", 48 | "pre-commit", 49 | "pytest", 50 | "pytest-cov", 51 | "pytest-integration", 52 | "pytest-requires", 53 | "pytest-rerunfailures", 54 | ] 55 | examples = [ 56 | "ipykernel", 57 | "matplotlib", 58 | "scikit-learn", 59 | "scipy", 60 | "seaborn", 61 | ] 62 | 63 | [tool.setuptools.package-dir] 64 | glasflow = "src/glasflow" 65 | "glasflow.nflows" = "submodules/nflows/nflows" 66 | 67 | [tool.setuptools_scm] 68 | 69 | [tool.black] 70 | line-length = 79 71 | target-version = ["py38", "py39", "py310", "py311", "py312"] 72 | extend-exclude = "submodules" 73 | 74 | [tool.flake8] 75 | exclude = [ 76 | "submodules", 77 | "build", 78 | ] 79 | ignore = ["E203", "E266", "E501", "W503", "F403", "F401"] 80 | max-line-length = 79 81 | max-complexity = 18 82 | select = ["B", "C", "E", "F", "W", "T4", "B9"] 83 | 84 | [tool.pytest.ini_options] 85 | addopts = [ 86 | "--import-mode=importlib", 87 | ] 88 | testpaths = [ 89 | "tests" 90 | ] 91 | -------------------------------------------------------------------------------- /src/glasflow/distributions/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multidimensional uniform distribution. 3 | """ 4 | 5 | from typing import Union 6 | 7 | from glasflow.nflows.distributions import Distribution 8 | import torch 9 | 10 | 11 | class MultivariateUniform(Distribution): 12 | def __init__( 13 | self, low: Union[torch.Tensor, float], high: Union[torch.Tensor, float] 14 | ): 15 | """Multivariate uniform distribution defined on a box. 16 | 17 | Based on this implementation: \ 18 | https://github.com/bayesiains/nflows/pull/17 but with fixes for 19 | CUDA support. 20 | 21 | Does not support conditional inputs. 22 | 23 | Parameters 24 | ----------- 25 | low : Union[torch.Tensor, float] 26 | Lower range (inclusive). 27 | high : Union[torch.Tensor, float] 28 | Upper range (exclusive). 29 | """ 30 | super().__init__() 31 | 32 | low, high = map(torch.as_tensor, [low, high]) 33 | 34 | if low.shape != high.shape: 35 | raise ValueError("low and high are not the same shape") 36 | 37 | if not (low < high).byte().all(): 38 | raise ValueError("low has elements that are larger than high") 39 | 40 | self._shape = low.shape 41 | self.register_buffer("low", low) 42 | self.register_buffer("high", high) 43 | self.register_buffer( 44 | "_log_prob_value", -torch.sum(torch.log(high - low)) 45 | ) 46 | 47 | def _log_prob(self, inputs, context): 48 | if context is not None: 49 | raise NotImplementedError( 50 | "Context is not supported by MultidimensionalUniform!" 51 | ) 52 | lb = self.low.le(inputs).type_as(self.low).prod(-1) 53 | ub = self.high.gt(inputs).type_as(self.low).prod(-1) 54 | return torch.log(lb.mul(ub)) + self._log_prob_value 55 | 56 | def _sample(self, num_samples, context): 57 | if context is not None: 58 | raise NotImplementedError( 59 | "Context is not supported by MultidimensionalUniform!" 60 | ) 61 | low_expanded = self.low.expand(num_samples, *self._shape) 62 | high_expanded = self.high.expand(num_samples, *self._shape) 63 | samples = low_expanded + torch.rand( 64 | num_samples, *self._shape, device=self.low.device 65 | ) * (high_expanded - low_expanded) 66 | return samples 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to glasflow 2 | 3 | ## Installation 4 | 5 | To install `glasflow` and contribute clone the repo and install the additional dependencies with: 6 | 7 | ```shell 8 | pip install -e .[dev,nflows] 9 | ``` 10 | 11 | **Note:** Make sure the submodules are up-to-date, see the [git documentation on submodules](https://git-scm.com/book/en/v2/Git-Tools-Submodules) for details. 12 | 13 | **Note:** because of a long-standing bug in `setuptools` ([#230](https://github.com/pypa/setuptools/issues/230)), editable installs (`pip install -e`) are not supported by older versions of `setuptools`. Editable installs require `setuptools>=64.0.3`. You may also require an up-to-date version of `pip`. 14 | 15 | ## Format checking 16 | 17 | We use [pre-commit](https://pre-commit.com/) to re-format code using `black` and check the quality of code suing `flake8` before committing. 18 | 19 | This requires some setup: 20 | 21 | ```shell 22 | pip install pre-commit # Should already be installed 23 | cd glasflow 24 | pre-commit install 25 | ``` 26 | 27 | Now we you run `$ git commit` `pre-commit` will run a series of checks. Some checks will automatically change the code and others will print warnings that you must address and re-commit. 28 | 29 | ## Testing glasflow 30 | 31 | When contributing code to `glasflow` please ensure that you also contribute corresponding unit tests and integration tests where applicable. We test `glasflow` using `pytest` and strive to test all of the core functionality in `glasflow`. Tests should be contained with the `tests` directory and follow the naming convention `test_.py`. We also welcome improvements to the existing tests and testing infrastructure. 32 | 33 | The tests can be run from the root directory using 34 | 35 | ```console 36 | pytest 37 | ``` 38 | 39 | Specific tests can be run using 40 | 41 | ```console 42 | pytest tests/test_.py 43 | ``` 44 | 45 | **Note:** the configuration for `pytest` is pulled from `pyproject.toml` 46 | 47 | See the `pytest` [documentation](https://docs.pytest.org/) for further details on how to write tests using `pytest`. 48 | 49 | ### Testing with nflows 50 | 51 | The continuous integration in glasflow tests with the environment variable `GLASFLOW_USE_NFLOWS` set to `True` and `False` independently. We recommend running the test suite with both values prior opening a pull request. For example: 52 | 53 | ```console 54 | $ export GLASFLOW_USE_NFLOWS=false 55 | $ pytest 56 | ... 57 | $ export GLASFLOW_USE_NFLOWS=true 58 | $ pytest 59 | ... 60 | ``` 61 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # glasflow 2 | 3 | MIT License 4 | 5 | Copyright (c) 2021-2022 Jordan McGinn, Federico Stachurski, John Veitch, Michael J. Williams 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | # nflows 26 | 27 | This software includes a fork of [nflows](https://github.com/uofgravity/nflows) 28 | which is licensed under an MIT License: 29 | 30 | MIT License 31 | 32 | Copyright (c) 2020 Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios 33 | 34 | Permission is hereby granted, free of charge, to any person obtaining a copy 35 | of this software and associated documentation files (the "Software"), to deal 36 | in the Software without restriction, including without limitation the rights 37 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 38 | copies of the Software, and to permit persons to whom the Software is 39 | furnished to do so, subject to the following conditions: 40 | 41 | The above copyright notice and this permission notice shall be included in all 42 | copies or substantial portions of the Software. 43 | 44 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 45 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 46 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 47 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 48 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 49 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 50 | SOFTWARE. 51 | -------------------------------------------------------------------------------- /src/glasflow/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | glasflow 4 | -------- 5 | 6 | Implementations of normalising flows in PyTorch based on nflows. 7 | 8 | Code is hosted at: https://github.com/uofgravity/glasflow 9 | 10 | nflows: https://github.com/bayesiains/nflows 11 | """ 12 | import importlib.util 13 | import logging 14 | import os 15 | import pkgutil 16 | import sys 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def _import_submodules(module): 22 | """Recursively import all submodules from a module. 23 | 24 | Imports all of the submodules and registers them as 25 | glasflow.. 26 | 27 | Based on https://stackoverflow.com/a/25562415 28 | """ 29 | for _, name, is_pkg in pkgutil.walk_packages(module.__path__): 30 | full_name = module.__name__ + "." + name 31 | submodule = importlib.import_module(full_name) 32 | sys.modules["glasflow." + full_name] = submodule 33 | if is_pkg: 34 | _import_submodules(submodule) 35 | 36 | 37 | if "nflows" in sys.modules or importlib.util.find_spec("nflows"): 38 | NFLOWS_INSTALLED = True 39 | else: 40 | NFLOWS_INSTALLED = False 41 | 42 | USE_NFLOWS = os.environ.get("GLASFLOW_USE_NFLOWS", "False").lower() in [ 43 | "true", 44 | "1", 45 | ] 46 | if USE_NFLOWS: 47 | logger.warning( 48 | "glasflow is using an externally installed version of nflows" 49 | ) 50 | if not NFLOWS_INSTALLED: 51 | raise RuntimeError( 52 | "nflows is not installed. Set the environment variable " 53 | "`GLASFLOW_USE_NFLOWS=False` to use the included fork of nflows." 54 | ) 55 | # Register glasflow.nflows so it points to nflows 56 | import nflows 57 | 58 | sys.modules["glasflow.nflows"] = nflows 59 | # Register all submodules in nflows so glaflow.nflows. points to 60 | # the nflows installation 61 | _import_submodules(nflows) 62 | else: 63 | logger.info("glasflow is using its own internal version of nflows") 64 | 65 | from .flows import ( # noqa 66 | CouplingNSF, 67 | RealNVP, 68 | ) 69 | 70 | try: 71 | from importlib.metadata import version, PackageNotFoundError 72 | except ImportError: # for Python < 3.8 73 | from importlib_metadata import version, PackageNotFoundError 74 | 75 | try: 76 | __version__ = version(__name__) 77 | if USE_NFLOWS: 78 | __version__ += "+nflows-ext" 79 | else: 80 | __version__ += "+nflows-int" 81 | except PackageNotFoundError: 82 | # package is not installed 83 | pass 84 | 85 | 86 | __all__ = [ 87 | "CouplingNSF", 88 | "RealNVP", 89 | ] 90 | -------------------------------------------------------------------------------- /src/glasflow/transforms/coupling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Alternative implementations of coupling transforms""" 3 | import logging 4 | import warnings 5 | 6 | from glasflow.nflows.transforms.coupling import ( 7 | AffineCouplingTransform as BaseAffineCouplingTransform, 8 | ) 9 | import torch.nn.functional as F 10 | 11 | from .utils import get_scale_activation 12 | from .. import USE_NFLOWS 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class AffineCouplingTransform(BaseAffineCouplingTransform): 19 | """Modified affine coupling transform that includes predefined scale 20 | activations. 21 | 22 | Adds the option specify `scale_activation` as a string which is passed to 23 | `get_scale_activation` to get the corresponding function. Also supports 24 | specifying the function instead of string. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | mask, 30 | transform_net_create_fn, 31 | unconditional_transform=None, 32 | scaling_method=None, 33 | scale_activation="nflows_general", 34 | **kwargs, 35 | ): 36 | if scaling_method is not None: 37 | warnings.warn( 38 | ( 39 | "scaling_method is deprecated and will be removed in a " 40 | "future release. Use `scale_activation` instead." 41 | ), 42 | FutureWarning, 43 | ) 44 | scale_activation = scaling_method 45 | 46 | scale_activation = get_scale_activation(scale_activation) 47 | 48 | try: 49 | super().__init__( 50 | mask, 51 | transform_net_create_fn, 52 | unconditional_transform=unconditional_transform, 53 | scale_activation=scale_activation, 54 | **kwargs, 55 | ) 56 | except TypeError as e: 57 | if USE_NFLOWS: 58 | logger.error( 59 | ( 60 | f"Could not initialise transform with with error: {e}. " 61 | "The version of `nflows` being used may not support " 62 | "`scale_activation`. Trying without `scale_activation`." 63 | " Full traceback:" 64 | ), 65 | exc_info=True, 66 | ) 67 | super().__init__( 68 | mask, 69 | transform_net_create_fn, 70 | unconditional_transform=unconditional_transform, 71 | **kwargs, 72 | ) 73 | logger.warning( 74 | "Using affine coupling transform without " 75 | "`scale_activation`, this is not recommended!" 76 | ) 77 | else: 78 | raise e 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7108558.svg)](https://doi.org/10.5281/zenodo.7108558) 2 | [![PyPI](https://img.shields.io/pypi/v/glasflow)](https://pypi.org/project/glasflow/) 3 | [![Conda Version](https://img.shields.io/conda/vn/conda-forge/glasflow.svg)](https://anaconda.org/conda-forge/glasflow) 4 | 5 | # Glasflow 6 | 7 | glasflow is a Python library containing a collection of [Normalizing flows](https://arxiv.org/abs/1912.02762) using [PyTorch](https://pytorch.org). It builds upon [nflows](https://github.com/bayesiains/nflows). 8 | 9 | ## Installation 10 | 11 | glasflow is available to install via `pip`: 12 | 13 | ```shell 14 | pip install glasflow 15 | ``` 16 | 17 | or via `conda`: 18 | 19 | ```shell 20 | conda install glasflow -c conda-forge 21 | ``` 22 | 23 | ## PyTorch 24 | 25 | **Important:** `glasflow` supports using CUDA devices but it is not a requirement and in most uses cases it provides little to no benefit. 26 | 27 | By default the version of PyTorch installed by `pip` or `conda` will not necessarily match the drivers on your system, to install a different version with the correct CUDA support see the PyTorch homepage for instructions: . 28 | 29 | ## Usage 30 | 31 | To define a RealNVP flow: 32 | 33 | ```python 34 | from glasflow import RealNVP 35 | 36 | # define RealNVP flow. Change hyperparameters as necessary. 37 | flow = RealNVP( 38 | n_inputs=2, 39 | n_transforms=5, 40 | n_neurons=32, 41 | batch_norm_between_transforms=True 42 | ) 43 | ``` 44 | 45 | Please see [glasflow/examples](https://github.com/uofgravity/glasflow/tree/main/examples) for a typical training regime example. 46 | 47 | ## nflows 48 | 49 | glasflow uses a fork of nflows which is included as submodule in glasflow and can used imported as follows: 50 | 51 | ```python 52 | import glasflow.nflows as nflows 53 | ``` 54 | 55 | It contains various bugfixes which, as of writing this, are not included in a current release of `nflows`. 56 | 57 | ### Using standard nflows 58 | 59 | There is also the option to use an independent install of nflows (if installed) by setting an environment variable. 60 | 61 | ```shell 62 | export GLASFLOW_USE_NFLOWS=True 63 | ``` 64 | 65 | After setting this variable `glasflow.nflows` will point to the version of nflows installed in the current python environment. 66 | 67 | **Note:** this must be set prior to importing glasflow. 68 | 69 | ## Contributing 70 | 71 | Pull requests are welcome. You can review the contribution guidelines [here](https://github.com/uofgravity/glasflow/blob/main/CONTRIBUTING.md). For major changes, please open an issue first to discuss what you would like to change. 72 | 73 | ## Citing 74 | 75 | If you use glasflow in your work please cite [our DOI](https://doi.org/10.5281/zenodo.7108558). We also recommend you also cite nflows following the guidelines in the [nflows readme](https://github.com/uofgravity/nflows#citing-nflows). 76 | -------------------------------------------------------------------------------- /tests/test_flows/test_nsf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Tests for neural spline flows. 4 | """ 5 | from unittest.mock import MagicMock, patch 6 | 7 | import numpy as np 8 | import pytest 9 | import torch 10 | 11 | from glasflow.flows import CouplingNSF 12 | from glasflow.distributions import MultivariateUniform 13 | 14 | 15 | @pytest.mark.parametrize("num_bins", [4, 10]) 16 | def test_coupling_nsf_init(num_bins): 17 | """Test the initialise method""" 18 | CouplingNSF(2, 2, num_bins=num_bins) 19 | 20 | 21 | def test_init_uniform_distribution(): 22 | """Assert a uniform distribution is created and used""" 23 | expected_low = torch.zeros(2) 24 | expected_high = torch.ones(2) 25 | dist = MagicMock(spec=MultivariateUniform) 26 | 27 | with patch( 28 | "glasflow.distributions.MultivariateUniform", return_value=dist 29 | ) as mock_dist, patch( 30 | "glasflow.flows.nsf.CouplingFlow.__init__" 31 | ) as mock_init: 32 | CouplingNSF( 33 | n_inputs=2, 34 | n_transforms=2, 35 | distribution="uniform", 36 | tail_bound=10.0, 37 | tail_type="linear", 38 | ) 39 | 40 | dist_kwargs = mock_dist.call_args[1] 41 | assert torch.equal(dist_kwargs["low"], expected_low) 42 | assert torch.equal(dist_kwargs["high"], expected_high) 43 | 44 | kwargs = mock_init.call_args[1] 45 | assert kwargs["tail_bound"] == 1.0 46 | assert kwargs["tails"] is None 47 | assert kwargs["distribution"] is dist 48 | assert kwargs["batch_norm_between_transforms"] is False 49 | 50 | 51 | @pytest.mark.integration_test 52 | @pytest.mark.flaky(reruns=5) 53 | def test_coupling_nsf_forward_inverse(): 54 | """Make sure the flow is invertible""" 55 | x = torch.randn(10, 2) 56 | flow = CouplingNSF(2, 2) 57 | 58 | with torch.no_grad(): 59 | x_prime, log_prob = flow.forward(x) 60 | x_out, log_prob_inv = flow.inverse(x_prime) 61 | 62 | np.testing.assert_array_almost_equal(x, x_out) 63 | np.testing.assert_array_almost_equal(log_prob, -log_prob_inv) 64 | 65 | 66 | @pytest.mark.integration_test 67 | @pytest.mark.flaky(reruns=5) 68 | def test_coupling_nsf_uniform(): 69 | """Integration test to make sure core functions work as intended with the 70 | uniform latent space. 71 | """ 72 | flow = CouplingNSF(2, 2, distribution="uniform") 73 | 74 | x = torch.rand(100, 2) 75 | 76 | with torch.no_grad(): 77 | z, log_j = flow.forward(x) 78 | x_inv, log_j_inv = flow.inverse(z) 79 | log_prob = flow.log_prob(x) 80 | x_out = flow.sample(10) 81 | 82 | # n_dims * log(1 - 0) = 0 83 | # So just Jacobian 84 | expected_log_prob = log_j.numpy() 85 | 86 | np.testing.assert_array_almost_equal(x, x_inv) 87 | np.testing.assert_array_almost_equal(log_j, -log_j_inv) 88 | np.testing.assert_array_equal(log_prob, expected_log_prob) 89 | assert x_out.shape == (10, 2) 90 | -------------------------------------------------------------------------------- /tests/test_distributions/test_uniform.py: -------------------------------------------------------------------------------- 1 | "Tests for the Multivariate uniform distribution" 2 | from unittest.mock import create_autospec 3 | 4 | import pytest 5 | import torch 6 | 7 | from glasflow.distributions.uniform import MultivariateUniform 8 | 9 | 10 | @pytest.fixture() 11 | def dist(): 12 | return create_autospec(MultivariateUniform) 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "low, high", [(0.0, 1.0), (torch.zeros(2), torch.ones(2))] 17 | ) 18 | def test_init(low, high): 19 | """Assert different input types are valid""" 20 | dist = MultivariateUniform(low, high) 21 | assert dist._log_prob_value is not None 22 | 23 | 24 | def test_init_invalid_shapes(): 25 | """Assert an error is raised if the shapes are different""" 26 | with pytest.raises( 27 | ValueError, match=r"low and high are not the same shape" 28 | ): 29 | MultivariateUniform(torch.tensor(0), torch.tensor([1, 1])) 30 | 31 | 32 | def test_init_invalid_bounds(): 33 | """Assert an error is raised if low !< high""" 34 | with pytest.raises( 35 | ValueError, match=r"low has elements that are larger than high" 36 | ): 37 | MultivariateUniform(torch.tensor([1, 0]), torch.tensor([1, 1])) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "inputs, target", 42 | [ 43 | (torch.tensor([0.5, 0.5]), -torch.log(torch.tensor(4.0))), 44 | (torch.tensor([-1.0, -1.0]), -torch.tensor(torch.inf)), 45 | (torch.tensor([-1.0, 0.5]), -torch.tensor(torch.inf)), 46 | ( 47 | torch.tensor([[-1.0, 0.5], [1.0, 1.0]]), 48 | torch.log(torch.tensor([0.0, 1.0 / 4.0])), 49 | ), 50 | ], 51 | ) 52 | def test_log_prob(dist, inputs, target): 53 | """Assert log prob returns the correct value""" 54 | dist.low = torch.zeros(2) 55 | dist.high = 2 * torch.ones(2) 56 | dist._log_prob_value = -torch.log(torch.tensor(4.0)) 57 | 58 | out = MultivariateUniform._log_prob(dist, inputs, None) 59 | assert torch.equal(out, target) 60 | 61 | 62 | @pytest.mark.parametrize("num_samples", [1, 10]) 63 | def test_sample(dist, num_samples): 64 | """Assert the correct number of samples are returned and they're in the 65 | correct range. 66 | """ 67 | n_dims = 2 68 | dist.low = torch.zeros(n_dims) 69 | dist.high = 2 * torch.ones(n_dims) 70 | dist._shape = dist.low.shape 71 | 72 | samples = MultivariateUniform._sample(dist, num_samples, None) 73 | 74 | assert samples.shape == (num_samples, n_dims) 75 | assert (samples < dist.high).all() 76 | assert (samples >= dist.low).all() 77 | 78 | 79 | def test_context_error_log_prob(dist): 80 | """Assert an error is raised if log_prob is called and is not None.""" 81 | with pytest.raises( 82 | NotImplementedError, 83 | match="Context is not supported by MultidimensionalUniform!", 84 | ): 85 | MultivariateUniform._log_prob(dist, torch.ones(1), torch.ones(1)) 86 | 87 | 88 | def test_context_error_sample(dist): 89 | """Assert an error is raised if sample is called and context is not None.""" 90 | with pytest.raises( 91 | NotImplementedError, 92 | match="Context is not supported by MultidimensionalUniform!", 93 | ): 94 | MultivariateUniform._sample(dist, 10, torch.ones(1)) 95 | -------------------------------------------------------------------------------- /src/glasflow/flows/realnvp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implementation of RealNVP. 4 | """ 5 | from glasflow.nflows.transforms.coupling import ( 6 | AdditiveCouplingTransform, 7 | ) 8 | import torch.nn.functional as F 9 | from .coupling import CouplingFlow 10 | from ..transforms import AffineCouplingTransform 11 | 12 | 13 | class RealNVP(CouplingFlow): 14 | """Implementation of Real Non-Volume Preserving Flows. 15 | 16 | See: https://arxiv.org/abs/1605.08803 17 | 18 | Parameters 19 | ---------- 20 | n_inputs : int 21 | Number of inputs 22 | n_transforms : int 23 | Number of transforms 24 | n_conditional_inputs: int 25 | Number of conditionals inputs 26 | n_neurons : int 27 | Number of neurons per residual block in each transform 28 | n_blocks_per_transform : int 29 | Number of residual blocks per transform 30 | batch_norm_within_blocks : bool 31 | Enable batch normalisation within each residual block 32 | batch_norm_between_transforms : bool 33 | Enable batch norm between transforms 34 | activation : function 35 | Activation function to use. Defaults to ReLU 36 | dropout_probability : float 37 | Amount of dropout to apply. 0 being no dropout and 1 being drop 38 | all connections 39 | linear_transform : str, {'permutation', 'lu', 'svd', None} 40 | Linear transform to apply before each coupling transform. 41 | distribution : :obj:`nflows.distribution.Distribution` 42 | Distribution object to use for that latent spae. If None, an n-d 43 | Gaussian is used. 44 | mask : Union[torch.Tensor, list, numpy.ndarray] 45 | Mask or array of masks to use to construct the flow. If not specified, 46 | an alternating binary mask will be used. 47 | volume_preserving : bool, optional 48 | If True use additive transforms that preserve volume. 49 | kwargs : 50 | Keyword arguments passed to either 51 | :py:obj:`nflows.transforms.coupling.AdditiveCouplingTransform` or 52 | :py:obj:`glasflow.transforms.coupling.AffineCouplingTransform`. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | n_inputs, 58 | n_transforms, 59 | n_conditional_inputs=None, 60 | n_neurons=32, 61 | n_blocks_per_transform=2, 62 | batch_norm_within_blocks=False, 63 | batch_norm_between_transforms=False, 64 | activation=F.relu, 65 | dropout_probability=0.0, 66 | linear_transform=None, 67 | distribution=None, 68 | mask=None, 69 | volume_preserving=False, 70 | **kwargs, 71 | ): 72 | if volume_preserving: 73 | transform_class = AdditiveCouplingTransform 74 | else: 75 | transform_class = AffineCouplingTransform 76 | super().__init__( 77 | transform_class, 78 | n_inputs, 79 | n_transforms, 80 | n_conditional_inputs=n_conditional_inputs, 81 | n_neurons=n_neurons, 82 | n_blocks_per_transform=n_blocks_per_transform, 83 | batch_norm_within_blocks=batch_norm_within_blocks, 84 | batch_norm_between_transforms=batch_norm_between_transforms, 85 | activation=activation, 86 | dropout_probability=dropout_probability, 87 | linear_transform=linear_transform, 88 | distribution=distribution, 89 | mask=mask, 90 | **kwargs, 91 | ) 92 | -------------------------------------------------------------------------------- /src/glasflow/flows/nsf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implementation of Neural Spline Flows. 4 | 5 | See: https://arxiv.org/abs/1906.04032 6 | """ 7 | from glasflow.nflows.transforms.coupling import ( 8 | PiecewiseRationalQuadraticCouplingTransform, 9 | ) 10 | import torch 11 | import torch.nn.functional as F 12 | from .coupling import CouplingFlow 13 | 14 | 15 | class CouplingNSF(CouplingFlow): 16 | """Implementation of Neural Spline Flows using a coupling transform. 17 | 18 | Supports use of a uniform distribution for the latent space. This 19 | automatically disables the tails and sets the bounds to [0, 1). 20 | 21 | Parameters 22 | ---------- 23 | n_inputs : int 24 | Number of inputs 25 | n_transforms : int 26 | Number of transforms 27 | n_conditional_inputs: int 28 | Number of conditionals inputs 29 | n_neurons : int 30 | Number of neurons per residual block in each transform 31 | n_blocks_per_transform : int 32 | Number of residual blocks per transform 33 | batch_norm_within_blocks : bool 34 | Enable batch normalisation within each residual block 35 | batch_norm_between_transforms : bool 36 | Enable batch norm between transforms 37 | activation : function 38 | Activation function to use. Defaults to ReLU 39 | dropout_probability : float 40 | Amount of dropout to apply. 0 being no dropout and 1 being drop 41 | all connections 42 | linear_transform : str, {'permutation', 'lu', 'svd', None} 43 | Linear transform to apply before each coupling transform. 44 | distribution : :obj:`nflows.distribution.Distribution` 45 | Distribution object to use for that latent spae. If None, an n-d 46 | Gaussian is used. 47 | mask : Union[torch.Tensor, list, numpy.ndarray] 48 | Mask or array of masks to use to construct the flow. If not specified, 49 | an alternating binary mask will be used. 50 | num_bins : int 51 | Number of bins for the spline in each dimension. 52 | tail_type : {None, 'linear'} 53 | Type of tails to use outside the bounds on which the splines are 54 | defined. 55 | tail_bound : float 56 | Bound that defines the region over which the splines are defined. 57 | I.e. [-tail_bound, tail_bound] 58 | kwargs : 59 | Keyword arguments passed to the transform when is it initialised. 60 | """ 61 | 62 | def __init__( 63 | self, 64 | n_inputs, 65 | n_transforms, 66 | n_conditional_inputs=None, 67 | n_neurons=32, 68 | n_blocks_per_transform=2, 69 | batch_norm_within_blocks=False, 70 | batch_norm_between_transforms=False, 71 | activation=F.relu, 72 | dropout_probability=0.0, 73 | linear_transform=None, 74 | distribution=None, 75 | mask=None, 76 | num_bins=4, 77 | tail_type="linear", 78 | tail_bound=5.0, 79 | **kwargs, 80 | ): 81 | transform_class = PiecewiseRationalQuadraticCouplingTransform 82 | 83 | if distribution == "uniform": 84 | from ..distributions import MultivariateUniform 85 | 86 | tail_bound = 1.0 87 | tail_type = None 88 | distribution = MultivariateUniform( 89 | low=torch.Tensor(n_inputs * [0.0]), 90 | high=torch.Tensor(n_inputs * [1.0]), 91 | ) 92 | batch_norm_between_transforms = False 93 | 94 | super().__init__( 95 | transform_class, 96 | n_inputs, 97 | n_transforms, 98 | n_conditional_inputs=n_conditional_inputs, 99 | n_neurons=n_neurons, 100 | n_blocks_per_transform=n_blocks_per_transform, 101 | batch_norm_within_blocks=batch_norm_within_blocks, 102 | batch_norm_between_transforms=batch_norm_between_transforms, 103 | activation=activation, 104 | dropout_probability=dropout_probability, 105 | linear_transform=linear_transform, 106 | distribution=distribution, 107 | mask=mask, 108 | num_bins=num_bins, 109 | tails=tail_type, 110 | tail_bound=tail_bound, 111 | **kwargs, 112 | ) 113 | -------------------------------------------------------------------------------- /tests/test_flows/test_coupling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from glasflow.flows.coupling import CouplingFlow 4 | from glasflow.nflows.transforms.coupling import AffineCouplingTransform 5 | import numpy as np 6 | import pytest 7 | import torch 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "kwargs", 12 | [ 13 | {}, 14 | dict(linear_transform="svd"), 15 | dict(n_conditional_inputs=1), 16 | dict(mask=[[-1, 1], [1, -1]]), 17 | ], 18 | ) 19 | def test_coupling_flow_init(kwargs): 20 | """Test the init method""" 21 | CouplingFlow(AffineCouplingTransform, 2, 2, **kwargs) 22 | 23 | 24 | @pytest.mark.integration_test 25 | def test_coupling_flow_w_mask(): 26 | """Test the forward pass with a custom mask""" 27 | n_inputs = 2 28 | flow = CouplingFlow( 29 | AffineCouplingTransform, n_inputs, 2, mask=[[-1, 1], [1, -1]] 30 | ) 31 | 32 | x = torch.randn(10, 2) 33 | 34 | z, log_jac = flow.forward(x) 35 | z = z.detach().numpy() 36 | log_jac = log_jac.detach().numpy() 37 | assert z.shape == (10, 2) 38 | assert log_jac.shape == (10,) 39 | 40 | log_prob = flow.log_prob(x) 41 | log_prob = log_prob.detach().numpy() 42 | assert log_prob.shape == (10,) 43 | 44 | x = flow.sample(10) 45 | x = x.detach().numpy() 46 | assert x.shape == (10, 2) 47 | 48 | 49 | @pytest.mark.integration_test 50 | def test_coupling_flow_forward_w_conditional(): 51 | """Test the forward pass with a conditional input""" 52 | n_inputs = 2 53 | n_conditionals = 2 54 | flow = CouplingFlow( 55 | AffineCouplingTransform, 56 | n_inputs, 57 | 2, 58 | n_conditional_inputs=n_conditionals, 59 | ) 60 | 61 | x = torch.randn(10, 2) 62 | conditional = torch.randn(10, 2) 63 | 64 | z, log_prob = flow.forward(x, conditional=conditional) 65 | 66 | z = z.detach().numpy() 67 | log_prob = log_prob.detach().numpy() 68 | assert z.shape == (10, 2) 69 | assert log_prob.shape == (10,) 70 | 71 | 72 | @pytest.mark.integration_test 73 | def test_coupling_flow_sample_w_conditional(): 74 | """Test sampling with a conditional input""" 75 | n = 10 76 | n_inputs = 2 77 | n_conditionals = 2 78 | flow = CouplingFlow( 79 | AffineCouplingTransform, 80 | n_inputs, 81 | 2, 82 | n_conditional_inputs=n_conditionals, 83 | ) 84 | conditional = torch.randn(n, 2) 85 | x = flow.sample(n, conditional=conditional) 86 | x = x.detach().numpy() 87 | assert x.shape == (n, 2) 88 | 89 | 90 | @pytest.mark.integration_test 91 | def test_coupling_flow_sample_and_log_prob_w_conditional(): 92 | """Test sampling with a conditional input""" 93 | n = 10 94 | n_inputs = 2 95 | n_conditionals = 2 96 | flow = CouplingFlow( 97 | AffineCouplingTransform, 98 | n_inputs, 99 | 2, 100 | n_conditional_inputs=n_conditionals, 101 | ) 102 | 103 | conditional = torch.randn(n, 2) 104 | x, log_prob = flow.sample_and_log_prob(n, conditional=conditional) 105 | x = x.detach().numpy() 106 | log_prob = log_prob.detach().numpy() 107 | assert x.shape == (n, 2) 108 | assert log_prob.shape == (n,) 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "mask, expected", 113 | [ 114 | (None, torch.tensor([[-1, 1], [1, -1], [-1, 1]]).int()), 115 | ( 116 | torch.tensor([-1, 1]), 117 | torch.tensor([[-1, 1], [1, -1], [-1, 1]]).int(), 118 | ), 119 | ([-1, 1], torch.tensor([[-1, 1], [1, -1], [-1, 1]]).int()), 120 | (np.array([-1, 1]), torch.tensor([[-1, 1], [1, -1], [-1, 1]]).int()), 121 | ( 122 | torch.tensor([[1, -1], [1, -1], [-1, 1]]), 123 | torch.tensor([[1, -1], [1, -1], [-1, 1]]).int(), 124 | ), 125 | ], 126 | ) 127 | def test_validate_mask(mask, expected): 128 | """Assert the correct mask is returned""" 129 | out = CouplingFlow.validate_mask(mask, 2, 3) 130 | assert torch.equal(out, expected) 131 | 132 | 133 | def test_validate_mask_invalid_length(): 134 | """Assert a mask that is an invalid length raises an error""" 135 | with pytest.raises(ValueError) as excinfo: 136 | CouplingFlow.validate_mask([1, -1, 1], 2, 3) 137 | assert "does not match number of inputs" in str(excinfo.value) 138 | 139 | 140 | def test_validate_mask_invalid_depth(): 141 | """Assert a mask that is an invalid depth raises an error""" 142 | with pytest.raises(ValueError) as excinfo: 143 | CouplingFlow.validate_mask([[1, -1], [-1, 1]], 2, 3) 144 | assert "does not match number of transforms" in str(excinfo.value) 145 | -------------------------------------------------------------------------------- /src/glasflow/flows/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Base class for all normalising flows.""" 3 | from torch.nn import Module 4 | 5 | 6 | class Flow(Module): 7 | """ 8 | Base class for flow objects implemented according to the outline in nflows. 9 | 10 | Supports conditonal transforms but not conditional latent distributions. 11 | 12 | Parameters 13 | ---------- 14 | transform : :obj: `nflows.transforms.Transform` 15 | Object that applys the transformation, must have`forward` and 16 | `inverse` methods. See nflows for more details. 17 | distribution : :obj: `nflows.distributions.Distribution` 18 | Object the serves as the base distribution used when sampling 19 | and computing the log probrability. Must have `log_prob` and 20 | `sample` methods. See nflows for details 21 | """ 22 | 23 | def __init__(self, transform, distribution): 24 | super().__init__() 25 | 26 | for method in ["forward", "inverse"]: 27 | if not hasattr(transform, method): 28 | raise RuntimeError( 29 | f"Transform does not have `{method}` method" 30 | ) 31 | 32 | for method in ["log_prob", "sample"]: 33 | if not hasattr(distribution, method): 34 | raise RuntimeError( 35 | f"Distribution does not have `{method}` method" 36 | ) 37 | 38 | self._transform = transform 39 | self._distribution = distribution 40 | 41 | def forward(self, x, conditional=None): 42 | """ 43 | Apply the forward transformation and return samples in the latent 44 | space and log |J| 45 | """ 46 | return self._transform.forward(x, context=conditional) 47 | 48 | def inverse(self, z, conditional=None): 49 | """ 50 | Apply the inverse transformation and return samples in the 51 | data space and log |J| (not log probability) 52 | """ 53 | return self._transform.inverse(z, context=conditional) 54 | 55 | def sample(self, num_samples, conditional=None): 56 | """ 57 | Produces N samples in the data space by drawing from the base 58 | distribution and the applying the inverse transform. 59 | Does NOT need to be specified by the user 60 | """ 61 | noise = self._distribution.sample(num_samples) 62 | samples, _ = self._transform.inverse(noise, context=conditional) 63 | return samples 64 | 65 | def log_prob(self, inputs, conditional=None): 66 | """ 67 | Computes the log probability of the inputs samples by apply the 68 | transform. 69 | Does NOT need to specified by the user 70 | """ 71 | noise, logabsdet = self._transform(inputs, context=conditional) 72 | log_prob = self._distribution.log_prob(noise) 73 | return log_prob + logabsdet 74 | 75 | def base_distribution_log_prob(self, z): 76 | """ 77 | Computes the log probability of samples in the latent for 78 | the base distribution in the flow. 79 | 80 | Does not accept condtional inputs 81 | 82 | Parameters 83 | ---------- 84 | z : :obj:`torch.Tensor` 85 | Tensor of latent samples 86 | 87 | Returns 88 | ------- 89 | :obj: `torch.Tensor` 90 | Tensor of log-probabilities 91 | """ 92 | return self._distribution.log_prob(z) 93 | 94 | def forward_and_log_prob(self, x, conditional=None): 95 | """ 96 | Apply the forward transformation and compute the log probability 97 | of each sample 98 | 99 | Conditional inputs are only used for the forward transform. 100 | 101 | Returns 102 | ------- 103 | :obj:`torch.Tensor` 104 | Tensor of samples in the latent space 105 | :obj:`torch.Tensor` 106 | Tensor of log probabilities of the samples 107 | """ 108 | z, log_J = self.forward(x, conditional=conditional) 109 | log_prob = self.base_distribution_log_prob(z) 110 | return z, log_prob + log_J 111 | 112 | def sample_and_log_prob(self, N, conditional=None): 113 | """ 114 | Generates samples from the flow, together with their log probabilities 115 | in the data space log p(x) = log p(z) + log|J|. 116 | For flows, this is more efficient that calling `sample` and `log_prob` 117 | separately. 118 | 119 | Conditional inputs are only used for the inverse transform. 120 | """ 121 | z, log_prob = self._distribution.sample_and_log_prob(N) 122 | samples, logabsdet = self._transform.inverse(z, context=conditional) 123 | return samples, log_prob - logabsdet 124 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | 10 | ## [0.4.1] 11 | 12 | ### Fixed 13 | 14 | - Fix error when using automatic mixed precision with spline transforms (https://github.com/uofgravity/glasflow/pull/66) 15 | 16 | ## [0.4.0] 17 | 18 | ### Added 19 | 20 | - Add various autoregressive flows using the existing transforms in `nflows` (https://github.com/uofgravity/glasflow/pull/62) 21 | - Add `scale_activation` keyword argument to `nflows.transforms.autoregressive.MaskedAffineAutoregressiveTransform` (https://github.com/uofgravity/nflows/pull/11) 22 | 23 | ### Changed 24 | 25 | - Drop support for Python 3.7 (https://github.com/uofgravity/glasflow/pull/61) 26 | 27 | 28 | ## [0.3.1] 29 | 30 | ### Fixed 31 | 32 | - Addressed a deprecation warning in the `nflows` submodule when using LU decomposition (https://github.com/uofgravity/nflows/pull/10, https://github.com/uofgravity/glasflow/pull/57) 33 | 34 | ## [0.3.0] 35 | 36 | ### Added 37 | 38 | - Keyword arguments passed to `glasflow.transform.coupling.AffineCouplingTransform` are now propogated to the parent class. ([#51](https://github.com/uofgravity/glasflow/pull/51)) 39 | - Add support `scale_activation` to `glasflow.transform.coupling.AffineCouplingTransform` and set the default to `nflows_general`. ([#52](https://github.com/uofgravity/glasflow/pull/52), [#54](https://github.com/uofgravity/glasflow/pull/54)) 40 | 41 | ### Changed 42 | 43 | - Default scale activation for `glasflow.transform.coupling.AffineCouplingTransform` is changed from `DEFAULT_SCALE_ACTIVATION` in nflows to `nflows_general` from glasflow. This changes the default behaviour, the previous behaviour can be recovered by setting `scale_activation='nflows'`. ([#52](https://github.com/uofgravity/glasflow/pull/52), [#54](https://github.com/uofgravity/glasflow/pull/54)) 44 | 45 | ### Fixed 46 | 47 | - fix a bug in `glasflow.nflows/utils/torchutils.searchsorted`, see https://github.com/uofgravity/nflows/pull/9 for details. ([#53](https://github.com/uofgravity/glasflow/pull/53)) 48 | 49 | ### Deprecated 50 | 51 | - The `scaling_method` argument in `glasflow.transform.coupling.AffineCouplingTransform` is now deprecated in favour of `scale_activation` and will be removed in a future release. ([#52](https://github.com/uofgravity/glasflow/pull/52)) 52 | 53 | ## [0.2.0] 54 | 55 | ### Added 56 | 57 | - Add a multi-layer perceptron (`glasflow.nets.mlp.MLP`). ([#40](https://github.com/uofgravity/glasflow/pull/40)) 58 | - Add a resampled Gaussian distribution that uses Learnt Accept/Reject Sampling (`glasflow.distributions.resampled.ResampledGaussian`). ([#40](https://github.com/uofgravity/glasflow/pull/40)) 59 | - Add `nessai.utils.get_torch_size`. ([#40](https://github.com/uofgravity/glasflow/pull/40)) 60 | - Add a multivariate uniform distribution for Neural Spline Flows (`glasflow.distributions.uniform.MultivariateUniform`). ([#47](https://github.com/uofgravity/glasflow/pull/47)) 61 | 62 | ## Changed 63 | 64 | - Change logging statements on import to, by default, only appear when an external version of nflows is being used. ([#44](https://github.com/uofgravity/glasflow/pull/44)) 65 | 66 | ## [0.1.2] 67 | 68 | Another patch to fix CI not uploading release to PyPI 69 | ### Changed 70 | 71 | - Update `nflows` submodule ([#36](https://github.com/uofgravity/glasflow/pull/36)) 72 | - Remove LFS from `publish-to-pypi` workflow ([#36](https://github.com/uofgravity/glasflow/pull/36)) 73 | 74 | ## [0.1.1] 75 | 76 | Patch to fix CI not uploading release to PyPI 77 | 78 | ### Changed 79 | 80 | - Add LFS to `publish-to-pypi` workflow ([#35](https://github.com/uofgravity/glasflow/pull/35)) 81 | 82 | ## [0.1.0] 83 | 84 | ### Added 85 | 86 | - Add `RealNVP` 87 | - Add `CouplingNSF` (Coupling Neural Spline Flow) 88 | - Add `nflows` submodule that replaces `nflows` dependency 89 | - Add option for user-defined masks in coupling-based flows 90 | 91 | [Unreleased]: https://github.com/uofgravity/glasflow/compare/v0.4.1...HEAD 92 | [0.4.1]: https://github.com/uofgravity/glasflow/compare/v0.4.0...v0.4.1 93 | [0.4.0]: https://github.com/uofgravity/glasflow/compare/v0.3.1...v0.4.0 94 | [0.3.1]: https://github.com/uofgravity/glasflow/compare/v0.3.0...v0.3.1 95 | [0.3.0]: https://github.com/uofgravity/glasflow/compare/v0.2.0...v0.3.0 96 | [0.2.0]: https://github.com/uofgravity/glasflow/compare/v0.1.2...v0.2.0 97 | [0.1.2]: https://github.com/uofgravity/glasflow/compare/v0.1.1...v0.1.2 98 | [0.1.1]: https://github.com/uofgravity/glasflow/compare/v0.1.0...v0.1.1 99 | [0.1.0]: https://github.com/uofgravity/glasflow/releases/tag/v0.1.0 100 | -------------------------------------------------------------------------------- /src/glasflow/nets/mlp.py: -------------------------------------------------------------------------------- 1 | """Multi-layer perceptrons""" 2 | 3 | from typing import Callable, List, Optional, Union 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from ..utils import get_torch_size 11 | 12 | 13 | class MLP(nn.Module): 14 | """A standard multi-layer perceptron. 15 | 16 | Based on the implementation in nflows and nessai 17 | 18 | Parameters 19 | ---------- 20 | input_shape 21 | Input shape. 22 | output_shape 23 | Output shape. 24 | n_neurons_per_layer 25 | Number of neurons in the hidden layers. 26 | activation_fn 27 | Activation function 28 | activate_output 29 | Whether to activate the output layer. If a bool is specified the same 30 | activation function is used. If a callable inputs is specified, it 31 | will be used for the activation. 32 | dropout_probability : float 33 | Amount of dropout to apply after the hidden layers. 34 | 35 | Raises 36 | ------ 37 | ValueError 38 | If the number of neurons per layers is empty. 39 | TypeError 40 | If :code:`activate_ouput` is an invalid type. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | input_shape: Union[tuple, int], 46 | output_shape: Union[tuple, int], 47 | n_neurons_per_layer: List[int], 48 | activation_fn: Union[bool, Callable] = F.relu, 49 | activate_output: bool = False, 50 | dropout_probability: float = 0.0, 51 | ): 52 | super().__init__() 53 | 54 | self._input_shape = get_torch_size(input_shape) 55 | self._output_shape = get_torch_size(output_shape) 56 | self._n_neurons_per_layer = n_neurons_per_layer 57 | self._activation_fn = activation_fn 58 | self._activate_output = activate_output 59 | 60 | if len(n_neurons_per_layer) == 0: 61 | raise ValueError("List of n neurons per layer cannot be empty") 62 | 63 | self._input_layer = nn.Linear( 64 | np.prod(input_shape), n_neurons_per_layer[0] 65 | ) 66 | self._hidden_layers = nn.ModuleList( 67 | [ 68 | nn.Linear(input_size, output_size) 69 | for input_size, output_size in zip( 70 | n_neurons_per_layer[:-1], n_neurons_per_layer[1:] 71 | ) 72 | ] 73 | ) 74 | self._dropout_layers = nn.ModuleList( 75 | nn.Dropout(dropout_probability) 76 | for _ in range(len(self._hidden_layers)) 77 | ) 78 | self._output_layer = nn.Linear( 79 | n_neurons_per_layer[-1], np.prod(output_shape) 80 | ) 81 | 82 | if activate_output: 83 | self._activate_output = True 84 | if activate_output is True: 85 | self._output_activation = self._activation_fn 86 | elif callable(activate_output): 87 | self._output_activation = activate_output 88 | else: 89 | raise TypeError( 90 | "activate_output must be a boolean or a callable." 91 | f"Got input of type {type(activate_output)}." 92 | ) 93 | else: 94 | self._activate_output = False 95 | 96 | def forward( 97 | self, inputs: torch.tensor, context: Optional[torch.tensor] = None 98 | ) -> torch.tensor: 99 | """Forward method that allows for kwargs such as context. 100 | 101 | Parameters 102 | ---------- 103 | inputs 104 | Inputs to the MLP 105 | context 106 | Conditional inputs, must be None. Only implemented for 107 | compatibility. 108 | 109 | Raises 110 | ------ 111 | ValueError 112 | If the context is not None. 113 | ValueError 114 | If the input shape is incorrect. 115 | """ 116 | if context is not None: 117 | raise ValueError("MLP with conditional inputs is not implemented.") 118 | if inputs.shape[1:] != self._input_shape: 119 | raise ValueError( 120 | "Expected inputs of shape {}, got {}.".format( 121 | self._input_shape, inputs.shape[1:] 122 | ) 123 | ) 124 | 125 | inputs = inputs.reshape(-1, np.prod(self._input_shape)) 126 | outputs = self._input_layer(inputs) 127 | outputs = self._activation_fn(outputs) 128 | 129 | for hidden_layer, dropout in zip( 130 | self._hidden_layers, self._dropout_layers 131 | ): 132 | outputs = hidden_layer(outputs) 133 | outputs = self._activation_fn(outputs) 134 | outputs = dropout(outputs) 135 | 136 | outputs = self._output_layer(outputs) 137 | if self._activate_output: 138 | outputs = self._output_activation(outputs) 139 | outputs = outputs.reshape(-1, *self._output_shape) 140 | 141 | return outputs 142 | -------------------------------------------------------------------------------- /examples/moons_nvp_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Example of training a flow using `glasflow`" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from glasflow.flows import RealNVP\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "from scipy.stats import norm\n", 20 | "import seaborn as sns\n", 21 | "import sklearn.datasets as datasets\n", 22 | "import torch\n", 23 | "from torch import optim\n", 24 | "\n", 25 | "# Update the plotting style\n", 26 | "sns.set_context(\"notebook\")\n", 27 | "sns.set_palette(\"colorblind\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "x, y = datasets.make_moons(128, noise=0.05)\n", 37 | "plt.scatter(x[:, 0], x[:, 1])\n", 38 | "plt.show()" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "flow = RealNVP(\n", 48 | " n_inputs=2,\n", 49 | " n_transforms=5,\n", 50 | " n_neurons=32,\n", 51 | " batch_norm_between_transforms=True,\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "optimizer = optim.Adam(flow.parameters())" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "num_iter = 5000\n", 71 | "train_loss = []\n", 72 | "\n", 73 | "for i in range(num_iter):\n", 74 | " t_loss = 0\n", 75 | "\n", 76 | " x, y = datasets.make_moons(128, noise=0.1)\n", 77 | " x = torch.tensor(x, dtype=torch.float32)\n", 78 | " optimizer.zero_grad()\n", 79 | " loss = -flow.log_prob(inputs=x).mean()\n", 80 | " loss.backward()\n", 81 | " optimizer.step()\n", 82 | " t_loss += loss.item()\n", 83 | "\n", 84 | " if (i + 1) % 500 == 0:\n", 85 | " xline = torch.linspace(-1.5, 2.5, 100)\n", 86 | " yline = torch.linspace(-0.75, 1.25, 100)\n", 87 | " xgrid, ygrid = torch.meshgrid(xline, yline)\n", 88 | " xyinput = torch.cat(\n", 89 | " [xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1\n", 90 | " )\n", 91 | "\n", 92 | " with torch.no_grad():\n", 93 | " zgrid = flow.log_prob(xyinput).exp().reshape(100, 100)\n", 94 | "\n", 95 | " plt.contourf(xgrid.numpy(), ygrid.numpy(), zgrid.numpy())\n", 96 | " plt.title(\"iteration {}\".format(i + 1))\n", 97 | " plt.show()\n", 98 | "\n", 99 | " train_loss.append(t_loss)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "plt.plot(train_loss)\n", 109 | "plt.xlabel(\"Iteration\", fontsize=12)\n", 110 | "plt.ylabel(\"Training loss\", fontsize=12)\n", 111 | "plt.show()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Drawing samples from the flow\n", 119 | "\n", 120 | "We can now draw samples from the trained flow." 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "n = 1000\n", 130 | "flow.eval()\n", 131 | "with torch.no_grad():\n", 132 | " generated_samples = flow.sample(1000)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "plt.scatter(generated_samples[:, 0], generated_samples[:, 1])\n", 142 | "plt.show()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## Plotting the latent space\n", 150 | "\n", 151 | "We can pass samples through the flow and produces samples in the latent space. These samples (z) should be Gaussian." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "flow.eval()\n", 161 | "with torch.no_grad():\n", 162 | " z_, _ = flow.forward(x)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "g = np.linspace(-5, 5, 100)\n", 172 | "plt.plot(g, norm.pdf(g), label=\"Standard Gaussian\")\n", 173 | "\n", 174 | "plt.hist(z_[:, 0], density=True, label=\"z_0\")\n", 175 | "plt.hist(z_[:, 1], density=True, label=\"z_1\")\n", 176 | "plt.legend()\n", 177 | "plt.show()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "Python 3.9.13 ('test-example')", 191 | "language": "python", 192 | "name": "python3" 193 | }, 194 | "language_info": { 195 | "codemirror_mode": { 196 | "name": "ipython", 197 | "version": 3 198 | }, 199 | "file_extension": ".py", 200 | "mimetype": "text/x-python", 201 | "name": "python", 202 | "nbconvert_exporter": "python", 203 | "pygments_lexer": "ipython3", 204 | "version": "3.9.13" 205 | }, 206 | "vscode": { 207 | "interpreter": { 208 | "hash": "3517700c92508b2e59eb9f265f5f8465660b244c7bbb91180936f922ef327d4b" 209 | } 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 4 214 | } 215 | -------------------------------------------------------------------------------- /src/glasflow/distributions/resampled.py: -------------------------------------------------------------------------------- 1 | """Distributions that include Learned Accept/Reject Sampling (LARS).""" 2 | 3 | from typing import Callable, Union 4 | 5 | from glasflow.nflows.distributions import Distribution 6 | from glasflow.nflows.utils import torchutils 7 | from glasflow.utils import get_torch_size 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class ResampledGaussian(Distribution): 14 | """Gaussian distribution that includes LARS. 15 | 16 | For details see: https://arxiv.org/abs/2110.15828 17 | 18 | Based on the implementation here: \ 19 | https://github.com/VincentStimper/resampled-base-flows 20 | 21 | Does not support conditional inputs. 22 | 23 | Parameters 24 | ---------- 25 | shape 26 | Shape of the distribution 27 | acceptance_fn 28 | Function that computes the acceptance. Typically a neural network. 29 | eps 30 | Decay parameter for the exponential moving average used to update 31 | the estimate of Z. 32 | truncation 33 | Maximum number of rejection steps. Called T in the original paper. 34 | trainable 35 | Boolean to indicate if the mean and standard deviation of the 36 | distribution are learnable parameters. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | shape: Union[int, tuple], 42 | acceptance_fn: Callable, 43 | eps: float = 0.05, 44 | truncation: int = 100, 45 | trainable: bool = False, 46 | ) -> None: 47 | super().__init__() 48 | self._shape = get_torch_size(shape) 49 | self.truncation = truncation 50 | self.acceptance_fn = acceptance_fn 51 | self.eps = eps 52 | 53 | self.register_buffer("norm", torch.tensor(-1.0)) 54 | self.register_buffer( 55 | "_log_z", 56 | torch.tensor( 57 | 0.5 * np.prod(self._shape) * np.log(2 * np.pi), 58 | dtype=torch.float64, 59 | ), 60 | ) 61 | if trainable: 62 | self.loc = nn.Parameter(torch.zeros(1, *self._shape)) 63 | self.log_scale = nn.Parameter(torch.zeros(1, *self._shape)) 64 | else: 65 | self.register_buffer("loc", torch.zeros(1, *self._shape)) 66 | self.register_buffer("log_scale", torch.zeros(1, *self._shape)) 67 | 68 | def _log_prob_gaussian(self, norm_inputs: torch.tensor) -> torch.tensor: 69 | """Base Gaussian log probability""" 70 | log_prob = ( 71 | -0.5 72 | * torchutils.sum_except_batch(norm_inputs**2, num_batch_dims=1) 73 | - torchutils.sum_except_batch(self.log_scale, num_batch_dims=1) 74 | - self._log_z 75 | ) 76 | return log_prob 77 | 78 | def _log_prob( 79 | self, inputs: torch.tensor, context: torch.tensor = None 80 | ) -> torch.tensor: 81 | """Log probability""" 82 | if context is not None: 83 | raise ValueError("Conditional inputs not supported") 84 | 85 | norm_inputs = (inputs - self.loc) / self.log_scale.exp() 86 | log_p_gaussian = self._log_prob_gaussian(norm_inputs) 87 | acc = self.acceptance_fn(norm_inputs) 88 | 89 | # Compute the normalisation 90 | if self.training or self.norm < 0.0: 91 | eps_ = torch.randn_like(inputs) 92 | norm_batch = torch.mean(self.acceptance_fn(eps_)) 93 | if self.norm < 0.0: 94 | self.norm = norm_batch.detach() 95 | else: 96 | # Update the normalisation estimate 97 | # eps defines the weight between the current estimate 98 | # and the new estimated value 99 | self.norm = ( 100 | 1 - self.eps 101 | ) * self.norm + self.eps * norm_batch.detach() 102 | # Why this? 103 | norm = norm_batch - norm_batch.detach() + self.norm 104 | else: 105 | norm = self.norm 106 | 107 | alpha = (1 - norm) ** (self.truncation - 1) 108 | return ( 109 | torch.log((1 - alpha) * acc[:, 0] / norm + alpha) + log_p_gaussian 110 | ) 111 | 112 | def _sample( 113 | self, num_samples: int, context: torch.tensor = None 114 | ) -> torch.tensor: 115 | if context is not None: 116 | raise ValueError("Conditional inputs not supported") 117 | 118 | device = self._log_z.device 119 | samples = torch.zeros(num_samples, *self._shape, device=device) 120 | 121 | t = 0 122 | s = 0 123 | n = 0 124 | norm_sum = 0 125 | 126 | for _ in range(self.truncation): 127 | samples_ = torch.randn(num_samples, *self._shape, device=device) 128 | acc = self.acceptance_fn(samples_) 129 | if self.training or self.norm < 0: 130 | norm_sum = norm_sum + acc.sum().detach() 131 | n += num_samples 132 | 133 | dec = torch.rand_like(acc) < acc 134 | for j, dec_ in enumerate(dec[:, 0]): 135 | if dec_ or (t == (self.truncation - 1)): 136 | samples[s, :] = samples_[j, :] 137 | s = s + 1 138 | t = 0 139 | else: 140 | t = t + 1 141 | if s == num_samples: 142 | break 143 | if s == num_samples: 144 | break 145 | 146 | samples = self.loc + self.log_scale.exp() * samples 147 | return samples 148 | 149 | def estimate_normalisation_constant( 150 | self, n_samples: int = 1000, n_batches: int = 1 151 | ) -> None: 152 | """Estimate the normalisation constant via Monte Carlo sampling. 153 | 154 | Should be called once the training is complete. 155 | 156 | Parameters 157 | ---------- 158 | n_samples 159 | Number of samples to draw in each batch. 160 | n_batches 161 | Number of batches to use. 162 | """ 163 | with torch.no_grad(): 164 | self.norm = self.norm * 0.0 165 | # Get dtype and device 166 | dtype = self.norm.dtype 167 | device = self.norm.device 168 | for _ in range(n_batches): 169 | eps = torch.randn( 170 | n_samples, *self._shape, dtype=dtype, device=device 171 | ) 172 | acc_ = self.acceptance_fn(eps) 173 | norm_batch = torch.mean(acc_) 174 | self.norm = self.norm + norm_batch / n_batches 175 | -------------------------------------------------------------------------------- /src/glasflow/flows/coupling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from glasflow.nflows.nn.nets import ResidualNet 3 | from glasflow.nflows import transforms 4 | from glasflow.nflows.transforms.coupling import CouplingTransform 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from .base import Flow 9 | 10 | 11 | class CouplingFlow(Flow): 12 | """Base class for coupling transform based flows. 13 | 14 | Uses an alternating binary mask and residual neural network. Others 15 | settings can be configured. See parameters for details. 16 | 17 | Parameters 18 | ---------- 19 | transform_class : :obj:`nflows.transforms.coupling.CouplingTransform` 20 | Class that inherits from `CouplingTransform` and implements the 21 | actual transformation. 22 | n_inputs : int 23 | Number of inputs 24 | n_transforms : int 25 | Number of transforms 26 | n_conditional_inputs: int 27 | Number of conditionals inputs 28 | n_neurons : int 29 | Number of neurons per residual block in each transform 30 | n_blocks_per_transform : int 31 | Number of residual blocks per transform 32 | batch_norm_within_blocks : bool 33 | Enable batch normalisation within each residual block 34 | batch_norm_between_transforms : bool 35 | Enable batch norm between transforms 36 | activation : function 37 | Activation function to use. Defaults to ReLU 38 | dropout_probability : float 39 | Amount of dropout to apply. 0 being no dropout and 1 being drop 40 | all connections 41 | linear_transform : str, {'permutation', 'lu', 'svd', None} 42 | Linear transform to apply before each coupling transform. 43 | distribution : :obj:`nflows.distribution.Distribution` 44 | Distribution object to use for that latent spae. If None, an n-d 45 | Gaussian is used. 46 | mask : Union[torch.Tensor, list, numpy.ndarray] 47 | Mask or array of masks to use to construct the flow. If not specified, 48 | an alternating binary mask will be used. 49 | kwargs : 50 | Keyword arguments passed to `transform_class` when is it initialised. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | transform_class, 56 | n_inputs, 57 | n_transforms, 58 | n_conditional_inputs=None, 59 | n_neurons=32, 60 | n_blocks_per_transform=2, 61 | batch_norm_within_blocks=False, 62 | batch_norm_between_transforms=False, 63 | activation=F.relu, 64 | dropout_probability=0.0, 65 | linear_transform=None, 66 | distribution=None, 67 | mask=None, 68 | **kwargs, 69 | ): 70 | if not issubclass(transform_class, CouplingTransform): 71 | raise RuntimeError( 72 | "Transform class does not inherit from `CouplingTransform`" 73 | ) 74 | 75 | def create_net(n_in, n_out): 76 | return ResidualNet( 77 | n_in, 78 | n_out, 79 | hidden_features=n_neurons, 80 | context_features=n_conditional_inputs, 81 | num_blocks=n_blocks_per_transform, 82 | activation=activation, 83 | dropout_probability=dropout_probability, 84 | use_batch_norm=batch_norm_within_blocks, 85 | ) 86 | 87 | def create_linear_transform(): 88 | if linear_transform == "permutation": 89 | return transforms.RandomPermutation(features=n_inputs) 90 | elif linear_transform == "lu": 91 | return transforms.CompositeTransform( 92 | [ 93 | transforms.RandomPermutation(features=n_inputs), 94 | transforms.LULinear( 95 | n_inputs, identity_init=True, using_cache=False 96 | ), 97 | ] 98 | ) 99 | elif linear_transform == "svd": 100 | return transforms.CompositeTransform( 101 | [ 102 | transforms.RandomPermutation(features=n_inputs), 103 | transforms.SVDLinear( 104 | n_inputs, num_householder=10, identity_init=True 105 | ), 106 | ] 107 | ) 108 | else: 109 | raise ValueError( 110 | f"Unknown linear transform: {linear_transform}." 111 | ) 112 | 113 | def create_transform(mask): 114 | return transform_class( 115 | mask=mask, transform_net_create_fn=create_net, **kwargs 116 | ) 117 | 118 | mask = self.validate_mask(mask, n_inputs, n_transforms) 119 | 120 | transforms_list = [] 121 | 122 | for i in range(n_transforms): 123 | if linear_transform is not None: 124 | transforms_list.append(create_linear_transform()) 125 | transforms_list.append(create_transform(mask[i])) 126 | if batch_norm_between_transforms: 127 | transforms_list.append(transforms.BatchNorm(n_inputs)) 128 | 129 | if distribution is None: 130 | from glasflow.nflows.distributions import StandardNormal 131 | 132 | distribution = StandardNormal([n_inputs]) 133 | 134 | super().__init__( 135 | transform=transforms.CompositeTransform(transforms_list), 136 | distribution=distribution, 137 | ) 138 | 139 | @staticmethod 140 | def validate_mask(mask, n_inputs, n_transforms): 141 | """Validate the mask. 142 | 143 | Returns 144 | ------- 145 | torch.Tensor 146 | A tensor of shape (n_transforms, n_inputs) with the mask for each 147 | transform. 148 | """ 149 | if mask is None: 150 | mask = torch.ones(n_inputs).int() 151 | mask[::2] = -1 152 | else: 153 | if not isinstance(mask, torch.Tensor): 154 | mask = torch.tensor(mask) 155 | if not mask.shape[-1] == n_inputs: 156 | raise ValueError("Mask does not match number of inputs") 157 | if mask.dim() == 2 and not mask.shape[0] == n_transforms: 158 | raise ValueError("Mask does not match number of transforms") 159 | mask = mask.int() 160 | 161 | # If mask is 1-d make a complete set of masks 162 | if mask.dim() == 1: 163 | mask_array = torch.empty([n_transforms, n_inputs]).int() 164 | for i in range(n_transforms): 165 | mask_array[i] = mask 166 | mask *= -1 167 | mask = mask_array 168 | return mask 169 | -------------------------------------------------------------------------------- /examples/conditional_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Example of using conditional flows with `glasflow` using a dataset from `sklearn`" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from glasflow import RealNVP\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "import seaborn as sns\n", 20 | "from sklearn.datasets import make_blobs\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "import torch\n", 23 | "\n", 24 | "torch.manual_seed(1451)\n", 25 | "np.random.seed(1451)\n", 26 | "sns.set_context(\"notebook\")\n", 27 | "sns.set_palette(\"colorblind\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "Use `make_blobs` to make a set of Gaussian blobs corresponding to different classes." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "data, labels = make_blobs(\n", 44 | " n_samples=10000,\n", 45 | " n_features=2,\n", 46 | " centers=4,\n", 47 | " cluster_std=[1.7, 5.0, 3.1, 0.2],\n", 48 | " random_state=314159,\n", 49 | ")\n", 50 | "classes = np.unique(labels)\n", 51 | "print(f\"Classes are: {classes}\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "fig = plt.figure(dpi=100)\n", 61 | "markers = [\".\", \"x\", \"+\", \"^\"]\n", 62 | "for c, m in zip(classes, markers):\n", 63 | " idx = labels == c\n", 64 | " plt.scatter(\n", 65 | " data[idx, 0], data[idx, 1], label=f\"Class {c}\", marker=m, alpha=0.8\n", 66 | " )\n", 67 | "plt.legend()\n", 68 | "plt.show()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "device = \"cpu\"\n", 78 | "flow = RealNVP(\n", 79 | " n_inputs=2,\n", 80 | " n_transforms=4,\n", 81 | " n_conditional_inputs=1,\n", 82 | " n_neurons=32,\n", 83 | " batch_norm_between_transforms=True,\n", 84 | ")\n", 85 | "flow.to(device)\n", 86 | "print(f\"Created flow and sent to {device}\")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "optimiser = torch.optim.Adam(flow.parameters())" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "batch_size = 1000\n", 105 | "x_train, x_val, y_train, y_val = train_test_split(data, labels[:, np.newaxis])" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "Prepare the data using dataloaders" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "x_train_tensor = torch.from_numpy(x_train.astype(np.float32))\n", 122 | "y_train_tensor = torch.from_numpy(y_train.astype(np.float32))\n", 123 | "train_dataset = torch.utils.data.TensorDataset(x_train_tensor, y_train_tensor)\n", 124 | "train_loader = torch.utils.data.DataLoader(\n", 125 | " train_dataset, batch_size=batch_size, shuffle=True\n", 126 | ")\n", 127 | "\n", 128 | "x_val_tensor = torch.from_numpy(x_val.astype(np.float32))\n", 129 | "y_val_tensor = torch.from_numpy(y_val.astype(np.float32))\n", 130 | "val_dataset = torch.utils.data.TensorDataset(x_val_tensor, y_val_tensor)\n", 131 | "val_loader = torch.utils.data.DataLoader(\n", 132 | " val_dataset, batch_size=batch_size, shuffle=False\n", 133 | ")" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "Train the flow." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "epochs = 200\n", 150 | "loss = dict(train=[], val=[])\n", 151 | "\n", 152 | "for i in range(epochs):\n", 153 | " flow.train()\n", 154 | " train_loss = 0.0\n", 155 | " for batch in train_loader:\n", 156 | " x, y = batch\n", 157 | " x = x.to(device)\n", 158 | " y = y.to(device)\n", 159 | " optimiser.zero_grad()\n", 160 | " _loss = -flow.log_prob(x, conditional=y).mean()\n", 161 | " _loss.backward()\n", 162 | " optimiser.step()\n", 163 | " train_loss += _loss.item()\n", 164 | " loss[\"train\"].append(train_loss / len(train_loader))\n", 165 | "\n", 166 | " flow.eval()\n", 167 | " val_loss = 0.0\n", 168 | " for batch in val_loader:\n", 169 | " x, y = batch\n", 170 | " x = x.to(device)\n", 171 | " y = y.to(device)\n", 172 | " with torch.no_grad():\n", 173 | " _loss = -flow.log_prob(x, conditional=y).mean().item()\n", 174 | " val_loss += _loss\n", 175 | " loss[\"val\"].append(val_loss / len(val_loader))\n", 176 | " if not i % 10:\n", 177 | " print(\n", 178 | " f\"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}\"\n", 179 | " )\n", 180 | "\n", 181 | "flow.eval()\n", 182 | "print(\"Finished training\")" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "plt.plot(loss[\"train\"], label=\"Train\")\n", 192 | "plt.plot(loss[\"val\"], label=\"Val.\")\n", 193 | "plt.xlabel(\"Epoch\")\n", 194 | "plt.ylabel(\"Loss\")\n", 195 | "plt.legend()\n", 196 | "plt.show()" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "n = 10000\n", 206 | "conditional = torch.from_numpy(\n", 207 | " np.random.choice(4, size=(n, 1)).astype(np.float32)\n", 208 | ").to(device)\n", 209 | "with torch.no_grad():\n", 210 | " samples = flow.sample(n, conditional=conditional)\n", 211 | "samples = samples.cpu().numpy()\n", 212 | "conditional = conditional.cpu().numpy()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "From the figure below we can see that the flow probably requires a bit more training but it does show how the flow can learn each distribution using the conditional inputs." 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "fig, ax = plt.subplots(\n", 229 | " 1, 2, sharex=True, sharey=True, figsize=(10, 5), dpi=100\n", 230 | ")\n", 231 | "markers = [\".\", \"x\", \"+\", \"^\"]\n", 232 | "for c, m in zip(classes, markers):\n", 233 | " idx = labels == c\n", 234 | " ax[0].scatter(\n", 235 | " data[idx, 0], data[idx, 1], label=f\"Class {c}\", marker=m, alpha=0.5\n", 236 | " )\n", 237 | "\n", 238 | " idx = conditional[:, 0] == c\n", 239 | " ax[1].scatter(\n", 240 | " samples[idx, 0],\n", 241 | " samples[idx, 1],\n", 242 | " label=f\"Class {c}\",\n", 243 | " marker=m,\n", 244 | " alpha=0.5,\n", 245 | " )\n", 246 | "ax[0].set_title(\"Data\")\n", 247 | "ax[1].set_title(\"Samples from flow\")\n", 248 | "plt.legend()\n", 249 | "plt.show()" 250 | ] 251 | } 252 | ], 253 | "metadata": { 254 | "kernelspec": { 255 | "display_name": "Python 3.9.13 ('test-example')", 256 | "language": "python", 257 | "name": "python3" 258 | }, 259 | "language_info": { 260 | "codemirror_mode": { 261 | "name": "ipython", 262 | "version": 3 263 | }, 264 | "file_extension": ".py", 265 | "mimetype": "text/x-python", 266 | "name": "python", 267 | "nbconvert_exporter": "python", 268 | "pygments_lexer": "ipython3", 269 | "version": "3.9.13" 270 | }, 271 | "vscode": { 272 | "interpreter": { 273 | "hash": "3517700c92508b2e59eb9f265f5f8465660b244c7bbb91180936f922ef327d4b" 274 | } 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 2 279 | } 280 | -------------------------------------------------------------------------------- /src/glasflow/flows/autoregressive.py: -------------------------------------------------------------------------------- 1 | from glasflow.nflows import transforms 2 | import logging 3 | import torch.nn.functional as F 4 | 5 | from .base import Flow 6 | from ..transforms.utils import get_scale_activation 7 | from .. import USE_NFLOWS 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class MaskedAutoregressiveFlow(Flow): 14 | """Base class for masked autoregressive flows. 15 | 16 | Parameters 17 | ---------- 18 | transform_class : :obj:`nflows.transforms.autoregressive.AutoregressiveTransform` 19 | Class that inherits from `CouplingTransform` and implements the 20 | actual transformation. 21 | n_inputs : int 22 | Number of inputs 23 | n_transforms : int 24 | Number of transforms 25 | n_conditional_inputs: int 26 | Number of conditionals inputs 27 | n_neurons : int 28 | Number of neurons per residual block in each transform 29 | n_blocks_per_transform : int 30 | Number of residual blocks per transform 31 | batch_norm_within_blocks : bool 32 | Enable batch normalisation within each residual block 33 | batch_norm_between_transforms : bool 34 | Enable batch norm between transforms 35 | activation : function 36 | Activation function to use. Defaults to ReLU 37 | dropout_probability : float 38 | Amount of dropout to apply. 0 being no dropout and 1 being drop 39 | all connections 40 | use_random_permutations : bool 41 | If True, the order of the inputs is randomly permuted between 42 | transforms. 43 | use_random_masks : bool 44 | If True, random masks are used for the autoregressive transform. 45 | distribution : :obj:`nflows.distribution.Distribution` 46 | Distribution object to use for that latent spae. If None, an n-d 47 | Gaussian is used. 48 | kwargs : 49 | Keyword arguments passed to `transform_class` when is it initialised. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | transform_class, 55 | n_inputs, 56 | n_transforms, 57 | n_conditional_inputs=None, 58 | n_neurons=32, 59 | n_blocks_per_transform=2, 60 | batch_norm_within_blocks=False, 61 | batch_norm_between_transforms=False, 62 | activation=F.relu, 63 | dropout_probability=0.0, 64 | use_random_permutations=False, 65 | use_random_masks=False, 66 | distribution=None, 67 | **kwargs, 68 | ): 69 | 70 | if use_random_permutations: 71 | permutation_constructor = transforms.RandomPermutation 72 | else: 73 | permutation_constructor = transforms.ReversePermutation 74 | 75 | layers = [] 76 | for _ in range(n_transforms): 77 | layers.append(permutation_constructor(n_inputs)) 78 | layers.append( 79 | transform_class( 80 | features=n_inputs, 81 | hidden_features=n_neurons, 82 | context_features=n_conditional_inputs, 83 | num_blocks=n_blocks_per_transform, 84 | random_mask=use_random_masks, 85 | activation=activation, 86 | dropout_probability=dropout_probability, 87 | use_batch_norm=batch_norm_within_blocks, 88 | **kwargs, 89 | ) 90 | ) 91 | if batch_norm_between_transforms: 92 | layers.append(transforms.BatchNorm(n_inputs)) 93 | 94 | if distribution is None: 95 | from glasflow.nflows.distributions import StandardNormal 96 | 97 | distribution = StandardNormal([n_inputs]) 98 | 99 | super().__init__( 100 | transform=transforms.CompositeTransform(layers), 101 | distribution=distribution, 102 | ) 103 | 104 | 105 | class MaskedAffineAutoregressiveFlow(MaskedAutoregressiveFlow): 106 | """Masked autoregressive flow with affine transforms. 107 | 108 | Parameters 109 | ---------- 110 | n_inputs : int 111 | Number of inputs 112 | n_transforms : int 113 | Number of transforms 114 | n_conditional_inputs: int 115 | Number of conditionals inputs 116 | n_neurons : int 117 | Number of neurons per residual block in each transform 118 | n_blocks_per_transform : int 119 | Number of residual blocks per transform 120 | batch_norm_within_blocks : bool 121 | Enable batch normalisation within each residual block 122 | batch_norm_between_transforms : bool 123 | Enable batch norm between transforms 124 | activation : function 125 | Activation function to use. Defaults to ReLU 126 | dropout_probability : float 127 | Amount of dropout to apply. 0 being no dropout and 1 being drop 128 | all connections 129 | use_random_permutations : bool 130 | If True, the order of the inputs is randomly permuted between 131 | transforms. 132 | use_random_masks : bool 133 | If True, random masks are used for the autoregressive transform. 134 | distribution : :obj:`nflows.distribution.Distribution` 135 | Distribution object to use for that latent spae. If None, an n-d 136 | Gaussian is used. 137 | scale_activation : Optional[str, Callable] 138 | Activation for constraining the scale parameter. 139 | kwargs : 140 | Keyword arguments passed to `transform_class` when is it initialised. 141 | """ 142 | 143 | def __init__( 144 | self, 145 | n_inputs, 146 | n_transforms, 147 | n_conditional_inputs=None, 148 | n_neurons=32, 149 | n_blocks_per_transform=2, 150 | batch_norm_within_blocks=False, 151 | batch_norm_between_transforms=False, 152 | activation=F.relu, 153 | dropout_probability=0, 154 | use_random_permutations=False, 155 | use_random_masks=False, 156 | distribution=None, 157 | **kwargs, 158 | ): 159 | 160 | if USE_NFLOWS and kwargs.get("scale_activation", None) is not None: 161 | logger.error("nflows backend does not support scale activation") 162 | elif kwargs.get("scale_activation", None) is not None: 163 | kwargs["scale_activation"] = get_scale_activation( 164 | kwargs["scale_activation"] 165 | ) 166 | 167 | super().__init__( 168 | transforms.autoregressive.MaskedAffineAutoregressiveTransform, 169 | n_inputs=n_inputs, 170 | n_transforms=n_transforms, 171 | n_conditional_inputs=n_conditional_inputs, 172 | n_neurons=n_neurons, 173 | n_blocks_per_transform=n_blocks_per_transform, 174 | batch_norm_within_blocks=batch_norm_within_blocks, 175 | batch_norm_between_transforms=batch_norm_between_transforms, 176 | activation=activation, 177 | dropout_probability=dropout_probability, 178 | use_random_permutations=use_random_permutations, 179 | use_random_masks=use_random_masks, 180 | distribution=distribution, 181 | **kwargs, 182 | ) 183 | 184 | 185 | class MaskedPiecewiseLinearAutoregressiveFlow(MaskedAutoregressiveFlow): 186 | """Masked autoregressive flow with piecewise linear splines. 187 | 188 | Parameters 189 | ---------- 190 | n_inputs : int 191 | Number of inputs 192 | n_transforms : int 193 | Number of transforms 194 | n_conditional_inputs: int 195 | Number of conditionals inputs 196 | n_neurons : int 197 | Number of neurons per residual block in each transform 198 | n_blocks_per_transform : int 199 | Number of residual blocks per transform 200 | batch_norm_within_blocks : bool 201 | Enable batch normalisation within each residual block 202 | batch_norm_between_transforms : bool 203 | Enable batch norm between transforms 204 | activation : function 205 | Activation function to use. Defaults to ReLU 206 | dropout_probability : float 207 | Amount of dropout to apply. 0 being no dropout and 1 being drop 208 | all connections 209 | use_random_permutations : bool 210 | If True, the order of the inputs is randomly permuted between 211 | transforms. 212 | use_random_masks : bool 213 | If True, random masks are used for the autoregressive transform. 214 | distribution : :obj:`nflows.distribution.Distribution` 215 | Distribution object to use for that latent spae. If None, an n-d 216 | Gaussian is used. 217 | num_bins : int 218 | The number of bins. 219 | kwargs : 220 | Keyword arguments passed to `transform_class` when is it initialised. 221 | """ 222 | 223 | def __init__( 224 | self, 225 | n_inputs, 226 | n_transforms, 227 | n_conditional_inputs=None, 228 | n_neurons=32, 229 | n_blocks_per_transform=2, 230 | batch_norm_within_blocks=False, 231 | batch_norm_between_transforms=False, 232 | activation=F.relu, 233 | dropout_probability=0, 234 | use_random_permutations=False, 235 | use_random_masks=False, 236 | distribution=None, 237 | num_bins=10, 238 | **kwargs, 239 | ): 240 | super().__init__( 241 | transforms.autoregressive.MaskedPiecewiseLinearAutoregressiveTransform, 242 | n_inputs=n_inputs, 243 | n_transforms=n_transforms, 244 | n_conditional_inputs=n_conditional_inputs, 245 | n_neurons=n_neurons, 246 | n_blocks_per_transform=n_blocks_per_transform, 247 | batch_norm_within_blocks=batch_norm_within_blocks, 248 | batch_norm_between_transforms=batch_norm_between_transforms, 249 | activation=activation, 250 | dropout_probability=dropout_probability, 251 | use_random_permutations=use_random_permutations, 252 | use_random_masks=use_random_masks, 253 | distribution=distribution, 254 | num_bins=num_bins, 255 | **kwargs, 256 | ) 257 | 258 | 259 | class MaskedPiecewiseQuadraticAutoregressiveFlow(MaskedAutoregressiveFlow): 260 | """Masked autoregressive flow with piecewise quadratic splines. 261 | 262 | Parameters 263 | ---------- 264 | n_inputs : int 265 | Number of inputs 266 | n_transforms : int 267 | Number of transforms 268 | n_conditional_inputs: int 269 | Number of conditionals inputs 270 | n_neurons : int 271 | Number of neurons per residual block in each transform 272 | n_blocks_per_transform : int 273 | Number of residual blocks per transform 274 | batch_norm_within_blocks : bool 275 | Enable batch normalisation within each residual block 276 | batch_norm_between_transforms : bool 277 | Enable batch norm between transforms 278 | activation : function 279 | Activation function to use. Defaults to ReLU 280 | dropout_probability : float 281 | Amount of dropout to apply. 0 being no dropout and 1 being drop 282 | all connections 283 | use_random_permutations : bool 284 | If True, the order of the inputs is randomly permuted between 285 | transforms. 286 | use_random_masks : bool 287 | If True, random masks are used for the autoregressive transform. 288 | distribution : :obj:`nflows.distribution.Distribution` 289 | Distribution object to use for that latent spae. If None, an n-d 290 | Gaussian is used. 291 | num_bins : int 292 | The number of bins. 293 | kwargs : 294 | Keyword arguments passed to `transform_class` when is it initialised. 295 | """ 296 | 297 | def __init__( 298 | self, 299 | n_inputs, 300 | n_transforms, 301 | n_conditional_inputs=None, 302 | n_neurons=32, 303 | n_blocks_per_transform=2, 304 | batch_norm_within_blocks=False, 305 | batch_norm_between_transforms=False, 306 | activation=F.relu, 307 | dropout_probability=0, 308 | use_random_permutations=False, 309 | use_random_masks=False, 310 | distribution=None, 311 | num_bins=10, 312 | **kwargs, 313 | ): 314 | super().__init__( 315 | transforms.autoregressive.MaskedPiecewiseQuadraticAutoregressiveTransform, 316 | n_inputs=n_inputs, 317 | n_transforms=n_transforms, 318 | n_conditional_inputs=n_conditional_inputs, 319 | n_neurons=n_neurons, 320 | n_blocks_per_transform=n_blocks_per_transform, 321 | batch_norm_within_blocks=batch_norm_within_blocks, 322 | batch_norm_between_transforms=batch_norm_between_transforms, 323 | activation=activation, 324 | dropout_probability=dropout_probability, 325 | use_random_permutations=use_random_permutations, 326 | use_random_masks=use_random_masks, 327 | distribution=distribution, 328 | num_bins=num_bins, 329 | **kwargs, 330 | ) 331 | 332 | 333 | class MaskedPiecewiseCubicAutoregressiveAutoregressiveFlow( 334 | MaskedAutoregressiveFlow 335 | ): 336 | """Masked autoregressive flow with piecewise cubic splines. 337 | 338 | Parameters 339 | ---------- 340 | n_inputs : int 341 | Number of inputs 342 | n_transforms : int 343 | Number of transforms 344 | n_conditional_inputs: int 345 | Number of conditionals inputs 346 | n_neurons : int 347 | Number of neurons per residual block in each transform 348 | n_blocks_per_transform : int 349 | Number of residual blocks per transform 350 | batch_norm_within_blocks : bool 351 | Enable batch normalisation within each residual block 352 | batch_norm_between_transforms : bool 353 | Enable batch norm between transforms 354 | activation : function 355 | Activation function to use. Defaults to ReLU 356 | dropout_probability : float 357 | Amount of dropout to apply. 0 being no dropout and 1 being drop 358 | all connections 359 | use_random_permutations : bool 360 | If True, the order of the inputs is randomly permuted between 361 | transforms. 362 | use_random_masks : bool 363 | If True, random masks are used for the autoregressive transform. 364 | distribution : :obj:`nflows.distribution.Distribution` 365 | Distribution object to use for that latent spae. If None, an n-d 366 | Gaussian is used. 367 | num_bins : int 368 | The number of bins. 369 | kwargs : 370 | Keyword arguments passed to `transform_class` when is it initialised. 371 | """ 372 | 373 | def __init__( 374 | self, 375 | n_inputs, 376 | n_transforms, 377 | n_conditional_inputs=None, 378 | n_neurons=32, 379 | n_blocks_per_transform=2, 380 | batch_norm_within_blocks=False, 381 | batch_norm_between_transforms=False, 382 | activation=F.relu, 383 | dropout_probability=0, 384 | use_random_permutations=False, 385 | use_random_masks=False, 386 | distribution=None, 387 | num_bins=10, 388 | **kwargs, 389 | ): 390 | super().__init__( 391 | transforms.autoregressive.MaskedPiecewiseCubicAutoregressiveTransform, 392 | n_inputs=n_inputs, 393 | n_transforms=n_transforms, 394 | n_conditional_inputs=n_conditional_inputs, 395 | n_neurons=n_neurons, 396 | n_blocks_per_transform=n_blocks_per_transform, 397 | batch_norm_within_blocks=batch_norm_within_blocks, 398 | batch_norm_between_transforms=batch_norm_between_transforms, 399 | activation=activation, 400 | dropout_probability=dropout_probability, 401 | use_random_permutations=use_random_permutations, 402 | use_random_masks=use_random_masks, 403 | distribution=distribution, 404 | num_bins=num_bins, 405 | **kwargs, 406 | ) 407 | 408 | 409 | class MaskedPiecewiseRationalQuadraticAutoregressiveFlow( 410 | MaskedAutoregressiveFlow 411 | ): 412 | """Masked autoregressive flow with piecewise rational quadratic splines. 413 | 414 | Parameters 415 | ---------- 416 | n_inputs : int 417 | Number of inputs 418 | n_transforms : int 419 | Number of transforms 420 | n_conditional_inputs: int 421 | Number of conditionals inputs 422 | n_neurons : int 423 | Number of neurons per residual block in each transform 424 | n_blocks_per_transform : int 425 | Number of residual blocks per transform 426 | batch_norm_within_blocks : bool 427 | Enable batch normalisation within each residual block 428 | batch_norm_between_transforms : bool 429 | Enable batch norm between transforms 430 | activation : function 431 | Activation function to use. Defaults to ReLU 432 | dropout_probability : float 433 | Amount of dropout to apply. 0 being no dropout and 1 being drop 434 | all connections 435 | use_random_permutations : bool 436 | If True, the order of the inputs is randomly permuted between 437 | transforms. 438 | use_random_masks : bool 439 | If True, random masks are used for the autoregressive transform. 440 | distribution : :obj:`nflows.distribution.Distribution` 441 | Distribution object to use for that latent spae. If None, an n-d 442 | Gaussian is used. 443 | num_bins : int 444 | The number of bins. 445 | kwargs : 446 | Keyword arguments passed to `transform_class` when is it initialised. 447 | """ 448 | 449 | def __init__( 450 | self, 451 | n_inputs, 452 | n_transforms, 453 | n_conditional_inputs=None, 454 | n_neurons=32, 455 | n_blocks_per_transform=2, 456 | batch_norm_within_blocks=False, 457 | batch_norm_between_transforms=False, 458 | activation=F.relu, 459 | dropout_probability=0, 460 | use_random_permutations=False, 461 | use_random_masks=False, 462 | distribution=None, 463 | num_bins=10, 464 | **kwargs, 465 | ): 466 | super().__init__( 467 | transforms.autoregressive.MaskedPiecewiseRationalQuadraticAutoregressiveTransform, 468 | n_inputs=n_inputs, 469 | n_transforms=n_transforms, 470 | n_conditional_inputs=n_conditional_inputs, 471 | n_neurons=n_neurons, 472 | n_blocks_per_transform=n_blocks_per_transform, 473 | batch_norm_within_blocks=batch_norm_within_blocks, 474 | batch_norm_between_transforms=batch_norm_between_transforms, 475 | activation=activation, 476 | dropout_probability=dropout_probability, 477 | use_random_permutations=use_random_permutations, 478 | use_random_masks=use_random_masks, 479 | distribution=distribution, 480 | num_bins=num_bins, 481 | **kwargs, 482 | ) 483 | --------------------------------------------------------------------------------