├── .github
└── workflows
│ ├── CI-openmm.yml
│ ├── CI.yml
│ ├── codeql.yml
│ └── lint.yml
├── .gitignore
├── LICENSE
├── README.md
├── bgflow
├── __init__.py
├── _version.py
├── bg.py
├── distribution
│ ├── __init__.py
│ ├── distributions.py
│ ├── energy
│ │ ├── __init__.py
│ │ ├── ase.py
│ │ ├── base.py
│ │ ├── clipped.py
│ │ ├── double_well.py
│ │ ├── lennard_jones.py
│ │ ├── multi_double_well_potential.py
│ │ ├── openmm.py
│ │ ├── particles.py
│ │ └── xtb.py
│ ├── mixture.py
│ ├── normal.py
│ ├── product.py
│ └── sampling
│ │ ├── __init__.py
│ │ ├── _iterative_helpers.py
│ │ ├── _mcmc
│ │ ├── __init__.py
│ │ ├── analysis.py
│ │ ├── latent_sampling.py
│ │ ├── metropolis.py
│ │ ├── permutation.py
│ │ └── umbrella_sampling.py
│ │ ├── base.py
│ │ ├── buffer.py
│ │ ├── dataset.py
│ │ ├── iterative.py
│ │ └── mcmc.py
├── factory
│ ├── GNN_factory.py
│ ├── __init__.py
│ ├── conditioner_factory.py
│ ├── distribution_factory.py
│ ├── generator_builder.py
│ ├── icmarginals.py
│ ├── tensor_info.py
│ └── transformer_factory.py
├── nn
│ ├── __init__.py
│ ├── dense.py
│ ├── flow
│ │ ├── __init__.py
│ │ ├── affine.py
│ │ ├── base.py
│ │ ├── bnaf.py
│ │ ├── cdf.py
│ │ ├── checkerboard.py
│ │ ├── circular.py
│ │ ├── coupling.py
│ │ ├── crd_transform
│ │ │ ├── __init__.py
│ │ │ ├── _deprecated_ic.py
│ │ │ ├── ic.py
│ │ │ ├── ic_helper.py
│ │ │ └── pca.py
│ │ ├── diffeq.py
│ │ ├── dynamics
│ │ │ ├── __init__.py
│ │ │ ├── anode_dynamic.py
│ │ │ ├── blackbox.py
│ │ │ ├── density.py
│ │ │ ├── inversed.py
│ │ │ ├── kernel_dynamic.py
│ │ │ └── simple.py
│ │ ├── elementwise.py
│ │ ├── estimator
│ │ │ ├── __init__.py
│ │ │ ├── brute_force_estimator.py
│ │ │ └── hutchinson_estimator.py
│ │ ├── funnel.py
│ │ ├── inverted.py
│ │ ├── kronecker.py
│ │ ├── modulo.py
│ │ ├── orthogonal.py
│ │ ├── pppp.py
│ │ ├── sequential.py
│ │ ├── spline.py
│ │ ├── stochastic
│ │ │ ├── __init__.py
│ │ │ ├── augment.py
│ │ │ ├── langevin.py
│ │ │ ├── mcmc.py
│ │ │ └── snf_openmm.py
│ │ ├── torchtransform.py
│ │ ├── transformer
│ │ │ ├── __init__.py
│ │ │ ├── affine.py
│ │ │ ├── base.py
│ │ │ ├── entropy_scaling.py
│ │ │ ├── gaussian.py
│ │ │ ├── jax.py
│ │ │ ├── jax_bridge.py
│ │ │ └── spline.py
│ │ └── triangular.py
│ ├── periodic.py
│ └── training
│ │ ├── __init__.py
│ │ └── trainers.py
└── utils
│ ├── __init__.py
│ ├── autograd.py
│ ├── free_energy.py
│ ├── geometry.py
│ ├── internal_coordinates.py
│ ├── openmm.py
│ ├── rbf_kernels.py
│ ├── shape.py
│ ├── tensorops.py
│ ├── train.py
│ └── types.py
├── devtools
└── conda-env.yml
├── docs
├── Makefile
├── README.md
├── _static
│ └── README.md
├── _templates
│ ├── README.md
│ └── class.rst
├── api
│ ├── bg.rst
│ ├── energies.rst
│ ├── flows.rst
│ ├── samplers.rst
│ └── utils.rst
├── conf.py
├── examples.rst
├── getting_started.rst
├── index.rst
├── installation.rst
├── literature.bib
├── make.bat
├── requirements.rst
└── requirements.yaml
├── examples
├── datasets
│ └── README.rst
├── general_examples
│ ├── README.rst
│ └── plot_simple_bg.py
└── nb_examples
│ ├── README.rst
│ └── example_bg_coupling.ipynb
├── notebooks
├── alanine_dipeptide_augmented.ipynb
├── alanine_dipeptide_basics.ipynb
├── alanine_dipeptide_basics.py
├── alanine_dipeptide_spline.ipynb
├── cgn_GNN_example.ipynb
├── example.ipynb
├── example_equivariant_RNVP.ipynb
├── example_equivariant_nODE.ipynb
├── iterative_umbrella_sampling.ipynb
└── samplers.ipynb
├── readthedocs.yml
├── setup.cfg
├── setup.py
├── tests
├── conftest.py
├── data
│ └── alanine-dipeptide-nowater.pdb
├── distribution
│ ├── energy
│ │ ├── test_ase.py
│ │ ├── test_base.py
│ │ ├── test_clipped.py
│ │ ├── test_lennard_jones.py
│ │ ├── test_multi_double_well_potential.py
│ │ ├── test_openmm.py
│ │ └── test_xtb.py
│ ├── sampling
│ │ ├── test_buffer.py
│ │ ├── test_dataset.py
│ │ ├── test_iterative.py
│ │ ├── test_iterative_helpers.py
│ │ └── test_mcmc.py
│ ├── test_distribution.py
│ ├── test_normal.py
│ └── test_product.py
├── factory
│ ├── test_conditioner_factory.py
│ ├── test_distribution_factory.py
│ ├── test_generator_builder.py
│ ├── test_icmarginals.py
│ ├── test_tensor_info.py
│ └── test_transformer_factory.py
├── nn
│ ├── flow
│ │ ├── crd_transform
│ │ │ └── test_ic.py
│ │ ├── dynamics
│ │ │ └── test_kernel_dynamics.py
│ │ ├── estimators
│ │ │ └── test_hutchinson_estimator.py
│ │ ├── stochastic
│ │ │ └── test_snf_openmm.py
│ │ ├── test_cdf.py
│ │ ├── test_coupling.py
│ │ ├── test_inverted.py
│ │ ├── test_modulo.py
│ │ ├── test_nODE.py
│ │ ├── test_pppp.py
│ │ ├── test_sequential.py
│ │ ├── test_torchtransform.py
│ │ ├── test_triangular.py
│ │ └── transformer
│ │ │ ├── test_affine.py
│ │ │ ├── test_gaussian.py
│ │ │ ├── test_jax_bridge.py
│ │ │ └── test_spline.py
│ └── test_wrap_distances.py
├── test_bg.py
├── test_readme.py
└── utils
│ ├── test_autograd.py
│ ├── test_free_energy.py
│ ├── test_geometry.py
│ ├── test_rbf_kernels.py
│ ├── test_train.py
│ └── test_types.py
└── versioneer.py
/.github/workflows/CI-openmm.yml:
--------------------------------------------------------------------------------
1 | name: CI with OpenMM on conda
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 | branches:
9 | - "main"
10 | schedule:
11 | # Nightly tests run on master by default:
12 | # Scheduled workflows run on the latest commit on the default or base branch.
13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)
14 | - cron: "0 0 * * *"
15 |
16 |
17 | jobs:
18 | test:
19 | runs-on: ${{ matrix.os }}
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | os: [ubuntu-latest]
24 | python-version: [3.9]
25 |
26 | steps:
27 |
28 | - uses: actions/checkout@v2
29 |
30 | # More info on options: https://github.com/conda-incubator/setup-miniconda
31 | - uses: conda-incubator/setup-miniconda@v2
32 | with:
33 | python-version: ${{ matrix.python-version }}
34 | environment-file: devtools/conda-env.yml
35 | channels: conda-forge, pytorch, defaults
36 | activate-environment: test
37 | auto-update-conda: true
38 | auto-activate-base: false
39 | show-channel-urls: true
40 |
41 | - name: Install pip dependencies
42 | shell: bash -l {0}
43 | run: |
44 | pip install einops
45 | pip install nflows
46 |
47 | - name: Install package
48 | shell: bash -l {0}
49 | run: |
50 | python setup.py install
51 |
52 | - name: Test with pytest
53 | shell: bash -l {0}
54 | run: |
55 | pytest -vs
56 |
--------------------------------------------------------------------------------
/.github/workflows/CI.yml:
--------------------------------------------------------------------------------
1 | name: CI without OpenMM
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 | branches:
9 | - "main"
10 | schedule:
11 | # Nightly tests run on master by default:
12 | # Scheduled workflows run on the latest commit on the default or base branch.
13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)
14 | - cron: "0 0 * * *"
15 |
16 |
17 | jobs:
18 | test:
19 | runs-on: ${{ matrix.cfg.os }}
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | cfg:
24 | - { os: ubuntu-latest, python-version: 3.7, torch-version: 'torch>=1.9,<1.11' }
25 | - { os: ubuntu-latest, python-version: 3.8, torch-version: 'torch>=1.9,<1.11' }
26 | - { os: ubuntu-latest, python-version: 3.9, torch-version: 'torch>=1.9,<1.11' }
27 | - { os: ubuntu-latest, python-version: 3.7, torch-version: 'torch>=1.11' }
28 | - { os: ubuntu-latest, python-version: 3.8, torch-version: 'torch>=1.11' }
29 | - { os: ubuntu-latest, python-version: 3.9, torch-version: 'torch>=1.11' }
30 | - { os: windows-latest, python-version: 3.9, torch-version: 'torch>=1.11' }
31 | - { os: macos-latest, python-version: 3.9, torch-version: 'torch>=1.11' }
32 |
33 | steps:
34 |
35 | # WITHOUT OPENMM
36 | - uses: actions/checkout@v2
37 | - name: Set up Python ${{ matrix.cfg.python-version }}
38 | uses: actions/setup-python@v2
39 | with:
40 | python-version: ${{ matrix.cfg.python-version }}
41 | - name: Install dependencies
42 | run: |
43 | python -m pip install --upgrade pip
44 | pip install pytest "${{ matrix.cfg.torch-version }}" numpy nflows torchdiffeq einops netCDF4
45 | - name: Install package
46 | run: |
47 | python setup.py install
48 | - name: Test with pytest
49 | run: |
50 | pytest -vs
51 |
52 |
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | name: "CodeQL"
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 | schedule:
9 | - cron: "52 8 * * 3"
10 |
11 | jobs:
12 | analyze:
13 | name: Analyze
14 | runs-on: ubuntu-latest
15 | permissions:
16 | actions: read
17 | contents: read
18 | security-events: write
19 |
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | language: [ python ]
24 |
25 | steps:
26 | - name: Checkout
27 | uses: actions/checkout@v3
28 |
29 | - name: Initialize CodeQL
30 | uses: github/codeql-action/init@v2
31 | with:
32 | languages: ${{ matrix.language }}
33 | queries: +security-and-quality
34 |
35 | - name: Autobuild
36 | uses: github/codeql-action/autobuild@v2
37 |
38 | - name: Perform CodeQL Analysis
39 | uses: github/codeql-action/analyze@v2
40 | with:
41 | category: "/language:${{ matrix.language }}"
42 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Flake8 Linting
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 | branches:
9 | - "main"
10 | schedule:
11 | # Nightly tests run on master by default:
12 | # Scheduled workflows run on the latest commit on the default or base branch.
13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)
14 | - cron: "0 0 * * *"
15 |
16 |
17 | jobs:
18 | lint:
19 | runs-on: ${{ matrix.os }}
20 | strategy:
21 | matrix:
22 | os: [ubuntu-latest]
23 | python-version: [3.9]
24 | torch-version: ['torch>=1.9']
25 |
26 | steps:
27 |
28 | # WITHOUT OPENMM
29 | - uses: actions/checkout@v2
30 | - name: Set up Python ${{ matrix.python-version }}
31 | uses: actions/setup-python@v2
32 | with:
33 | python-version: ${{ matrix.python-version }}
34 | - name: Install dependencies
35 | run: |
36 | python -m pip install --upgrade pip
37 | pip install flake8 pytest "${{ matrix.torch-version }}" numpy nflows torchdiffeq
38 | - name: Lint with flake8
39 | run: |
40 | # stop the build if there are Python syntax errors or undefined names
41 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
42 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
43 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.sw[a-z]
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 | docs/api/generated
76 | docs/datasets
77 | docs/examples
78 | docs/nb_examples
79 |
80 |
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # celery beat schedule file
103 | celerybeat-schedule
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) 2021 Jonas Köhler, Andreas Krämer, Manuel Dibak, Leon Klein, Frank Noé
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/bgflow/__init__.py:
--------------------------------------------------------------------------------
1 | """Boltzmann Generators and Normalizing Flows in PyTorch"""
2 |
3 | # Handle versioneer
4 | from ._version import get_versions
5 | versions = get_versions()
6 | __version__ = versions['version']
7 | __git_revision__ = versions['full-revisionid']
8 | del get_versions, versions
9 |
10 | from .distribution import *
11 | from .nn import *
12 | from .factory import *
13 | from .bg import *
14 |
--------------------------------------------------------------------------------
/bgflow/distribution/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | ===============================================================================
4 | Samplers
5 | ===============================================================================
6 |
7 | .. autosummary::
8 | :toctree: generated/
9 | :template: class.rst
10 |
11 | Sampler
12 | DataSetSampler
13 | GaussianMCMCSampler
14 |
15 | ===============================================================================
16 | Distributions
17 | ===============================================================================
18 |
19 | .. autosummary::
20 | :toctree: generated/
21 | :template: class.rst
22 |
23 | TorchDistribution
24 | CustomDistribution
25 | UniformDistribution
26 | MixtureDistribution
27 | NormalDistribution
28 | TruncatedNormalDistribution
29 | MeanFreeNormalDistribution
30 | ProductEnergy
31 | ProductSampler
32 | ProductDistribution
33 |
34 | """
35 |
36 | from .distributions import *
37 | from .energy import *
38 | from .sampling import *
39 | from .normal import *
40 | from .mixture import *
41 | from .product import *
42 |
--------------------------------------------------------------------------------
/bgflow/distribution/energy/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | .. currentmodule: bgflow.distribution.energy
3 |
4 | ===============================================================================
5 | Double Well Potential
6 | ===============================================================================
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 | :template: class.rst
11 |
12 | DoubleWellEnergy
13 | MultiDoubleWellPotential
14 |
15 | ===============================================================================
16 | Lennard Jones Potential
17 | ===============================================================================
18 |
19 | .. autosummary::
20 | :toctree: generated/
21 | :template: class.rst
22 |
23 | LennardJonesPotential
24 |
25 | ===============================================================================
26 | OpenMMBridge
27 | ===============================================================================
28 |
29 | .. autosummary::
30 | :toctree: generated/
31 | :template: class.rst
32 |
33 | OpenMMBridge
34 | OpenMMEnergy
35 |
36 | ===============================================================================
37 | Particle Box
38 | ===============================================================================
39 |
40 | .. autosummary::
41 | :toctree: generated/
42 | :template: class.rst
43 |
44 | RepulsiveParticles
45 | HarmonicParticles
46 |
47 | ===============================================================================
48 | Clipped Energies
49 | ===============================================================================
50 |
51 | .. autosummary::
52 | :toctree: generated/
53 | :template: class.rst
54 |
55 | LinLogCutEnergy
56 | GradientClippedEnergy
57 |
58 | ===============================================================================
59 | Base
60 | ===============================================================================
61 |
62 | .. autosummary::
63 | :toctree: generated/
64 | :template: class.rst
65 |
66 | Energy
67 |
68 | """
69 |
70 |
71 | from .base import *
72 | from .double_well import *
73 | from .particles import *
74 | from .lennard_jones import *
75 | from .openmm import *
76 | from .multi_double_well_potential import *
77 | from .clipped import *
78 | from .openmm import *
79 | from .xtb import *
80 | from .ase import *
81 |
--------------------------------------------------------------------------------
/bgflow/distribution/energy/ase.py:
--------------------------------------------------------------------------------
1 | """Wrapper around ASE (atomic simulation environment)
2 | """
3 | __all__ = ["ASEBridge", "ASEEnergy"]
4 |
5 |
6 | import warnings
7 | import torch
8 | import numpy as np
9 | from .base import _BridgeEnergy, _Bridge
10 |
11 |
12 | class ASEBridge(_Bridge):
13 | """Wrapper around Atomic Simulation Environment.
14 |
15 | Parameters
16 | ----------
17 | atoms : ase.Atoms
18 | An `Atoms` object that has a calculator attached to it.
19 | temperature : float
20 | Temperature in Kelvin.
21 | err_handling : str
22 | How to deal with exceptions inside ase. One of `["ignore", "warning", "error"]`
23 |
24 | Notes
25 | -----
26 | Requires the ase package (installable with `conda install -c conda-forge ase`).
27 |
28 | """
29 | def __init__(
30 | self,
31 | atoms,
32 | temperature: float,
33 | err_handling: str = "warning"
34 | ):
35 | super().__init__()
36 | assert hasattr(atoms, "calc")
37 | self.atoms = atoms
38 | self.temperature = temperature
39 | self.err_handling = err_handling
40 |
41 | @property
42 | def n_atoms(self):
43 | return len(self.atoms)
44 |
45 | def _evaluate_single(
46 | self,
47 | positions: torch.Tensor,
48 | evaluate_force=True,
49 | evaluate_energy=True,
50 | ):
51 | from ase.units import kB, nm
52 | kbt = kB * self.temperature
53 | energy, force = None, None
54 | try:
55 | self.atoms.positions = positions * nm
56 | if evaluate_energy:
57 | energy = self.atoms.get_potential_energy() / kbt
58 | if evaluate_force:
59 | force = self.atoms.get_forces() / (kbt / nm)
60 | assert not np.isnan(energy)
61 | assert not np.isnan(force).any()
62 | except AssertionError as e:
63 | force[np.isnan(force)] = 0.
64 | energy = np.infty
65 | if self.err_handling == "warning":
66 | warnings.warn("Found nan in ase force or energy. Returning infinite energy and zero force.")
67 | elif self.err_handling == "error":
68 | raise e
69 | return energy, force
70 |
71 |
72 | class ASEEnergy(_BridgeEnergy):
73 | """Energy computation with calculators from the atomic simulation environment (ASE).
74 | Various molecular simulation programs provide wrappers for ASE,
75 | see https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html
76 | for a list of available calculators.
77 |
78 | Examples
79 | --------
80 | Use the calculator from the xtb package to compute the energy of a water molecule with the GFN2-xTB method.
81 | >>> from ase.build import molecule
82 | >>> from xtb.ase.calculator import XTB
83 | >>> water = molecule("H2O")
84 | >>> water.calc = XTB()
85 | >>> target = ASEEnergy(ASEBridge(water, 300.))
86 | >>> pos = torch.tensor(0.1*water.positions, **ctx)
87 | >>> energy = target.energy(pos)
88 |
89 | Parameters
90 | ----------
91 | ase_bridge : ASEBridge
92 | The wrapper object.
93 | two_event_dims : bool
94 | Whether to use two event dimensions.
95 | In this case, the energy call expects positions of shape (*batch_shape, n_atoms, 3).
96 | Otherwise, it expects positions of shape (*batch_shape, n_atoms * 3).
97 | """
98 | pass
99 |
--------------------------------------------------------------------------------
/bgflow/distribution/energy/clipped.py:
--------------------------------------------------------------------------------
1 | from ...utils.train import linlogcut, ClipGradient
2 | from .base import Energy
3 |
4 |
5 | __all__ = ["LinLogCutEnergy", "GradientClippedEnergy"]
6 |
7 |
8 | class LinLogCutEnergy(Energy):
9 | """Cut off energy at singularities.
10 |
11 | Parameters
12 | ----------
13 | energy : Energy
14 | high_energy : float
15 | Energies beyond this value are replaced by `u = high_energy + log(1 + energy - high_energy)`
16 | max_energy : float
17 | Upper bound for energies returned by this object.
18 | """
19 | def __init__(self, energy, high_energy=1e3, max_energy=1e9):
20 | super().__init__(energy.event_shapes)
21 | self.delegate = energy
22 | self.high_energy = high_energy
23 | self.max_energy = max_energy
24 |
25 | def _energy(self, *xs, **kwargs):
26 | u = self.delegate.energy(*xs, **kwargs)
27 | return linlogcut(u, high_val=self.high_energy, max_val=self.max_energy)
28 |
29 |
30 | class GradientClippedEnergy(Energy):
31 | """An Energy with clipped gradients. See `ClipGradient` for details."""
32 | def __init__(self, energy: Energy, gradient_clipping: ClipGradient):
33 | super().__init__(energy.event_shapes)
34 | self.delegate = energy
35 | self.clipping = gradient_clipping
36 |
37 | def _energy(self, *xs, **kwargs):
38 | return self.delegate.energy(*((self.clipping(x) for x in xs)), **kwargs)
39 |
--------------------------------------------------------------------------------
/bgflow/distribution/energy/double_well.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Energy
4 | from numpy import pi as PI
5 |
6 |
7 | __all__ = ["DoubleWellEnergy", "MultiDimensionalDoubleWell", "MuellerEnergy", "ModifiedWolfeQuapp"]
8 |
9 |
10 | class DoubleWellEnergy(Energy):
11 | def __init__(self, dim, a=0, b=-4.0, c=1.0):
12 | super().__init__(dim)
13 | self._a = a
14 | self._b = b
15 | self._c = c
16 |
17 | def _energy(self, x):
18 | d = x[..., [0]]
19 | v = x[..., 1:]
20 | e1 = self._a * d + self._b * d.pow(2) + self._c * d.pow(4)
21 | e2 = 0.5 * v.pow(2).sum(dim=-1, keepdim=True)
22 | return e1 + e2
23 |
24 |
25 | class MultiDimensionalDoubleWell(Energy):
26 | def __init__(self, dim, a=0.0, b=-4.0, c=1.0, transformer=None):
27 | super().__init__(dim)
28 | if not isinstance(a, torch.Tensor):
29 | a = torch.tensor(a)
30 | if not isinstance(b, torch.Tensor):
31 | b = torch.tensor(b)
32 | if not isinstance(c, torch.Tensor):
33 | c = torch.tensor(c)
34 | self.register_buffer("_a", a)
35 | self.register_buffer("_b", b)
36 | self.register_buffer("_c", c)
37 | if transformer is not None:
38 | self.register_buffer("_transformer", transformer)
39 | else:
40 | self._transformer = None
41 |
42 | def _energy(self, x):
43 | if self._transformer is not None:
44 | x = torch.matmul(x, self._transformer)
45 | e1 = self._a * x + self._b * x.pow(2) + self._c * x.pow(4)
46 | return e1.sum(dim=1, keepdim=True)
47 |
48 |
49 | class MuellerEnergy(Energy):
50 | def __init__(self, dim=2, scale1=0.15, scale2=15, beta=1.):
51 | super().__init__(dim)
52 | assert dim >= 2
53 | self._scale1 = scale1
54 | self._scale2 = scale2
55 | self._beta = beta
56 |
57 | def _energy(self, x):
58 | xx = x[..., [0]]
59 | yy = x[..., [1]]
60 | e1 = -200 * torch.exp(-(xx -1).pow(2) -10 * yy.pow(2))
61 | e2 = -100 * torch.exp(-xx.pow(2) -10 * (yy -0.5).pow(2))
62 | e3 = -170 * torch.exp(-6.5 * (0.5 + xx).pow(2) +11 * (xx +0.5) * (yy -1.5) -6.5 * (yy -1.5).pow(2))
63 | e4 = 15.0 * torch.exp(0.7 * (1 +xx).pow(2) +0.6 * (xx +1) * (yy -1) +0.7 * (yy -1).pow(2)) +146.7
64 | v = x[..., 2:]
65 | ev = self._scale2 * 0.5 * v.pow(2).sum(dim=-1, keepdim=True)
66 | return self._beta * (self._scale1 * (e1 + e2 + e3 + e4) + ev)
67 |
68 | @property
69 | def potential_str(self):
70 | pot_str = f'{self._scale1:g}*(-200*exp(-(x-1)^2-10*y^2)-100*exp(-x^2-10*(y-0.5)^2)-170*exp(-6.5*(0.5+x)^2+11*(x+0.5)*(y-1.5)-6.5*(y-1.5)^2)+15*exp(0.7*(1+x)^2+0.6*(x+1)*(y-1)+0.7*(y-1)^2)+146.7)'
71 | if self.dim >= 3:
72 | pot_str += f'+{self._scale2:g}*0.5*z^2'
73 | return pot_str
74 |
75 | class ModifiedWolfeQuapp(Energy):
76 | def __init__(self, dim=2, theta=-0.3*PI/2, scale1=2, scale2=15, beta=1.):
77 | super().__init__(dim)
78 | assert dim >= 2
79 | self._scale1 = scale1
80 | self._scale2 = scale2
81 | self._beta = beta
82 | self._c = torch.cos(torch.as_tensor(theta))
83 | self._s = torch.sin(torch.as_tensor(theta))
84 |
85 | def _energy(self, x):
86 | xx = self._c * x[..., [0]] - self._s * x[..., [1]]
87 | yy = self._s * x[..., [0]] + self._c * x[..., [1]]
88 | e4 = xx.pow(4) + yy.pow(4)
89 | e2 = -2 * xx.pow(2) - 4 * yy.pow(2) + 2 * xx * yy
90 | e1 = 0.8 * xx + 0.1 * yy + 9.28
91 | v = x[..., 2:]
92 | ev = self._scale2 * 0.5 * v.pow(2).sum(dim=-1, keepdim=True)
93 | return self._beta * (self._scale1 * (e4 + e2 + e1) + ev)
94 |
95 | @property
96 | def potential_str(self):
97 | x_str = f'({self._c:g}*x-{self._s:g}*y)'
98 | y_str = f'({self._s:g}*x+{self._c:g}*y)'
99 | pot_str = f'{self._scale1:g}*({x_str}^4+{y_str}^4-2*{x_str}^2-4*{y_str}^2+2*{x_str}*{y_str}+0.8*{x_str}+0.1*{y_str}+9.28)'
100 | if self.dim >= 3:
101 | pot_str += f'+{self._scale2:g}*0.5*z^2'
102 | return pot_str
103 |
--------------------------------------------------------------------------------
/bgflow/distribution/energy/lennard_jones.py:
--------------------------------------------------------------------------------
1 | from .base import Energy
2 | from bgflow.utils import distance_vectors, distances_from_vectors
3 | import torch
4 |
5 |
6 | __all__ = ["LennardJonesPotential"]
7 |
8 |
9 | def lennard_jones_energy_torch(r, eps=1.0, rm=1.0):
10 | lj = eps * ((rm / r) ** 12 - 2 * (rm / r) ** 6)
11 | return lj
12 |
13 |
14 | class LennardJonesPotential(Energy):
15 | def __init__(
16 | self, dim, n_particles, eps=1.0, rm=1.0, oscillator=True, oscillator_scale=1., two_event_dims=True):
17 | """Energy for a Lennard-Jones cluster
18 |
19 | Parameters
20 | ----------
21 | dim : int
22 | Number of degrees of freedom ( = space dimension x n_particles)
23 | n_particles : int
24 | Number of Lennard-Jones particles
25 | eps : float
26 | LJ well depth epsilon
27 | rm : float
28 | LJ well radius R_min
29 | oscillator : bool
30 | Whether to use a harmonic oscillator as an external force
31 | oscillator_scale : float
32 | Force constant of the harmonic oscillator energy
33 | two_event_dims : bool
34 | If True, the energy expects inputs with two event dimensions (particle_id, coordinate).
35 | Else, use only one event dimension.
36 | """
37 | if two_event_dims:
38 | super().__init__([n_particles, dim//n_particles])
39 | else:
40 | super().__init__(dim)
41 | self._n_particles = n_particles
42 | self._n_dims = dim // n_particles
43 |
44 | self._eps = eps
45 | self._rm = rm
46 | self.oscillator = oscillator
47 | self._oscillator_scale = oscillator_scale
48 |
49 | def _energy(self, x):
50 | batch_shape = x.shape[:-len(self.event_shape)]
51 | x = x.view(*batch_shape, self._n_particles, self._n_dims)
52 |
53 | dists = distances_from_vectors(
54 | distance_vectors(x.view(-1, self._n_particles, self._n_dims))
55 | )
56 |
57 | lj_energies = lennard_jones_energy_torch(dists, self._eps, self._rm)
58 | lj_energies = lj_energies.view(*batch_shape, -1).sum(dim=-1) / 2
59 |
60 | if self.oscillator:
61 | osc_energies = 0.5 * self._remove_mean(x).pow(2).sum(dim=(-2, -1)).view(*batch_shape)
62 | lj_energies = lj_energies + osc_energies * self._oscillator_scale
63 |
64 | return lj_energies[:, None]
65 |
66 | def _remove_mean(self, x):
67 | x = x.view(-1, self._n_particles, self._n_dims)
68 | return x - torch.mean(x, dim=1, keepdim=True)
69 |
70 | def _energy_numpy(self, x):
71 | x = torch.Tensor(x)
72 | return self._energy(x).cpu().numpy()
73 |
--------------------------------------------------------------------------------
/bgflow/distribution/energy/multi_double_well_potential.py:
--------------------------------------------------------------------------------
1 | from .base import Energy
2 | from bgflow.utils import compute_distances
3 |
4 | __all__ = ["MultiDoubleWellPotential"]
5 |
6 |
7 | class MultiDoubleWellPotential(Energy):
8 | """Energy for a many particle system with pair wise double-well interactions.
9 | The energy of the double-well is given via
10 |
11 | .. math::
12 | E_{DW}(d) = a \cdot (d-d_{\text{offset})^4 + b \cdot (d-d_{\text{offset})^2 + c.
13 |
14 | Parameters
15 | ----------
16 | dim : int
17 | Number of degrees of freedom ( = space dimension x n_particles)
18 | n_particles : int
19 | Number of particles
20 | a, b, c, offset : float
21 | parameters of the potential
22 | """
23 |
24 | def __init__(self, dim, n_particles, a, b, c, offset, two_event_dims=True):
25 | if two_event_dims:
26 | super().__init__([n_particles, dim // n_particles])
27 | else:
28 | super().__init__(dim)
29 | self._dim = dim
30 | self._n_particles = n_particles
31 | self._n_dimensions = dim // n_particles
32 | self._a = a
33 | self._b = b
34 | self._c = c
35 | self._offset = offset
36 |
37 | def _energy(self, x):
38 | x = x.contiguous()
39 | dists = compute_distances(x, self._n_particles, self._n_dimensions)
40 | dists = dists - self._offset
41 |
42 | energies = self._a * dists ** 4 + self._b * dists ** 2 + self._c
43 | return energies.sum(-1, keepdim=True)
44 |
--------------------------------------------------------------------------------
/bgflow/distribution/mixture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .energy import Energy
5 | from .sampling import Sampler
6 | from ..utils.types import assert_numpy
7 |
8 |
9 | __all__ = ["MixtureDistribution"]
10 |
11 |
12 | class MixtureDistribution(Energy, Sampler):
13 | def __init__(self, components, unnormed_log_weights=None, trainable_weights=False):
14 | assert all([c.dim == components[0].dim for c in components]),\
15 | "All mixture components must have the same dimensionality."
16 | super().__init__(components[0].dim)
17 | self._components = torch.nn.ModuleList(components)
18 |
19 | if unnormed_log_weights is None:
20 | unnormed_log_weights = torch.zeros(len(components))
21 | else:
22 | assert len(unnormed_log_weights.shape) == 1,\
23 | "Mixture weights must be a Tensor of shape `[n_components]`."
24 | assert len(unnormed_log_weights) == len(components),\
25 | "Number of mixture weights does not match number of components."
26 | if trainable_weights:
27 | self._unnormed_log_weights = torch.nn.Parameter(unnormed_log_weights)
28 | else:
29 | self.register_buffer("_unnormed_log_weights", unnormed_log_weights)
30 |
31 | @property
32 | def _log_weights(self):
33 | return torch.log_softmax(self._unnormed_log_weights, dim=-1)
34 |
35 | def _sample(self, n_samples):
36 | weights_numpy = assert_numpy(self._log_weights.exp())
37 | ns = np.random.multinomial(n_samples, weights_numpy, 1)[0]
38 | samples = [c.sample(n) for n, c in zip(ns, self._components)]
39 | return torch.cat(samples, dim=0)
40 |
41 | def _energy(self, x):
42 | energies = torch.stack([c.energy(x) for c in self._components], dim=-1)
43 | return -torch.logsumexp(-energies + self._log_weights.view(1, 1, -1), dim=-1)
44 |
45 | def _log_assignments(self, x):
46 | energies = torch.stack([c.energy(x) for c in self._components], dim=-1)
47 | return -energies
--------------------------------------------------------------------------------
/bgflow/distribution/sampling/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .mcmc import *
3 | from .dataset import *
4 | from .buffer import *
5 | from .iterative import *
--------------------------------------------------------------------------------
/bgflow/distribution/sampling/_iterative_helpers.py:
--------------------------------------------------------------------------------
1 | """Helper classes and functions for iterative samplers."""
2 |
3 | import torch
4 |
5 |
6 | __all__ = ["AbstractSamplerState", "default_set_samples_hook", "default_extract_sample_hook"]
7 |
8 |
9 | class AbstractSamplerState:
10 | """Defines the interface for implementations of the internal state of iterative samplers."""
11 |
12 | def as_dict(self):
13 | """Return a dictionary representing this instance. The dictionary has to define the
14 | keys that are used within `SamplerStep`s of an `IterativeSampler`, such as "samples", "energies", ...
15 | """
16 | raise NotImplementedError()
17 |
18 | def _replace(self, **kwargs):
19 | """Return a new object with changed fields.
20 | This function has to support all the keys that are used
21 | within `SamplerStep`s of an `IterativeSampler` as well as the keys "energies_up_to_date" and
22 | "forces_up_to_date"
23 | """
24 | raise NotImplementedError()
25 |
26 | def evaluate_energy_force(self, energy_model, evaluate_energies=True, evaluate_forces=True):
27 | """Return a new state with updated energies/forces."""
28 | state = self.as_dict()
29 | evaluate_energies = evaluate_energies and not state["energies_up_to_date"]
30 | energies = energy_model.energy(*state["samples"])[..., 0] if evaluate_energies else state["energies"]
31 |
32 | evaluate_forces = evaluate_forces and not state["forces_up_to_date"]
33 | forces = energy_model.force(*state["samples"]) if evaluate_forces else state["forces"]
34 | return self.replace(energies=energies, forces=forces)
35 |
36 | def replace(self, **kwargs):
37 | """Return a new state with updated fields."""
38 |
39 | # keep track of energies and forces
40 | state_dict = self.as_dict()
41 | if "energies" in kwargs:
42 | kwargs = {**kwargs, "energies_up_to_date": True}
43 | elif "samples" in kwargs:
44 | kwargs = {**kwargs, "energies_up_to_date": False}
45 | if "forces" in kwargs:
46 | kwargs = {**kwargs, "forces_up_to_date": True}
47 | elif "samples" in kwargs:
48 | kwargs = {**kwargs, "forces_up_to_date": False}
49 |
50 | # map to primary unit cell
51 | box_vectors = None
52 | if "box_vectors" in kwargs:
53 | box_vectors = kwargs["box_vectors"]
54 | elif "box_vectors" in state_dict:
55 | box_vectors = state_dict["box_vectors"]
56 | if "samples" in kwargs and box_vectors is not None:
57 | kwargs = {
58 | **kwargs,
59 | "samples": tuple(
60 | _map_to_primary_cell(x, cell)
61 | for x, cell in zip(kwargs["samples"], box_vectors)
62 | )
63 | }
64 | return self._replace(**kwargs)
65 |
66 |
67 | def default_set_samples_hook(x):
68 | """by default, use samples as is"""
69 | return x
70 |
71 |
72 | def default_extract_sample_hook(state: AbstractSamplerState):
73 | """Default extraction of samples from a SamplerState."""
74 | return state.as_dict()["samples"]
75 |
76 |
77 | def _bmv(m, bv):
78 | """Batched matrix-vector multiply."""
79 | return torch.einsum("ij,...j->...i", m, bv)
80 |
81 |
82 | def _map_to_primary_cell(x, cell):
83 | """Map coordinates to the primary unit cell of a periodic lattice.
84 |
85 | Parameters
86 | ----------
87 | x : torch.Tensor
88 | n-dimensional coordinates of shape (..., n), where n is the spatial dimension and ... denote an
89 | arbitrary number of batch dimensions.
90 | cell : torch.Tensor
91 | Lattice vectors (column-wise). Has to be upper triangular.
92 | """
93 | if cell is None:
94 | return x
95 | n = _bmv(torch.inverse(cell), x)
96 | n = torch.floor(n)
97 | return x - _bmv(cell, n)
98 |
--------------------------------------------------------------------------------
/bgflow/distribution/sampling/_mcmc/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = "noe"
2 |
3 | from .analysis import *
4 | from .latent_sampling import *
5 | from .metropolis import *
6 | from .permutation import *
7 | from .umbrella_sampling import *
8 |
--------------------------------------------------------------------------------
/bgflow/distribution/sampling/_mcmc/permutation.py:
--------------------------------------------------------------------------------
1 | __author__ = "noe"
2 |
3 | import numpy as np
4 | from scipy.optimize import linear_sum_assignment
5 |
6 | from deep_boltzmann.util import ensure_traj, distance_matrix_squared
7 |
8 |
9 | class HungarianMapper:
10 | def __init__(self, xref, dim=2, identical_particles=None):
11 | """ Permutes identical particles to minimize distance to reference structure.
12 |
13 | For a given structure or set of structures finds the permutation of identical particles
14 | that minimizes the mean square distance to a given reference structure. The optimization
15 | is done by solving the linear sum assignment problem with the Hungarian algorithm.
16 |
17 | Parameters
18 | ----------
19 | xref : array
20 | reference structure
21 | dim : int
22 | number of dimensions of particle system to define relation between vector position and
23 | particle index. If dim=2, coordinate vectors are [x1, y1, x2, y2, ...].
24 | indentical_particles : None or array
25 | indices of particles subject to permutation. If None, all particles are used
26 |
27 | """
28 | self.xref = xref
29 | self.dim = dim
30 | if identical_particles is None:
31 | identical_particles = np.arange(xref.size)
32 | self.identical_particles = identical_particles
33 | self.ip_indices = np.concatenate(
34 | [dim * self.identical_particles + i for i in range(dim)]
35 | )
36 | self.ip_indices.sort()
37 |
38 | def map(self, X):
39 | """ Maps X (configuration or trajectory) to reference structure by permuting identical particles """
40 | X = ensure_traj(X)
41 | Y = X.copy()
42 | C = distance_matrix_squared(
43 | np.tile(self.xref[:, self.ip_indices], (X.shape[0], 1)),
44 | X[:, self.ip_indices],
45 | )
46 |
47 | for i in range(C.shape[0]): # for each configuration
48 | _, col_assignment = linear_sum_assignment(C[i])
49 | assignment_components = [
50 | self.dim * col_assignment + i for i in range(self.dim)
51 | ]
52 | col_assignment = np.vstack(assignment_components).T.flatten()
53 | Y[i, self.ip_indices] = X[i, self.ip_indices[col_assignment]]
54 | return Y
55 |
56 | def is_permuted(self, X):
57 | """ Returns True for permuted configurations """
58 | X = ensure_traj(X)
59 | C = distance_matrix_squared(
60 | np.tile(self.xref[:, self.ip_indices], (X.shape[0], 1)),
61 | X[:, self.ip_indices],
62 | )
63 | isP = np.zeros(X.shape[0], dtype=bool)
64 |
65 | for i in range(C.shape[0]): # for each configuration
66 | _, col_assignment = linear_sum_assignment(C[i])
67 | assignment_components = [
68 | self.dim * col_assignment + i for i in range(self.dim)
69 | ]
70 | col_assignment = np.vstack(assignment_components).T.flatten()
71 | if not np.all(col_assignment == np.arange(col_assignment.size)):
72 | isP[i] = True
73 | return isP
74 |
--------------------------------------------------------------------------------
/bgflow/distribution/sampling/base.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Tuple
3 | import torch
4 | from ...utils.types import unpack_tensor_tuple, pack_tensor_in_list
5 |
6 | __all__ = ["Sampler"]
7 |
8 |
9 | class Sampler(torch.nn.Module):
10 | """Abstract base class for samplers.
11 |
12 | Parameters
13 | ----------
14 | return_hook : Callable, optional
15 | A function to postprocess the samples. This can (for example) be used to
16 | only return samples at a selected thermodynamic state of a replica exchange sampler
17 | or to combine the batch and sample dimension.
18 | The function takes a list of tensors and should return a list of tensors.
19 | Each tensor contains a batch of samples.
20 | """
21 |
22 | def __init__(self, return_hook=lambda x: x, **kwargs):
23 | super().__init__(**kwargs)
24 | self.return_hook = return_hook
25 |
26 | def _sample_with_temperature(self, n_samples, temperature, *args, **kwargs):
27 | raise NotImplementedError()
28 |
29 | def _sample(self, n_samples, *args, **kwargs):
30 | raise NotImplementedError()
31 |
32 | def sample(self, n_samples, temperature=1.0, *args, **kwargs):
33 | """Create a number of samples.
34 |
35 | Parameters
36 | ----------
37 | n_samples : int
38 | The number of samples to be created.
39 | temperature : float, optional
40 | The relative temperature at which to create samples.
41 | Only available for sampler that implement `_sample_with_temperature`.
42 |
43 | Returns
44 | -------
45 | samples : Union[torch.Tensor, Tuple[torch.Tensor, ...]]
46 | If this sampler reflects a joint distribution of multiple tensors,
47 | it returns a tuple of tensors, each of which have length n_samples.
48 | Otherwise it returns a single tensor of length n_samples.
49 | """
50 | if isinstance(temperature, float) and temperature == 1.0:
51 | samples = self._sample(n_samples, *args, **kwargs)
52 | else:
53 | samples = self._sample_with_temperature(n_samples, temperature, *args, **kwargs)
54 | samples = pack_tensor_in_list(samples)
55 | return unpack_tensor_tuple(self.return_hook(samples))
56 |
57 | def sample_to_cpu(self, n_samples, batch_size=64, *args, **kwargs):
58 | """A utility method for creating many samples that might not fit into GPU memory."""
59 | with torch.no_grad():
60 | samples = self.sample(min(n_samples, batch_size), *args, **kwargs)
61 | samples = pack_tensor_in_list(samples)
62 | samples = [tensor.detach().cpu() for tensor in samples]
63 | while len(samples[0]) < n_samples:
64 | new_samples = self.sample(min(n_samples-len(samples[0]), batch_size), *args, **kwargs)
65 | new_samples = pack_tensor_in_list(new_samples)
66 | for i, new in enumerate(new_samples):
67 | samples[i] = torch.cat([samples[i], new.detach().cpu()], dim=0)
68 | return unpack_tensor_tuple(samples)
69 |
--------------------------------------------------------------------------------
/bgflow/factory/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from .tensor_info import *
4 | from .transformer_factory import *
5 | from .conditioner_factory import *
6 | from .distribution_factory import *
7 | from .icmarginals import *
8 | from .generator_builder import *
9 |
--------------------------------------------------------------------------------
/bgflow/factory/distribution_factory.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import bgflow as bg
4 |
5 |
6 | __all__ = ["make_distribution"]
7 |
8 | # === Prior Factory ===
9 |
10 |
11 | def make_distribution(distribution_type, shape, **kwargs):
12 | factory = DISTRIBUTION_FACTORIES[distribution_type]
13 | return factory(shape=shape, **kwargs)
14 |
15 |
16 | def _make_uniform_distribution(shape, device=None, dtype=None, **kwargs):
17 | defaults = {
18 | "low": torch.zeros(shape),
19 | "high": torch.ones(shape)
20 | }
21 | defaults.update(kwargs)
22 | for key in defaults:
23 | if isinstance(defaults[key], torch.Tensor):
24 | defaults[key] = defaults[key].to(device=device, dtype=dtype)
25 | return bg.UniformDistribution(**defaults)
26 |
27 |
28 | def _make_normal_distribution(shape, device=None, dtype=None, **kwargs):
29 | defaults = {
30 | "dim": shape,
31 | "mean": torch.zeros(shape),
32 | }
33 | defaults.update(kwargs)
34 | for key in defaults:
35 | if isinstance(defaults[key], torch.Tensor):
36 | defaults[key] = defaults[key].to(device=device, dtype=dtype)
37 | return bg.NormalDistribution(**defaults)
38 |
39 |
40 | def _make_truncated_normal_distribution(shape, device=None, dtype=None, **kwargs):
41 | defaults = {
42 | "mu": torch.zeros(shape),
43 | "sigma": torch.ones(shape),
44 | }
45 | defaults.update(kwargs)
46 | for key in defaults:
47 | if isinstance(defaults[key], torch.Tensor):
48 | defaults[key] = defaults[key].to(device=device, dtype=dtype)
49 | return bg.TruncatedNormalDistribution(**defaults)
50 |
51 |
52 | DISTRIBUTION_FACTORIES = {
53 | bg.UniformDistribution: _make_uniform_distribution,
54 | bg.NormalDistribution: _make_normal_distribution,
55 | bg.TruncatedNormalDistribution: _make_truncated_normal_distribution
56 | }
57 |
58 |
--------------------------------------------------------------------------------
/bgflow/factory/transformer_factory.py:
--------------------------------------------------------------------------------
1 | """Factory for flow transformations."""
2 |
3 | import torch
4 | from ..nn.flow.inverted import InverseFlow
5 | from ..nn.flow.transformer.affine import AffineTransformer
6 | from ..nn.flow.transformer.spline import ConditionalSplineTransformer
7 |
8 | __all__ = ["make_transformer"]
9 |
10 |
11 | def make_transformer(transformer_type, what, shape_info, conditioners, inverse=False, **kwargs):
12 | """Factory function.
13 |
14 | Parameters
15 | ----------
16 | transformer_type : bgflow.
17 | """
18 | factory = TRANSFORMER_FACTORIES[transformer_type]
19 | transformer = factory(what=what, shape_info=shape_info, conditioners=conditioners, **kwargs)
20 | if inverse:
21 | transformer = InverseFlow(transformer)
22 | return transformer
23 |
24 |
25 | def _make_spline_transformer(what, shape_info, conditioners, **kwargs):
26 | return ConditionalSplineTransformer(
27 | is_circular=shape_info.is_circular(what),
28 | **conditioners,
29 | **kwargs
30 | )
31 |
32 |
33 | def _make_affine_transformer(what, shape_info, conditioners, **kwargs):
34 | if shape_info.dim_circular(what) not in [0, shape_info[what[0]][-1]]:
35 | raise NotImplementedError(
36 | "Circular affine transformers are currently "
37 | "not supported for partly circular indices."
38 | )
39 | return AffineTransformer(
40 | **conditioners,
41 | is_circular=shape_info.dim_circular(what) > 0,
42 | **kwargs
43 | )
44 |
45 | # def _make_sigmoid_transformer(
46 | # what,
47 | # shape_info,
48 | # conditioners,
49 | # smoothness_type="type1",
50 | # zero_boundary_left=False,
51 | # zero_boundary_right=False,
52 | # **kwargs
53 | # ):
54 | # assert all(field.is_circular == what[0].is_circular for field in what)
55 | # transformer = bg.MixtureCDFTransformer(
56 | # compute_weights=conditioners["weights"],
57 | # compute_components=bg.AffineSigmoidComponents(
58 | # conditional_ramp=bg.SmoothRamp(
59 | # compute_alpha=conditioners["alphas"],
60 | # unimodal=True,
61 | # ramp_type=smoothness_type
62 | # ),
63 | # compute_params=conditioners["params"],
64 | # periodic=what[0].is_circular,
65 | # min_density=torch.tensor(1e-6),
66 | # log_sigma_bound=torch.tensor(1.),
67 | # zero_boundary_left=zero_boundary_left,
68 | # zero_boundary_right=zero_boundary_right,
69 | # **kwargs
70 | # )
71 | # )
72 | # transformer = bg.WrapCDFTransformerWithInverse(
73 | # transformer=transformer,
74 | # oracle=bg.GridInversion( #bg.BisectionRootFinder(
75 | # transformer=transformer,
76 | # compute_init_grid=lambda x, y: torch.linspace(0, 1, 100).view(-1, 1, 1).repeat(1, *y.shape).to(y)
77 | # #abs_tol=torch.tensor(1e-5), max_iters=max_iters, verbose=False, raise_exception=True
78 | # )
79 | # )
80 | # return transformer
81 |
82 |
83 | TRANSFORMER_FACTORIES = {
84 | ConditionalSplineTransformer: _make_spline_transformer,
85 | AffineTransformer: _make_affine_transformer,
86 | # MixtureCDFTransformer: _make_sigmoid_transformer
87 | }
88 |
89 |
--------------------------------------------------------------------------------
/bgflow/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .dense import *
2 | from .periodic import *
3 | from .flow import *
4 | from .training import *
5 |
--------------------------------------------------------------------------------
/bgflow/nn/dense.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..utils.types import is_list_or_tuple
4 |
5 |
6 | __all__ = ["DenseNet", "MeanFreeDenseNet"]
7 |
8 |
9 | class DenseNet(torch.nn.Module):
10 | def __init__(self, n_units, activation=None, weight_scale=1.0, bias_scale=0.0):
11 | """
12 | Simple multi-layer perceptron.
13 |
14 | Parameters:
15 | -----------
16 | n_units : List / Tuple of integers.
17 | activation : Non-linearity or List / Tuple of non-linearities.
18 | If List / Tuple then each nonlinearity will be placed after each respective hidden layer.
19 | If just a single non-linearity, will be applied to all hidden layers.
20 | If set to None no non-linearity will be applied.
21 | """
22 | super().__init__()
23 |
24 | dims_in = n_units[:-1]
25 | dims_out = n_units[1:]
26 |
27 | if is_list_or_tuple(activation):
28 | assert len(activation) == len(n_units) - 2
29 |
30 | layers = []
31 | for i, (dim_in, dim_out) in enumerate(zip(dims_in, dims_out)):
32 | layers.append(torch.nn.Linear(dim_in, dim_out))
33 | layers[-1].weight.data *= weight_scale
34 | if bias_scale > 0.0:
35 | layers[-1].bias.data = (
36 | torch.Tensor(layers[-1].bias.data).uniform_() * bias_scale
37 | )
38 | if i < len(n_units) - 2:
39 | if activation is not None:
40 | if is_list_or_tuple(activation):
41 | layers.append(activation[i])
42 | else:
43 | layers.append(activation)
44 |
45 | self._layers = torch.nn.Sequential(*layers)
46 |
47 | def forward(self, x):
48 | return self._layers(x)
49 |
50 |
51 | class MeanFreeDenseNet(DenseNet):
52 | def forward(self, x):
53 | y = self._layers(x)
54 | return y - y.mean(dim=1, keepdim=True)
55 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | .. currentmodule: bgflow.nn.flow
3 |
4 | ===============================================================================
5 | Coupling flows
6 | ===============================================================================
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 | :template: class.rst
11 |
12 | CouplingFlow
13 | SplitFlow
14 | MergeFlow
15 | SwapFlow
16 | WrapFlow
17 | Transformer
18 | AffineTransformer
19 | TruncatedGaussianTransformer
20 | ConditionalSplineTransformer
21 | ScalingLayer
22 | EntropyScalingLayer
23 |
24 | ===============================================================================
25 | Continuous Normalizing Flows
26 | ===============================================================================
27 |
28 | .. autosummary::
29 | :toctree: generated/
30 | :template: class.rst
31 |
32 | DiffEqFlow
33 |
34 | Dynamics Functions
35 | ---------------------
36 |
37 | .. autosummary::
38 | :toctree: generated/
39 | :template: class.rst
40 |
41 | BlackBoxDynamics
42 | TimeIndependentDynamics
43 | KernelDynamics
44 | DensityDynamics
45 |
46 |
47 | Jacobian Trace Estimators
48 | ------------------------------
49 |
50 | .. autosummary::
51 | :toctree: generated/
52 | :template: class.rst
53 |
54 | BruteForceEstimator
55 | HutchinsonEstimator
56 |
57 | ===============================================================================
58 | Stochastic Normalizing Flows
59 | ===============================================================================
60 |
61 | .. autosummary::
62 | :toctree: generated/
63 | :template: class.rst
64 |
65 | MetropolisMCFlow
66 | BrownianFlow
67 | LangevinFlow
68 | StochasticAugmentation
69 | OpenMMStochasticFlow
70 | PathProbabilityIntegrator
71 | BrownianPathProbabilityIntegrator
72 |
73 | ===============================================================================
74 | Internal Coordinate Transformations
75 | ===============================================================================
76 |
77 | .. autosummary::
78 | :toctree: generated/
79 | :template: class.rst
80 |
81 | RelativeInternalCoordinateTransformation
82 | GlobalInternalCoordinateTransformation
83 | MixedCoordinateTransformation
84 | WhitenFlow
85 |
86 | ===============================================================================
87 | CDF Transformations
88 | ===============================================================================
89 |
90 | .. autosummary::
91 | :toctree: generated/
92 | :template: class.rst
93 |
94 | CDFTransform
95 | DistributionTransferFlow
96 | ConstrainGaussianFlow
97 |
98 | ===============================================================================
99 | Base
100 | ===============================================================================
101 |
102 | .. autosummary::
103 | :toctree: generated/
104 | :template: class.rst
105 |
106 | Flow
107 | InverseFlow
108 | SequentialFlow
109 |
110 | ===============================================================================
111 | Other
112 | ===============================================================================
113 | Docs and/or classification required
114 |
115 | .. autosummary::
116 | :toctree: generated/
117 | :template: class.rst
118 |
119 | AffineFlow
120 | CheckerboardFlow
121 | BentIdentity
122 | FunnelFlow
123 | KroneckerProductFlow
124 | PseudoOrthogonalFlow
125 | InvertiblePPPP
126 | PPPPScheduler
127 | TorchTransform
128 | TriuFlow
129 | BNARFlow
130 | """
131 |
132 | from .base import *
133 | from .crd_transform import *
134 | from .dynamics import *
135 | from .estimator import *
136 | from .stochastic import *
137 | from .transformer import *
138 |
139 | from .affine import *
140 | from .coupling import *
141 | from .funnel import FunnelFlow
142 | from .kronecker import KroneckerProductFlow
143 | from .sequential import SequentialFlow
144 | from .inverted import *
145 | from .checkerboard import CheckerboardFlow
146 | from .bnaf import BNARFlow
147 | from .elementwise import *
148 | from .orthogonal import *
149 | from .triangular import *
150 | from .pppp import *
151 | from .diffeq import DiffEqFlow
152 | from .cdf import *
153 | from .torchtransform import *
154 | from .modulo import *
155 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/affine.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Flow
4 |
5 |
6 | __all__ = ["AffineFlow"]
7 |
8 |
9 | class AffineFlow(Flow):
10 | def __init__(self, n_dims, use_scaling=True, use_translation=True):
11 | super().__init__()
12 | self._n_dims = n_dims
13 | self._log_sigma = None
14 | if use_scaling:
15 | self._log_sigma = torch.nn.Parameter(torch.zeros(self._n_dims))
16 | else:
17 | self._log_sigma = None
18 | if use_translation:
19 | self._mu = torch.nn.Parameter(torch.zeros(self._n_dims))
20 | else:
21 | self._mu = None
22 |
23 | def _forward(self, x, **kwargs):
24 | assert x.shape[-1] == self._n_dims, "dimension `x` does not match `n_dims`"
25 | dlogp = torch.zeros(*x.shape[:-1], 1).to(x)
26 | if self._log_sigma is not None:
27 | sigma = torch.exp(self._log_sigma.to(x))
28 | dlogp = dlogp + self._log_sigma.sum()
29 | x = sigma * x
30 | if self._mu is not None:
31 | x = x + self._mu.to(x)
32 | return x, dlogp
33 |
34 | def _inverse(self, x, **kwargs):
35 | assert x.shape[-1] == self._n_dims, "dimension `x` does not match `n_dims`"
36 | dlogp = torch.zeros(*x.shape[:-1], 1).to(x)
37 | if self._mu is not None:
38 | x = x - self._mu.to(x)
39 | if self._log_sigma is not None:
40 | sigma_inv = torch.exp(-self._log_sigma.to(x))
41 | dlogp = dlogp - self._log_sigma.sum()
42 | x = sigma_inv * x
43 | return x, dlogp
44 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | __all__ = ["Flow"]
5 |
6 |
7 | class Flow(torch.nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def _forward(self, *xs, **kwargs):
12 | raise NotImplementedError()
13 |
14 | def _inverse(self, *xs, **kwargs):
15 | raise NotImplementedError()
16 |
17 | def forward(self, *xs, inverse=False, **kwargs):
18 | """
19 | Forward method of the flow.
20 | Computes the forward or inverse direction of the flow.
21 |
22 | Parameters
23 | ----------
24 | xs : torch.Tensor
25 | Input of the flow
26 |
27 | inverse : boolean
28 | Whether to compute the forward or inverse
29 | """
30 | if inverse:
31 | return self._inverse(*xs, **kwargs)
32 | else:
33 | return self._forward(*xs, **kwargs)
34 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/checkerboard.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from itertools import product
4 |
5 | from .base import Flow
6 |
7 |
8 | def _make_checkerboard_idxs(sz):
9 | even = np.arange(sz, dtype=np.int64) % 2
10 | odd = 1 - even
11 | grid = np.arange(sz * sz, dtype=np.int64)
12 | idxs = []
13 | for i, j in product([odd, even], repeat=2):
14 | mask = np.outer(i, j).astype(bool).reshape(-1)
15 | chunk = grid[mask]
16 | idxs.append(chunk)
17 | return np.concatenate(idxs)
18 |
19 |
20 | def _checkerboard_2x2_masks(sz):
21 | mask = _make_checkerboard_idxs(sz)
22 | inv_mask = np.argsort(mask)
23 | offset = sz ** 2 // 4
24 | sub_masks = [
25 | mask[i * offset:(i+1) * offset]
26 | for i in range(4)
27 | ]
28 | return inv_mask, sub_masks
29 |
30 |
31 | class CheckerboardFlow(Flow):
32 | def __init__(self, size):
33 | super().__init__()
34 | self._size = size
35 | inv_mask, submasks = _checkerboard_2x2_masks(size)
36 | self.register_buffer("_sub_masks", torch.LongTensor(submasks))
37 | self.register_buffer("_inv_mask", torch.LongTensor(inv_mask))
38 |
39 | def _forward(self, *xs, **kwargs):
40 | n_batch = xs[0].shape[0]
41 | dlogp = torch.zeros(n_batch)
42 | sz = self._size // 2
43 | assert len(xs) == 1
44 | x = xs[0]
45 | assert len(x.shape) == 4 and x.shape[1] == self._size and x.shape[2] == self._size,\
46 | "`x` needs to be of shape `[n_batch, size, size, n_filters]`"
47 | x = x.view(n_batch, self._size ** 2, -1)
48 | xs = []
49 | for i in range(4):
50 | patch = x[:, self._sub_masks[i], :].view(n_batch, sz, sz, -1)
51 | xs.append(patch)
52 | return (*xs, dlogp)
53 | return x, dlogp
54 |
55 | def _inverse(self, *xs, **kwargs):
56 | n_batch = xs[0].shape[0]
57 | dlogp = torch.zeros(n_batch)
58 | sz = self._size // 2
59 | assert len(xs) == 4
60 | assert all(x.shape[1] == self._size // 2 and x.shape[2] == self._size // 2 for x in xs),\
61 | "all `xs` needs to be of shape `[n_batch, size, size, n_filters]`"
62 | xs = [x.view(n_batch, sz ** 2, -1) for x in xs]
63 | x = torch.cat(xs, axis=-2)[:, self._inv_mask, :].view(
64 | n_batch, self._size, self._size, -1
65 | )
66 | return x, dlogp
67 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/crd_transform/__init__.py:
--------------------------------------------------------------------------------
1 | from .pca import *
2 | from .ic import *
3 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/crd_transform/pca.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from bgflow.nn.flow.base import Flow
4 |
5 |
6 | __all__ = ["WhitenFlow"]
7 |
8 |
9 | def _pca(X0, keepdims=None):
10 | """Implements PCA in Numpy.
11 |
12 | This is not written for training in torch because derivatives of eig are not implemented
13 |
14 | """
15 | if keepdims is None:
16 | keepdims = X0.shape[1]
17 |
18 | # pca
19 | X0mean = X0.mean(axis=0)
20 | X0meanfree = X0 - X0mean
21 | C = np.matmul(X0meanfree.T, X0meanfree) / (X0meanfree.shape[0] - 1.0)
22 | eigval, eigvec = np.linalg.eigh(C)
23 |
24 | # sort in descending order and keep only the wanted eigenpairs
25 | I = np.argsort(eigval)[::-1]
26 | I = I[:keepdims]
27 | eigval = eigval[I]
28 | std = np.sqrt(eigval)
29 | eigvec = eigvec[:, I]
30 |
31 | # whiten and unwhiten matrices
32 | Twhiten = np.matmul(eigvec, np.diag(1.0 / std))
33 | Tblacken = np.matmul(np.diag(std), eigvec.T)
34 | return X0mean, Twhiten, Tblacken, std
35 |
36 |
37 | class WhitenFlow(Flow):
38 | def __init__(self, X0, keepdims=None, whiten_inverse=True):
39 | """Performs static whitening of the data given PCA of X0
40 |
41 | Parameters:
42 | -----------
43 | X0 : array
44 | Initial Data on which PCA will be computed.
45 | keepdims : int or None
46 | Number of dimensions to keep. By default, all dimensions will be kept
47 | whiten_inverse : bool
48 | Whitens when calling inverse (default). Otherwise when calling forward
49 |
50 | """
51 | super().__init__()
52 | if keepdims is None:
53 | keepdims = X0.shape[1]
54 | self.dim = X0.shape[1]
55 | self.keepdims = keepdims
56 | self.whiten_inverse = whiten_inverse
57 |
58 | X0_np = X0.detach().cpu().numpy()
59 | X0mean, Twhiten, Tblacken, std = _pca(X0_np, keepdims=keepdims)
60 | # self.X0mean = torch.tensor(X0mean)
61 | self.register_buffer("X0mean", torch.tensor(X0mean).to(X0))
62 | # self.Twhiten = torch.tensor(Twhiten)
63 | self.register_buffer("Twhiten", torch.tensor(Twhiten).to(X0))
64 | # self.Tblacken = torch.tensor(Tblacken)
65 | self.register_buffer("Tblacken", torch.tensor(Tblacken).to(X0))
66 | # self.std = torch.tensor(std)
67 | self.register_buffer("std", torch.tensor(std).to(X0))
68 | if torch.any(self.std <= 0):
69 | raise ValueError(
70 | "Cannot construct whiten layer because trying to keep nonpositive eigenvalues."
71 | )
72 | self.jacobian_xz = -torch.sum(torch.log(self.std))
73 |
74 | def _whiten(self, x):
75 | # Whiten
76 | output_z = torch.matmul(x - self.X0mean, self.Twhiten)
77 | # if self.keepdims < self.dim:
78 | # junk_dims = self.dim - self.keepdims
79 | # output_z = torch.cat([output_z, torch.Tensor(x.shape[0], junk_dims).normal_()], dim=1)
80 | # Jacobian
81 | dlogp = self.jacobian_xz * torch.ones((x.shape[0], 1)).to(x)
82 |
83 | return output_z, dlogp
84 |
85 | def _blacken(self, x):
86 | # if we have reduced the dimension, we ignore the last dimensions from the z-direction.
87 | # if self.keepdims < self.dim:
88 | # x = x[:, 0:self.keepdims]
89 | output_x = torch.matmul(x, self.Tblacken) + self.X0mean
90 | # Jacobian
91 | dlogp = -self.jacobian_xz * torch.ones((x.shape[0], 1)).to(x)
92 |
93 | return output_x, dlogp
94 |
95 | def _forward(self, x, *args, **kwargs):
96 | if self.whiten_inverse:
97 | y, dlogp = self._blacken(x)
98 | else:
99 | y, dlogp = self._whiten(x)
100 | return y, dlogp
101 |
102 | def _inverse(self, x, *args, **kwargs):
103 | if self.whiten_inverse:
104 | y, dlogp = self._whiten(x)
105 | else:
106 | y, dlogp = self._blacken(x)
107 | return y, dlogp
108 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/diffeq.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Flow
4 | from .dynamics import (
5 | DensityDynamics,
6 | InversedDynamics,
7 | AnodeDynamics
8 | )
9 |
10 |
11 | class DiffEqFlow(Flow):
12 | """
13 | Neural Ordinary Differential Equations flow :footcite:`chen2018neural`
14 | with the choice of optimize than discretize (use_checkpoints=False)
15 | and discretize than optimize :footcite:`gholami2019anode` (use_checkpoints=True) for the ODE solver.
16 |
17 | References
18 | ----------
19 | .. footbibliography::
20 |
21 | """
22 |
23 | def __init__(
24 | self,
25 | dynamics,
26 | integrator="dopri5",
27 | atol=1e-10,
28 | rtol=1e-5,
29 | n_time_steps=2,
30 | t_max=1.,
31 | use_checkpoints=False,
32 | **kwargs
33 | ):
34 | super().__init__()
35 | self._dynamics = DensityDynamics(dynamics)
36 | self._inverse_dynamics = DensityDynamics(InversedDynamics(dynamics, t_max))
37 | self._integrator_method = integrator
38 | self._integrator_atol = atol
39 | self._integrator_rtol = rtol
40 | self._n_time_steps = n_time_steps
41 | self._t_max = t_max
42 | self._use_checkpoints = use_checkpoints
43 | self._kwargs = kwargs
44 |
45 | def _forward(self, *xs, **kwargs):
46 | return self._run_ode(*xs, dynamics=self._dynamics, **kwargs)
47 |
48 | def _inverse(self, *xs, **kwargs):
49 | return self._run_ode(*xs, dynamics=self._inverse_dynamics, **kwargs)
50 |
51 | def _run_ode(self, *xs, dynamics, **kwargs):
52 | """
53 | Runs the ODE solver.
54 |
55 | Parameters
56 | ----------
57 | xs : PyTorch tensor
58 | The current configuration of the system
59 | dynamics : PyTorch module
60 | A dynamics function that computes the change of the system and its density.
61 |
62 | Returns
63 | -------
64 | ys : PyTorch tensor
65 | The new configuration of the system after being propagated by the dynamics.
66 | dlogp : PyTorch tensor
67 | The change in log density due to the dynamics.
68 | """
69 |
70 | assert (all(x.shape[0] == xs[0].shape[0] for x in xs[1:]))
71 | n_batch = xs[0].shape[0]
72 | logp_init = torch.zeros(n_batch, 1).to(xs[0])
73 | state = (*xs, logp_init)
74 | ts = torch.linspace(0.0, self._t_max, self._n_time_steps).to(xs[0])
75 | kwargs = {**self._kwargs, **kwargs}
76 | if not self._use_checkpoints:
77 | from torchdiffeq import odeint_adjoint
78 | *ys, dlogp = odeint_adjoint(
79 | dynamics,
80 | state,
81 | t=ts,
82 | method=self._integrator_method,
83 | rtol=self._integrator_rtol,
84 | atol=self._integrator_atol,
85 | options=kwargs
86 | )
87 | ys = [y[-1] for y in ys]
88 | else:
89 | from anode.adjoint import odesolver_adjoint
90 | state = torch.cat(state, dim=-1)
91 | anode_dynamics = AnodeDynamics(dynamics)
92 | state = odesolver_adjoint(anode_dynamics, state, options=kwargs)
93 | ys = [state[:, :-1]]
94 | dlogp = [state[:, -1:]]
95 | dlogp = dlogp[-1]
96 | return (*ys, dlogp)
97 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/dynamics/__init__.py:
--------------------------------------------------------------------------------
1 | from .density import DensityDynamics
2 | from .inversed import InversedDynamics
3 | from .blackbox import BlackBoxDynamics
4 | from .simple import TimeIndependentDynamics
5 | from .kernel_dynamic import KernelDynamics
6 | from .anode_dynamic import AnodeDynamics
--------------------------------------------------------------------------------
/bgflow/nn/flow/dynamics/anode_dynamic.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AnodeDynamics(torch.nn.Module):
5 | """Wrapper class to allow the use of the ANODE ode solver.
6 |
7 | Attributes
8 | ----------
9 | dynamics : torch.nn.Module
10 | A dynamics function that computes the change of the system and its density.
11 | """
12 |
13 | def __init__(self, dynamics):
14 | super().__init__()
15 | self._dynamics = dynamics
16 |
17 | def forward(self, t, state):
18 | """
19 | Converts the the concatenated state, which is required for the ANODE ode solver,
20 | to the tuple (`xs`, `dlogp`) for the following dynamics function.
21 | Then the output is concatenated again for the ANODE ode solver.
22 |
23 | Parameters
24 | ----------
25 | t : PyTorch tensor
26 | The current time
27 | state : PyTorch tensor
28 | The current state of the system
29 |
30 | Returns
31 | -------
32 | state : PyTorch tensor
33 | The combined state update of shape `[n_batch, n_dimensions]`
34 | containing the state update of the system state `dx/dt`
35 | (`state[:, :-1]`) and the update of the log density (`state[:, -1]`).
36 | """
37 | xs = state[:, :-1]
38 | dlogp = state[:, -1:]
39 | state = (xs, dlogp)
40 | *dxs, div = self._dynamics(t, state)
41 | state = torch.cat([*dxs, div], dim=-1)
42 | return state
43 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/dynamics/blackbox.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BlackBoxDynamics(torch.nn.Module):
5 | """Black box dynamics that allows to use any dynamics function.
6 | The divergence of the dynamics is computed with a divergence estimator.
7 | """
8 |
9 | def __init__(self, dynamics_function, divergence_estimator, compute_divergence=True):
10 | super().__init__()
11 | self._dynamics_function = dynamics_function
12 | self._divergence_estimator = divergence_estimator
13 | self._compute_divergence = compute_divergence
14 |
15 | def forward(self, t, *xs):
16 | """
17 | Computes the change of the system `dxs` at state `xs` and
18 | time `t`. Furthermore, can also compute the change of log density
19 | which is equal to the divergence of the change.
20 |
21 | Parameters
22 | ----------
23 | t : PyTorch tensor
24 | The current time
25 | xs : PyTorch tensor
26 | The current configuration of the system
27 |
28 | Returns
29 | -------
30 | (*dxs, divergence): Tuple of PyTorch tensors
31 | The combined state update of shape `[n_batch, n_dimensions]`
32 | containing the state update of the system state `dx/dt`
33 | (`dxs`) and the update of the log density (`dlogp`)
34 | """
35 | if self._compute_divergence:
36 | *dxs, divergence = self._divergence_estimator(
37 | self._dynamics_function, t, *xs
38 | )
39 | else:
40 | dxs = self._dynamics_function(t, xs)
41 | divergence = None
42 | return (*dxs, divergence)
43 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/dynamics/density.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class DensityDynamics(torch.nn.Module):
5 |
6 | def __init__(self, dynamics):
7 | super().__init__()
8 | self._dynamics = dynamics
9 | self._n_evals = 0
10 |
11 | def forward(self, t, state):
12 | """
13 | Computes the change of the system `dx/dt` at state `x` and
14 | time `t`. Furthermore, computes the change of density, happening
15 | due to moving `x` infinitesimally in the direction `dx/dt`
16 | according to the "instantaneous change of variables rule" [1]
17 | `dlogp(p(x(t))/dt = -div(dx(t)/dt)`
18 | [1] Neural Ordinary Differential Equations, Chen et. al,
19 | https://arxiv.org/abs/1806.07366
20 |
21 | Parameters
22 | ----------
23 | t : PyTorch tensor
24 | The current time
25 | state : PyTorch tensor
26 | The current state of the system.
27 | Consisting of the configuration `xs` and the log density change `dlogp`.
28 |
29 | Returns
30 | -------
31 | (*dxs, -dlogp) : Tuple of PyTorch tensors
32 | The combined state update of shape `[n_batch, n_dimensions]`
33 | containing the state update of the system state `dx/dt`
34 | (`dxs`) and the update of the log density (`dlogp`).
35 | """
36 | xs = state[:-1]
37 | *dxs, dlogp = self._dynamics(t, *xs)
38 | return (*dxs, -dlogp)
39 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/dynamics/inversed.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class InversedDynamics(torch.nn.Module):
5 | """
6 | Inverse of a dynamics for the inverse flow.
7 | """
8 |
9 | def __init__(self, dynamics, t_max=1.0):
10 | super().__init__()
11 | self._dynamics = dynamics
12 | self._t_max = t_max
13 |
14 | def forward(self, t, state):
15 | """
16 | Evaluates the change of the system `dxs` at time `t_max` - `t` for the inverse dynamics.
17 |
18 | Parameters
19 | ----------
20 | t : PyTorch tensor
21 | The current time
22 | state : PyTorch tensor
23 | The current state of the system
24 |
25 | Returns
26 | -------
27 | [-*dxs, -dlogp] : List of PyTorch tensors
28 | The combined state update of shape `[n_batch, n_dimensions]`
29 | containing the state update of the system state `dx/dt`
30 | (`-dxs`) and the update of the log density (`-dlogp`).
31 | """
32 |
33 | *dxs, dlogp = self._dynamics(self._t_max - t, state)
34 | return [-dx for dx in dxs] + [-dlogp]
35 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/dynamics/simple.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TimeIndependentDynamics(torch.nn.Module):
5 | """
6 | Time independent dynamics function.
7 | """
8 |
9 | def __init__(self, dynamics):
10 | super().__init__()
11 | self._dynamics = dynamics
12 |
13 | def forward(self, t, xs):
14 | """
15 | Computes the change of the system `dxs` due to a time independent dynamics function.
16 |
17 | Parameters
18 | ----------
19 | t : PyTorch tensor
20 | The current time
21 | xs : PyTorch tensor
22 | The current configuration of the system
23 |
24 | Returns
25 | -------
26 | dxs : PyTorch tensor
27 | The change of the system due to the dynamics function
28 | """
29 |
30 | dxs = self._dynamics(xs)
31 | return dxs
32 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/elementwise.py:
--------------------------------------------------------------------------------
1 | """Nonlinear One-dimensional Diffeomorphisms"""
2 |
3 | import torch
4 | from bgflow.nn.flow.base import Flow
5 |
6 |
7 | __all__ = ["BentIdentity"]
8 |
9 |
10 | class BentIdentity(Flow):
11 | """Bent identity. A nonlinear diffeomorphism with analytic gradients and inverse.
12 | See https://towardsdatascience.com/secret-sauce-behind-the-beauty-of-deep-learning-beginners-guide-to-activation-functions-a8e23a57d046 .
13 | """
14 | def __init__(self):
15 | super(BentIdentity, self).__init__()
16 |
17 | def _forward(self, x, **kwargs):
18 | """Forward transform
19 |
20 | Parameters
21 | ----------
22 | x : torch.tensor
23 | Input tensor
24 |
25 | kwargs : dict
26 | Miscellaneous arguments to satisfy the interface.
27 |
28 | Returns
29 | -------
30 | y : torch.tensor
31 | Elementwise transformed tensor with the same shape as x.
32 |
33 | dlogp : torch.tensor
34 | Natural log of the Jacobian determinant.
35 | """
36 | dlogp = torch.log(self.derivative(x)).sum(dim=-1, keepdim=True)
37 | return (torch.sqrt(x ** 2 + 1) - 1) / 2 + x, dlogp
38 |
39 | def _inverse(self, x, **kwargs):
40 | """Inverse transform
41 |
42 | Parameters
43 | ----------
44 | x : torch.tensor
45 | Input tensor
46 |
47 | kwargs : dict
48 | Miscellaneous arguments to satisfy the interface.
49 |
50 | Returns
51 | -------
52 | y : torch.tensor
53 | Elementwise transformed tensor with the same shape as x.
54 |
55 | dlogp : torch.tensor
56 | Natural log of the Jacobian determinant.
57 | """
58 | dlogp = torch.log(self.inverse_derivative(x)).sum(dim=-1, keepdim=True)
59 | return 2 / 3 * (2 * x + 1 - torch.sqrt(x ** 2 + x + 1)), dlogp
60 |
61 | @staticmethod
62 | def derivative(x):
63 | """Elementwise derivative of the activation function."""
64 | return x / (2 * torch.sqrt(x ** 2 + 1)) + 1
65 |
66 | @staticmethod
67 | def inverse_derivative(x):
68 | """Elementwise derivative of the inverse activation function."""
69 | return 4 / 3 - (2 * x + 1) / (3 * torch.sqrt(x ** 2 + x + 1))
70 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/estimator/__init__.py:
--------------------------------------------------------------------------------
1 | from . brute_force_estimator import BruteForceEstimator
2 | from .hutchinson_estimator import HutchinsonEstimator
--------------------------------------------------------------------------------
/bgflow/nn/flow/estimator/brute_force_estimator.py:
--------------------------------------------------------------------------------
1 | from bgflow.utils.autograd import brute_force_jacobian_trace
2 | import torch
3 |
4 |
5 | class BruteForceEstimator(torch.nn.Module):
6 | """
7 | Exact bruteforce estimation of the divergence of a dynamics function.
8 | """
9 |
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, dynamics, t, xs):
14 | """
15 | Computes the change of the system `dxs` due to a time independent dynamics function.
16 | Furthermore, also computes the exact change of log density
17 | which is equal to the divergence of the change `dxs`.
18 |
19 | Parameters
20 | ----------
21 | dynamics : torch.nn.Module
22 | A dynamics function that computes the change of the system and its density.
23 | t : PyTorch tensor
24 | The current time
25 | xs : PyTorch tensor
26 | The current configuration of the system
27 |
28 | Returns
29 | -------
30 | dxs, -divergence: PyTorch tensors
31 | The combined state update of shape `[n_batch, n_dimensions]`
32 | containing the state update of the system state `dx/dt`
33 | (`dxs`) and the negative update of the log density (`-divergence`)
34 | """
35 |
36 | with torch.set_grad_enabled(True):
37 | xs.requires_grad_(True)
38 | dxs = dynamics(t, xs)
39 |
40 | assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]"
41 | divergence = brute_force_jacobian_trace(dxs, xs)
42 |
43 | return dxs, -divergence.view(-1, 1)
44 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/estimator/hutchinson_estimator.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class HutchinsonEstimator(torch.nn.Module):
5 | """
6 | Estimation of the divergence of a dynamics function with the Hutchinson Estimator [1].
7 | [1] A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines, Hutchinson
8 | """
9 |
10 | def __init__(self, rademacher=True):
11 | super().__init__()
12 | self._rademacher = rademacher
13 | self._reset_noise = True
14 |
15 | def reset_noise(self, reset_noise=True):
16 | """
17 | Resets the noise vector.
18 | """
19 |
20 | self._reset_noise = reset_noise
21 |
22 | def forward(self, dynamics, t, xs):
23 | """
24 | Computes the change of the system `dxs` due to a time independent dynamics function.
25 | Furthermore, also estimates the change of log density, which is equal to the divergence of the change `dxs`,
26 | with the Hutchinson Estimator.
27 | This is done with either Rademacher or Gaussian noise.
28 |
29 | Parameters
30 | ----------
31 | dynamics : torch.nn.Module
32 | A dynamics function that computes the change of the system and its density.
33 | t : PyTorch tensor
34 | The current time
35 | xs : PyTorch tensor
36 | The current configuration of the system
37 |
38 | Returns
39 | -------
40 | dxs, -divergence: PyTorch tensors
41 | The combined state update of shape `[n_batch, n_dimensions]`
42 | containing the state update of the system state `dx/dt`
43 | (`dxs`) and the negative update of the log density (`-divergence`)
44 | """
45 |
46 | with torch.set_grad_enabled(True):
47 | xs.requires_grad_(True)
48 | dxs = dynamics(t, xs)
49 |
50 | assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]"
51 | system_dim = dxs.shape[-1]
52 |
53 | if self._reset_noise == True:
54 | self._reset_noise = False
55 | if self._rademacher == True:
56 | self._noise = torch.randint(low=0, high=2, size=xs.shape).to(xs) * 2 - 1
57 | else:
58 | self._noise = torch.randn_like(xs)
59 |
60 | noise_ddxs = torch.autograd.grad(dxs, xs, self._noise, create_graph=True)[0]
61 | divergence = torch.sum((noise_ddxs * self._noise).view(-1, system_dim), 1, keepdim=True)
62 |
63 | return dxs, -divergence
64 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/funnel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from .base import Flow
5 |
6 | # TODO: write docstrings
7 | # TODO: refactor messy implementation
8 |
9 |
10 | class FunnelFlow(Flow):
11 | def __init__(self, eps=1e-6, min_val=-1.0, max_val=1.0):
12 | super().__init__()
13 | self._eps = eps
14 | self._min_val = min_val
15 | self._max_val = max_val
16 |
17 | def _forward(self, x, **kwargs):
18 | dlogp = (
19 | torch.nn.functional.logsigmoid(x)
20 | - torch.nn.functional.softplus(x)
21 | + np.log(self._max_val - self._min_val)
22 | ).sum(dim=-1, keepdim=True)
23 | x = torch.sigmoid(x)
24 | x = x * (self._max_val - self._min_val) + self._min_val
25 | x = torch.clamp(x, self._min_val + self._eps, self._max_val - self._eps)
26 | return x, dlogp
27 |
28 | def _inverse(self, x, **kwargs):
29 | x = torch.clamp(x, self._min_val + self._eps, self._max_val - self._eps)
30 | x = (x - self._min_val) / (self._max_val - self._min_val)
31 | dlogp = (
32 | -torch.log(x - x.pow(2))
33 | - np.log(self._max_val - self._min_val)
34 | ).sum(dim=-1, keepdim=True)
35 | x = torch.log(x) - torch.log(1 - x)
36 | return x, dlogp
--------------------------------------------------------------------------------
/bgflow/nn/flow/inverted.py:
--------------------------------------------------------------------------------
1 |
2 | from .base import Flow
3 |
4 | __all__ = ["InverseFlow"]
5 |
6 |
7 | class InverseFlow(Flow):
8 | """The inverse of a given transform.
9 |
10 | Parameters
11 | ----------
12 | delegate : Flow
13 | The flow to invert.
14 | """
15 | def __init__(self, delegate):
16 | super().__init__()
17 | self._delegate = delegate
18 |
19 | def _forward(self, *xs, **kwargs):
20 | return self._delegate._inverse(*xs, **kwargs)
21 |
22 | def _inverse(self, *xs, **kwargs):
23 | return self._delegate._forward(*xs, **kwargs)
--------------------------------------------------------------------------------
/bgflow/nn/flow/kronecker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | from .base import Flow
6 |
7 |
8 | # TODO: write docstrings
9 |
10 |
11 | def _is_power2(x):
12 | return x != 0 and ((x & (x - 1)) == 0)
13 |
14 |
15 | def _kronecker(A, B):
16 | return torch.einsum("ab,cd->acbd", A, B).view(
17 | A.size(0) * B.size(0), A.size(1) * B.size(1)
18 | )
19 |
20 |
21 | def _batch_determinant_2x2(As, log=False):
22 | result = As[:, 0, 0] * As[:, 1, 1] - As[:, 1, 0] * As[:, 0, 1]
23 | if log:
24 | result = result.abs().log()
25 | return result
26 |
27 |
28 | def _create_ortho_matrices(n, d):
29 | qs = []
30 | for i in range(n):
31 | q, _ = np.linalg.qr(np.random.normal(size=(d, d)))
32 | qs.append(q)
33 | qs = np.array(qs)
34 | return qs
35 |
36 |
37 | class KroneckerProductFlow(Flow):
38 | def __init__(self, n_dim):
39 | super().__init__()
40 |
41 | assert _is_power2(n_dim)
42 |
43 | self._n_dim = n_dim
44 | self._n_factors = int(np.log2(n_dim))
45 |
46 | self._factors = torch.nn.Parameter(
47 | torch.Tensor(_create_ortho_matrices(self._n_factors, 2))
48 | )
49 | self._bias = torch.nn.Parameter(torch.Tensor(1, n_dim).zero_())
50 |
51 | def _forward(self, x, **kwargs):
52 | n_batch = x.shape[0]
53 | factors = self._factors.to(x)
54 | M = factors[0]
55 | dets = _batch_determinant_2x2(factors)
56 | det = dets[0]
57 | power = 2
58 | for new_det, factor in zip(dets[1:], factors[1:]):
59 | det = det.pow(2) * new_det.pow(power)
60 | M = _kronecker(M, factor)
61 | power = power * 2
62 | dlogp = torch.zeros(n_batch, 1).to(x)
63 | dlogp = dlogp + det.abs().log().sum(dim=-1, keepdim=True)
64 | return x @ M + self._bias.to(x), dlogp
65 |
66 | def _inverse(self, x, **kwargs):
67 | n_batch = x.shape[0]
68 | factors = self._factors.to(x)
69 | inv_factors = torch.inverse(factors)
70 | M = inv_factors[0]
71 | inv_dets = _batch_determinant_2x2(inv_factors)
72 | inv_det = inv_dets[0]
73 | power = 2
74 | for new_inv_det, factor in zip(inv_dets[1:], inv_factors[1:]):
75 | inv_det = inv_det.pow(2) * new_inv_det.pow(power)
76 | M = _kronecker(M, factor)
77 | power = power * 2
78 | dlogp = torch.zeros(n_batch, 1).to(x)
79 | dlogp = dlogp + inv_det.abs().log().sum(dim=-1, keepdim=True)
80 | return (x - self._bias.to(x)) @ M, dlogp
81 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/modulo.py:
--------------------------------------------------------------------------------
1 | __all__ = ["IncreaseMultiplicityFlow", "CircularShiftFlow"]
2 |
3 | import torch
4 | from bgflow.nn.flow.transformer.base import Flow
5 |
6 |
7 | class IncreaseMultiplicityFlow(Flow):
8 | """A flow that increases the multiplicity of torsional degrees of freedom.
9 | The input and output tensors are expected to be in [0,1].
10 | The output represents the sum over sheaves from the input.
11 |
12 | Parameters
13 | ----------
14 | multiplicities : Union[torch.Tensor, int]
15 | A tensor of integers that define the number of periods in the unit interval.
16 | """
17 |
18 | def __init__(self, multiplicities):
19 | super().__init__()
20 | self.register_buffer("_multiplicities", torch.as_tensor(multiplicities))
21 |
22 | def _forward(self, x, **kwargs):
23 | _assert_in_unit_interval(x)
24 | multiplicities = torch.ones_like(x) * self._multiplicities
25 | sheaves = _randint(multiplicities)
26 | y = (x + sheaves) / self._multiplicities
27 | dlogp = torch.zeros_like(x[..., [0]])
28 | return y, dlogp
29 |
30 | def _inverse(self, x, **kwargs):
31 | _assert_in_unit_interval(x)
32 | y = (x % (1 / self._multiplicities)) * self._multiplicities
33 | dlogp = torch.zeros_like(x[..., [0]])
34 | return y, dlogp
35 |
36 |
37 | def _randint(high):
38 | with torch.no_grad():
39 | return torch.floor(torch.rand(high.shape, device=high.device) * high)
40 |
41 |
42 | def _assert_in_unit_interval(x):
43 | if (x > 1 + 1e-6).any() or (x < - 1e-6).any():
44 | raise ValueError(f'IncreaseMultiplicityFlow operates on [0,1] but input was {x}')
45 |
46 |
47 | class CircularShiftFlow(Flow):
48 | """A flow that shifts the position of torsional degrees of freedom.
49 | The input and output tensors are expected to be in [0,1].
50 | The output is a translated version of the input, respecting circulariry.
51 |
52 | Parameters
53 | ----------
54 | shift : Union[torch.Tensor, float]
55 | A tensor that defines the translation of the circular interval
56 | """
57 |
58 | def __init__(self, shift):
59 | super().__init__()
60 | self.register_buffer("_shift", torch.as_tensor(shift))
61 |
62 | def _forward(self, x, **kwargs):
63 | _assert_in_unit_interval(x)
64 | y = (x + self._shift) % 1
65 | dlogp = torch.zeros_like(x[..., [0]])
66 | return y, dlogp
67 |
68 | def _inverse(self, x, **kwargs):
69 | _assert_in_unit_interval(x)
70 | y = (x - self._shift) % 1
71 | dlogp = torch.zeros_like(x[..., [0]])
72 | return y, dlogp
73 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/orthogonal.py:
--------------------------------------------------------------------------------
1 | """
2 | (Pseudo-) Orthogonal Linear Layers. Advantage: Jacobian determinant is unity.
3 | """
4 |
5 | import torch
6 | from bgflow.nn.flow.base import Flow
7 |
8 | __all__ = ["PseudoOrthogonalFlow"]
9 |
10 | # Note: OrthogonalPPPP is implemented in pppp.py
11 |
12 |
13 | class PseudoOrthogonalFlow(Flow):
14 | """Linear flow W*x+b with a penalty function
15 | penalty_parameter*||W^T W - I||^2
16 |
17 | Attributes
18 | ----------
19 | dim : int
20 | dimension
21 | shift : boolean
22 | Whether to use a shift parameter (+b). If False, b=0.
23 | penalty_parameter : float
24 | Scaling factor for the orthogonality constraint.
25 | """
26 | def __init__(self, dim, shift=True, penalty_parameter=1e5):
27 | super(PseudoOrthogonalFlow, self).__init__()
28 | self.dim = dim
29 | self.W = torch.nn.Parameter(torch.eye(dim))
30 | if shift:
31 | self.b = torch.nn.Parameter(torch.zeros(dim))
32 | else:
33 | self.register_buffer("b", torch.tensor(0.0))
34 | self.register_buffer("penalty_parameter", torch.tensor(penalty_parameter))
35 |
36 | def _forward(self, x, **kwargs):
37 | """Forward transform.
38 |
39 | Attributes
40 | ----------
41 | x : torch.tensor
42 | The input vector. The transform is applied to the last dimension.
43 | kwargs : dict
44 | keyword arguments to satisfy the interface
45 |
46 | Returns
47 | -------
48 | y : torch.tensor
49 | W*x + b
50 | dlogp : torch.tensor
51 | natural log of the Jacobian determinant
52 | """
53 | dlogp = torch.zeros(*x.shape[:-1], 1).to(x)
54 | y = torch.einsum("ab,...b->...a", self.W, x)
55 | return y + self.b, dlogp
56 |
57 | def _inverse(self, y, **kwargs):
58 | """Inverse transform assuming that W is orthogonal.
59 |
60 | Attributes
61 | ----------
62 | y : torch.tensor
63 | The input vector. The transform is applied to the last dimension.
64 | kwargs : dict
65 | keyword arguments to satisfy the interface
66 |
67 | Returns
68 | -------
69 | x : torch.tensor
70 | W^T*(y-b)
71 | dlogp : torch.tensor
72 | natural log of the Jacobian determinant
73 | """
74 | dlogp = torch.zeros(*y.shape[:-1], 1).to(y)
75 | x = torch.einsum("ab,...b->...a", self.W.transpose(1, 0), y - self.b)
76 | return x, dlogp
77 |
78 | def penalty(self):
79 | """Penalty function for the orthogonality constraint
80 |
81 | p(W) = penalty_parameter * ||W^T*W - I||^2.
82 |
83 | Returns
84 | -------
85 | penalty : float
86 | Value of the penalty function
87 | """
88 | return self.penalty_parameter * torch.sum((torch.eye(self.dim) - torch.mm(self.W.transpose(1, 0), self.W)) ** 2)
89 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/sequential.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 |
5 | from .base import Flow
6 |
7 | logger = logging.getLogger('bgflow')
8 |
9 |
10 | class SequentialFlow(Flow):
11 | def __init__(self, blocks):
12 | """
13 | Represents a diffeomorphism that can be computed
14 | as a discrete finite stack of layers.
15 |
16 | Returns the transformed variable and the log determinant
17 | of the Jacobian matrix.
18 |
19 | Parameters
20 | ----------
21 | blocks : Tuple / List of flow blocks
22 | """
23 | super().__init__()
24 | self._blocks = torch.nn.ModuleList(blocks)
25 |
26 | def forward(self, *xs, inverse=False, **kwargs):
27 | """
28 | Transforms the input along the diffeomorphism and returns
29 | the transformed variable together with the volume change.
30 |
31 | Parameters
32 | ----------
33 | x : PyTorch Floating Tensor.
34 | Input variable to be transformed.
35 | Tensor of shape `[..., n_dimensions]`.
36 | inverse: boolean.
37 | Indicates whether forward or inverse transformation shall be performed.
38 | If `True` computes the inverse transformation.
39 |
40 | Returns
41 | -------
42 | z: PyTorch Floating Tensor.
43 | Transformed variable.
44 | Tensor of shape `[..., n_dimensions]`.
45 | dlogp : PyTorch Floating Tensor.
46 | Total volume change as a result of the transformation.
47 | Corresponds to the log determinant of the Jacobian matrix.
48 | """
49 | dlogp = 0.0
50 | blocks = self._blocks
51 | if inverse:
52 | blocks = reversed(blocks)
53 | for i, block in enumerate(blocks):
54 | logger.debug(f"Input shapes {[x.shape for x in xs]}")
55 | *xs, ddlogp = block(*xs, inverse=inverse, **kwargs)
56 | logger.debug(f"Flow block {i} (inverse={inverse}): {block}")
57 | logger.debug(f"Output shapes {[x.shape for x in xs]}")
58 | dlogp += ddlogp
59 | return (*xs, dlogp)
60 |
61 | def _forward(self, *args, **kwargs):
62 | return self.forward(*args, **kwargs, inverse=False)
63 |
64 | def _inverse(self, *args, **kwargs):
65 | return self.forward(*args, **kwargs, inverse=True)
66 |
67 | def trigger(self, function_name):
68 | """
69 | Evaluate functions for all blocks that have a function with that name and return a tensor of the stacked results.
70 | """
71 | results = [
72 | getattr(block, function_name)()
73 | for block in self._blocks
74 | if hasattr(block, function_name) and callable(getattr(block, function_name))
75 | ]
76 | if len(results) > 0 and all(res is not None for res in results):
77 | return torch.stack(results)
78 | else:
79 | return torch.zeros(0)
80 |
81 | def __iter__(self):
82 | return iter(self._blocks)
83 |
84 | def __getitem__(self, index):
85 | if isinstance(index, int):
86 | return self._blocks[index]
87 | else:
88 | indices = np.arange(len(self))[index]
89 | return SequentialFlow([self._blocks[i] for i in indices])
90 |
91 | def __len__(self):
92 | return len(self._blocks)
93 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/stochastic/__init__.py:
--------------------------------------------------------------------------------
1 | from .mcmc import MetropolisMCFlow
2 | from .langevin import LangevinFlow, BrownianFlow, OverdampedLangevinFlow
3 | from .augment import StochasticAugmentation
4 | from .snf_openmm import *
5 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/stochastic/augment.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from bgflow.nn.flow.base import Flow
4 |
5 |
6 | class StochasticAugmentation(Flow):
7 | def __init__(self, distribution):
8 | """
9 | Stochastic augmentation layer
10 |
11 | Adds additional coordinates to the state vector by sampling them from distribution.
12 | Pre-sampled momenta can be passed through the layer by the kwarg "momenta" and
13 | transformed momenta are returned if kwarg "return_momenta" is set True.
14 | If momenta are returned, their contribution is not added to the Jacobian.
15 | This behavior is required in some MCMC samplers.
16 |
17 | Parameters
18 | ----------
19 | distribution : Energy
20 | Energy object, needs sample and energy method.
21 | """
22 | super().__init__()
23 | self.distribution = distribution
24 | self._cached_momenta_forward = None
25 | self._cached_momenta_backward = None
26 |
27 | def _forward(self, q, **kwargs):
28 | batch_size = q.shape[0]
29 | temperature = kwargs.get("temperature", 1.0)
30 | cache_momenta = kwargs.get("cache_momenta", False)
31 | # Add option to pass pre-sampled momenta as key word argument
32 | p = kwargs.get("momenta", None)
33 | if p is None:
34 | p = self.distribution.sample(batch_size, temperature=temperature)
35 | dlogp = self.distribution.energy(p, temperature=temperature)
36 | else:
37 | dlogp = torch.zeros(p.shape[0], 1).to(p)
38 | if cache_momenta:
39 | self._cached_momenta_forward = p
40 | x = torch.cat([q, p], dim=1)
41 | return x, dlogp
42 |
43 | def _inverse(self, x, **kwargs):
44 | return_momenta = kwargs.get("return_momenta", False)
45 | cache_momenta = kwargs.get("cache_momenta", False)
46 | p = x[:, self.distribution.dim :]
47 | temperature = kwargs.get("temperature", 1.0)
48 | # Add option to return transformed momenta as key word argument.
49 | # Momenta will be returned in same tensor as configurations
50 | if cache_momenta:
51 | self._cached_momenta_backward = p
52 | if return_momenta:
53 | return x, torch.zeros(p.shape[0], 1).to(p)
54 | dlogp = self.distribution.energy(p, temperature=temperature)
55 | return x[:, : self.distribution.dim], -dlogp
56 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/stochastic/mcmc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bgflow.nn.flow.base import Flow
3 |
4 | class MetropolisMCFlow(Flow):
5 | def __init__(self, energy_model, nsteps=1, stepsize=0.01):
6 | """ Stochastic Flow layer that simulates Metropolis Monte Carlo
7 |
8 | """
9 | super().__init__()
10 | self.energy_model = energy_model
11 | self.nsteps = nsteps
12 | self.stepsize = stepsize
13 |
14 | def _forward(self, x, **kwargs):
15 | """ Run a stochastic trajectory forward
16 |
17 | Parameters
18 | ----------
19 | x : PyTorch Tensor
20 | Batch of input configurations
21 |
22 | Returns
23 | -------
24 | x' : PyTorch Tensor
25 | Transformed configurations
26 | dW : PyTorch Tensor
27 | Nonequilibrium work done, always 0 for this process
28 |
29 | """
30 | E0 = self.energy_model.energy(x)
31 | E = E0
32 |
33 | for i in range(self.nsteps):
34 | # proposal step
35 | dx = self.stepsize * torch.zeros_like(x).normal_()
36 | xprop = x + dx
37 | Eprop = self.energy_model.energy(xprop)
38 |
39 | # acceptance step
40 | acc = (torch.rand(x.shape[0], 1) < torch.exp(-(Eprop - E))).float() # selection variable: 0 or 1.
41 | x = (1-acc) * x + acc * xprop
42 | E = (1-acc) * E + acc * Eprop
43 |
44 | # Work is energy difference
45 | dW = E - E0
46 |
47 | return x, dW
48 |
49 | def _inverse(self, x, **kwargs):
50 | """ Same as forward """
51 | return self._forward(x, **kwargs)
52 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/torchtransform.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from .base import Flow
4 |
5 |
6 | __all__ = ["TorchTransform"]
7 |
8 |
9 | class TorchTransform(Flow):
10 | """Wrap a torch.distributions.Transform as a Flow instance
11 |
12 | Parameters
13 | ----------
14 | transform : torch.distributions.Transform
15 | The transform instance that should be wrapped as a Flow instance.
16 | reinterpreted_batch_ndims : int, optional
17 | Number of batch dimensions to be reinterpreted as event dimensions.
18 | If >0, this transform is wrapped in an torch.distributions.IndependentTransform instance.
19 | """
20 |
21 | def __init__(self, transform, reinterpreted_batch_ndims=0):
22 | super().__init__()
23 | if reinterpreted_batch_ndims > 0:
24 | transform = torch.distributions.IndependentTransform(transform, reinterpreted_batch_ndims)
25 | self._delegate_transform = transform
26 |
27 | def _forward(self, x, **kwargs):
28 | y = self._delegate_transform(x)
29 | dlogp = self._delegate_transform.log_abs_det_jacobian(x, y)
30 | return y, dlogp[..., None]
31 |
32 | def _inverse(self, y, **kwargs):
33 | x = self._delegate_transform.inv(y)
34 | dlogp = - self._delegate_transform.log_abs_det_jacobian(x, y)
35 | return x, dlogp[..., None]
36 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import*
2 | from .affine import *
3 | from .entropy_scaling import *
4 | from .spline import *
5 | from .gaussian import *
6 | from .jax import *
7 | from .jax_bridge import *
8 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/transformer/affine.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Transformer
4 |
5 | # TODO: write docstring
6 |
7 | __all__ = ["AffineTransformer"]
8 |
9 |
10 | class AffineTransformer(Transformer):
11 | """RealNVP/NICE
12 |
13 | Parameters
14 | ----------
15 | is_circular : bool
16 | Whether this transform is periodic on [0,1].
17 | """
18 | def __init__(
19 | self,
20 | shift_transformation=None,
21 | scale_transformation=None,
22 | init_downscale=1.0,
23 | preserve_volume=False,
24 | is_circular=False,
25 | ):
26 | if scale_transformation is not None and is_circular:
27 | raise ValueError("Scaling is not compatible with periodicity.")
28 | super().__init__()
29 | self._shift_transformation = shift_transformation
30 | self._scale_transformation = scale_transformation
31 | self._log_alpha = torch.nn.Parameter(torch.zeros(1) - init_downscale)
32 | self._preserve_volume = preserve_volume
33 | self._is_circular = is_circular
34 |
35 | def _get_mu_and_log_sigma(self, x, y, *cond):
36 | if self._shift_transformation is not None:
37 | mu = self._shift_transformation(x, *cond)
38 | else:
39 | mu = torch.zeros_like(y).to(x)
40 | if self._scale_transformation is not None:
41 | alpha = torch.exp(self._log_alpha.to(x))
42 | log_sigma = torch.tanh(self._scale_transformation(x, *cond))
43 | log_sigma = log_sigma * alpha
44 | if self._preserve_volume:
45 | log_sigma = log_sigma - log_sigma.mean(dim=-1, keepdim=True)
46 | else:
47 | log_sigma = torch.zeros_like(y).to(x)
48 | return mu, log_sigma
49 |
50 | def _forward(self, x, y, *cond, **kwargs):
51 | mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond)
52 | assert mu.shape[-1] == y.shape[-1]
53 | assert log_sigma.shape[-1] == y.shape[-1]
54 | sigma = torch.exp(log_sigma)
55 | dlogp = (log_sigma).sum(dim=-1, keepdim=True)
56 | y = sigma * y + mu
57 | if self._is_circular:
58 | y = y % 1.0
59 | return y, dlogp
60 |
61 | def _inverse(self, x, y, *cond, **kwargs):
62 | mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond)
63 | assert mu.shape[-1] == y.shape[-1]
64 | assert log_sigma.shape[-1] == y.shape[-1]
65 | sigma_inv = torch.exp(-log_sigma)
66 | dlogp = (-log_sigma).sum(dim=-1, keepdim=True)
67 | y = sigma_inv * (y - mu)
68 | if self._is_circular:
69 | y = y % 1.0
70 | return y, dlogp
71 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/transformer/base.py:
--------------------------------------------------------------------------------
1 | from ..base import Flow
2 |
3 |
4 | __all__ = ["Transformer"]
5 |
6 |
7 | class Transformer(Flow):
8 |
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def _forward(self, x, y, *args, **kwargs):
13 | raise NotImplementedError()
14 |
15 | def _inverse(self, x, y, *args, **kwargs):
16 | raise NotImplementedError()
17 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/transformer/entropy_scaling.py:
--------------------------------------------------------------------------------
1 | from .base import Flow
2 | from torch.nn.parameter import Parameter
3 | import torch
4 |
5 |
6 | __all__ = ["ScalingLayer", "EntropyScalingLayer"]
7 |
8 |
9 | class ScalingLayer(Flow):
10 | def __init__(self, init_factor=1.0, dim=1):
11 | super().__init__()
12 | self._scalefactor = Parameter(init_factor * torch.ones(1))
13 | self.dim = dim
14 |
15 | def _forward(self, x, *cond, **kwargs):
16 | n_batch = x.shape[0]
17 | y = torch.zeros_like(x)
18 | y[:, : self.dim] = x[:, : self.dim] * self._scalefactor
19 | y[:, self.dim :] = x[:, self.dim :]
20 | return (
21 | y,
22 | (self.dim * self._scalefactor.log()).repeat(n_batch, 1),
23 | )
24 |
25 | def _inverse(self, x, *cond, **kwargs):
26 | n_batch = x.shape[0]
27 | y = torch.zeros_like(x)
28 | y[:, : self.dim] = x[:, : self.dim] / self._scalefactor
29 | y[:, self.dim :] = x[:, self.dim :]
30 | return (
31 | y,
32 | (-self.dim * self._scalefactor.log()).repeat(n_batch, 1),
33 | )
34 |
35 |
36 | class EntropyScalingLayer(Flow):
37 | def __init__(self, init_factor=1.0, dim=1):
38 | super().__init__()
39 | self._scalefactor = Parameter(init_factor * torch.ones(1))
40 | self.dim = dim
41 |
42 | def _forward(self, x, y, *cond, **kwargs):
43 | n_batch = x.shape[0]
44 | return (
45 | self._scalefactor * x,
46 | y,
47 | (self.dim * self._scalefactor.log()).repeat(n_batch, 1),
48 | )
49 |
50 | def _inverse(self, x, y, *cond, **kwargs):
51 | n_batch = x.shape[0]
52 | return (
53 | x / self._scalefactor,
54 | y,
55 | (-self.dim * self._scalefactor.log()).repeat(n_batch, 1),
56 | )
57 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/transformer/jax.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | try:
4 | import jax
5 | import jax.numpy as jnp
6 | except ImportError:
7 | jax = None
8 | jnp = None
9 |
10 |
11 | __all__ = [
12 | 'affine_transform',
13 | 'smooth_ramp',
14 | 'monomial_ramp',
15 | 'ramp_to_sigmoid',
16 | 'affine_sigmoid',
17 | 'wrap_around',
18 | 'remap_to_unit',
19 | 'mixture',
20 | ]
21 |
22 |
23 | def affine_transform(x, a, b):
24 | """Affine transform."""
25 | return x * jnp.exp(a) + b
26 |
27 |
28 | def smooth_ramp(x, logalpha, power=1, eps=1e-9):
29 | """Smooth ramp."""
30 | assert power > 0
31 | assert isinstance(power, int)
32 | assert eps > 0
33 | alpha = jnp.exp(logalpha)
34 | # double `where` trick to avoid NaN in backward pass
35 | z = jnp.where(x > eps, x, jnp.ones_like(x) * eps)
36 | normalizer = jnp.exp(-alpha * 1.)
37 | return jnp.where(
38 | x > eps,
39 | jnp.exp(-alpha * jnp.power(z, -power)) / normalizer,
40 | jnp.zeros_like(z))
41 |
42 |
43 | def monomial_ramp(x, order=2):
44 | assert order > 0 and isinstance(order, int)
45 | return jnp.power(x, order)
46 |
47 |
48 | def ramp_to_sigmoid(ramp):
49 | """Generalized sigmoid, given a ramp."""
50 | def _sigmoid(x, *params):
51 | numer = ramp(x, *params)
52 | denom = numer + ramp(1. - x, *params)
53 | return numer / denom
54 | return _sigmoid
55 |
56 |
57 | def affine_sigmoid(sigmoid, eps=1e-8):
58 | """Generalized affine sigmoid transform."""
59 | assert eps > 0
60 |
61 | def _affine_sigmoid(x, shift, slope, mix, *params):
62 | slope = jnp.exp(slope)
63 | mix = jax.nn.sigmoid(mix) * (1. - eps) + eps
64 | return (mix * sigmoid(slope * (x - shift), *params)
65 | + (1. - mix) * x)
66 | return _affine_sigmoid
67 |
68 |
69 | def wrap_around(bijector, sheaves=None, weights=None):
70 | """Wraps affine sigmoid around circle."""
71 | if sheaves is None:
72 | sheaves = jnp.array([-1, 0, 1])
73 | if weights is None:
74 | weights = jnp.zeros_like(sheaves)
75 | mixture_ = mixture(bijector)
76 |
77 | def _wrapped(x, *params):
78 | x = x - sheaves[None]
79 | params = [jnp.repeat(p[..., None], len(sheaves), axis=-1) for p in params]
80 | return mixture_(x, weights, *params)
81 | return remap_to_unit(_wrapped)
82 |
83 |
84 | def remap_to_unit(fun):
85 | """Maps transformation back to [0, 1]."""
86 | @functools.wraps(fun)
87 | def _remapped(x, *params):
88 | y1 = fun(jnp.ones_like(x), *params)
89 | y0 = fun(jnp.zeros_like(x), *params)
90 | return (fun(x, *params) - y0) / (y1 - y0)
91 | return _remapped
92 |
93 |
94 | def mixture(bijector):
95 | """Combines multiple bijectors into a mixture."""
96 | def _mixture_bijector(x, weights, *params):
97 | components = jax.vmap(
98 | functools.partial(bijector, x),
99 | in_axes=-1,
100 | out_axes=-1)(*params)
101 | return jnp.sum(jax.nn.softmax(weights) * components)
102 | return _mixture_bijector
103 |
--------------------------------------------------------------------------------
/bgflow/nn/flow/triangular.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | from bgflow.nn.flow.base import Flow
5 |
6 |
7 | __all__ = ["TriuFlow"]
8 |
9 |
10 | class TriuFlow(Flow):
11 | """Linear flow (I+R)*x+b with a upper triangular matrix R.
12 |
13 | Attributes
14 | ----------
15 | dim : int
16 | dimension
17 | shift : boolean
18 | Whether to use a shift parameter (+b). If False, b=0.
19 | """
20 | def __init__(self, dim, shift=True):
21 | super(TriuFlow, self).__init__()
22 | self.dim = dim
23 | self.register_buffer("indices", torch.triu_indices(dim, dim))
24 | n_matrix_parameters = self.indices.shape[1]
25 | self._unique_elements = torch.nn.Parameter(torch.zeros(n_matrix_parameters))
26 | if shift:
27 | self.b = torch.nn.Parameter(torch.zeros(dim))
28 | else:
29 | self.register_buffer("b", torch.tensor(0.0))
30 | self.register_buffer("R", torch.zeros((self.dim, self.dim)))
31 |
32 | def _make_r(self):
33 | self.R[:] = 0
34 | self.R[self.indices[0], self.indices[1]] = self._unique_elements
35 | self.R += torch.eye(self.dim)
36 | return self.R
37 |
38 | def _forward(self, x, **kwargs):
39 | """Forward transform.
40 |
41 | Attributes
42 | ----------
43 | x : torch.tensor
44 | The input vector. The transform is applied to the last dimension.
45 | kwargs : dict
46 | keyword arguments to satisfy the interface
47 |
48 | Returns
49 | -------
50 | y : torch.tensor
51 | W*x + b
52 | dlogp : torch.tensor
53 | natural log of the Jacobian determinant
54 | """
55 | R = self._make_r()
56 | dlogp = torch.ones_like(x[...,0,None])*torch.sum(torch.log(torch.abs(torch.diagonal(R))))
57 | y = torch.einsum("ab,...b->...a", R, x)
58 | return y + self.b, dlogp
59 |
60 | def _inverse(self, y, **kwargs):
61 | """Inverse transform.
62 |
63 | Attributes
64 | ----------
65 | y : torch.tensor
66 | The input vector. The transform is applied to the last dimension.
67 | kwargs : dict
68 | keyword arguments to satisfy the interface
69 |
70 | Returns
71 | -------
72 | x : torch.tensor
73 | W^T*(y-b)
74 | dlogp : torch.tensor
75 | natural log of the Jacobian determinant
76 | """
77 | R = self._make_r()
78 | dlogp = torch.ones_like(y[...,0,None])*(-torch.sum(torch.log(torch.abs(torch.diagonal(R)))))
79 | try:
80 | x = torch.linalg.solve_triangular(R, (y-self.b)[...,None], upper=True)
81 | except AttributeError:
82 | # legacy call for torch < 1.11
83 | x, _ = torch.triangular_solve((y-self.b)[...,None], R)
84 | return x[...,0], dlogp
85 |
--------------------------------------------------------------------------------
/bgflow/nn/periodic.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | __all__ = ["WrapPeriodic"]
5 |
6 |
7 | class WrapPeriodic(torch.nn.Module):
8 | """Wrap network inputs around a unit sphere.
9 |
10 | Parameters
11 | ----------
12 | net : torch.nn.Module
13 | The module, whose inputs are wrapped around the unit sphere.
14 | left : float, optional
15 | Left boundary of periodic interval.
16 | right : float, optional
17 | Right boundary of periodic interval.
18 | indices : Union[Sequence[int], slice], optional
19 | Array of periodic input indices.
20 | Only indices covered by this index array are wrapped around the sphere.
21 | The default corresponds with all indices.
22 | """
23 | def __init__(self, net, left=0.0, right=1.0, indices=slice(None)):
24 | super().__init__()
25 | self.net = net
26 | self.left = left
27 | self.right = right
28 | self.indices = indices
29 |
30 | def forward(self, x):
31 | indices = np.arange(x.shape[-1])[self.indices]
32 | other_indices = np.setdiff1d(np.arange(x.shape[-1]), indices)
33 | y = x[..., indices]
34 | cos = torch.cos(2 * np.pi * (y - self.left) / (self.right - self.left))
35 | sin = torch.sin(2 * np.pi * (y - self.left) / (self.right - self.left))
36 | x = torch.cat([cos, sin, x[..., other_indices]], dim=-1)
37 | return self.net.forward(x)
38 |
39 |
40 | class WrapDistances(torch.nn.Module):
41 | """TODO: TEST!!!"""
42 | def __init__(self, net, left=0.0, right=1.0, indices=slice(None)):
43 | super().__init__()
44 | self.net = net
45 | self.left = left
46 | self.right = right
47 | self.indices = indices
48 |
49 | def forward(self, x):
50 | indices = np.arange(x.shape[-1])[self.indices]
51 | other_indices = np.setdiff1d(np.arange(x.shape[-1]), indices)
52 | y = x[..., indices].view(x.shape[0],-1,3)
53 | distance_matrix = torch.cdist(y,y)
54 | mask = ~torch.tril(torch.ones_like(distance_matrix)).bool()
55 |
56 | distances = distance_matrix[mask].view(x.shape[0], -1)
57 | x = torch.cat([x[..., other_indices], distances], dim=-1)
58 | return self.net.forward(x)
59 |
--------------------------------------------------------------------------------
/bgflow/nn/training/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainers import *
2 |
--------------------------------------------------------------------------------
/bgflow/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | .. currentmodule: bgflow.utils
3 |
4 | ===============================================================================
5 | Geometry Utilities
6 | ===============================================================================
7 |
8 | .. autosummary::
9 | :toctree: generated/
10 | :template: class.rst
11 |
12 | distance_vectors
13 | distances_from_vectors
14 | remove_mean
15 | compute_distances
16 | compute_gammas
17 | kernelize_with_rbf
18 | RbfEncoder
19 | rbf_kernels
20 |
21 | ===============================================================================
22 | Jacobian Computation
23 | ===============================================================================
24 |
25 | .. autosummary::
26 | :toctree: generated/
27 | :template: class.rst
28 |
29 | brute_force_jacobian_trace
30 | brute_force_jacobian
31 | "batch_jacobian
32 | get_jacobian
33 | requires_grad
34 |
35 | ===============================================================================
36 | Types
37 | ===============================================================================
38 |
39 | .. autosummary::
40 | :toctree: generated/
41 | :template: class.rst
42 |
43 | is_list_or_tuple
44 | assert_numpy
45 | as_numpy
46 |
47 | ===============================================================================
48 | Free Energy Estimation
49 | ===============================================================================
50 |
51 | .. autosummary::
52 | :toctree: generated/
53 | :template: class.rst
54 |
55 | bennett_acceptance_ratio
56 |
57 | ===============================================================================
58 | Training utilities
59 | ===============================================================================
60 |
61 | .. autosummary::
62 | :toctree: generated/
63 | :template: class.rst
64 |
65 | IndexBatchIterator
66 | LossReporter
67 |
68 |
69 | ===============================================================================
70 | Training utilities
71 | ===============================================================================
72 |
73 | .. autosummary::
74 | :toctree: generated/
75 | :template: class.rst
76 |
77 | ClipGradient
78 |
79 | """
80 |
81 | from .train import IndexBatchIterator, LossReporter, ClipGradient
82 | from .shape import tile
83 | from .types import *
84 | from .autograd import *
85 |
86 | from .geometry import (
87 | distance_vectors,
88 | distances_from_vectors,
89 | remove_mean,
90 | compute_distances
91 | )
92 | from .rbf_kernels import (
93 | kernelize_with_rbf,
94 | compute_gammas,
95 | RbfEncoder,
96 | rbf_kernels
97 | )
98 |
99 | from .free_energy import *
100 |
--------------------------------------------------------------------------------
/bgflow/utils/openmm.py:
--------------------------------------------------------------------------------
1 | __author__ = 'solsson'
2 |
3 | import numpy as np
4 |
5 |
6 | def save_latent_samples_as_trajectory(samples, mdtraj_topology, filename=None, topology_fn=None, return_openmm_traj=True):
7 | """
8 | Save Boltzmann Generator samples as a molecular dynamics trajectory.
9 | `samples`: posterior (Nsamples, n_atoms*n_dim)
10 | `mdtraj_topology`: an MDTraj Topology object of the molecular system
11 | `filename=None`: output filename with extension (all MDTraj compatible formats)
12 | `topology_fn=None`: outputs a PDB-file of the molecular topology for external visualization and analysis.
13 | """
14 | import mdtraj as md
15 | trajectory = md.Trajectory(samples.reshape(-1, mdtraj_topology.n_atoms, 3), mdtraj_topology)
16 | if isinstance(topology_fn, str):
17 | trajectory[0].save_pdb(topology_fn)
18 | if isinstance(filename, str):
19 | trajectory.save(filename)
20 | if return_openmm_traj:
21 | return trajectory
22 |
23 |
24 | class NumpyReporter(object):
25 | def __init__(self, reportInterval, enforcePeriodicBox=True):
26 | self._coords = []
27 | self._reportInterval = reportInterval
28 | self.enforcePeriodicBox = enforcePeriodicBox
29 |
30 | def describeNextReport(self, simulation):
31 | steps = self._reportInterval - simulation.currentStep%self._reportInterval
32 | return (steps, True, False, False, False, self.enforcePeriodicBox)
33 |
34 | def report(self, simulation, state):
35 | self._coords.append(state.getPositions(asNumpy=True).ravel())
36 |
37 | def get_coordinates(self, superimpose=None):
38 | """
39 | return saved coordinates as numpy array
40 | `superimpose`: openmm/mdtraj topology, will superimpose on first frame
41 | """
42 | import mdtraj as md
43 | try:
44 | from openmm.app.topology import Topology as _Topology
45 | except ImportError: # fall back to older version < 7.6
46 | from simtk.openmm.app.topology import Topology as _Topology
47 | if superimpose is None:
48 | return np.array(self._coords)
49 | elif isinstance(superimpose, _Topology):
50 | trajectory = md.Trajectory(np.array(self._coords).reshape(-1, superimpose.getNumAtoms(), 3),
51 | md.Topology().from_openmm(superimpose))
52 | else:
53 | trajectory = md.Trajectory(np.array(self._coords).reshape(-1, superimpose.n_atoms, 3),
54 | superimpose)
55 |
56 | trajectory.superpose(trajectory[0])
57 | return trajectory.xyz.reshape(-1, superimpose.n_atoms * 3)
58 |
--------------------------------------------------------------------------------
/bgflow/utils/shape.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def tile(a, dim, n_tile):
6 | """
7 | Tiles a pytorch tensor along one an arbitrary dimension.
8 |
9 | Parameters
10 | ----------
11 | a : PyTorch tensor
12 | the tensor which is to be tiled
13 | dim : Integer
14 | dimension along the tensor is tiled
15 | n_tile : Integer
16 | number of tiles
17 |
18 | Returns
19 | -------
20 | b : PyTorch tensor
21 | the tensor with dimension `dim` tiled `n_tile` times
22 | """
23 | init_dim = a.size(dim)
24 | repeat_idx = [1] * a.dim()
25 | repeat_idx[dim] = n_tile
26 | a = a.repeat(*(repeat_idx))
27 | order_index = np.concatenate(
28 | [init_dim * np.arange(n_tile) + i for i in range(init_dim)]
29 | )
30 | order_index = torch.LongTensor(order_index).to(a).long()
31 | return torch.index_select(a, dim, order_index)
32 |
--------------------------------------------------------------------------------
/bgflow/utils/tensorops.py:
--------------------------------------------------------------------------------
1 | def log_dot_exp(logA, logB):
2 | """Fast and stable matrix log multiplication"""
3 | maxA = logA.max(dim=-1, keepdim=True).values
4 | maxB = logB.max(dim=-2, keepdim=True).values
5 | A = (logA - maxA).exp()
6 | B = (logB - maxB).exp()
7 | batch_shape = A.shape[:-2]
8 | A = A.view(-1, A.shape[-2], A.shape[-1])
9 | B = B.view(-1, B.shape[-2], B.shape[-1])
10 | logC = A.bmm(B).log().view(*batch_shape, A.shape[-2], B.shape[-1])
11 | logC.add_(maxA + maxB)
12 | return logC
13 |
--------------------------------------------------------------------------------
/bgflow/utils/types.py:
--------------------------------------------------------------------------------
1 |
2 | from collections.abc import Iterable
3 | import numpy as np
4 | import torch
5 |
6 | __all__ = [
7 | "is_list_or_tuple", "assert_numpy", "as_numpy",
8 | "unpack_tensor_tuple", "pack_tensor_in_tuple",
9 | "pack_tensor_in_list",
10 | ]
11 |
12 |
13 | def is_list_or_tuple(x):
14 | return isinstance(x, list) or isinstance(x, tuple)
15 |
16 |
17 | def assert_numpy(x, arr_type=None):
18 | if isinstance(x, torch.Tensor):
19 | if x.is_cuda:
20 | x = x.cpu()
21 | x = x.detach().numpy()
22 | if is_list_or_tuple(x):
23 | x = np.array(x)
24 | assert isinstance(x, np.ndarray)
25 | if arr_type is not None:
26 | x = x.astype(arr_type)
27 | return x
28 |
29 |
30 | def as_numpy(tensor):
31 | """convert tensor to numpy"""
32 | return torch.as_tensor(tensor).detach().cpu().numpy()
33 |
34 |
35 | def unpack_tensor_tuple(seq):
36 | """unpack a tuple containing one tensor to a tensor"""
37 | if isinstance(seq, torch.Tensor):
38 | return seq
39 | else:
40 | if len(seq) == 1:
41 | return seq[0]
42 | else:
43 | return (*seq, )
44 |
45 |
46 | def pack_tensor_in_tuple(seq):
47 | """pack a tensor into a tuple of Tensor of length 1"""
48 | if isinstance(seq, torch.Tensor):
49 | return seq,
50 | elif isinstance(seq, Iterable):
51 | return (*seq, )
52 | else:
53 | return seq
54 |
55 |
56 | def pack_tensor_in_list(seq):
57 | """pack a tensor into a list of Tensor of length 1"""
58 | if isinstance(seq, torch.Tensor):
59 | return [seq]
60 | elif isinstance(seq, Iterable):
61 | return list(seq)
62 | else:
63 | return seq
64 |
--------------------------------------------------------------------------------
/devtools/conda-env.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | channels:
3 | - conda-forge
4 | - pytorch
5 |
6 | dependencies:
7 | - python
8 | - pip
9 |
10 | - pytest
11 | - numpy
12 |
13 | - openmm
14 | - xtb-python
15 | - ase
16 | - openmmtools
17 | - pytorch
18 | - jax
19 | - nequip
20 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = bgflow
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Compiling bgflow's Documentation
2 |
3 | The docs for this project are built with [Sphinx](http://www.sphinx-doc.org/en/master/).
4 | To compile the docs, first ensure that Sphinx and corresponding packages are installed.
5 |
6 |
7 | ```bash
8 | pip install sphinx sphinx_rtd_theme sphinx_gallerie sphinx_nbexamples sphinxcontrib-katex sphinxcontrib-bibtex
9 | ```
10 |
11 | Once installed, you can use the `Makefile` in this directory to compile static HTML pages by
12 | ```bash
13 | make html
14 | ```
15 |
16 | The compiled docs will be in the `docs/_build/html/` directory and can be viewed by opening `index.html`.
17 |
18 | The documentation rst-files can be found in `docs` and `docs/api`. To include a class/function in the documentation it
19 | needs to be referenced in a corresponding `__init__.py`. That way the documentation is directly where the code is. For
20 | an example see `nn/flow/__init__.py`. If the `__init__.py` does not have any docstrings so far, it most likely needs to
21 | be referenced in a rst-file. The rst-files need to look like the ones in `docs/api` and must also be included
22 | in `docs/index.rst`.
23 |
24 | Notebooks can be also part of the documentation as well. To include them they need to be in `examples/nb_examples`
25 | and named `example_{name}.ipynb`. Especially md comments in separate cells work well.
26 |
27 |
28 | A configuration file for [Read The Docs](https://readthedocs.org/) (readthedocs.yaml) is included in the top level of
29 | the repository. To use Read the Docs to host your documentation, go to https://readthedocs.org/
30 | and connect this repository. You may need to change your default branch to `main` under Advanced Settings for the
31 | project.
32 |
33 | If you would like to use Read The Docs with `autodoc` (included automatically)
34 | and your package has dependencies, you will need to include those dependencies in your documentation yaml
35 | file (`docs/requirements.yaml`).
36 |
37 |
--------------------------------------------------------------------------------
/docs/_static/README.md:
--------------------------------------------------------------------------------
1 | # Static Doc Directory
2 |
3 | Add any paths that contain custom static files (such as style sheets) here,
4 | relative to the `conf.py` file's directory.
5 | They are copied after the builtin static files,
6 | so a file named "default.css" will overwrite the builtin "default.css".
7 |
8 | The path to this folder is set in the Sphinx `conf.py` file in the line:
9 | ```python
10 | templates_path = ['_static']
11 | ```
12 |
13 | ## Examples of file to add to this directory
14 | * Custom Cascading Style Sheets
15 | * Custom JavaScript code
16 | * Static logo images
17 |
--------------------------------------------------------------------------------
/docs/_templates/README.md:
--------------------------------------------------------------------------------
1 | # Templates Doc Directory
2 |
3 | Add any paths that contain templates here, relative to
4 | the `conf.py` file's directory.
5 | They are copied after the builtin template files,
6 | so a file named "page.html" will overwrite the builtin "page.html".
7 |
8 | The path to this folder is set in the Sphinx `conf.py` file in the line:
9 | ```python
10 | html_static_path = ['_templates']
11 | ```
12 |
13 | ## Examples of file to add to this directory
14 | * HTML extensions of stock pages like `page.html` or `layout.html`
15 |
--------------------------------------------------------------------------------
/docs/_templates/class.rst:
--------------------------------------------------------------------------------
1 | {% set escapedname = objname|escape %}
2 | {% set title = "*" ~ objtype ~ "* " ~ escapedname %}
3 | {{ title | underline }}
4 |
5 | .. currentmodule:: {{ module }}
6 |
7 | .. auto{{objtype}}:: {{ objname }}
8 |
9 | {% block attributes %}
10 | {% if attributes %}
11 | .. rubric:: {{ _('Attributes') }}
12 |
13 | .. autosummary::
14 | {% for item in attributes %}
15 | ~{{ name }}.{{ item }}
16 | {%- endfor %}
17 | {% endif %}
18 | {% endblock %}
19 |
20 | {% block methods %}
21 |
22 | {% if methods %}
23 | .. rubric:: {{ _('Methods') }}
24 |
25 | .. autosummary::
26 | {% for item in methods %}
27 | ~{{ name }}.{{ item }}
28 | {%- endfor %}
29 | {% endif %}
30 | {% endblock %}
31 |
--------------------------------------------------------------------------------
/docs/api/bg.rst:
--------------------------------------------------------------------------------
1 | Boltzmann Generators
2 | ====================
3 |
4 |
5 | .. currentmodule:: bgflow
6 | .. autosummary::
7 | :toctree: generated/
8 | :template: class.rst
9 |
10 | bg.BoltzmannGenerator
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/api/energies.rst:
--------------------------------------------------------------------------------
1 | Energies
2 | =========
3 |
4 |
5 | .. automodule:: bgflow.distribution.energy
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
10 |
--------------------------------------------------------------------------------
/docs/api/flows.rst:
--------------------------------------------------------------------------------
1 | Flows
2 | =======
3 |
4 |
5 | .. automodule:: bgflow.nn.flow
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
10 |
--------------------------------------------------------------------------------
/docs/api/samplers.rst:
--------------------------------------------------------------------------------
1 | Sampling
2 | ========
3 |
4 |
5 | .. automodule:: bgflow.distribution
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
--------------------------------------------------------------------------------
/docs/api/utils.rst:
--------------------------------------------------------------------------------
1 | Utils
2 | =======
3 |
4 |
5 | .. automodule:: bgflow.utils
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
10 |
--------------------------------------------------------------------------------
/docs/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | =========
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | examples/index.rst
8 | nb_examples/index.rst
9 | datasets/index.rst
10 |
--------------------------------------------------------------------------------
/docs/getting_started.rst:
--------------------------------------------------------------------------------
1 | Getting Started
2 | ===============
3 |
4 | Coming soon ...
5 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Bgflow
2 | =========================================================
3 | Bgflow is a pytorch framework for Boltzmann Generators :footcite:`noe2019boltzmann` and other sampling methods.
4 |
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | installation.rst
10 | getting_started.rst
11 | requirements.rst
12 | examples.rst
13 |
14 | .. toctree::
15 | :maxdepth: 1
16 | :caption: API docs:
17 |
18 | api/bg.rst
19 | api/flows.rst
20 | api/samplers.rst
21 | api/energies.rst
22 | api/utils.rst
23 |
24 | References
25 | ----------
26 | .. footbibliography::
27 |
28 |
29 | Indices and tables
30 | ==================
31 |
32 | * :ref:`genindex`
33 | * :ref:`modindex`
34 | * :ref:`search`
35 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | To come via pip
5 |
6 | .. code-block:: console
7 |
8 | pip install bgtorch
9 |
10 | or conda-forge
11 |
12 | .. code-block:: console
13 |
14 | conda install -c conda-forge bgtorch
--------------------------------------------------------------------------------
/docs/literature.bib:
--------------------------------------------------------------------------------
1 | % Encoding: UTF-8
2 |
3 | @article{Khler2020EquivariantFE,
4 | title={Equivariant Flows: exact likelihood generative learning for symmetric densities},
5 | author={Jonas K{\"o}hler and Leon Klein and F. No{\'e}},
6 | journal={ArXiv},
7 | year={2020},
8 | volume={abs/2006.02425}
9 | }
10 |
11 | @article{noe2019boltzmann,
12 | title={Boltzmann generators-sampling equilibrium states of many-body systems with deep learning},
13 | author={No{\'e}, Frank and Olsson, Simon and K{\"o}hler, Jonas and Wu, Hao},
14 | journal={Science},
15 | volume={365},
16 | issue={6457},
17 | pages={eaaw1147},
18 | year={2019}
19 | }
20 |
21 | @inproceedings{chen2018neural,
22 | title={Neural ordinary differential equations},
23 | author={Chen, Tian Qi and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David K},
24 | booktitle={Advances in Neural Information Processing Systems},
25 | pages={6571--6583},
26 | year={2018}
27 | }
28 |
29 | @article{gholami2019anode,
30 | title={ANODE: Unconditionally Accurate Memory-Efficient Gradients for Neural ODEs},
31 | author={Gholami, Amir and Keutzer, Kurt and Biros, George},
32 | journal={arXiv preprint arXiv:1902.10298},
33 | year={2019}
34 | }
35 |
36 | @article{garcia2021n,
37 | title={E (n) Equivariant Normalizing Flows},
38 | author={Garcia Satorras, Victor and Hoogeboom, Emiel and Fuchs, Fabian and Posner, Ingmar and Welling, Max},
39 | journal={Advances in Neural Information Processing Systems},
40 | volume={34},
41 | year={2021}
42 | }
43 |
44 | @article{satorras2021n,
45 | title={E (n) equivariant graph neural networks},
46 | author={Satorras, Victor Garcia and Hoogeboom, Emiel and Welling, Max},
47 | journal={arXiv preprint arXiv:2102.09844},
48 | year={2021}
49 | }
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=bgflow
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/requirements.rst:
--------------------------------------------------------------------------------
1 | Requirements
2 | ============
3 |
4 | Mandatory
5 | **********
6 |
7 | - `einops `_
8 | - `pytorch `_
9 | - `numpy `_
10 | - `matplotlib `_
11 |
12 | Optional
13 | ********
14 |
15 | - `pytest `_ (for testing)
16 | - `nflows `_ (for Neural Spline Flows)
17 | - `OpenMM `_ (for molecular examples)
18 | - `torchdiffeq `_ (for neural ODEs)
19 | - `ANODE `_ (for neural ODEs)
20 |
--------------------------------------------------------------------------------
/docs/requirements.yaml:
--------------------------------------------------------------------------------
1 | name: docs
2 | channels:
3 | dependencies:
4 | # Base depends
5 | - python
6 | - pip
7 | - numpy
8 | - nflows
9 | - torchdiffeq
10 | - einops
11 | - torch
12 | # Doc depends
13 | - sphinx
14 | - sphinx-gallery
15 | - sphinx-nbexamples
16 | - sphinx-rtd-theme
17 | - sphinxcontrib-applehelp
18 | - sphinxcontrib-bibtex
19 | - sphinxcontrib-devhelp
20 | - sphinxcontrib-htmlhelp
21 | - sphinxcontrib-katex
22 | - sphinxcontrib-jsmath
23 | - sphinxcontrib-qthelp
24 | - sphinxcontrib-serializinghtml
25 |
26 |
--------------------------------------------------------------------------------
/examples/datasets/README.rst:
--------------------------------------------------------------------------------
1 | Datasets
2 | ==================
3 |
4 | Coming soon...
--------------------------------------------------------------------------------
/examples/general_examples/README.rst:
--------------------------------------------------------------------------------
1 | General Examples
2 | ==================
3 |
4 | Below is a gallery of examples
--------------------------------------------------------------------------------
/examples/general_examples/plot_simple_bg.py:
--------------------------------------------------------------------------------
1 | """
2 | Minimal Boltzmann Generator Example
3 | ====================================
4 |
5 | In this example a simple Boltzmann Generator is created using coupling layers.
6 | """
7 |
8 | import torch
9 | import matplotlib.pyplot as plt
10 | import bgflow as bg
11 |
12 | # define prior and target
13 | dim = 2
14 | prior = bg.NormalDistribution(dim)
15 | target = bg.DoubleWellEnergy(dim)
16 |
17 | # here we aggregate all layers of the flow
18 | layers = []
19 | layers.append(bg.SplitFlow(dim // 2))
20 | layers.append(bg.CouplingFlow(
21 | # we use a affine transformation to transform
22 | # the RHS conditioned on the LHS
23 | bg.AffineTransformer(
24 | # use simple dense nets for the affine shift/scale
25 | shift_transformation=bg.DenseNet(
26 | [dim // 2, 4, dim // 2],
27 | activation=torch.nn.ReLU()
28 | ),
29 | scale_transformation=bg.DenseNet(
30 | [dim // 2, 4, dim // 2],
31 | activation=torch.nn.Tanh()
32 | )
33 | )
34 | ))
35 | layers.append(bg.InverseFlow(bg.SplitFlow(dim // 2)))
36 |
37 | # now define the flow as a sequence of all operations stored in layers
38 | flow = bg.SequentialFlow(layers)
39 |
40 | # The BG is defined by a prior, target and a flow
41 | generator = bg.BoltzmannGenerator(prior, flow, target)
42 |
43 | # sample from the BG
44 | samples = generator.sample(1000)
45 | _ = plt.hist2d(
46 | samples[:, 0].detach().numpy(),
47 | samples[:, 1].detach().numpy(), bins=100
48 | )
49 |
--------------------------------------------------------------------------------
/examples/nb_examples/README.rst:
--------------------------------------------------------------------------------
1 | Example Notebooks
2 | ==================
3 |
4 | Below is a gallery of example jupyter notebooks
5 | They are only rendered if they are named example_{name}.ipynb
--------------------------------------------------------------------------------
/notebooks/alanine_dipeptide_augmented.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "basic-river",
6 | "metadata": {},
7 | "source": [
8 | "# Alanine Dipeptide with Augmented Normalizing Flows\n",
9 | "\n",
10 | "This notebook introduces augmented normalizing flows. \n",
11 | "\n",
12 | "At first, let us import everything from the basics tutorial so that we can focus on the newly introduced concepts."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "fuzzy-petite",
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "name": "stdout",
23 | "output_type": "stream",
24 | "text": [
25 | "Using downloaded and verified file: /tmp/alanine-dipeptide-nowater.pdb\n"
26 | ]
27 | },
28 | {
29 | "data": {
30 | "application/vnd.jupyter.widget-view+json": {
31 | "model_id": "51cb1200f7e64d5c83575f8c5e60a3be",
32 | "version_major": 2,
33 | "version_minor": 0
34 | },
35 | "text/plain": [
36 | "_ColormakerRegistry()"
37 | ]
38 | },
39 | "metadata": {},
40 | "output_type": "display_data"
41 | }
42 | ],
43 | "source": [
44 | "import alanine_dipeptide_basics as basic"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "id": "increased-serbia",
50 | "metadata": {},
51 | "source": [
52 | "## Augmented Prior"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 3,
58 | "id": "vertical-whale",
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/plain": [
64 | "OpenMMEnergy()"
65 | ]
66 | },
67 | "execution_count": 3,
68 | "metadata": {},
69 | "output_type": "execute_result"
70 | }
71 | ],
72 | "source": [
73 | "basic.target_energy"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "id": "distant-banks",
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "basic.coordinate_transform"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "id": "innocent-envelope",
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "basic.dim_ics"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "id": "correct-tribe",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": []
103 | }
104 | ],
105 | "metadata": {
106 | "kernelspec": {
107 | "display_name": "Python [conda env:ml] *",
108 | "language": "python",
109 | "name": "conda-env-ml-py"
110 | },
111 | "language_info": {
112 | "codemirror_mode": {
113 | "name": "ipython",
114 | "version": 3
115 | },
116 | "file_extension": ".py",
117 | "mimetype": "text/x-python",
118 | "name": "python",
119 | "nbconvert_exporter": "python",
120 | "pygments_lexer": "ipython3",
121 | "version": "3.7.6"
122 | }
123 | },
124 | "nbformat": 4,
125 | "nbformat_minor": 5
126 | }
127 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | # readthedocs.yml
2 |
3 | version: 2
4 |
5 | build:
6 | image: latest
7 |
8 | python:
9 | version: 3.8
10 | install:
11 | - method: pip
12 | path: .
13 |
14 | conda:
15 | environment: docs/requirements.yaml
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 |
2 | [yapf]
3 | # YAPF, in .style.yapf files this shows up as "[style]" header
4 | COLUMN_LIMIT = 119
5 | INDENT_WIDTH = 4
6 | USE_TABS = False
7 |
8 |
9 | [versioneer]
10 | # Automatic version numbering scheme
11 | VCS = git
12 | style = pep440
13 | versionfile_source = bgflow/_version.py
14 | versionfile_build = bgflow/_version.py
15 | tag_prefix = ''
16 |
17 | [aliases]
18 | test = pytest
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | bgflow
3 | Boltzmann Generators and Normalizing Flows in PyTorch
4 | """
5 | import sys
6 | from setuptools import setup, find_packages
7 | import versioneer
8 |
9 | short_description = "Boltzmann Generators and Normalizing Flows in PyTorch".split("\n")[0]
10 |
11 | # from https://github.com/pytest-dev/pytest-runner#conditional-requirement
12 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv)
13 | pytest_runner = ['pytest-runner'] if needs_pytest else []
14 |
15 | try:
16 | with open("README.md", "r") as handle:
17 | long_description = handle.read()
18 | except:
19 | long_description = None
20 |
21 |
22 | setup(
23 | # Self-descriptive entries which should always be present
24 | name='bgflow',
25 | author='Jonas Köhler, Andreas Krämer, Leon Klein, Manuel Dibak, Frank Noé',
26 | author_email='kraemer.research@gmail.com',
27 | description=short_description,
28 | long_description=long_description,
29 | long_description_content_type="text/markdown",
30 | version=versioneer.get_version(),
31 | cmdclass=versioneer.get_cmdclass(),
32 | license='MIT',
33 |
34 | # Which Python importable modules should be included when your package is installed
35 | # Handled automatically by setuptools. Use 'exclude' to prevent some specific
36 | # subpackage(s) from being added, if needed
37 | packages=find_packages(),
38 |
39 | # Optional include package data to ship with your package
40 | # Customize MANIFEST.in if the general case does not suit your needs
41 | # Comment out this line to prevent the files from being packaged with your software
42 | include_package_data=True,
43 |
44 | # Allows `setup.py test` to work correctly with pytest
45 | setup_requires=[] + pytest_runner,
46 |
47 | # Additional entries you may want simply uncomment the lines you want and fill in the data
48 | # url='http://www.my_package.com', # Website
49 | # install_requires=[], # Required packages, pulls from pip if needed; do not use for Conda deployment
50 | # platforms=['Linux',
51 | # 'Mac OS-X',
52 | # 'Unix',
53 | # 'Windows'], # Valid platforms your code works on, adjust to your flavor
54 | # python_requires=">=3.5", # Python version restrictions
55 |
56 | # Manual control if final package is compressible or not, set False to prevent the .egg from being made
57 | # zip_safe=False,
58 |
59 | )
60 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from types import SimpleNamespace
4 | import pytest
5 | import numpy as np
6 | import torch
7 | from bgflow import MixedCoordinateTransformation, OpenMMBridge, OpenMMEnergy, RelativeInternalCoordinateTransformation
8 |
9 |
10 | @pytest.fixture(
11 | params=[
12 | "cpu",
13 | pytest.param(
14 | "cuda:0",
15 | marks=pytest.mark.skipif(
16 | not torch.cuda.is_available(),
17 | reason="CUDA not available."
18 | )
19 | )
20 | ]
21 | )
22 | def device(request):
23 | """Run a test case for all available devices."""
24 | return torch.device(request.param)
25 |
26 |
27 | @pytest.fixture(params=[torch.float32, torch.float64])
28 | def dtype(request, device):
29 | """Run a test case in single and double precision."""
30 | return request.param
31 |
32 |
33 | @pytest.fixture()
34 | def ctx(dtype, device):
35 | return {"dtype": dtype, "device": device}
36 |
37 |
38 | @pytest.fixture(params=[torch.enable_grad, torch.no_grad])
39 | def with_grad_and_no_grad(request):
40 | """Run a test with and without torch grad enabled"""
41 | with request.param():
42 | yield
43 |
44 |
45 | @pytest.fixture(scope="session")
46 | def ala2():
47 | """Mock bgmol dataset."""
48 | mm = pytest.importorskip("simtk.openmm")
49 | md = pytest.importorskip("mdtraj")
50 | pdb = mm.app.PDBFile(os.path.join(os.path.dirname(__file__), "data/alanine-dipeptide-nowater.pdb"))
51 | system = SimpleNamespace()
52 | system.topology = pdb.getTopology()
53 | system.mdtraj_topology = md.Topology.from_openmm(system.topology)
54 | system.system = mm.app.ForceField("amber99sbildn.xml").createSystem(
55 | pdb.getTopology(),
56 | removeCMMotion=True,
57 | nonbondedMethod=mm.app.NoCutoff,
58 | constraints=mm.app.HBonds,
59 | rigidWater=True
60 | )
61 | system.energy_model = OpenMMEnergy(
62 | bridge=OpenMMBridge(
63 | system.system,
64 | mm.LangevinIntegrator(300, 1, 0.001),
65 | n_workers=1
66 | )
67 | )
68 | system.positions = pdb.getPositions()
69 | system.rigid_block = np.array([6, 8, 9, 10, 14])
70 | system.z_matrix = np.array([
71 | [0, 1, 4, 6],
72 | [1, 4, 6, 8],
73 | [2, 1, 4, 0],
74 | [3, 1, 4, 0],
75 | [4, 6, 8, 14],
76 | [5, 4, 6, 8],
77 | [7, 6, 8, 4],
78 | [11, 10, 8, 6],
79 | [12, 10, 8, 11],
80 | [13, 10, 8, 11],
81 | [15, 14, 8, 16],
82 | [16, 14, 8, 6],
83 | [17, 16, 14, 15],
84 | [18, 16, 14, 8],
85 | [19, 18, 16, 14],
86 | [20, 18, 16, 19],
87 | [21, 18, 16, 19]
88 | ])
89 | system.global_z_matrix = np.row_stack([
90 | system.z_matrix,
91 | np.array([
92 | [9, 8, 6, 14],
93 | [10, 8, 14, 6],
94 | [6, 8, 14, -1],
95 | [8, 14, -1, -1],
96 | [14, -1, -1, -1]
97 | ])
98 | ])
99 | dataset = SimpleNamespace()
100 | dataset.system = system
101 | # super-short simulation
102 | xyz = []
103 | simulation = mm.app.Simulation(system.topology, system.system, mm.LangevinIntegrator(300,1,0.001))
104 | simulation.context.setPositions(system.positions)
105 | for i in range(100):
106 | simulation.step(10)
107 | pos = simulation.context.getState(getPositions=True).getPositions(asNumpy=True)
108 | xyz.append(pos._value)
109 | dataset.xyz = np.stack(xyz, axis=0)
110 | return dataset
111 |
112 |
113 | @pytest.fixture()
114 | def crd_trafo(ala2, ctx):
115 | z_matrix = ala2.system.z_matrix
116 | fixed_atoms = ala2.system.rigid_block
117 | crd_transform = MixedCoordinateTransformation(torch.tensor(ala2.xyz, **ctx), z_matrix, fixed_atoms)
118 | return crd_transform
119 |
120 |
121 | @pytest.fixture()
122 | def crd_trafo_unwhitened(ala2, ctx):
123 | z_matrix = ala2.system.z_matrix
124 | fixed_atoms = ala2.system.rigid_block
125 | crd_transform = RelativeInternalCoordinateTransformation(z_matrix, fixed_atoms)
126 | return crd_transform
127 |
128 |
--------------------------------------------------------------------------------
/tests/data/alanine-dipeptide-nowater.pdb:
--------------------------------------------------------------------------------
1 | CRYST1 27.222 27.222 27.222 90.00 90.00 90.00 P 1 1
2 | ATOM 1 HH31 ACE X 1 3.225 27.427 2.566 1.00 0.00
3 | ATOM 2 CH3 ACE X 1 3.720 26.570 2.110 1.00 0.00
4 | ATOM 3 HH32 ACE X 1 4.088 25.905 2.891 1.00 0.00
5 | ATOM 4 HH33 ACE X 1 4.557 26.914 1.502 1.00 0.00
6 | ATOM 5 C ACE X 1 2.770 25.800 1.230 1.00 0.00
7 | ATOM 6 O ACE X 1 1.600 26.150 1.090 1.00 0.00
8 | ATOM 7 N ALA X 2 3.270 24.640 0.690 1.00 0.00
9 | ATOM 8 H ALA X 2 4.259 24.471 0.810 1.00 0.00
10 | ATOM 9 CA ALA X 2 2.480 23.690 -0.190 1.00 0.00
11 | ATOM 10 HA ALA X 2 1.733 24.315 -0.679 1.00 0.00
12 | ATOM 11 CB ALA X 2 3.470 23.160 -1.270 1.00 0.00
13 | ATOM 12 HB1 ALA X 2 4.219 22.525 -0.797 1.00 0.00
14 | ATOM 13 HB2 ALA X 2 2.922 22.582 -2.014 1.00 0.00
15 | ATOM 14 HB3 ALA X 2 3.963 24.002 -1.756 1.00 0.00
16 | ATOM 15 C ALA X 2 1.730 22.590 0.490 1.00 0.00
17 | ATOM 16 O ALA X 2 2.340 21.880 1.280 1.00 0.00
18 | ATOM 17 N NME X 3 0.400 22.430 0.210 1.00 0.00
19 | ATOM 18 H NME X 3 -0.008 23.118 -0.407 1.00 0.00
20 | ATOM 19 CH3 NME X 3 -0.470 21.350 0.730 1.00 0.00
21 | ATOM 20 HH31 NME X 3 0.112 20.693 1.376 1.00 0.00
22 | ATOM 21 HH32 NME X 3 -1.290 21.786 1.300 1.00 0.00
23 | ATOM 22 HH33 NME X 3 -0.873 20.775 -0.103 1.00 0.00
24 | END
25 |
--------------------------------------------------------------------------------
/tests/distribution/energy/test_ase.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from bgflow import ASEBridge, ASEEnergy, XTBBridge, XTBEnergy
5 |
6 |
7 | try:
8 | import ase
9 | import xtb
10 | ase_and_xtb_imported = True
11 | except ImportError:
12 | ase_and_xtb_imported = False
13 |
14 | pytestmark = pytest.mark.skipif(not ase_and_xtb_imported, reason="Tests require ASE and XTB")
15 |
16 |
17 | def test_ase_energy(ctx):
18 | from ase.build import molecule
19 | from xtb.ase.calculator import XTB
20 | water = molecule("H2O")
21 | water.calc = XTB()
22 | target = ASEEnergy(ASEBridge(water, 300.))
23 | pos = torch.tensor(0.1*water.positions, **ctx)
24 | e = target.energy(pos)
25 | f = target.force(pos)
26 |
27 |
28 | def test_ase_vs_xtb(ctx):
29 | # to make sure that unit conversion is the same, etc.
30 | from ase.build import molecule
31 | from xtb.ase.calculator import XTB
32 | water = molecule("H2O")
33 | water.calc = XTB()
34 | target1 = ASEEnergy(ASEBridge(water, 300.))
35 | target2 = XTBEnergy(XTBBridge(water.numbers, 300.))
36 | pos = torch.tensor(0.1 * water.positions[None, ...], **ctx)
37 | assert torch.allclose(target1.energy(pos), target2.energy(pos))
38 | assert torch.allclose(target1.force(pos), target2.force(pos), atol=1e-6)
39 |
40 |
--------------------------------------------------------------------------------
/tests/distribution/energy/test_clipped.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import warnings
4 | import torch
5 | from bgflow import Energy, LinLogCutEnergy, GradientClippedEnergy, DoubleWellEnergy
6 | from bgflow.utils import ClipGradient
7 |
8 |
9 | class StrongRepulsion(Energy):
10 | def __init__(self):
11 | super().__init__([2, 2])
12 |
13 | def _energy(self, x):
14 | dist = torch.cdist(x, x)
15 | return (dist ** -12)[..., 0, 1][:, None]
16 |
17 |
18 | def test_linlogcut(ctx):
19 | lj = StrongRepulsion()
20 | llc = LinLogCutEnergy(lj, high_energy=1e3, max_energy=1e10)
21 | x = torch.tensor([
22 | [[0., 0.], [0.0, 0.0]], # > max energy
23 | [[0., 0.], [0.0, 0.3]], # < max_energy, > high_energy
24 | [[0., 0.], [0.0, 1.]], # < high_energy
25 | ], **ctx)
26 | raw = lj.energy(x)[:, 0]
27 | cut = llc.energy(x)[:, 0]
28 |
29 | # first energy is clamped
30 | assert not (raw <= 1e10).all()
31 | assert (cut <= 1e10).all()
32 | assert cut[0].item() == pytest.approx(1e10, abs=1e-5)
33 | # second energy is softened, but force points in the right direction
34 | assert 1e3 < cut[1].item() < 1e10
35 | assert llc.force(x)[1][1, 1] > 0.0
36 | assert llc.force(x)[1][0, 1] < 0.0
37 | # third energy is unchanged
38 | assert torch.allclose(raw[2], cut[2], atol=1e-5)
39 |
40 |
41 | def openmm_example(grad_clipping, ctx):
42 | try:
43 | with warnings.catch_warnings():
44 | warnings.simplefilter(
45 | "ignore", DeprecationWarning
46 | ) # ignore warnings inside OpenMM
47 | from simtk import openmm, unit
48 | except ImportError:
49 | pytest.skip("Test requires OpenMM.")
50 |
51 | system = openmm.System()
52 | system.addParticle(1.)
53 | system.addParticle(2.)
54 | nonbonded = openmm.NonbondedForce()
55 | nonbonded.addParticle(0.0, 1.0, 2.0)
56 | nonbonded.addParticle(0.0, 1.0, 2.0)
57 | system.addForce(nonbonded)
58 |
59 | from bgflow import OpenMMEnergy, OpenMMBridge
60 | bridge = OpenMMBridge(system, openmm.LangevinIntegrator(300., 0.1, 0.001), n_workers=1)
61 |
62 | energy = OpenMMEnergy(bridge=bridge, two_event_dims=False)
63 | energy = GradientClippedEnergy(energy, grad_clipping).to(**ctx)
64 | positions = torch.tensor([[0.0, 0.0, 0.0, 0.1, 0.2, 0.6]]).to(**ctx)
65 | positions.requires_grad = True
66 | force = energy.force(positions)
67 | energy.energy(positions).sum().backward()
68 | #force = torch.tensor([[-1908.0890, -3816.1780, -11448.5342, 1908.0890, 3816.1780, 11448.5342]]).to(**ctx)
69 | return positions.grad, force
70 |
71 |
72 | def test_openmm_clip_by_value(ctx):
73 | grad_clipping = ClipGradient(clip=3000.0, norm_dim=1)
74 | grad, force = openmm_example(grad_clipping, ctx)
75 | expected = - torch.as_tensor([[-1908.0890, -3000., -3000., 1908.0890, 3000., 3000.]], **ctx)
76 | assert torch.allclose(grad.flatten(), expected, atol=1e-3)
77 |
78 |
79 | def test_openmm_clip_by_atom(ctx):
80 | grad_clipping = ClipGradient(clip=torch.as_tensor([3000.0, 1.0]), norm_dim=3)
81 | grad, force = openmm_example(grad_clipping, ctx)
82 | norm_ratio = torch.linalg.norm(grad[..., :3], dim=-1).item()
83 | assert norm_ratio == pytest.approx(3000.)
84 | assert torch.allclose(grad[..., :3] / 3000., - grad[..., 3:], atol=1e-6)
85 |
86 |
87 | def test_openmm_clip_by_batch(ctx):
88 | grad_clipping = ClipGradient(clip=1.0, norm_dim=-1)
89 | grad, force = openmm_example(grad_clipping, ctx)
90 | ratio = force / grad
91 | assert torch.allclose(ratio, ratio[0, 0] * torch.ones_like(ratio))
92 | assert torch.linalg.norm(grad).item() == pytest.approx(1.)
93 |
94 |
95 | def test_openmm_clip_no_grad(ctx):
96 | energy = GradientClippedEnergy(
97 | energy=DoubleWellEnergy(2),
98 | gradient_clipping=ClipGradient(clip=1.0, norm_dim=1)
99 | )
100 | x = torch.randn(12,2).to(**ctx)
101 | x.requires_grad = False
102 | energy.energy(x)
103 |
--------------------------------------------------------------------------------
/tests/distribution/energy/test_lennard_jones.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from bgflow.distribution.energy import LennardJonesPotential
5 | from bgflow.distribution.energy.lennard_jones import lennard_jones_energy_torch
6 |
7 |
8 | def test_lennard_jones_energy_torch():
9 | energy_large = lennard_jones_energy_torch(torch.tensor(1e10), eps=1, rm=1)
10 | energy_min = lennard_jones_energy_torch(torch.tensor(5.), eps=3, rm=5)
11 | energy_zero = lennard_jones_energy_torch(torch.tensor(2 ** (-1 / 6)), eps=3, rm=1)
12 | assert torch.allclose(energy_large, torch.tensor(0.))
13 | assert torch.allclose(energy_min, torch.tensor(-3.))
14 | assert torch.allclose(energy_zero, torch.tensor(0.), atol=1e-5)
15 |
16 |
17 | @pytest.mark.parametrize("oscillator", [True, False])
18 | @pytest.mark.parametrize("two_event_dims", [True, False])
19 | def test_lennard_jones_potential(oscillator, two_event_dims):
20 | eps = 5.
21 |
22 | # 2 particles in 3D
23 | lj_pot = LennardJonesPotential(
24 | dim=6, n_particles=2, eps=eps, rm=2.0,
25 | oscillator=oscillator, oscillator_scale=1.,
26 | two_event_dims=two_event_dims
27 | )
28 |
29 | batch_shape = (5, 7)
30 | data3d = torch.tensor([[[[-1., 0, 0], [1, 0, 0]]]]).repeat(*batch_shape, 1, 1)
31 | if not two_event_dims:
32 | data3d = data3d.view(*batch_shape, 6)
33 | energy3d = torch.tensor([[- eps]]).repeat(*batch_shape)
34 | if oscillator:
35 | energy3d += 1
36 | lj_energy_3d = lj_pot.energy(data3d)
37 | assert torch.allclose(energy3d[:, None], lj_energy_3d)
38 |
39 | # 3 particles in 2D
40 | lj_pot = LennardJonesPotential(
41 | dim=6, n_particles=3, eps=eps, rm=1.0,
42 | oscillator=oscillator, oscillator_scale=1.,
43 | two_event_dims=two_event_dims
44 | )
45 | h = np.sqrt(0.75)
46 | data2d = torch.tensor([[[0, 2 / 3 * h], [0.5, -1 / 3 * h], [-0.5, -1 / 3 * h]]], dtype=torch.float)
47 | if not two_event_dims:
48 | data2d = data2d.view(-1, 6)
49 | energy2d = torch.tensor([- 3 * eps])
50 | if oscillator:
51 | energy2d += 0.5 * (data2d ** 2).sum()
52 | lj_energy_2d = lj_pot.energy(data2d)
53 | assert torch.allclose(energy2d, lj_energy_2d)
54 | lj_energy2d_np = lj_pot._energy_numpy(data2d)
55 | assert energy2d[:, None].numpy() == pytest.approx(lj_energy2d_np, abs=1e-6)
56 |
--------------------------------------------------------------------------------
/tests/distribution/energy/test_multi_double_well_potential.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow.distribution import MultiDoubleWellPotential
4 |
5 |
6 | def test_multi_double_well_potential(ctx):
7 | target = MultiDoubleWellPotential(dim=4, n_particles=2, a=0.9, b=-4, c=0, offset=4)
8 | x = torch.tensor([[[2., 0, ], [-2, 0]]], **ctx)
9 | energy = target.energy(x)
10 | target.force(x)
11 | assert torch.allclose(energy, torch.tensor([[0.]], **ctx), atol=1e-5)
12 |
--------------------------------------------------------------------------------
/tests/distribution/energy/test_xtb.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from bgflow import XTBEnergy, XTBBridge
5 |
6 | try:
7 | import xtb
8 | xtb_imported = True
9 | except ImportError:
10 | xtb_imported = False
11 |
12 | pytestmark = pytest.mark.skipif(not xtb_imported, reason="Test requires XTB")
13 |
14 |
15 | @pytest.mark.parametrize("pos_shape", [(1, 3, 3), (1, 9)])
16 | def test_xtb_water(pos_shape, ctx):
17 | unit = pytest.importorskip("openmm.unit")
18 | temperature = 300
19 | numbers = np.array([8, 1, 1])
20 | positions = torch.tensor([
21 | [0.00000000000000, 0.00000000000000, -0.73578586109551],
22 | [1.44183152868459, 0.00000000000000, 0.36789293054775],
23 | [-1.44183152868459, 0.00000000000000, 0.36789293054775]],
24 | **ctx
25 | )
26 | positions = (positions * unit.bohr).value_in_unit(unit.nanometer)
27 | target = XTBEnergy(
28 | XTBBridge(numbers=numbers, temperature=temperature),
29 | two_event_dims=(pos_shape == (1, 3, 3))
30 | )
31 | energy = target.energy(positions.reshape(pos_shape))
32 | force = target.force(positions.reshape(pos_shape))
33 | assert energy.shape == (1, 1)
34 | assert force.shape == pos_shape
35 |
36 | kbt = unit.BOLTZMANN_CONSTANT_kB * temperature * unit.kelvin
37 | expected_energy = torch.tensor(-5.070451354836705, **ctx) * unit.hartree / kbt
38 | expected_force = - torch.tensor([
39 | [6.24500451e-17, - 3.47909735e-17, - 5.07156941e-03],
40 | [-1.24839222e-03, 2.43536791e-17, 2.53578470e-03],
41 | [1.24839222e-03, 1.04372944e-17, 2.53578470e-03],
42 | ], **ctx) * unit.hartree/unit.bohr/(kbt/unit.nanometer)
43 | assert torch.allclose(energy.flatten(), expected_energy.flatten(), atol=1e-5)
44 | assert torch.allclose(force.flatten(), expected_force.flatten(), atol=1e-5)
45 |
46 |
47 | def _eval_invalid(ctx, err_handling):
48 | pos = torch.zeros(1, 3, 3, **ctx)
49 | target = XTBEnergy(
50 | XTBBridge(numbers=np.array([8, 1, 1]), temperature=300, err_handling=err_handling)
51 | )
52 | return target.energy(pos), target.force(pos)
53 |
54 |
55 | def test_xtb_error(ctx):
56 | from xtb.interface import XTBException
57 | with pytest.raises(XTBException):
58 | _eval_invalid(ctx, err_handling="error")
59 |
60 |
61 | def test_xtb_warning(ctx):
62 | with pytest.warns(UserWarning, match="Caught exception in xtb"):
63 | e, f = _eval_invalid(ctx, err_handling="warning")
64 | assert torch.isinf(e).all()
65 | assert torch.allclose(f, torch.zeros_like(f))
66 |
67 |
68 | def test_xtb_ignore(ctx):
69 | e, f = _eval_invalid(ctx, err_handling="ignore")
70 | assert torch.isinf(e).all()
71 | assert torch.allclose(f, torch.zeros_like(f))
72 |
--------------------------------------------------------------------------------
/tests/distribution/sampling/test_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | from bgflow import DataSetSampler, DataLoaderSampler
5 |
6 |
7 | def test_dataset_sampler(ctx):
8 | data = torch.arange(12).reshape(4,3).to(**ctx)
9 | sampler = DataSetSampler(data).to(**ctx)
10 | idxs = sampler._idxs.copy()
11 | # test sampling out of range
12 | assert torch.allclose(sampler.sample(3), data[idxs[:3]])
13 | assert torch.allclose(sampler.sample(3)[0], data[idxs[3]])
14 | for i in range(10):
15 | assert sampler.sample(3).shape == (3, 3)
16 | # test sampling larger #samples than len(data)
17 | sampler._current_index = 0
18 | samples = sampler.sample(12)
19 | assert samples.shape == (12, 3)
20 | # check that rewinding works
21 | for i in range(3):
22 | assert torch.allclose(
23 | data.flatten(),
24 | torch.sort(samples[4*i: 4*(i+1)].flatten())[0]
25 | )
26 |
27 |
28 | def test_dataset_to_device_sampler(ctx):
29 | data = torch.arange(12).reshape(4,3)
30 | sampler = DataSetSampler(data).to(**ctx)
31 | assert sampler.sample(10).device == ctx["device"]
32 |
33 |
34 | def test_multiple_dataset_sampler(ctx):
35 | data = torch.arange(12).reshape(4,3).to(**ctx)
36 | data2 = torch.arange(8).reshape(4,2).to(**ctx)
37 | sampler = DataSetSampler(data, data2).to(**ctx)
38 | samples = sampler.sample(3)
39 | assert len(samples) == 2
40 | assert samples[0].shape == (3, 3)
41 | assert samples[1].shape == (3, 2)
42 | assert samples[0].device == ctx["device"]
43 |
44 |
45 | def test_resizing(ctx):
46 | data = torch.arange(12).reshape(4, 3).to(**ctx)
47 | sampler = DataSetSampler(data)
48 | sampler.resize_(5)
49 | assert len(sampler) == 5
50 | assert sampler.sample(2).shape == (2, 3)
51 | sampler.resize_(3)
52 | assert len(sampler) == 3
53 | assert sampler.sample(2).shape == (2, 3)
54 |
55 |
56 | def test_dataloader_sampler(ctx):
57 | loader = torch.utils.data.DataLoader(
58 | torch.utils.data.TensorDataset(torch.randn(10, 2, 2, **ctx)),
59 | batch_size=4,
60 | )
61 | sampler = DataLoaderSampler(loader, **ctx)
62 | samples = sampler.sample(4)
63 | assert samples.shape == (4, 2, 2)
64 | assert samples.device == ctx["device"]
65 |
66 |
--------------------------------------------------------------------------------
/tests/distribution/sampling/test_iterative.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | from bgflow import IterativeSampler, SamplerState, SamplerStep
5 | from bgflow.distribution.sampling._iterative_helpers import AbstractSamplerState
6 |
7 |
8 | class AddOne(SamplerStep):
9 | def _step(self, state: AbstractSamplerState):
10 | statedict = state.as_dict()
11 | samples = tuple(
12 | x + 1.0
13 | for x in statedict["samples"]
14 | )
15 | return state.replace(samples=samples)
16 |
17 |
18 | def test_iterative_sampler(ctx):
19 | state = SamplerState(samples=[torch.zeros(2, **ctx), ])
20 |
21 | # test burnin
22 | sampler = IterativeSampler(state, sampler_steps=[AddOne()], n_burnin=10)
23 | assert torch.allclose(sampler.state.samples[0], 10*torch.ones_like(sampler.state.samples[0]))
24 |
25 | # test sampling
26 | samples = sampler.sample(2)
27 | assert torch.allclose(samples, torch.tensor([[11., 11.], [12., 12.]], **ctx))
28 |
29 | # test stride
30 | sampler.stride = 5
31 | samples = sampler.sample(2)
32 | assert sampler.i == 14
33 | assert torch.allclose(samples, torch.tensor([[17., 17.], [22., 22.]], **ctx))
34 |
35 | # test iteration
36 | sampler.max_iterations = 15
37 | for batch in sampler: # only called once
38 | assert torch.allclose(batch.samples[0], torch.tensor([[27., 27.]], **ctx))
39 | sampler.max_iterations = None
40 |
41 |
--------------------------------------------------------------------------------
/tests/distribution/sampling/test_iterative_helpers.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bgflow.distribution.sampling._iterative_helpers import _map_to_primary_cell
4 |
5 |
6 | def test_map_to_primary_cell():
7 | cell = torch.eye(3)
8 | x = torch.tensor([[1.2, -0.1, 4.5]])
9 | assert torch.allclose(_map_to_primary_cell(x, cell), torch.tensor([[0.2, 0.9, 0.5]]))
10 |
11 | cell = torch.tensor(
12 | [
13 | [1.0, 2.0, 0.0],
14 | [0.0, 2.0, 0.0],
15 | [0.0, 0.0, 1.0]
16 | ]
17 | )
18 | assert torch.allclose(_map_to_primary_cell(x, cell), torch.tensor([[2.2, 1.9, 0.5]]))
19 |
20 |
--------------------------------------------------------------------------------
/tests/distribution/sampling/test_mcmc.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import pytest
4 | import torch
5 | from bgflow import (
6 | GaussianProposal, SamplerState, NormalDistribution,
7 | IterativeSampler, MCMCStep, LatentProposal, BentIdentity, GaussianMCMCSampler
8 | )
9 | from bgflow.distribution.sampling.mcmc import _GaussianMCMCSampler
10 |
11 |
12 | @pytest.mark.parametrize("proposal", [
13 | GaussianProposal(noise_std=0.2),
14 | LatentProposal(
15 | flow=BentIdentity(),
16 | base_proposal=GaussianProposal(noise_std=0.4))
17 | ])
18 | @pytest.mark.parametrize("temperatures", [torch.ones(3), torch.arange(1, 10, 100)])
19 | def test_mcmc(ctx, proposal, temperatures):
20 | """sample from a normal distribution with mu=3 and std=1,2,3 using MCMC"""
21 | try:
22 | import tqdm
23 | progress_bar = tqdm.tqdm
24 | except ImportError:
25 | progress_bar = lambda x: x
26 | target = NormalDistribution(4, 3.0*torch.ones(4, **ctx))
27 | temperatures = temperatures.to(**ctx)
28 | # for testing efficiency we have a second batch dimension
29 | # this is not required; we could remove the first batch dimension (256)
30 | # and just sample longer
31 | state = SamplerState(samples=0.0*torch.ones(512, 3, 4, **ctx))
32 | mcmc = IterativeSampler(
33 | sampler_state=state,
34 | sampler_steps=[
35 | MCMCStep(
36 | target,
37 | proposal=proposal.to(**ctx),
38 | target_temperatures=temperatures
39 | )
40 | ],
41 | stride=2,
42 | n_burnin=100,
43 | progress_bar=progress_bar
44 | )
45 | samples = mcmc.sample(100)
46 | assert torch.allclose(samples.mean(dim=(0, 1, 3)), torch.tensor([3.0]*3, **ctx), atol=0.1)
47 | std = samples.std(dim=(0, 1, 3))
48 | assert torch.allclose(std, temperatures.sqrt(), rtol=0.05, atol=0.0)
49 |
50 |
51 | def test_old_vs_new_mcmc(ctx):
52 | energy = NormalDistribution(dim=4)
53 | x0 = torch.randn(64, 4)
54 |
55 | def constraint(x):
56 | return torch.fmod(x,torch.ones(4))
57 |
58 | with warnings.catch_warnings():
59 | warnings.simplefilter("ignore", DeprecationWarning)
60 | old_mc = _GaussianMCMCSampler(
61 | energy, x0, n_stride=10, n_burnin=10,
62 | noise_std=0.3, box_constraint=constraint
63 | )
64 | new_mc = GaussianMCMCSampler(
65 | energy, x0, stride=10, n_burnin=10,
66 | noise_std=0.3, box_constraint=constraint
67 | )
68 | old_samples = old_mc.sample(1000)
69 | new_samples = new_mc.sample(100)
70 | assert old_samples.shape == (6400, 4)
71 | assert old_samples.shape == new_samples.shape
72 | assert old_samples.mean().item() == pytest.approx(new_samples.mean().item(), abs=1e-2)
73 | assert old_samples.std().item() == pytest.approx(new_samples.std().item(), abs=1e-1)
74 |
--------------------------------------------------------------------------------
/tests/distribution/test_distribution.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from torch.distributions import MultivariateNormal
5 | from bgflow.distribution import TorchDistribution, NormalDistribution, UniformDistribution, ProductDistribution
6 |
7 |
8 | def _random_mean_cov(dim, device, dtype):
9 | mean = 10*torch.randn(dim).to(device, dtype)
10 | # generate a symmetric, positive definite matrix
11 | cov = torch.triu(torch.rand(dim, dim))
12 | cov = cov + cov.T + dim * torch.diag(torch.ones(dim))
13 | cov = cov.to(device, dtype)
14 | return mean, cov
15 |
16 |
17 | @pytest.mark.parametrize("dim", (2, 10))
18 | def test_distribution_energy(dim, device, dtype):
19 | """compare torch's normal distribution with bgflow's normal distribution"""
20 | n_samples = 7
21 | mean, cov = _random_mean_cov(dim, device, dtype)
22 | samples = torch.randn((n_samples, dim)).to(device, dtype)
23 | normal_trch = TorchDistribution(MultivariateNormal(loc=mean, covariance_matrix=cov))
24 | normal_bgtrch = NormalDistribution(dim, mean, cov)
25 | assert torch.allclose(normal_trch.energy(samples), normal_bgtrch.energy(samples), rtol=2e-2, atol=1e-2)
26 |
27 |
28 | @pytest.mark.parametrize("dim", (2, 10))
29 | @pytest.mark.parametrize("sample_shape", (50000, torch.Size([10,1])))
30 | def test_distribution_samples(dim, sample_shape, device, dtype):
31 | """compare torch's normal distribution with bgflow's normal distribution"""
32 | mean, cov = _random_mean_cov(dim, device, dtype)
33 | normal_trch = TorchDistribution(MultivariateNormal(loc=mean, covariance_matrix=cov))
34 | normal_bgtrch = NormalDistribution(dim, mean, cov)
35 | samples_trch = normal_trch.sample(sample_shape)
36 | target_shape = torch.Size([sample_shape]) if isinstance(sample_shape, int) else sample_shape
37 | assert samples_trch.size() == target_shape + torch.Size([dim])
38 | if isinstance(sample_shape, int):
39 | samples_bgtrch = normal_bgtrch.sample(sample_shape)
40 | # to make sure that both sample from the same distribution, compute divergences
41 | for p in [normal_trch, normal_bgtrch]:
42 | for q in [normal_trch, normal_bgtrch]:
43 | for x in [samples_bgtrch, samples_trch]:
44 | for y in [samples_bgtrch, samples_trch]:
45 | div = torch.mean(
46 | (-p.energy(x) + q.energy(y))
47 | )
48 | assert torch.abs(div) < 5e-2
49 |
50 |
51 | def test_sample_uniform_with_temperature(ctx):
52 | uniform = UniformDistribution(low=torch.zeros(100, **ctx), high=torch.ones(100, **ctx))
53 | assert uniform.sample(20).mean().item() == pytest.approx(0.5, abs=0.05)
54 | assert uniform.sample(20, temperature=100.).mean().item() == pytest.approx(0.5, abs=0.05)
55 |
56 |
57 | def test_sample_product_with_temperature(ctx):
58 | normal = NormalDistribution(dim=100, mean=torch.zeros(100, **ctx))
59 | product = ProductDistribution([normal, normal])
60 | x1, y1 = product.sample(20, temperature=1.)
61 | x2, y2 = product.sample(20, temperature=100.)
62 |
63 | assert (x1.std() / x2.std()).item() == pytest.approx(0.1, abs=0.05)
64 | assert (y1.std() / y2.std()).item() == pytest.approx(0.1, abs=0.05)
65 |
66 |
67 |
--------------------------------------------------------------------------------
/tests/distribution/test_product.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from bgflow import NormalDistribution, ProductDistribution
5 |
6 |
7 | def test_multi_distribution():
8 | """Test that a compound normal distribution behaves identical to a multivariate standard normal distribution."""
9 | n1 = NormalDistribution(3)
10 | n2 = NormalDistribution(2)
11 | n3 = NormalDistribution(1)
12 | n = NormalDistribution(6)
13 | compound = ProductDistribution([n1, n2, n3], cat_dim=-1)
14 | n_samples = 10
15 | samples = compound.sample(n_samples)
16 | assert samples.shape == n.sample(n_samples).shape
17 | assert n.energy(samples).shape == compound.energy(samples).shape
18 | assert n.energy(samples).numpy() == pytest.approx(compound.energy(samples).numpy())
19 |
20 |
21 | def test_multi_distribution_no_cat():
22 | """Test that a compound normal distribution behaves identical to a multivariate standard normal distribution."""
23 | n1 = NormalDistribution(3)
24 | n2 = NormalDistribution(2)
25 | n3 = NormalDistribution(1)
26 | n = NormalDistribution(6)
27 | compound = ProductDistribution([n1, n2, n3], cat_dim=None)
28 | n_samples = 10
29 | samples = compound.sample(n_samples)
30 | assert len(samples) == 3
31 | assert isinstance(samples, tuple)
32 | assert samples[0].shape == (10, 3)
33 | assert samples[1].shape == (10, 2)
34 | assert samples[2].shape == (10, 1)
35 | assert compound.energy(*samples).shape == (10, 1)
36 | assert n.energy(torch.cat(samples, dim=-1)).numpy() == pytest.approx(compound.energy(*samples).numpy())
37 |
38 |
39 | def test_sample_to_cpu(ctx):
40 | cpu = torch.device("cpu")
41 | normal_distribution = NormalDistribution(dim=2).to(**ctx)
42 | samples = normal_distribution.sample_to_cpu(150, batch_size=17)
43 | assert samples.shape == (150, 2)
44 | assert samples.device == cpu
45 |
46 | product_distribution = ProductDistribution([NormalDistribution(2), NormalDistribution(3)], cat_dim=None)
47 | samples = product_distribution.sample_to_cpu(150, batch_size=17)
48 | assert len(samples) == 2
49 | assert samples[0].shape == (150, 2)
50 | assert samples[1].shape == (150, 3)
51 | assert samples[0].device == cpu
52 | assert samples[1].device == cpu
53 |
--------------------------------------------------------------------------------
/tests/factory/test_distribution_factory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow import UniformDistribution, NormalDistribution, TruncatedNormalDistribution, make_distribution
4 |
5 |
6 | @pytest.mark.parametrize("prior_type", [UniformDistribution, NormalDistribution, TruncatedNormalDistribution])
7 | def test_prior_factory(prior_type, ctx):
8 | prior = make_distribution(prior_type, 2, **ctx)
9 | samples = prior.sample(10)
10 | assert torch.device(samples.device) == torch.device(ctx["device"])
11 | assert samples.dtype == ctx["dtype"]
12 | assert samples.shape == (10, 2)
13 |
14 |
15 | def test_prior_factory_with_kwargs(ctx):
16 | prior = make_distribution(UniformDistribution, 2, low=torch.tensor([2.0, 2.0]), high=torch.tensor([3.0, 3.0]), **ctx)
17 | samples = prior.sample(5)
18 | assert torch.device(samples.device) == torch.device(ctx["device"])
19 | assert samples.dtype == ctx["dtype"]
20 | assert (samples > 1.0).all()
21 |
--------------------------------------------------------------------------------
/tests/factory/test_icmarginals.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 |
5 | from bgflow import (
6 | GlobalInternalCoordinateTransformation,
7 | InternalCoordinateMarginals, BONDS, ANGLES,
8 | ShapeDictionary, TensorInfo
9 | )
10 |
11 |
12 | @pytest.mark.parametrize("with_data", [True, False])
13 | def test_icmarginals_inform_api(tmpdir, ctx, with_data):
14 | """API test"""
15 | bgmol = pytest.importorskip("bgmol")
16 | dataset = bgmol.datasets.Ala2Implicit1000Test(
17 | root=tmpdir,
18 | download=True,
19 | read=True
20 | )
21 | coordinate_transform = GlobalInternalCoordinateTransformation(
22 | bgmol.systems.ala2.DEFAULT_GLOBAL_Z_MATRIX
23 | )
24 | current_dims = ShapeDictionary()
25 | current_dims[BONDS] = (coordinate_transform.dim_bonds - dataset.system.system.getNumConstraints(), )
26 | current_dims[ANGLES] = (coordinate_transform.dim_angles, )
27 | marginals = InternalCoordinateMarginals(current_dims, ctx)
28 | if with_data:
29 | constrained_indices, _ = bgmol.bond_constraints(dataset.system.system, coordinate_transform)
30 | marginals.inform_with_data(
31 | torch.tensor(dataset.xyz, **ctx), coordinate_transform,
32 | constrained_bond_indices=constrained_indices
33 | )
34 | else:
35 | marginals.inform_with_force_field(
36 | dataset.system.system, coordinate_transform, 1000.,
37 | )
38 |
--------------------------------------------------------------------------------
/tests/factory/test_tensor_info.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | from bgflow.factory.tensor_info import ShapeDictionary, TensorInfo, BONDS, ANGLES, TORSIONS, FIXED
4 |
5 | def test_shape_info(crd_trafo):
6 | shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo)
7 |
8 | for key in [BONDS, ANGLES, TORSIONS]:
9 | assert shape_info[key] == (len(crd_trafo.z_matrix), )
10 | assert shape_info[FIXED][0] == 3*len(crd_trafo.fixed_atoms)
11 | assert not (shape_info.is_circular([BONDS, TORSIONS])[: shape_info[BONDS][0]]).any()
12 | assert (shape_info.is_circular([BONDS, TORSIONS])[shape_info[BONDS][0]:]).all()
13 | assert shape_info.is_circular().sum() == shape_info[TORSIONS][0]
14 | assert (
15 | shape_info.circular_indices([FIXED, TORSIONS])
16 | == np.arange(shape_info[FIXED][0], shape_info[FIXED][0]+shape_info[TORSIONS][0])
17 | ).all()
18 | assert (
19 | shape_info.circular_indices()
20 | == np.arange(shape_info[BONDS][0]+shape_info[ANGLES][0],
21 | shape_info[BONDS][0]+shape_info[ANGLES][0]+shape_info[TORSIONS][0]
22 | )
23 | ).all()
24 | assert shape_info.dim_all([BONDS, TORSIONS]) == shape_info[BONDS][0] + shape_info[TORSIONS][0]
25 | assert shape_info.dim_all() == 66
26 | assert shape_info.dim_circular([ANGLES, BONDS]) == 0
27 | assert shape_info.dim_circular() == shape_info[TORSIONS][0]
28 | assert shape_info.dim_noncircular([ANGLES, BONDS]) == shape_info[ANGLES][0] + shape_info[BONDS][0]
29 | assert shape_info.dim_noncircular() == 66 - shape_info[TORSIONS][0]
30 |
31 | assert shape_info.dim_cartesian([ANGLES, BONDS]) == 0
32 | assert shape_info.dim_cartesian([FIXED]) == shape_info[FIXED][0]
33 | assert shape_info.dim_noncartesian([ANGLES, BONDS]) == shape_info[ANGLES][0] + shape_info[BONDS][0]
34 | assert shape_info.dim_noncartesian([FIXED]) == 0
35 | assert not (shape_info.is_cartesian([BONDS, FIXED])[: shape_info[BONDS][0]]).any()
36 | assert (
37 | shape_info.cartesian_indices()
38 | == np.arange(shape_info[BONDS][0]+shape_info[ANGLES][0]+shape_info[TORSIONS][0],
39 | shape_info[BONDS][0]+shape_info[ANGLES][0]+shape_info[TORSIONS][0]+shape_info[FIXED][0]
40 | )
41 | ).all()
42 |
43 |
44 |
45 |
46 | def test_shape_info_insert():
47 | shape_info = ShapeDictionary()
48 | for i in range(5):
49 | shape_info[i] = (i, )
50 | shape_info.insert(100, 2, (100, ))
51 | assert list(shape_info) == [0, 1, 100, 2, 3, 4]
52 | assert list(shape_info.values()) == [(i, ) for i in [0, 1, 100, 2, 3, 4]]
53 |
54 |
55 | def test_shape_info_split_merge():
56 | shape_info = ShapeDictionary()
57 | for i in range(8):
58 | shape_info[i] = (i, )
59 | shape_info.split(4, into=("a", "b"), sizes=(1, 3))
60 | assert list(shape_info) == [0, 1, 2, 3, "a", "b", 5, 6, 7]
61 | assert list(shape_info.values()) == [(i, ) for i in [0, 1, 2, 3, 1, 3, 5, 6, 7]]
62 |
63 | shape_info.merge(("a", "b"), to=4)
64 | assert list(shape_info) == [0, 1, 2, 3, 4, 5, 6, 7]
65 | assert list(shape_info.values()) == [(i, ) for i in [0, 1, 2, 3, 4, 5, 6, 7]]
66 |
--------------------------------------------------------------------------------
/tests/factory/test_transformer_factory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | import bgflow
5 | from bgflow import (
6 | make_transformer, make_conditioners,
7 | ShapeDictionary, BONDS, FIXED, ANGLES, TORSIONS,
8 | ConditionalSplineTransformer, AffineTransformer
9 | )
10 |
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "transformer_type",
15 | [
16 | ConditionalSplineTransformer,
17 | AffineTransformer,
18 | # TODO: MixtureCDFTransformer
19 | ]
20 | )
21 | def test_transformers(crd_trafo, transformer_type):
22 | pytest.importorskip("nflows")
23 |
24 | shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo)
25 | conditioners = make_conditioners(transformer_type, (BONDS,), (FIXED,), shape_info)
26 | transformer = make_transformer(transformer_type, (BONDS,), shape_info, conditioners=conditioners)
27 | out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]), torch.zeros(2, shape_info[BONDS][0]))
28 | assert out[0].shape == (2, shape_info[BONDS][0])
29 |
30 |
31 | def test_circular_affine(crd_trafo):
32 | shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo)
33 |
34 | with pytest.raises(ValueError):
35 | conditioners = make_conditioners(
36 | bgflow.AffineTransformer,
37 | (TORSIONS,), (FIXED,), shape_info=shape_info
38 | )
39 | make_transformer(bgflow.AffineTransformer, (TORSIONS,), shape_info, conditioners=conditioners)
40 |
41 | conditioners = make_conditioners(
42 | bgflow.AffineTransformer,
43 | (TORSIONS,), (FIXED,), shape_info=shape_info, use_scaling=False
44 | )
45 | assert list(conditioners.keys()) == ["shift_transformation"]
46 | transformer = make_transformer(bgflow.AffineTransformer, (TORSIONS,), shape_info, conditioners=conditioners)
47 | assert transformer._is_circular
48 | out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]), torch.zeros(2, shape_info[TORSIONS][0]))
49 | assert out[0].shape == (2, shape_info[TORSIONS][0])
50 |
--------------------------------------------------------------------------------
/tests/nn/flow/dynamics/test_kernel_dynamics.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow.distribution import NormalDistribution
4 | from bgflow.nn.flow import DiffEqFlow
5 | from bgflow.nn.flow.dynamics import KernelDynamics
6 | from bgflow.utils import brute_force_jacobian_trace
7 |
8 |
9 | @pytest.mark.parametrize("n_particles", [2, 3])
10 | @pytest.mark.parametrize("n_dimensions", [2, 3])
11 | @pytest.mark.parametrize("use_checkpoints", [True, False])
12 | def test_kernel_dynamics(n_particles, n_dimensions, use_checkpoints, device):
13 | # Build flow with kernel dynamics and run initial config.
14 |
15 | dim = n_particles * n_dimensions
16 | n_samples = 100
17 | prior = NormalDistribution(dim).to(device)
18 | latent = prior.sample(n_samples)
19 |
20 | d_max = 8
21 | mus = torch.linspace(0, d_max, 10).to(device)
22 | gammas = 0.3 * torch.ones(len(mus))
23 |
24 | mus_time = torch.linspace(0, 1, 5).to(device)
25 | gammas_time = 0.3 * torch.ones(len(mus_time))
26 |
27 | kernel_dynamics = KernelDynamics(n_particles, n_dimensions, mus, gammas, optimize_d_gammas=True,
28 | optimize_t_gammas=True, mus_time=mus_time, gammas_time=gammas_time)
29 |
30 | flow = DiffEqFlow(
31 | dynamics=kernel_dynamics
32 | ).to(device)
33 |
34 | if not use_checkpoints:
35 | pytest.importorskip("torchdiffeq")
36 |
37 | samples, dlogp = flow(latent)
38 | latent2, ndlogp = flow.forward(samples, inverse=True)
39 |
40 | assert samples.shape == torch.Size([n_samples, dim])
41 | assert dlogp.shape == torch.Size([n_samples, 1])
42 | # assert (latent - latent2).abs().mean() < 0.002
43 | # assert (latent - samples).abs().mean() > 0.01
44 | # assert (dlogp + ndlogp).abs().mean() < 0.002
45 |
46 | if use_checkpoints:
47 | pytest.importorskip("anode")
48 | flow._use_checkpoints = True
49 | options = {
50 | "Nt": 20,
51 | "method": "RK4"
52 | }
53 | flow._kwargs = options
54 |
55 | samples, dlogp = flow(latent)
56 | latent2, ndlogp = flow.forward(samples, inverse=True)
57 |
58 | assert samples.shape == torch.Size([n_samples, dim])
59 | assert dlogp.shape == torch.Size([n_samples, 1])
60 | # assert (latent - latent2).abs().mean() < 0.002
61 | # assert (latent - samples).abs().mean() > 0.01
62 | # assert (dlogp + ndlogp).abs().mean() < 0.002
63 |
64 |
65 | @pytest.mark.parametrize("n_particles", [2, 3])
66 | @pytest.mark.parametrize("n_dimensions", [2, 3])
67 | def test_kernel_dynamics_trace(n_particles, n_dimensions):
68 | # Test if the trace computation of the kernel dynamics is correct.
69 |
70 | d_max = 8
71 | mus = torch.linspace(0, d_max, 10)
72 | gammas = 0.3 * torch.ones(len(mus))
73 |
74 | mus_time = torch.linspace(0, 1, 5)
75 | gammas_time = 0.3 * torch.ones(len(mus_time))
76 |
77 | kernel_dynamics = KernelDynamics(n_particles, n_dimensions, mus, gammas, mus_time=mus_time, gammas_time=gammas_time)
78 | x = torch.Tensor(1, n_particles * n_dimensions).normal_().requires_grad_(True)
79 | y, trace = kernel_dynamics(1., x)
80 | brute_force_trace = brute_force_jacobian_trace(y, x)
81 |
82 | # The kernel dynamics outputs the negative trace
83 | assert torch.allclose(trace.sum(), -brute_force_trace[0], atol=1e-4)
84 |
85 | # test kernel dynamics without trace
86 | kernel_dynamics(1., x, compute_divergence=False)
87 |
--------------------------------------------------------------------------------
/tests/nn/flow/estimators/test_hutchinson_estimator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow.distribution import NormalDistribution
4 | from bgflow.nn.flow.dynamics import TimeIndependentDynamics
5 | from bgflow.nn.flow.estimator import HutchinsonEstimator
6 | from bgflow.nn import DenseNet
7 | from bgflow.utils import brute_force_jacobian_trace
8 |
9 |
10 | @pytest.mark.parametrize("dim", [1, 2])
11 | @pytest.mark.parametrize("rademacher", [True, False])
12 | def test_hutchinson_estimator(dim, rademacher):
13 | # Test trace estimation of the hutchinson estimator for small dimensions, where it is less noisy
14 | n_batch = 1024
15 | time_independent_dynamics = TimeIndependentDynamics(
16 | DenseNet([dim, 16, 16, dim], activation=torch.nn.Tanh()))
17 | hutchinson_estimator = HutchinsonEstimator(rademacher)
18 | normal_distribution = NormalDistribution(dim)
19 | x = normal_distribution.sample(n_batch)
20 | y, trace = hutchinson_estimator(time_independent_dynamics, None, x)
21 | brute_force_trace = brute_force_jacobian_trace(y, x)
22 | if rademacher and dim == 1:
23 | # Hutchinson is exact for rademacher noise and dim=1
24 | assert torch.allclose(trace.mean(), -brute_force_trace.mean(), atol=1e-6)
25 | else:
26 | assert torch.allclose(trace.mean(), -brute_force_trace.mean(), atol=1e-1)
27 |
28 |
29 | @pytest.mark.parametrize("rademacher", [True, False])
30 | def test_test_hutchinson_estimator_reset_noise(rademacher):
31 | # Test if the noise vector is resetted to deal with different shape
32 | dim = 10
33 | time_independent_dynamics = TimeIndependentDynamics(
34 | DenseNet([dim, 16, 16, dim], activation=torch.nn.Tanh()))
35 | hutchinson_estimator = HutchinsonEstimator(rademacher)
36 | normal_distribution = NormalDistribution(dim)
37 |
38 | x = normal_distribution.sample(100)
39 | _, _ = hutchinson_estimator(time_independent_dynamics, None, x)
40 | x = normal_distribution.sample(10)
41 | hutchinson_estimator.reset_noise()
42 | # this will fail if the noise is not resetted
43 | _, _ = hutchinson_estimator(time_independent_dynamics, None, x)
44 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_cdf.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from bgflow import (
5 | DistributionTransferFlow, ConstrainGaussianFlow,
6 | CDFTransform, InverseFlow, TruncatedNormalDistribution
7 | )
8 | from torch.distributions import Normal
9 |
10 |
11 | def test_distribution_transfer(ctx):
12 | src = Normal(torch.zeros(2, **ctx), torch.ones(2, **ctx))
13 | target = Normal(torch.ones(2, **ctx), torch.ones(2, **ctx))
14 | swap = DistributionTransferFlow(src, target)
15 | # forward
16 | out, dlogp = swap.forward(torch.zeros((2,2), **ctx))
17 | assert torch.allclose(out, torch.ones(2,2, **ctx))
18 | assert torch.allclose(dlogp, torch.zeros(2,1, **ctx))
19 | # inverse
20 | out2, dlogp = swap.forward(out, inverse=True)
21 | assert torch.allclose(out2, torch.zeros(2,2, **ctx))
22 | assert torch.allclose(dlogp, torch.zeros(2,1, **ctx))
23 |
24 |
25 | def test_constrain_positivity(ctx):
26 | """Make sure that the bonds are obeyed."""
27 | torch.manual_seed(1)
28 | constrain_flow = ConstrainGaussianFlow(mu=torch.ones(10, **ctx), lower_bound=1e-10)
29 | samples = (1.0+torch.randn((10,10), **ctx)) * 1000.
30 | y, dlogp = constrain_flow.forward(samples)
31 | assert y.shape == (10, 10)
32 | assert dlogp.shape == (10, 1)
33 | assert (y >= 0.0).all()
34 | assert (dlogp.sum() < 0.0).all()
35 |
36 |
37 | def test_constrain_slightly_pertubed(ctx):
38 | """Check that samples are not changed much when the bounds are generous."""
39 | torch.manual_seed(1)
40 | constrain_flow = ConstrainGaussianFlow(mu=torch.ones(10, **ctx), sigma=torch.ones(10, **ctx), lower_bound=-1000., upper_bound=1000.)
41 | samples = (1.0+torch.randn((10,10), **ctx))
42 | y, dlogp = constrain_flow.forward(samples)
43 | assert torch.allclose(samples, y, atol=1e-4, rtol=0.0)
44 | assert torch.allclose(dlogp, torch.zeros_like(dlogp), atol=1e-4, rtol=0.0)
45 |
46 | x2, dlogp = constrain_flow.forward(y, inverse=True)
47 | assert torch.allclose(x2, y, atol=1e-4, rtol=0.0)
48 | assert torch.allclose(dlogp, torch.zeros_like(dlogp), atol=1e-4, rtol=0.0)
49 |
50 |
51 | def test_cdf_transform(ctx):
52 | input = torch.arange(0.1, 1.0, 0.1, **ctx)[:,None]
53 | input.requires_grad = True
54 | truncated_normal = TruncatedNormalDistribution(
55 | mu=torch.tensor([0.5], **ctx),
56 | upper_bound=torch.tensor([1.0], **ctx),
57 | is_learnable=True
58 | )
59 | flow = InverseFlow(CDFTransform((truncated_normal)))
60 | output, dlogp = flow.forward(input)
61 | assert output.mean().item() == pytest.approx(0.5)
62 | # try computing the grad
63 | output.mean().backward(create_graph=True)
64 | dlogp.mean().backward()
65 |
66 |
67 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_inverted.py:
--------------------------------------------------------------------------------
1 | """
2 | Test inverse and its derivative
3 | """
4 |
5 | import torch
6 | import pytest
7 | from bgflow.nn import flow
8 |
9 |
10 | @pytest.fixture(params=[
11 | flow.KroneckerProductFlow(2),
12 | flow.PseudoOrthogonalFlow(2),
13 | flow.BentIdentity(),
14 | flow.FunnelFlow(),
15 | flow.AffineFlow(2),
16 | flow.SplitFlow(1),
17 | flow.SplitFlow(1,1),
18 | flow.TriuFlow(2),
19 | flow.InvertiblePPPP(2),
20 | flow.TorchTransform(torch.distributions.IndependentTransform(torch.distributions.SigmoidTransform(), 1))
21 | ])
22 | def simpleflow2d(request):
23 | return request.param
24 |
25 |
26 | def test_inverse(simpleflow2d):
27 | """Test inverse and inverse logDet of simple 2d flow blocks."""
28 | inverse = flow.InverseFlow(simpleflow2d)
29 | x = torch.tensor([[1., 2.]])
30 | *y, dlogp = simpleflow2d._forward(x)
31 | x2, dlogpinv = inverse._forward(*y)
32 | assert (dlogp + dlogpinv).detach().numpy() == pytest.approx(0.0, abs=1e-6)
33 | assert torch.norm(x2 - x).item() == pytest.approx(0.0, abs=1e-6)
34 |
35 | # test dimensions
36 | assert x2.shape == x.shape
37 | assert dlogp.shape == x[..., 0, None].shape
38 | assert dlogpinv.shape == x[..., 0, None].shape
39 |
40 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_modulo.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from bgflow import IncreaseMultiplicityFlow
5 | from bgflow import CircularShiftFlow
6 |
7 |
8 | @pytest.mark.parametrize('mult', [1, torch.ones(3, dtype=torch.int)])
9 | def test_IncreaseMultiplicityFlow(ctx, mult):
10 | m = 3
11 | mult = m * mult
12 | flow = IncreaseMultiplicityFlow(mult).to(**ctx)
13 |
14 | x = 1/6 + torch.linspace(0, 1, m + 1)[:-1].to(**ctx)
15 | x = torch.tile(x[None, ...], (1000, 1))
16 | y, dlogp = flow.forward(x, inverse=True)
17 | assert y.shape == x.shape
18 | assert torch.allclose(y, 1/2 * torch.ones_like(y))
19 |
20 | x2, dlogp2 = flow.forward(y)
21 | for point in (1/6, 3/6, 5/6):
22 | count = torch.sum(torch.isclose(x2, point * torch.ones_like(x2)))
23 | assert count > 800
24 | assert count < 1200
25 | assert torch.allclose(dlogp, torch.zeros_like(dlogp))
26 | assert torch.allclose(dlogp2, torch.zeros_like(dlogp))
27 |
28 | @pytest.mark.parametrize('shift', [1, torch.ones(3)])
29 | def test_CircularShiftFlow(ctx, shift):
30 | m = 0.2
31 | shift = m * shift
32 | flow = CircularShiftFlow(shift).to(**ctx)
33 |
34 | x = torch.tensor([0.0, 0.2, 0.9]).to(**ctx)
35 | x_shifted = torch.tensor([0.2, 0.4, 0.1]).to(**ctx)
36 | y, dlogp = flow.forward(x)
37 | assert y.shape == x.shape
38 | assert torch.allclose(y, x_shifted)
39 |
40 | x2, dlogp2 = flow.forward(y, inverse=True)
41 | assert torch.allclose(x, x2)
42 | assert torch.allclose(dlogp, torch.zeros_like(dlogp))
43 | assert torch.allclose(dlogp2, torch.zeros_like(dlogp))
44 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_nODE.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow.distribution import NormalDistribution
4 | from bgflow.nn.flow import DiffEqFlow
5 | from bgflow.nn.flow.dynamics import BlackBoxDynamics, TimeIndependentDynamics
6 | from bgflow.nn.flow.estimator import BruteForceEstimator
7 |
8 | dim = 1
9 | n_samples = 100
10 | prior = NormalDistribution(dim)
11 | latent = prior.sample(n_samples)
12 |
13 |
14 | class SimpleDynamics(torch.nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, xs):
19 | dxs = - 1 * xs
20 | return dxs
21 |
22 |
23 | def make_black_box_flow():
24 | black_box_dynamics = BlackBoxDynamics(
25 | dynamics_function=TimeIndependentDynamics(SimpleDynamics()),
26 | divergence_estimator=BruteForceEstimator()
27 | )
28 |
29 | flow = DiffEqFlow(
30 | dynamics=black_box_dynamics
31 | )
32 | return flow
33 |
34 |
35 | def test_nODE_flow_OTD():
36 | # Test forward pass of simple nODE with the OTD solver
37 | flow = make_black_box_flow()
38 |
39 | try:
40 | samples, dlogp = flow(latent)
41 | except ImportError:
42 | pytest.skip("Test requires torchdiffeq.")
43 |
44 | assert samples.std() < 1
45 | assert torch.allclose(dlogp, -torch.ones(n_samples))
46 |
47 | # Test backward pass of simple nODE with the OTD solver
48 | try:
49 | samples, dlogp = flow(latent, inverse=True)
50 | except ImportError:
51 | pytest.skip("Test requires torchdiffeq.")
52 |
53 | assert samples.std() > 1
54 | assert torch.allclose(dlogp, torch.ones(n_samples))
55 |
56 |
57 | def test_nODE_flow_DTO():
58 | # Test forward pass of simple nODE with the DTO solver
59 | flow = make_black_box_flow()
60 |
61 | flow._use_checkpoints = True
62 | options = {
63 | "Nt": 20,
64 | "method": "RK4"
65 | }
66 | flow._kwargs = options
67 |
68 | try:
69 | samples, dlogp = flow(latent)
70 | except ImportError:
71 | pytest.skip("Test requires anode.")
72 |
73 | assert samples.std() < 1
74 | assert torch.allclose(dlogp, -torch.ones(n_samples))
75 |
76 | # Test backward pass of simple nODE with the DTO solver
77 | try:
78 | samples, dlogp = flow(latent, inverse=True)
79 | except ImportError:
80 | pytest.skip("Test requires torchdiffeq.")
81 |
82 | assert samples.std() > 1
83 | assert torch.allclose(dlogp, torch.ones(n_samples))
84 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_pppp.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from bgflow.nn.flow.pppp import InvertiblePPPP, PPPPScheduler, _iterative_solve
5 | from bgflow.nn.flow.sequential import SequentialFlow
6 |
7 |
8 | def test_invertible_pppp():
9 | flow = InvertiblePPPP(2, penalty_parameter=0.1)
10 | x = torch.tensor([[3.1, -2.0]])
11 | with torch.no_grad():
12 | flow.u[:] = torch.tensor([0.4, -0.3])
13 | flow.v[:] = torch.tensor([0.1, 0.2])
14 | y, logdet = flow.forward(x)
15 |
16 | x2, logdetinv = flow.forward(y, inverse=True)
17 | assert torch.isclose(x2, x, atol=1e-5).all()
18 |
19 | assert(flow.penalty() > -1)
20 | flow.pppp_merge()
21 | y2, logdet2 = flow.forward(x)
22 | x3, logdetinv2 = flow._inverse(y2, inverse=True)
23 |
24 | assert torch.isclose(torch.mm(flow.A, flow.Ainv), torch.eye(2)).all()
25 | assert torch.isclose(x3, x, atol=1e-5).all()
26 | assert torch.isclose(logdet, -logdetinv, atol=1e-5)
27 | assert torch.isclose(logdet, logdet2, atol=1e-5)
28 | assert torch.isclose(logdet, -logdetinv2, atol=1e-5)
29 | assert torch.isclose(y, y2, atol=1e-5).all()
30 | assert torch.isclose(flow.u, torch.zeros(1), atol=1e-5).all()
31 |
32 | # test training mode
33 | flow.train(False)
34 | y_test, logdet_test = flow.forward(x)
35 | assert torch.isclose(y, y_test, atol=1e-5).all()
36 | assert torch.isclose(logdet, logdet_test, atol=1e-5)
37 | x_test, logdetinv_test = flow.forward(y, inverse=True)
38 | assert torch.isclose(x, x_test, atol=1e-5).all()
39 | assert torch.isclose(-logdet, logdetinv_test, atol=1e-5)
40 |
41 |
42 | @pytest.mark.parametrize("mode", ["eye", "reverse"])
43 | @pytest.mark.parametrize("dim", [1, 3, 5, 6, 7, 10, 20])
44 | def test_initialization(mode, dim):
45 | flow = InvertiblePPPP(dim, init=mode)
46 | assert flow.A.inverse().numpy() == pytest.approx(flow.Ainv.numpy())
47 | assert torch.det(flow.A).item() == pytest.approx(flow.detA.item())
48 |
49 |
50 | @pytest.mark.parametrize("order", [2, 3, 7])
51 | def test_iterative_solve(order):
52 | torch.seed = 0
53 | a = torch.eye(3)+0.05*torch.randn(3, 3)
54 | inv = torch.inverse(a) + 0.2*torch.randn(3,3)
55 | for i in range(10):
56 | inv = _iterative_solve(a, inv, order)
57 | assert torch.mm(inv, a).numpy() == pytest.approx(torch.eye(3).numpy(), abs=1e-5)
58 |
59 |
60 | @pytest.mark.parametrize("model", [InvertiblePPPP(3), SequentialFlow([InvertiblePPPP(3), InvertiblePPPP(3)])])
61 | def test_scheduler(model):
62 | """API test for a full optimization workflow."""
63 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
64 | scheduler = PPPPScheduler(model, optimizer, n_force_merge=5, n_correct=5, n_recompute_det=5)
65 | torch.manual_seed(0)
66 | a = 1e-6*torch.eye(3)
67 | data = torch.ones(800, 3)
68 | target = torch.einsum("ij,...j->...i", a, data)
69 | loss = torch.nn.MSELoss()
70 | assert loss(data, target) > 1e-1
71 | for iter in range(100):
72 | optimizer.zero_grad()
73 | batch = data[8*iter:8*(iter+1)]
74 | y, _ = model.forward(batch)
75 | mse = loss(y, target[8*iter:8*(iter+1)])
76 | mse_plus_penalty = mse + scheduler.penalty()
77 | mse_plus_penalty.backward()
78 | optimizer.step()
79 | if iter % 10 == 0:
80 | scheduler.step()
81 | assert scheduler.penalty().item() >= 0.0
82 | assert mse < 2e-2 # check that the model has improved
83 | assert scheduler.i == 10
84 |
85 |
86 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_sequential.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from bgflow.nn.flow.sequential import SequentialFlow
4 | from bgflow.nn.flow.orthogonal import PseudoOrthogonalFlow
5 | from bgflow.nn.flow.elementwise import BentIdentity
6 | from bgflow.nn.flow.triangular import TriuFlow
7 |
8 |
9 | def test_trigger_penalty():
10 | flow = SequentialFlow([
11 | PseudoOrthogonalFlow(3),
12 | PseudoOrthogonalFlow(3),
13 | PseudoOrthogonalFlow(3),
14 | ])
15 | penalties = flow.trigger("penalty")
16 | assert len(penalties) == 3
17 |
18 |
19 | def test_getitem():
20 | a = BentIdentity()
21 | b = TriuFlow(2)
22 | flow = SequentialFlow([a,b])
23 | assert flow[0] == a
24 | assert flow[1] == b
25 | subflow = flow[[0]]
26 | assert len(subflow) == 1
27 | assert isinstance(subflow, SequentialFlow)
28 | assert subflow[0] == a
--------------------------------------------------------------------------------
/tests/nn/flow/test_torchtransform.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.distributions import SigmoidTransform, AffineTransform, IndependentTransform
4 | from bgflow import TorchTransform, SequentialFlow, BentIdentity
5 |
6 |
7 | def test_torch_transform(ctx):
8 | """try using torch.Transform in combination with bgflow.Flow"""
9 | torch.manual_seed(10)
10 | x = torch.torch.randn(10, 3, **ctx)
11 | flow = SequentialFlow([
12 | TorchTransform(IndependentTransform(SigmoidTransform(), 1)),
13 | TorchTransform(
14 | AffineTransform(
15 | loc=torch.randn(3, **ctx),
16 | scale=2.0+torch.rand(3, **ctx), event_dim=1
17 | ),
18 | ),
19 | BentIdentity(),
20 | # test the reinterpret_batch_ndims arguments
21 | TorchTransform(SigmoidTransform(), 1)
22 | ])
23 | z, dlogp = flow.forward(x)
24 | y, neg_dlogp = flow.forward(z, inverse=True)
25 | tol = 1e-7 if ctx["dtype"] is torch.float64 else 1e-5
26 | assert torch.allclose(x, y, atol=tol)
27 | assert torch.allclose(dlogp, -neg_dlogp, atol=tol)
28 |
--------------------------------------------------------------------------------
/tests/nn/flow/test_triangular.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import pytest
4 | from bgflow.nn.flow.triangular import TriuFlow
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "b", [
9 | torch.tensor(0.),
10 | torch.nn.Parameter(torch.tensor([1.,4.,1.])),
11 | torch.nn.Parameter(torch.zeros(3))
12 | ])
13 | def test_invert(b):
14 | tf = TriuFlow(3, shift=(isinstance(b, torch.nn.Parameter)))
15 | tf._unique_elements = torch.nn.Parameter(torch.rand_like(tf._unique_elements))
16 | tf._make_r()
17 | tf.b = b
18 | x = torch.randn(10,3)
19 | y, dlogp = tf._forward(x)
20 | x2, dlogpinv = tf._inverse(y)
21 | assert torch.norm(dlogp + dlogpinv).item() == pytest.approx(0.0)
22 | assert torch.norm(x-x2).item() == pytest.approx(0.0, abs=1e-6)
23 |
--------------------------------------------------------------------------------
/tests/nn/flow/transformer/test_affine.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import torch
4 | from bgflow import AffineTransformer, DenseNet
5 |
6 |
7 | class ShiftModule(torch.nn.Module):
8 | def forward(self, x):
9 | return torch.ones_like(x)
10 |
11 |
12 | @pytest.mark.parametrize("is_circular", [True, False])
13 | @pytest.mark.parametrize("use_scale_transform", [True, False])
14 | def test_affine(is_circular, use_scale_transform):
15 |
16 | if use_scale_transform:
17 | scale = DenseNet([2, 2])
18 | else:
19 | scale = None
20 |
21 | if use_scale_transform and is_circular:
22 | with pytest.raises(ValueError):
23 | trafo = AffineTransformer(
24 | shift_transformation=ShiftModule(),
25 | scale_transformation=scale,
26 | is_circular=is_circular
27 | )
28 |
29 | else:
30 | trafo = AffineTransformer(
31 | shift_transformation=ShiftModule(),
32 | scale_transformation=scale,
33 | is_circular=is_circular
34 | )
35 | x = torch.rand(100, 2)
36 | y = torch.rand(100, 2)
37 | y2, dlogp = trafo.forward(x, y)
38 | assert y2.shape == y.shape
39 | if is_circular:
40 | assert (y2 < 1).all()
41 | else:
42 | assert (y2 > 1).any()
43 |
44 |
--------------------------------------------------------------------------------
/tests/nn/flow/transformer/test_gaussian.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from bgflow.nn.flow.transformer.gaussian import TruncatedGaussianTransformer
4 | from bgflow.distribution.normal import TruncatedNormalDistribution
5 |
6 |
7 | def test_constrained_affine_transformer(ctx):
8 | tol = 5e-4 if ctx["dtype"] == torch.float32 else 5e-7
9 | mu_net = torch.nn.Linear(3, 2, bias=False)
10 | sigma_net = torch.nn.Linear(3, 2, bias=False)
11 | mu_net.weight.data = torch.ones_like(mu_net.weight.data)
12 | sigma_net.weight.data = torch.ones_like(sigma_net.weight.data)
13 | constrained = TruncatedGaussianTransformer(mu_net, sigma_net, -5.0, 5.0, -8.0, 0.0).to(**ctx)
14 |
15 | # test if forward and inverse are compatible
16 | x = torch.ones(1, 3, **ctx)
17 | y = torch.tensor([[-2.5, 2.5]], **ctx)
18 | y.requires_grad = True
19 | out, dlogp = constrained.forward(x, y)
20 | assert not torch.allclose(dlogp, torch.zeros_like(dlogp))
21 | assert (out >= torch.tensor(-8.0, **ctx)).all()
22 | assert (out <= torch.tensor(0.20, **ctx)).all()
23 | y2, neg_dlogp = constrained.forward(x, out, inverse=True)
24 | assert torch.allclose(y, y2, atol=tol)
25 | assert torch.allclose(dlogp + neg_dlogp, torch.zeros_like(dlogp), atol=tol)
26 |
27 | # test if the log det agrees with the log prob of a truncated normal distribution
28 | mu = torch.einsum("ij,...j->...i", mu_net.weight.data, x)
29 | _, logsigma = constrained._get_mu_and_log_sigma(x, y) # not reiterating the tanh stuff
30 | sigma = torch.exp(logsigma)
31 | trunc_gaussian = TruncatedNormalDistribution(mu, sigma, torch.tensor(-5, **ctx), torch.tensor(5, **ctx))
32 | log_prob = trunc_gaussian.log_prob(y)
33 | log_scale = torch.log(torch.tensor(8., **ctx))
34 | assert torch.allclose(dlogp, (log_prob + log_scale).sum(dim=-1, keepdim=True), atol=tol)
35 |
36 | # try backward pass and assert reasonable gradients
37 | y2.sum().backward(create_graph=True)
38 | neg_dlogp.backward()
39 | for tensor in [y, mu_net.weight]:
40 | assert (tensor.grad > -1e6).all()
41 | assert (tensor.grad < 1e6).all()
42 |
--------------------------------------------------------------------------------
/tests/nn/flow/transformer/test_spline.py:
--------------------------------------------------------------------------------
1 | """Test spline transformer"""
2 |
3 | import pytest
4 | import torch
5 | from bgflow import ConditionalSplineTransformer, CouplingFlow, SplitFlow, NormalDistribution, DenseNet
6 |
7 |
8 | @pytest.mark.parametrize("is_circular", [True, False])
9 | def test_conditional_spline_transformer_api(is_circular, ctx):
10 | pytest.importorskip("nflows")
11 |
12 | n_bins = 4
13 | dim_trans = 10
14 | n_samples = 10
15 | dim_cond = 9
16 | x_cond = torch.rand((n_samples, dim_cond), **ctx)
17 | x_trans = torch.rand((n_samples, dim_trans)).to(x_cond)
18 |
19 | if is_circular:
20 | dim_net_out = 3 * n_bins * dim_trans
21 | else:
22 | dim_net_out = (3 * n_bins + 1) * dim_trans
23 | conditioner = DenseNet([dim_cond, dim_net_out])
24 |
25 | transformer = ConditionalSplineTransformer(
26 | params_net=conditioner,
27 | is_circular=is_circular,
28 | ).to(x_cond)
29 |
30 | y, dlogp = transformer.forward(x_cond, x_trans)
31 |
32 | assert (y > 0.0).all()
33 | assert (y < 1.0).all()
34 |
35 |
36 | @pytest.mark.parametrize(
37 | "is_circular", [torch.tensor(True), torch.tensor(False), torch.tensor([True, False], dtype=torch.bool)]
38 | )
39 | def test_conditional_spline_continuity(is_circular, ctx):
40 | pytest.importorskip("nflows")
41 | torch.manual_seed(2150)
42 |
43 | n_bins = 3
44 | dim_trans = 2
45 | n_samples = 1
46 | dim_cond = 1
47 | x_cond = torch.rand((n_samples, dim_cond), **ctx)
48 |
49 | if is_circular.all():
50 | dim_net_out = 3 * n_bins * dim_trans
51 | elif not is_circular.any():
52 | dim_net_out = (3 * n_bins + 1) * dim_trans
53 | else:
54 | dim_net_out = 3 * n_bins * dim_trans + int(is_circular.sum())
55 | conditioner = DenseNet([dim_cond, dim_net_out], bias_scale=2.)
56 |
57 | transformer = ConditionalSplineTransformer(
58 | params_net=conditioner,
59 | is_circular=is_circular,
60 | ).to(x_cond)
61 |
62 | slopes = transformer._compute_params(x_cond, dim_trans)[2]
63 | continuous = torch.isclose(slopes[0,:,0], slopes[0,:,-1]).tolist()
64 | if is_circular.all():
65 | assert continuous == [True, True]
66 | elif not is_circular.any():
67 | assert continuous == [False, False]
68 | else:
69 | assert continuous == [True, False]
--------------------------------------------------------------------------------
/tests/nn/test_wrap_distances.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from bgflow.nn.periodic import WrapDistances
4 |
5 |
6 | def test_wrap_distances():
7 | positions = torch.tensor([[[0.,0.,0.],[0.,0.,1.],[0.,2.,0.]]])
8 | positions_flat = positions.view(positions.shape[0],-1)
9 | module = torch.nn.ReLU()
10 | wrapper = WrapDistances(module)
11 | result = wrapper.forward(positions_flat)
12 | expected = torch.tensor([[1.,2.,np.sqrt(5)]]).to(positions)
13 | assert torch.allclose(result, expected)
14 |
--------------------------------------------------------------------------------
/tests/test_readme.py:
--------------------------------------------------------------------------------
1 |
2 | # If this fails the example in the readme wont work!
3 |
4 | def test_readme():
5 | import torch
6 | import matplotlib.pyplot as plt
7 | import bgflow as bg
8 |
9 | # define prior and target
10 | dim = 2
11 | prior = bg.NormalDistribution(dim)
12 | target = bg.DoubleWellEnergy(dim)
13 |
14 | # here we aggregate all layers of the flow
15 | layers = []
16 | layers.append(bg.SplitFlow(dim // 2))
17 | layers.append(bg.CouplingFlow(
18 | # we use a affine transformation to transform the RHS conditioned on the LHS
19 | bg.AffineTransformer(
20 | # use simple dense nets for the affine shift/scale
21 | shift_transformation=bg.DenseNet(
22 | [dim // 2, 4, dim // 2],
23 | activation=torch.nn.ReLU()
24 | ),
25 | scale_transformation=bg.DenseNet(
26 | [dim // 2, 4, dim // 2],
27 | activation=torch.nn.Tanh()
28 | )
29 | )
30 | ))
31 | layers.append(bg.InverseFlow(bg.SplitFlow(dim // 2)))
32 |
33 | # now define the flow as a sequence of all operations stored in layers
34 | flow = bg.SequentialFlow(layers)
35 |
36 | # The BG is defined by a prior, target and a flow
37 | generator = bg.BoltzmannGenerator(prior, flow, target)
38 |
39 | # sample from the BG
40 | samples = generator.sample(1000)
41 | plt.hist2d(
42 | samples[:, 0].detach().numpy(),
43 | samples[:, 1].detach().numpy(), bins=100
44 | )
--------------------------------------------------------------------------------
/tests/utils/test_autograd.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from bgflow.utils import (
5 | brute_force_jacobian, brute_force_jacobian_trace, get_jacobian,
6 | batch_jacobian, requires_grad
7 | )
8 |
9 |
10 | x = torch.tensor([[1., 2, 3]], requires_grad=True)
11 | y = x.pow(2) + x[:, 1]
12 | true_jacobian = np.array([[[2., 1, 0], [0, 5, 0], [0, 1, 6]]])
13 |
14 |
15 | def test_brute_force_jacobian_trace():
16 | jacobian_trace = brute_force_jacobian_trace(y, x)
17 | assert jacobian_trace.detach().numpy() == pytest.approx(np.array([13.]), abs=1e-6)
18 |
19 |
20 | def test_brute_force_jacobian():
21 | jacobian = brute_force_jacobian(y, x)
22 | assert jacobian.detach().numpy() == pytest.approx(true_jacobian, abs=1e-6)
23 |
24 |
25 | def test_batch_jacobian(ctx):
26 | t = x.repeat(10, 1).to(**ctx)
27 | out = t.pow(2) + t[:, [1]]
28 | expected = torch.tensor(true_jacobian, **ctx).repeat(10, 1, 1)
29 | assert torch.allclose(batch_jacobian(out, t), expected)
30 |
31 |
32 | def test_get_jacobian(ctx):
33 | t = torch.ones((2,), **ctx)
34 | func = lambda s: s**2
35 | expected = 2*t*torch.eye(2, **ctx)
36 | assert torch.allclose(get_jacobian(func, t).jac, expected)
37 |
38 |
39 | def test_requires_grad(ctx):
40 | t = torch.zeros((2,), **ctx)
41 | assert not t.requires_grad
42 | with requires_grad(t):
43 | assert t.requires_grad
44 |
--------------------------------------------------------------------------------
/tests/utils/test_free_energy.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import pytest
5 | import torch
6 | from bgflow.distribution import NormalDistribution
7 |
8 | from bgflow.utils.free_energy import bennett_acceptance_ratio
9 |
10 |
11 | @pytest.mark.parametrize("method", ["torch", "pymbar"])
12 | @pytest.mark.parametrize("compute_uncertainty", [True, False])
13 | def test_bar(ctx, method, compute_uncertainty):
14 | pytest.importorskip("pymbar")
15 | dim = 1
16 | energy1 = NormalDistribution(dim, mean=torch.zeros(dim, **ctx))
17 | energy2 = NormalDistribution(dim, mean=0.2*torch.ones(dim, **ctx))
18 | samples1 = energy1.sample(10000)
19 | samples2 = energy2.sample(20000)
20 |
21 | free_energy, uncertainty = bennett_acceptance_ratio(
22 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1),
23 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)),
24 | implementation=method,
25 | compute_uncertainty=compute_uncertainty
26 | )
27 | assert free_energy.item() == pytest.approx(1., abs=1e-2)
28 | if compute_uncertainty:
29 | assert uncertainty.item() < 1e-2
30 | else:
31 | assert uncertainty is None
32 |
33 |
34 | @pytest.mark.parametrize("method", ["torch", "pymbar"])
35 | @pytest.mark.parametrize("warn", [True, False])
36 | def test_bar_no_convergence(ctx, method, warn):
37 | with warnings.catch_warnings():
38 | warnings.simplefilter("ignore", RuntimeWarning)
39 | pytest.importorskip("pymbar")
40 | dim = 1
41 | energy1 = NormalDistribution(dim, mean=-1e20*torch.ones(dim, **ctx))
42 | energy2 = NormalDistribution(dim, mean=1e20*torch.ones(dim, **ctx))
43 | samples1 = energy1.sample(5)
44 | samples2 = energy2.sample(5)
45 |
46 | if warn:
47 | # test if warning is raised
48 | with pytest.warns(UserWarning, match="BAR could not"):
49 | free_energy, uncertainty = bennett_acceptance_ratio(
50 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1),
51 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)),
52 | implementation=method,
53 | warn=warn
54 | )
55 | else:
56 | free_energy, uncertainty = bennett_acceptance_ratio(
57 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1),
58 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)),
59 | implementation=method,
60 | warn=warn
61 | )
62 | assert np.isnan(free_energy.item())
63 | assert np.isnan(uncertainty.item())
64 |
65 |
66 | def test_bar_uncertainty(ctx):
67 | """test consistency with the reference implementation"""
68 | pytest.importorskip("pymbar")
69 | dim = 1
70 | energy1 = NormalDistribution(dim, mean=torch.zeros(dim, **ctx))
71 | energy2 = NormalDistribution(dim, mean=0.2*torch.ones(dim, **ctx)) # will be multiplied by e
72 | samples1 = energy1.sample(1000)
73 | samples2 = energy2.sample(2000)
74 |
75 | free_energy1, uncertainty1 = bennett_acceptance_ratio(
76 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1),
77 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)),
78 | implementation="torch"
79 | )
80 | free_energy2, uncertainty2 = bennett_acceptance_ratio(
81 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1),
82 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)),
83 | implementation="pymbar"
84 | )
85 | assert free_energy1.item() == pytest.approx(free_energy2.item(), rel=1e-3)
86 | assert uncertainty1.item() == pytest.approx(uncertainty2.item(), rel=1e-3)
87 |
--------------------------------------------------------------------------------
/tests/utils/test_geometry.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow.utils import distance_vectors, distances_from_vectors, remove_mean, compute_distances
4 |
5 |
6 | @pytest.mark.parametrize("remove_diagonal", [True, False])
7 | def test_distances_from_vectors(remove_diagonal):
8 | """Test if distances are calculated correctly from a particle configuration."""
9 | particles = torch.tensor([[[3., 0], [3, 0], [0, 4]]])
10 | distances = distances_from_vectors(
11 | distance_vectors(particles, remove_diagonal=remove_diagonal))
12 | if remove_diagonal == True:
13 | assert torch.allclose(distances, torch.tensor([[[0., 5], [0, 5], [5, 5]]]), atol=1e-2)
14 | else:
15 | assert torch.allclose(distances, torch.tensor([[[0., 0, 5], [0, 0, 5], [5, 5, 0]]]), atol=1e-2)
16 |
17 |
18 | def test_mean_free(ctx):
19 | """Test if the mean of random configurations is removed correctly"""
20 | samples = torch.rand(100, 100, 3, **ctx) - 0.3
21 | samples = remove_mean(samples, 100, 3)
22 | mean_deviation = samples.mean(dim=(1, 2))
23 | threshold = 1e-5
24 | assert torch.all(mean_deviation.abs() < threshold)
25 |
26 | @pytest.mark.parametrize("remove_duplicates", [True, False])
27 | def test_compute_distances(remove_duplicates, ctx):
28 | """Test if distances are calculated correctly from a particle configuration."""
29 | particles = torch.tensor([[[3., 0], [3, 0], [0, 4]]], **ctx)
30 | distances = compute_distances(particles, n_particles=3, n_dimensions=2, remove_duplicates=remove_duplicates)
31 | if remove_duplicates == True:
32 | assert torch.allclose(distances, torch.tensor([[0., 5, 5]], **ctx), atol=1e-5)
33 | else:
34 | assert torch.allclose(distances, torch.tensor([[[0., 0, 5], [0, 0, 5], [5, 5, 0]]], **ctx), atol=1e-5)
35 |
--------------------------------------------------------------------------------
/tests/utils/test_rbf_kernels.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from bgflow.utils import kernelize_with_rbf, rbf_kernels
4 |
5 |
6 | def test_kernelize_with_rbf():
7 | # distance vector of shape `[n_batch, n_particles, n_particles - 1, 1]`
8 | distances = torch.tensor([[[[0.], [5]], [[0], [5]], [[5], [5]]]])
9 | mus = torch.tensor([[[[0., 5]]]])
10 | gammas = torch.tensor([[[[1., 0.5]]]])
11 |
12 | distances2 = torch.tensor([[[[1.]]]], requires_grad=True)
13 | mus2 = torch.tensor([[[[1., 3]]]])
14 | gammas2 = torch.tensor([[[[1., 0.5]]]])
15 |
16 | rbf1 = torch.exp(- (distances2 - mus2[0, 0, 0, 0]) ** 2 / gammas2[0, 0, 0, 0] ** 2)
17 | rbf2 = torch.exp(- (distances2 - mus2[0, 0, 0, 1]) ** 2 / gammas2[0, 0, 0, 1] ** 2)
18 | r1 = rbf1 / (rbf1 + rbf2)
19 | r2 = rbf2 / (rbf1 + rbf2)
20 |
21 | # Test shapes and math for simple configurations
22 | rbfs = kernelize_with_rbf(distances, mus, gammas)
23 | assert torch.allclose(rbfs, torch.tensor([[[[1., 0], [0, 1]], [[1., 0], [0, 1]], [[0, 1], [0, 1]]]]), atol=1e-5)
24 |
25 | rbfs2 = kernelize_with_rbf(distances2, mus2, gammas2)
26 | assert torch.allclose(rbfs2, torch.cat([r1, r2], dim=-1), atol=1e-5)
27 |
28 |
29 | def test_rbf_kernels():
30 | # distance vector of shape `[n_batch, n_particles, n_particles - 1, 1]`
31 | distances = torch.tensor([[[[0.], [5]], [[0], [5]], [[5], [5]]]])
32 | mus = torch.tensor([[[[0., 5]]]])
33 | gammas = torch.tensor([[[[1., 0.5]]]])
34 |
35 | distances2 = torch.tensor([[[[1.]]]], requires_grad=True)
36 | mus2 = torch.tensor([[[[1., 3]]]])
37 | gammas2 = torch.tensor([[[[1., 0.5]]]])
38 |
39 | rbf1 = torch.exp(- (distances2 - mus2[0, 0, 0, 0]) ** 2 / gammas2[0, 0, 0, 0] ** 2)
40 | rbf2 = torch.exp(- (distances2 - mus2[0, 0, 0, 1]) ** 2 / gammas2[0, 0, 0, 1] ** 2)
41 | r1 = rbf1 / (rbf1 + rbf2)
42 | r2 = rbf2 / (rbf1 + rbf2)
43 |
44 | # Test shapes, math and the derivative for simple configurations
45 | neg_log_gammas = - torch.log(gammas)
46 | rbfs, derivatives_rbfs = rbf_kernels(distances, mus, neg_log_gammas, derivative=True)
47 | assert torch.allclose(rbfs, torch.tensor([[[[1., 0], [0, 1]], [[1., 0], [0, 1]], [[0, 1], [0, 1]]]]), atol=1e-5)
48 | assert torch.allclose(derivatives_rbfs, torch.zeros((1, 3, 2, 2)), atol=1e-5)
49 |
50 | neg_log_gammas2 = - torch.log(gammas2)
51 | rbfs2, derivatives_rbfs2 = rbf_kernels(distances2, mus2, neg_log_gammas2, derivative=True)
52 | assert torch.allclose(rbfs2, torch.cat([r1, r2], dim=-1), atol=1e-5)
53 |
54 | # Check derivative of rbf
55 | dr1 = torch.autograd.grad(r1, distances2, retain_graph=True)
56 | dr2 = torch.autograd.grad(r2, distances2)
57 | assert torch.allclose(derivatives_rbfs2, torch.cat([dr1[0], dr2[0]], dim=-1), atol=1e-5)
58 |
--------------------------------------------------------------------------------
/tests/utils/test_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bgflow.utils import ClipGradient
3 |
4 |
5 | def torch_example(grad_clipping, ctx):
6 | positions = torch.arange(6).reshape(2, 3).to(**ctx)
7 | positions.requires_grad = True
8 | positions = grad_clipping.to(**ctx)(positions)
9 | (0.5 * positions ** 2).sum().backward()
10 | return positions.grad
11 |
12 |
13 | def test_clip_by_val(ctx):
14 | grad_clipping = ClipGradient(clip=3., norm_dim=1)
15 | assert torch.allclose(
16 | torch_example(grad_clipping, ctx),
17 | torch.tensor([[0., 1., 2.], [3., 3., 3.]], **ctx)
18 | )
19 |
20 |
21 | def test_clip_by_atom(ctx):
22 | grad_clipping = ClipGradient(clip=3., norm_dim=3)
23 | norm2 = torch.linalg.norm(torch.arange(3, 6, **ctx)).item()
24 | assert torch.allclose(
25 | torch_example(grad_clipping, ctx),
26 | torch.tensor([[0., 1., 2.], [3/norm2*3, 4/norm2*3, 5/norm2*3]], **ctx)
27 | )
28 |
29 |
30 | def test_clip_by_batch(ctx):
31 | grad_clipping = ClipGradient(clip=3., norm_dim=-1)
32 | norm2 = torch.linalg.norm(torch.arange(6, **ctx)).item()
33 | assert torch.allclose(
34 | torch_example(grad_clipping, ctx),
35 | (torch.arange(6, **ctx) / norm2 * 3.).reshape(2, 3)
36 | )
37 |
38 |
--------------------------------------------------------------------------------
/tests/utils/test_types.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from bgflow.utils import as_numpy
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def test_as_numpy():
9 | assert np.allclose(as_numpy(2.0), np.array(2.0))
10 | assert np.allclose(as_numpy(np.ones(2)), np.ones(2))
11 | assert as_numpy(1) == np.array(1)
12 |
13 |
14 | def test_tensor_as_numpy(ctx):
15 | out = as_numpy(torch.zeros(2, **ctx))
16 | assert isinstance(out, np.ndarray)
17 | assert np.allclose(out, np.zeros(2))
18 |
--------------------------------------------------------------------------------