├── .github └── workflows │ ├── CI-openmm.yml │ ├── CI.yml │ ├── codeql.yml │ └── lint.yml ├── .gitignore ├── LICENSE ├── README.md ├── bgflow ├── __init__.py ├── _version.py ├── bg.py ├── distribution │ ├── __init__.py │ ├── distributions.py │ ├── energy │ │ ├── __init__.py │ │ ├── ase.py │ │ ├── base.py │ │ ├── clipped.py │ │ ├── double_well.py │ │ ├── lennard_jones.py │ │ ├── multi_double_well_potential.py │ │ ├── openmm.py │ │ ├── particles.py │ │ └── xtb.py │ ├── mixture.py │ ├── normal.py │ ├── product.py │ └── sampling │ │ ├── __init__.py │ │ ├── _iterative_helpers.py │ │ ├── _mcmc │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── latent_sampling.py │ │ ├── metropolis.py │ │ ├── permutation.py │ │ └── umbrella_sampling.py │ │ ├── base.py │ │ ├── buffer.py │ │ ├── dataset.py │ │ ├── iterative.py │ │ └── mcmc.py ├── factory │ ├── GNN_factory.py │ ├── __init__.py │ ├── conditioner_factory.py │ ├── distribution_factory.py │ ├── generator_builder.py │ ├── icmarginals.py │ ├── tensor_info.py │ └── transformer_factory.py ├── nn │ ├── __init__.py │ ├── dense.py │ ├── flow │ │ ├── __init__.py │ │ ├── affine.py │ │ ├── base.py │ │ ├── bnaf.py │ │ ├── cdf.py │ │ ├── checkerboard.py │ │ ├── circular.py │ │ ├── coupling.py │ │ ├── crd_transform │ │ │ ├── __init__.py │ │ │ ├── _deprecated_ic.py │ │ │ ├── ic.py │ │ │ ├── ic_helper.py │ │ │ └── pca.py │ │ ├── diffeq.py │ │ ├── dynamics │ │ │ ├── __init__.py │ │ │ ├── anode_dynamic.py │ │ │ ├── blackbox.py │ │ │ ├── density.py │ │ │ ├── inversed.py │ │ │ ├── kernel_dynamic.py │ │ │ └── simple.py │ │ ├── elementwise.py │ │ ├── estimator │ │ │ ├── __init__.py │ │ │ ├── brute_force_estimator.py │ │ │ └── hutchinson_estimator.py │ │ ├── funnel.py │ │ ├── inverted.py │ │ ├── kronecker.py │ │ ├── modulo.py │ │ ├── orthogonal.py │ │ ├── pppp.py │ │ ├── sequential.py │ │ ├── spline.py │ │ ├── stochastic │ │ │ ├── __init__.py │ │ │ ├── augment.py │ │ │ ├── langevin.py │ │ │ ├── mcmc.py │ │ │ └── snf_openmm.py │ │ ├── torchtransform.py │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ ├── affine.py │ │ │ ├── base.py │ │ │ ├── entropy_scaling.py │ │ │ ├── gaussian.py │ │ │ ├── jax.py │ │ │ ├── jax_bridge.py │ │ │ └── spline.py │ │ └── triangular.py │ ├── periodic.py │ └── training │ │ ├── __init__.py │ │ └── trainers.py └── utils │ ├── __init__.py │ ├── autograd.py │ ├── free_energy.py │ ├── geometry.py │ ├── internal_coordinates.py │ ├── openmm.py │ ├── rbf_kernels.py │ ├── shape.py │ ├── tensorops.py │ ├── train.py │ └── types.py ├── devtools └── conda-env.yml ├── docs ├── Makefile ├── README.md ├── _static │ └── README.md ├── _templates │ ├── README.md │ └── class.rst ├── api │ ├── bg.rst │ ├── energies.rst │ ├── flows.rst │ ├── samplers.rst │ └── utils.rst ├── conf.py ├── examples.rst ├── getting_started.rst ├── index.rst ├── installation.rst ├── literature.bib ├── make.bat ├── requirements.rst └── requirements.yaml ├── examples ├── datasets │ └── README.rst ├── general_examples │ ├── README.rst │ └── plot_simple_bg.py └── nb_examples │ ├── README.rst │ └── example_bg_coupling.ipynb ├── notebooks ├── alanine_dipeptide_augmented.ipynb ├── alanine_dipeptide_basics.ipynb ├── alanine_dipeptide_basics.py ├── alanine_dipeptide_spline.ipynb ├── cgn_GNN_example.ipynb ├── example.ipynb ├── example_equivariant_RNVP.ipynb ├── example_equivariant_nODE.ipynb ├── iterative_umbrella_sampling.ipynb └── samplers.ipynb ├── readthedocs.yml ├── setup.cfg ├── setup.py ├── tests ├── conftest.py ├── data │ └── alanine-dipeptide-nowater.pdb ├── distribution │ ├── energy │ │ ├── test_ase.py │ │ ├── test_base.py │ │ ├── test_clipped.py │ │ ├── test_lennard_jones.py │ │ ├── test_multi_double_well_potential.py │ │ ├── test_openmm.py │ │ └── test_xtb.py │ ├── sampling │ │ ├── test_buffer.py │ │ ├── test_dataset.py │ │ ├── test_iterative.py │ │ ├── test_iterative_helpers.py │ │ └── test_mcmc.py │ ├── test_distribution.py │ ├── test_normal.py │ └── test_product.py ├── factory │ ├── test_conditioner_factory.py │ ├── test_distribution_factory.py │ ├── test_generator_builder.py │ ├── test_icmarginals.py │ ├── test_tensor_info.py │ └── test_transformer_factory.py ├── nn │ ├── flow │ │ ├── crd_transform │ │ │ └── test_ic.py │ │ ├── dynamics │ │ │ └── test_kernel_dynamics.py │ │ ├── estimators │ │ │ └── test_hutchinson_estimator.py │ │ ├── stochastic │ │ │ └── test_snf_openmm.py │ │ ├── test_cdf.py │ │ ├── test_coupling.py │ │ ├── test_inverted.py │ │ ├── test_modulo.py │ │ ├── test_nODE.py │ │ ├── test_pppp.py │ │ ├── test_sequential.py │ │ ├── test_torchtransform.py │ │ ├── test_triangular.py │ │ └── transformer │ │ │ ├── test_affine.py │ │ │ ├── test_gaussian.py │ │ │ ├── test_jax_bridge.py │ │ │ └── test_spline.py │ └── test_wrap_distances.py ├── test_bg.py ├── test_readme.py └── utils │ ├── test_autograd.py │ ├── test_free_energy.py │ ├── test_geometry.py │ ├── test_rbf_kernels.py │ ├── test_train.py │ └── test_types.py └── versioneer.py /.github/workflows/CI-openmm.yml: -------------------------------------------------------------------------------- 1 | name: CI with OpenMM on conda 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "main" 10 | schedule: 11 | # Nightly tests run on master by default: 12 | # Scheduled workflows run on the latest commit on the default or base branch. 13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 14 | - cron: "0 0 * * *" 15 | 16 | 17 | jobs: 18 | test: 19 | runs-on: ${{ matrix.os }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: [ubuntu-latest] 24 | python-version: [3.9] 25 | 26 | steps: 27 | 28 | - uses: actions/checkout@v2 29 | 30 | # More info on options: https://github.com/conda-incubator/setup-miniconda 31 | - uses: conda-incubator/setup-miniconda@v2 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | environment-file: devtools/conda-env.yml 35 | channels: conda-forge, pytorch, defaults 36 | activate-environment: test 37 | auto-update-conda: true 38 | auto-activate-base: false 39 | show-channel-urls: true 40 | 41 | - name: Install pip dependencies 42 | shell: bash -l {0} 43 | run: | 44 | pip install einops 45 | pip install nflows 46 | 47 | - name: Install package 48 | shell: bash -l {0} 49 | run: | 50 | python setup.py install 51 | 52 | - name: Test with pytest 53 | shell: bash -l {0} 54 | run: | 55 | pytest -vs 56 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI without OpenMM 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "main" 10 | schedule: 11 | # Nightly tests run on master by default: 12 | # Scheduled workflows run on the latest commit on the default or base branch. 13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 14 | - cron: "0 0 * * *" 15 | 16 | 17 | jobs: 18 | test: 19 | runs-on: ${{ matrix.cfg.os }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | cfg: 24 | - { os: ubuntu-latest, python-version: 3.7, torch-version: 'torch>=1.9,<1.11' } 25 | - { os: ubuntu-latest, python-version: 3.8, torch-version: 'torch>=1.9,<1.11' } 26 | - { os: ubuntu-latest, python-version: 3.9, torch-version: 'torch>=1.9,<1.11' } 27 | - { os: ubuntu-latest, python-version: 3.7, torch-version: 'torch>=1.11' } 28 | - { os: ubuntu-latest, python-version: 3.8, torch-version: 'torch>=1.11' } 29 | - { os: ubuntu-latest, python-version: 3.9, torch-version: 'torch>=1.11' } 30 | - { os: windows-latest, python-version: 3.9, torch-version: 'torch>=1.11' } 31 | - { os: macos-latest, python-version: 3.9, torch-version: 'torch>=1.11' } 32 | 33 | steps: 34 | 35 | # WITHOUT OPENMM 36 | - uses: actions/checkout@v2 37 | - name: Set up Python ${{ matrix.cfg.python-version }} 38 | uses: actions/setup-python@v2 39 | with: 40 | python-version: ${{ matrix.cfg.python-version }} 41 | - name: Install dependencies 42 | run: | 43 | python -m pip install --upgrade pip 44 | pip install pytest "${{ matrix.cfg.torch-version }}" numpy nflows torchdiffeq einops netCDF4 45 | - name: Install package 46 | run: | 47 | python setup.py install 48 | - name: Test with pytest 49 | run: | 50 | pytest -vs 51 | 52 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | schedule: 9 | - cron: "52 8 * * 3" 10 | 11 | jobs: 12 | analyze: 13 | name: Analyze 14 | runs-on: ubuntu-latest 15 | permissions: 16 | actions: read 17 | contents: read 18 | security-events: write 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | language: [ python ] 24 | 25 | steps: 26 | - name: Checkout 27 | uses: actions/checkout@v3 28 | 29 | - name: Initialize CodeQL 30 | uses: github/codeql-action/init@v2 31 | with: 32 | languages: ${{ matrix.language }} 33 | queries: +security-and-quality 34 | 35 | - name: Autobuild 36 | uses: github/codeql-action/autobuild@v2 37 | 38 | - name: Perform CodeQL Analysis 39 | uses: github/codeql-action/analyze@v2 40 | with: 41 | category: "/language:${{ matrix.language }}" 42 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Flake8 Linting 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "main" 10 | schedule: 11 | # Nightly tests run on master by default: 12 | # Scheduled workflows run on the latest commit on the default or base branch. 13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 14 | - cron: "0 0 * * *" 15 | 16 | 17 | jobs: 18 | lint: 19 | runs-on: ${{ matrix.os }} 20 | strategy: 21 | matrix: 22 | os: [ubuntu-latest] 23 | python-version: [3.9] 24 | torch-version: ['torch>=1.9'] 25 | 26 | steps: 27 | 28 | # WITHOUT OPENMM 29 | - uses: actions/checkout@v2 30 | - name: Set up Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v2 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | - name: Install dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install flake8 pytest "${{ matrix.torch-version }}" numpy nflows torchdiffeq 38 | - name: Lint with flake8 39 | run: | 40 | # stop the build if there are Python syntax errors or undefined names 41 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 42 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 43 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw[a-z] 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | docs/api/generated 76 | docs/datasets 77 | docs/examples 78 | docs/nb_examples 79 | 80 | 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2021 Jonas Köhler, Andreas Krämer, Manuel Dibak, Leon Klein, Frank Noé 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /bgflow/__init__.py: -------------------------------------------------------------------------------- 1 | """Boltzmann Generators and Normalizing Flows in PyTorch""" 2 | 3 | # Handle versioneer 4 | from ._version import get_versions 5 | versions = get_versions() 6 | __version__ = versions['version'] 7 | __git_revision__ = versions['full-revisionid'] 8 | del get_versions, versions 9 | 10 | from .distribution import * 11 | from .nn import * 12 | from .factory import * 13 | from .bg import * 14 | -------------------------------------------------------------------------------- /bgflow/distribution/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | =============================================================================== 4 | Samplers 5 | =============================================================================== 6 | 7 | .. autosummary:: 8 | :toctree: generated/ 9 | :template: class.rst 10 | 11 | Sampler 12 | DataSetSampler 13 | GaussianMCMCSampler 14 | 15 | =============================================================================== 16 | Distributions 17 | =============================================================================== 18 | 19 | .. autosummary:: 20 | :toctree: generated/ 21 | :template: class.rst 22 | 23 | TorchDistribution 24 | CustomDistribution 25 | UniformDistribution 26 | MixtureDistribution 27 | NormalDistribution 28 | TruncatedNormalDistribution 29 | MeanFreeNormalDistribution 30 | ProductEnergy 31 | ProductSampler 32 | ProductDistribution 33 | 34 | """ 35 | 36 | from .distributions import * 37 | from .energy import * 38 | from .sampling import * 39 | from .normal import * 40 | from .mixture import * 41 | from .product import * 42 | -------------------------------------------------------------------------------- /bgflow/distribution/energy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. currentmodule: bgflow.distribution.energy 3 | 4 | =============================================================================== 5 | Double Well Potential 6 | =============================================================================== 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | :template: class.rst 11 | 12 | DoubleWellEnergy 13 | MultiDoubleWellPotential 14 | 15 | =============================================================================== 16 | Lennard Jones Potential 17 | =============================================================================== 18 | 19 | .. autosummary:: 20 | :toctree: generated/ 21 | :template: class.rst 22 | 23 | LennardJonesPotential 24 | 25 | =============================================================================== 26 | OpenMMBridge 27 | =============================================================================== 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :template: class.rst 32 | 33 | OpenMMBridge 34 | OpenMMEnergy 35 | 36 | =============================================================================== 37 | Particle Box 38 | =============================================================================== 39 | 40 | .. autosummary:: 41 | :toctree: generated/ 42 | :template: class.rst 43 | 44 | RepulsiveParticles 45 | HarmonicParticles 46 | 47 | =============================================================================== 48 | Clipped Energies 49 | =============================================================================== 50 | 51 | .. autosummary:: 52 | :toctree: generated/ 53 | :template: class.rst 54 | 55 | LinLogCutEnergy 56 | GradientClippedEnergy 57 | 58 | =============================================================================== 59 | Base 60 | =============================================================================== 61 | 62 | .. autosummary:: 63 | :toctree: generated/ 64 | :template: class.rst 65 | 66 | Energy 67 | 68 | """ 69 | 70 | 71 | from .base import * 72 | from .double_well import * 73 | from .particles import * 74 | from .lennard_jones import * 75 | from .openmm import * 76 | from .multi_double_well_potential import * 77 | from .clipped import * 78 | from .openmm import * 79 | from .xtb import * 80 | from .ase import * 81 | -------------------------------------------------------------------------------- /bgflow/distribution/energy/ase.py: -------------------------------------------------------------------------------- 1 | """Wrapper around ASE (atomic simulation environment) 2 | """ 3 | __all__ = ["ASEBridge", "ASEEnergy"] 4 | 5 | 6 | import warnings 7 | import torch 8 | import numpy as np 9 | from .base import _BridgeEnergy, _Bridge 10 | 11 | 12 | class ASEBridge(_Bridge): 13 | """Wrapper around Atomic Simulation Environment. 14 | 15 | Parameters 16 | ---------- 17 | atoms : ase.Atoms 18 | An `Atoms` object that has a calculator attached to it. 19 | temperature : float 20 | Temperature in Kelvin. 21 | err_handling : str 22 | How to deal with exceptions inside ase. One of `["ignore", "warning", "error"]` 23 | 24 | Notes 25 | ----- 26 | Requires the ase package (installable with `conda install -c conda-forge ase`). 27 | 28 | """ 29 | def __init__( 30 | self, 31 | atoms, 32 | temperature: float, 33 | err_handling: str = "warning" 34 | ): 35 | super().__init__() 36 | assert hasattr(atoms, "calc") 37 | self.atoms = atoms 38 | self.temperature = temperature 39 | self.err_handling = err_handling 40 | 41 | @property 42 | def n_atoms(self): 43 | return len(self.atoms) 44 | 45 | def _evaluate_single( 46 | self, 47 | positions: torch.Tensor, 48 | evaluate_force=True, 49 | evaluate_energy=True, 50 | ): 51 | from ase.units import kB, nm 52 | kbt = kB * self.temperature 53 | energy, force = None, None 54 | try: 55 | self.atoms.positions = positions * nm 56 | if evaluate_energy: 57 | energy = self.atoms.get_potential_energy() / kbt 58 | if evaluate_force: 59 | force = self.atoms.get_forces() / (kbt / nm) 60 | assert not np.isnan(energy) 61 | assert not np.isnan(force).any() 62 | except AssertionError as e: 63 | force[np.isnan(force)] = 0. 64 | energy = np.infty 65 | if self.err_handling == "warning": 66 | warnings.warn("Found nan in ase force or energy. Returning infinite energy and zero force.") 67 | elif self.err_handling == "error": 68 | raise e 69 | return energy, force 70 | 71 | 72 | class ASEEnergy(_BridgeEnergy): 73 | """Energy computation with calculators from the atomic simulation environment (ASE). 74 | Various molecular simulation programs provide wrappers for ASE, 75 | see https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html 76 | for a list of available calculators. 77 | 78 | Examples 79 | -------- 80 | Use the calculator from the xtb package to compute the energy of a water molecule with the GFN2-xTB method. 81 | >>> from ase.build import molecule 82 | >>> from xtb.ase.calculator import XTB 83 | >>> water = molecule("H2O") 84 | >>> water.calc = XTB() 85 | >>> target = ASEEnergy(ASEBridge(water, 300.)) 86 | >>> pos = torch.tensor(0.1*water.positions, **ctx) 87 | >>> energy = target.energy(pos) 88 | 89 | Parameters 90 | ---------- 91 | ase_bridge : ASEBridge 92 | The wrapper object. 93 | two_event_dims : bool 94 | Whether to use two event dimensions. 95 | In this case, the energy call expects positions of shape (*batch_shape, n_atoms, 3). 96 | Otherwise, it expects positions of shape (*batch_shape, n_atoms * 3). 97 | """ 98 | pass 99 | -------------------------------------------------------------------------------- /bgflow/distribution/energy/clipped.py: -------------------------------------------------------------------------------- 1 | from ...utils.train import linlogcut, ClipGradient 2 | from .base import Energy 3 | 4 | 5 | __all__ = ["LinLogCutEnergy", "GradientClippedEnergy"] 6 | 7 | 8 | class LinLogCutEnergy(Energy): 9 | """Cut off energy at singularities. 10 | 11 | Parameters 12 | ---------- 13 | energy : Energy 14 | high_energy : float 15 | Energies beyond this value are replaced by `u = high_energy + log(1 + energy - high_energy)` 16 | max_energy : float 17 | Upper bound for energies returned by this object. 18 | """ 19 | def __init__(self, energy, high_energy=1e3, max_energy=1e9): 20 | super().__init__(energy.event_shapes) 21 | self.delegate = energy 22 | self.high_energy = high_energy 23 | self.max_energy = max_energy 24 | 25 | def _energy(self, *xs, **kwargs): 26 | u = self.delegate.energy(*xs, **kwargs) 27 | return linlogcut(u, high_val=self.high_energy, max_val=self.max_energy) 28 | 29 | 30 | class GradientClippedEnergy(Energy): 31 | """An Energy with clipped gradients. See `ClipGradient` for details.""" 32 | def __init__(self, energy: Energy, gradient_clipping: ClipGradient): 33 | super().__init__(energy.event_shapes) 34 | self.delegate = energy 35 | self.clipping = gradient_clipping 36 | 37 | def _energy(self, *xs, **kwargs): 38 | return self.delegate.energy(*((self.clipping(x) for x in xs)), **kwargs) 39 | -------------------------------------------------------------------------------- /bgflow/distribution/energy/double_well.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Energy 4 | from numpy import pi as PI 5 | 6 | 7 | __all__ = ["DoubleWellEnergy", "MultiDimensionalDoubleWell", "MuellerEnergy", "ModifiedWolfeQuapp"] 8 | 9 | 10 | class DoubleWellEnergy(Energy): 11 | def __init__(self, dim, a=0, b=-4.0, c=1.0): 12 | super().__init__(dim) 13 | self._a = a 14 | self._b = b 15 | self._c = c 16 | 17 | def _energy(self, x): 18 | d = x[..., [0]] 19 | v = x[..., 1:] 20 | e1 = self._a * d + self._b * d.pow(2) + self._c * d.pow(4) 21 | e2 = 0.5 * v.pow(2).sum(dim=-1, keepdim=True) 22 | return e1 + e2 23 | 24 | 25 | class MultiDimensionalDoubleWell(Energy): 26 | def __init__(self, dim, a=0.0, b=-4.0, c=1.0, transformer=None): 27 | super().__init__(dim) 28 | if not isinstance(a, torch.Tensor): 29 | a = torch.tensor(a) 30 | if not isinstance(b, torch.Tensor): 31 | b = torch.tensor(b) 32 | if not isinstance(c, torch.Tensor): 33 | c = torch.tensor(c) 34 | self.register_buffer("_a", a) 35 | self.register_buffer("_b", b) 36 | self.register_buffer("_c", c) 37 | if transformer is not None: 38 | self.register_buffer("_transformer", transformer) 39 | else: 40 | self._transformer = None 41 | 42 | def _energy(self, x): 43 | if self._transformer is not None: 44 | x = torch.matmul(x, self._transformer) 45 | e1 = self._a * x + self._b * x.pow(2) + self._c * x.pow(4) 46 | return e1.sum(dim=1, keepdim=True) 47 | 48 | 49 | class MuellerEnergy(Energy): 50 | def __init__(self, dim=2, scale1=0.15, scale2=15, beta=1.): 51 | super().__init__(dim) 52 | assert dim >= 2 53 | self._scale1 = scale1 54 | self._scale2 = scale2 55 | self._beta = beta 56 | 57 | def _energy(self, x): 58 | xx = x[..., [0]] 59 | yy = x[..., [1]] 60 | e1 = -200 * torch.exp(-(xx -1).pow(2) -10 * yy.pow(2)) 61 | e2 = -100 * torch.exp(-xx.pow(2) -10 * (yy -0.5).pow(2)) 62 | e3 = -170 * torch.exp(-6.5 * (0.5 + xx).pow(2) +11 * (xx +0.5) * (yy -1.5) -6.5 * (yy -1.5).pow(2)) 63 | e4 = 15.0 * torch.exp(0.7 * (1 +xx).pow(2) +0.6 * (xx +1) * (yy -1) +0.7 * (yy -1).pow(2)) +146.7 64 | v = x[..., 2:] 65 | ev = self._scale2 * 0.5 * v.pow(2).sum(dim=-1, keepdim=True) 66 | return self._beta * (self._scale1 * (e1 + e2 + e3 + e4) + ev) 67 | 68 | @property 69 | def potential_str(self): 70 | pot_str = f'{self._scale1:g}*(-200*exp(-(x-1)^2-10*y^2)-100*exp(-x^2-10*(y-0.5)^2)-170*exp(-6.5*(0.5+x)^2+11*(x+0.5)*(y-1.5)-6.5*(y-1.5)^2)+15*exp(0.7*(1+x)^2+0.6*(x+1)*(y-1)+0.7*(y-1)^2)+146.7)' 71 | if self.dim >= 3: 72 | pot_str += f'+{self._scale2:g}*0.5*z^2' 73 | return pot_str 74 | 75 | class ModifiedWolfeQuapp(Energy): 76 | def __init__(self, dim=2, theta=-0.3*PI/2, scale1=2, scale2=15, beta=1.): 77 | super().__init__(dim) 78 | assert dim >= 2 79 | self._scale1 = scale1 80 | self._scale2 = scale2 81 | self._beta = beta 82 | self._c = torch.cos(torch.as_tensor(theta)) 83 | self._s = torch.sin(torch.as_tensor(theta)) 84 | 85 | def _energy(self, x): 86 | xx = self._c * x[..., [0]] - self._s * x[..., [1]] 87 | yy = self._s * x[..., [0]] + self._c * x[..., [1]] 88 | e4 = xx.pow(4) + yy.pow(4) 89 | e2 = -2 * xx.pow(2) - 4 * yy.pow(2) + 2 * xx * yy 90 | e1 = 0.8 * xx + 0.1 * yy + 9.28 91 | v = x[..., 2:] 92 | ev = self._scale2 * 0.5 * v.pow(2).sum(dim=-1, keepdim=True) 93 | return self._beta * (self._scale1 * (e4 + e2 + e1) + ev) 94 | 95 | @property 96 | def potential_str(self): 97 | x_str = f'({self._c:g}*x-{self._s:g}*y)' 98 | y_str = f'({self._s:g}*x+{self._c:g}*y)' 99 | pot_str = f'{self._scale1:g}*({x_str}^4+{y_str}^4-2*{x_str}^2-4*{y_str}^2+2*{x_str}*{y_str}+0.8*{x_str}+0.1*{y_str}+9.28)' 100 | if self.dim >= 3: 101 | pot_str += f'+{self._scale2:g}*0.5*z^2' 102 | return pot_str 103 | -------------------------------------------------------------------------------- /bgflow/distribution/energy/lennard_jones.py: -------------------------------------------------------------------------------- 1 | from .base import Energy 2 | from bgflow.utils import distance_vectors, distances_from_vectors 3 | import torch 4 | 5 | 6 | __all__ = ["LennardJonesPotential"] 7 | 8 | 9 | def lennard_jones_energy_torch(r, eps=1.0, rm=1.0): 10 | lj = eps * ((rm / r) ** 12 - 2 * (rm / r) ** 6) 11 | return lj 12 | 13 | 14 | class LennardJonesPotential(Energy): 15 | def __init__( 16 | self, dim, n_particles, eps=1.0, rm=1.0, oscillator=True, oscillator_scale=1., two_event_dims=True): 17 | """Energy for a Lennard-Jones cluster 18 | 19 | Parameters 20 | ---------- 21 | dim : int 22 | Number of degrees of freedom ( = space dimension x n_particles) 23 | n_particles : int 24 | Number of Lennard-Jones particles 25 | eps : float 26 | LJ well depth epsilon 27 | rm : float 28 | LJ well radius R_min 29 | oscillator : bool 30 | Whether to use a harmonic oscillator as an external force 31 | oscillator_scale : float 32 | Force constant of the harmonic oscillator energy 33 | two_event_dims : bool 34 | If True, the energy expects inputs with two event dimensions (particle_id, coordinate). 35 | Else, use only one event dimension. 36 | """ 37 | if two_event_dims: 38 | super().__init__([n_particles, dim//n_particles]) 39 | else: 40 | super().__init__(dim) 41 | self._n_particles = n_particles 42 | self._n_dims = dim // n_particles 43 | 44 | self._eps = eps 45 | self._rm = rm 46 | self.oscillator = oscillator 47 | self._oscillator_scale = oscillator_scale 48 | 49 | def _energy(self, x): 50 | batch_shape = x.shape[:-len(self.event_shape)] 51 | x = x.view(*batch_shape, self._n_particles, self._n_dims) 52 | 53 | dists = distances_from_vectors( 54 | distance_vectors(x.view(-1, self._n_particles, self._n_dims)) 55 | ) 56 | 57 | lj_energies = lennard_jones_energy_torch(dists, self._eps, self._rm) 58 | lj_energies = lj_energies.view(*batch_shape, -1).sum(dim=-1) / 2 59 | 60 | if self.oscillator: 61 | osc_energies = 0.5 * self._remove_mean(x).pow(2).sum(dim=(-2, -1)).view(*batch_shape) 62 | lj_energies = lj_energies + osc_energies * self._oscillator_scale 63 | 64 | return lj_energies[:, None] 65 | 66 | def _remove_mean(self, x): 67 | x = x.view(-1, self._n_particles, self._n_dims) 68 | return x - torch.mean(x, dim=1, keepdim=True) 69 | 70 | def _energy_numpy(self, x): 71 | x = torch.Tensor(x) 72 | return self._energy(x).cpu().numpy() 73 | -------------------------------------------------------------------------------- /bgflow/distribution/energy/multi_double_well_potential.py: -------------------------------------------------------------------------------- 1 | from .base import Energy 2 | from bgflow.utils import compute_distances 3 | 4 | __all__ = ["MultiDoubleWellPotential"] 5 | 6 | 7 | class MultiDoubleWellPotential(Energy): 8 | """Energy for a many particle system with pair wise double-well interactions. 9 | The energy of the double-well is given via 10 | 11 | .. math:: 12 | E_{DW}(d) = a \cdot (d-d_{\text{offset})^4 + b \cdot (d-d_{\text{offset})^2 + c. 13 | 14 | Parameters 15 | ---------- 16 | dim : int 17 | Number of degrees of freedom ( = space dimension x n_particles) 18 | n_particles : int 19 | Number of particles 20 | a, b, c, offset : float 21 | parameters of the potential 22 | """ 23 | 24 | def __init__(self, dim, n_particles, a, b, c, offset, two_event_dims=True): 25 | if two_event_dims: 26 | super().__init__([n_particles, dim // n_particles]) 27 | else: 28 | super().__init__(dim) 29 | self._dim = dim 30 | self._n_particles = n_particles 31 | self._n_dimensions = dim // n_particles 32 | self._a = a 33 | self._b = b 34 | self._c = c 35 | self._offset = offset 36 | 37 | def _energy(self, x): 38 | x = x.contiguous() 39 | dists = compute_distances(x, self._n_particles, self._n_dimensions) 40 | dists = dists - self._offset 41 | 42 | energies = self._a * dists ** 4 + self._b * dists ** 2 + self._c 43 | return energies.sum(-1, keepdim=True) 44 | -------------------------------------------------------------------------------- /bgflow/distribution/mixture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .energy import Energy 5 | from .sampling import Sampler 6 | from ..utils.types import assert_numpy 7 | 8 | 9 | __all__ = ["MixtureDistribution"] 10 | 11 | 12 | class MixtureDistribution(Energy, Sampler): 13 | def __init__(self, components, unnormed_log_weights=None, trainable_weights=False): 14 | assert all([c.dim == components[0].dim for c in components]),\ 15 | "All mixture components must have the same dimensionality." 16 | super().__init__(components[0].dim) 17 | self._components = torch.nn.ModuleList(components) 18 | 19 | if unnormed_log_weights is None: 20 | unnormed_log_weights = torch.zeros(len(components)) 21 | else: 22 | assert len(unnormed_log_weights.shape) == 1,\ 23 | "Mixture weights must be a Tensor of shape `[n_components]`." 24 | assert len(unnormed_log_weights) == len(components),\ 25 | "Number of mixture weights does not match number of components." 26 | if trainable_weights: 27 | self._unnormed_log_weights = torch.nn.Parameter(unnormed_log_weights) 28 | else: 29 | self.register_buffer("_unnormed_log_weights", unnormed_log_weights) 30 | 31 | @property 32 | def _log_weights(self): 33 | return torch.log_softmax(self._unnormed_log_weights, dim=-1) 34 | 35 | def _sample(self, n_samples): 36 | weights_numpy = assert_numpy(self._log_weights.exp()) 37 | ns = np.random.multinomial(n_samples, weights_numpy, 1)[0] 38 | samples = [c.sample(n) for n, c in zip(ns, self._components)] 39 | return torch.cat(samples, dim=0) 40 | 41 | def _energy(self, x): 42 | energies = torch.stack([c.energy(x) for c in self._components], dim=-1) 43 | return -torch.logsumexp(-energies + self._log_weights.view(1, 1, -1), dim=-1) 44 | 45 | def _log_assignments(self, x): 46 | energies = torch.stack([c.energy(x) for c in self._components], dim=-1) 47 | return -energies -------------------------------------------------------------------------------- /bgflow/distribution/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .mcmc import * 3 | from .dataset import * 4 | from .buffer import * 5 | from .iterative import * -------------------------------------------------------------------------------- /bgflow/distribution/sampling/_iterative_helpers.py: -------------------------------------------------------------------------------- 1 | """Helper classes and functions for iterative samplers.""" 2 | 3 | import torch 4 | 5 | 6 | __all__ = ["AbstractSamplerState", "default_set_samples_hook", "default_extract_sample_hook"] 7 | 8 | 9 | class AbstractSamplerState: 10 | """Defines the interface for implementations of the internal state of iterative samplers.""" 11 | 12 | def as_dict(self): 13 | """Return a dictionary representing this instance. The dictionary has to define the 14 | keys that are used within `SamplerStep`s of an `IterativeSampler`, such as "samples", "energies", ... 15 | """ 16 | raise NotImplementedError() 17 | 18 | def _replace(self, **kwargs): 19 | """Return a new object with changed fields. 20 | This function has to support all the keys that are used 21 | within `SamplerStep`s of an `IterativeSampler` as well as the keys "energies_up_to_date" and 22 | "forces_up_to_date" 23 | """ 24 | raise NotImplementedError() 25 | 26 | def evaluate_energy_force(self, energy_model, evaluate_energies=True, evaluate_forces=True): 27 | """Return a new state with updated energies/forces.""" 28 | state = self.as_dict() 29 | evaluate_energies = evaluate_energies and not state["energies_up_to_date"] 30 | energies = energy_model.energy(*state["samples"])[..., 0] if evaluate_energies else state["energies"] 31 | 32 | evaluate_forces = evaluate_forces and not state["forces_up_to_date"] 33 | forces = energy_model.force(*state["samples"]) if evaluate_forces else state["forces"] 34 | return self.replace(energies=energies, forces=forces) 35 | 36 | def replace(self, **kwargs): 37 | """Return a new state with updated fields.""" 38 | 39 | # keep track of energies and forces 40 | state_dict = self.as_dict() 41 | if "energies" in kwargs: 42 | kwargs = {**kwargs, "energies_up_to_date": True} 43 | elif "samples" in kwargs: 44 | kwargs = {**kwargs, "energies_up_to_date": False} 45 | if "forces" in kwargs: 46 | kwargs = {**kwargs, "forces_up_to_date": True} 47 | elif "samples" in kwargs: 48 | kwargs = {**kwargs, "forces_up_to_date": False} 49 | 50 | # map to primary unit cell 51 | box_vectors = None 52 | if "box_vectors" in kwargs: 53 | box_vectors = kwargs["box_vectors"] 54 | elif "box_vectors" in state_dict: 55 | box_vectors = state_dict["box_vectors"] 56 | if "samples" in kwargs and box_vectors is not None: 57 | kwargs = { 58 | **kwargs, 59 | "samples": tuple( 60 | _map_to_primary_cell(x, cell) 61 | for x, cell in zip(kwargs["samples"], box_vectors) 62 | ) 63 | } 64 | return self._replace(**kwargs) 65 | 66 | 67 | def default_set_samples_hook(x): 68 | """by default, use samples as is""" 69 | return x 70 | 71 | 72 | def default_extract_sample_hook(state: AbstractSamplerState): 73 | """Default extraction of samples from a SamplerState.""" 74 | return state.as_dict()["samples"] 75 | 76 | 77 | def _bmv(m, bv): 78 | """Batched matrix-vector multiply.""" 79 | return torch.einsum("ij,...j->...i", m, bv) 80 | 81 | 82 | def _map_to_primary_cell(x, cell): 83 | """Map coordinates to the primary unit cell of a periodic lattice. 84 | 85 | Parameters 86 | ---------- 87 | x : torch.Tensor 88 | n-dimensional coordinates of shape (..., n), where n is the spatial dimension and ... denote an 89 | arbitrary number of batch dimensions. 90 | cell : torch.Tensor 91 | Lattice vectors (column-wise). Has to be upper triangular. 92 | """ 93 | if cell is None: 94 | return x 95 | n = _bmv(torch.inverse(cell), x) 96 | n = torch.floor(n) 97 | return x - _bmv(cell, n) 98 | -------------------------------------------------------------------------------- /bgflow/distribution/sampling/_mcmc/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "noe" 2 | 3 | from .analysis import * 4 | from .latent_sampling import * 5 | from .metropolis import * 6 | from .permutation import * 7 | from .umbrella_sampling import * 8 | -------------------------------------------------------------------------------- /bgflow/distribution/sampling/_mcmc/permutation.py: -------------------------------------------------------------------------------- 1 | __author__ = "noe" 2 | 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | from deep_boltzmann.util import ensure_traj, distance_matrix_squared 7 | 8 | 9 | class HungarianMapper: 10 | def __init__(self, xref, dim=2, identical_particles=None): 11 | """ Permutes identical particles to minimize distance to reference structure. 12 | 13 | For a given structure or set of structures finds the permutation of identical particles 14 | that minimizes the mean square distance to a given reference structure. The optimization 15 | is done by solving the linear sum assignment problem with the Hungarian algorithm. 16 | 17 | Parameters 18 | ---------- 19 | xref : array 20 | reference structure 21 | dim : int 22 | number of dimensions of particle system to define relation between vector position and 23 | particle index. If dim=2, coordinate vectors are [x1, y1, x2, y2, ...]. 24 | indentical_particles : None or array 25 | indices of particles subject to permutation. If None, all particles are used 26 | 27 | """ 28 | self.xref = xref 29 | self.dim = dim 30 | if identical_particles is None: 31 | identical_particles = np.arange(xref.size) 32 | self.identical_particles = identical_particles 33 | self.ip_indices = np.concatenate( 34 | [dim * self.identical_particles + i for i in range(dim)] 35 | ) 36 | self.ip_indices.sort() 37 | 38 | def map(self, X): 39 | """ Maps X (configuration or trajectory) to reference structure by permuting identical particles """ 40 | X = ensure_traj(X) 41 | Y = X.copy() 42 | C = distance_matrix_squared( 43 | np.tile(self.xref[:, self.ip_indices], (X.shape[0], 1)), 44 | X[:, self.ip_indices], 45 | ) 46 | 47 | for i in range(C.shape[0]): # for each configuration 48 | _, col_assignment = linear_sum_assignment(C[i]) 49 | assignment_components = [ 50 | self.dim * col_assignment + i for i in range(self.dim) 51 | ] 52 | col_assignment = np.vstack(assignment_components).T.flatten() 53 | Y[i, self.ip_indices] = X[i, self.ip_indices[col_assignment]] 54 | return Y 55 | 56 | def is_permuted(self, X): 57 | """ Returns True for permuted configurations """ 58 | X = ensure_traj(X) 59 | C = distance_matrix_squared( 60 | np.tile(self.xref[:, self.ip_indices], (X.shape[0], 1)), 61 | X[:, self.ip_indices], 62 | ) 63 | isP = np.zeros(X.shape[0], dtype=bool) 64 | 65 | for i in range(C.shape[0]): # for each configuration 66 | _, col_assignment = linear_sum_assignment(C[i]) 67 | assignment_components = [ 68 | self.dim * col_assignment + i for i in range(self.dim) 69 | ] 70 | col_assignment = np.vstack(assignment_components).T.flatten() 71 | if not np.all(col_assignment == np.arange(col_assignment.size)): 72 | isP[i] = True 73 | return isP 74 | -------------------------------------------------------------------------------- /bgflow/distribution/sampling/base.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Tuple 3 | import torch 4 | from ...utils.types import unpack_tensor_tuple, pack_tensor_in_list 5 | 6 | __all__ = ["Sampler"] 7 | 8 | 9 | class Sampler(torch.nn.Module): 10 | """Abstract base class for samplers. 11 | 12 | Parameters 13 | ---------- 14 | return_hook : Callable, optional 15 | A function to postprocess the samples. This can (for example) be used to 16 | only return samples at a selected thermodynamic state of a replica exchange sampler 17 | or to combine the batch and sample dimension. 18 | The function takes a list of tensors and should return a list of tensors. 19 | Each tensor contains a batch of samples. 20 | """ 21 | 22 | def __init__(self, return_hook=lambda x: x, **kwargs): 23 | super().__init__(**kwargs) 24 | self.return_hook = return_hook 25 | 26 | def _sample_with_temperature(self, n_samples, temperature, *args, **kwargs): 27 | raise NotImplementedError() 28 | 29 | def _sample(self, n_samples, *args, **kwargs): 30 | raise NotImplementedError() 31 | 32 | def sample(self, n_samples, temperature=1.0, *args, **kwargs): 33 | """Create a number of samples. 34 | 35 | Parameters 36 | ---------- 37 | n_samples : int 38 | The number of samples to be created. 39 | temperature : float, optional 40 | The relative temperature at which to create samples. 41 | Only available for sampler that implement `_sample_with_temperature`. 42 | 43 | Returns 44 | ------- 45 | samples : Union[torch.Tensor, Tuple[torch.Tensor, ...]] 46 | If this sampler reflects a joint distribution of multiple tensors, 47 | it returns a tuple of tensors, each of which have length n_samples. 48 | Otherwise it returns a single tensor of length n_samples. 49 | """ 50 | if isinstance(temperature, float) and temperature == 1.0: 51 | samples = self._sample(n_samples, *args, **kwargs) 52 | else: 53 | samples = self._sample_with_temperature(n_samples, temperature, *args, **kwargs) 54 | samples = pack_tensor_in_list(samples) 55 | return unpack_tensor_tuple(self.return_hook(samples)) 56 | 57 | def sample_to_cpu(self, n_samples, batch_size=64, *args, **kwargs): 58 | """A utility method for creating many samples that might not fit into GPU memory.""" 59 | with torch.no_grad(): 60 | samples = self.sample(min(n_samples, batch_size), *args, **kwargs) 61 | samples = pack_tensor_in_list(samples) 62 | samples = [tensor.detach().cpu() for tensor in samples] 63 | while len(samples[0]) < n_samples: 64 | new_samples = self.sample(min(n_samples-len(samples[0]), batch_size), *args, **kwargs) 65 | new_samples = pack_tensor_in_list(new_samples) 66 | for i, new in enumerate(new_samples): 67 | samples[i] = torch.cat([samples[i], new.detach().cpu()], dim=0) 68 | return unpack_tensor_tuple(samples) 69 | -------------------------------------------------------------------------------- /bgflow/factory/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .tensor_info import * 4 | from .transformer_factory import * 5 | from .conditioner_factory import * 6 | from .distribution_factory import * 7 | from .icmarginals import * 8 | from .generator_builder import * 9 | -------------------------------------------------------------------------------- /bgflow/factory/distribution_factory.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import bgflow as bg 4 | 5 | 6 | __all__ = ["make_distribution"] 7 | 8 | # === Prior Factory === 9 | 10 | 11 | def make_distribution(distribution_type, shape, **kwargs): 12 | factory = DISTRIBUTION_FACTORIES[distribution_type] 13 | return factory(shape=shape, **kwargs) 14 | 15 | 16 | def _make_uniform_distribution(shape, device=None, dtype=None, **kwargs): 17 | defaults = { 18 | "low": torch.zeros(shape), 19 | "high": torch.ones(shape) 20 | } 21 | defaults.update(kwargs) 22 | for key in defaults: 23 | if isinstance(defaults[key], torch.Tensor): 24 | defaults[key] = defaults[key].to(device=device, dtype=dtype) 25 | return bg.UniformDistribution(**defaults) 26 | 27 | 28 | def _make_normal_distribution(shape, device=None, dtype=None, **kwargs): 29 | defaults = { 30 | "dim": shape, 31 | "mean": torch.zeros(shape), 32 | } 33 | defaults.update(kwargs) 34 | for key in defaults: 35 | if isinstance(defaults[key], torch.Tensor): 36 | defaults[key] = defaults[key].to(device=device, dtype=dtype) 37 | return bg.NormalDistribution(**defaults) 38 | 39 | 40 | def _make_truncated_normal_distribution(shape, device=None, dtype=None, **kwargs): 41 | defaults = { 42 | "mu": torch.zeros(shape), 43 | "sigma": torch.ones(shape), 44 | } 45 | defaults.update(kwargs) 46 | for key in defaults: 47 | if isinstance(defaults[key], torch.Tensor): 48 | defaults[key] = defaults[key].to(device=device, dtype=dtype) 49 | return bg.TruncatedNormalDistribution(**defaults) 50 | 51 | 52 | DISTRIBUTION_FACTORIES = { 53 | bg.UniformDistribution: _make_uniform_distribution, 54 | bg.NormalDistribution: _make_normal_distribution, 55 | bg.TruncatedNormalDistribution: _make_truncated_normal_distribution 56 | } 57 | 58 | -------------------------------------------------------------------------------- /bgflow/factory/transformer_factory.py: -------------------------------------------------------------------------------- 1 | """Factory for flow transformations.""" 2 | 3 | import torch 4 | from ..nn.flow.inverted import InverseFlow 5 | from ..nn.flow.transformer.affine import AffineTransformer 6 | from ..nn.flow.transformer.spline import ConditionalSplineTransformer 7 | 8 | __all__ = ["make_transformer"] 9 | 10 | 11 | def make_transformer(transformer_type, what, shape_info, conditioners, inverse=False, **kwargs): 12 | """Factory function. 13 | 14 | Parameters 15 | ---------- 16 | transformer_type : bgflow. 17 | """ 18 | factory = TRANSFORMER_FACTORIES[transformer_type] 19 | transformer = factory(what=what, shape_info=shape_info, conditioners=conditioners, **kwargs) 20 | if inverse: 21 | transformer = InverseFlow(transformer) 22 | return transformer 23 | 24 | 25 | def _make_spline_transformer(what, shape_info, conditioners, **kwargs): 26 | return ConditionalSplineTransformer( 27 | is_circular=shape_info.is_circular(what), 28 | **conditioners, 29 | **kwargs 30 | ) 31 | 32 | 33 | def _make_affine_transformer(what, shape_info, conditioners, **kwargs): 34 | if shape_info.dim_circular(what) not in [0, shape_info[what[0]][-1]]: 35 | raise NotImplementedError( 36 | "Circular affine transformers are currently " 37 | "not supported for partly circular indices." 38 | ) 39 | return AffineTransformer( 40 | **conditioners, 41 | is_circular=shape_info.dim_circular(what) > 0, 42 | **kwargs 43 | ) 44 | 45 | # def _make_sigmoid_transformer( 46 | # what, 47 | # shape_info, 48 | # conditioners, 49 | # smoothness_type="type1", 50 | # zero_boundary_left=False, 51 | # zero_boundary_right=False, 52 | # **kwargs 53 | # ): 54 | # assert all(field.is_circular == what[0].is_circular for field in what) 55 | # transformer = bg.MixtureCDFTransformer( 56 | # compute_weights=conditioners["weights"], 57 | # compute_components=bg.AffineSigmoidComponents( 58 | # conditional_ramp=bg.SmoothRamp( 59 | # compute_alpha=conditioners["alphas"], 60 | # unimodal=True, 61 | # ramp_type=smoothness_type 62 | # ), 63 | # compute_params=conditioners["params"], 64 | # periodic=what[0].is_circular, 65 | # min_density=torch.tensor(1e-6), 66 | # log_sigma_bound=torch.tensor(1.), 67 | # zero_boundary_left=zero_boundary_left, 68 | # zero_boundary_right=zero_boundary_right, 69 | # **kwargs 70 | # ) 71 | # ) 72 | # transformer = bg.WrapCDFTransformerWithInverse( 73 | # transformer=transformer, 74 | # oracle=bg.GridInversion( #bg.BisectionRootFinder( 75 | # transformer=transformer, 76 | # compute_init_grid=lambda x, y: torch.linspace(0, 1, 100).view(-1, 1, 1).repeat(1, *y.shape).to(y) 77 | # #abs_tol=torch.tensor(1e-5), max_iters=max_iters, verbose=False, raise_exception=True 78 | # ) 79 | # ) 80 | # return transformer 81 | 82 | 83 | TRANSFORMER_FACTORIES = { 84 | ConditionalSplineTransformer: _make_spline_transformer, 85 | AffineTransformer: _make_affine_transformer, 86 | # MixtureCDFTransformer: _make_sigmoid_transformer 87 | } 88 | 89 | -------------------------------------------------------------------------------- /bgflow/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .dense import * 2 | from .periodic import * 3 | from .flow import * 4 | from .training import * 5 | -------------------------------------------------------------------------------- /bgflow/nn/dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..utils.types import is_list_or_tuple 4 | 5 | 6 | __all__ = ["DenseNet", "MeanFreeDenseNet"] 7 | 8 | 9 | class DenseNet(torch.nn.Module): 10 | def __init__(self, n_units, activation=None, weight_scale=1.0, bias_scale=0.0): 11 | """ 12 | Simple multi-layer perceptron. 13 | 14 | Parameters: 15 | ----------- 16 | n_units : List / Tuple of integers. 17 | activation : Non-linearity or List / Tuple of non-linearities. 18 | If List / Tuple then each nonlinearity will be placed after each respective hidden layer. 19 | If just a single non-linearity, will be applied to all hidden layers. 20 | If set to None no non-linearity will be applied. 21 | """ 22 | super().__init__() 23 | 24 | dims_in = n_units[:-1] 25 | dims_out = n_units[1:] 26 | 27 | if is_list_or_tuple(activation): 28 | assert len(activation) == len(n_units) - 2 29 | 30 | layers = [] 31 | for i, (dim_in, dim_out) in enumerate(zip(dims_in, dims_out)): 32 | layers.append(torch.nn.Linear(dim_in, dim_out)) 33 | layers[-1].weight.data *= weight_scale 34 | if bias_scale > 0.0: 35 | layers[-1].bias.data = ( 36 | torch.Tensor(layers[-1].bias.data).uniform_() * bias_scale 37 | ) 38 | if i < len(n_units) - 2: 39 | if activation is not None: 40 | if is_list_or_tuple(activation): 41 | layers.append(activation[i]) 42 | else: 43 | layers.append(activation) 44 | 45 | self._layers = torch.nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | return self._layers(x) 49 | 50 | 51 | class MeanFreeDenseNet(DenseNet): 52 | def forward(self, x): 53 | y = self._layers(x) 54 | return y - y.mean(dim=1, keepdim=True) 55 | -------------------------------------------------------------------------------- /bgflow/nn/flow/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. currentmodule: bgflow.nn.flow 3 | 4 | =============================================================================== 5 | Coupling flows 6 | =============================================================================== 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | :template: class.rst 11 | 12 | CouplingFlow 13 | SplitFlow 14 | MergeFlow 15 | SwapFlow 16 | WrapFlow 17 | Transformer 18 | AffineTransformer 19 | TruncatedGaussianTransformer 20 | ConditionalSplineTransformer 21 | ScalingLayer 22 | EntropyScalingLayer 23 | 24 | =============================================================================== 25 | Continuous Normalizing Flows 26 | =============================================================================== 27 | 28 | .. autosummary:: 29 | :toctree: generated/ 30 | :template: class.rst 31 | 32 | DiffEqFlow 33 | 34 | Dynamics Functions 35 | --------------------- 36 | 37 | .. autosummary:: 38 | :toctree: generated/ 39 | :template: class.rst 40 | 41 | BlackBoxDynamics 42 | TimeIndependentDynamics 43 | KernelDynamics 44 | DensityDynamics 45 | 46 | 47 | Jacobian Trace Estimators 48 | ------------------------------ 49 | 50 | .. autosummary:: 51 | :toctree: generated/ 52 | :template: class.rst 53 | 54 | BruteForceEstimator 55 | HutchinsonEstimator 56 | 57 | =============================================================================== 58 | Stochastic Normalizing Flows 59 | =============================================================================== 60 | 61 | .. autosummary:: 62 | :toctree: generated/ 63 | :template: class.rst 64 | 65 | MetropolisMCFlow 66 | BrownianFlow 67 | LangevinFlow 68 | StochasticAugmentation 69 | OpenMMStochasticFlow 70 | PathProbabilityIntegrator 71 | BrownianPathProbabilityIntegrator 72 | 73 | =============================================================================== 74 | Internal Coordinate Transformations 75 | =============================================================================== 76 | 77 | .. autosummary:: 78 | :toctree: generated/ 79 | :template: class.rst 80 | 81 | RelativeInternalCoordinateTransformation 82 | GlobalInternalCoordinateTransformation 83 | MixedCoordinateTransformation 84 | WhitenFlow 85 | 86 | =============================================================================== 87 | CDF Transformations 88 | =============================================================================== 89 | 90 | .. autosummary:: 91 | :toctree: generated/ 92 | :template: class.rst 93 | 94 | CDFTransform 95 | DistributionTransferFlow 96 | ConstrainGaussianFlow 97 | 98 | =============================================================================== 99 | Base 100 | =============================================================================== 101 | 102 | .. autosummary:: 103 | :toctree: generated/ 104 | :template: class.rst 105 | 106 | Flow 107 | InverseFlow 108 | SequentialFlow 109 | 110 | =============================================================================== 111 | Other 112 | =============================================================================== 113 | Docs and/or classification required 114 | 115 | .. autosummary:: 116 | :toctree: generated/ 117 | :template: class.rst 118 | 119 | AffineFlow 120 | CheckerboardFlow 121 | BentIdentity 122 | FunnelFlow 123 | KroneckerProductFlow 124 | PseudoOrthogonalFlow 125 | InvertiblePPPP 126 | PPPPScheduler 127 | TorchTransform 128 | TriuFlow 129 | BNARFlow 130 | """ 131 | 132 | from .base import * 133 | from .crd_transform import * 134 | from .dynamics import * 135 | from .estimator import * 136 | from .stochastic import * 137 | from .transformer import * 138 | 139 | from .affine import * 140 | from .coupling import * 141 | from .funnel import FunnelFlow 142 | from .kronecker import KroneckerProductFlow 143 | from .sequential import SequentialFlow 144 | from .inverted import * 145 | from .checkerboard import CheckerboardFlow 146 | from .bnaf import BNARFlow 147 | from .elementwise import * 148 | from .orthogonal import * 149 | from .triangular import * 150 | from .pppp import * 151 | from .diffeq import DiffEqFlow 152 | from .cdf import * 153 | from .torchtransform import * 154 | from .modulo import * 155 | -------------------------------------------------------------------------------- /bgflow/nn/flow/affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Flow 4 | 5 | 6 | __all__ = ["AffineFlow"] 7 | 8 | 9 | class AffineFlow(Flow): 10 | def __init__(self, n_dims, use_scaling=True, use_translation=True): 11 | super().__init__() 12 | self._n_dims = n_dims 13 | self._log_sigma = None 14 | if use_scaling: 15 | self._log_sigma = torch.nn.Parameter(torch.zeros(self._n_dims)) 16 | else: 17 | self._log_sigma = None 18 | if use_translation: 19 | self._mu = torch.nn.Parameter(torch.zeros(self._n_dims)) 20 | else: 21 | self._mu = None 22 | 23 | def _forward(self, x, **kwargs): 24 | assert x.shape[-1] == self._n_dims, "dimension `x` does not match `n_dims`" 25 | dlogp = torch.zeros(*x.shape[:-1], 1).to(x) 26 | if self._log_sigma is not None: 27 | sigma = torch.exp(self._log_sigma.to(x)) 28 | dlogp = dlogp + self._log_sigma.sum() 29 | x = sigma * x 30 | if self._mu is not None: 31 | x = x + self._mu.to(x) 32 | return x, dlogp 33 | 34 | def _inverse(self, x, **kwargs): 35 | assert x.shape[-1] == self._n_dims, "dimension `x` does not match `n_dims`" 36 | dlogp = torch.zeros(*x.shape[:-1], 1).to(x) 37 | if self._mu is not None: 38 | x = x - self._mu.to(x) 39 | if self._log_sigma is not None: 40 | sigma_inv = torch.exp(-self._log_sigma.to(x)) 41 | dlogp = dlogp - self._log_sigma.sum() 42 | x = sigma_inv * x 43 | return x, dlogp 44 | -------------------------------------------------------------------------------- /bgflow/nn/flow/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ["Flow"] 5 | 6 | 7 | class Flow(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def _forward(self, *xs, **kwargs): 12 | raise NotImplementedError() 13 | 14 | def _inverse(self, *xs, **kwargs): 15 | raise NotImplementedError() 16 | 17 | def forward(self, *xs, inverse=False, **kwargs): 18 | """ 19 | Forward method of the flow. 20 | Computes the forward or inverse direction of the flow. 21 | 22 | Parameters 23 | ---------- 24 | xs : torch.Tensor 25 | Input of the flow 26 | 27 | inverse : boolean 28 | Whether to compute the forward or inverse 29 | """ 30 | if inverse: 31 | return self._inverse(*xs, **kwargs) 32 | else: 33 | return self._forward(*xs, **kwargs) 34 | -------------------------------------------------------------------------------- /bgflow/nn/flow/checkerboard.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from itertools import product 4 | 5 | from .base import Flow 6 | 7 | 8 | def _make_checkerboard_idxs(sz): 9 | even = np.arange(sz, dtype=np.int64) % 2 10 | odd = 1 - even 11 | grid = np.arange(sz * sz, dtype=np.int64) 12 | idxs = [] 13 | for i, j in product([odd, even], repeat=2): 14 | mask = np.outer(i, j).astype(bool).reshape(-1) 15 | chunk = grid[mask] 16 | idxs.append(chunk) 17 | return np.concatenate(idxs) 18 | 19 | 20 | def _checkerboard_2x2_masks(sz): 21 | mask = _make_checkerboard_idxs(sz) 22 | inv_mask = np.argsort(mask) 23 | offset = sz ** 2 // 4 24 | sub_masks = [ 25 | mask[i * offset:(i+1) * offset] 26 | for i in range(4) 27 | ] 28 | return inv_mask, sub_masks 29 | 30 | 31 | class CheckerboardFlow(Flow): 32 | def __init__(self, size): 33 | super().__init__() 34 | self._size = size 35 | inv_mask, submasks = _checkerboard_2x2_masks(size) 36 | self.register_buffer("_sub_masks", torch.LongTensor(submasks)) 37 | self.register_buffer("_inv_mask", torch.LongTensor(inv_mask)) 38 | 39 | def _forward(self, *xs, **kwargs): 40 | n_batch = xs[0].shape[0] 41 | dlogp = torch.zeros(n_batch) 42 | sz = self._size // 2 43 | assert len(xs) == 1 44 | x = xs[0] 45 | assert len(x.shape) == 4 and x.shape[1] == self._size and x.shape[2] == self._size,\ 46 | "`x` needs to be of shape `[n_batch, size, size, n_filters]`" 47 | x = x.view(n_batch, self._size ** 2, -1) 48 | xs = [] 49 | for i in range(4): 50 | patch = x[:, self._sub_masks[i], :].view(n_batch, sz, sz, -1) 51 | xs.append(patch) 52 | return (*xs, dlogp) 53 | return x, dlogp 54 | 55 | def _inverse(self, *xs, **kwargs): 56 | n_batch = xs[0].shape[0] 57 | dlogp = torch.zeros(n_batch) 58 | sz = self._size // 2 59 | assert len(xs) == 4 60 | assert all(x.shape[1] == self._size // 2 and x.shape[2] == self._size // 2 for x in xs),\ 61 | "all `xs` needs to be of shape `[n_batch, size, size, n_filters]`" 62 | xs = [x.view(n_batch, sz ** 2, -1) for x in xs] 63 | x = torch.cat(xs, axis=-2)[:, self._inv_mask, :].view( 64 | n_batch, self._size, self._size, -1 65 | ) 66 | return x, dlogp 67 | -------------------------------------------------------------------------------- /bgflow/nn/flow/crd_transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .pca import * 2 | from .ic import * 3 | -------------------------------------------------------------------------------- /bgflow/nn/flow/crd_transform/pca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from bgflow.nn.flow.base import Flow 4 | 5 | 6 | __all__ = ["WhitenFlow"] 7 | 8 | 9 | def _pca(X0, keepdims=None): 10 | """Implements PCA in Numpy. 11 | 12 | This is not written for training in torch because derivatives of eig are not implemented 13 | 14 | """ 15 | if keepdims is None: 16 | keepdims = X0.shape[1] 17 | 18 | # pca 19 | X0mean = X0.mean(axis=0) 20 | X0meanfree = X0 - X0mean 21 | C = np.matmul(X0meanfree.T, X0meanfree) / (X0meanfree.shape[0] - 1.0) 22 | eigval, eigvec = np.linalg.eigh(C) 23 | 24 | # sort in descending order and keep only the wanted eigenpairs 25 | I = np.argsort(eigval)[::-1] 26 | I = I[:keepdims] 27 | eigval = eigval[I] 28 | std = np.sqrt(eigval) 29 | eigvec = eigvec[:, I] 30 | 31 | # whiten and unwhiten matrices 32 | Twhiten = np.matmul(eigvec, np.diag(1.0 / std)) 33 | Tblacken = np.matmul(np.diag(std), eigvec.T) 34 | return X0mean, Twhiten, Tblacken, std 35 | 36 | 37 | class WhitenFlow(Flow): 38 | def __init__(self, X0, keepdims=None, whiten_inverse=True): 39 | """Performs static whitening of the data given PCA of X0 40 | 41 | Parameters: 42 | ----------- 43 | X0 : array 44 | Initial Data on which PCA will be computed. 45 | keepdims : int or None 46 | Number of dimensions to keep. By default, all dimensions will be kept 47 | whiten_inverse : bool 48 | Whitens when calling inverse (default). Otherwise when calling forward 49 | 50 | """ 51 | super().__init__() 52 | if keepdims is None: 53 | keepdims = X0.shape[1] 54 | self.dim = X0.shape[1] 55 | self.keepdims = keepdims 56 | self.whiten_inverse = whiten_inverse 57 | 58 | X0_np = X0.detach().cpu().numpy() 59 | X0mean, Twhiten, Tblacken, std = _pca(X0_np, keepdims=keepdims) 60 | # self.X0mean = torch.tensor(X0mean) 61 | self.register_buffer("X0mean", torch.tensor(X0mean).to(X0)) 62 | # self.Twhiten = torch.tensor(Twhiten) 63 | self.register_buffer("Twhiten", torch.tensor(Twhiten).to(X0)) 64 | # self.Tblacken = torch.tensor(Tblacken) 65 | self.register_buffer("Tblacken", torch.tensor(Tblacken).to(X0)) 66 | # self.std = torch.tensor(std) 67 | self.register_buffer("std", torch.tensor(std).to(X0)) 68 | if torch.any(self.std <= 0): 69 | raise ValueError( 70 | "Cannot construct whiten layer because trying to keep nonpositive eigenvalues." 71 | ) 72 | self.jacobian_xz = -torch.sum(torch.log(self.std)) 73 | 74 | def _whiten(self, x): 75 | # Whiten 76 | output_z = torch.matmul(x - self.X0mean, self.Twhiten) 77 | # if self.keepdims < self.dim: 78 | # junk_dims = self.dim - self.keepdims 79 | # output_z = torch.cat([output_z, torch.Tensor(x.shape[0], junk_dims).normal_()], dim=1) 80 | # Jacobian 81 | dlogp = self.jacobian_xz * torch.ones((x.shape[0], 1)).to(x) 82 | 83 | return output_z, dlogp 84 | 85 | def _blacken(self, x): 86 | # if we have reduced the dimension, we ignore the last dimensions from the z-direction. 87 | # if self.keepdims < self.dim: 88 | # x = x[:, 0:self.keepdims] 89 | output_x = torch.matmul(x, self.Tblacken) + self.X0mean 90 | # Jacobian 91 | dlogp = -self.jacobian_xz * torch.ones((x.shape[0], 1)).to(x) 92 | 93 | return output_x, dlogp 94 | 95 | def _forward(self, x, *args, **kwargs): 96 | if self.whiten_inverse: 97 | y, dlogp = self._blacken(x) 98 | else: 99 | y, dlogp = self._whiten(x) 100 | return y, dlogp 101 | 102 | def _inverse(self, x, *args, **kwargs): 103 | if self.whiten_inverse: 104 | y, dlogp = self._whiten(x) 105 | else: 106 | y, dlogp = self._blacken(x) 107 | return y, dlogp 108 | -------------------------------------------------------------------------------- /bgflow/nn/flow/diffeq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Flow 4 | from .dynamics import ( 5 | DensityDynamics, 6 | InversedDynamics, 7 | AnodeDynamics 8 | ) 9 | 10 | 11 | class DiffEqFlow(Flow): 12 | """ 13 | Neural Ordinary Differential Equations flow :footcite:`chen2018neural` 14 | with the choice of optimize than discretize (use_checkpoints=False) 15 | and discretize than optimize :footcite:`gholami2019anode` (use_checkpoints=True) for the ODE solver. 16 | 17 | References 18 | ---------- 19 | .. footbibliography:: 20 | 21 | """ 22 | 23 | def __init__( 24 | self, 25 | dynamics, 26 | integrator="dopri5", 27 | atol=1e-10, 28 | rtol=1e-5, 29 | n_time_steps=2, 30 | t_max=1., 31 | use_checkpoints=False, 32 | **kwargs 33 | ): 34 | super().__init__() 35 | self._dynamics = DensityDynamics(dynamics) 36 | self._inverse_dynamics = DensityDynamics(InversedDynamics(dynamics, t_max)) 37 | self._integrator_method = integrator 38 | self._integrator_atol = atol 39 | self._integrator_rtol = rtol 40 | self._n_time_steps = n_time_steps 41 | self._t_max = t_max 42 | self._use_checkpoints = use_checkpoints 43 | self._kwargs = kwargs 44 | 45 | def _forward(self, *xs, **kwargs): 46 | return self._run_ode(*xs, dynamics=self._dynamics, **kwargs) 47 | 48 | def _inverse(self, *xs, **kwargs): 49 | return self._run_ode(*xs, dynamics=self._inverse_dynamics, **kwargs) 50 | 51 | def _run_ode(self, *xs, dynamics, **kwargs): 52 | """ 53 | Runs the ODE solver. 54 | 55 | Parameters 56 | ---------- 57 | xs : PyTorch tensor 58 | The current configuration of the system 59 | dynamics : PyTorch module 60 | A dynamics function that computes the change of the system and its density. 61 | 62 | Returns 63 | ------- 64 | ys : PyTorch tensor 65 | The new configuration of the system after being propagated by the dynamics. 66 | dlogp : PyTorch tensor 67 | The change in log density due to the dynamics. 68 | """ 69 | 70 | assert (all(x.shape[0] == xs[0].shape[0] for x in xs[1:])) 71 | n_batch = xs[0].shape[0] 72 | logp_init = torch.zeros(n_batch, 1).to(xs[0]) 73 | state = (*xs, logp_init) 74 | ts = torch.linspace(0.0, self._t_max, self._n_time_steps).to(xs[0]) 75 | kwargs = {**self._kwargs, **kwargs} 76 | if not self._use_checkpoints: 77 | from torchdiffeq import odeint_adjoint 78 | *ys, dlogp = odeint_adjoint( 79 | dynamics, 80 | state, 81 | t=ts, 82 | method=self._integrator_method, 83 | rtol=self._integrator_rtol, 84 | atol=self._integrator_atol, 85 | options=kwargs 86 | ) 87 | ys = [y[-1] for y in ys] 88 | else: 89 | from anode.adjoint import odesolver_adjoint 90 | state = torch.cat(state, dim=-1) 91 | anode_dynamics = AnodeDynamics(dynamics) 92 | state = odesolver_adjoint(anode_dynamics, state, options=kwargs) 93 | ys = [state[:, :-1]] 94 | dlogp = [state[:, -1:]] 95 | dlogp = dlogp[-1] 96 | return (*ys, dlogp) 97 | -------------------------------------------------------------------------------- /bgflow/nn/flow/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from .density import DensityDynamics 2 | from .inversed import InversedDynamics 3 | from .blackbox import BlackBoxDynamics 4 | from .simple import TimeIndependentDynamics 5 | from .kernel_dynamic import KernelDynamics 6 | from .anode_dynamic import AnodeDynamics -------------------------------------------------------------------------------- /bgflow/nn/flow/dynamics/anode_dynamic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AnodeDynamics(torch.nn.Module): 5 | """Wrapper class to allow the use of the ANODE ode solver. 6 | 7 | Attributes 8 | ---------- 9 | dynamics : torch.nn.Module 10 | A dynamics function that computes the change of the system and its density. 11 | """ 12 | 13 | def __init__(self, dynamics): 14 | super().__init__() 15 | self._dynamics = dynamics 16 | 17 | def forward(self, t, state): 18 | """ 19 | Converts the the concatenated state, which is required for the ANODE ode solver, 20 | to the tuple (`xs`, `dlogp`) for the following dynamics function. 21 | Then the output is concatenated again for the ANODE ode solver. 22 | 23 | Parameters 24 | ---------- 25 | t : PyTorch tensor 26 | The current time 27 | state : PyTorch tensor 28 | The current state of the system 29 | 30 | Returns 31 | ------- 32 | state : PyTorch tensor 33 | The combined state update of shape `[n_batch, n_dimensions]` 34 | containing the state update of the system state `dx/dt` 35 | (`state[:, :-1]`) and the update of the log density (`state[:, -1]`). 36 | """ 37 | xs = state[:, :-1] 38 | dlogp = state[:, -1:] 39 | state = (xs, dlogp) 40 | *dxs, div = self._dynamics(t, state) 41 | state = torch.cat([*dxs, div], dim=-1) 42 | return state 43 | -------------------------------------------------------------------------------- /bgflow/nn/flow/dynamics/blackbox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BlackBoxDynamics(torch.nn.Module): 5 | """Black box dynamics that allows to use any dynamics function. 6 | The divergence of the dynamics is computed with a divergence estimator. 7 | """ 8 | 9 | def __init__(self, dynamics_function, divergence_estimator, compute_divergence=True): 10 | super().__init__() 11 | self._dynamics_function = dynamics_function 12 | self._divergence_estimator = divergence_estimator 13 | self._compute_divergence = compute_divergence 14 | 15 | def forward(self, t, *xs): 16 | """ 17 | Computes the change of the system `dxs` at state `xs` and 18 | time `t`. Furthermore, can also compute the change of log density 19 | which is equal to the divergence of the change. 20 | 21 | Parameters 22 | ---------- 23 | t : PyTorch tensor 24 | The current time 25 | xs : PyTorch tensor 26 | The current configuration of the system 27 | 28 | Returns 29 | ------- 30 | (*dxs, divergence): Tuple of PyTorch tensors 31 | The combined state update of shape `[n_batch, n_dimensions]` 32 | containing the state update of the system state `dx/dt` 33 | (`dxs`) and the update of the log density (`dlogp`) 34 | """ 35 | if self._compute_divergence: 36 | *dxs, divergence = self._divergence_estimator( 37 | self._dynamics_function, t, *xs 38 | ) 39 | else: 40 | dxs = self._dynamics_function(t, xs) 41 | divergence = None 42 | return (*dxs, divergence) 43 | -------------------------------------------------------------------------------- /bgflow/nn/flow/dynamics/density.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DensityDynamics(torch.nn.Module): 5 | 6 | def __init__(self, dynamics): 7 | super().__init__() 8 | self._dynamics = dynamics 9 | self._n_evals = 0 10 | 11 | def forward(self, t, state): 12 | """ 13 | Computes the change of the system `dx/dt` at state `x` and 14 | time `t`. Furthermore, computes the change of density, happening 15 | due to moving `x` infinitesimally in the direction `dx/dt` 16 | according to the "instantaneous change of variables rule" [1] 17 | `dlogp(p(x(t))/dt = -div(dx(t)/dt)` 18 | [1] Neural Ordinary Differential Equations, Chen et. al, 19 | https://arxiv.org/abs/1806.07366 20 | 21 | Parameters 22 | ---------- 23 | t : PyTorch tensor 24 | The current time 25 | state : PyTorch tensor 26 | The current state of the system. 27 | Consisting of the configuration `xs` and the log density change `dlogp`. 28 | 29 | Returns 30 | ------- 31 | (*dxs, -dlogp) : Tuple of PyTorch tensors 32 | The combined state update of shape `[n_batch, n_dimensions]` 33 | containing the state update of the system state `dx/dt` 34 | (`dxs`) and the update of the log density (`dlogp`). 35 | """ 36 | xs = state[:-1] 37 | *dxs, dlogp = self._dynamics(t, *xs) 38 | return (*dxs, -dlogp) 39 | -------------------------------------------------------------------------------- /bgflow/nn/flow/dynamics/inversed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class InversedDynamics(torch.nn.Module): 5 | """ 6 | Inverse of a dynamics for the inverse flow. 7 | """ 8 | 9 | def __init__(self, dynamics, t_max=1.0): 10 | super().__init__() 11 | self._dynamics = dynamics 12 | self._t_max = t_max 13 | 14 | def forward(self, t, state): 15 | """ 16 | Evaluates the change of the system `dxs` at time `t_max` - `t` for the inverse dynamics. 17 | 18 | Parameters 19 | ---------- 20 | t : PyTorch tensor 21 | The current time 22 | state : PyTorch tensor 23 | The current state of the system 24 | 25 | Returns 26 | ------- 27 | [-*dxs, -dlogp] : List of PyTorch tensors 28 | The combined state update of shape `[n_batch, n_dimensions]` 29 | containing the state update of the system state `dx/dt` 30 | (`-dxs`) and the update of the log density (`-dlogp`). 31 | """ 32 | 33 | *dxs, dlogp = self._dynamics(self._t_max - t, state) 34 | return [-dx for dx in dxs] + [-dlogp] 35 | -------------------------------------------------------------------------------- /bgflow/nn/flow/dynamics/simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TimeIndependentDynamics(torch.nn.Module): 5 | """ 6 | Time independent dynamics function. 7 | """ 8 | 9 | def __init__(self, dynamics): 10 | super().__init__() 11 | self._dynamics = dynamics 12 | 13 | def forward(self, t, xs): 14 | """ 15 | Computes the change of the system `dxs` due to a time independent dynamics function. 16 | 17 | Parameters 18 | ---------- 19 | t : PyTorch tensor 20 | The current time 21 | xs : PyTorch tensor 22 | The current configuration of the system 23 | 24 | Returns 25 | ------- 26 | dxs : PyTorch tensor 27 | The change of the system due to the dynamics function 28 | """ 29 | 30 | dxs = self._dynamics(xs) 31 | return dxs 32 | -------------------------------------------------------------------------------- /bgflow/nn/flow/elementwise.py: -------------------------------------------------------------------------------- 1 | """Nonlinear One-dimensional Diffeomorphisms""" 2 | 3 | import torch 4 | from bgflow.nn.flow.base import Flow 5 | 6 | 7 | __all__ = ["BentIdentity"] 8 | 9 | 10 | class BentIdentity(Flow): 11 | """Bent identity. A nonlinear diffeomorphism with analytic gradients and inverse. 12 | See https://towardsdatascience.com/secret-sauce-behind-the-beauty-of-deep-learning-beginners-guide-to-activation-functions-a8e23a57d046 . 13 | """ 14 | def __init__(self): 15 | super(BentIdentity, self).__init__() 16 | 17 | def _forward(self, x, **kwargs): 18 | """Forward transform 19 | 20 | Parameters 21 | ---------- 22 | x : torch.tensor 23 | Input tensor 24 | 25 | kwargs : dict 26 | Miscellaneous arguments to satisfy the interface. 27 | 28 | Returns 29 | ------- 30 | y : torch.tensor 31 | Elementwise transformed tensor with the same shape as x. 32 | 33 | dlogp : torch.tensor 34 | Natural log of the Jacobian determinant. 35 | """ 36 | dlogp = torch.log(self.derivative(x)).sum(dim=-1, keepdim=True) 37 | return (torch.sqrt(x ** 2 + 1) - 1) / 2 + x, dlogp 38 | 39 | def _inverse(self, x, **kwargs): 40 | """Inverse transform 41 | 42 | Parameters 43 | ---------- 44 | x : torch.tensor 45 | Input tensor 46 | 47 | kwargs : dict 48 | Miscellaneous arguments to satisfy the interface. 49 | 50 | Returns 51 | ------- 52 | y : torch.tensor 53 | Elementwise transformed tensor with the same shape as x. 54 | 55 | dlogp : torch.tensor 56 | Natural log of the Jacobian determinant. 57 | """ 58 | dlogp = torch.log(self.inverse_derivative(x)).sum(dim=-1, keepdim=True) 59 | return 2 / 3 * (2 * x + 1 - torch.sqrt(x ** 2 + x + 1)), dlogp 60 | 61 | @staticmethod 62 | def derivative(x): 63 | """Elementwise derivative of the activation function.""" 64 | return x / (2 * torch.sqrt(x ** 2 + 1)) + 1 65 | 66 | @staticmethod 67 | def inverse_derivative(x): 68 | """Elementwise derivative of the inverse activation function.""" 69 | return 4 / 3 - (2 * x + 1) / (3 * torch.sqrt(x ** 2 + x + 1)) 70 | -------------------------------------------------------------------------------- /bgflow/nn/flow/estimator/__init__.py: -------------------------------------------------------------------------------- 1 | from . brute_force_estimator import BruteForceEstimator 2 | from .hutchinson_estimator import HutchinsonEstimator -------------------------------------------------------------------------------- /bgflow/nn/flow/estimator/brute_force_estimator.py: -------------------------------------------------------------------------------- 1 | from bgflow.utils.autograd import brute_force_jacobian_trace 2 | import torch 3 | 4 | 5 | class BruteForceEstimator(torch.nn.Module): 6 | """ 7 | Exact bruteforce estimation of the divergence of a dynamics function. 8 | """ 9 | 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, dynamics, t, xs): 14 | """ 15 | Computes the change of the system `dxs` due to a time independent dynamics function. 16 | Furthermore, also computes the exact change of log density 17 | which is equal to the divergence of the change `dxs`. 18 | 19 | Parameters 20 | ---------- 21 | dynamics : torch.nn.Module 22 | A dynamics function that computes the change of the system and its density. 23 | t : PyTorch tensor 24 | The current time 25 | xs : PyTorch tensor 26 | The current configuration of the system 27 | 28 | Returns 29 | ------- 30 | dxs, -divergence: PyTorch tensors 31 | The combined state update of shape `[n_batch, n_dimensions]` 32 | containing the state update of the system state `dx/dt` 33 | (`dxs`) and the negative update of the log density (`-divergence`) 34 | """ 35 | 36 | with torch.set_grad_enabled(True): 37 | xs.requires_grad_(True) 38 | dxs = dynamics(t, xs) 39 | 40 | assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]" 41 | divergence = brute_force_jacobian_trace(dxs, xs) 42 | 43 | return dxs, -divergence.view(-1, 1) 44 | -------------------------------------------------------------------------------- /bgflow/nn/flow/estimator/hutchinson_estimator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class HutchinsonEstimator(torch.nn.Module): 5 | """ 6 | Estimation of the divergence of a dynamics function with the Hutchinson Estimator [1]. 7 | [1] A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines, Hutchinson 8 | """ 9 | 10 | def __init__(self, rademacher=True): 11 | super().__init__() 12 | self._rademacher = rademacher 13 | self._reset_noise = True 14 | 15 | def reset_noise(self, reset_noise=True): 16 | """ 17 | Resets the noise vector. 18 | """ 19 | 20 | self._reset_noise = reset_noise 21 | 22 | def forward(self, dynamics, t, xs): 23 | """ 24 | Computes the change of the system `dxs` due to a time independent dynamics function. 25 | Furthermore, also estimates the change of log density, which is equal to the divergence of the change `dxs`, 26 | with the Hutchinson Estimator. 27 | This is done with either Rademacher or Gaussian noise. 28 | 29 | Parameters 30 | ---------- 31 | dynamics : torch.nn.Module 32 | A dynamics function that computes the change of the system and its density. 33 | t : PyTorch tensor 34 | The current time 35 | xs : PyTorch tensor 36 | The current configuration of the system 37 | 38 | Returns 39 | ------- 40 | dxs, -divergence: PyTorch tensors 41 | The combined state update of shape `[n_batch, n_dimensions]` 42 | containing the state update of the system state `dx/dt` 43 | (`dxs`) and the negative update of the log density (`-divergence`) 44 | """ 45 | 46 | with torch.set_grad_enabled(True): 47 | xs.requires_grad_(True) 48 | dxs = dynamics(t, xs) 49 | 50 | assert len(dxs.shape) == 2, "`dxs` must have shape [n_btach, system_dim]" 51 | system_dim = dxs.shape[-1] 52 | 53 | if self._reset_noise == True: 54 | self._reset_noise = False 55 | if self._rademacher == True: 56 | self._noise = torch.randint(low=0, high=2, size=xs.shape).to(xs) * 2 - 1 57 | else: 58 | self._noise = torch.randn_like(xs) 59 | 60 | noise_ddxs = torch.autograd.grad(dxs, xs, self._noise, create_graph=True)[0] 61 | divergence = torch.sum((noise_ddxs * self._noise).view(-1, system_dim), 1, keepdim=True) 62 | 63 | return dxs, -divergence 64 | -------------------------------------------------------------------------------- /bgflow/nn/flow/funnel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .base import Flow 5 | 6 | # TODO: write docstrings 7 | # TODO: refactor messy implementation 8 | 9 | 10 | class FunnelFlow(Flow): 11 | def __init__(self, eps=1e-6, min_val=-1.0, max_val=1.0): 12 | super().__init__() 13 | self._eps = eps 14 | self._min_val = min_val 15 | self._max_val = max_val 16 | 17 | def _forward(self, x, **kwargs): 18 | dlogp = ( 19 | torch.nn.functional.logsigmoid(x) 20 | - torch.nn.functional.softplus(x) 21 | + np.log(self._max_val - self._min_val) 22 | ).sum(dim=-1, keepdim=True) 23 | x = torch.sigmoid(x) 24 | x = x * (self._max_val - self._min_val) + self._min_val 25 | x = torch.clamp(x, self._min_val + self._eps, self._max_val - self._eps) 26 | return x, dlogp 27 | 28 | def _inverse(self, x, **kwargs): 29 | x = torch.clamp(x, self._min_val + self._eps, self._max_val - self._eps) 30 | x = (x - self._min_val) / (self._max_val - self._min_val) 31 | dlogp = ( 32 | -torch.log(x - x.pow(2)) 33 | - np.log(self._max_val - self._min_val) 34 | ).sum(dim=-1, keepdim=True) 35 | x = torch.log(x) - torch.log(1 - x) 36 | return x, dlogp -------------------------------------------------------------------------------- /bgflow/nn/flow/inverted.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import Flow 3 | 4 | __all__ = ["InverseFlow"] 5 | 6 | 7 | class InverseFlow(Flow): 8 | """The inverse of a given transform. 9 | 10 | Parameters 11 | ---------- 12 | delegate : Flow 13 | The flow to invert. 14 | """ 15 | def __init__(self, delegate): 16 | super().__init__() 17 | self._delegate = delegate 18 | 19 | def _forward(self, *xs, **kwargs): 20 | return self._delegate._inverse(*xs, **kwargs) 21 | 22 | def _inverse(self, *xs, **kwargs): 23 | return self._delegate._forward(*xs, **kwargs) -------------------------------------------------------------------------------- /bgflow/nn/flow/kronecker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | from .base import Flow 6 | 7 | 8 | # TODO: write docstrings 9 | 10 | 11 | def _is_power2(x): 12 | return x != 0 and ((x & (x - 1)) == 0) 13 | 14 | 15 | def _kronecker(A, B): 16 | return torch.einsum("ab,cd->acbd", A, B).view( 17 | A.size(0) * B.size(0), A.size(1) * B.size(1) 18 | ) 19 | 20 | 21 | def _batch_determinant_2x2(As, log=False): 22 | result = As[:, 0, 0] * As[:, 1, 1] - As[:, 1, 0] * As[:, 0, 1] 23 | if log: 24 | result = result.abs().log() 25 | return result 26 | 27 | 28 | def _create_ortho_matrices(n, d): 29 | qs = [] 30 | for i in range(n): 31 | q, _ = np.linalg.qr(np.random.normal(size=(d, d))) 32 | qs.append(q) 33 | qs = np.array(qs) 34 | return qs 35 | 36 | 37 | class KroneckerProductFlow(Flow): 38 | def __init__(self, n_dim): 39 | super().__init__() 40 | 41 | assert _is_power2(n_dim) 42 | 43 | self._n_dim = n_dim 44 | self._n_factors = int(np.log2(n_dim)) 45 | 46 | self._factors = torch.nn.Parameter( 47 | torch.Tensor(_create_ortho_matrices(self._n_factors, 2)) 48 | ) 49 | self._bias = torch.nn.Parameter(torch.Tensor(1, n_dim).zero_()) 50 | 51 | def _forward(self, x, **kwargs): 52 | n_batch = x.shape[0] 53 | factors = self._factors.to(x) 54 | M = factors[0] 55 | dets = _batch_determinant_2x2(factors) 56 | det = dets[0] 57 | power = 2 58 | for new_det, factor in zip(dets[1:], factors[1:]): 59 | det = det.pow(2) * new_det.pow(power) 60 | M = _kronecker(M, factor) 61 | power = power * 2 62 | dlogp = torch.zeros(n_batch, 1).to(x) 63 | dlogp = dlogp + det.abs().log().sum(dim=-1, keepdim=True) 64 | return x @ M + self._bias.to(x), dlogp 65 | 66 | def _inverse(self, x, **kwargs): 67 | n_batch = x.shape[0] 68 | factors = self._factors.to(x) 69 | inv_factors = torch.inverse(factors) 70 | M = inv_factors[0] 71 | inv_dets = _batch_determinant_2x2(inv_factors) 72 | inv_det = inv_dets[0] 73 | power = 2 74 | for new_inv_det, factor in zip(inv_dets[1:], inv_factors[1:]): 75 | inv_det = inv_det.pow(2) * new_inv_det.pow(power) 76 | M = _kronecker(M, factor) 77 | power = power * 2 78 | dlogp = torch.zeros(n_batch, 1).to(x) 79 | dlogp = dlogp + inv_det.abs().log().sum(dim=-1, keepdim=True) 80 | return (x - self._bias.to(x)) @ M, dlogp 81 | -------------------------------------------------------------------------------- /bgflow/nn/flow/modulo.py: -------------------------------------------------------------------------------- 1 | __all__ = ["IncreaseMultiplicityFlow", "CircularShiftFlow"] 2 | 3 | import torch 4 | from bgflow.nn.flow.transformer.base import Flow 5 | 6 | 7 | class IncreaseMultiplicityFlow(Flow): 8 | """A flow that increases the multiplicity of torsional degrees of freedom. 9 | The input and output tensors are expected to be in [0,1]. 10 | The output represents the sum over sheaves from the input. 11 | 12 | Parameters 13 | ---------- 14 | multiplicities : Union[torch.Tensor, int] 15 | A tensor of integers that define the number of periods in the unit interval. 16 | """ 17 | 18 | def __init__(self, multiplicities): 19 | super().__init__() 20 | self.register_buffer("_multiplicities", torch.as_tensor(multiplicities)) 21 | 22 | def _forward(self, x, **kwargs): 23 | _assert_in_unit_interval(x) 24 | multiplicities = torch.ones_like(x) * self._multiplicities 25 | sheaves = _randint(multiplicities) 26 | y = (x + sheaves) / self._multiplicities 27 | dlogp = torch.zeros_like(x[..., [0]]) 28 | return y, dlogp 29 | 30 | def _inverse(self, x, **kwargs): 31 | _assert_in_unit_interval(x) 32 | y = (x % (1 / self._multiplicities)) * self._multiplicities 33 | dlogp = torch.zeros_like(x[..., [0]]) 34 | return y, dlogp 35 | 36 | 37 | def _randint(high): 38 | with torch.no_grad(): 39 | return torch.floor(torch.rand(high.shape, device=high.device) * high) 40 | 41 | 42 | def _assert_in_unit_interval(x): 43 | if (x > 1 + 1e-6).any() or (x < - 1e-6).any(): 44 | raise ValueError(f'IncreaseMultiplicityFlow operates on [0,1] but input was {x}') 45 | 46 | 47 | class CircularShiftFlow(Flow): 48 | """A flow that shifts the position of torsional degrees of freedom. 49 | The input and output tensors are expected to be in [0,1]. 50 | The output is a translated version of the input, respecting circulariry. 51 | 52 | Parameters 53 | ---------- 54 | shift : Union[torch.Tensor, float] 55 | A tensor that defines the translation of the circular interval 56 | """ 57 | 58 | def __init__(self, shift): 59 | super().__init__() 60 | self.register_buffer("_shift", torch.as_tensor(shift)) 61 | 62 | def _forward(self, x, **kwargs): 63 | _assert_in_unit_interval(x) 64 | y = (x + self._shift) % 1 65 | dlogp = torch.zeros_like(x[..., [0]]) 66 | return y, dlogp 67 | 68 | def _inverse(self, x, **kwargs): 69 | _assert_in_unit_interval(x) 70 | y = (x - self._shift) % 1 71 | dlogp = torch.zeros_like(x[..., [0]]) 72 | return y, dlogp 73 | -------------------------------------------------------------------------------- /bgflow/nn/flow/orthogonal.py: -------------------------------------------------------------------------------- 1 | """ 2 | (Pseudo-) Orthogonal Linear Layers. Advantage: Jacobian determinant is unity. 3 | """ 4 | 5 | import torch 6 | from bgflow.nn.flow.base import Flow 7 | 8 | __all__ = ["PseudoOrthogonalFlow"] 9 | 10 | # Note: OrthogonalPPPP is implemented in pppp.py 11 | 12 | 13 | class PseudoOrthogonalFlow(Flow): 14 | """Linear flow W*x+b with a penalty function 15 | penalty_parameter*||W^T W - I||^2 16 | 17 | Attributes 18 | ---------- 19 | dim : int 20 | dimension 21 | shift : boolean 22 | Whether to use a shift parameter (+b). If False, b=0. 23 | penalty_parameter : float 24 | Scaling factor for the orthogonality constraint. 25 | """ 26 | def __init__(self, dim, shift=True, penalty_parameter=1e5): 27 | super(PseudoOrthogonalFlow, self).__init__() 28 | self.dim = dim 29 | self.W = torch.nn.Parameter(torch.eye(dim)) 30 | if shift: 31 | self.b = torch.nn.Parameter(torch.zeros(dim)) 32 | else: 33 | self.register_buffer("b", torch.tensor(0.0)) 34 | self.register_buffer("penalty_parameter", torch.tensor(penalty_parameter)) 35 | 36 | def _forward(self, x, **kwargs): 37 | """Forward transform. 38 | 39 | Attributes 40 | ---------- 41 | x : torch.tensor 42 | The input vector. The transform is applied to the last dimension. 43 | kwargs : dict 44 | keyword arguments to satisfy the interface 45 | 46 | Returns 47 | ------- 48 | y : torch.tensor 49 | W*x + b 50 | dlogp : torch.tensor 51 | natural log of the Jacobian determinant 52 | """ 53 | dlogp = torch.zeros(*x.shape[:-1], 1).to(x) 54 | y = torch.einsum("ab,...b->...a", self.W, x) 55 | return y + self.b, dlogp 56 | 57 | def _inverse(self, y, **kwargs): 58 | """Inverse transform assuming that W is orthogonal. 59 | 60 | Attributes 61 | ---------- 62 | y : torch.tensor 63 | The input vector. The transform is applied to the last dimension. 64 | kwargs : dict 65 | keyword arguments to satisfy the interface 66 | 67 | Returns 68 | ------- 69 | x : torch.tensor 70 | W^T*(y-b) 71 | dlogp : torch.tensor 72 | natural log of the Jacobian determinant 73 | """ 74 | dlogp = torch.zeros(*y.shape[:-1], 1).to(y) 75 | x = torch.einsum("ab,...b->...a", self.W.transpose(1, 0), y - self.b) 76 | return x, dlogp 77 | 78 | def penalty(self): 79 | """Penalty function for the orthogonality constraint 80 | 81 | p(W) = penalty_parameter * ||W^T*W - I||^2. 82 | 83 | Returns 84 | ------- 85 | penalty : float 86 | Value of the penalty function 87 | """ 88 | return self.penalty_parameter * torch.sum((torch.eye(self.dim) - torch.mm(self.W.transpose(1, 0), self.W)) ** 2) 89 | -------------------------------------------------------------------------------- /bgflow/nn/flow/sequential.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | 5 | from .base import Flow 6 | 7 | logger = logging.getLogger('bgflow') 8 | 9 | 10 | class SequentialFlow(Flow): 11 | def __init__(self, blocks): 12 | """ 13 | Represents a diffeomorphism that can be computed 14 | as a discrete finite stack of layers. 15 | 16 | Returns the transformed variable and the log determinant 17 | of the Jacobian matrix. 18 | 19 | Parameters 20 | ---------- 21 | blocks : Tuple / List of flow blocks 22 | """ 23 | super().__init__() 24 | self._blocks = torch.nn.ModuleList(blocks) 25 | 26 | def forward(self, *xs, inverse=False, **kwargs): 27 | """ 28 | Transforms the input along the diffeomorphism and returns 29 | the transformed variable together with the volume change. 30 | 31 | Parameters 32 | ---------- 33 | x : PyTorch Floating Tensor. 34 | Input variable to be transformed. 35 | Tensor of shape `[..., n_dimensions]`. 36 | inverse: boolean. 37 | Indicates whether forward or inverse transformation shall be performed. 38 | If `True` computes the inverse transformation. 39 | 40 | Returns 41 | ------- 42 | z: PyTorch Floating Tensor. 43 | Transformed variable. 44 | Tensor of shape `[..., n_dimensions]`. 45 | dlogp : PyTorch Floating Tensor. 46 | Total volume change as a result of the transformation. 47 | Corresponds to the log determinant of the Jacobian matrix. 48 | """ 49 | dlogp = 0.0 50 | blocks = self._blocks 51 | if inverse: 52 | blocks = reversed(blocks) 53 | for i, block in enumerate(blocks): 54 | logger.debug(f"Input shapes {[x.shape for x in xs]}") 55 | *xs, ddlogp = block(*xs, inverse=inverse, **kwargs) 56 | logger.debug(f"Flow block {i} (inverse={inverse}): {block}") 57 | logger.debug(f"Output shapes {[x.shape for x in xs]}") 58 | dlogp += ddlogp 59 | return (*xs, dlogp) 60 | 61 | def _forward(self, *args, **kwargs): 62 | return self.forward(*args, **kwargs, inverse=False) 63 | 64 | def _inverse(self, *args, **kwargs): 65 | return self.forward(*args, **kwargs, inverse=True) 66 | 67 | def trigger(self, function_name): 68 | """ 69 | Evaluate functions for all blocks that have a function with that name and return a tensor of the stacked results. 70 | """ 71 | results = [ 72 | getattr(block, function_name)() 73 | for block in self._blocks 74 | if hasattr(block, function_name) and callable(getattr(block, function_name)) 75 | ] 76 | if len(results) > 0 and all(res is not None for res in results): 77 | return torch.stack(results) 78 | else: 79 | return torch.zeros(0) 80 | 81 | def __iter__(self): 82 | return iter(self._blocks) 83 | 84 | def __getitem__(self, index): 85 | if isinstance(index, int): 86 | return self._blocks[index] 87 | else: 88 | indices = np.arange(len(self))[index] 89 | return SequentialFlow([self._blocks[i] for i in indices]) 90 | 91 | def __len__(self): 92 | return len(self._blocks) 93 | -------------------------------------------------------------------------------- /bgflow/nn/flow/stochastic/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcmc import MetropolisMCFlow 2 | from .langevin import LangevinFlow, BrownianFlow, OverdampedLangevinFlow 3 | from .augment import StochasticAugmentation 4 | from .snf_openmm import * 5 | -------------------------------------------------------------------------------- /bgflow/nn/flow/stochastic/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bgflow.nn.flow.base import Flow 4 | 5 | 6 | class StochasticAugmentation(Flow): 7 | def __init__(self, distribution): 8 | """ 9 | Stochastic augmentation layer 10 | 11 | Adds additional coordinates to the state vector by sampling them from distribution. 12 | Pre-sampled momenta can be passed through the layer by the kwarg "momenta" and 13 | transformed momenta are returned if kwarg "return_momenta" is set True. 14 | If momenta are returned, their contribution is not added to the Jacobian. 15 | This behavior is required in some MCMC samplers. 16 | 17 | Parameters 18 | ---------- 19 | distribution : Energy 20 | Energy object, needs sample and energy method. 21 | """ 22 | super().__init__() 23 | self.distribution = distribution 24 | self._cached_momenta_forward = None 25 | self._cached_momenta_backward = None 26 | 27 | def _forward(self, q, **kwargs): 28 | batch_size = q.shape[0] 29 | temperature = kwargs.get("temperature", 1.0) 30 | cache_momenta = kwargs.get("cache_momenta", False) 31 | # Add option to pass pre-sampled momenta as key word argument 32 | p = kwargs.get("momenta", None) 33 | if p is None: 34 | p = self.distribution.sample(batch_size, temperature=temperature) 35 | dlogp = self.distribution.energy(p, temperature=temperature) 36 | else: 37 | dlogp = torch.zeros(p.shape[0], 1).to(p) 38 | if cache_momenta: 39 | self._cached_momenta_forward = p 40 | x = torch.cat([q, p], dim=1) 41 | return x, dlogp 42 | 43 | def _inverse(self, x, **kwargs): 44 | return_momenta = kwargs.get("return_momenta", False) 45 | cache_momenta = kwargs.get("cache_momenta", False) 46 | p = x[:, self.distribution.dim :] 47 | temperature = kwargs.get("temperature", 1.0) 48 | # Add option to return transformed momenta as key word argument. 49 | # Momenta will be returned in same tensor as configurations 50 | if cache_momenta: 51 | self._cached_momenta_backward = p 52 | if return_momenta: 53 | return x, torch.zeros(p.shape[0], 1).to(p) 54 | dlogp = self.distribution.energy(p, temperature=temperature) 55 | return x[:, : self.distribution.dim], -dlogp 56 | -------------------------------------------------------------------------------- /bgflow/nn/flow/stochastic/mcmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bgflow.nn.flow.base import Flow 3 | 4 | class MetropolisMCFlow(Flow): 5 | def __init__(self, energy_model, nsteps=1, stepsize=0.01): 6 | """ Stochastic Flow layer that simulates Metropolis Monte Carlo 7 | 8 | """ 9 | super().__init__() 10 | self.energy_model = energy_model 11 | self.nsteps = nsteps 12 | self.stepsize = stepsize 13 | 14 | def _forward(self, x, **kwargs): 15 | """ Run a stochastic trajectory forward 16 | 17 | Parameters 18 | ---------- 19 | x : PyTorch Tensor 20 | Batch of input configurations 21 | 22 | Returns 23 | ------- 24 | x' : PyTorch Tensor 25 | Transformed configurations 26 | dW : PyTorch Tensor 27 | Nonequilibrium work done, always 0 for this process 28 | 29 | """ 30 | E0 = self.energy_model.energy(x) 31 | E = E0 32 | 33 | for i in range(self.nsteps): 34 | # proposal step 35 | dx = self.stepsize * torch.zeros_like(x).normal_() 36 | xprop = x + dx 37 | Eprop = self.energy_model.energy(xprop) 38 | 39 | # acceptance step 40 | acc = (torch.rand(x.shape[0], 1) < torch.exp(-(Eprop - E))).float() # selection variable: 0 or 1. 41 | x = (1-acc) * x + acc * xprop 42 | E = (1-acc) * E + acc * Eprop 43 | 44 | # Work is energy difference 45 | dW = E - E0 46 | 47 | return x, dW 48 | 49 | def _inverse(self, x, **kwargs): 50 | """ Same as forward """ 51 | return self._forward(x, **kwargs) 52 | -------------------------------------------------------------------------------- /bgflow/nn/flow/torchtransform.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from .base import Flow 4 | 5 | 6 | __all__ = ["TorchTransform"] 7 | 8 | 9 | class TorchTransform(Flow): 10 | """Wrap a torch.distributions.Transform as a Flow instance 11 | 12 | Parameters 13 | ---------- 14 | transform : torch.distributions.Transform 15 | The transform instance that should be wrapped as a Flow instance. 16 | reinterpreted_batch_ndims : int, optional 17 | Number of batch dimensions to be reinterpreted as event dimensions. 18 | If >0, this transform is wrapped in an torch.distributions.IndependentTransform instance. 19 | """ 20 | 21 | def __init__(self, transform, reinterpreted_batch_ndims=0): 22 | super().__init__() 23 | if reinterpreted_batch_ndims > 0: 24 | transform = torch.distributions.IndependentTransform(transform, reinterpreted_batch_ndims) 25 | self._delegate_transform = transform 26 | 27 | def _forward(self, x, **kwargs): 28 | y = self._delegate_transform(x) 29 | dlogp = self._delegate_transform.log_abs_det_jacobian(x, y) 30 | return y, dlogp[..., None] 31 | 32 | def _inverse(self, y, **kwargs): 33 | x = self._delegate_transform.inv(y) 34 | dlogp = - self._delegate_transform.log_abs_det_jacobian(x, y) 35 | return x, dlogp[..., None] 36 | -------------------------------------------------------------------------------- /bgflow/nn/flow/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import* 2 | from .affine import * 3 | from .entropy_scaling import * 4 | from .spline import * 5 | from .gaussian import * 6 | from .jax import * 7 | from .jax_bridge import * 8 | -------------------------------------------------------------------------------- /bgflow/nn/flow/transformer/affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Transformer 4 | 5 | # TODO: write docstring 6 | 7 | __all__ = ["AffineTransformer"] 8 | 9 | 10 | class AffineTransformer(Transformer): 11 | """RealNVP/NICE 12 | 13 | Parameters 14 | ---------- 15 | is_circular : bool 16 | Whether this transform is periodic on [0,1]. 17 | """ 18 | def __init__( 19 | self, 20 | shift_transformation=None, 21 | scale_transformation=None, 22 | init_downscale=1.0, 23 | preserve_volume=False, 24 | is_circular=False, 25 | ): 26 | if scale_transformation is not None and is_circular: 27 | raise ValueError("Scaling is not compatible with periodicity.") 28 | super().__init__() 29 | self._shift_transformation = shift_transformation 30 | self._scale_transformation = scale_transformation 31 | self._log_alpha = torch.nn.Parameter(torch.zeros(1) - init_downscale) 32 | self._preserve_volume = preserve_volume 33 | self._is_circular = is_circular 34 | 35 | def _get_mu_and_log_sigma(self, x, y, *cond): 36 | if self._shift_transformation is not None: 37 | mu = self._shift_transformation(x, *cond) 38 | else: 39 | mu = torch.zeros_like(y).to(x) 40 | if self._scale_transformation is not None: 41 | alpha = torch.exp(self._log_alpha.to(x)) 42 | log_sigma = torch.tanh(self._scale_transformation(x, *cond)) 43 | log_sigma = log_sigma * alpha 44 | if self._preserve_volume: 45 | log_sigma = log_sigma - log_sigma.mean(dim=-1, keepdim=True) 46 | else: 47 | log_sigma = torch.zeros_like(y).to(x) 48 | return mu, log_sigma 49 | 50 | def _forward(self, x, y, *cond, **kwargs): 51 | mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond) 52 | assert mu.shape[-1] == y.shape[-1] 53 | assert log_sigma.shape[-1] == y.shape[-1] 54 | sigma = torch.exp(log_sigma) 55 | dlogp = (log_sigma).sum(dim=-1, keepdim=True) 56 | y = sigma * y + mu 57 | if self._is_circular: 58 | y = y % 1.0 59 | return y, dlogp 60 | 61 | def _inverse(self, x, y, *cond, **kwargs): 62 | mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond) 63 | assert mu.shape[-1] == y.shape[-1] 64 | assert log_sigma.shape[-1] == y.shape[-1] 65 | sigma_inv = torch.exp(-log_sigma) 66 | dlogp = (-log_sigma).sum(dim=-1, keepdim=True) 67 | y = sigma_inv * (y - mu) 68 | if self._is_circular: 69 | y = y % 1.0 70 | return y, dlogp 71 | -------------------------------------------------------------------------------- /bgflow/nn/flow/transformer/base.py: -------------------------------------------------------------------------------- 1 | from ..base import Flow 2 | 3 | 4 | __all__ = ["Transformer"] 5 | 6 | 7 | class Transformer(Flow): 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def _forward(self, x, y, *args, **kwargs): 13 | raise NotImplementedError() 14 | 15 | def _inverse(self, x, y, *args, **kwargs): 16 | raise NotImplementedError() 17 | -------------------------------------------------------------------------------- /bgflow/nn/flow/transformer/entropy_scaling.py: -------------------------------------------------------------------------------- 1 | from .base import Flow 2 | from torch.nn.parameter import Parameter 3 | import torch 4 | 5 | 6 | __all__ = ["ScalingLayer", "EntropyScalingLayer"] 7 | 8 | 9 | class ScalingLayer(Flow): 10 | def __init__(self, init_factor=1.0, dim=1): 11 | super().__init__() 12 | self._scalefactor = Parameter(init_factor * torch.ones(1)) 13 | self.dim = dim 14 | 15 | def _forward(self, x, *cond, **kwargs): 16 | n_batch = x.shape[0] 17 | y = torch.zeros_like(x) 18 | y[:, : self.dim] = x[:, : self.dim] * self._scalefactor 19 | y[:, self.dim :] = x[:, self.dim :] 20 | return ( 21 | y, 22 | (self.dim * self._scalefactor.log()).repeat(n_batch, 1), 23 | ) 24 | 25 | def _inverse(self, x, *cond, **kwargs): 26 | n_batch = x.shape[0] 27 | y = torch.zeros_like(x) 28 | y[:, : self.dim] = x[:, : self.dim] / self._scalefactor 29 | y[:, self.dim :] = x[:, self.dim :] 30 | return ( 31 | y, 32 | (-self.dim * self._scalefactor.log()).repeat(n_batch, 1), 33 | ) 34 | 35 | 36 | class EntropyScalingLayer(Flow): 37 | def __init__(self, init_factor=1.0, dim=1): 38 | super().__init__() 39 | self._scalefactor = Parameter(init_factor * torch.ones(1)) 40 | self.dim = dim 41 | 42 | def _forward(self, x, y, *cond, **kwargs): 43 | n_batch = x.shape[0] 44 | return ( 45 | self._scalefactor * x, 46 | y, 47 | (self.dim * self._scalefactor.log()).repeat(n_batch, 1), 48 | ) 49 | 50 | def _inverse(self, x, y, *cond, **kwargs): 51 | n_batch = x.shape[0] 52 | return ( 53 | x / self._scalefactor, 54 | y, 55 | (-self.dim * self._scalefactor.log()).repeat(n_batch, 1), 56 | ) 57 | -------------------------------------------------------------------------------- /bgflow/nn/flow/transformer/jax.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | try: 4 | import jax 5 | import jax.numpy as jnp 6 | except ImportError: 7 | jax = None 8 | jnp = None 9 | 10 | 11 | __all__ = [ 12 | 'affine_transform', 13 | 'smooth_ramp', 14 | 'monomial_ramp', 15 | 'ramp_to_sigmoid', 16 | 'affine_sigmoid', 17 | 'wrap_around', 18 | 'remap_to_unit', 19 | 'mixture', 20 | ] 21 | 22 | 23 | def affine_transform(x, a, b): 24 | """Affine transform.""" 25 | return x * jnp.exp(a) + b 26 | 27 | 28 | def smooth_ramp(x, logalpha, power=1, eps=1e-9): 29 | """Smooth ramp.""" 30 | assert power > 0 31 | assert isinstance(power, int) 32 | assert eps > 0 33 | alpha = jnp.exp(logalpha) 34 | # double `where` trick to avoid NaN in backward pass 35 | z = jnp.where(x > eps, x, jnp.ones_like(x) * eps) 36 | normalizer = jnp.exp(-alpha * 1.) 37 | return jnp.where( 38 | x > eps, 39 | jnp.exp(-alpha * jnp.power(z, -power)) / normalizer, 40 | jnp.zeros_like(z)) 41 | 42 | 43 | def monomial_ramp(x, order=2): 44 | assert order > 0 and isinstance(order, int) 45 | return jnp.power(x, order) 46 | 47 | 48 | def ramp_to_sigmoid(ramp): 49 | """Generalized sigmoid, given a ramp.""" 50 | def _sigmoid(x, *params): 51 | numer = ramp(x, *params) 52 | denom = numer + ramp(1. - x, *params) 53 | return numer / denom 54 | return _sigmoid 55 | 56 | 57 | def affine_sigmoid(sigmoid, eps=1e-8): 58 | """Generalized affine sigmoid transform.""" 59 | assert eps > 0 60 | 61 | def _affine_sigmoid(x, shift, slope, mix, *params): 62 | slope = jnp.exp(slope) 63 | mix = jax.nn.sigmoid(mix) * (1. - eps) + eps 64 | return (mix * sigmoid(slope * (x - shift), *params) 65 | + (1. - mix) * x) 66 | return _affine_sigmoid 67 | 68 | 69 | def wrap_around(bijector, sheaves=None, weights=None): 70 | """Wraps affine sigmoid around circle.""" 71 | if sheaves is None: 72 | sheaves = jnp.array([-1, 0, 1]) 73 | if weights is None: 74 | weights = jnp.zeros_like(sheaves) 75 | mixture_ = mixture(bijector) 76 | 77 | def _wrapped(x, *params): 78 | x = x - sheaves[None] 79 | params = [jnp.repeat(p[..., None], len(sheaves), axis=-1) for p in params] 80 | return mixture_(x, weights, *params) 81 | return remap_to_unit(_wrapped) 82 | 83 | 84 | def remap_to_unit(fun): 85 | """Maps transformation back to [0, 1].""" 86 | @functools.wraps(fun) 87 | def _remapped(x, *params): 88 | y1 = fun(jnp.ones_like(x), *params) 89 | y0 = fun(jnp.zeros_like(x), *params) 90 | return (fun(x, *params) - y0) / (y1 - y0) 91 | return _remapped 92 | 93 | 94 | def mixture(bijector): 95 | """Combines multiple bijectors into a mixture.""" 96 | def _mixture_bijector(x, weights, *params): 97 | components = jax.vmap( 98 | functools.partial(bijector, x), 99 | in_axes=-1, 100 | out_axes=-1)(*params) 101 | return jnp.sum(jax.nn.softmax(weights) * components) 102 | return _mixture_bijector 103 | -------------------------------------------------------------------------------- /bgflow/nn/flow/triangular.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from bgflow.nn.flow.base import Flow 5 | 6 | 7 | __all__ = ["TriuFlow"] 8 | 9 | 10 | class TriuFlow(Flow): 11 | """Linear flow (I+R)*x+b with a upper triangular matrix R. 12 | 13 | Attributes 14 | ---------- 15 | dim : int 16 | dimension 17 | shift : boolean 18 | Whether to use a shift parameter (+b). If False, b=0. 19 | """ 20 | def __init__(self, dim, shift=True): 21 | super(TriuFlow, self).__init__() 22 | self.dim = dim 23 | self.register_buffer("indices", torch.triu_indices(dim, dim)) 24 | n_matrix_parameters = self.indices.shape[1] 25 | self._unique_elements = torch.nn.Parameter(torch.zeros(n_matrix_parameters)) 26 | if shift: 27 | self.b = torch.nn.Parameter(torch.zeros(dim)) 28 | else: 29 | self.register_buffer("b", torch.tensor(0.0)) 30 | self.register_buffer("R", torch.zeros((self.dim, self.dim))) 31 | 32 | def _make_r(self): 33 | self.R[:] = 0 34 | self.R[self.indices[0], self.indices[1]] = self._unique_elements 35 | self.R += torch.eye(self.dim) 36 | return self.R 37 | 38 | def _forward(self, x, **kwargs): 39 | """Forward transform. 40 | 41 | Attributes 42 | ---------- 43 | x : torch.tensor 44 | The input vector. The transform is applied to the last dimension. 45 | kwargs : dict 46 | keyword arguments to satisfy the interface 47 | 48 | Returns 49 | ------- 50 | y : torch.tensor 51 | W*x + b 52 | dlogp : torch.tensor 53 | natural log of the Jacobian determinant 54 | """ 55 | R = self._make_r() 56 | dlogp = torch.ones_like(x[...,0,None])*torch.sum(torch.log(torch.abs(torch.diagonal(R)))) 57 | y = torch.einsum("ab,...b->...a", R, x) 58 | return y + self.b, dlogp 59 | 60 | def _inverse(self, y, **kwargs): 61 | """Inverse transform. 62 | 63 | Attributes 64 | ---------- 65 | y : torch.tensor 66 | The input vector. The transform is applied to the last dimension. 67 | kwargs : dict 68 | keyword arguments to satisfy the interface 69 | 70 | Returns 71 | ------- 72 | x : torch.tensor 73 | W^T*(y-b) 74 | dlogp : torch.tensor 75 | natural log of the Jacobian determinant 76 | """ 77 | R = self._make_r() 78 | dlogp = torch.ones_like(y[...,0,None])*(-torch.sum(torch.log(torch.abs(torch.diagonal(R))))) 79 | try: 80 | x = torch.linalg.solve_triangular(R, (y-self.b)[...,None], upper=True) 81 | except AttributeError: 82 | # legacy call for torch < 1.11 83 | x, _ = torch.triangular_solve((y-self.b)[...,None], R) 84 | return x[...,0], dlogp 85 | -------------------------------------------------------------------------------- /bgflow/nn/periodic.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | __all__ = ["WrapPeriodic"] 5 | 6 | 7 | class WrapPeriodic(torch.nn.Module): 8 | """Wrap network inputs around a unit sphere. 9 | 10 | Parameters 11 | ---------- 12 | net : torch.nn.Module 13 | The module, whose inputs are wrapped around the unit sphere. 14 | left : float, optional 15 | Left boundary of periodic interval. 16 | right : float, optional 17 | Right boundary of periodic interval. 18 | indices : Union[Sequence[int], slice], optional 19 | Array of periodic input indices. 20 | Only indices covered by this index array are wrapped around the sphere. 21 | The default corresponds with all indices. 22 | """ 23 | def __init__(self, net, left=0.0, right=1.0, indices=slice(None)): 24 | super().__init__() 25 | self.net = net 26 | self.left = left 27 | self.right = right 28 | self.indices = indices 29 | 30 | def forward(self, x): 31 | indices = np.arange(x.shape[-1])[self.indices] 32 | other_indices = np.setdiff1d(np.arange(x.shape[-1]), indices) 33 | y = x[..., indices] 34 | cos = torch.cos(2 * np.pi * (y - self.left) / (self.right - self.left)) 35 | sin = torch.sin(2 * np.pi * (y - self.left) / (self.right - self.left)) 36 | x = torch.cat([cos, sin, x[..., other_indices]], dim=-1) 37 | return self.net.forward(x) 38 | 39 | 40 | class WrapDistances(torch.nn.Module): 41 | """TODO: TEST!!!""" 42 | def __init__(self, net, left=0.0, right=1.0, indices=slice(None)): 43 | super().__init__() 44 | self.net = net 45 | self.left = left 46 | self.right = right 47 | self.indices = indices 48 | 49 | def forward(self, x): 50 | indices = np.arange(x.shape[-1])[self.indices] 51 | other_indices = np.setdiff1d(np.arange(x.shape[-1]), indices) 52 | y = x[..., indices].view(x.shape[0],-1,3) 53 | distance_matrix = torch.cdist(y,y) 54 | mask = ~torch.tril(torch.ones_like(distance_matrix)).bool() 55 | 56 | distances = distance_matrix[mask].view(x.shape[0], -1) 57 | x = torch.cat([x[..., other_indices], distances], dim=-1) 58 | return self.net.forward(x) 59 | -------------------------------------------------------------------------------- /bgflow/nn/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import * 2 | -------------------------------------------------------------------------------- /bgflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. currentmodule: bgflow.utils 3 | 4 | =============================================================================== 5 | Geometry Utilities 6 | =============================================================================== 7 | 8 | .. autosummary:: 9 | :toctree: generated/ 10 | :template: class.rst 11 | 12 | distance_vectors 13 | distances_from_vectors 14 | remove_mean 15 | compute_distances 16 | compute_gammas 17 | kernelize_with_rbf 18 | RbfEncoder 19 | rbf_kernels 20 | 21 | =============================================================================== 22 | Jacobian Computation 23 | =============================================================================== 24 | 25 | .. autosummary:: 26 | :toctree: generated/ 27 | :template: class.rst 28 | 29 | brute_force_jacobian_trace 30 | brute_force_jacobian 31 | "batch_jacobian 32 | get_jacobian 33 | requires_grad 34 | 35 | =============================================================================== 36 | Types 37 | =============================================================================== 38 | 39 | .. autosummary:: 40 | :toctree: generated/ 41 | :template: class.rst 42 | 43 | is_list_or_tuple 44 | assert_numpy 45 | as_numpy 46 | 47 | =============================================================================== 48 | Free Energy Estimation 49 | =============================================================================== 50 | 51 | .. autosummary:: 52 | :toctree: generated/ 53 | :template: class.rst 54 | 55 | bennett_acceptance_ratio 56 | 57 | =============================================================================== 58 | Training utilities 59 | =============================================================================== 60 | 61 | .. autosummary:: 62 | :toctree: generated/ 63 | :template: class.rst 64 | 65 | IndexBatchIterator 66 | LossReporter 67 | 68 | 69 | =============================================================================== 70 | Training utilities 71 | =============================================================================== 72 | 73 | .. autosummary:: 74 | :toctree: generated/ 75 | :template: class.rst 76 | 77 | ClipGradient 78 | 79 | """ 80 | 81 | from .train import IndexBatchIterator, LossReporter, ClipGradient 82 | from .shape import tile 83 | from .types import * 84 | from .autograd import * 85 | 86 | from .geometry import ( 87 | distance_vectors, 88 | distances_from_vectors, 89 | remove_mean, 90 | compute_distances 91 | ) 92 | from .rbf_kernels import ( 93 | kernelize_with_rbf, 94 | compute_gammas, 95 | RbfEncoder, 96 | rbf_kernels 97 | ) 98 | 99 | from .free_energy import * 100 | -------------------------------------------------------------------------------- /bgflow/utils/openmm.py: -------------------------------------------------------------------------------- 1 | __author__ = 'solsson' 2 | 3 | import numpy as np 4 | 5 | 6 | def save_latent_samples_as_trajectory(samples, mdtraj_topology, filename=None, topology_fn=None, return_openmm_traj=True): 7 | """ 8 | Save Boltzmann Generator samples as a molecular dynamics trajectory. 9 | `samples`: posterior (Nsamples, n_atoms*n_dim) 10 | `mdtraj_topology`: an MDTraj Topology object of the molecular system 11 | `filename=None`: output filename with extension (all MDTraj compatible formats) 12 | `topology_fn=None`: outputs a PDB-file of the molecular topology for external visualization and analysis. 13 | """ 14 | import mdtraj as md 15 | trajectory = md.Trajectory(samples.reshape(-1, mdtraj_topology.n_atoms, 3), mdtraj_topology) 16 | if isinstance(topology_fn, str): 17 | trajectory[0].save_pdb(topology_fn) 18 | if isinstance(filename, str): 19 | trajectory.save(filename) 20 | if return_openmm_traj: 21 | return trajectory 22 | 23 | 24 | class NumpyReporter(object): 25 | def __init__(self, reportInterval, enforcePeriodicBox=True): 26 | self._coords = [] 27 | self._reportInterval = reportInterval 28 | self.enforcePeriodicBox = enforcePeriodicBox 29 | 30 | def describeNextReport(self, simulation): 31 | steps = self._reportInterval - simulation.currentStep%self._reportInterval 32 | return (steps, True, False, False, False, self.enforcePeriodicBox) 33 | 34 | def report(self, simulation, state): 35 | self._coords.append(state.getPositions(asNumpy=True).ravel()) 36 | 37 | def get_coordinates(self, superimpose=None): 38 | """ 39 | return saved coordinates as numpy array 40 | `superimpose`: openmm/mdtraj topology, will superimpose on first frame 41 | """ 42 | import mdtraj as md 43 | try: 44 | from openmm.app.topology import Topology as _Topology 45 | except ImportError: # fall back to older version < 7.6 46 | from simtk.openmm.app.topology import Topology as _Topology 47 | if superimpose is None: 48 | return np.array(self._coords) 49 | elif isinstance(superimpose, _Topology): 50 | trajectory = md.Trajectory(np.array(self._coords).reshape(-1, superimpose.getNumAtoms(), 3), 51 | md.Topology().from_openmm(superimpose)) 52 | else: 53 | trajectory = md.Trajectory(np.array(self._coords).reshape(-1, superimpose.n_atoms, 3), 54 | superimpose) 55 | 56 | trajectory.superpose(trajectory[0]) 57 | return trajectory.xyz.reshape(-1, superimpose.n_atoms * 3) 58 | -------------------------------------------------------------------------------- /bgflow/utils/shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def tile(a, dim, n_tile): 6 | """ 7 | Tiles a pytorch tensor along one an arbitrary dimension. 8 | 9 | Parameters 10 | ---------- 11 | a : PyTorch tensor 12 | the tensor which is to be tiled 13 | dim : Integer 14 | dimension along the tensor is tiled 15 | n_tile : Integer 16 | number of tiles 17 | 18 | Returns 19 | ------- 20 | b : PyTorch tensor 21 | the tensor with dimension `dim` tiled `n_tile` times 22 | """ 23 | init_dim = a.size(dim) 24 | repeat_idx = [1] * a.dim() 25 | repeat_idx[dim] = n_tile 26 | a = a.repeat(*(repeat_idx)) 27 | order_index = np.concatenate( 28 | [init_dim * np.arange(n_tile) + i for i in range(init_dim)] 29 | ) 30 | order_index = torch.LongTensor(order_index).to(a).long() 31 | return torch.index_select(a, dim, order_index) 32 | -------------------------------------------------------------------------------- /bgflow/utils/tensorops.py: -------------------------------------------------------------------------------- 1 | def log_dot_exp(logA, logB): 2 | """Fast and stable matrix log multiplication""" 3 | maxA = logA.max(dim=-1, keepdim=True).values 4 | maxB = logB.max(dim=-2, keepdim=True).values 5 | A = (logA - maxA).exp() 6 | B = (logB - maxB).exp() 7 | batch_shape = A.shape[:-2] 8 | A = A.view(-1, A.shape[-2], A.shape[-1]) 9 | B = B.view(-1, B.shape[-2], B.shape[-1]) 10 | logC = A.bmm(B).log().view(*batch_shape, A.shape[-2], B.shape[-1]) 11 | logC.add_(maxA + maxB) 12 | return logC 13 | -------------------------------------------------------------------------------- /bgflow/utils/types.py: -------------------------------------------------------------------------------- 1 | 2 | from collections.abc import Iterable 3 | import numpy as np 4 | import torch 5 | 6 | __all__ = [ 7 | "is_list_or_tuple", "assert_numpy", "as_numpy", 8 | "unpack_tensor_tuple", "pack_tensor_in_tuple", 9 | "pack_tensor_in_list", 10 | ] 11 | 12 | 13 | def is_list_or_tuple(x): 14 | return isinstance(x, list) or isinstance(x, tuple) 15 | 16 | 17 | def assert_numpy(x, arr_type=None): 18 | if isinstance(x, torch.Tensor): 19 | if x.is_cuda: 20 | x = x.cpu() 21 | x = x.detach().numpy() 22 | if is_list_or_tuple(x): 23 | x = np.array(x) 24 | assert isinstance(x, np.ndarray) 25 | if arr_type is not None: 26 | x = x.astype(arr_type) 27 | return x 28 | 29 | 30 | def as_numpy(tensor): 31 | """convert tensor to numpy""" 32 | return torch.as_tensor(tensor).detach().cpu().numpy() 33 | 34 | 35 | def unpack_tensor_tuple(seq): 36 | """unpack a tuple containing one tensor to a tensor""" 37 | if isinstance(seq, torch.Tensor): 38 | return seq 39 | else: 40 | if len(seq) == 1: 41 | return seq[0] 42 | else: 43 | return (*seq, ) 44 | 45 | 46 | def pack_tensor_in_tuple(seq): 47 | """pack a tensor into a tuple of Tensor of length 1""" 48 | if isinstance(seq, torch.Tensor): 49 | return seq, 50 | elif isinstance(seq, Iterable): 51 | return (*seq, ) 52 | else: 53 | return seq 54 | 55 | 56 | def pack_tensor_in_list(seq): 57 | """pack a tensor into a list of Tensor of length 1""" 58 | if isinstance(seq, torch.Tensor): 59 | return [seq] 60 | elif isinstance(seq, Iterable): 61 | return list(seq) 62 | else: 63 | return seq 64 | -------------------------------------------------------------------------------- /devtools/conda-env.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | 6 | dependencies: 7 | - python 8 | - pip 9 | 10 | - pytest 11 | - numpy 12 | 13 | - openmm 14 | - xtb-python 15 | - ase 16 | - openmmtools 17 | - pytorch 18 | - jax 19 | - nequip 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = bgflow 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Compiling bgflow's Documentation 2 | 3 | The docs for this project are built with [Sphinx](http://www.sphinx-doc.org/en/master/). 4 | To compile the docs, first ensure that Sphinx and corresponding packages are installed. 5 | 6 | 7 | ```bash 8 | pip install sphinx sphinx_rtd_theme sphinx_gallerie sphinx_nbexamples sphinxcontrib-katex sphinxcontrib-bibtex 9 | ``` 10 | 11 | Once installed, you can use the `Makefile` in this directory to compile static HTML pages by 12 | ```bash 13 | make html 14 | ``` 15 | 16 | The compiled docs will be in the `docs/_build/html/` directory and can be viewed by opening `index.html`. 17 | 18 | The documentation rst-files can be found in `docs` and `docs/api`. To include a class/function in the documentation it 19 | needs to be referenced in a corresponding `__init__.py`. That way the documentation is directly where the code is. For 20 | an example see `nn/flow/__init__.py`. If the `__init__.py` does not have any docstrings so far, it most likely needs to 21 | be referenced in a rst-file. The rst-files need to look like the ones in `docs/api` and must also be included 22 | in `docs/index.rst`. 23 | 24 | Notebooks can be also part of the documentation as well. To include them they need to be in `examples/nb_examples` 25 | and named `example_{name}.ipynb`. Especially md comments in separate cells work well. 26 | 27 | 28 | A configuration file for [Read The Docs](https://readthedocs.org/) (readthedocs.yaml) is included in the top level of 29 | the repository. To use Read the Docs to host your documentation, go to https://readthedocs.org/ 30 | and connect this repository. You may need to change your default branch to `main` under Advanced Settings for the 31 | project. 32 | 33 | If you would like to use Read The Docs with `autodoc` (included automatically) 34 | and your package has dependencies, you will need to include those dependencies in your documentation yaml 35 | file (`docs/requirements.yaml`). 36 | 37 | -------------------------------------------------------------------------------- /docs/_static/README.md: -------------------------------------------------------------------------------- 1 | # Static Doc Directory 2 | 3 | Add any paths that contain custom static files (such as style sheets) here, 4 | relative to the `conf.py` file's directory. 5 | They are copied after the builtin static files, 6 | so a file named "default.css" will overwrite the builtin "default.css". 7 | 8 | The path to this folder is set in the Sphinx `conf.py` file in the line: 9 | ```python 10 | templates_path = ['_static'] 11 | ``` 12 | 13 | ## Examples of file to add to this directory 14 | * Custom Cascading Style Sheets 15 | * Custom JavaScript code 16 | * Static logo images 17 | -------------------------------------------------------------------------------- /docs/_templates/README.md: -------------------------------------------------------------------------------- 1 | # Templates Doc Directory 2 | 3 | Add any paths that contain templates here, relative to 4 | the `conf.py` file's directory. 5 | They are copied after the builtin template files, 6 | so a file named "page.html" will overwrite the builtin "page.html". 7 | 8 | The path to this folder is set in the Sphinx `conf.py` file in the line: 9 | ```python 10 | html_static_path = ['_templates'] 11 | ``` 12 | 13 | ## Examples of file to add to this directory 14 | * HTML extensions of stock pages like `page.html` or `layout.html` 15 | -------------------------------------------------------------------------------- /docs/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {% set escapedname = objname|escape %} 2 | {% set title = "*" ~ objtype ~ "* " ~ escapedname %} 3 | {{ title | underline }} 4 | 5 | .. currentmodule:: {{ module }} 6 | 7 | .. auto{{objtype}}:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: {{ _('Attributes') }} 12 | 13 | .. autosummary:: 14 | {% for item in attributes %} 15 | ~{{ name }}.{{ item }} 16 | {%- endfor %} 17 | {% endif %} 18 | {% endblock %} 19 | 20 | {% block methods %} 21 | 22 | {% if methods %} 23 | .. rubric:: {{ _('Methods') }} 24 | 25 | .. autosummary:: 26 | {% for item in methods %} 27 | ~{{ name }}.{{ item }} 28 | {%- endfor %} 29 | {% endif %} 30 | {% endblock %} 31 | -------------------------------------------------------------------------------- /docs/api/bg.rst: -------------------------------------------------------------------------------- 1 | Boltzmann Generators 2 | ==================== 3 | 4 | 5 | .. currentmodule:: bgflow 6 | .. autosummary:: 7 | :toctree: generated/ 8 | :template: class.rst 9 | 10 | bg.BoltzmannGenerator 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/api/energies.rst: -------------------------------------------------------------------------------- 1 | Energies 2 | ========= 3 | 4 | 5 | .. automodule:: bgflow.distribution.energy 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | -------------------------------------------------------------------------------- /docs/api/flows.rst: -------------------------------------------------------------------------------- 1 | Flows 2 | ======= 3 | 4 | 5 | .. automodule:: bgflow.nn.flow 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | -------------------------------------------------------------------------------- /docs/api/samplers.rst: -------------------------------------------------------------------------------- 1 | Sampling 2 | ======== 3 | 4 | 5 | .. automodule:: bgflow.distribution 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | -------------------------------------------------------------------------------- /docs/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ======= 3 | 4 | 5 | .. automodule:: bgflow.utils 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | examples/index.rst 8 | nb_examples/index.rst 9 | datasets/index.rst 10 | -------------------------------------------------------------------------------- /docs/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Coming soon ... 5 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Bgflow 2 | ========================================================= 3 | Bgflow is a pytorch framework for Boltzmann Generators :footcite:`noe2019boltzmann` and other sampling methods. 4 | 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | installation.rst 10 | getting_started.rst 11 | requirements.rst 12 | examples.rst 13 | 14 | .. toctree:: 15 | :maxdepth: 1 16 | :caption: API docs: 17 | 18 | api/bg.rst 19 | api/flows.rst 20 | api/samplers.rst 21 | api/energies.rst 22 | api/utils.rst 23 | 24 | References 25 | ---------- 26 | .. footbibliography:: 27 | 28 | 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | To come via pip 5 | 6 | .. code-block:: console 7 | 8 | pip install bgtorch 9 | 10 | or conda-forge 11 | 12 | .. code-block:: console 13 | 14 | conda install -c conda-forge bgtorch -------------------------------------------------------------------------------- /docs/literature.bib: -------------------------------------------------------------------------------- 1 | % Encoding: UTF-8 2 | 3 | @article{Khler2020EquivariantFE, 4 | title={Equivariant Flows: exact likelihood generative learning for symmetric densities}, 5 | author={Jonas K{\"o}hler and Leon Klein and F. No{\'e}}, 6 | journal={ArXiv}, 7 | year={2020}, 8 | volume={abs/2006.02425} 9 | } 10 | 11 | @article{noe2019boltzmann, 12 | title={Boltzmann generators-sampling equilibrium states of many-body systems with deep learning}, 13 | author={No{\'e}, Frank and Olsson, Simon and K{\"o}hler, Jonas and Wu, Hao}, 14 | journal={Science}, 15 | volume={365}, 16 | issue={6457}, 17 | pages={eaaw1147}, 18 | year={2019} 19 | } 20 | 21 | @inproceedings{chen2018neural, 22 | title={Neural ordinary differential equations}, 23 | author={Chen, Tian Qi and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David K}, 24 | booktitle={Advances in Neural Information Processing Systems}, 25 | pages={6571--6583}, 26 | year={2018} 27 | } 28 | 29 | @article{gholami2019anode, 30 | title={ANODE: Unconditionally Accurate Memory-Efficient Gradients for Neural ODEs}, 31 | author={Gholami, Amir and Keutzer, Kurt and Biros, George}, 32 | journal={arXiv preprint arXiv:1902.10298}, 33 | year={2019} 34 | } 35 | 36 | @article{garcia2021n, 37 | title={E (n) Equivariant Normalizing Flows}, 38 | author={Garcia Satorras, Victor and Hoogeboom, Emiel and Fuchs, Fabian and Posner, Ingmar and Welling, Max}, 39 | journal={Advances in Neural Information Processing Systems}, 40 | volume={34}, 41 | year={2021} 42 | } 43 | 44 | @article{satorras2021n, 45 | title={E (n) equivariant graph neural networks}, 46 | author={Satorras, Victor Garcia and Hoogeboom, Emiel and Welling, Max}, 47 | journal={arXiv preprint arXiv:2102.09844}, 48 | year={2021} 49 | } -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=bgflow 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.rst: -------------------------------------------------------------------------------- 1 | Requirements 2 | ============ 3 | 4 | Mandatory 5 | ********** 6 | 7 | - `einops `_ 8 | - `pytorch `_ 9 | - `numpy `_ 10 | - `matplotlib `_ 11 | 12 | Optional 13 | ******** 14 | 15 | - `pytest `_ (for testing) 16 | - `nflows `_ (for Neural Spline Flows) 17 | - `OpenMM `_ (for molecular examples) 18 | - `torchdiffeq `_ (for neural ODEs) 19 | - `ANODE `_ (for neural ODEs) 20 | -------------------------------------------------------------------------------- /docs/requirements.yaml: -------------------------------------------------------------------------------- 1 | name: docs 2 | channels: 3 | dependencies: 4 | # Base depends 5 | - python 6 | - pip 7 | - numpy 8 | - nflows 9 | - torchdiffeq 10 | - einops 11 | - torch 12 | # Doc depends 13 | - sphinx 14 | - sphinx-gallery 15 | - sphinx-nbexamples 16 | - sphinx-rtd-theme 17 | - sphinxcontrib-applehelp 18 | - sphinxcontrib-bibtex 19 | - sphinxcontrib-devhelp 20 | - sphinxcontrib-htmlhelp 21 | - sphinxcontrib-katex 22 | - sphinxcontrib-jsmath 23 | - sphinxcontrib-qthelp 24 | - sphinxcontrib-serializinghtml 25 | 26 | -------------------------------------------------------------------------------- /examples/datasets/README.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ================== 3 | 4 | Coming soon... -------------------------------------------------------------------------------- /examples/general_examples/README.rst: -------------------------------------------------------------------------------- 1 | General Examples 2 | ================== 3 | 4 | Below is a gallery of examples -------------------------------------------------------------------------------- /examples/general_examples/plot_simple_bg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal Boltzmann Generator Example 3 | ==================================== 4 | 5 | In this example a simple Boltzmann Generator is created using coupling layers. 6 | """ 7 | 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import bgflow as bg 11 | 12 | # define prior and target 13 | dim = 2 14 | prior = bg.NormalDistribution(dim) 15 | target = bg.DoubleWellEnergy(dim) 16 | 17 | # here we aggregate all layers of the flow 18 | layers = [] 19 | layers.append(bg.SplitFlow(dim // 2)) 20 | layers.append(bg.CouplingFlow( 21 | # we use a affine transformation to transform 22 | # the RHS conditioned on the LHS 23 | bg.AffineTransformer( 24 | # use simple dense nets for the affine shift/scale 25 | shift_transformation=bg.DenseNet( 26 | [dim // 2, 4, dim // 2], 27 | activation=torch.nn.ReLU() 28 | ), 29 | scale_transformation=bg.DenseNet( 30 | [dim // 2, 4, dim // 2], 31 | activation=torch.nn.Tanh() 32 | ) 33 | ) 34 | )) 35 | layers.append(bg.InverseFlow(bg.SplitFlow(dim // 2))) 36 | 37 | # now define the flow as a sequence of all operations stored in layers 38 | flow = bg.SequentialFlow(layers) 39 | 40 | # The BG is defined by a prior, target and a flow 41 | generator = bg.BoltzmannGenerator(prior, flow, target) 42 | 43 | # sample from the BG 44 | samples = generator.sample(1000) 45 | _ = plt.hist2d( 46 | samples[:, 0].detach().numpy(), 47 | samples[:, 1].detach().numpy(), bins=100 48 | ) 49 | -------------------------------------------------------------------------------- /examples/nb_examples/README.rst: -------------------------------------------------------------------------------- 1 | Example Notebooks 2 | ================== 3 | 4 | Below is a gallery of example jupyter notebooks 5 | They are only rendered if they are named example_{name}.ipynb -------------------------------------------------------------------------------- /notebooks/alanine_dipeptide_augmented.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "basic-river", 6 | "metadata": {}, 7 | "source": [ 8 | "# Alanine Dipeptide with Augmented Normalizing Flows\n", 9 | "\n", 10 | "This notebook introduces augmented normalizing flows. \n", 11 | "\n", 12 | "At first, let us import everything from the basics tutorial so that we can focus on the newly introduced concepts." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "fuzzy-petite", 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Using downloaded and verified file: /tmp/alanine-dipeptide-nowater.pdb\n" 26 | ] 27 | }, 28 | { 29 | "data": { 30 | "application/vnd.jupyter.widget-view+json": { 31 | "model_id": "51cb1200f7e64d5c83575f8c5e60a3be", 32 | "version_major": 2, 33 | "version_minor": 0 34 | }, 35 | "text/plain": [ 36 | "_ColormakerRegistry()" 37 | ] 38 | }, 39 | "metadata": {}, 40 | "output_type": "display_data" 41 | } 42 | ], 43 | "source": [ 44 | "import alanine_dipeptide_basics as basic" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "increased-serbia", 50 | "metadata": {}, 51 | "source": [ 52 | "## Augmented Prior" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "id": "vertical-whale", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "OpenMMEnergy()" 65 | ] 66 | }, 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "basic.target_energy" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "distant-banks", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "basic.coordinate_transform" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "innocent-envelope", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "basic.dim_ics" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "correct-tribe", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python [conda env:ml] *", 108 | "language": "python", 109 | "name": "conda-env-ml-py" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.7.6" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # readthedocs.yml 2 | 3 | version: 2 4 | 5 | build: 6 | image: latest 7 | 8 | python: 9 | version: 3.8 10 | install: 11 | - method: pip 12 | path: . 13 | 14 | conda: 15 | environment: docs/requirements.yaml -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | [yapf] 3 | # YAPF, in .style.yapf files this shows up as "[style]" header 4 | COLUMN_LIMIT = 119 5 | INDENT_WIDTH = 4 6 | USE_TABS = False 7 | 8 | 9 | [versioneer] 10 | # Automatic version numbering scheme 11 | VCS = git 12 | style = pep440 13 | versionfile_source = bgflow/_version.py 14 | versionfile_build = bgflow/_version.py 15 | tag_prefix = '' 16 | 17 | [aliases] 18 | test = pytest -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | bgflow 3 | Boltzmann Generators and Normalizing Flows in PyTorch 4 | """ 5 | import sys 6 | from setuptools import setup, find_packages 7 | import versioneer 8 | 9 | short_description = "Boltzmann Generators and Normalizing Flows in PyTorch".split("\n")[0] 10 | 11 | # from https://github.com/pytest-dev/pytest-runner#conditional-requirement 12 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) 13 | pytest_runner = ['pytest-runner'] if needs_pytest else [] 14 | 15 | try: 16 | with open("README.md", "r") as handle: 17 | long_description = handle.read() 18 | except: 19 | long_description = None 20 | 21 | 22 | setup( 23 | # Self-descriptive entries which should always be present 24 | name='bgflow', 25 | author='Jonas Köhler, Andreas Krämer, Leon Klein, Manuel Dibak, Frank Noé', 26 | author_email='kraemer.research@gmail.com', 27 | description=short_description, 28 | long_description=long_description, 29 | long_description_content_type="text/markdown", 30 | version=versioneer.get_version(), 31 | cmdclass=versioneer.get_cmdclass(), 32 | license='MIT', 33 | 34 | # Which Python importable modules should be included when your package is installed 35 | # Handled automatically by setuptools. Use 'exclude' to prevent some specific 36 | # subpackage(s) from being added, if needed 37 | packages=find_packages(), 38 | 39 | # Optional include package data to ship with your package 40 | # Customize MANIFEST.in if the general case does not suit your needs 41 | # Comment out this line to prevent the files from being packaged with your software 42 | include_package_data=True, 43 | 44 | # Allows `setup.py test` to work correctly with pytest 45 | setup_requires=[] + pytest_runner, 46 | 47 | # Additional entries you may want simply uncomment the lines you want and fill in the data 48 | # url='http://www.my_package.com', # Website 49 | # install_requires=[], # Required packages, pulls from pip if needed; do not use for Conda deployment 50 | # platforms=['Linux', 51 | # 'Mac OS-X', 52 | # 'Unix', 53 | # 'Windows'], # Valid platforms your code works on, adjust to your flavor 54 | # python_requires=">=3.5", # Python version restrictions 55 | 56 | # Manual control if final package is compressible or not, set False to prevent the .egg from being made 57 | # zip_safe=False, 58 | 59 | ) 60 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from types import SimpleNamespace 4 | import pytest 5 | import numpy as np 6 | import torch 7 | from bgflow import MixedCoordinateTransformation, OpenMMBridge, OpenMMEnergy, RelativeInternalCoordinateTransformation 8 | 9 | 10 | @pytest.fixture( 11 | params=[ 12 | "cpu", 13 | pytest.param( 14 | "cuda:0", 15 | marks=pytest.mark.skipif( 16 | not torch.cuda.is_available(), 17 | reason="CUDA not available." 18 | ) 19 | ) 20 | ] 21 | ) 22 | def device(request): 23 | """Run a test case for all available devices.""" 24 | return torch.device(request.param) 25 | 26 | 27 | @pytest.fixture(params=[torch.float32, torch.float64]) 28 | def dtype(request, device): 29 | """Run a test case in single and double precision.""" 30 | return request.param 31 | 32 | 33 | @pytest.fixture() 34 | def ctx(dtype, device): 35 | return {"dtype": dtype, "device": device} 36 | 37 | 38 | @pytest.fixture(params=[torch.enable_grad, torch.no_grad]) 39 | def with_grad_and_no_grad(request): 40 | """Run a test with and without torch grad enabled""" 41 | with request.param(): 42 | yield 43 | 44 | 45 | @pytest.fixture(scope="session") 46 | def ala2(): 47 | """Mock bgmol dataset.""" 48 | mm = pytest.importorskip("simtk.openmm") 49 | md = pytest.importorskip("mdtraj") 50 | pdb = mm.app.PDBFile(os.path.join(os.path.dirname(__file__), "data/alanine-dipeptide-nowater.pdb")) 51 | system = SimpleNamespace() 52 | system.topology = pdb.getTopology() 53 | system.mdtraj_topology = md.Topology.from_openmm(system.topology) 54 | system.system = mm.app.ForceField("amber99sbildn.xml").createSystem( 55 | pdb.getTopology(), 56 | removeCMMotion=True, 57 | nonbondedMethod=mm.app.NoCutoff, 58 | constraints=mm.app.HBonds, 59 | rigidWater=True 60 | ) 61 | system.energy_model = OpenMMEnergy( 62 | bridge=OpenMMBridge( 63 | system.system, 64 | mm.LangevinIntegrator(300, 1, 0.001), 65 | n_workers=1 66 | ) 67 | ) 68 | system.positions = pdb.getPositions() 69 | system.rigid_block = np.array([6, 8, 9, 10, 14]) 70 | system.z_matrix = np.array([ 71 | [0, 1, 4, 6], 72 | [1, 4, 6, 8], 73 | [2, 1, 4, 0], 74 | [3, 1, 4, 0], 75 | [4, 6, 8, 14], 76 | [5, 4, 6, 8], 77 | [7, 6, 8, 4], 78 | [11, 10, 8, 6], 79 | [12, 10, 8, 11], 80 | [13, 10, 8, 11], 81 | [15, 14, 8, 16], 82 | [16, 14, 8, 6], 83 | [17, 16, 14, 15], 84 | [18, 16, 14, 8], 85 | [19, 18, 16, 14], 86 | [20, 18, 16, 19], 87 | [21, 18, 16, 19] 88 | ]) 89 | system.global_z_matrix = np.row_stack([ 90 | system.z_matrix, 91 | np.array([ 92 | [9, 8, 6, 14], 93 | [10, 8, 14, 6], 94 | [6, 8, 14, -1], 95 | [8, 14, -1, -1], 96 | [14, -1, -1, -1] 97 | ]) 98 | ]) 99 | dataset = SimpleNamespace() 100 | dataset.system = system 101 | # super-short simulation 102 | xyz = [] 103 | simulation = mm.app.Simulation(system.topology, system.system, mm.LangevinIntegrator(300,1,0.001)) 104 | simulation.context.setPositions(system.positions) 105 | for i in range(100): 106 | simulation.step(10) 107 | pos = simulation.context.getState(getPositions=True).getPositions(asNumpy=True) 108 | xyz.append(pos._value) 109 | dataset.xyz = np.stack(xyz, axis=0) 110 | return dataset 111 | 112 | 113 | @pytest.fixture() 114 | def crd_trafo(ala2, ctx): 115 | z_matrix = ala2.system.z_matrix 116 | fixed_atoms = ala2.system.rigid_block 117 | crd_transform = MixedCoordinateTransformation(torch.tensor(ala2.xyz, **ctx), z_matrix, fixed_atoms) 118 | return crd_transform 119 | 120 | 121 | @pytest.fixture() 122 | def crd_trafo_unwhitened(ala2, ctx): 123 | z_matrix = ala2.system.z_matrix 124 | fixed_atoms = ala2.system.rigid_block 125 | crd_transform = RelativeInternalCoordinateTransformation(z_matrix, fixed_atoms) 126 | return crd_transform 127 | 128 | -------------------------------------------------------------------------------- /tests/data/alanine-dipeptide-nowater.pdb: -------------------------------------------------------------------------------- 1 | CRYST1 27.222 27.222 27.222 90.00 90.00 90.00 P 1 1 2 | ATOM 1 HH31 ACE X 1 3.225 27.427 2.566 1.00 0.00 3 | ATOM 2 CH3 ACE X 1 3.720 26.570 2.110 1.00 0.00 4 | ATOM 3 HH32 ACE X 1 4.088 25.905 2.891 1.00 0.00 5 | ATOM 4 HH33 ACE X 1 4.557 26.914 1.502 1.00 0.00 6 | ATOM 5 C ACE X 1 2.770 25.800 1.230 1.00 0.00 7 | ATOM 6 O ACE X 1 1.600 26.150 1.090 1.00 0.00 8 | ATOM 7 N ALA X 2 3.270 24.640 0.690 1.00 0.00 9 | ATOM 8 H ALA X 2 4.259 24.471 0.810 1.00 0.00 10 | ATOM 9 CA ALA X 2 2.480 23.690 -0.190 1.00 0.00 11 | ATOM 10 HA ALA X 2 1.733 24.315 -0.679 1.00 0.00 12 | ATOM 11 CB ALA X 2 3.470 23.160 -1.270 1.00 0.00 13 | ATOM 12 HB1 ALA X 2 4.219 22.525 -0.797 1.00 0.00 14 | ATOM 13 HB2 ALA X 2 2.922 22.582 -2.014 1.00 0.00 15 | ATOM 14 HB3 ALA X 2 3.963 24.002 -1.756 1.00 0.00 16 | ATOM 15 C ALA X 2 1.730 22.590 0.490 1.00 0.00 17 | ATOM 16 O ALA X 2 2.340 21.880 1.280 1.00 0.00 18 | ATOM 17 N NME X 3 0.400 22.430 0.210 1.00 0.00 19 | ATOM 18 H NME X 3 -0.008 23.118 -0.407 1.00 0.00 20 | ATOM 19 CH3 NME X 3 -0.470 21.350 0.730 1.00 0.00 21 | ATOM 20 HH31 NME X 3 0.112 20.693 1.376 1.00 0.00 22 | ATOM 21 HH32 NME X 3 -1.290 21.786 1.300 1.00 0.00 23 | ATOM 22 HH33 NME X 3 -0.873 20.775 -0.103 1.00 0.00 24 | END 25 | -------------------------------------------------------------------------------- /tests/distribution/energy/test_ase.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from bgflow import ASEBridge, ASEEnergy, XTBBridge, XTBEnergy 5 | 6 | 7 | try: 8 | import ase 9 | import xtb 10 | ase_and_xtb_imported = True 11 | except ImportError: 12 | ase_and_xtb_imported = False 13 | 14 | pytestmark = pytest.mark.skipif(not ase_and_xtb_imported, reason="Tests require ASE and XTB") 15 | 16 | 17 | def test_ase_energy(ctx): 18 | from ase.build import molecule 19 | from xtb.ase.calculator import XTB 20 | water = molecule("H2O") 21 | water.calc = XTB() 22 | target = ASEEnergy(ASEBridge(water, 300.)) 23 | pos = torch.tensor(0.1*water.positions, **ctx) 24 | e = target.energy(pos) 25 | f = target.force(pos) 26 | 27 | 28 | def test_ase_vs_xtb(ctx): 29 | # to make sure that unit conversion is the same, etc. 30 | from ase.build import molecule 31 | from xtb.ase.calculator import XTB 32 | water = molecule("H2O") 33 | water.calc = XTB() 34 | target1 = ASEEnergy(ASEBridge(water, 300.)) 35 | target2 = XTBEnergy(XTBBridge(water.numbers, 300.)) 36 | pos = torch.tensor(0.1 * water.positions[None, ...], **ctx) 37 | assert torch.allclose(target1.energy(pos), target2.energy(pos)) 38 | assert torch.allclose(target1.force(pos), target2.force(pos), atol=1e-6) 39 | 40 | -------------------------------------------------------------------------------- /tests/distribution/energy/test_clipped.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import warnings 4 | import torch 5 | from bgflow import Energy, LinLogCutEnergy, GradientClippedEnergy, DoubleWellEnergy 6 | from bgflow.utils import ClipGradient 7 | 8 | 9 | class StrongRepulsion(Energy): 10 | def __init__(self): 11 | super().__init__([2, 2]) 12 | 13 | def _energy(self, x): 14 | dist = torch.cdist(x, x) 15 | return (dist ** -12)[..., 0, 1][:, None] 16 | 17 | 18 | def test_linlogcut(ctx): 19 | lj = StrongRepulsion() 20 | llc = LinLogCutEnergy(lj, high_energy=1e3, max_energy=1e10) 21 | x = torch.tensor([ 22 | [[0., 0.], [0.0, 0.0]], # > max energy 23 | [[0., 0.], [0.0, 0.3]], # < max_energy, > high_energy 24 | [[0., 0.], [0.0, 1.]], # < high_energy 25 | ], **ctx) 26 | raw = lj.energy(x)[:, 0] 27 | cut = llc.energy(x)[:, 0] 28 | 29 | # first energy is clamped 30 | assert not (raw <= 1e10).all() 31 | assert (cut <= 1e10).all() 32 | assert cut[0].item() == pytest.approx(1e10, abs=1e-5) 33 | # second energy is softened, but force points in the right direction 34 | assert 1e3 < cut[1].item() < 1e10 35 | assert llc.force(x)[1][1, 1] > 0.0 36 | assert llc.force(x)[1][0, 1] < 0.0 37 | # third energy is unchanged 38 | assert torch.allclose(raw[2], cut[2], atol=1e-5) 39 | 40 | 41 | def openmm_example(grad_clipping, ctx): 42 | try: 43 | with warnings.catch_warnings(): 44 | warnings.simplefilter( 45 | "ignore", DeprecationWarning 46 | ) # ignore warnings inside OpenMM 47 | from simtk import openmm, unit 48 | except ImportError: 49 | pytest.skip("Test requires OpenMM.") 50 | 51 | system = openmm.System() 52 | system.addParticle(1.) 53 | system.addParticle(2.) 54 | nonbonded = openmm.NonbondedForce() 55 | nonbonded.addParticle(0.0, 1.0, 2.0) 56 | nonbonded.addParticle(0.0, 1.0, 2.0) 57 | system.addForce(nonbonded) 58 | 59 | from bgflow import OpenMMEnergy, OpenMMBridge 60 | bridge = OpenMMBridge(system, openmm.LangevinIntegrator(300., 0.1, 0.001), n_workers=1) 61 | 62 | energy = OpenMMEnergy(bridge=bridge, two_event_dims=False) 63 | energy = GradientClippedEnergy(energy, grad_clipping).to(**ctx) 64 | positions = torch.tensor([[0.0, 0.0, 0.0, 0.1, 0.2, 0.6]]).to(**ctx) 65 | positions.requires_grad = True 66 | force = energy.force(positions) 67 | energy.energy(positions).sum().backward() 68 | #force = torch.tensor([[-1908.0890, -3816.1780, -11448.5342, 1908.0890, 3816.1780, 11448.5342]]).to(**ctx) 69 | return positions.grad, force 70 | 71 | 72 | def test_openmm_clip_by_value(ctx): 73 | grad_clipping = ClipGradient(clip=3000.0, norm_dim=1) 74 | grad, force = openmm_example(grad_clipping, ctx) 75 | expected = - torch.as_tensor([[-1908.0890, -3000., -3000., 1908.0890, 3000., 3000.]], **ctx) 76 | assert torch.allclose(grad.flatten(), expected, atol=1e-3) 77 | 78 | 79 | def test_openmm_clip_by_atom(ctx): 80 | grad_clipping = ClipGradient(clip=torch.as_tensor([3000.0, 1.0]), norm_dim=3) 81 | grad, force = openmm_example(grad_clipping, ctx) 82 | norm_ratio = torch.linalg.norm(grad[..., :3], dim=-1).item() 83 | assert norm_ratio == pytest.approx(3000.) 84 | assert torch.allclose(grad[..., :3] / 3000., - grad[..., 3:], atol=1e-6) 85 | 86 | 87 | def test_openmm_clip_by_batch(ctx): 88 | grad_clipping = ClipGradient(clip=1.0, norm_dim=-1) 89 | grad, force = openmm_example(grad_clipping, ctx) 90 | ratio = force / grad 91 | assert torch.allclose(ratio, ratio[0, 0] * torch.ones_like(ratio)) 92 | assert torch.linalg.norm(grad).item() == pytest.approx(1.) 93 | 94 | 95 | def test_openmm_clip_no_grad(ctx): 96 | energy = GradientClippedEnergy( 97 | energy=DoubleWellEnergy(2), 98 | gradient_clipping=ClipGradient(clip=1.0, norm_dim=1) 99 | ) 100 | x = torch.randn(12,2).to(**ctx) 101 | x.requires_grad = False 102 | energy.energy(x) 103 | -------------------------------------------------------------------------------- /tests/distribution/energy/test_lennard_jones.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from bgflow.distribution.energy import LennardJonesPotential 5 | from bgflow.distribution.energy.lennard_jones import lennard_jones_energy_torch 6 | 7 | 8 | def test_lennard_jones_energy_torch(): 9 | energy_large = lennard_jones_energy_torch(torch.tensor(1e10), eps=1, rm=1) 10 | energy_min = lennard_jones_energy_torch(torch.tensor(5.), eps=3, rm=5) 11 | energy_zero = lennard_jones_energy_torch(torch.tensor(2 ** (-1 / 6)), eps=3, rm=1) 12 | assert torch.allclose(energy_large, torch.tensor(0.)) 13 | assert torch.allclose(energy_min, torch.tensor(-3.)) 14 | assert torch.allclose(energy_zero, torch.tensor(0.), atol=1e-5) 15 | 16 | 17 | @pytest.mark.parametrize("oscillator", [True, False]) 18 | @pytest.mark.parametrize("two_event_dims", [True, False]) 19 | def test_lennard_jones_potential(oscillator, two_event_dims): 20 | eps = 5. 21 | 22 | # 2 particles in 3D 23 | lj_pot = LennardJonesPotential( 24 | dim=6, n_particles=2, eps=eps, rm=2.0, 25 | oscillator=oscillator, oscillator_scale=1., 26 | two_event_dims=two_event_dims 27 | ) 28 | 29 | batch_shape = (5, 7) 30 | data3d = torch.tensor([[[[-1., 0, 0], [1, 0, 0]]]]).repeat(*batch_shape, 1, 1) 31 | if not two_event_dims: 32 | data3d = data3d.view(*batch_shape, 6) 33 | energy3d = torch.tensor([[- eps]]).repeat(*batch_shape) 34 | if oscillator: 35 | energy3d += 1 36 | lj_energy_3d = lj_pot.energy(data3d) 37 | assert torch.allclose(energy3d[:, None], lj_energy_3d) 38 | 39 | # 3 particles in 2D 40 | lj_pot = LennardJonesPotential( 41 | dim=6, n_particles=3, eps=eps, rm=1.0, 42 | oscillator=oscillator, oscillator_scale=1., 43 | two_event_dims=two_event_dims 44 | ) 45 | h = np.sqrt(0.75) 46 | data2d = torch.tensor([[[0, 2 / 3 * h], [0.5, -1 / 3 * h], [-0.5, -1 / 3 * h]]], dtype=torch.float) 47 | if not two_event_dims: 48 | data2d = data2d.view(-1, 6) 49 | energy2d = torch.tensor([- 3 * eps]) 50 | if oscillator: 51 | energy2d += 0.5 * (data2d ** 2).sum() 52 | lj_energy_2d = lj_pot.energy(data2d) 53 | assert torch.allclose(energy2d, lj_energy_2d) 54 | lj_energy2d_np = lj_pot._energy_numpy(data2d) 55 | assert energy2d[:, None].numpy() == pytest.approx(lj_energy2d_np, abs=1e-6) 56 | -------------------------------------------------------------------------------- /tests/distribution/energy/test_multi_double_well_potential.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow.distribution import MultiDoubleWellPotential 4 | 5 | 6 | def test_multi_double_well_potential(ctx): 7 | target = MultiDoubleWellPotential(dim=4, n_particles=2, a=0.9, b=-4, c=0, offset=4) 8 | x = torch.tensor([[[2., 0, ], [-2, 0]]], **ctx) 9 | energy = target.energy(x) 10 | target.force(x) 11 | assert torch.allclose(energy, torch.tensor([[0.]], **ctx), atol=1e-5) 12 | -------------------------------------------------------------------------------- /tests/distribution/energy/test_xtb.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from bgflow import XTBEnergy, XTBBridge 5 | 6 | try: 7 | import xtb 8 | xtb_imported = True 9 | except ImportError: 10 | xtb_imported = False 11 | 12 | pytestmark = pytest.mark.skipif(not xtb_imported, reason="Test requires XTB") 13 | 14 | 15 | @pytest.mark.parametrize("pos_shape", [(1, 3, 3), (1, 9)]) 16 | def test_xtb_water(pos_shape, ctx): 17 | unit = pytest.importorskip("openmm.unit") 18 | temperature = 300 19 | numbers = np.array([8, 1, 1]) 20 | positions = torch.tensor([ 21 | [0.00000000000000, 0.00000000000000, -0.73578586109551], 22 | [1.44183152868459, 0.00000000000000, 0.36789293054775], 23 | [-1.44183152868459, 0.00000000000000, 0.36789293054775]], 24 | **ctx 25 | ) 26 | positions = (positions * unit.bohr).value_in_unit(unit.nanometer) 27 | target = XTBEnergy( 28 | XTBBridge(numbers=numbers, temperature=temperature), 29 | two_event_dims=(pos_shape == (1, 3, 3)) 30 | ) 31 | energy = target.energy(positions.reshape(pos_shape)) 32 | force = target.force(positions.reshape(pos_shape)) 33 | assert energy.shape == (1, 1) 34 | assert force.shape == pos_shape 35 | 36 | kbt = unit.BOLTZMANN_CONSTANT_kB * temperature * unit.kelvin 37 | expected_energy = torch.tensor(-5.070451354836705, **ctx) * unit.hartree / kbt 38 | expected_force = - torch.tensor([ 39 | [6.24500451e-17, - 3.47909735e-17, - 5.07156941e-03], 40 | [-1.24839222e-03, 2.43536791e-17, 2.53578470e-03], 41 | [1.24839222e-03, 1.04372944e-17, 2.53578470e-03], 42 | ], **ctx) * unit.hartree/unit.bohr/(kbt/unit.nanometer) 43 | assert torch.allclose(energy.flatten(), expected_energy.flatten(), atol=1e-5) 44 | assert torch.allclose(force.flatten(), expected_force.flatten(), atol=1e-5) 45 | 46 | 47 | def _eval_invalid(ctx, err_handling): 48 | pos = torch.zeros(1, 3, 3, **ctx) 49 | target = XTBEnergy( 50 | XTBBridge(numbers=np.array([8, 1, 1]), temperature=300, err_handling=err_handling) 51 | ) 52 | return target.energy(pos), target.force(pos) 53 | 54 | 55 | def test_xtb_error(ctx): 56 | from xtb.interface import XTBException 57 | with pytest.raises(XTBException): 58 | _eval_invalid(ctx, err_handling="error") 59 | 60 | 61 | def test_xtb_warning(ctx): 62 | with pytest.warns(UserWarning, match="Caught exception in xtb"): 63 | e, f = _eval_invalid(ctx, err_handling="warning") 64 | assert torch.isinf(e).all() 65 | assert torch.allclose(f, torch.zeros_like(f)) 66 | 67 | 68 | def test_xtb_ignore(ctx): 69 | e, f = _eval_invalid(ctx, err_handling="ignore") 70 | assert torch.isinf(e).all() 71 | assert torch.allclose(f, torch.zeros_like(f)) 72 | -------------------------------------------------------------------------------- /tests/distribution/sampling/test_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from bgflow import DataSetSampler, DataLoaderSampler 5 | 6 | 7 | def test_dataset_sampler(ctx): 8 | data = torch.arange(12).reshape(4,3).to(**ctx) 9 | sampler = DataSetSampler(data).to(**ctx) 10 | idxs = sampler._idxs.copy() 11 | # test sampling out of range 12 | assert torch.allclose(sampler.sample(3), data[idxs[:3]]) 13 | assert torch.allclose(sampler.sample(3)[0], data[idxs[3]]) 14 | for i in range(10): 15 | assert sampler.sample(3).shape == (3, 3) 16 | # test sampling larger #samples than len(data) 17 | sampler._current_index = 0 18 | samples = sampler.sample(12) 19 | assert samples.shape == (12, 3) 20 | # check that rewinding works 21 | for i in range(3): 22 | assert torch.allclose( 23 | data.flatten(), 24 | torch.sort(samples[4*i: 4*(i+1)].flatten())[0] 25 | ) 26 | 27 | 28 | def test_dataset_to_device_sampler(ctx): 29 | data = torch.arange(12).reshape(4,3) 30 | sampler = DataSetSampler(data).to(**ctx) 31 | assert sampler.sample(10).device == ctx["device"] 32 | 33 | 34 | def test_multiple_dataset_sampler(ctx): 35 | data = torch.arange(12).reshape(4,3).to(**ctx) 36 | data2 = torch.arange(8).reshape(4,2).to(**ctx) 37 | sampler = DataSetSampler(data, data2).to(**ctx) 38 | samples = sampler.sample(3) 39 | assert len(samples) == 2 40 | assert samples[0].shape == (3, 3) 41 | assert samples[1].shape == (3, 2) 42 | assert samples[0].device == ctx["device"] 43 | 44 | 45 | def test_resizing(ctx): 46 | data = torch.arange(12).reshape(4, 3).to(**ctx) 47 | sampler = DataSetSampler(data) 48 | sampler.resize_(5) 49 | assert len(sampler) == 5 50 | assert sampler.sample(2).shape == (2, 3) 51 | sampler.resize_(3) 52 | assert len(sampler) == 3 53 | assert sampler.sample(2).shape == (2, 3) 54 | 55 | 56 | def test_dataloader_sampler(ctx): 57 | loader = torch.utils.data.DataLoader( 58 | torch.utils.data.TensorDataset(torch.randn(10, 2, 2, **ctx)), 59 | batch_size=4, 60 | ) 61 | sampler = DataLoaderSampler(loader, **ctx) 62 | samples = sampler.sample(4) 63 | assert samples.shape == (4, 2, 2) 64 | assert samples.device == ctx["device"] 65 | 66 | -------------------------------------------------------------------------------- /tests/distribution/sampling/test_iterative.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from bgflow import IterativeSampler, SamplerState, SamplerStep 5 | from bgflow.distribution.sampling._iterative_helpers import AbstractSamplerState 6 | 7 | 8 | class AddOne(SamplerStep): 9 | def _step(self, state: AbstractSamplerState): 10 | statedict = state.as_dict() 11 | samples = tuple( 12 | x + 1.0 13 | for x in statedict["samples"] 14 | ) 15 | return state.replace(samples=samples) 16 | 17 | 18 | def test_iterative_sampler(ctx): 19 | state = SamplerState(samples=[torch.zeros(2, **ctx), ]) 20 | 21 | # test burnin 22 | sampler = IterativeSampler(state, sampler_steps=[AddOne()], n_burnin=10) 23 | assert torch.allclose(sampler.state.samples[0], 10*torch.ones_like(sampler.state.samples[0])) 24 | 25 | # test sampling 26 | samples = sampler.sample(2) 27 | assert torch.allclose(samples, torch.tensor([[11., 11.], [12., 12.]], **ctx)) 28 | 29 | # test stride 30 | sampler.stride = 5 31 | samples = sampler.sample(2) 32 | assert sampler.i == 14 33 | assert torch.allclose(samples, torch.tensor([[17., 17.], [22., 22.]], **ctx)) 34 | 35 | # test iteration 36 | sampler.max_iterations = 15 37 | for batch in sampler: # only called once 38 | assert torch.allclose(batch.samples[0], torch.tensor([[27., 27.]], **ctx)) 39 | sampler.max_iterations = None 40 | 41 | -------------------------------------------------------------------------------- /tests/distribution/sampling/test_iterative_helpers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bgflow.distribution.sampling._iterative_helpers import _map_to_primary_cell 4 | 5 | 6 | def test_map_to_primary_cell(): 7 | cell = torch.eye(3) 8 | x = torch.tensor([[1.2, -0.1, 4.5]]) 9 | assert torch.allclose(_map_to_primary_cell(x, cell), torch.tensor([[0.2, 0.9, 0.5]])) 10 | 11 | cell = torch.tensor( 12 | [ 13 | [1.0, 2.0, 0.0], 14 | [0.0, 2.0, 0.0], 15 | [0.0, 0.0, 1.0] 16 | ] 17 | ) 18 | assert torch.allclose(_map_to_primary_cell(x, cell), torch.tensor([[2.2, 1.9, 0.5]])) 19 | 20 | -------------------------------------------------------------------------------- /tests/distribution/sampling/test_mcmc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | import torch 5 | from bgflow import ( 6 | GaussianProposal, SamplerState, NormalDistribution, 7 | IterativeSampler, MCMCStep, LatentProposal, BentIdentity, GaussianMCMCSampler 8 | ) 9 | from bgflow.distribution.sampling.mcmc import _GaussianMCMCSampler 10 | 11 | 12 | @pytest.mark.parametrize("proposal", [ 13 | GaussianProposal(noise_std=0.2), 14 | LatentProposal( 15 | flow=BentIdentity(), 16 | base_proposal=GaussianProposal(noise_std=0.4)) 17 | ]) 18 | @pytest.mark.parametrize("temperatures", [torch.ones(3), torch.arange(1, 10, 100)]) 19 | def test_mcmc(ctx, proposal, temperatures): 20 | """sample from a normal distribution with mu=3 and std=1,2,3 using MCMC""" 21 | try: 22 | import tqdm 23 | progress_bar = tqdm.tqdm 24 | except ImportError: 25 | progress_bar = lambda x: x 26 | target = NormalDistribution(4, 3.0*torch.ones(4, **ctx)) 27 | temperatures = temperatures.to(**ctx) 28 | # for testing efficiency we have a second batch dimension 29 | # this is not required; we could remove the first batch dimension (256) 30 | # and just sample longer 31 | state = SamplerState(samples=0.0*torch.ones(512, 3, 4, **ctx)) 32 | mcmc = IterativeSampler( 33 | sampler_state=state, 34 | sampler_steps=[ 35 | MCMCStep( 36 | target, 37 | proposal=proposal.to(**ctx), 38 | target_temperatures=temperatures 39 | ) 40 | ], 41 | stride=2, 42 | n_burnin=100, 43 | progress_bar=progress_bar 44 | ) 45 | samples = mcmc.sample(100) 46 | assert torch.allclose(samples.mean(dim=(0, 1, 3)), torch.tensor([3.0]*3, **ctx), atol=0.1) 47 | std = samples.std(dim=(0, 1, 3)) 48 | assert torch.allclose(std, temperatures.sqrt(), rtol=0.05, atol=0.0) 49 | 50 | 51 | def test_old_vs_new_mcmc(ctx): 52 | energy = NormalDistribution(dim=4) 53 | x0 = torch.randn(64, 4) 54 | 55 | def constraint(x): 56 | return torch.fmod(x,torch.ones(4)) 57 | 58 | with warnings.catch_warnings(): 59 | warnings.simplefilter("ignore", DeprecationWarning) 60 | old_mc = _GaussianMCMCSampler( 61 | energy, x0, n_stride=10, n_burnin=10, 62 | noise_std=0.3, box_constraint=constraint 63 | ) 64 | new_mc = GaussianMCMCSampler( 65 | energy, x0, stride=10, n_burnin=10, 66 | noise_std=0.3, box_constraint=constraint 67 | ) 68 | old_samples = old_mc.sample(1000) 69 | new_samples = new_mc.sample(100) 70 | assert old_samples.shape == (6400, 4) 71 | assert old_samples.shape == new_samples.shape 72 | assert old_samples.mean().item() == pytest.approx(new_samples.mean().item(), abs=1e-2) 73 | assert old_samples.std().item() == pytest.approx(new_samples.std().item(), abs=1e-1) 74 | -------------------------------------------------------------------------------- /tests/distribution/test_distribution.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from torch.distributions import MultivariateNormal 5 | from bgflow.distribution import TorchDistribution, NormalDistribution, UniformDistribution, ProductDistribution 6 | 7 | 8 | def _random_mean_cov(dim, device, dtype): 9 | mean = 10*torch.randn(dim).to(device, dtype) 10 | # generate a symmetric, positive definite matrix 11 | cov = torch.triu(torch.rand(dim, dim)) 12 | cov = cov + cov.T + dim * torch.diag(torch.ones(dim)) 13 | cov = cov.to(device, dtype) 14 | return mean, cov 15 | 16 | 17 | @pytest.mark.parametrize("dim", (2, 10)) 18 | def test_distribution_energy(dim, device, dtype): 19 | """compare torch's normal distribution with bgflow's normal distribution""" 20 | n_samples = 7 21 | mean, cov = _random_mean_cov(dim, device, dtype) 22 | samples = torch.randn((n_samples, dim)).to(device, dtype) 23 | normal_trch = TorchDistribution(MultivariateNormal(loc=mean, covariance_matrix=cov)) 24 | normal_bgtrch = NormalDistribution(dim, mean, cov) 25 | assert torch.allclose(normal_trch.energy(samples), normal_bgtrch.energy(samples), rtol=2e-2, atol=1e-2) 26 | 27 | 28 | @pytest.mark.parametrize("dim", (2, 10)) 29 | @pytest.mark.parametrize("sample_shape", (50000, torch.Size([10,1]))) 30 | def test_distribution_samples(dim, sample_shape, device, dtype): 31 | """compare torch's normal distribution with bgflow's normal distribution""" 32 | mean, cov = _random_mean_cov(dim, device, dtype) 33 | normal_trch = TorchDistribution(MultivariateNormal(loc=mean, covariance_matrix=cov)) 34 | normal_bgtrch = NormalDistribution(dim, mean, cov) 35 | samples_trch = normal_trch.sample(sample_shape) 36 | target_shape = torch.Size([sample_shape]) if isinstance(sample_shape, int) else sample_shape 37 | assert samples_trch.size() == target_shape + torch.Size([dim]) 38 | if isinstance(sample_shape, int): 39 | samples_bgtrch = normal_bgtrch.sample(sample_shape) 40 | # to make sure that both sample from the same distribution, compute divergences 41 | for p in [normal_trch, normal_bgtrch]: 42 | for q in [normal_trch, normal_bgtrch]: 43 | for x in [samples_bgtrch, samples_trch]: 44 | for y in [samples_bgtrch, samples_trch]: 45 | div = torch.mean( 46 | (-p.energy(x) + q.energy(y)) 47 | ) 48 | assert torch.abs(div) < 5e-2 49 | 50 | 51 | def test_sample_uniform_with_temperature(ctx): 52 | uniform = UniformDistribution(low=torch.zeros(100, **ctx), high=torch.ones(100, **ctx)) 53 | assert uniform.sample(20).mean().item() == pytest.approx(0.5, abs=0.05) 54 | assert uniform.sample(20, temperature=100.).mean().item() == pytest.approx(0.5, abs=0.05) 55 | 56 | 57 | def test_sample_product_with_temperature(ctx): 58 | normal = NormalDistribution(dim=100, mean=torch.zeros(100, **ctx)) 59 | product = ProductDistribution([normal, normal]) 60 | x1, y1 = product.sample(20, temperature=1.) 61 | x2, y2 = product.sample(20, temperature=100.) 62 | 63 | assert (x1.std() / x2.std()).item() == pytest.approx(0.1, abs=0.05) 64 | assert (y1.std() / y2.std()).item() == pytest.approx(0.1, abs=0.05) 65 | 66 | 67 | -------------------------------------------------------------------------------- /tests/distribution/test_product.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from bgflow import NormalDistribution, ProductDistribution 5 | 6 | 7 | def test_multi_distribution(): 8 | """Test that a compound normal distribution behaves identical to a multivariate standard normal distribution.""" 9 | n1 = NormalDistribution(3) 10 | n2 = NormalDistribution(2) 11 | n3 = NormalDistribution(1) 12 | n = NormalDistribution(6) 13 | compound = ProductDistribution([n1, n2, n3], cat_dim=-1) 14 | n_samples = 10 15 | samples = compound.sample(n_samples) 16 | assert samples.shape == n.sample(n_samples).shape 17 | assert n.energy(samples).shape == compound.energy(samples).shape 18 | assert n.energy(samples).numpy() == pytest.approx(compound.energy(samples).numpy()) 19 | 20 | 21 | def test_multi_distribution_no_cat(): 22 | """Test that a compound normal distribution behaves identical to a multivariate standard normal distribution.""" 23 | n1 = NormalDistribution(3) 24 | n2 = NormalDistribution(2) 25 | n3 = NormalDistribution(1) 26 | n = NormalDistribution(6) 27 | compound = ProductDistribution([n1, n2, n3], cat_dim=None) 28 | n_samples = 10 29 | samples = compound.sample(n_samples) 30 | assert len(samples) == 3 31 | assert isinstance(samples, tuple) 32 | assert samples[0].shape == (10, 3) 33 | assert samples[1].shape == (10, 2) 34 | assert samples[2].shape == (10, 1) 35 | assert compound.energy(*samples).shape == (10, 1) 36 | assert n.energy(torch.cat(samples, dim=-1)).numpy() == pytest.approx(compound.energy(*samples).numpy()) 37 | 38 | 39 | def test_sample_to_cpu(ctx): 40 | cpu = torch.device("cpu") 41 | normal_distribution = NormalDistribution(dim=2).to(**ctx) 42 | samples = normal_distribution.sample_to_cpu(150, batch_size=17) 43 | assert samples.shape == (150, 2) 44 | assert samples.device == cpu 45 | 46 | product_distribution = ProductDistribution([NormalDistribution(2), NormalDistribution(3)], cat_dim=None) 47 | samples = product_distribution.sample_to_cpu(150, batch_size=17) 48 | assert len(samples) == 2 49 | assert samples[0].shape == (150, 2) 50 | assert samples[1].shape == (150, 3) 51 | assert samples[0].device == cpu 52 | assert samples[1].device == cpu 53 | -------------------------------------------------------------------------------- /tests/factory/test_distribution_factory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow import UniformDistribution, NormalDistribution, TruncatedNormalDistribution, make_distribution 4 | 5 | 6 | @pytest.mark.parametrize("prior_type", [UniformDistribution, NormalDistribution, TruncatedNormalDistribution]) 7 | def test_prior_factory(prior_type, ctx): 8 | prior = make_distribution(prior_type, 2, **ctx) 9 | samples = prior.sample(10) 10 | assert torch.device(samples.device) == torch.device(ctx["device"]) 11 | assert samples.dtype == ctx["dtype"] 12 | assert samples.shape == (10, 2) 13 | 14 | 15 | def test_prior_factory_with_kwargs(ctx): 16 | prior = make_distribution(UniformDistribution, 2, low=torch.tensor([2.0, 2.0]), high=torch.tensor([3.0, 3.0]), **ctx) 17 | samples = prior.sample(5) 18 | assert torch.device(samples.device) == torch.device(ctx["device"]) 19 | assert samples.dtype == ctx["dtype"] 20 | assert (samples > 1.0).all() 21 | -------------------------------------------------------------------------------- /tests/factory/test_icmarginals.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | 5 | from bgflow import ( 6 | GlobalInternalCoordinateTransformation, 7 | InternalCoordinateMarginals, BONDS, ANGLES, 8 | ShapeDictionary, TensorInfo 9 | ) 10 | 11 | 12 | @pytest.mark.parametrize("with_data", [True, False]) 13 | def test_icmarginals_inform_api(tmpdir, ctx, with_data): 14 | """API test""" 15 | bgmol = pytest.importorskip("bgmol") 16 | dataset = bgmol.datasets.Ala2Implicit1000Test( 17 | root=tmpdir, 18 | download=True, 19 | read=True 20 | ) 21 | coordinate_transform = GlobalInternalCoordinateTransformation( 22 | bgmol.systems.ala2.DEFAULT_GLOBAL_Z_MATRIX 23 | ) 24 | current_dims = ShapeDictionary() 25 | current_dims[BONDS] = (coordinate_transform.dim_bonds - dataset.system.system.getNumConstraints(), ) 26 | current_dims[ANGLES] = (coordinate_transform.dim_angles, ) 27 | marginals = InternalCoordinateMarginals(current_dims, ctx) 28 | if with_data: 29 | constrained_indices, _ = bgmol.bond_constraints(dataset.system.system, coordinate_transform) 30 | marginals.inform_with_data( 31 | torch.tensor(dataset.xyz, **ctx), coordinate_transform, 32 | constrained_bond_indices=constrained_indices 33 | ) 34 | else: 35 | marginals.inform_with_force_field( 36 | dataset.system.system, coordinate_transform, 1000., 37 | ) 38 | -------------------------------------------------------------------------------- /tests/factory/test_tensor_info.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from bgflow.factory.tensor_info import ShapeDictionary, TensorInfo, BONDS, ANGLES, TORSIONS, FIXED 4 | 5 | def test_shape_info(crd_trafo): 6 | shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo) 7 | 8 | for key in [BONDS, ANGLES, TORSIONS]: 9 | assert shape_info[key] == (len(crd_trafo.z_matrix), ) 10 | assert shape_info[FIXED][0] == 3*len(crd_trafo.fixed_atoms) 11 | assert not (shape_info.is_circular([BONDS, TORSIONS])[: shape_info[BONDS][0]]).any() 12 | assert (shape_info.is_circular([BONDS, TORSIONS])[shape_info[BONDS][0]:]).all() 13 | assert shape_info.is_circular().sum() == shape_info[TORSIONS][0] 14 | assert ( 15 | shape_info.circular_indices([FIXED, TORSIONS]) 16 | == np.arange(shape_info[FIXED][0], shape_info[FIXED][0]+shape_info[TORSIONS][0]) 17 | ).all() 18 | assert ( 19 | shape_info.circular_indices() 20 | == np.arange(shape_info[BONDS][0]+shape_info[ANGLES][0], 21 | shape_info[BONDS][0]+shape_info[ANGLES][0]+shape_info[TORSIONS][0] 22 | ) 23 | ).all() 24 | assert shape_info.dim_all([BONDS, TORSIONS]) == shape_info[BONDS][0] + shape_info[TORSIONS][0] 25 | assert shape_info.dim_all() == 66 26 | assert shape_info.dim_circular([ANGLES, BONDS]) == 0 27 | assert shape_info.dim_circular() == shape_info[TORSIONS][0] 28 | assert shape_info.dim_noncircular([ANGLES, BONDS]) == shape_info[ANGLES][0] + shape_info[BONDS][0] 29 | assert shape_info.dim_noncircular() == 66 - shape_info[TORSIONS][0] 30 | 31 | assert shape_info.dim_cartesian([ANGLES, BONDS]) == 0 32 | assert shape_info.dim_cartesian([FIXED]) == shape_info[FIXED][0] 33 | assert shape_info.dim_noncartesian([ANGLES, BONDS]) == shape_info[ANGLES][0] + shape_info[BONDS][0] 34 | assert shape_info.dim_noncartesian([FIXED]) == 0 35 | assert not (shape_info.is_cartesian([BONDS, FIXED])[: shape_info[BONDS][0]]).any() 36 | assert ( 37 | shape_info.cartesian_indices() 38 | == np.arange(shape_info[BONDS][0]+shape_info[ANGLES][0]+shape_info[TORSIONS][0], 39 | shape_info[BONDS][0]+shape_info[ANGLES][0]+shape_info[TORSIONS][0]+shape_info[FIXED][0] 40 | ) 41 | ).all() 42 | 43 | 44 | 45 | 46 | def test_shape_info_insert(): 47 | shape_info = ShapeDictionary() 48 | for i in range(5): 49 | shape_info[i] = (i, ) 50 | shape_info.insert(100, 2, (100, )) 51 | assert list(shape_info) == [0, 1, 100, 2, 3, 4] 52 | assert list(shape_info.values()) == [(i, ) for i in [0, 1, 100, 2, 3, 4]] 53 | 54 | 55 | def test_shape_info_split_merge(): 56 | shape_info = ShapeDictionary() 57 | for i in range(8): 58 | shape_info[i] = (i, ) 59 | shape_info.split(4, into=("a", "b"), sizes=(1, 3)) 60 | assert list(shape_info) == [0, 1, 2, 3, "a", "b", 5, 6, 7] 61 | assert list(shape_info.values()) == [(i, ) for i in [0, 1, 2, 3, 1, 3, 5, 6, 7]] 62 | 63 | shape_info.merge(("a", "b"), to=4) 64 | assert list(shape_info) == [0, 1, 2, 3, 4, 5, 6, 7] 65 | assert list(shape_info.values()) == [(i, ) for i in [0, 1, 2, 3, 4, 5, 6, 7]] 66 | -------------------------------------------------------------------------------- /tests/factory/test_transformer_factory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import bgflow 5 | from bgflow import ( 6 | make_transformer, make_conditioners, 7 | ShapeDictionary, BONDS, FIXED, ANGLES, TORSIONS, 8 | ConditionalSplineTransformer, AffineTransformer 9 | ) 10 | 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "transformer_type", 15 | [ 16 | ConditionalSplineTransformer, 17 | AffineTransformer, 18 | # TODO: MixtureCDFTransformer 19 | ] 20 | ) 21 | def test_transformers(crd_trafo, transformer_type): 22 | pytest.importorskip("nflows") 23 | 24 | shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo) 25 | conditioners = make_conditioners(transformer_type, (BONDS,), (FIXED,), shape_info) 26 | transformer = make_transformer(transformer_type, (BONDS,), shape_info, conditioners=conditioners) 27 | out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]), torch.zeros(2, shape_info[BONDS][0])) 28 | assert out[0].shape == (2, shape_info[BONDS][0]) 29 | 30 | 31 | def test_circular_affine(crd_trafo): 32 | shape_info = ShapeDictionary.from_coordinate_transform(crd_trafo) 33 | 34 | with pytest.raises(ValueError): 35 | conditioners = make_conditioners( 36 | bgflow.AffineTransformer, 37 | (TORSIONS,), (FIXED,), shape_info=shape_info 38 | ) 39 | make_transformer(bgflow.AffineTransformer, (TORSIONS,), shape_info, conditioners=conditioners) 40 | 41 | conditioners = make_conditioners( 42 | bgflow.AffineTransformer, 43 | (TORSIONS,), (FIXED,), shape_info=shape_info, use_scaling=False 44 | ) 45 | assert list(conditioners.keys()) == ["shift_transformation"] 46 | transformer = make_transformer(bgflow.AffineTransformer, (TORSIONS,), shape_info, conditioners=conditioners) 47 | assert transformer._is_circular 48 | out = transformer.forward(torch.zeros(2, shape_info[FIXED][0]), torch.zeros(2, shape_info[TORSIONS][0])) 49 | assert out[0].shape == (2, shape_info[TORSIONS][0]) 50 | -------------------------------------------------------------------------------- /tests/nn/flow/dynamics/test_kernel_dynamics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow.distribution import NormalDistribution 4 | from bgflow.nn.flow import DiffEqFlow 5 | from bgflow.nn.flow.dynamics import KernelDynamics 6 | from bgflow.utils import brute_force_jacobian_trace 7 | 8 | 9 | @pytest.mark.parametrize("n_particles", [2, 3]) 10 | @pytest.mark.parametrize("n_dimensions", [2, 3]) 11 | @pytest.mark.parametrize("use_checkpoints", [True, False]) 12 | def test_kernel_dynamics(n_particles, n_dimensions, use_checkpoints, device): 13 | # Build flow with kernel dynamics and run initial config. 14 | 15 | dim = n_particles * n_dimensions 16 | n_samples = 100 17 | prior = NormalDistribution(dim).to(device) 18 | latent = prior.sample(n_samples) 19 | 20 | d_max = 8 21 | mus = torch.linspace(0, d_max, 10).to(device) 22 | gammas = 0.3 * torch.ones(len(mus)) 23 | 24 | mus_time = torch.linspace(0, 1, 5).to(device) 25 | gammas_time = 0.3 * torch.ones(len(mus_time)) 26 | 27 | kernel_dynamics = KernelDynamics(n_particles, n_dimensions, mus, gammas, optimize_d_gammas=True, 28 | optimize_t_gammas=True, mus_time=mus_time, gammas_time=gammas_time) 29 | 30 | flow = DiffEqFlow( 31 | dynamics=kernel_dynamics 32 | ).to(device) 33 | 34 | if not use_checkpoints: 35 | pytest.importorskip("torchdiffeq") 36 | 37 | samples, dlogp = flow(latent) 38 | latent2, ndlogp = flow.forward(samples, inverse=True) 39 | 40 | assert samples.shape == torch.Size([n_samples, dim]) 41 | assert dlogp.shape == torch.Size([n_samples, 1]) 42 | # assert (latent - latent2).abs().mean() < 0.002 43 | # assert (latent - samples).abs().mean() > 0.01 44 | # assert (dlogp + ndlogp).abs().mean() < 0.002 45 | 46 | if use_checkpoints: 47 | pytest.importorskip("anode") 48 | flow._use_checkpoints = True 49 | options = { 50 | "Nt": 20, 51 | "method": "RK4" 52 | } 53 | flow._kwargs = options 54 | 55 | samples, dlogp = flow(latent) 56 | latent2, ndlogp = flow.forward(samples, inverse=True) 57 | 58 | assert samples.shape == torch.Size([n_samples, dim]) 59 | assert dlogp.shape == torch.Size([n_samples, 1]) 60 | # assert (latent - latent2).abs().mean() < 0.002 61 | # assert (latent - samples).abs().mean() > 0.01 62 | # assert (dlogp + ndlogp).abs().mean() < 0.002 63 | 64 | 65 | @pytest.mark.parametrize("n_particles", [2, 3]) 66 | @pytest.mark.parametrize("n_dimensions", [2, 3]) 67 | def test_kernel_dynamics_trace(n_particles, n_dimensions): 68 | # Test if the trace computation of the kernel dynamics is correct. 69 | 70 | d_max = 8 71 | mus = torch.linspace(0, d_max, 10) 72 | gammas = 0.3 * torch.ones(len(mus)) 73 | 74 | mus_time = torch.linspace(0, 1, 5) 75 | gammas_time = 0.3 * torch.ones(len(mus_time)) 76 | 77 | kernel_dynamics = KernelDynamics(n_particles, n_dimensions, mus, gammas, mus_time=mus_time, gammas_time=gammas_time) 78 | x = torch.Tensor(1, n_particles * n_dimensions).normal_().requires_grad_(True) 79 | y, trace = kernel_dynamics(1., x) 80 | brute_force_trace = brute_force_jacobian_trace(y, x) 81 | 82 | # The kernel dynamics outputs the negative trace 83 | assert torch.allclose(trace.sum(), -brute_force_trace[0], atol=1e-4) 84 | 85 | # test kernel dynamics without trace 86 | kernel_dynamics(1., x, compute_divergence=False) 87 | -------------------------------------------------------------------------------- /tests/nn/flow/estimators/test_hutchinson_estimator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow.distribution import NormalDistribution 4 | from bgflow.nn.flow.dynamics import TimeIndependentDynamics 5 | from bgflow.nn.flow.estimator import HutchinsonEstimator 6 | from bgflow.nn import DenseNet 7 | from bgflow.utils import brute_force_jacobian_trace 8 | 9 | 10 | @pytest.mark.parametrize("dim", [1, 2]) 11 | @pytest.mark.parametrize("rademacher", [True, False]) 12 | def test_hutchinson_estimator(dim, rademacher): 13 | # Test trace estimation of the hutchinson estimator for small dimensions, where it is less noisy 14 | n_batch = 1024 15 | time_independent_dynamics = TimeIndependentDynamics( 16 | DenseNet([dim, 16, 16, dim], activation=torch.nn.Tanh())) 17 | hutchinson_estimator = HutchinsonEstimator(rademacher) 18 | normal_distribution = NormalDistribution(dim) 19 | x = normal_distribution.sample(n_batch) 20 | y, trace = hutchinson_estimator(time_independent_dynamics, None, x) 21 | brute_force_trace = brute_force_jacobian_trace(y, x) 22 | if rademacher and dim == 1: 23 | # Hutchinson is exact for rademacher noise and dim=1 24 | assert torch.allclose(trace.mean(), -brute_force_trace.mean(), atol=1e-6) 25 | else: 26 | assert torch.allclose(trace.mean(), -brute_force_trace.mean(), atol=1e-1) 27 | 28 | 29 | @pytest.mark.parametrize("rademacher", [True, False]) 30 | def test_test_hutchinson_estimator_reset_noise(rademacher): 31 | # Test if the noise vector is resetted to deal with different shape 32 | dim = 10 33 | time_independent_dynamics = TimeIndependentDynamics( 34 | DenseNet([dim, 16, 16, dim], activation=torch.nn.Tanh())) 35 | hutchinson_estimator = HutchinsonEstimator(rademacher) 36 | normal_distribution = NormalDistribution(dim) 37 | 38 | x = normal_distribution.sample(100) 39 | _, _ = hutchinson_estimator(time_independent_dynamics, None, x) 40 | x = normal_distribution.sample(10) 41 | hutchinson_estimator.reset_noise() 42 | # this will fail if the noise is not resetted 43 | _, _ = hutchinson_estimator(time_independent_dynamics, None, x) 44 | -------------------------------------------------------------------------------- /tests/nn/flow/test_cdf.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from bgflow import ( 5 | DistributionTransferFlow, ConstrainGaussianFlow, 6 | CDFTransform, InverseFlow, TruncatedNormalDistribution 7 | ) 8 | from torch.distributions import Normal 9 | 10 | 11 | def test_distribution_transfer(ctx): 12 | src = Normal(torch.zeros(2, **ctx), torch.ones(2, **ctx)) 13 | target = Normal(torch.ones(2, **ctx), torch.ones(2, **ctx)) 14 | swap = DistributionTransferFlow(src, target) 15 | # forward 16 | out, dlogp = swap.forward(torch.zeros((2,2), **ctx)) 17 | assert torch.allclose(out, torch.ones(2,2, **ctx)) 18 | assert torch.allclose(dlogp, torch.zeros(2,1, **ctx)) 19 | # inverse 20 | out2, dlogp = swap.forward(out, inverse=True) 21 | assert torch.allclose(out2, torch.zeros(2,2, **ctx)) 22 | assert torch.allclose(dlogp, torch.zeros(2,1, **ctx)) 23 | 24 | 25 | def test_constrain_positivity(ctx): 26 | """Make sure that the bonds are obeyed.""" 27 | torch.manual_seed(1) 28 | constrain_flow = ConstrainGaussianFlow(mu=torch.ones(10, **ctx), lower_bound=1e-10) 29 | samples = (1.0+torch.randn((10,10), **ctx)) * 1000. 30 | y, dlogp = constrain_flow.forward(samples) 31 | assert y.shape == (10, 10) 32 | assert dlogp.shape == (10, 1) 33 | assert (y >= 0.0).all() 34 | assert (dlogp.sum() < 0.0).all() 35 | 36 | 37 | def test_constrain_slightly_pertubed(ctx): 38 | """Check that samples are not changed much when the bounds are generous.""" 39 | torch.manual_seed(1) 40 | constrain_flow = ConstrainGaussianFlow(mu=torch.ones(10, **ctx), sigma=torch.ones(10, **ctx), lower_bound=-1000., upper_bound=1000.) 41 | samples = (1.0+torch.randn((10,10), **ctx)) 42 | y, dlogp = constrain_flow.forward(samples) 43 | assert torch.allclose(samples, y, atol=1e-4, rtol=0.0) 44 | assert torch.allclose(dlogp, torch.zeros_like(dlogp), atol=1e-4, rtol=0.0) 45 | 46 | x2, dlogp = constrain_flow.forward(y, inverse=True) 47 | assert torch.allclose(x2, y, atol=1e-4, rtol=0.0) 48 | assert torch.allclose(dlogp, torch.zeros_like(dlogp), atol=1e-4, rtol=0.0) 49 | 50 | 51 | def test_cdf_transform(ctx): 52 | input = torch.arange(0.1, 1.0, 0.1, **ctx)[:,None] 53 | input.requires_grad = True 54 | truncated_normal = TruncatedNormalDistribution( 55 | mu=torch.tensor([0.5], **ctx), 56 | upper_bound=torch.tensor([1.0], **ctx), 57 | is_learnable=True 58 | ) 59 | flow = InverseFlow(CDFTransform((truncated_normal))) 60 | output, dlogp = flow.forward(input) 61 | assert output.mean().item() == pytest.approx(0.5) 62 | # try computing the grad 63 | output.mean().backward(create_graph=True) 64 | dlogp.mean().backward() 65 | 66 | 67 | -------------------------------------------------------------------------------- /tests/nn/flow/test_inverted.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test inverse and its derivative 3 | """ 4 | 5 | import torch 6 | import pytest 7 | from bgflow.nn import flow 8 | 9 | 10 | @pytest.fixture(params=[ 11 | flow.KroneckerProductFlow(2), 12 | flow.PseudoOrthogonalFlow(2), 13 | flow.BentIdentity(), 14 | flow.FunnelFlow(), 15 | flow.AffineFlow(2), 16 | flow.SplitFlow(1), 17 | flow.SplitFlow(1,1), 18 | flow.TriuFlow(2), 19 | flow.InvertiblePPPP(2), 20 | flow.TorchTransform(torch.distributions.IndependentTransform(torch.distributions.SigmoidTransform(), 1)) 21 | ]) 22 | def simpleflow2d(request): 23 | return request.param 24 | 25 | 26 | def test_inverse(simpleflow2d): 27 | """Test inverse and inverse logDet of simple 2d flow blocks.""" 28 | inverse = flow.InverseFlow(simpleflow2d) 29 | x = torch.tensor([[1., 2.]]) 30 | *y, dlogp = simpleflow2d._forward(x) 31 | x2, dlogpinv = inverse._forward(*y) 32 | assert (dlogp + dlogpinv).detach().numpy() == pytest.approx(0.0, abs=1e-6) 33 | assert torch.norm(x2 - x).item() == pytest.approx(0.0, abs=1e-6) 34 | 35 | # test dimensions 36 | assert x2.shape == x.shape 37 | assert dlogp.shape == x[..., 0, None].shape 38 | assert dlogpinv.shape == x[..., 0, None].shape 39 | 40 | -------------------------------------------------------------------------------- /tests/nn/flow/test_modulo.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from bgflow import IncreaseMultiplicityFlow 5 | from bgflow import CircularShiftFlow 6 | 7 | 8 | @pytest.mark.parametrize('mult', [1, torch.ones(3, dtype=torch.int)]) 9 | def test_IncreaseMultiplicityFlow(ctx, mult): 10 | m = 3 11 | mult = m * mult 12 | flow = IncreaseMultiplicityFlow(mult).to(**ctx) 13 | 14 | x = 1/6 + torch.linspace(0, 1, m + 1)[:-1].to(**ctx) 15 | x = torch.tile(x[None, ...], (1000, 1)) 16 | y, dlogp = flow.forward(x, inverse=True) 17 | assert y.shape == x.shape 18 | assert torch.allclose(y, 1/2 * torch.ones_like(y)) 19 | 20 | x2, dlogp2 = flow.forward(y) 21 | for point in (1/6, 3/6, 5/6): 22 | count = torch.sum(torch.isclose(x2, point * torch.ones_like(x2))) 23 | assert count > 800 24 | assert count < 1200 25 | assert torch.allclose(dlogp, torch.zeros_like(dlogp)) 26 | assert torch.allclose(dlogp2, torch.zeros_like(dlogp)) 27 | 28 | @pytest.mark.parametrize('shift', [1, torch.ones(3)]) 29 | def test_CircularShiftFlow(ctx, shift): 30 | m = 0.2 31 | shift = m * shift 32 | flow = CircularShiftFlow(shift).to(**ctx) 33 | 34 | x = torch.tensor([0.0, 0.2, 0.9]).to(**ctx) 35 | x_shifted = torch.tensor([0.2, 0.4, 0.1]).to(**ctx) 36 | y, dlogp = flow.forward(x) 37 | assert y.shape == x.shape 38 | assert torch.allclose(y, x_shifted) 39 | 40 | x2, dlogp2 = flow.forward(y, inverse=True) 41 | assert torch.allclose(x, x2) 42 | assert torch.allclose(dlogp, torch.zeros_like(dlogp)) 43 | assert torch.allclose(dlogp2, torch.zeros_like(dlogp)) 44 | -------------------------------------------------------------------------------- /tests/nn/flow/test_nODE.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow.distribution import NormalDistribution 4 | from bgflow.nn.flow import DiffEqFlow 5 | from bgflow.nn.flow.dynamics import BlackBoxDynamics, TimeIndependentDynamics 6 | from bgflow.nn.flow.estimator import BruteForceEstimator 7 | 8 | dim = 1 9 | n_samples = 100 10 | prior = NormalDistribution(dim) 11 | latent = prior.sample(n_samples) 12 | 13 | 14 | class SimpleDynamics(torch.nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, xs): 19 | dxs = - 1 * xs 20 | return dxs 21 | 22 | 23 | def make_black_box_flow(): 24 | black_box_dynamics = BlackBoxDynamics( 25 | dynamics_function=TimeIndependentDynamics(SimpleDynamics()), 26 | divergence_estimator=BruteForceEstimator() 27 | ) 28 | 29 | flow = DiffEqFlow( 30 | dynamics=black_box_dynamics 31 | ) 32 | return flow 33 | 34 | 35 | def test_nODE_flow_OTD(): 36 | # Test forward pass of simple nODE with the OTD solver 37 | flow = make_black_box_flow() 38 | 39 | try: 40 | samples, dlogp = flow(latent) 41 | except ImportError: 42 | pytest.skip("Test requires torchdiffeq.") 43 | 44 | assert samples.std() < 1 45 | assert torch.allclose(dlogp, -torch.ones(n_samples)) 46 | 47 | # Test backward pass of simple nODE with the OTD solver 48 | try: 49 | samples, dlogp = flow(latent, inverse=True) 50 | except ImportError: 51 | pytest.skip("Test requires torchdiffeq.") 52 | 53 | assert samples.std() > 1 54 | assert torch.allclose(dlogp, torch.ones(n_samples)) 55 | 56 | 57 | def test_nODE_flow_DTO(): 58 | # Test forward pass of simple nODE with the DTO solver 59 | flow = make_black_box_flow() 60 | 61 | flow._use_checkpoints = True 62 | options = { 63 | "Nt": 20, 64 | "method": "RK4" 65 | } 66 | flow._kwargs = options 67 | 68 | try: 69 | samples, dlogp = flow(latent) 70 | except ImportError: 71 | pytest.skip("Test requires anode.") 72 | 73 | assert samples.std() < 1 74 | assert torch.allclose(dlogp, -torch.ones(n_samples)) 75 | 76 | # Test backward pass of simple nODE with the DTO solver 77 | try: 78 | samples, dlogp = flow(latent, inverse=True) 79 | except ImportError: 80 | pytest.skip("Test requires torchdiffeq.") 81 | 82 | assert samples.std() > 1 83 | assert torch.allclose(dlogp, torch.ones(n_samples)) 84 | -------------------------------------------------------------------------------- /tests/nn/flow/test_pppp.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from bgflow.nn.flow.pppp import InvertiblePPPP, PPPPScheduler, _iterative_solve 5 | from bgflow.nn.flow.sequential import SequentialFlow 6 | 7 | 8 | def test_invertible_pppp(): 9 | flow = InvertiblePPPP(2, penalty_parameter=0.1) 10 | x = torch.tensor([[3.1, -2.0]]) 11 | with torch.no_grad(): 12 | flow.u[:] = torch.tensor([0.4, -0.3]) 13 | flow.v[:] = torch.tensor([0.1, 0.2]) 14 | y, logdet = flow.forward(x) 15 | 16 | x2, logdetinv = flow.forward(y, inverse=True) 17 | assert torch.isclose(x2, x, atol=1e-5).all() 18 | 19 | assert(flow.penalty() > -1) 20 | flow.pppp_merge() 21 | y2, logdet2 = flow.forward(x) 22 | x3, logdetinv2 = flow._inverse(y2, inverse=True) 23 | 24 | assert torch.isclose(torch.mm(flow.A, flow.Ainv), torch.eye(2)).all() 25 | assert torch.isclose(x3, x, atol=1e-5).all() 26 | assert torch.isclose(logdet, -logdetinv, atol=1e-5) 27 | assert torch.isclose(logdet, logdet2, atol=1e-5) 28 | assert torch.isclose(logdet, -logdetinv2, atol=1e-5) 29 | assert torch.isclose(y, y2, atol=1e-5).all() 30 | assert torch.isclose(flow.u, torch.zeros(1), atol=1e-5).all() 31 | 32 | # test training mode 33 | flow.train(False) 34 | y_test, logdet_test = flow.forward(x) 35 | assert torch.isclose(y, y_test, atol=1e-5).all() 36 | assert torch.isclose(logdet, logdet_test, atol=1e-5) 37 | x_test, logdetinv_test = flow.forward(y, inverse=True) 38 | assert torch.isclose(x, x_test, atol=1e-5).all() 39 | assert torch.isclose(-logdet, logdetinv_test, atol=1e-5) 40 | 41 | 42 | @pytest.mark.parametrize("mode", ["eye", "reverse"]) 43 | @pytest.mark.parametrize("dim", [1, 3, 5, 6, 7, 10, 20]) 44 | def test_initialization(mode, dim): 45 | flow = InvertiblePPPP(dim, init=mode) 46 | assert flow.A.inverse().numpy() == pytest.approx(flow.Ainv.numpy()) 47 | assert torch.det(flow.A).item() == pytest.approx(flow.detA.item()) 48 | 49 | 50 | @pytest.mark.parametrize("order", [2, 3, 7]) 51 | def test_iterative_solve(order): 52 | torch.seed = 0 53 | a = torch.eye(3)+0.05*torch.randn(3, 3) 54 | inv = torch.inverse(a) + 0.2*torch.randn(3,3) 55 | for i in range(10): 56 | inv = _iterative_solve(a, inv, order) 57 | assert torch.mm(inv, a).numpy() == pytest.approx(torch.eye(3).numpy(), abs=1e-5) 58 | 59 | 60 | @pytest.mark.parametrize("model", [InvertiblePPPP(3), SequentialFlow([InvertiblePPPP(3), InvertiblePPPP(3)])]) 61 | def test_scheduler(model): 62 | """API test for a full optimization workflow.""" 63 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-1) 64 | scheduler = PPPPScheduler(model, optimizer, n_force_merge=5, n_correct=5, n_recompute_det=5) 65 | torch.manual_seed(0) 66 | a = 1e-6*torch.eye(3) 67 | data = torch.ones(800, 3) 68 | target = torch.einsum("ij,...j->...i", a, data) 69 | loss = torch.nn.MSELoss() 70 | assert loss(data, target) > 1e-1 71 | for iter in range(100): 72 | optimizer.zero_grad() 73 | batch = data[8*iter:8*(iter+1)] 74 | y, _ = model.forward(batch) 75 | mse = loss(y, target[8*iter:8*(iter+1)]) 76 | mse_plus_penalty = mse + scheduler.penalty() 77 | mse_plus_penalty.backward() 78 | optimizer.step() 79 | if iter % 10 == 0: 80 | scheduler.step() 81 | assert scheduler.penalty().item() >= 0.0 82 | assert mse < 2e-2 # check that the model has improved 83 | assert scheduler.i == 10 84 | 85 | 86 | -------------------------------------------------------------------------------- /tests/nn/flow/test_sequential.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from bgflow.nn.flow.sequential import SequentialFlow 4 | from bgflow.nn.flow.orthogonal import PseudoOrthogonalFlow 5 | from bgflow.nn.flow.elementwise import BentIdentity 6 | from bgflow.nn.flow.triangular import TriuFlow 7 | 8 | 9 | def test_trigger_penalty(): 10 | flow = SequentialFlow([ 11 | PseudoOrthogonalFlow(3), 12 | PseudoOrthogonalFlow(3), 13 | PseudoOrthogonalFlow(3), 14 | ]) 15 | penalties = flow.trigger("penalty") 16 | assert len(penalties) == 3 17 | 18 | 19 | def test_getitem(): 20 | a = BentIdentity() 21 | b = TriuFlow(2) 22 | flow = SequentialFlow([a,b]) 23 | assert flow[0] == a 24 | assert flow[1] == b 25 | subflow = flow[[0]] 26 | assert len(subflow) == 1 27 | assert isinstance(subflow, SequentialFlow) 28 | assert subflow[0] == a -------------------------------------------------------------------------------- /tests/nn/flow/test_torchtransform.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.distributions import SigmoidTransform, AffineTransform, IndependentTransform 4 | from bgflow import TorchTransform, SequentialFlow, BentIdentity 5 | 6 | 7 | def test_torch_transform(ctx): 8 | """try using torch.Transform in combination with bgflow.Flow""" 9 | torch.manual_seed(10) 10 | x = torch.torch.randn(10, 3, **ctx) 11 | flow = SequentialFlow([ 12 | TorchTransform(IndependentTransform(SigmoidTransform(), 1)), 13 | TorchTransform( 14 | AffineTransform( 15 | loc=torch.randn(3, **ctx), 16 | scale=2.0+torch.rand(3, **ctx), event_dim=1 17 | ), 18 | ), 19 | BentIdentity(), 20 | # test the reinterpret_batch_ndims arguments 21 | TorchTransform(SigmoidTransform(), 1) 22 | ]) 23 | z, dlogp = flow.forward(x) 24 | y, neg_dlogp = flow.forward(z, inverse=True) 25 | tol = 1e-7 if ctx["dtype"] is torch.float64 else 1e-5 26 | assert torch.allclose(x, y, atol=tol) 27 | assert torch.allclose(dlogp, -neg_dlogp, atol=tol) 28 | -------------------------------------------------------------------------------- /tests/nn/flow/test_triangular.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import pytest 4 | from bgflow.nn.flow.triangular import TriuFlow 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "b", [ 9 | torch.tensor(0.), 10 | torch.nn.Parameter(torch.tensor([1.,4.,1.])), 11 | torch.nn.Parameter(torch.zeros(3)) 12 | ]) 13 | def test_invert(b): 14 | tf = TriuFlow(3, shift=(isinstance(b, torch.nn.Parameter))) 15 | tf._unique_elements = torch.nn.Parameter(torch.rand_like(tf._unique_elements)) 16 | tf._make_r() 17 | tf.b = b 18 | x = torch.randn(10,3) 19 | y, dlogp = tf._forward(x) 20 | x2, dlogpinv = tf._inverse(y) 21 | assert torch.norm(dlogp + dlogpinv).item() == pytest.approx(0.0) 22 | assert torch.norm(x-x2).item() == pytest.approx(0.0, abs=1e-6) 23 | -------------------------------------------------------------------------------- /tests/nn/flow/transformer/test_affine.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import torch 4 | from bgflow import AffineTransformer, DenseNet 5 | 6 | 7 | class ShiftModule(torch.nn.Module): 8 | def forward(self, x): 9 | return torch.ones_like(x) 10 | 11 | 12 | @pytest.mark.parametrize("is_circular", [True, False]) 13 | @pytest.mark.parametrize("use_scale_transform", [True, False]) 14 | def test_affine(is_circular, use_scale_transform): 15 | 16 | if use_scale_transform: 17 | scale = DenseNet([2, 2]) 18 | else: 19 | scale = None 20 | 21 | if use_scale_transform and is_circular: 22 | with pytest.raises(ValueError): 23 | trafo = AffineTransformer( 24 | shift_transformation=ShiftModule(), 25 | scale_transformation=scale, 26 | is_circular=is_circular 27 | ) 28 | 29 | else: 30 | trafo = AffineTransformer( 31 | shift_transformation=ShiftModule(), 32 | scale_transformation=scale, 33 | is_circular=is_circular 34 | ) 35 | x = torch.rand(100, 2) 36 | y = torch.rand(100, 2) 37 | y2, dlogp = trafo.forward(x, y) 38 | assert y2.shape == y.shape 39 | if is_circular: 40 | assert (y2 < 1).all() 41 | else: 42 | assert (y2 > 1).any() 43 | 44 | -------------------------------------------------------------------------------- /tests/nn/flow/transformer/test_gaussian.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from bgflow.nn.flow.transformer.gaussian import TruncatedGaussianTransformer 4 | from bgflow.distribution.normal import TruncatedNormalDistribution 5 | 6 | 7 | def test_constrained_affine_transformer(ctx): 8 | tol = 5e-4 if ctx["dtype"] == torch.float32 else 5e-7 9 | mu_net = torch.nn.Linear(3, 2, bias=False) 10 | sigma_net = torch.nn.Linear(3, 2, bias=False) 11 | mu_net.weight.data = torch.ones_like(mu_net.weight.data) 12 | sigma_net.weight.data = torch.ones_like(sigma_net.weight.data) 13 | constrained = TruncatedGaussianTransformer(mu_net, sigma_net, -5.0, 5.0, -8.0, 0.0).to(**ctx) 14 | 15 | # test if forward and inverse are compatible 16 | x = torch.ones(1, 3, **ctx) 17 | y = torch.tensor([[-2.5, 2.5]], **ctx) 18 | y.requires_grad = True 19 | out, dlogp = constrained.forward(x, y) 20 | assert not torch.allclose(dlogp, torch.zeros_like(dlogp)) 21 | assert (out >= torch.tensor(-8.0, **ctx)).all() 22 | assert (out <= torch.tensor(0.20, **ctx)).all() 23 | y2, neg_dlogp = constrained.forward(x, out, inverse=True) 24 | assert torch.allclose(y, y2, atol=tol) 25 | assert torch.allclose(dlogp + neg_dlogp, torch.zeros_like(dlogp), atol=tol) 26 | 27 | # test if the log det agrees with the log prob of a truncated normal distribution 28 | mu = torch.einsum("ij,...j->...i", mu_net.weight.data, x) 29 | _, logsigma = constrained._get_mu_and_log_sigma(x, y) # not reiterating the tanh stuff 30 | sigma = torch.exp(logsigma) 31 | trunc_gaussian = TruncatedNormalDistribution(mu, sigma, torch.tensor(-5, **ctx), torch.tensor(5, **ctx)) 32 | log_prob = trunc_gaussian.log_prob(y) 33 | log_scale = torch.log(torch.tensor(8., **ctx)) 34 | assert torch.allclose(dlogp, (log_prob + log_scale).sum(dim=-1, keepdim=True), atol=tol) 35 | 36 | # try backward pass and assert reasonable gradients 37 | y2.sum().backward(create_graph=True) 38 | neg_dlogp.backward() 39 | for tensor in [y, mu_net.weight]: 40 | assert (tensor.grad > -1e6).all() 41 | assert (tensor.grad < 1e6).all() 42 | -------------------------------------------------------------------------------- /tests/nn/flow/transformer/test_spline.py: -------------------------------------------------------------------------------- 1 | """Test spline transformer""" 2 | 3 | import pytest 4 | import torch 5 | from bgflow import ConditionalSplineTransformer, CouplingFlow, SplitFlow, NormalDistribution, DenseNet 6 | 7 | 8 | @pytest.mark.parametrize("is_circular", [True, False]) 9 | def test_conditional_spline_transformer_api(is_circular, ctx): 10 | pytest.importorskip("nflows") 11 | 12 | n_bins = 4 13 | dim_trans = 10 14 | n_samples = 10 15 | dim_cond = 9 16 | x_cond = torch.rand((n_samples, dim_cond), **ctx) 17 | x_trans = torch.rand((n_samples, dim_trans)).to(x_cond) 18 | 19 | if is_circular: 20 | dim_net_out = 3 * n_bins * dim_trans 21 | else: 22 | dim_net_out = (3 * n_bins + 1) * dim_trans 23 | conditioner = DenseNet([dim_cond, dim_net_out]) 24 | 25 | transformer = ConditionalSplineTransformer( 26 | params_net=conditioner, 27 | is_circular=is_circular, 28 | ).to(x_cond) 29 | 30 | y, dlogp = transformer.forward(x_cond, x_trans) 31 | 32 | assert (y > 0.0).all() 33 | assert (y < 1.0).all() 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "is_circular", [torch.tensor(True), torch.tensor(False), torch.tensor([True, False], dtype=torch.bool)] 38 | ) 39 | def test_conditional_spline_continuity(is_circular, ctx): 40 | pytest.importorskip("nflows") 41 | torch.manual_seed(2150) 42 | 43 | n_bins = 3 44 | dim_trans = 2 45 | n_samples = 1 46 | dim_cond = 1 47 | x_cond = torch.rand((n_samples, dim_cond), **ctx) 48 | 49 | if is_circular.all(): 50 | dim_net_out = 3 * n_bins * dim_trans 51 | elif not is_circular.any(): 52 | dim_net_out = (3 * n_bins + 1) * dim_trans 53 | else: 54 | dim_net_out = 3 * n_bins * dim_trans + int(is_circular.sum()) 55 | conditioner = DenseNet([dim_cond, dim_net_out], bias_scale=2.) 56 | 57 | transformer = ConditionalSplineTransformer( 58 | params_net=conditioner, 59 | is_circular=is_circular, 60 | ).to(x_cond) 61 | 62 | slopes = transformer._compute_params(x_cond, dim_trans)[2] 63 | continuous = torch.isclose(slopes[0,:,0], slopes[0,:,-1]).tolist() 64 | if is_circular.all(): 65 | assert continuous == [True, True] 66 | elif not is_circular.any(): 67 | assert continuous == [False, False] 68 | else: 69 | assert continuous == [True, False] -------------------------------------------------------------------------------- /tests/nn/test_wrap_distances.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from bgflow.nn.periodic import WrapDistances 4 | 5 | 6 | def test_wrap_distances(): 7 | positions = torch.tensor([[[0.,0.,0.],[0.,0.,1.],[0.,2.,0.]]]) 8 | positions_flat = positions.view(positions.shape[0],-1) 9 | module = torch.nn.ReLU() 10 | wrapper = WrapDistances(module) 11 | result = wrapper.forward(positions_flat) 12 | expected = torch.tensor([[1.,2.,np.sqrt(5)]]).to(positions) 13 | assert torch.allclose(result, expected) 14 | -------------------------------------------------------------------------------- /tests/test_readme.py: -------------------------------------------------------------------------------- 1 | 2 | # If this fails the example in the readme wont work! 3 | 4 | def test_readme(): 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import bgflow as bg 8 | 9 | # define prior and target 10 | dim = 2 11 | prior = bg.NormalDistribution(dim) 12 | target = bg.DoubleWellEnergy(dim) 13 | 14 | # here we aggregate all layers of the flow 15 | layers = [] 16 | layers.append(bg.SplitFlow(dim // 2)) 17 | layers.append(bg.CouplingFlow( 18 | # we use a affine transformation to transform the RHS conditioned on the LHS 19 | bg.AffineTransformer( 20 | # use simple dense nets for the affine shift/scale 21 | shift_transformation=bg.DenseNet( 22 | [dim // 2, 4, dim // 2], 23 | activation=torch.nn.ReLU() 24 | ), 25 | scale_transformation=bg.DenseNet( 26 | [dim // 2, 4, dim // 2], 27 | activation=torch.nn.Tanh() 28 | ) 29 | ) 30 | )) 31 | layers.append(bg.InverseFlow(bg.SplitFlow(dim // 2))) 32 | 33 | # now define the flow as a sequence of all operations stored in layers 34 | flow = bg.SequentialFlow(layers) 35 | 36 | # The BG is defined by a prior, target and a flow 37 | generator = bg.BoltzmannGenerator(prior, flow, target) 38 | 39 | # sample from the BG 40 | samples = generator.sample(1000) 41 | plt.hist2d( 42 | samples[:, 0].detach().numpy(), 43 | samples[:, 1].detach().numpy(), bins=100 44 | ) -------------------------------------------------------------------------------- /tests/utils/test_autograd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from bgflow.utils import ( 5 | brute_force_jacobian, brute_force_jacobian_trace, get_jacobian, 6 | batch_jacobian, requires_grad 7 | ) 8 | 9 | 10 | x = torch.tensor([[1., 2, 3]], requires_grad=True) 11 | y = x.pow(2) + x[:, 1] 12 | true_jacobian = np.array([[[2., 1, 0], [0, 5, 0], [0, 1, 6]]]) 13 | 14 | 15 | def test_brute_force_jacobian_trace(): 16 | jacobian_trace = brute_force_jacobian_trace(y, x) 17 | assert jacobian_trace.detach().numpy() == pytest.approx(np.array([13.]), abs=1e-6) 18 | 19 | 20 | def test_brute_force_jacobian(): 21 | jacobian = brute_force_jacobian(y, x) 22 | assert jacobian.detach().numpy() == pytest.approx(true_jacobian, abs=1e-6) 23 | 24 | 25 | def test_batch_jacobian(ctx): 26 | t = x.repeat(10, 1).to(**ctx) 27 | out = t.pow(2) + t[:, [1]] 28 | expected = torch.tensor(true_jacobian, **ctx).repeat(10, 1, 1) 29 | assert torch.allclose(batch_jacobian(out, t), expected) 30 | 31 | 32 | def test_get_jacobian(ctx): 33 | t = torch.ones((2,), **ctx) 34 | func = lambda s: s**2 35 | expected = 2*t*torch.eye(2, **ctx) 36 | assert torch.allclose(get_jacobian(func, t).jac, expected) 37 | 38 | 39 | def test_requires_grad(ctx): 40 | t = torch.zeros((2,), **ctx) 41 | assert not t.requires_grad 42 | with requires_grad(t): 43 | assert t.requires_grad 44 | -------------------------------------------------------------------------------- /tests/utils/test_free_energy.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from bgflow.distribution import NormalDistribution 7 | 8 | from bgflow.utils.free_energy import bennett_acceptance_ratio 9 | 10 | 11 | @pytest.mark.parametrize("method", ["torch", "pymbar"]) 12 | @pytest.mark.parametrize("compute_uncertainty", [True, False]) 13 | def test_bar(ctx, method, compute_uncertainty): 14 | pytest.importorskip("pymbar") 15 | dim = 1 16 | energy1 = NormalDistribution(dim, mean=torch.zeros(dim, **ctx)) 17 | energy2 = NormalDistribution(dim, mean=0.2*torch.ones(dim, **ctx)) 18 | samples1 = energy1.sample(10000) 19 | samples2 = energy2.sample(20000) 20 | 21 | free_energy, uncertainty = bennett_acceptance_ratio( 22 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1), 23 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)), 24 | implementation=method, 25 | compute_uncertainty=compute_uncertainty 26 | ) 27 | assert free_energy.item() == pytest.approx(1., abs=1e-2) 28 | if compute_uncertainty: 29 | assert uncertainty.item() < 1e-2 30 | else: 31 | assert uncertainty is None 32 | 33 | 34 | @pytest.mark.parametrize("method", ["torch", "pymbar"]) 35 | @pytest.mark.parametrize("warn", [True, False]) 36 | def test_bar_no_convergence(ctx, method, warn): 37 | with warnings.catch_warnings(): 38 | warnings.simplefilter("ignore", RuntimeWarning) 39 | pytest.importorskip("pymbar") 40 | dim = 1 41 | energy1 = NormalDistribution(dim, mean=-1e20*torch.ones(dim, **ctx)) 42 | energy2 = NormalDistribution(dim, mean=1e20*torch.ones(dim, **ctx)) 43 | samples1 = energy1.sample(5) 44 | samples2 = energy2.sample(5) 45 | 46 | if warn: 47 | # test if warning is raised 48 | with pytest.warns(UserWarning, match="BAR could not"): 49 | free_energy, uncertainty = bennett_acceptance_ratio( 50 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1), 51 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)), 52 | implementation=method, 53 | warn=warn 54 | ) 55 | else: 56 | free_energy, uncertainty = bennett_acceptance_ratio( 57 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1), 58 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)), 59 | implementation=method, 60 | warn=warn 61 | ) 62 | assert np.isnan(free_energy.item()) 63 | assert np.isnan(uncertainty.item()) 64 | 65 | 66 | def test_bar_uncertainty(ctx): 67 | """test consistency with the reference implementation""" 68 | pytest.importorskip("pymbar") 69 | dim = 1 70 | energy1 = NormalDistribution(dim, mean=torch.zeros(dim, **ctx)) 71 | energy2 = NormalDistribution(dim, mean=0.2*torch.ones(dim, **ctx)) # will be multiplied by e 72 | samples1 = energy1.sample(1000) 73 | samples2 = energy2.sample(2000) 74 | 75 | free_energy1, uncertainty1 = bennett_acceptance_ratio( 76 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1), 77 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)), 78 | implementation="torch" 79 | ) 80 | free_energy2, uncertainty2 = bennett_acceptance_ratio( 81 | forward_work=(1.0 + energy2.energy(samples1)) - energy1.energy(samples1), 82 | reverse_work=energy1.energy(samples2) - (1.0 + energy2.energy(samples2)), 83 | implementation="pymbar" 84 | ) 85 | assert free_energy1.item() == pytest.approx(free_energy2.item(), rel=1e-3) 86 | assert uncertainty1.item() == pytest.approx(uncertainty2.item(), rel=1e-3) 87 | -------------------------------------------------------------------------------- /tests/utils/test_geometry.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow.utils import distance_vectors, distances_from_vectors, remove_mean, compute_distances 4 | 5 | 6 | @pytest.mark.parametrize("remove_diagonal", [True, False]) 7 | def test_distances_from_vectors(remove_diagonal): 8 | """Test if distances are calculated correctly from a particle configuration.""" 9 | particles = torch.tensor([[[3., 0], [3, 0], [0, 4]]]) 10 | distances = distances_from_vectors( 11 | distance_vectors(particles, remove_diagonal=remove_diagonal)) 12 | if remove_diagonal == True: 13 | assert torch.allclose(distances, torch.tensor([[[0., 5], [0, 5], [5, 5]]]), atol=1e-2) 14 | else: 15 | assert torch.allclose(distances, torch.tensor([[[0., 0, 5], [0, 0, 5], [5, 5, 0]]]), atol=1e-2) 16 | 17 | 18 | def test_mean_free(ctx): 19 | """Test if the mean of random configurations is removed correctly""" 20 | samples = torch.rand(100, 100, 3, **ctx) - 0.3 21 | samples = remove_mean(samples, 100, 3) 22 | mean_deviation = samples.mean(dim=(1, 2)) 23 | threshold = 1e-5 24 | assert torch.all(mean_deviation.abs() < threshold) 25 | 26 | @pytest.mark.parametrize("remove_duplicates", [True, False]) 27 | def test_compute_distances(remove_duplicates, ctx): 28 | """Test if distances are calculated correctly from a particle configuration.""" 29 | particles = torch.tensor([[[3., 0], [3, 0], [0, 4]]], **ctx) 30 | distances = compute_distances(particles, n_particles=3, n_dimensions=2, remove_duplicates=remove_duplicates) 31 | if remove_duplicates == True: 32 | assert torch.allclose(distances, torch.tensor([[0., 5, 5]], **ctx), atol=1e-5) 33 | else: 34 | assert torch.allclose(distances, torch.tensor([[[0., 0, 5], [0, 0, 5], [5, 5, 0]]], **ctx), atol=1e-5) 35 | -------------------------------------------------------------------------------- /tests/utils/test_rbf_kernels.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from bgflow.utils import kernelize_with_rbf, rbf_kernels 4 | 5 | 6 | def test_kernelize_with_rbf(): 7 | # distance vector of shape `[n_batch, n_particles, n_particles - 1, 1]` 8 | distances = torch.tensor([[[[0.], [5]], [[0], [5]], [[5], [5]]]]) 9 | mus = torch.tensor([[[[0., 5]]]]) 10 | gammas = torch.tensor([[[[1., 0.5]]]]) 11 | 12 | distances2 = torch.tensor([[[[1.]]]], requires_grad=True) 13 | mus2 = torch.tensor([[[[1., 3]]]]) 14 | gammas2 = torch.tensor([[[[1., 0.5]]]]) 15 | 16 | rbf1 = torch.exp(- (distances2 - mus2[0, 0, 0, 0]) ** 2 / gammas2[0, 0, 0, 0] ** 2) 17 | rbf2 = torch.exp(- (distances2 - mus2[0, 0, 0, 1]) ** 2 / gammas2[0, 0, 0, 1] ** 2) 18 | r1 = rbf1 / (rbf1 + rbf2) 19 | r2 = rbf2 / (rbf1 + rbf2) 20 | 21 | # Test shapes and math for simple configurations 22 | rbfs = kernelize_with_rbf(distances, mus, gammas) 23 | assert torch.allclose(rbfs, torch.tensor([[[[1., 0], [0, 1]], [[1., 0], [0, 1]], [[0, 1], [0, 1]]]]), atol=1e-5) 24 | 25 | rbfs2 = kernelize_with_rbf(distances2, mus2, gammas2) 26 | assert torch.allclose(rbfs2, torch.cat([r1, r2], dim=-1), atol=1e-5) 27 | 28 | 29 | def test_rbf_kernels(): 30 | # distance vector of shape `[n_batch, n_particles, n_particles - 1, 1]` 31 | distances = torch.tensor([[[[0.], [5]], [[0], [5]], [[5], [5]]]]) 32 | mus = torch.tensor([[[[0., 5]]]]) 33 | gammas = torch.tensor([[[[1., 0.5]]]]) 34 | 35 | distances2 = torch.tensor([[[[1.]]]], requires_grad=True) 36 | mus2 = torch.tensor([[[[1., 3]]]]) 37 | gammas2 = torch.tensor([[[[1., 0.5]]]]) 38 | 39 | rbf1 = torch.exp(- (distances2 - mus2[0, 0, 0, 0]) ** 2 / gammas2[0, 0, 0, 0] ** 2) 40 | rbf2 = torch.exp(- (distances2 - mus2[0, 0, 0, 1]) ** 2 / gammas2[0, 0, 0, 1] ** 2) 41 | r1 = rbf1 / (rbf1 + rbf2) 42 | r2 = rbf2 / (rbf1 + rbf2) 43 | 44 | # Test shapes, math and the derivative for simple configurations 45 | neg_log_gammas = - torch.log(gammas) 46 | rbfs, derivatives_rbfs = rbf_kernels(distances, mus, neg_log_gammas, derivative=True) 47 | assert torch.allclose(rbfs, torch.tensor([[[[1., 0], [0, 1]], [[1., 0], [0, 1]], [[0, 1], [0, 1]]]]), atol=1e-5) 48 | assert torch.allclose(derivatives_rbfs, torch.zeros((1, 3, 2, 2)), atol=1e-5) 49 | 50 | neg_log_gammas2 = - torch.log(gammas2) 51 | rbfs2, derivatives_rbfs2 = rbf_kernels(distances2, mus2, neg_log_gammas2, derivative=True) 52 | assert torch.allclose(rbfs2, torch.cat([r1, r2], dim=-1), atol=1e-5) 53 | 54 | # Check derivative of rbf 55 | dr1 = torch.autograd.grad(r1, distances2, retain_graph=True) 56 | dr2 = torch.autograd.grad(r2, distances2) 57 | assert torch.allclose(derivatives_rbfs2, torch.cat([dr1[0], dr2[0]], dim=-1), atol=1e-5) 58 | -------------------------------------------------------------------------------- /tests/utils/test_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bgflow.utils import ClipGradient 3 | 4 | 5 | def torch_example(grad_clipping, ctx): 6 | positions = torch.arange(6).reshape(2, 3).to(**ctx) 7 | positions.requires_grad = True 8 | positions = grad_clipping.to(**ctx)(positions) 9 | (0.5 * positions ** 2).sum().backward() 10 | return positions.grad 11 | 12 | 13 | def test_clip_by_val(ctx): 14 | grad_clipping = ClipGradient(clip=3., norm_dim=1) 15 | assert torch.allclose( 16 | torch_example(grad_clipping, ctx), 17 | torch.tensor([[0., 1., 2.], [3., 3., 3.]], **ctx) 18 | ) 19 | 20 | 21 | def test_clip_by_atom(ctx): 22 | grad_clipping = ClipGradient(clip=3., norm_dim=3) 23 | norm2 = torch.linalg.norm(torch.arange(3, 6, **ctx)).item() 24 | assert torch.allclose( 25 | torch_example(grad_clipping, ctx), 26 | torch.tensor([[0., 1., 2.], [3/norm2*3, 4/norm2*3, 5/norm2*3]], **ctx) 27 | ) 28 | 29 | 30 | def test_clip_by_batch(ctx): 31 | grad_clipping = ClipGradient(clip=3., norm_dim=-1) 32 | norm2 = torch.linalg.norm(torch.arange(6, **ctx)).item() 33 | assert torch.allclose( 34 | torch_example(grad_clipping, ctx), 35 | (torch.arange(6, **ctx) / norm2 * 3.).reshape(2, 3) 36 | ) 37 | 38 | -------------------------------------------------------------------------------- /tests/utils/test_types.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from bgflow.utils import as_numpy 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def test_as_numpy(): 9 | assert np.allclose(as_numpy(2.0), np.array(2.0)) 10 | assert np.allclose(as_numpy(np.ones(2)), np.ones(2)) 11 | assert as_numpy(1) == np.array(1) 12 | 13 | 14 | def test_tensor_as_numpy(ctx): 15 | out = as_numpy(torch.zeros(2, **ctx)) 16 | assert isinstance(out, np.ndarray) 17 | assert np.allclose(out, np.zeros(2)) 18 | --------------------------------------------------------------------------------