├── .coveragerc ├── .editorconfig ├── .flake8 ├── .github └── workflows │ ├── CI.yml │ └── publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.rst ├── README.md ├── configs ├── __init__.py ├── default_config.py └── example.py ├── diffusionjax ├── __init__.py ├── inverse_problems.py ├── models.py ├── models │ ├── __init__.py │ └── networks_edm2.py ├── plot.py ├── run_lib.py ├── sde.py ├── solvers.py └── utils.py ├── examples ├── example.py ├── example1.py └── example2.py ├── mypy.ini ├── pyproject.toml ├── readme_empirical_score.png ├── readme_heatmap_bounded_perturbation.png ├── readme_heatmap_empirical_score.png ├── readme_heatmap_inpainted.png ├── readme_heatmap_trained_score.png ├── readme_nplan.png ├── readme_samples.png ├── readme_trained_score.png ├── ruff.toml ├── setup.py └── test ├── __init__.py ├── external └── test_benchmark.py ├── test_solvers.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = diffusionjax/test/* 3 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # 2 space indentation 2 | [*.py] 3 | indent_style = space 4 | indent_size = 2 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] # https://flake8.pycqa.org/en/6.0.0/user/options.html#cmdoption-flake8-select 2 | filename = 3 | */diffusionjax/*.py, 4 | */test/*.py 5 | select=F,W6,E71,E72,E112,E113,E124,E203,E272,E303,E304,E502,E702,E703,E731,W191 6 | indent-size=2 7 | # ignore = F722, F821 8 | per-file-ignores = 9 | test/*: F401, F403, F405, F541, E722, E731, F811, F821, F841 10 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.8", "3.9", "3.10", "3.11"] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | pip install --upgrade pip setuptools numpy 22 | pip install -e '.[linting,testing]' 23 | pip install --upgrade numpy 24 | - name: Test 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | COVERALLS_FLAG_NAME: ${{ matrix.python-version }} 28 | COVERALLS_PARALLEL: true 29 | run: | 30 | JAX_ENABLE_X64=1 pytest -v --cov=diffusionjax --cov-report term-missing --ignore test/external 31 | coveralls --service=github 32 | coveralls: 33 | name: Finish coverage 34 | needs: test 35 | runs-on: ubuntu-latest 36 | container: python:3-slim 37 | steps: 38 | - name: Finished 39 | run: | 40 | pip3 install --upgrade coveralls 41 | coveralls --finish 42 | env: 43 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 44 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python package using Twine when a release is 2 | # created. For more information see the following link: 3 | # https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 4 | 5 | name: Publish to PyPI 6 | 7 | on: 8 | release: 9 | types: [published] 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | # Make sure tags are fetched so we can get a version. 19 | - run: | 20 | git fetch --prune --unshallow --tags 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.x' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -U setuptools 'setuptools_scm[toml]' setuptools_scm_git_archive wheel twine 30 | pip install -U numpy 31 | - name: Build and publish 32 | env: 33 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | 36 | run: | 37 | python setup.py sdist 38 | twine upload dist/* 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Autogenerated files 2 | diffusionjax/_version.py 3 | examples/*.png 4 | */checkpoints* 5 | */samples* 6 | */wandb* 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Idea files 121 | .idea/ 122 | 123 | # Scratch files 124 | scratch.py 125 | refactoring.txt 126 | refactoring 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # Mac OS directory information 147 | .DS_Store 148 | 149 | # cython 150 | *.c 151 | 152 | # Videos 153 | *.mp4 154 | 155 | .vscode/ 156 | *terminal* 157 | wandb/ 158 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: ruff 5 | name: ruff 6 | entry: ruff . 7 | language: system 8 | always_run: true 9 | pass_filenames: false 10 | - id: flake8 11 | name: flake8 12 | entry: flake8 --statistics -j4 13 | language: system 14 | always_run: true 15 | pass_filenames: false 16 | - id: mypy 17 | name: mypy 18 | entry: mypy diffusionjax/ 19 | language: system 20 | always_run: true 21 | pass_filenames: false 22 | - id: tests 23 | name: subset of tests 24 | entry: pytest test/test_utils.py 25 | language: system 26 | always_run: true 27 | pass_filenames: false 28 | -------------------------------------------------------------------------------- /LICENSE.rst: -------------------------------------------------------------------------------- 1 | *********** 2 | MIT License 3 | *********** 4 | 5 | Copyright (c) 2020 Benjamin Boys 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | diffusionjax 2 | ============ 3 | [![CI](https://github.com/bb515/diffusionjax/actions/workflows/CI.yml/badge.svg)](https://github.com/bb515/diffusionjax/actions/workflows/CI.yml) 4 | [![Coverage Status](https://coveralls.io/repos/github/bb515/diffusionjax/badge.svg?branch=master)](https://coveralls.io/github/bb515/diffusionjax?branch=master) 5 | 6 | diffusionjax is a simple, accessible introduction to diffusion models, also known as score-based generative models (SGMs). It is implemented in Python via the autodiff framework, [JAX](https://github.com/google/jax). In particular, diffusionjax uses the [Flax](https://github.com/google/flax) library for the neural network approximator of the score. diffusionjax focusses on the continuous time formulation during training. 7 | 8 | The development of diffusionjax has been supported by The Alan Turing Institute through the Theory and Methods Challenge Fortnights event "Accelerating generative models and nonconvex optimisation", which took place on 6-10 June 2022 and 5-9 Sep 2022 at The Alan Turing Institute headquarters. 9 | 10 | ![nPlan](readme_nplan.png) 11 | 12 | Thank you to [nPlan](https://www.nplan.io/), who are supporting this project. 13 | 14 | Contents: 15 | - [Installation](#installation) 16 | - [Examples](#examples) 17 | - [Introduction to diffusion models](#introduction-to-diffusion-models) 18 | - [Does haves](#does-haves) 19 | - [Doesn't haves](#doesn't-haves) 20 | - [References](#references) 21 | 22 | ## Installation 23 | The package requires Python 3.8+. First, it is recommended to [create a new python virtual environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands). 24 | diffusionjax depends on JAX. Because the JAX installation is different depending on your CUDA version, diffusionjax does not list JAX as a dependency in `setup.py`. 25 | First, [follow these instructions](https://github.com/google/jax#installation) to install JAX with the relevant accelerator support. 26 | To run the examples, you may optionally need to install [optax](https://optax.readthedocs.io/en/latest/), [orbax-checkpoint](https://orbax.readthedocs.io/en/latest/), [torch[cpu]](https://pytorch.org/get-started/locally/) and [mlkernels](https://github.com/wesselb/mlkernels#installation), which the package depends on only through the examples given. 27 | Then, `pip install diffusionjax` or for developers, 28 | - Clone the repository `git clone git@github.com:bb515/diffusionjax.git` 29 | - Install using pip `pip install -e .` from the root directory of the repository (see the `setup.py` for the requirements that this command installs). 30 | 31 | ## Examples 32 | 33 | ### Introduction to diffusion models 34 | Run the example by typing 35 | ```sh 36 | python examples/example.py: 37 | --config: Training configuration. 38 | (default: './configs/example.py') 39 | --workdir: Working directory 40 | (default: './examples/') 41 | ``` 42 | on the command line from the root directory of the repository. 43 | * `config` is the path to the config file. The default config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections). 44 | * `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results via wandb. 45 | 46 | The example is based off the [Jupyter notebook](https://jakiw.com/sgm_intro) by Jakiw Pidstrigach, a tutorial on the theoretical and implementation aspects of diffusion models. 47 | ```python 48 | >>> num_epochs = 4000 49 | >>> num_samples = 8 50 | >>> samples = sample_circle(num_samples) 51 | >>> N = samples.shape[1] 52 | >>> plot_scatter(samples=samples, index=(0, 1), fname="samples", lims=((-3, 3), (-3, 3))) 53 | >>> rng = random.PRNGKey(2023) 54 | ``` 55 | ![Prediction](readme_samples.png) 56 | ```python 57 | >>> # Get variance preserving (VP) a.k.a. time-changed Ohrnstein Uhlenbeck (OU) sde model 58 | >>> sde = VP() 59 | >>> 60 | >>> def log_hat_pt(x, t): 61 | >>> """ 62 | >>> Empirical distribution score. 63 | >>> 64 | >>> Args: 65 | >>> x: One location in $\mathbb{R}^2$ 66 | >>> t: time 67 | >>> Returns: 68 | >>> The empirical log density, as described in the Jupyter notebook 69 | >>> .. math:: 70 | >>> \hat{p}_{t}(x) 71 | >>> """ 72 | >>> mean, std = sde.marginal_prob(samples, t) 73 | >>> potentials = jnp.sum(-(x - mean)**2 / (2 * std**2), axis=1) 74 | >>> return logsumexp(potentials, axis=0, b=1/num_samples) 75 | >>> 76 | >>> # Get a jax grad function, which can be batched with vmap 77 | >>> nabla_log_hat_pt = jit(vmap(grad(log_hat_pt), in_axes=(0, 0), out_axes=(0))) 78 | >>> 79 | >>> # Running the reverse SDE with the empirical drift 80 | >>> plot_score(score=nabla_log_hat_pt, t=0.01, area_min=-3, area_max=3, fname="empirical score") 81 | ``` 82 | ![Prediction](readme_empirical_score.png) 83 | ```python 84 | >>> sampler = get_sampler((5760, N), EulerMaruyama(sde.reverse(nabla_log_hat_pt))) 85 | >>> rng, *sample_rng = random.split(rng, 2) 86 | >>> q_samples = sampler(jnp.array(sample_rng)) 87 | >>> q_samples = q_samples.reshape(5760, N) 88 | >>> plot_heatmap(samples=q_samples, area_min=-3, area_max=3, fname="heatmap empirical score") 89 | ``` 90 | ![Prediction](readme_heatmap_empirical_score.png) 91 | ```python 92 | >>> # What happens when I perturb the score with a constant? 93 | >>> perturbed_score = lambda x, t: nabla_log_hat_pt(x, t) + 1 94 | >>> sampler = get_sampler((5760, N), EulerMaruyama(sde.reverse(perturbed_score))) 95 | >>> rng, *sample_rng = random.split(rng, 2) 96 | >>> q_samples = sampler(jnp.array(sample_rng)) 97 | >>> q_samples = q_samples.reshape(5760, N) 98 | >>> plot_heatmap(samples=q_samples, area_min=-3, area_max=3, fname="heatmap bounded perturbation") 99 | ``` 100 | ![Prediction](readme_heatmap_bounded_perturbation.png) 101 | ```python 102 | >>> # Neural network training via score matching 103 | >>> batch_size=16 104 | >>> score_model = MLP() 105 | >>> # Initialize parameters 106 | >>> params = score_model.init(step_rng, jnp.zeros((batch_size, N)), jnp.ones((batch_size,))) 107 | >>> # Initialize optimizer 108 | >>> opt_state = optimizer.init(params) 109 | >>> # Get loss function 110 | >>> solver = EulerMaruyama(sde) 111 | >>> loss = get_loss( 112 | >>> sde, solver, score_model, score_scaling=True, likelihood_weighting=False) 113 | >>> # Train with score matching 114 | >>> score_model, params, opt_state, mean_losses = retrain_nn( 115 | >>> update_step=update_step, 116 | >>> num_epochs=num_epochs, 117 | >>> step_rng=step_rng, 118 | >>> samples=samples, 119 | >>> score_model=score_model, 120 | >>> params=params, 121 | >>> opt_state=opt_state, 122 | >>> loss=loss, 123 | >>> batch_size=batch_size) 124 | >>> # Get trained score 125 | >>> trained_score = get_score(sde, score_model, params, score_scaling=True) 126 | >>> plot_score(score=trained_score, t=0.01, area_min=-3, area_max=3, fname="trained score") 127 | ``` 128 | ![Prediction](readme_trained_score.png) 129 | ```python 130 | >>> solver = EulerMaruyama(sde.reverse(trained_score)) 131 | >>> sampler = get_sampler((720, N), solver, stack_samples=False) 132 | >>> rng, *sample_rng = random.split(rng, 2) 133 | >>> q_samples = sampler(jnp.array(sample_rng)) 134 | >>> q_samples = q_samples.reshape(720, N) 135 | >>> plot_heatmap(samples=q_samples, area_min=-3, area_max=3, fname="heatmap trained score") 136 | ``` 137 | ![Prediction](readme_heatmap_trained_score.png) 138 | ```python 139 | >>> # Condition on one of the coordinates 140 | >>> y = jnp.array([-0.5, 0.0]) 141 | >>> mask = jnp.array([1., 0.]) 142 | >>> # Get inpainter 143 | >>> sampler = get_sampler(sampling_shape, 144 | solver, 145 | Inpainted(sde.reverse(trained_score), mask, y), 146 | inverse_scaler=inverse_scaler, 147 | stack_samples=False, 148 | denoise=True) 149 | >>> q_samples, _ = sampler(sample_rng) 150 | >>> q_samples = q_samples.reshape(sampling_shape) 151 | >>> plot_heatmap(samples=q_samples, area_bounds=[-3., 3.], fname="heatmap inpainted") 152 | ``` 153 | ![Prediction](readme_heatmap_inpainted.png) 154 | 155 | ## Does haves 156 | - Training scores on (possibly, image) data and sampling from the generative model. Also inverse problems, such as inpainting. 157 | - jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via `config.training.n_jitted_steps`. 158 | - Not many lines of code. 159 | - Bayesian inversion (inverse problems) with linear observation maps. 160 | - Easy to use, extendable. Get started with the example, provided. 161 | - Implements a JAX port of the model and loss from [Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696) 162 | 163 | ## Doesn't haves 164 | - Geometry other than Euclidean space, such as Riemannian manifolds. 165 | - Diffusion in a latent space. 166 | - Augmented with critically-damped Langevin diffusion. 167 | 168 | ## References 169 | This is the implementation for the paper [Tweedie Moment Projected Diffusions for Inverse Problems](https://arxiv.org/pdf/2310.06721.pdf) by Benjamin Boys, Mark Girolami, Jakiw Pidstrigach, Sebastian Reich, Alan Mosca and O. Deniz Akyildiz. 170 | 171 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/configs/__init__.py -------------------------------------------------------------------------------- /configs/default_config.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_default_configs(): 5 | config = ml_collections.ConfigDict() 6 | 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 64 10 | training.n_iters = 2400001 11 | training.snapshot_freq = 50000 12 | training.log_epochs_freq = 10 13 | training.log_step_freq = 8 14 | training.eval_freq = 100 15 | ## store additional checkpoints for preemption in cloud computing environments 16 | training.snapshot_freq_for_preemption = 5000 17 | ## produce samples at each snapshot. 18 | training.snapshot_sampling = True 19 | training.likelihood_weighting = False 20 | training.score_scaling = True 21 | training.n_jitted_steps = 1 22 | training.pmap = False 23 | training.reduce_mean = True 24 | training.pointwise_t = False 25 | 26 | # sampling 27 | config.sampling = sampling = ml_collections.ConfigDict() 28 | sampling.stack_samples = False 29 | sampling.denoise = True 30 | 31 | # evaluation 32 | config.eval = evaluate = ml_collections.ConfigDict() 33 | evaluate.batch_size = 128 34 | 35 | # data 36 | config.data = data = ml_collections.ConfigDict() 37 | data.num_channels = None 38 | data.image_size = 2 39 | 40 | # model 41 | config.model = model = ml_collections.ConfigDict() 42 | 43 | # for vp 44 | model.beta_min = 0.1 45 | model.beta_max = 20.0 46 | 47 | # for ve 48 | model.sigma_max = 378.0 49 | model.sigma_min = 0.01 50 | 51 | # solver 52 | config.solver = solver = ml_collections.ConfigDict() 53 | solver.num_outer_steps = 1000 54 | solver.num_inner_steps = 1 55 | solver.outer_solver = "EulerMaruyama" 56 | solver.eta = None # for DDIM 57 | solver.inner_solver = None 58 | solver.dt = None 59 | solver.epsilon = None 60 | solver.snr = None 61 | 62 | # optimization 63 | config.seed = 2023 64 | config.optim = optim = ml_collections.ConfigDict() 65 | optim.optimizer = "Adam" 66 | optim.lr = 2e-4 67 | optim.warmup = 5000 68 | optim.weight_decay = False 69 | optim.grad_clip = None 70 | optim.beta1 = 0.9 71 | optim.eps = 1e-8 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /configs/example.py: -------------------------------------------------------------------------------- 1 | """Config for `examples/example.py`.""" 2 | 3 | from configs.default_config import get_default_configs 4 | 5 | 6 | def get_config(): 7 | config = get_default_configs() 8 | # training 9 | training = config.training 10 | training.sde = "vpsde" 11 | # training.sde = 'vesde' 12 | training.n_iters = 4000 13 | training.batch_size = 8 14 | training.likelihood_weighting = False 15 | training.score_scaling = True 16 | training.reduce_mean = True 17 | training.log_epoch_freq = 1 18 | training.log_step_freq = 8000 19 | training.pmap = False 20 | training.n_jitted_steps = 1 21 | ## store additional checkpoints for preemption in cloud computing environments 22 | training.snapshot_freq = 8000 23 | training.snapshot_freq_for_preemption = 8000 24 | training.eval_freq = 8000 25 | 26 | # eval 27 | eval = config.eval 28 | eval.batch_size = 1000 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.denoise = True 33 | sampling.noise_std = 0.01 34 | 35 | # data 36 | data = config.data 37 | data.image_size = 2 38 | data.num_channels = None 39 | 40 | # model 41 | model = config.model 42 | # for vp 43 | model.beta_min = 0.01 44 | model.beta_max = 3.0 45 | # for ve 46 | model.sigma_min = 0.01 47 | model.sigma_max = 10.0 48 | 49 | # solver 50 | solver = config.solver 51 | solver.num_outer_steps = 1000 52 | solver.outer_solver = "EulerMaruyama" 53 | solver.inner_solver = None 54 | 55 | # optim 56 | optim = config.optim 57 | optim.optimizer = "Adam" 58 | optim.lr = 1e-3 59 | optim.warmup = False 60 | optim.weight_decay = False 61 | optim.grad_clip = None 62 | 63 | config.seed = 2023 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /diffusionjax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/diffusionjax/__init__.py -------------------------------------------------------------------------------- /diffusionjax/inverse_problems.py: -------------------------------------------------------------------------------- 1 | """Utility functions related to Bayesian inversion.""" 2 | 3 | import jax.numpy as jnp 4 | from jax import vmap, vjp, jacfwd, jacrev, grad 5 | from diffusionjax.utils import ( 6 | batch_mul, 7 | batch_matmul_A, 8 | batch_linalg_solve, 9 | batch_matmul, 10 | batch_mul_A, 11 | batch_linalg_solve_A, 12 | ) 13 | 14 | 15 | def get_dps(sde, observation_map, y, noise_std, scale=0.4): 16 | """ 17 | Implementation of score guidance suggested in 18 | `Diffusion Posterior Sampling for general noisy inverse problems' 19 | Chung et al. 2022, 20 | https://github.com/DPS2022/diffusion-posterior-sampling/blob/main/guided_diffusion/condition_methods.py 21 | 22 | Computes a single (batched) gradient. 23 | 24 | NOTE: This is not how Chung et al. 2022 implemented their method, but is a related 25 | continuous time method. 26 | 27 | Args: 28 | scale: Hyperparameter of the method. 29 | See https://arxiv.org/pdf/2209.14687.pdf#page=20&zoom=100,144,757 30 | """ 31 | 32 | def get_l2_norm(y, estimate_h_x_0): 33 | def l2_norm(x, t): 34 | h_x_0, (s, _) = estimate_h_x_0(x, t) 35 | innovation = y - h_x_0 36 | return jnp.linalg.norm(innovation), s 37 | 38 | return l2_norm 39 | 40 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 41 | l2_norm = get_l2_norm(y, estimate_h_x_0) 42 | likelihood_score = grad(l2_norm, has_aux=True) 43 | 44 | def guidance_score(x, t): 45 | ls, s = likelihood_score(x, t) 46 | gs = s - scale * ls 47 | return gs 48 | 49 | return guidance_score 50 | 51 | 52 | def get_diffusion_posterior_sampling(sde, observation_map, y, noise_std): 53 | """ 54 | Implementation of score guidance suggested in 55 | `Diffusion Posterior Sampling for general noisy inverse problems' 56 | Chung et al. 2022, 57 | https://github.com/DPS2022/diffusion-posterior-sampling/blob/main/guided_diffusion/condition_methods.py 58 | guidance score for an observation_map that can be 59 | represented by either `def observation_map(x: Float[Array, dims]) -> y: Float[Array, d_x = dims.flatten()]: return mask * x # (d_x,)` 60 | or `def observation_map(x: Float[Array, dims]) -> y: Float[Array, d_y]: return H @ x # (d_y,)` 61 | Computes one vjps. 62 | 63 | NOTE: This is not how Chung et al. 2022 implemented their method, their method is `:meth:get_dps`. 64 | Whereas this method uses their approximation in Eq. 11 https://arxiv.org/pdf/2209.14687.pdf#page=20&zoom=100,144,757 65 | to directly calculate the score. 66 | """ 67 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 68 | 69 | def guidance_score(x, t): 70 | h_x_0, vjp_estimate_h_x_0, (s, _) = vjp( 71 | lambda x: estimate_h_x_0(x, t), x, has_aux=True 72 | ) 73 | innovation = y - h_x_0 74 | C_yy = ( 75 | noise_std**2 76 | ) # TODO: could investigate replacing with jnp.linalg.norm(innovation**2) 77 | ls = innovation / C_yy 78 | ls = vjp_estimate_h_x_0(ls)[0] 79 | gs = s + ls 80 | return gs 81 | 82 | return guidance_score 83 | 84 | 85 | def get_pseudo_inverse_guidance( 86 | sde, observation_map, y, noise_std, HHT=jnp.array([1.0]) 87 | ): 88 | """ 89 | `Pseudo-Inverse guided diffusion models for inverse problems` 90 | https://openreview.net/pdf?id=9_gsMA8MRKQ 91 | Song et al. 2023, 92 | guidance score for an observation_map that can be 93 | represented by either `def observation_map(x: Float[Array, dims]) -> y: Float[Array, d_x = dims.flatten()]: return mask * x # (d_x,)` 94 | or `def observation_map(x: Float[Array, dims]) -> y: Float[Array, d_y]: return H @ x # (d_y,)` 95 | Computes one vjps. 96 | """ 97 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 98 | 99 | def guidance_score(x, t): 100 | h_x_0, vjp_estimate_h_x_0, (s, _) = vjp( 101 | lambda x: estimate_h_x_0(x, t), x, has_aux=True 102 | ) 103 | innovation = y - h_x_0 104 | if HHT.shape == (y.shape[1], y.shape[1]): 105 | C_yy = sde.r2(t[0], data_variance=1.0) * HHT + noise_std**2 * jnp.eye(y.shape[1]) 106 | f = batch_linalg_solve_A(C_yy, innovation) 107 | elif HHT.shape == (1,): 108 | C_yy = sde.r2(t[0], data_variance=1.0) * HHT + noise_std**2 109 | f = innovation / C_yy 110 | ls = vjp_estimate_h_x_0(f)[0] 111 | gs = s + ls 112 | return gs 113 | 114 | return guidance_score 115 | 116 | 117 | def get_vjp_guidance_alt(sde, H, y, noise_std, shape): 118 | """ 119 | Uses full second moment approximation of the covariance of x_0|x_t. 120 | 121 | Computes using H.shape[0] vjps. 122 | 123 | NOTE: Alternate implementation to `meth:get_vjp_guidance` that does all reshaping here. 124 | """ 125 | estimate_x_0 = sde.get_estimate_x_0(lambda x: x) 126 | _shape = (H.shape[0],) + shape[1:] 127 | axes = (1, 0) + tuple(range(len(shape) + 1)[2:]) 128 | batch_H = jnp.transpose( 129 | jnp.tile(H.reshape(_shape), (shape[0],) + len(shape) * (1,)), axes=axes 130 | ) 131 | 132 | def guidance_score(x, t): 133 | x_0, vjp_x_0, (s, _) = vjp(lambda x: estimate_x_0(x, t), x, has_aux=True) 134 | vec_vjp_x_0 = vmap(vjp_x_0) 135 | H_grad_x_0 = vec_vjp_x_0(batch_H)[0] 136 | H_grad_x_0 = H_grad_x_0.reshape(H.shape[0], shape[0], H.shape[1]) 137 | C_yy = sde.ratio(t[0]) * batch_matmul_A( 138 | H, H_grad_x_0.transpose(1, 2, 0) 139 | ) + noise_std**2 * jnp.eye(y.shape[1]) 140 | innovation = y - batch_matmul_A(H, x_0.reshape(shape[0], -1)) 141 | f = batch_linalg_solve(C_yy, innovation) 142 | ls = vjp_x_0(batch_matmul_A(H.T, f).reshape(shape))[0] 143 | gs = s + ls 144 | return gs 145 | 146 | return guidance_score 147 | 148 | 149 | def get_vjp_guidance(sde, H, y, noise_std, shape): 150 | """ 151 | Uses full second moment approximation of the covariance of x_0|x_t. 152 | 153 | Computes using H.shape[0] vjps. 154 | """ 155 | # TODO: necessary to use shape here? 156 | estimate_x_0 = sde.get_estimate_x_0(lambda x: x, shape=(shape[0], -1)) 157 | batch_H = jnp.transpose(jnp.tile(H, (shape[0], 1, 1)), axes=(1, 0, 2)) 158 | assert y.shape[0] == shape[0] 159 | assert y.shape[1] == H.shape[0] 160 | 161 | def guidance_score(x, t): 162 | x_0, vjp_x_0, (s, _) = vjp(lambda x: estimate_x_0(x, t), x, has_aux=True) 163 | vec_vjp_x_0 = vmap(vjp_x_0) 164 | H_grad_x_0 = vec_vjp_x_0(batch_H)[0] 165 | H_grad_x_0 = H_grad_x_0.reshape(H.shape[0], shape[0], H.shape[1]) 166 | C_yy = sde.ratio(t[0]) * batch_matmul_A( 167 | H, H_grad_x_0.transpose(1, 2, 0) 168 | ) + noise_std**2 * jnp.eye(y.shape[1]) 169 | innovation = y - batch_matmul_A(H, x_0) 170 | f = batch_linalg_solve(C_yy, innovation) 171 | # NOTE: in some early tests it's faster to calculate via H_grad_x_0, instead of another vjp 172 | ls = batch_matmul(H_grad_x_0.transpose(1, 2, 0), f).reshape(s.shape) 173 | # ls = vjp_x_0(batch_matmul_A(H.T, f))[0] 174 | gs = s + ls 175 | return gs 176 | 177 | return guidance_score 178 | 179 | 180 | def get_vjp_guidance_mask(sde, observation_map, y, noise_std): 181 | """ 182 | Uses row sum of second moment approximation of the covariance of x_0|x_t. 183 | 184 | Computes two vjps. 185 | """ 186 | # estimate_h_x_0_vmap = sde.get_estimate_x_0_vmap(observation_map) 187 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 188 | batch_observation_map = vmap(observation_map) 189 | 190 | def guidance_score(x, t): 191 | h_x_0, vjp_h_x_0, (s, _) = vjp(lambda x: estimate_h_x_0(x, t), x, has_aux=True) 192 | diag = batch_observation_map(vjp_h_x_0(batch_observation_map(jnp.ones_like(x)))[0]) 193 | C_yy = sde.ratio(t[0]) * diag + noise_std**2 194 | innovation = y - h_x_0 195 | ls = innovation / C_yy 196 | ls = vjp_h_x_0(ls)[0] 197 | gs = s + ls 198 | return gs 199 | 200 | return guidance_score 201 | 202 | 203 | def get_jacrev_guidance(sde, observation_map, y, noise_std, shape): 204 | """ 205 | Uses full second moment approximation of the covariance of x_0|x_t. 206 | 207 | Computes using d_y vjps. 208 | """ 209 | batch_batch_observation_map = vmap(vmap(observation_map)) 210 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 211 | estimate_h_x_0_vmap = sde.get_estimate_x_0_vmap(observation_map) 212 | jacrev_vmap = vmap(jacrev(lambda x, t: estimate_h_x_0_vmap(x, t)[0])) 213 | 214 | # axes tuple for correct permutation of grad_H_x_0 array 215 | axes = (0,) + tuple(range(len(shape) + 1)[2:]) + (1,) 216 | 217 | def guidance_score(x, t): 218 | h_x_0, (s, _) = estimate_h_x_0( 219 | x, t 220 | ) # TODO: in python 3.8 this line can be removed by utilizing has_aux=True 221 | grad_H_x_0 = jacrev_vmap(x, t) 222 | H_grad_H_x_0 = batch_batch_observation_map(grad_H_x_0) 223 | C_yy = sde.ratio(t[0]) * H_grad_H_x_0 + noise_std**2 * jnp.eye(y.shape[1]) 224 | innovation = y - h_x_0 225 | f = batch_linalg_solve(C_yy, innovation) 226 | ls = batch_matmul(jnp.transpose(grad_H_x_0, axes), f).reshape(s.shape) 227 | gs = s + ls 228 | return gs 229 | 230 | return guidance_score 231 | 232 | 233 | def get_jacfwd_guidance(sde, observation_map, y, noise_std, shape): 234 | """ 235 | Uses full second moment approximation of the covariance of x_0|x_t. 236 | 237 | Computes using d_y jvps. 238 | """ 239 | batch_batch_observation_map = vmap(vmap(observation_map)) 240 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 241 | estimate_h_x_0_vmap = sde.get_estimate_x_0_vmap(observation_map) 242 | 243 | # axes tuple for correct permutation of grad_H_x_0 array 244 | axes = (0,) + tuple(range(len(shape) + 1)[2:]) + (1,) 245 | jacfwd_vmap = vmap(jacfwd(lambda x, t: estimate_h_x_0_vmap(x, t)[0])) 246 | 247 | def guidance_score(x, t): 248 | h_x_0, (s, _) = estimate_h_x_0( 249 | x, t 250 | ) # TODO: in python 3.8 this line can be removed by utilizing has_aux=True 251 | H_grad_x_0 = jacfwd_vmap(x, t) 252 | H_grad_H_x_0 = batch_batch_observation_map(H_grad_x_0) 253 | C_yy = sde.ratio(t[0]) * H_grad_H_x_0 + noise_std**2 * jnp.eye(y.shape[1]) 254 | innovation = y - h_x_0 255 | f = batch_linalg_solve(C_yy, innovation) 256 | ls = batch_matmul(jnp.transpose(H_grad_x_0, axes), f).reshape(s.shape) 257 | gs = s + ls 258 | return gs 259 | 260 | return guidance_score 261 | 262 | 263 | def get_diag_jacrev_guidance(sde, observation_map, y, noise_std, shape): 264 | """Use a diagonal approximation to the variance inside the likelihood, 265 | This produces similar results when the covariance is approximately diagonal 266 | """ 267 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 268 | batch_batch_observation_map = vmap(vmap(observation_map)) 269 | 270 | # axes tuple for correct permutation of grad_H_x_0 array 271 | axes = (0,) + tuple(range(len(shape) + 1)[2:]) + (1,) 272 | 273 | def vec_jacrev(x, t): 274 | return vmap( 275 | jacrev(lambda _x: estimate_h_x_0(jnp.expand_dims(_x, axis=0), t.reshape(1, 1))[0]) 276 | )(x) 277 | 278 | def guidance_score(x, t): 279 | h_x_0, (s, _) = estimate_h_x_0( 280 | x, t 281 | ) # TODO: in python 3.8 this line can be removed by utilizing has_aux=True 282 | grad_H_x_0 = jnp.squeeze(vec_jacrev(x, t[0]), axis=1) 283 | H_grad_H_x_0 = batch_batch_observation_map(grad_H_x_0) 284 | C_yy = sde.ratio(t[0]) * jnp.diagonal(H_grad_H_x_0, axis1=1, axis2=2) + noise_std**2 285 | innovation = y - h_x_0 286 | f = batch_mul(innovation, 1.0 / C_yy) 287 | ls = batch_matmul(jnp.transpose(grad_H_x_0, axes=axes), f).reshape(s.shape) 288 | gs = s + ls 289 | return gs 290 | 291 | return guidance_score 292 | 293 | 294 | def get_diag_vjp_guidance(sde, H, y, noise_std, shape): 295 | """ 296 | Uses full second moment approximation of the covariance of x_0|x_t. 297 | 298 | Computes using H.shape[0] vjps. 299 | """ 300 | # TODO: necessary to use shape here? 301 | estimate_x_0 = sde.get_estimate_x_0(lambda x: x, shape=(shape[0], -1)) 302 | batch_H = jnp.transpose(jnp.tile(H, (shape[0], 1, 1)), axes=(1, 0, 2)) 303 | 304 | def guidance_score(x, t): 305 | x_0, vjp_x_0, (s, _) = vjp(lambda x: estimate_x_0(x, t), x, has_aux=True) 306 | vec_vjp_x_0 = vmap(vjp_x_0) 307 | H_grad_x_0 = vec_vjp_x_0(batch_H)[0] 308 | H_grad_x_0 = H_grad_x_0.reshape(H.shape[0], shape[0], H.shape[1]) 309 | diag_H_grad_H_x_0 = jnp.sum(batch_mul_A(H, H_grad_x_0.transpose(1, 0, 2)), axis=-1) 310 | C_yy = sde.ratio(t[0]) * diag_H_grad_H_x_0 + noise_std**2 311 | innovation = y - batch_matmul_A(H, x_0) 312 | f = batch_mul(innovation, 1.0 / C_yy) 313 | ls = vjp_x_0(batch_matmul_A(H.T, f))[0] 314 | gs = s + ls 315 | return gs 316 | 317 | return guidance_score 318 | 319 | 320 | def get_diag_jacfwd_guidance(sde, observation_map, y, noise_std, shape): 321 | """Use a diagonal approximation to the variance inside the likelihood, 322 | This produces similar results when the covariance is approximately diagonal 323 | """ 324 | batch_batch_observation_map = vmap(vmap(observation_map)) 325 | estimate_h_x_0 = sde.get_estimate_x_0(observation_map) 326 | # axes tuple for correct permutation of grad_H_x_0 array 327 | axes = (0,) + tuple(range(len(shape) + 1)[2:]) + (1,) 328 | 329 | def vec_jacfwd(x, t): 330 | return vmap( 331 | jacfwd(lambda _x: estimate_h_x_0(jnp.expand_dims(_x, axis=0), t.reshape(1, 1))[0]) 332 | )(x) 333 | 334 | def guidance_score(x, t): 335 | h_x_0, (s, _) = estimate_h_x_0( 336 | x, t 337 | ) # TODO: in python 3.8 this line can be removed by utilizing has_aux=True 338 | H_grad_x_0 = jnp.squeeze(vec_jacfwd(x, t[0]), axis=(1)) 339 | H_grad_H_x_0 = batch_batch_observation_map(H_grad_x_0) 340 | C_yy = sde.ratio(t[0]) * jnp.diagonal(H_grad_H_x_0, axis1=1, axis2=2) + noise_std**2 341 | f = batch_mul(y - h_x_0, 1.0 / C_yy) 342 | ls = batch_matmul(jnp.transpose(H_grad_x_0, axes=axes), f).reshape(s.shape) 343 | gs = s + ls 344 | return gs 345 | 346 | return guidance_score 347 | -------------------------------------------------------------------------------- /diffusionjax/models.py: -------------------------------------------------------------------------------- 1 | """Functions are designed for a mini-batch of inputs.""" 2 | 3 | import flax.linen as nn 4 | import numpy as np 5 | import jax.numpy as jnp 6 | 7 | 8 | class MLP(nn.Module): 9 | @nn.compact 10 | def __call__(self, x, t): 11 | x_shape = x.shape 12 | in_size = np.prod(x_shape[1:]) 13 | n_hidden = 256 14 | t = t.reshape((t.shape[0], -1)) 15 | x = x.reshape((x.shape[0], -1)) # flatten 16 | t = jnp.concatenate([t - 0.5, jnp.cos(2 * jnp.pi * t)], axis=-1) 17 | x = jnp.concatenate([x, t], axis=-1) 18 | x = nn.Dense(n_hidden)(x) 19 | x = nn.relu(x) 20 | x = nn.Dense(n_hidden)(x) 21 | x = nn.relu(x) 22 | x = nn.Dense(n_hidden)(x) 23 | x = nn.relu(x) 24 | x = nn.Dense(in_size)(x) 25 | return x.reshape(x_shape) 26 | 27 | 28 | class CNN(nn.Module): 29 | @nn.compact 30 | def __call__(self, x, t): 31 | x_shape = x.shape 32 | ndim = x.ndim 33 | 34 | n_hidden = x_shape[1] 35 | n_time_channels = 1 36 | 37 | t = t.reshape((t.shape[0], -1)) 38 | t = jnp.concatenate([t - 0.5, jnp.cos(2 * jnp.pi * t)], axis=-1) 39 | t = nn.Dense(n_hidden**2 * n_time_channels)(t) 40 | t = nn.relu(t) 41 | t = nn.Dense(n_hidden**2 * n_time_channels)(t) 42 | t = nn.relu(t) 43 | t = t.reshape(t.shape[0], n_hidden, n_hidden, n_time_channels) 44 | # Add time as another channel 45 | x = jnp.concatenate((x, t), axis=-1) 46 | # A single convolution layer 47 | x = nn.Conv(x_shape[-1], kernel_size=(9,) * (ndim - 2))(x) 48 | return x 49 | -------------------------------------------------------------------------------- /diffusionjax/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/diffusionjax/models/__init__.py -------------------------------------------------------------------------------- /diffusionjax/models/networks_edm2.py: -------------------------------------------------------------------------------- 1 | """JAX port of Improved diffusion model architecture proposed in the paper 2 | "Analyzing and Improving the Training Dynamics of Diffusion Models". 3 | Ported from the code https://github.com/NVlabs/edm2/blob/main/training/networks_edm2.py 4 | """ 5 | import jax 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from typing import Any 9 | 10 | 11 | def jax_unstack(x, axis=0): 12 | """https://github.com/google/jax/discussions/11028""" 13 | return [ 14 | jax.lax.index_in_dim(x, i, axis, keepdims=False) for i in range(x.shape[axis]) 15 | ] 16 | 17 | 18 | def pixel_normalize(x, channel_axis, eps=1e-4): 19 | """ 20 | Normalize given tensor to unit magnitude with respect to the given 21 | channel axis. 22 | Args: 23 | x: Assume (N, C, H, W) 24 | """ 25 | norm = jnp.float32(jnp.linalg.vector_norm(x, axis=channel_axis, keepdims=True)) 26 | norm = eps + jnp.sqrt(norm.size / x.size) * norm 27 | return x / jnp.array(norm, dtype=x.dtype) 28 | 29 | 30 | def weight_normalize(x, eps=1e-4): 31 | """ 32 | Normalize given tensor to unit magnitude with respect to all the dimensions 33 | except the first. 34 | Args: 35 | x: Assume (N, C, H, W) 36 | """ 37 | norm = jnp.float32(jax.vmap(lambda x: jnp.linalg.vector_norm(x, keepdims=True))(x)) 38 | norm = eps + jnp.sqrt(norm.size / x.size) * norm 39 | return x / jnp.array(norm, dtype=x.dtype) 40 | 41 | 42 | def forced_weight_normalize(x, eps=1e-4): 43 | """ 44 | Normalize given tensor to unit magnitude with respect to all the dimensions 45 | except the first. Don't take gradients through the computation. 46 | Args: 47 | x: Assume (N, C, H, W) 48 | """ 49 | norm = jax.lax.stop_gradient( 50 | jnp.float32(jax.vmap(lambda x: jnp.linalg.vector_norm(x, keepdims=True))(x)) 51 | ) 52 | norm = eps + jnp.sqrt(norm.size / x.size) * norm 53 | return x / jnp.array(norm, dtype=x.dtype) 54 | 55 | 56 | def resample(x, f=[1, 1], mode="keep"): 57 | """ 58 | Upsample or downsample the given tensor with the given filter, 59 | or keep it as is. 60 | 61 | Args: 62 | x: Assume (N, C, H, W) 63 | """ 64 | if mode == "keep": 65 | return x 66 | f = jnp.array(f, dtype=x.dtype) 67 | assert f.ndim == 1 and len(f) % 2 == 0 68 | f = f / f.sum() 69 | f = jnp.outer(f, f)[jnp.newaxis, jnp.newaxis, :, :] 70 | c = x.shape[1] 71 | 72 | if mode == "down": 73 | return jax.lax.conv_general_dilated( 74 | x, 75 | jnp.tile(f, (c, 1, 1, 1)), 76 | window_strides=(2, 2), 77 | feature_group_count=c, 78 | padding="SAME", 79 | ) 80 | assert mode == "up" 81 | 82 | pad = (len(f) - 1) // 2 + 1 83 | return jax.lax.conv_general_dilated( 84 | x, 85 | jnp.tile(f * 4, (c, 1, 1, 1)), 86 | dimension_numbers=("NCHW", "OIHW", "NCHW"), 87 | window_strides=(1, 1), 88 | lhs_dilation=(2, 2), 89 | feature_group_count=c, 90 | padding=((pad, pad), (pad, pad)), 91 | ) 92 | 93 | 94 | def mp_silu(x): 95 | """Magnitude-preserving SiLU (Equation 81).""" 96 | return nn.activation.silu(x) / 0.596 97 | 98 | 99 | def mp_sum(a, b, t=0.5): 100 | """Magnitude-preserving sum (Equation 88).""" 101 | return (a + t * (b - a)) / jnp.sqrt((1 - t) ** 2 + t**2) 102 | 103 | 104 | def mp_cat(a, b, dim=1, t=0.5): 105 | """Magnitude-preserving concatenation (Equation 103).""" 106 | Na = a.shape[dim] 107 | Nb = b.shape[dim] 108 | C = jnp.sqrt((Na + Nb) / ((1 - t) ** 2 + t**2)) 109 | wa = C / jnp.sqrt(Na) * (1 - t) 110 | wb = C / jnp.sqrt(Nb) * t 111 | return jax.lax.concatenate([wa * a, wb * b], dimension=dim) 112 | 113 | 114 | class MPFourier(nn.Module): 115 | """Magnitude-preserving Fourier features (Equation 75).""" 116 | 117 | num_channels: int 118 | bandwidth: float = 1.0 119 | 120 | @nn.compact 121 | def __call__(self, x): 122 | freqs = self.param( 123 | "freqs", 124 | jax.nn.initializers.normal(stddev=2 * jnp.pi * self.bandwidth), 125 | (self.num_channels,), 126 | ) 127 | freqs = jax.lax.stop_gradient(freqs) 128 | phases = self.param( 129 | "phases", jax.nn.initializers.normal(stddev=2 * jnp.pi), (self.num_channels,) 130 | ) 131 | phases = jax.lax.stop_gradient(phases) 132 | y = jnp.float32(x) 133 | y = jnp.float32(jnp.outer(x, freqs)) 134 | y = y + jnp.float32(phases) 135 | y = jnp.cos(y) * jnp.sqrt(2) 136 | return jnp.array(y, dtype=x.dtype) 137 | 138 | 139 | class MPConv(nn.Module): 140 | """Magnitude-preserving convolution or fully-connected layer (Equation 47) 141 | with force weight normalization (Equation 66). 142 | """ 143 | 144 | in_channels: int 145 | out_channels: int 146 | kernel_shape: tuple 147 | training: bool = True 148 | 149 | @nn.compact 150 | def __call__(self, x, gain=1.0): 151 | w = jnp.float32( 152 | self.param( 153 | "w", 154 | jax.nn.initializers.normal(stddev=1.0), 155 | (self.out_channels, self.in_channels, *self.kernel_shape), 156 | ) 157 | ) # TODO: type promotion required in JAX? 158 | if self.training: 159 | w = forced_weight_normalize(w) # forced weight normalization 160 | 161 | w = weight_normalize(w) # traditional weight normalization 162 | w = w * (gain / jnp.sqrt(w[0].size)) # magnitude-preserving scaling 163 | w = jnp.array(w, dtype=x.dtype) 164 | if w.ndim == 2: 165 | return x @ w.T 166 | assert w.ndim == 4 167 | 168 | return jax.lax.conv( 169 | x, 170 | w, 171 | window_strides=(1, 1), 172 | padding="SAME", 173 | ) 174 | 175 | 176 | class Block(nn.Module): 177 | """ 178 | U-Net encoder/decoder block with optional self-attention (Figure 21). 179 | """ 180 | 181 | in_channels: int # Number of input channels 182 | out_channels: int # Number of output channels 183 | emb_channels: int # Number of embedding channels 184 | flavor: str = "enc" # Flavor: 'enc' or 'dec' 185 | resample_mode: str = "keep" # Resampling: 'keep', 'up', or 'down'. 186 | resample_filter: tuple = (1, 1) # Resampling filter. 187 | attention: bool = False # Include self-attention? 188 | channels_per_head: int = 64 # Number of channels per attention head. 189 | dropout: float = 0.0 # Dropout probability. 190 | res_balance: float = 0.3 # Balance between main branch (0) and residual branch (1). 191 | attn_balance: float = 0.3 # Balance between main branch (0) and self-attention (1). 192 | clip_act: int = 256 # Clip output activations. None = do not clip. 193 | training: bool = True 194 | 195 | @nn.compact 196 | def __call__(self, x, emb): 197 | # Main branch 198 | x = resample(x, f=self.resample_filter, mode=self.resample_mode) 199 | if self.flavor == "enc": 200 | if self.in_channels != self.out_channels: 201 | x = MPConv( 202 | self.in_channels, self.out_channels, kernel_shape=(1, 1), name="conv_skip" 203 | )(x) 204 | x = pixel_normalize(x, channel_axis=1) # pixel norm 205 | 206 | # Residual branch 207 | y = MPConv( 208 | self.out_channels if self.flavor == "enc" else self.in_channels, 209 | self.out_channels, 210 | kernel_shape=(3, 3), 211 | name="conv_res0", 212 | )(mp_silu(x)) 213 | 214 | c = ( 215 | MPConv(self.emb_channels, self.out_channels, kernel_shape=(), name="emb_linear")( 216 | emb, gain=self.param("emb_gain", jax.nn.initializers.zeros, (1,)) 217 | ) 218 | + 1 219 | ) 220 | y = jnp.array( 221 | mp_silu(y * jnp.expand_dims(jnp.expand_dims(c, axis=2), axis=3)), dtype=y.dtype 222 | ) 223 | if self.dropout: 224 | y = nn.Dropout(self.dropout)(y, deterministic=not self.training) 225 | y = MPConv( 226 | self.out_channels, self.out_channels, kernel_shape=(3, 3), name="conv_res1" 227 | )(y) 228 | 229 | # Connect the branches 230 | if self.flavor == "dec" and self.in_channels != self.out_channels: 231 | x = MPConv( 232 | self.in_channels, self.out_channels, kernel_shape=(1, 1), name="conv_skip" 233 | )(x) 234 | x = mp_sum(x, y, t=self.res_balance) 235 | 236 | # Self-attention 237 | # TODO: test if flax.linen.SelfAttention can be used instead here? 238 | num_heads = self.out_channels // self.channels_per_head if self.attention else 0 239 | if num_heads != 0: 240 | y = MPConv( 241 | self.out_channels, self.out_channels * 3, kernel_shape=(1, 1), name="attn_qkv" 242 | )(x) 243 | y = y.reshape(y.shape[0], num_heads, -1, 3, y.shape[2] * y.shape[3]) 244 | q, k, v = jax_unstack( 245 | pixel_normalize(y, channel_axis=2), axis=3 246 | ) # pixel normalization and split 247 | # NOTE: quadratic cost in last dimension 248 | w = nn.softmax(jnp.einsum("nhcq,nhck->nhqk", q, k / jnp.sqrt(q.shape[2])), axis=3) 249 | y = jnp.einsum("nhqk,nhck->nhcq", w, v) 250 | y = MPConv( 251 | self.out_channels, self.out_channels, kernel_shape=(1, 1), name="attn_proj" 252 | )(y.reshape(*x.shape)) 253 | x = mp_sum(x, y, t=self.attn_balance) 254 | 255 | # Clip activations 256 | if self.clip_act is not None: 257 | x = jnp.clip(x, -self.clip_act, self.clip_act) 258 | return x 259 | 260 | 261 | class UNet(nn.Module): 262 | """EDM2 U-Net model (Figure 21).""" 263 | 264 | img_resolution: int # Image resolution. 265 | img_channels: int # Image channels. 266 | label_dim: int # Class label dimensionality. 0 = unconditional. 267 | model_channels: int = 192 # Base multiplier for the number of channels. 268 | channel_mult: tuple = ( 269 | 1, 270 | 2, 271 | 3, 272 | 4, 273 | ) # Per-resolution multipliers for the number of channels. 274 | channel_mult_noise: Any = None # Multiplier for noise embedding dimensionality. None = select based on channel_mult. 275 | channel_mult_emb: Any = None # Multiplier for final embedding dimensionality. None = select based on channel_mult. 276 | num_blocks: int = 3 # Number of residual blocks per resolution. 277 | attn_resolutions: tuple = (16, 8) # List of resolutions with self-attention. 278 | label_balance: float = ( 279 | 0.5 # Balance between noise embedding (0) and class embedding (1). 280 | ) 281 | concat_balance: float = 0.5 # Balance between skip connections (0) and main path (1). 282 | 283 | # **block_kwargs - arguments for Block 284 | resample_filter: tuple = (1, 1) # Resampling filter 285 | channels_per_head: int = 64 # Number of channels per attention head 286 | dropout: float = 0.0 # Dropout probability 287 | res_balance: float = 0.3 # Balance between main branch (0) and residual branch (1) 288 | attn_balance: float = 0.3 # Balance between main branch (0) and self-attention (1) 289 | clip_act: int = 256 # Clip output activations. None = do not clip 290 | out_gain: Any = None 291 | block_kwargs = { 292 | "resample_filter": resample_filter, 293 | "channels_per_head": channels_per_head, 294 | "dropout": dropout, 295 | "res_balance": res_balance, 296 | "attn_balance": attn_balance, 297 | "clip_act": clip_act, 298 | } 299 | 300 | @nn.compact 301 | def __call__(self, x, noise_labels, class_labels): 302 | cblock = [self.model_channels * x for x in self.channel_mult] 303 | cnoise = ( 304 | self.model_channels * self.channel_mult_noise 305 | if self.channel_mult_noise is not None 306 | else cblock[0] 307 | ) 308 | cemb = ( 309 | self.model_channels * self.channel_mult_emb 310 | if self.channel_mult_emb is not None 311 | else max(cblock) 312 | ) 313 | 314 | if self.out_gain is None: 315 | out_gain = self.param("out_gain", jax.nn.initializers.zeros, (1,)) 316 | else: 317 | out_gain = self.out_gain 318 | 319 | # Encoder 320 | enc = {} 321 | cout = self.img_channels + 1 322 | for level, channels in enumerate(cblock): 323 | res = self.img_resolution >> level 324 | if level == 0: 325 | cin = cout 326 | cout = channels 327 | enc[f"{res}x{res}_conv"] = MPConv( 328 | cin, cout, kernel_shape=(3, 3), name=f"enc_{res}x{res}_conv" 329 | ) 330 | else: 331 | enc[f"{res}x{res}_down"] = Block( 332 | cout, 333 | cout, 334 | cemb, 335 | flavor="enc", 336 | resample_mode="down", 337 | name=f"enc_{res}x{res}_down", 338 | **self.block_kwargs, 339 | ) 340 | for idx in range(self.num_blocks): 341 | cin = cout 342 | cout = channels 343 | enc[f"{res}x{res}_block{idx}"] = Block( 344 | cin, 345 | cout, 346 | cemb, 347 | flavor="enc", 348 | attention=(res in self.attn_resolutions), 349 | name=f"enc_{res}x{res}_block{idx}", 350 | **self.block_kwargs, 351 | ) 352 | 353 | # Decoder 354 | dec = {} 355 | skips = [block.out_channels for block in enc.values()] 356 | for level, channels in reversed(list(enumerate(cblock))): 357 | res = self.img_resolution >> level 358 | if level == len(cblock) - 1: 359 | dec[f"{res}x{res}_in0"] = Block( 360 | cout, 361 | cout, 362 | cemb, 363 | flavor="dec", 364 | attention=True, 365 | name=f"dec_{res}x{res}_in0", 366 | **self.block_kwargs, 367 | ) 368 | dec[f"{res}x{res}_in1"] = Block( 369 | cout, 370 | cout, 371 | cemb, 372 | flavor="dec", 373 | name=f"dec_{res}x{res}_in1", 374 | **self.block_kwargs, 375 | ) 376 | else: 377 | dec[f"{res}x{res}_up"] = Block( 378 | cout, 379 | cout, 380 | cemb, 381 | flavor="dec", 382 | resample_mode="up", 383 | name=f"dec_{res}x{res}_up", 384 | **self.block_kwargs, 385 | ) 386 | for idx in range(self.num_blocks + 1): 387 | cin = cout + skips.pop() 388 | cout = channels 389 | dec[f"{res}x{res}_block{idx}"] = Block( 390 | cin, 391 | cout, 392 | cemb, 393 | flavor="dec", 394 | attention=(res in self.attn_resolutions), 395 | name=f"dec_{res}x{res}_block{idx}", 396 | **self.block_kwargs, 397 | ) 398 | 399 | # Embedding 400 | emb = MPConv(cnoise, cemb, kernel_shape=(), name="emb_noise")( 401 | MPFourier(cnoise, name="emb_fourier")(noise_labels) 402 | ) 403 | if self.label_dim != 0: 404 | emb = mp_sum( 405 | emb, 406 | MPConv(self.label_dim, cemb, kernel_shape=(), name="emb_label")( 407 | class_labels * jnp.sqrt(class_labels.shape[1]) 408 | ), 409 | t=self.label_balance, 410 | ) 411 | emb = mp_silu(emb) 412 | 413 | # Encoder 414 | x = jax.lax.concatenate([x, jnp.ones_like(x[:, :1])], dimension=1) 415 | skips = [] 416 | for name, block in enc.items(): 417 | x = block(x) if "conv" in name else block(x, emb) 418 | skips.append(x) 419 | 420 | # Decoder 421 | for name, block in dec.items(): 422 | if "block" in name: 423 | x = mp_cat(x, skips.pop(), t=self.concat_balance) 424 | x = block(x, emb) 425 | x = MPConv(cout, self.img_channels, kernel_shape=(3, 3), name="out_conv")( 426 | x, gain=out_gain 427 | ) 428 | return x 429 | 430 | 431 | class Precond(nn.Module): 432 | """Preconditioning and uncertainty estimation.""" 433 | 434 | img_resolution: int # Image resolution. 435 | img_channels: int # Image channels. 436 | label_dim: int # Class label dimensionality. 0 = unconditional. 437 | # **precond_kwargs 438 | use_fp16: bool = True # Run the model at FP16 precision? 439 | sigma_data: float = 0.5 # Expected standard deviation of the training data. 440 | logvar_channels: int = 128 # Intermediate dimensionality for uncertainty estimation. 441 | return_logvar: bool = False 442 | # **unet_kwargs # Keyword arguments for UNet. 443 | model_channels: int = 192 # Base multiplier for the number of channels. 444 | channel_mult: tuple = ( 445 | 1, 446 | 2, 447 | 3, 448 | 4, 449 | ) # Per-resolution multipliers for the number of channels. 450 | channel_mult_noise: Any = None # Multiplier for noise embedding dimensionality. None = select based on channel_mult. 451 | channel_mult_emb: Any = None # Multiplier for final embedding dimensionality. None = select based on channel_mult. 452 | num_blocks: int = 3 # Number of residual blocks per resolution. 453 | attn_resolutions: tuple = (16, 8) # List of resolutions with self-attention. 454 | label_balance: float = ( 455 | 0.5 # Balance between noise embedding (0) and class embedding (1). 456 | ) 457 | concat_balance: float = 0.5 # Balance between skip connections (0) and main path (1). 458 | out_gain: float = 1.0 459 | unet_kwargs = { 460 | "model_channels": model_channels, 461 | "channel_mult": channel_mult, 462 | "channel_mult_noise": channel_mult_noise, 463 | "channel_mult_emb": channel_mult_emb, 464 | "num_blocks": num_blocks, 465 | "attn_resolutions": attn_resolutions, 466 | "label_balance": label_balance, 467 | "concat_balance": concat_balance, 468 | "out_gain": out_gain, 469 | } 470 | 471 | # **block_kwargs # Keyword arguments for Block 472 | resample_filter: tuple = (1, 1) # Resampling filter 473 | channels_per_head: int = 64 # Number of channels per attention head 474 | dropout: float = 0.0 # Dropout probability 475 | res_balance: float = 0.3 # Balance between main branch (0) and residual branch (1) 476 | attn_balance: float = 0.3 # Balance between main branch (0) and self-attention (1) 477 | clip_act: int = 256 # Clip output activations. None = do not clip 478 | out_gain: Any = None 479 | block_kwargs = { 480 | "resample_filter": resample_filter, 481 | "channels_per_head": channels_per_head, 482 | "dropout": dropout, 483 | "res_balance": res_balance, 484 | "attn_balance": attn_balance, 485 | "clip_act": clip_act, 486 | } 487 | 488 | @nn.compact 489 | def __call__( 490 | self, 491 | x, 492 | sigma, 493 | class_labels=None, 494 | force_fp32=False, 495 | ): 496 | x = jnp.float32(x) 497 | sigma = jnp.float32(sigma).reshape(-1, 1, 1, 1) 498 | class_labels = ( 499 | None 500 | if self.label_dim == 0 501 | else jnp.zeros((1, self.label_dim), device=x.device) 502 | if class_labels is None 503 | else jnp.float32(class_labels).reshape(-1, self.label_dim) 504 | ) 505 | dtype = jnp.float16 if (self.use_fp16 and not force_fp32) else jnp.float32 506 | 507 | # Preconditioning weights 508 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 509 | c_out = sigma * self.sigma_data / jnp.sqrt(sigma**2 + self.sigma_data**2) 510 | c_in = 1 / jnp.sqrt(self.sigma_data**2 + sigma**2) 511 | c_noise = jnp.log(sigma.flatten()) / 4 512 | 513 | # Run the model 514 | x_in = jnp.array(c_in * x, dtype=dtype) 515 | 516 | F_x = UNet( 517 | img_resolution=self.img_resolution, 518 | img_channels=self.img_channels, 519 | label_dim=self.label_dim, 520 | **self.unet_kwargs, 521 | **self.block_kwargs, 522 | name="unet", 523 | )(x_in, c_noise, class_labels) 524 | D_x = c_skip * x + c_out * jnp.float32(F_x) 525 | 526 | # Estimate uncertainty if requested 527 | if self.return_logvar: 528 | logvar = MPConv(self.logvar_channels, 1, kernel_shape=(), name="logvar_linear")( 529 | MPFourier(self.logvar_channels, name="logvar_fourier")(c_noise) 530 | ).reshape(-1, 1, 1, 1) 531 | return D_x, logvar # u(sigma) in Equation 21 532 | return D_x 533 | -------------------------------------------------------------------------------- /diffusionjax/plot.py: -------------------------------------------------------------------------------- 1 | """Plotting code for the examples.""" 2 | 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | from jax import jit, vmap 6 | from functools import partial 7 | import matplotlib.animation as animation 8 | import numpy as np 9 | 10 | 11 | BG_ALPHA = 1.0 12 | MG_ALPHA = 1.0 13 | FG_ALPHA = 0.3 14 | color_posterior = "#a2c4c9" 15 | color_algorithm = "#ff7878" 16 | dpi_val = 1200 17 | cmap = "magma" 18 | 19 | 20 | def plot_heatmap(samples, area_bounds, lengthscale=350.0, fname="plot_heatmap") -> None: 21 | """Plots a heatmap of all samples in the area area_bounds x area_bounds. 22 | Args: 23 | samples: locations of particles shape (num_particles, 2) 24 | """ 25 | 26 | def small_kernel(z, area_bounds): 27 | a = jnp.linspace(area_bounds[0], area_bounds[1], 512) 28 | x, y = jnp.meshgrid(a, a) 29 | dist = (x - z[0]) ** 2 + (y - z[1]) ** 2 30 | hm = jnp.exp(-lengthscale * dist) 31 | return hm 32 | 33 | @jit # jit most of the code, but use the helper functions since cannot jit all of it because of plt 34 | def produce_heatmap(samples, area_bounds): 35 | return jnp.sum(vmap(small_kernel, in_axes=(0, None))(samples, area_bounds), axis=0) 36 | 37 | hm = produce_heatmap(samples, area_bounds) 38 | extent = area_bounds + area_bounds 39 | plt.imshow(hm, interpolation="nearest", extent=extent) 40 | ax = plt.gca() 41 | ax.invert_yaxis() 42 | plt.savefig(fname + ".png") 43 | plt.close() 44 | 45 | 46 | def image_grid(x, image_size, num_channels): 47 | img = x.reshape(-1, image_size, image_size, num_channels) 48 | w = int(np.sqrt(img.shape[0])) 49 | return ( 50 | img.reshape((w, w, image_size, image_size, num_channels)) 51 | .transpose((0, 2, 1, 3, 4)) 52 | .reshape((w * image_size, w * image_size, num_channels)) 53 | ) 54 | 55 | 56 | def plot_samples(x, image_size=32, num_channels=3, fname="samples"): 57 | img = image_grid(x, image_size, num_channels) 58 | plt.figure(figsize=(8, 8)) 59 | plt.axis("off") 60 | plt.imshow(img, cmap=cmap) 61 | plt.savefig(fname + ".png", bbox_inches="tight", pad_inches=0.0) 62 | # plt.savefig(fname + '.pdf', bbox_inches='tight', pad_inches=0.0) 63 | plt.close() 64 | 65 | 66 | def plot_scatter(samples, index, fname="samples", lims=None): 67 | fig, ax = plt.subplots(1, 1) 68 | fig.patch.set_facecolor("white") 69 | fig.patch.set_alpha(BG_ALPHA) 70 | ax.scatter(samples[:, index[0]], samples[:, index[1]], color="red", label=r"$x$") 71 | ax.legend() 72 | ax.set_xlabel(r"$x_{}$".format(index[0])) 73 | ax.set_ylabel(r"$x_{}$".format(index[1])) 74 | if lims is not None: 75 | ax.set_xlim(lims[0]) 76 | ax.set_ylim(lims[1]) 77 | plt.gca().set_aspect("equal", adjustable="box") 78 | plt.draw() 79 | fig.savefig(fname + ".png", facecolor=fig.get_facecolor(), edgecolor="none") 80 | plt.close() 81 | 82 | 83 | def plot_samples_1D(samples, image_size, x_max=5.0, fname="samples 1D", alpha=FG_ALPHA): 84 | x = np.linspace(-x_max, x_max, image_size) 85 | plt.plot(x, samples[..., 0].T, alpha=alpha) 86 | plt.xlim(-5.0, 5.0) 87 | plt.ylim(-5.0, 5.0) 88 | plt.savefig(fname + ".png") 89 | plt.close() 90 | 91 | 92 | def plot_animation(fig, ax, animate, frames, fname, fps=20, bitrate=800, dpi=300): 93 | ani = animation.FuncAnimation(fig, animate, frames=frames, interval=1, fargs=(ax,)) 94 | # Set up formatting for the movie files 95 | Writer = animation.writers["ffmpeg"] 96 | writer = Writer(fps=fps, metadata=dict(artist="Me"), bitrate=bitrate) 97 | # Note that mp4 does not work on pdf 98 | ani.save("{}.mp4".format(fname), writer=writer, dpi=dpi) 99 | 100 | 101 | def plot_score(score, scaler, t, area_bounds=[-3.0, 3.0], fname="plot_score"): 102 | fig, ax = plt.subplots(1, 1) 103 | 104 | # this helper function is here so that we can jit 105 | @partial( 106 | jit, 107 | static_argnums=[ 108 | 0, 109 | ], 110 | ) # We can not jit the whole function since plt.quiver cannot be jitted 111 | def helper(score, t, area_bounds): 112 | x = jnp.linspace(area_bounds[0], area_bounds[1], 16) 113 | x, y = jnp.meshgrid(x, x) 114 | grid = jnp.stack([x.flatten(), y.flatten()], axis=1) 115 | t = jnp.ones((grid.shape[0],)) * t 116 | scores = score(scaler(grid), t) 117 | return grid, scores 118 | 119 | grid, scores = helper(score, t, area_bounds) 120 | ax.quiver(grid[:, 0], grid[:, 1], scores[:, 0], scores[:, 1]) 121 | ax.set_xlabel(r"$x_0$") 122 | ax.set_ylabel(r"$x_1$") 123 | plt.gca().set_aspect("equal", adjustable="box") 124 | fig.savefig(fname + ".png") 125 | plt.close() 126 | 127 | 128 | def plot_score_ax(ax, score, scaler, t, area_bounds=[-3.0, 3.0]): 129 | @partial( 130 | jit, 131 | static_argnums=[ 132 | 0, 133 | ], 134 | ) # We can not jit the whole function since plt.quiver cannot be jitted 135 | def helper(score, t, area_bounds): 136 | x = jnp.linspace(area_bounds[0], area_bounds[1], 16) 137 | x, y = jnp.meshgrid(x, x) 138 | grid = jnp.stack([x.flatten(), y.flatten()], axis=1) 139 | t = jnp.ones((grid.shape[0],)) * t 140 | scores = score(scaler(grid), t) 141 | return grid, scores 142 | 143 | grid, scores = helper(score, t, area_bounds) 144 | ax.quiver(grid[:, 0], grid[:, 1], scores[:, 0], scores[:, 3]) 145 | ax.set_xlabel(r"$x_0$") 146 | ax.set_ylabel(r"$x_1$") 147 | 148 | 149 | def plot_heatmap_ax(ax, samples, area_bounds=[-3.0, 3.0], lengthscale=350): 150 | """Plots a heatmap of all samples in the area area_bounds^{2}. 151 | Args: 152 | samples: locations of all particles in R^2, array (J, 2) 153 | """ 154 | 155 | def small_kernel(z, area_bounds): 156 | a = jnp.linspace(area_bounds[0], area_bounds[1], 512) 157 | x, y = jnp.meshgrid(a, a) 158 | dist = (x - z[0]) ** 2 + (y - z[1]) ** 2 159 | hm = jnp.exp(-lengthscale * dist) 160 | return hm 161 | 162 | @jit 163 | def produce_heatmap(samples, area_bounds): 164 | return jnp.sum( 165 | vmap(small_kernel, in_axes=(0, None, None))(samples, area_bounds), axis=0 166 | ) 167 | 168 | hm = produce_heatmap(samples, area_bounds) 169 | extent = area_bounds + area_bounds 170 | ax.imshow(hm, interpolation="nearest", extent=extent) 171 | ax = plt.gca() 172 | ax.invert_yaxis() 173 | ax.set_xlabel(r"$x_0$") 174 | ax.set_ylabel(r"$x_1$") 175 | 176 | 177 | def plot_temperature_schedule(sde, solver): 178 | """Plots the temperature schedule of the SDE marginals. 179 | 180 | Args: 181 | sde: a valid SDE class. 182 | """ 183 | m2 = sde.mean_coeff(solver.ts) ** 2 184 | v = sde.variance(solver.ts) 185 | plt.plot(solver.ts, m2, label="m2") 186 | plt.plot(solver.ts, v, label="v") 187 | plt.legend() 188 | plt.savefig("plot_temperature_schedule.png") 189 | plt.close() 190 | -------------------------------------------------------------------------------- /diffusionjax/run_lib.py: -------------------------------------------------------------------------------- 1 | """Training and evaluation for score-based generative models.""" 2 | 3 | import jax 4 | from jax import jit, value_and_grad 5 | import jax.random as random 6 | import jax.numpy as jnp 7 | from diffusionjax.utils import ( 8 | get_loss, 9 | get_score, 10 | get_sampler, 11 | get_times, 12 | get_exponential_sigma_function, 13 | get_linear_beta_function, 14 | ) 15 | import diffusionjax.sde as sde_lib 16 | from diffusionjax.solvers import EulerMaruyama, Annealed, DDIMVP, DDIMVE, SMLD, DDPM 17 | import numpy as np 18 | from functools import partial 19 | import flax 20 | import flax.training.orbax_utils as orbax_utils 21 | import flax.jax_utils as flax_utils 22 | from absl import flags 23 | from tqdm import tqdm, trange 24 | import os 25 | import time 26 | from typing import Any 27 | import logging 28 | import wandb 29 | 30 | # This run_library requires optax, https://optax.readthedocs.io/en/latest/ 31 | import optax 32 | 33 | # This run_library requires orbax, https://orbax.readthedocs.io/en/latest/ 34 | import orbax.checkpoint 35 | 36 | # This run_library requires torch[cpu], https://pytorch.org/get-started/locally/ 37 | from torch.utils.data import DataLoader 38 | 39 | 40 | FLAGS = flags.FLAGS 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | def get_step_fn(loss, optimizer, train, pmap): 45 | """Create a one-step training/evaluation function. 46 | 47 | Args: 48 | loss: A loss function. 49 | optimizer: An optimization function. 50 | train: `True` for training and `False` for evaluation. 51 | pmap: `True` for pmap across jax devices, `False` for single device. 52 | 53 | Returns: 54 | A one-step function for training or evaluation. 55 | """ 56 | 57 | @jit 58 | def step_fn(carry, batch): 59 | (rng, params, opt_state) = carry 60 | rng, step_rng = random.split(rng) 61 | grad_fn = value_and_grad(loss) 62 | if train: 63 | loss_val, grads = grad_fn(params, step_rng, batch) 64 | if pmap: 65 | loss_val = jax.lax.pmean(loss_val, axis_name="batch") 66 | grads = jax.lax.pmean(grads, axis_name="batch") 67 | updates, opt_state = optimizer.update(grads, opt_state) 68 | params = optax.apply_updates(params, updates) 69 | else: 70 | loss_val = loss(params, step_rng, batch) 71 | if pmap: 72 | loss_val = jax.lax.pmean(loss_val, axis_name="batch") 73 | return (rng, params, opt_state), loss_val 74 | 75 | return step_fn 76 | 77 | 78 | # The dataclass that stores all training states 79 | @flax.struct.dataclass 80 | class State: 81 | step: int 82 | opt_state: Any 83 | params: Any 84 | rng: Any 85 | lr: float 86 | 87 | 88 | def get_sde(config): 89 | # Setup SDE 90 | if config.training.sde.lower() == "vpsde": 91 | beta, mean_coeff = get_linear_beta_function( 92 | config.model.beta_min, config.model.beta_max 93 | ) 94 | return sde_lib.VP(beta=beta, mean_coeff=mean_coeff) 95 | elif config.training.sde.lower() == "vesde": 96 | sigma = get_exponential_sigma_function(config.model.sigma_min, config.model.sigma_max) 97 | return sde_lib.VE(sigma=sigma) 98 | else: 99 | raise NotImplementedError(f"SDE {config.training.SDE} unknown.") 100 | 101 | 102 | def get_optimizer(config): 103 | """Returns an optax optimizer object based on `config`.""" 104 | if config.optim.warmup: 105 | schedule = optax.warmup_cosine_decay_schedule( 106 | init_value=0.0, 107 | peak_value=1.0, 108 | warmup_steps=config.optim.warmup, 109 | decay_steps=config.optim.warmup + 1, 110 | end_value=1.0, 111 | ) 112 | else: 113 | schedule = config.optim.lr 114 | if config.optim.optimizer == "Adam": 115 | if config.optim.weight_decay: 116 | optimizer = optax.adamw( 117 | learning_rate=schedule, b1=config.optim.beta1, eps=config.optim.eps 118 | ) 119 | else: 120 | optimizer = optax.adam( 121 | learning_rate=schedule, b1=config.optim.beta1, eps=config.optim.eps 122 | ) 123 | else: 124 | raise NotImplementedError( 125 | "Optimiser {} not supported yet!".format(config.optim.optimizer) 126 | ) 127 | if config.optim.grad_clip: 128 | optimizer = optax.chain(optax.clip(config.optim.grad_clip), optimizer) 129 | return optimizer 130 | 131 | 132 | def get_solver(config, sde, score): 133 | if config.solver.outer_solver.lower() == "eulermaruyama": 134 | ts, _ = get_times( 135 | num_steps=config.solver.num_outer_steps, 136 | dt=config.solver.dt, 137 | t0=config.solver.epsilon, 138 | ) 139 | outer_solver = EulerMaruyama(sde.reverse(score), ts) 140 | else: 141 | raise NotImplementedError(f"Solver {config.solver.outer_solver} unknown.") 142 | if config.solver.inner_solver is None: 143 | inner_solver = None 144 | elif config.solver.inner_solver.lower() == "annealed": 145 | ts, _ = get_times(num_steps=config.solver.num_inner_steps) 146 | inner_solver = Annealed( 147 | sde.corrector(sde_lib.ULangevin, score), snr=config.solver.snr, ts=ts 148 | ) 149 | else: 150 | raise NotImplementedError(f"Solver {config.solver.inner_solver} unknown.") 151 | return outer_solver, inner_solver 152 | 153 | 154 | def get_ddim_chain(config, model): 155 | """ 156 | Args: 157 | model: DDIM parameterizes the `epsilon(x, t) = -1. * fwd_marginal_std(t) * score(x, t)` function 158 | """ 159 | if config.solver.outer_solver.lower() == "ddimvp": 160 | ts, _ = get_times( 161 | config.solver.num_outer_steps, dt=config.solver.dt, t0=config.solver.epsilon 162 | ) 163 | beta, _ = get_linear_beta_function( 164 | beta_min=config.model.beta_min, beta_max=config.model.beta_max 165 | ) 166 | return DDIMVP(model, eta=config.solver.eta, beta=beta, ts=ts) 167 | elif config.solver.outer_solver.lower() == "ddimve": 168 | ts, _ = get_times( 169 | config.solver.num_outer_steps, dt=config.solver.dt, t0=config.solver.epsilon 170 | ) 171 | sigma = get_exponential_sigma_function( 172 | sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max 173 | ) 174 | return DDIMVE(model, eta=config.solver.eta, sigma=sigma, ts=ts) 175 | else: 176 | raise NotImplementedError(f"DDIM Chain {config.solver.outer_solver} unknown.") 177 | 178 | 179 | def get_markov_chain(config, score): 180 | """ 181 | Args: 182 | score: DDPM/SMLD(NCSN) parameterizes the `score(x, t)` function. 183 | """ 184 | if config.solver.outer_solver.lower() == "ddpm": 185 | ts, _ = get_times( 186 | num_steps=config.solver.num_outer_steps, 187 | dt=config.solver.dt, 188 | t0=config.solver.epsilon, 189 | ) 190 | beta, _ = get_linear_beta_function( 191 | beta_min=config.model.beta_min, beta_max=config.model.beta_max 192 | ) 193 | return DDPM(score, beta=beta, ts=ts) 194 | elif config.solver.outer_solver.lower() == "smld": 195 | ts, _ = get_times( 196 | config.solver.num_outer_steps, dt=config.solver.dt, t0=config.solver.epsilon 197 | ) 198 | sigma = get_exponential_sigma_function( 199 | sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max 200 | ) 201 | return SMLD(score, sigma=sigma, ts=ts) 202 | else: 203 | raise NotImplementedError(f"Markov Chain {config.solver.outer_solver} unknown.") 204 | 205 | 206 | def numpy_collate(batch): 207 | if isinstance(batch[0], np.ndarray): 208 | return np.stack(batch) 209 | elif isinstance(batch[0], (tuple, list)): 210 | transposed = zip(*batch) 211 | return [numpy_collate(samples) for samples in transposed] 212 | else: 213 | return np.array(batch) 214 | 215 | 216 | def jit_collate(n_jitted_steps, batch_size, batch): 217 | return np.reshape(batch, (n_jitted_steps, batch_size, -1)) 218 | 219 | 220 | def pmap_and_jit_collate(num_devices, n_jitted_steps, per_device_batch_size, batch): 221 | return np.reshape(batch, (num_devices, n_jitted_steps, per_device_batch_size, -1)) 222 | 223 | 224 | def pmap_collate(num_devices, per_device_batch_size, batch): 225 | return np.reshape(batch, (num_devices, per_device_batch_size, -1)) 226 | 227 | 228 | class NumpyLoader(DataLoader): 229 | def __init__( 230 | self, 231 | config, 232 | dataset, 233 | shuffle=False, 234 | sampler=None, 235 | batch_sampler=None, 236 | num_workers=0, 237 | pin_memory=False, 238 | drop_last=False, 239 | timeout=0, 240 | worker_init_fn=None, 241 | ): 242 | prod_batch_dims = config.training.batch_size * config.training.n_jitted_steps 243 | if config.training.pmap and config.training.n_jitted_steps != 1: 244 | collate_fn = partial( 245 | pmap_and_jit_collate, 246 | jax.local_device_count(), 247 | config.training.n_jitted_steps, 248 | config.training.batch_size // jax.local_device_count(), 249 | ) 250 | elif config.training.pmap and config.training.n_jitted_steps == 1: 251 | collate_fn = partial( 252 | pmap_collate, 253 | jax.local_device_count(), 254 | config.training.batch_size // jax.local_device_count(), 255 | ) 256 | elif config.training.n_jitted_steps != 1: 257 | collate_fn = partial( 258 | jit_collate, config.training.n_jitted_steps, config.training.batch_size 259 | ) 260 | else: 261 | collate_fn = numpy_collate # type: ignore 262 | 263 | super().__init__( 264 | dataset, 265 | batch_size=prod_batch_dims, 266 | shuffle=shuffle, 267 | sampler=sampler, 268 | batch_sampler=batch_sampler, 269 | num_workers=num_workers, 270 | collate_fn=collate_fn, 271 | pin_memory=pin_memory, 272 | drop_last=drop_last, 273 | timeout=timeout, 274 | worker_init_fn=worker_init_fn, 275 | ) 276 | 277 | 278 | def train(sampling_shape, config, model, dataset, workdir=None, use_wandb=False): 279 | """Train a score based generative model using stochastic gradient descent 280 | 281 | Args: 282 | sampling_shape : sampling shape may differ depending on the modality of data 283 | model: A valid flax nn.Module. 284 | config: An ml-collections configuration to use. 285 | dataset: a valid `torch.DataLoader` class. 286 | workdir: Optional working directory for checkpoints and TF summaries. If this 287 | contains checkpoint training will be resumed from the latest checkpoint. 288 | use_wandb: Bool. If set to `True`, uses weights and biases to store and visualize loss data. 289 | """ 290 | train_dataloader = NumpyLoader(config, dataset) 291 | eval_dataloader = NumpyLoader(config, dataset) 292 | 293 | jax.default_device = jax.devices()[0] # type: ignore 294 | # Tip: use `export CUDA_VISIBLE_DEVICES` to restrict the devices visible to jax 295 | # ... devices (GPUs/TPUs) must be all the same model for pmap to work 296 | num_devices = int(jax.local_device_count()) 297 | if jax.process_index() == 0: 298 | print("num_devices={}, pmap={}".format(num_devices, config.training.pmap)) 299 | 300 | # Create directories for experimental logs 301 | if workdir is not None: 302 | sample_dir = os.path.join(workdir, "samples") 303 | if not os.path.exists(sample_dir): 304 | os.mkdir(sample_dir) 305 | 306 | scaler = dataset.get_data_scaler(config) 307 | inverse_scaler = dataset.get_data_inverse_scaler(config) 308 | # eval_function = dataset.calculate_metrics_batch 309 | # metric_names = dataset.metric_names() 310 | 311 | if jax.process_index() == 0 and use_wandb: 312 | run = wandb.init( 313 | project="diffusionjax", 314 | config=config, 315 | ) 316 | rng = random.PRNGKey(config.seed) 317 | 318 | # Initialize model 319 | rng, model_rng = random.split(rng, 2) 320 | # Initialize parameters 321 | params = model.init( 322 | model_rng, jnp.zeros(sampling_shape), jnp.ones((sampling_shape[0],)) 323 | ) 324 | 325 | # Initialize optimizer 326 | optimizer = get_optimizer(config) 327 | opt_state = optimizer.init(params) 328 | state = State(step=0, opt_state=opt_state, params=params, lr=config.optim.lr, rng=rng) 329 | 330 | if workdir is not None: 331 | # Create checkpoints directory 332 | checkpoint_dir = os.path.join(workdir, "checkpoints") 333 | if not os.path.exists(checkpoint_dir): 334 | os.mkdir(checkpoint_dir) 335 | 336 | # Intermediate checkpoints to resume training after pre-emption in cloud environments 337 | checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") 338 | if not os.path.exists(checkpoint_meta_dir): 339 | os.mkdir(checkpoint_meta_dir) 340 | 341 | # Orbax checkpointer boilerplate 342 | manager_options = orbax.checkpoint.CheckpointManagerOptions( 343 | create=True, max_to_keep=np.inf 344 | ) 345 | checkpoint_manager = orbax.checkpoint.CheckpointManager( 346 | checkpoint_dir, 347 | orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), 348 | manager_options, 349 | ) 350 | 351 | # meta_manager_options = orbax.checkpoint.CheckpointManagerOptions( 352 | # create=True, max_to_keep=1) 353 | meta_checkpoint_manager = orbax.checkpoint.CheckpointManager( 354 | checkpoint_meta_dir, 355 | orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), 356 | manager_options, 357 | ) 358 | 359 | # Resume training when intermediate checkpoints are detected 360 | restore_args = orbax_utils.restore_args_from_target(state, mesh=None) 361 | save_step = meta_checkpoint_manager.latest_step() 362 | if save_step is not None: 363 | meta_checkpoint_manager.restore( 364 | save_step, items=state, restore_kwargs={"restore_args": restore_args} 365 | ) 366 | 367 | # `state.step` is JAX integer on the GPU/TPU devices 368 | initial_step = int(state.step) 369 | rng = state.rng 370 | 371 | # Build one-step training and evaluation functions 372 | sde = get_sde(config) 373 | # Trained score 374 | score = get_score(sde, model, params, score_scaling=config.training.score_scaling) 375 | 376 | # Setup solver 377 | outer_solver, inner_solver = get_solver(config, sde, score) 378 | 379 | loss = get_loss( 380 | sde, 381 | outer_solver, 382 | model, 383 | score_scaling=config.training.score_scaling, 384 | likelihood_weighting=config.training.likelihood_weighting, 385 | ) 386 | train_step = get_step_fn(loss, optimizer, train=True, pmap=config.training.pmap) 387 | eval_step = get_step_fn(loss, optimizer, train=False, pmap=config.training.pmap) 388 | 389 | if config.training.n_jitted_steps > 1: 390 | train_step = partial(jax.lax.scan, train_step) 391 | eval_step = partial(jax.lax.scan, eval_step) 392 | 393 | if config.training.pmap: 394 | train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=1) 395 | eval_step = jax.pmap(eval_step, axis_name="batch", donate_argnums=1) 396 | 397 | # Replicate the training state to run on multiple devices 398 | state = flax_utils.replicate(state) 399 | 400 | # Probably want to train over multiple epochs 401 | # If num_epochs > num_batch, decides which tqdm to go over 402 | i_epoch = 0 403 | prev_time = time.time() 404 | 405 | # Deal with training in a number of steps 406 | # num_epochs = config.training.n_iters // (dataset_size / batch_size) 407 | num_epochs = config.training.n_iters 408 | 409 | step = initial_step 410 | # In case there are multiple hosts (e.g., TPU pods), only log to host 0 411 | if jax.process_index() == 0: 412 | logging.info("Starting training loop at step %d." % (initial_step,)) 413 | rng = jax.random.fold_in(rng, jax.process_index()) 414 | 415 | # JIT multiple training steps together for faster training 416 | n_jitted_steps = config.training.n_jitted_steps 417 | 418 | # Must be divisible by the number of steps jitted together 419 | assert ( 420 | config.training.log_step_freq % n_jitted_steps == 0 421 | and config.training.snapshot_freq_for_preemption % n_jitted_steps == 0 422 | and config.training.eval_freq % n_jitted_steps == 0 423 | and config.training.snapshot_freq % n_jitted_steps == 0 424 | ), "Missing logs or checkpoints!" 425 | 426 | mean_losses = jnp.zeros((num_epochs, 1)) 427 | for i_epoch in trange(1, num_epochs + 1, unit="epochs"): 428 | current_time = time.time() 429 | 430 | if i_epoch != 0 and (num_epochs < config.training.batch_size): 431 | print("Epoch took {:.1f} seconds".format(current_time - prev_time)) 432 | prev_time = time.time() 433 | 434 | eval_iter = iter(eval_dataloader) 435 | 436 | with tqdm(train_dataloader, unit=" batch", disable=True) as tepoch: 437 | tepoch.set_description(f"Epoch {i_epoch}") 438 | losses = jnp.empty((len(tepoch), 1)) 439 | 440 | for i_batch, batch in enumerate(tepoch): 441 | batch = jax.tree_util.tree_map(lambda x: scaler(x), batch) 442 | 443 | # Execute one training step TODO: don't need a split here and in loss? 444 | if config.training.pmap: 445 | rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) 446 | next_rng = jnp.asarray(next_rng) # type: ignore 447 | else: 448 | rng, next_rng = jax.random.split(rng) # type: ignore 449 | 450 | (_, params, opt_state), loss_train = train_step( 451 | (next_rng, state.params, state.opt_state), batch 452 | ) 453 | # TODO: Can't just interate state? move inside train_step? should rng be part of state? 454 | state = state.replace(opt_state=opt_state, params=params) # type: ignore 455 | 456 | if config.training.pmap: 457 | loss_train = flax_utils.unreplicate( 458 | loss_train 459 | ).mean() # returns a single instance of replicated loss array 460 | else: 461 | loss_train = loss_train.mean() 462 | 463 | # Log to console, file and wandb on host 0 464 | if jax.process_index() == 0: 465 | step += config.training.n_jitted_steps 466 | losses = losses.at[i_batch].set(loss_train) 467 | if ( 468 | step % config.training.log_step_freq == 0 469 | and jax.process_index() == 0 470 | and use_wandb 471 | ): 472 | logging.info("step {:d}, training_loss {:.2e}".format(step, loss_train)) 473 | 474 | # Save a temporary checkpoint to resume training after pre-emption (for cloud computing environments) periodically 475 | if ( 476 | step != 0 477 | and step % config.training.snapshot_freq_for_preemption == 0 478 | and jax.process_index() == 0 479 | ): 480 | if config.training.pmap: 481 | saved_state = flax_utils.unreplicate(state) 482 | else: 483 | saved_state = state 484 | saved_state = saved_state.replace(rng=rng) 485 | if workdir: 486 | saved_args = orbax_utils.save_args_from_target(saved_state) 487 | meta_checkpoint_manager.save( 488 | step // config.training.snapshot_freq_for_preemption, 489 | saved_state, 490 | save_kwargs={"save_args": saved_args}, 491 | ) 492 | 493 | # Report the loss on an evaluation dataset periodically 494 | if step % config.training.eval_freq == 0: 495 | eval_batch = jax.tree_util.tree_map(lambda x: scaler(x), next(eval_iter)) 496 | if config.training.pmap: 497 | rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) 498 | next_rng = jnp.asarray(next_rng) # type: ignore 499 | else: 500 | rng, next_rng = jax.random.split(rng) # type: ignore 501 | (_, _, _), loss_eval = eval_step( 502 | (next_rng, state.params, state.opt_state), eval_batch 503 | ) 504 | 505 | if config.training.pmap: 506 | loss_eval = flax_utils.unreplicate(loss_eval).mean() 507 | else: 508 | loss_eval = loss_eval.mean() 509 | 510 | if jax.process_index() == 0 and use_wandb: 511 | logging.info("batch: {:d}, eval_loss: {:.5e}".format(step, loss_eval)) 512 | wandb.log({"eval-loss": loss_eval}) 513 | 514 | # Save a checkpoint periodically and generate samples if needed 515 | if ( 516 | step != 0 517 | and step % config.training.snapshot_freq == 0 518 | or step == config.training.n_iters 519 | ): 520 | # Save the checkpoint 521 | if jax.process_index() == 0: 522 | if config.training.pmap: 523 | saved_state = flax_utils.unreplicate(state) 524 | else: 525 | saved_state = state 526 | saved_state = saved_state.replace(rng=rng) 527 | if workdir: 528 | saved_args = orbax_utils.save_args_from_target(saved_state) 529 | checkpoint_manager.save( 530 | step // config.training.snapshot_freq, 531 | saved_state, 532 | save_kwargs={"save_args": saved_args}, 533 | ) 534 | 535 | # Generate and save samples 536 | if config.training.snapshot_sampling: 537 | # Setup solver with new trained score 538 | # Use the unreplicated parameters of the saved state 539 | score = get_score( 540 | sde, model, saved_state.params, config.training.score_scaling 541 | ) 542 | outer_solver, inner_solver = get_solver(config, sde, score) 543 | sampler = get_sampler( 544 | sampling_shape, 545 | outer_solver, 546 | inner_solver, 547 | denoise=config.sampling.denoise, 548 | stack_samples=config.sampling.stack_samples, 549 | inverse_scaler=inverse_scaler, 550 | ) 551 | 552 | if config.training.pmap: 553 | sampler = jax.pmap(sampler, axis_name="batch") 554 | rng, *sample_rng = random.split(rng, 1 + jax.local_device_count()) 555 | sample_rng = jnp.asarray(sample_rng) # type: ignore 556 | else: 557 | rng, sample_rng = random.split(rng, 2) # type: ignore 558 | 559 | sample, _ = sampler(sample_rng) 560 | 561 | # eval_fn = eval_function(sample) 562 | # wandb.log({metric_names[0]: eval_fn}) 563 | 564 | if workdir: 565 | this_sample_dir = os.path.join( 566 | sample_dir, "iter_{}_host_{}".format(step, jax.process_index()) 567 | ) 568 | if not os.path.isdir(this_sample_dir): 569 | os.mkdir(this_sample_dir) 570 | 571 | with open(os.path.join(this_sample_dir, "sample.np"), "wb") as infile: 572 | np.save(infile, sample) 573 | 574 | if jax.process_index() == 0: 575 | mean_loss = jnp.mean(losses, axis=0) 576 | 577 | if jax.process_index() == 0 and i_epoch % config.training.log_epoch_freq == 0: 578 | mean_losses = mean_losses.at[i_epoch].set(mean_loss) 579 | if use_wandb: 580 | logging.info( 581 | "step {:d}, mean_loss {:.2e}".format(int(step), float(mean_loss[0])) 582 | ) 583 | wandb.log({"train-loss": mean_loss}) 584 | 585 | if workdir and use_wandb: 586 | artifact = wandb.Artifact(name="checkpoint", type="checkpoint") 587 | artifact.add_dir(local_path=checkpoint_dir) 588 | run.log_artifact(artifact) # type: ignore 589 | 590 | # Get the model and do test dataset 591 | if config.training.pmap: 592 | saved_state = flax_utils.unreplicate(state) 593 | else: 594 | saved_state = state 595 | return saved_state.params, saved_state.opt_state, mean_losses 596 | 597 | -------------------------------------------------------------------------------- /diffusionjax/sde.py: -------------------------------------------------------------------------------- 1 | """SDE class.""" 2 | 3 | import jax.numpy as jnp 4 | from jax import random, vmap 5 | from diffusionjax.utils import ( 6 | batch_mul, 7 | get_exponential_sigma_function, 8 | get_linear_beta_function, 9 | ) 10 | 11 | 12 | def ulangevin(score, x, t): 13 | drift = -score(x, t) 14 | diffusion = jnp.ones(x.shape) * jnp.sqrt(2) 15 | return drift, diffusion 16 | 17 | 18 | class ULangevin: 19 | """Unadjusted Langevin SDE.""" 20 | 21 | def __init__(self, score): 22 | self.score = score 23 | self.sde = lambda x, t: ulangevin(self.score, x, t) 24 | 25 | 26 | class RSDE: 27 | """Reverse SDE class.""" 28 | 29 | def __init__(self, score, forward_sde, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | self.score = score 32 | self.forward_sde = forward_sde 33 | 34 | def sde(self, x, t): 35 | drift, diffusion = self.forward_sde(x, t) 36 | drift = -drift + batch_mul(diffusion**2, self.score(x, t)) 37 | return drift, diffusion 38 | 39 | 40 | class VE: 41 | """Variance exploding (VE) SDE, a.k.a. diffusion process with a time dependent diffusion coefficient.""" 42 | 43 | def __init__(self, sigma=None): 44 | if sigma is None: 45 | self.sigma = get_exponential_sigma_function(sigma_min=0.01, sigma_max=378.0) 46 | else: 47 | self.sigma = sigma 48 | self.sigma_min = self.sigma(0.0) 49 | self.sigma_max = self.sigma(1.0) 50 | 51 | def sde(self, x, t): 52 | sigma_t = self.sigma(t) 53 | drift = jnp.zeros_like(x) 54 | diffusion = sigma_t * jnp.sqrt( 55 | 2 * (jnp.log(self.sigma_max) - jnp.log(self.sigma_min)) 56 | ) 57 | 58 | return drift, diffusion 59 | 60 | def mean_coeff(self, t): 61 | return jnp.ones_like(t) 62 | 63 | def variance(self, t): 64 | return self.sigma(t) ** 2 65 | 66 | def prior(self, rng, shape): 67 | return random.normal(rng, shape) * self.sigma_max 68 | 69 | def reverse(self, score): 70 | forward_sde = self.sde 71 | sigma = self.sigma 72 | 73 | return RVE(score, forward_sde, sigma) 74 | 75 | def r2(self, t, data_variance): 76 | r"""Analytic variance of the distribution at time zero conditioned on x_t, given crude assumption that 77 | the data distribution is isotropic-Gaussian. 78 | 79 | .. math:: 80 | \text{Variance of }p_{0}(x_{0}|x_{t}) \text{ if } p_{0}(x_{0}) = \mathcal{N}(0, \text{data_variance}I) 81 | \text{ and } p_{t|0}(x_{t}|x_{0}) = \mathcal{N}(x_0, \sigma_{t}^{2}I) 82 | """ 83 | variance = self.variance(t) 84 | return variance * data_variance / (variance + data_variance) 85 | 86 | def ratio(self, t): 87 | """Ratio of marginal variance and mean coeff.""" 88 | return self.variance(t) 89 | 90 | 91 | class VP: 92 | """Variance preserving (VP) SDE, a.k.a. time rescaled Ohrnstein Uhlenbeck (OU) SDE.""" 93 | 94 | def __init__(self, beta=None, mean_coeff=None): 95 | if beta is None: 96 | self.beta, self.mean_coeff = get_linear_beta_function( 97 | beta_min=0.1, beta_max=20.0 98 | ) 99 | else: 100 | self.beta = beta 101 | self.mean_coeff = mean_coeff 102 | 103 | def sde(self, x, t): 104 | beta_t = self.beta(t) 105 | drift = -0.5 * batch_mul(beta_t, x) 106 | diffusion = jnp.sqrt(beta_t) 107 | return drift, diffusion 108 | 109 | def std(self, t): 110 | return jnp.sqrt(self.variance(t)) 111 | 112 | def variance(self, t): 113 | return 1.0 - self.mean_coeff(t)**2 114 | 115 | def marginal_prob(self, x, t): 116 | return batch_mul(self.mean_coeff(t), x), jnp.sqrt(self.variance(t)) 117 | 118 | def prior(self, rng, shape): 119 | return random.normal(rng, shape) 120 | 121 | def reverse(self, score): 122 | fwd_sde = self.sde 123 | beta = self.beta 124 | mean_coeff = self.mean_coeff 125 | return RVP(score, fwd_sde, beta, mean_coeff) 126 | 127 | def r2(self, t, data_variance): 128 | r"""Analytic variance of the distribution at time zero conditioned on x_t, given crude assumption that 129 | the data distribution is isotropic-Gaussian. 130 | 131 | .. math:: 132 | \text{Variance of }p_{0}(x_{0}|x_{t}) \text{ if } p_{0}(x_{0}) = \mathcal{N}(0, \text{data_variance}I) 133 | \text{ and } p_{t|0}(x_{t}|x_{0}) = \mathcal{N}(\sqrt(\alpha_{t})x_0, (1 - \alpha_{t})I) 134 | """ 135 | alpha = self.mean_coeff(t)**2 136 | variance = 1.0 - alpha 137 | return variance * data_variance / (variance + alpha * data_variance) 138 | 139 | def ratio(self, t): 140 | """Ratio of marginal variance and mean coeff.""" 141 | return self.variance(t) / self.mean_coeff(t) 142 | 143 | 144 | class RVE(RSDE, VE): 145 | def get_estimate_x_0_vmap(self, observation_map): 146 | """ 147 | Get a function returning the MMSE estimate of x_0|x_t. 148 | 149 | Args: 150 | observation_map: function that operates on unbatched x. 151 | shape: optional tuple that reshapes x so that it can be operated on. 152 | """ 153 | 154 | def estimate_x_0(x, t): 155 | x = jnp.expand_dims(x, axis=0) 156 | t = jnp.expand_dims(t, axis=0) 157 | v_t = self.variance(t) 158 | s = self.score(x, t) 159 | x_0 = x + v_t * s 160 | return observation_map(x_0), (s, x_0) 161 | 162 | return estimate_x_0 163 | 164 | def get_estimate_x_0(self, observation_map, shape=None): 165 | """ 166 | Get a function returning the MMSE estimate of x_0|x_t. 167 | 168 | Args: 169 | observation_map: function that operates on unbatched x. 170 | shape: optional tuple that reshapes x so that it can be operated on. 171 | """ 172 | batch_observation_map = vmap(observation_map) 173 | 174 | def estimate_x_0(x, t): 175 | v_t = self.variance(t) 176 | s = self.score(x, t) 177 | x_0 = x + batch_mul(v_t, s) 178 | if shape: 179 | return batch_observation_map(x_0.reshape(shape)), (s, x_0) 180 | else: 181 | return batch_observation_map(x_0), (s, x_0) 182 | 183 | return estimate_x_0 184 | 185 | def guide(self, get_guidance_score, observation_map, *args, **kwargs): 186 | guidance_score = get_guidance_score(self, observation_map, *args, **kwargs) 187 | return RVE(guidance_score, self.forward_sde, self.sigma) 188 | 189 | def correct(self, corrector): 190 | class CVE(RVE): 191 | def sde(x, t): 192 | return corrector(self.score, x, t) 193 | 194 | return CVE(self.score, self.forward_sde, self.sigma) 195 | 196 | 197 | class RVP(RSDE, VP): 198 | def get_estimate_x_0_vmap(self, observation_map): 199 | """ 200 | Get a function returning the MMSE estimate of x_0|x_t. 201 | 202 | Args: 203 | observation_map: function that operates on unbatched x. 204 | shape: optional tuple that reshapes x so that it can be operated on. 205 | """ 206 | 207 | def estimate_x_0(x, t): 208 | x = jnp.expand_dims(x, axis=0) 209 | t = jnp.expand_dims(t, axis=0) 210 | m_t = self.mean_coeff(t) 211 | v_t = self.variance(t) 212 | s = self.score(x, t) 213 | x_0 = (x + v_t * s) / m_t 214 | return observation_map(x_0), (s, x_0) 215 | 216 | return estimate_x_0 217 | 218 | def get_estimate_x_0(self, observation_map, shape=None): 219 | """ 220 | Get a function returning the MMSE estimate of x_0|x_t. 221 | 222 | Args: 223 | observation_map: function that operates on unbatched x. 224 | shape: optional tuple that reshapes x so that it can be operated on. 225 | """ 226 | batch_observation_map = vmap(observation_map) 227 | 228 | def estimate_x_0(x, t): 229 | m_t = self.mean_coeff(t) 230 | v_t = self.variance(t) 231 | s = self.score(x, t) 232 | x_0 = batch_mul(x + batch_mul(v_t, s), 1.0 / m_t) 233 | if shape: 234 | return batch_observation_map(x_0.reshape(shape)), (s, x_0) 235 | else: 236 | return batch_observation_map(x_0), (s, x_0) 237 | 238 | return estimate_x_0 239 | 240 | def guide(self, get_guidance_score, observation_map, *args, **kwargs): 241 | guidance_score = get_guidance_score(self, observation_map, *args, **kwargs) 242 | return RVP(guidance_score, self.forward_sde, self.beta, self.mean_coeff) 243 | 244 | def correct(self, corrector): 245 | class CVP(RVP): 246 | def sde(x, t): 247 | return corrector(self.score, x, t) 248 | 249 | return CVP(self.score, self.forward_sde, self.beta, self.mean_coeff) 250 | -------------------------------------------------------------------------------- /diffusionjax/solvers.py: -------------------------------------------------------------------------------- 1 | """Solver classes, including Markov chains.""" 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as random 6 | from jax import vmap 7 | from diffusionjax.utils import ( 8 | batch_mul, 9 | batch_mul_A, 10 | get_times, 11 | get_timestep, 12 | get_exponential_sigma_function, 13 | get_karras_sigma_function, 14 | get_karras_gamma_function, 15 | get_linear_beta_function, 16 | continuous_to_discrete, 17 | ) 18 | import abc 19 | 20 | 21 | class Solver(abc.ABC): 22 | """SDE solver abstract class. Functions are designed for a mini-batch of inputs.""" 23 | 24 | def __init__(self, ts=None): 25 | """Construct a Solver. Note that for continuous time we choose to control for numerical 26 | error by using a beta schedule instead of an adaptive time step schedule, since adaptive 27 | time steps are equivalent to a beta schedule, and beta schedule hyperparameters have 28 | been explored extensively in the literature. Therefore, the timesteps must be equally 29 | spaced by dt. 30 | Args: 31 | ts: JAX array of equally spaced, monotonically increasing values t in [t0, t1]. 32 | """ 33 | if ts is None: 34 | ts, _ = get_times(num_steps=1000) 35 | self.ts = ts 36 | self.t1 = ts[-1] 37 | self.t0 = ts[0] 38 | self.dt = ts[1] - ts[0] 39 | self.num_steps = ts.size 40 | 41 | @abc.abstractmethod 42 | def update(self, rng, x, t): 43 | """Return the update of the state and any auxilliary values. 44 | 45 | Args: 46 | rng: A JAX random state. 47 | x: A JAX array of the state. 48 | t: JAX array of the time. 49 | 50 | Returns: 51 | x: A JAX array of the next state. 52 | x_mean: A JAX array. The next state without random noise. Useful for denoising. 53 | """ 54 | 55 | 56 | class EulerMaruyama(Solver): 57 | """Euler Maruyama numerical solver of an SDE. 58 | Functions are designed for a mini-batch of inputs.""" 59 | 60 | def __init__(self, sde, ts=None): 61 | """Constructs an Euler-Maruyama Solver. 62 | Args: 63 | sde: A valid SDE class. 64 | """ 65 | super().__init__(ts) 66 | self.sde = sde 67 | self.prior = sde.prior 68 | 69 | def update(self, rng, x, t): 70 | drift, diffusion = self.sde.sde(x, t) 71 | f = drift * self.dt 72 | G = diffusion * jnp.sqrt(self.dt) 73 | noise = random.normal(rng, x.shape) 74 | x_mean = x + f 75 | x = x_mean + batch_mul(G, noise) 76 | return x, x_mean 77 | 78 | 79 | class Annealed(Solver): 80 | """Annealed Langevin numerical solver of an SDE. 81 | Functions are designed for a mini-batch of inputs. 82 | Sampler must be `pmap` over "batch" axis as 83 | suggested by https://arxiv.org/abs/2011.13456 Song 84 | et al. 85 | """ 86 | 87 | def __init__(self, sde, snr=1e-2, ts=jnp.empty((2, 1))): 88 | """Constructs an Annealed Langevin Solver. 89 | Args: 90 | sde: A valid SDE class. 91 | snr: A hyperparameter representing a signal-to-noise ratio. 92 | ts: For a corrector, just need a placeholder JAX array with length 93 | number of timesteps of the inner solver. 94 | """ 95 | super().__init__(ts) 96 | self.sde = sde 97 | self.snr = snr 98 | self.prior = sde.prior 99 | 100 | def update(self, rng, x, t): 101 | grad = self.sde.score(x, t) 102 | grad_norm = jnp.linalg.norm(grad.reshape((grad.shape[0], -1)), axis=-1).mean() 103 | grad_norm = jax.lax.pmean(grad_norm, axis_name="batch") 104 | noise = random.normal(rng, x.shape) 105 | noise_norm = jnp.linalg.norm(noise.reshape((noise.shape[0], -1)), axis=-1).mean() 106 | noise_norm = jax.lax.pmean(noise_norm, axis_name="batch") 107 | alpha = self.sde.mean_coeff(t)**2 108 | dt = (self.snr * noise_norm / grad_norm) ** 2 * 2 * alpha 109 | x_mean = x + batch_mul(grad, dt) 110 | x = x_mean + batch_mul(2 * dt, noise) 111 | return x, x_mean 112 | 113 | 114 | class Inpainted(Solver): 115 | """Inpainting constraint for numerical solver of an SDE. 116 | Functions are designed for a mini-batch of inputs.""" 117 | 118 | def __init__(self, sde, mask, y, ts=jnp.empty((1, 1))): 119 | """Constructs an Annealed Langevin Solver. 120 | Args: 121 | sde: A valid SDE class. 122 | snr: A hyperparameter representing a signal-to-noise ratio. 123 | """ 124 | super().__init__(ts) 125 | self.sde = sde 126 | self.mask = mask 127 | self.y = y 128 | 129 | def prior(self, rng, shape): 130 | x = self.sde.prior(rng, shape) 131 | x = batch_mul_A((1.0 - self.mask), x) + self.y * self.mask 132 | return x 133 | 134 | def update(self, rng, x, t): 135 | mean_coeff = self.sde.mean_coeff(t) 136 | std = jnp.sqrt(self.sde.variance(t)) 137 | masked_data_mean = batch_mul_A(self.y, mean_coeff) 138 | masked_data = masked_data_mean + batch_mul(random.normal(rng, x.shape), std) 139 | x = batch_mul_A((1.0 - self.mask), x) + batch_mul_A(self.mask, masked_data) 140 | x_mean = batch_mul_A((1.0 - self.mask), x) + batch_mul_A( 141 | self.mask, masked_data_mean 142 | ) 143 | return x, x_mean 144 | 145 | 146 | class Projected(Solver): 147 | """Inpainting constraint for numerical solver of an SDE. 148 | Functions are designed for a mini-batch of inputs.""" 149 | 150 | def __init__(self, sde, mask, y, coeff=1.0, ts=jnp.empty((1, 1))): 151 | """Constructs an Annealed Langevin Solver. 152 | Args: 153 | sde: A valid SDE class. 154 | snr: A hyperparameter representing a signal-to-noise ratio. 155 | """ 156 | super().__init__(ts) 157 | self.sde = sde 158 | self.mask = mask 159 | self.y = y 160 | self.coeff = coeff 161 | self.prior = sde.prior 162 | 163 | def merge_data_with_mask(self, x_space, data, mask, coeff): 164 | return batch_mul_A(mask * coeff, data) + batch_mul_A((1.0 - mask * coeff), x_space) 165 | 166 | def update(self, rng, x, t): 167 | mean_coeff = self.sde.mean_coeff(t) 168 | masked_data_mean = batch_mul_A(self.y, mean_coeff) 169 | std = jnp.sqrt(self.sde.variance(t)) 170 | z_data = masked_data_mean + batch_mul(std, random.normal(rng, x.shape)) 171 | x = self.merge_data_with_mask(x, z_data, self.mask, self.coeff) 172 | x_mean = self.merge_data_with_mask(x, masked_data_mean, self.mask, self.coeff) 173 | return x, x_mean 174 | 175 | 176 | class DDPM(Solver): 177 | """DDPM Markov chain using Ancestral sampling.""" 178 | 179 | def __init__(self, score, beta=None, ts=None): 180 | super().__init__(ts) 181 | if beta is None: 182 | beta, _ = get_linear_beta_function(beta_min=0.1, beta_max=20.0) 183 | self.discrete_betas = continuous_to_discrete(vmap(beta)(self.ts.flatten()), self.dt) 184 | self.score = score 185 | self.alphas = 1.0 - self.discrete_betas 186 | self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) 187 | self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod) 188 | self.sqrt_1m_alphas_cumprod = jnp.sqrt(1.0 - self.alphas_cumprod) 189 | self.alphas_cumprod_prev = jnp.append(1.0, self.alphas_cumprod[:-1]) 190 | self.sqrt_alphas_cumprod_prev = jnp.sqrt(self.alphas_cumprod_prev) 191 | self.sqrt_1m_alphas_cumprod_prev = jnp.sqrt(1.0 - self.alphas_cumprod_prev) 192 | 193 | def get_estimate_x_0_vmap(self, observation_map, clip=False, centered=True): 194 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 195 | 196 | def estimate_x_0(x, t, timestep): 197 | x = jnp.expand_dims(x, axis=0) 198 | t = jnp.expand_dims(t, axis=0) 199 | m = self.sqrt_alphas_cumprod[timestep] 200 | v = self.sqrt_1m_alphas_cumprod[timestep] ** 2 201 | s = self.score(x, t) 202 | x_0 = (x + v * s) / m 203 | if clip: 204 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 205 | return observation_map(x_0), (s, x_0) 206 | 207 | return estimate_x_0 208 | 209 | def get_estimate_x_0(self, observation_map, clip=False, centered=True): 210 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 211 | batch_observation_map = vmap(observation_map) 212 | 213 | def estimate_x_0(x, t, timestep): 214 | m = self.sqrt_alphas_cumprod[timestep] 215 | v = self.sqrt_1m_alphas_cumprod[timestep] ** 2 216 | s = self.score(x, t) 217 | x_0 = batch_mul(x + batch_mul(v, s), 1.0 / m) 218 | if clip: 219 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 220 | return batch_observation_map(x_0), (s, x_0) 221 | 222 | return estimate_x_0 223 | 224 | def prior(self, rng, shape): 225 | return random.normal(rng, shape) 226 | 227 | def posterior(self, score, x, timestep): 228 | beta = self.discrete_betas[timestep] 229 | # As implemented by Song 230 | # https://github.com/yang-song/score_sde/blob/0acb9e0ea3b8cccd935068cd9c657318fbc6ce4c/sampling.py#L237C5-L237C79 231 | # x_mean = batch_mul( 232 | # (x + batch_mul(beta, score)), 1. / jnp.sqrt(1. - beta)) # DDPM 233 | # std = jnp.sqrt(beta) 234 | 235 | # # As implemented by DPS2022 236 | # https://github.com/DPS2022/diffusion-posterior-sampling/blob/effbde7325b22ce8dc3e2c06c160c021e743a12d/guided_diffusion/gaussian_diffusion.py#L373 237 | m = self.sqrt_alphas_cumprod[timestep] 238 | v = self.sqrt_1m_alphas_cumprod[timestep] ** 2 239 | alpha = self.alphas[timestep] 240 | x_0 = batch_mul((x + batch_mul(v, score)), 1.0 / m) 241 | m_prev = self.sqrt_alphas_cumprod_prev[timestep] 242 | v_prev = self.sqrt_1m_alphas_cumprod_prev[timestep] ** 2 243 | x_mean = batch_mul(jnp.sqrt(alpha) * v_prev / v, x) + batch_mul( 244 | m_prev * beta / v, x_0 245 | ) 246 | std = jnp.sqrt(beta * v_prev / v) 247 | return x_mean, std 248 | 249 | def update(self, rng, x, t): 250 | score = self.score(x, t) 251 | timestep = get_timestep(t, self.t0, self.t1, self.num_steps) 252 | x_mean, std = self.posterior(score, x, timestep) 253 | z = random.normal(rng, x.shape) 254 | x = x_mean + batch_mul(std, z) 255 | return x, x_mean 256 | 257 | 258 | class SMLD(Solver): 259 | """SMLD(NCSN) Markov Chain using Ancestral sampling.""" 260 | 261 | def __init__(self, score, sigma=None, ts=None): 262 | super().__init__(ts) 263 | if sigma is None: 264 | sigma = get_exponential_sigma_function(sigma_min=0.01, sigma_max=378.0) 265 | sigmas = vmap(sigma)(self.ts.flatten()) 266 | self.discrete_sigmas = sigmas 267 | self.discrete_sigmas_prev = jnp.append(0.0, self.discrete_sigmas[:-1]) 268 | self.score = score 269 | 270 | def get_estimate_x_0_vmap(self, observation_map, clip=False, centered=False): 271 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 272 | 273 | def estimate_x_0(x, t, timestep): 274 | x = jnp.expand_dims(x, axis=0) 275 | t = jnp.expand_dims(t, axis=0) 276 | v = self.discrete_sigmas[timestep] ** 2 277 | s = self.score(x, t) 278 | x_0 = x + v * s 279 | if clip: 280 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 281 | return observation_map(x_0), (s, x_0) 282 | 283 | return estimate_x_0 284 | 285 | def get_estimate_x_0(self, observation_map, clip=False, centered=False): 286 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 287 | batch_observation_map = vmap(observation_map) 288 | 289 | def estimate_x_0(x, t, timestep): 290 | v = self.discrete_sigmas[timestep] ** 2 291 | s = self.score(x, t) 292 | x_0 = x + batch_mul(v, s) 293 | if clip: 294 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 295 | return batch_observation_map(x_0), (s, x_0) 296 | 297 | return estimate_x_0 298 | 299 | def prior(self, rng, shape): 300 | return random.normal(rng, shape) * self.discrete_sigmas[-1] 301 | 302 | def posterior(self, score, x, timestep): 303 | sigma = self.discrete_sigmas[timestep] 304 | sigma_prev = self.discrete_sigmas_prev[timestep] 305 | 306 | # As implemented by Song https://github.com/yang-song/score_sde/blob/0acb9e0ea3b8cccd935068cd9c657318fbc6ce4c/sampling.py#L220 307 | # x_mean = x + batch_mul(score, sigma**2 - sigma_prev**2) 308 | # std = jnp.sqrt((sigma_prev**2 * (sigma**2 - sigma_prev**2)) / (sigma**2)) 309 | 310 | # From posterior in Appendix F https://arxiv.org/pdf/2011.13456.pdf 311 | x_0 = x + batch_mul(sigma**2, score) 312 | x_mean = batch_mul(sigma_prev**2 / sigma**2, x) + batch_mul( 313 | 1 - sigma_prev**2 / sigma**2, x_0 314 | ) 315 | std = jnp.sqrt((sigma_prev**2 * (sigma**2 - sigma_prev**2)) / (sigma**2)) 316 | return x_mean, std 317 | 318 | def update(self, rng, x, t): 319 | timestep = get_timestep(t, self.t0, self.t1, self.num_steps) 320 | score = self.score(x, t) 321 | x_mean, std = self.posterior(score, x, timestep) 322 | z = random.normal(rng, x.shape) 323 | x = x_mean + batch_mul(std, z) 324 | return x, x_mean 325 | 326 | 327 | class DDIMVP(Solver): 328 | """DDIM Markov chain. For the DDPM Markov Chain or VP SDE.""" 329 | 330 | def __init__(self, model, eta=1.0, beta=None, ts=None): 331 | """ 332 | Args: 333 | model: DDIM parameterizes the `epsilon(x, t) = -1. * fwd_marginal_std(t) * score(x, t)` function. 334 | eta: the hyperparameter for DDIM, a value of `eta=0.0` is deterministic 'probability ODE' solver, `eta=1.0` is DDPMVP. 335 | """ 336 | super().__init__(ts) 337 | if beta is None: 338 | beta, _ = get_linear_beta_function(beta_min=0.1, beta_max=20.0) 339 | self.discrete_betas = continuous_to_discrete(vmap(beta)(self.ts.flatten()), self.dt) 340 | self.eta = eta 341 | self.model = model 342 | self.alphas = 1.0 - self.discrete_betas 343 | self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) 344 | self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod) 345 | self.sqrt_1m_alphas_cumprod = jnp.sqrt(1.0 - self.alphas_cumprod) 346 | self.alphas_cumprod_prev = jnp.append(1.0, self.alphas_cumprod[:-1]) 347 | self.sqrt_alphas_cumprod_prev = jnp.sqrt(self.alphas_cumprod_prev) 348 | self.sqrt_1m_alphas_cumprod_prev = jnp.sqrt(1.0 - self.alphas_cumprod_prev) 349 | 350 | def get_estimate_x_0_vmap(self, observation_map, clip=False, centered=True): 351 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 352 | 353 | def estimate_x_0(x, t, timestep): 354 | x = jnp.expand_dims(x, axis=0) 355 | t = jnp.expand_dims(t, axis=0) 356 | m = self.sqrt_alphas_cumprod[timestep] 357 | sqrt_1m_alpha = self.sqrt_1m_alphas_cumprod[timestep] 358 | epsilon = self.model(x, t) 359 | x_0 = (x - sqrt_1m_alpha * epsilon) / m 360 | if clip: 361 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 362 | return observation_map(x_0), (epsilon, x_0) 363 | 364 | return estimate_x_0 365 | 366 | def get_estimate_x_0(self, observation_map, clip=False, centered=True): 367 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 368 | batch_observation_map = vmap(observation_map) 369 | 370 | def estimate_x_0(x, t, timestep): 371 | m = self.sqrt_alphas_cumprod[timestep] 372 | sqrt_1m_alpha = self.sqrt_1m_alphas_cumprod[timestep] 373 | epsilon = self.model(x, t) 374 | x_0 = batch_mul(x - batch_mul(sqrt_1m_alpha, epsilon), 1.0 / m) 375 | if clip: 376 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 377 | return batch_observation_map(x_0), (epsilon, x_0) 378 | 379 | return estimate_x_0 380 | 381 | def prior(self, rng, shape): 382 | return random.normal(rng, shape) 383 | 384 | def posterior(self, x, t): 385 | # # As implemented by DPS2022 386 | # https://github.com/DPS2022/diffusion-posterior-sampling/blob/effbde7325b22ce8dc3e2c06c160c021e743a12d/guided_diffusion/gaussian_diffusion.py#L373 387 | # and as written in https://arxiv.org/pdf/2010.02502.pdf 388 | epsilon = self.model(x, t) 389 | timestep = get_timestep(t, self.t0, self.t1, self.num_steps) 390 | m = self.sqrt_alphas_cumprod[timestep] 391 | sqrt_1m_alpha = self.sqrt_1m_alphas_cumprod[timestep] 392 | v = sqrt_1m_alpha**2 393 | alpha_cumprod = self.alphas_cumprod[timestep] 394 | alpha_cumprod_prev = self.alphas_cumprod_prev[timestep] 395 | m_prev = self.sqrt_alphas_cumprod_prev[timestep] 396 | v_prev = self.sqrt_1m_alphas_cumprod_prev[timestep] ** 2 397 | x_0 = batch_mul((x - batch_mul(sqrt_1m_alpha, epsilon)), 1.0 / m) 398 | coeff1 = self.eta * jnp.sqrt( 399 | (v_prev / v) * (1 - alpha_cumprod / alpha_cumprod_prev) 400 | ) 401 | coeff2 = jnp.sqrt(v_prev - coeff1**2) 402 | x_mean = batch_mul(m_prev, x_0) + batch_mul(coeff2, epsilon) 403 | std = coeff1 404 | return x_mean, std 405 | 406 | def update(self, rng, x, t): 407 | x_mean, std = self.posterior(x, t) 408 | z = random.normal(rng, x.shape) 409 | x = x_mean + batch_mul(std, z) 410 | return x, x_mean 411 | 412 | 413 | class DDIMVE(Solver): 414 | """DDIM Markov chain. For the SMLD Markov Chain or VE SDE. 415 | Args: 416 | model: DDIM parameterizes the `epsilon(x, t) = -1. * fwd_marginal_std(t) * score(x, t)` function. 417 | eta: the hyperparameter for DDIM, a value of `eta=0.0` is deterministic 'probability ODE' solver, `eta=1.0` is DDPMVE. 418 | """ 419 | 420 | def __init__(self, model, eta=1.0, sigma=None, ts=None): 421 | super().__init__(ts) 422 | if sigma is None: 423 | sigma = get_exponential_sigma_function(sigma_min=0.01, sigma_max=378.0) 424 | sigmas = vmap(sigma)(self.ts.flatten()) 425 | self.discrete_sigmas = sigmas 426 | self.discrete_sigmas_prev = jnp.append(0.0, self.discrete_sigmas[:-1]) 427 | self.eta = eta 428 | self.model = model 429 | 430 | def get_estimate_x_0_vmap(self, observation_map, clip=False, centered=False): 431 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 432 | 433 | def estimate_x_0(x, t, timestep): 434 | x = jnp.expand_dims(x, axis=0) 435 | t = jnp.expand_dims(t, axis=0) 436 | std = self.discrete_sigmas[timestep] 437 | epsilon = self.model(x, t) 438 | x_0 = x - std * epsilon 439 | if clip: 440 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 441 | return observation_map(x_0), (epsilon, x_0) 442 | 443 | return estimate_x_0 444 | 445 | def get_estimate_x_0(self, observation_map, clip=False, centered=False): 446 | (a_min, a_max) = (-1.0, 1.0) if centered else (0.0, 1.0) 447 | batch_observation_map = vmap(observation_map) 448 | 449 | def estimate_x_0(x, t, timestep): 450 | std = self.discrete_sigmas[timestep] 451 | epsilon = self.model(x, t) 452 | x_0 = x - batch_mul(std, epsilon) 453 | if clip: 454 | x_0 = jnp.clip(x_0, a_min=a_min, a_max=a_max) 455 | return batch_observation_map(x_0), (epsilon, x_0) 456 | 457 | return estimate_x_0 458 | 459 | def prior(self, rng, shape): 460 | return random.normal(rng, shape) * self.discrete_sigmas[-1] 461 | 462 | def posterior(self, x, t): 463 | timestep = get_timestep(t, self.t0, self.t1, self.num_steps) 464 | epsilon = self.model(x, t) 465 | sigma = self.discrete_sigmas[timestep] 466 | sigma_prev = self.discrete_sigmas_prev[timestep] 467 | coeff1 = self.eta * jnp.sqrt( 468 | (sigma_prev**2 * (sigma**2 - sigma_prev**2)) / (sigma**2) 469 | ) 470 | coeff2 = jnp.sqrt(sigma_prev**2 - coeff1**2) 471 | 472 | # Eq.(18) Appendix A.4 https://openreview.net/pdf/210093330709030207aa90dbfe2a1f525ac5fb7d.pdf 473 | x_0 = x - batch_mul(sigma, epsilon) 474 | x_mean = x_0 + batch_mul(coeff2, epsilon) 475 | 476 | # Eq.(19) Appendix A.4 https://openreview.net/pdf/210093330709030207aa90dbfe2a1f525ac5fb7d.pdf 477 | # score = - batch_mul(1. / sigma, epsilon) 478 | # x_mean = x + batch_mul(sigma * (sigma - coeff2), score) 479 | 480 | std = coeff1 481 | return x_mean, std 482 | 483 | def update(self, rng, x, t): 484 | x_mean, std = self.posterior(x, t) 485 | z = random.normal(rng, x.shape) 486 | x = x_mean + batch_mul(std, z) 487 | return x, x_mean 488 | 489 | 490 | class EDMEuler(Solver): 491 | """ 492 | A solver from the paper Elucidating the Design space of Diffusion-Based 493 | Generative Models. 494 | 495 | Algorithm 2 (Euler steps) from Karras et al. (2022) arxiv.org/abs/2206.00364 496 | """ 497 | 498 | def __init__(self, denoise, sigma=None, gamma=None, ts=None, s_noise=1.0): 499 | """ 500 | The default `args:ts` to use is `ts, dt = diffusionjax.utils.get_times(num_steps, t0=0.0)`. 501 | """ 502 | super().__init__(ts) 503 | if sigma is None: 504 | sigma = get_karras_sigma_function(sigma_min=0.002, sigma_max=80.0, rho=7) 505 | self.discrete_sigmas = vmap(sigma)(self.ts.flatten()) 506 | if gamma is None: 507 | gamma = get_karras_gamma_function( 508 | num_steps=self.num_steps, s_churn=0.0, s_min=0.0, s_max=float("inf") 509 | ) 510 | self.gammas = gamma(self.discrete_sigmas) 511 | self.bool_gamma_greater_than_zero = jnp.where(self.gammas > 0, 1, 0) 512 | self.discrete_sigmas_prev = jnp.append(0.0, self.discrete_sigmas[:-1]) 513 | self.s_noise = s_noise 514 | self.denoise = denoise 515 | 516 | def prior(self, rng, shape): 517 | return random.normal(rng, shape) * self.discrete_sigmas[-1] 518 | 519 | def update(self, rng, x, t): 520 | timestep = get_timestep(t, self.t0, self.t1, self.num_steps) 521 | sigma = self.discrete_sigmas[timestep] 522 | # sigma_prev is the one that will finish with zero, and so it is the previous sigma in forward time 523 | sigma_prev = self.discrete_sigmas_prev[timestep] 524 | gamma = self.gammas[timestep] 525 | sigma_hat = sigma * (gamma + 1) 526 | 527 | # need to do this since get JAX tracer concretization error the naive way 528 | bool = self.bool_gamma_greater_than_zero[timestep[0]] 529 | z = random.normal(rng, x.shape) * self.s_noise 530 | std = jnp.sqrt(sigma_hat**2 - sigma**2) * bool 531 | x = x + batch_mul(std, z) 532 | 533 | # Convert the denoiser output to a Karras ODE derivative 534 | drift = batch_mul(x - self.denoise(x, sigma_hat), 1.0 / sigma) 535 | dt = sigma_prev - sigma_hat 536 | x = x + batch_mul(drift, dt) # Euler method 537 | return x, None 538 | 539 | 540 | class EDMHeun(EDMEuler): 541 | """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" 542 | 543 | def update(self, rng, x, t): 544 | timestep = get_timestep(t, self.t0, self.t1, self.num_steps) 545 | sigma = self.discrete_sigmas[timestep] 546 | sigma_prev = self.discrete_sigmas_prev[timestep] 547 | gamma = self.gammas[timestep] 548 | sigma_hat = sigma * (gamma + 1) 549 | # need to do this since get JAX tracer concretization error the naive way 550 | std = jnp.sqrt(sigma_hat**2 - sigma**2) 551 | bool = self.bool_gamma_greater_than_zero[timestep[0]] 552 | x = jnp.where( 553 | bool, x + batch_mul(std, random.normal(rng, x.shape) * self.s_noise), x 554 | ) 555 | 556 | # Convert the denoiser output to a Karras ODE derivative 557 | drift = batch_mul(x - self.denoise(x, sigma_hat), 1.0 / sigma) 558 | dt = sigma_prev - sigma_hat 559 | x_1 = x + batch_mul(drift, dt) # Euler step 560 | drift_1 = batch_mul(x_1 - self.denoise(x_1, sigma_prev), 1.0 / sigma_prev) 561 | drift_prime = (drift + drift_1) / 2 562 | x_2 = x_1 + batch_mul(drift_prime, dt) # 2nd order correction 563 | return x_2, x_1 564 | -------------------------------------------------------------------------------- /diffusionjax/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions, including all functions related to 2 | loss computation, optimization and sampling. 3 | """ 4 | 5 | import jax.numpy as jnp 6 | from jax.lax import scan 7 | from jax import vmap 8 | import jax.random as random 9 | from functools import partial 10 | from collections.abc import MutableMapping 11 | 12 | 13 | # Taken from https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys 14 | def flatten_nested_dict(nested_dict, parent_key="", sep="."): 15 | items = [] 16 | for name, cfg in nested_dict.items(): 17 | new_key = parent_key + sep + name if parent_key else name 18 | if isinstance(cfg, MutableMapping): 19 | items.extend(flatten_nested_dict(cfg, new_key, sep=sep).items()) 20 | else: 21 | items.append((new_key, cfg)) 22 | 23 | return dict(items) 24 | 25 | 26 | def get_timestep(t, t0, t1, num_steps): 27 | return (jnp.rint((t - t0) * (num_steps - 1) / (t1 - t0))).astype(jnp.int32) 28 | 29 | 30 | def continuous_to_discrete(betas, dt): 31 | discrete_betas = betas * dt 32 | return discrete_betas 33 | 34 | 35 | def get_exponential_sigma_function(sigma_min, sigma_max): 36 | log_sigma_min = jnp.log(sigma_min) 37 | log_sigma_max = jnp.log(sigma_max) 38 | 39 | def sigma(t): 40 | # return sigma_min * (sigma_max / sigma_min)**t # Has large relative error close to zero compared to alternative, below 41 | return jnp.exp(log_sigma_min + t * (log_sigma_max - log_sigma_min)) 42 | 43 | return sigma 44 | 45 | 46 | def get_linear_beta_function(beta_min, beta_max): 47 | """Returns: 48 | Linear beta (cooling rate parameter) as a function of time, 49 | It's integral multiplied by -0.5, which is the log mean coefficient of the VP SDE. 50 | """ 51 | 52 | def beta(t): 53 | return beta_min + t * (beta_max - beta_min) 54 | 55 | def mean_coeff(t): 56 | """..math: exp(-0.5 * \int_{0}^{t} \beta(s) ds)""" 57 | return jnp.exp(-0.5 * t * beta_min - 0.25 * t**2 * (beta_max - beta_min)) 58 | 59 | return beta, mean_coeff 60 | 61 | 62 | def get_cosine_beta_function(beta_max, offset=0.08): 63 | """Returns: 64 | Squared cosine beta (cooling rate parameter) as a function of time, 65 | It's integral multiplied by -0.5, which is the log mean coefficient of the VP SDE. 66 | Note: this implementation cannot perfectly replicate https://arxiv.org/abs/2102.09672 67 | since it deals with a continuous time formulation of beta(t). 68 | Args: 69 | offset: https://arxiv.org/abs/2102.09672 "Use a small offset to prevent 70 | $\beta(t)$ from being too small near 71 | $t = 0$, since we found that having tiny amounts of noise at the beginning 72 | of the process made it hard for the network to predict $\epsilon$ 73 | accurately enough" 74 | """ 75 | 76 | def beta(t): 77 | # clip to max_beta 78 | return jnp.clip(jnp.sin((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) / (jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) + 1e-5) * jnp.pi * (1.0 / (1.0 + offset)), a_max=beta_max) 79 | 80 | def mean_coeff(t): 81 | """..math: -0.5 * \int_{0}^{t} \beta(s) ds""" 82 | return jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) 83 | # return jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) / jnp.cos(offset / (1.0 + offset) * 0.5 * jnp.pi) 84 | 85 | return beta, mean_coeff 86 | 87 | 88 | def get_karras_sigma_function(sigma_min, sigma_max, rho=7): 89 | """ 90 | A sigma function from Algorithm 2 from Karras et al. (2022) arxiv.org/abs/2206.00364 91 | 92 | Returns: 93 | A function that can be used like `sigmas = vmap(sigma)(ts)` where `ts.shape = (num_steps,)`, see `test_utils.py` for usage. 94 | 95 | Args: 96 | sigma_min: Minimum standard deviation of forawrd transition kernel. 97 | sigma_max: Maximum standard deviation of forward transition kernel. 98 | rho: Order of the polynomial in t (determines both smoothness and growth 99 | rate). 100 | """ 101 | min_inv_rho = sigma_min ** (1 / rho) 102 | max_inv_rho = sigma_max ** (1 / rho) 103 | 104 | def sigma(t): 105 | # NOTE: is defined in reverse time of the definition in arxiv.org/abs/2206.00364 106 | return (min_inv_rho + t * (max_inv_rho - min_inv_rho)) ** rho 107 | 108 | return sigma 109 | 110 | 111 | def get_karras_gamma_function(num_steps, s_churn, s_min, s_max): 112 | """ 113 | A gamma function from Algorithm 2 from Karras et al. (2022) arxiv.org/abs/2206.00364 114 | Returns: 115 | A function that can be used like `gammas = gamma(sigmas)` where `sigmas.shape = (num_steps,)`, see `test_utils.py` for usage. 116 | Args: 117 | num_steps: 118 | s_churn: "controls the overall amount of stochasticity" in Algorithm 2 from Karras et al. (2022) 119 | [s_min, s_max] : Range of noise levels that "stochasticity" is enabled. 120 | """ 121 | 122 | def gamma(sigmas): 123 | gammas = jnp.where(sigmas <= s_max, min(s_churn / num_steps, jnp.sqrt(2) - 1), 0.0) 124 | gammas = jnp.where(s_min <= sigmas, gammas, 0.0) 125 | return gammas 126 | 127 | return gamma 128 | 129 | 130 | def get_times(num_steps=1000, dt=None, t0=None): 131 | """ 132 | Get linear, monotonically increasing time schedule. 133 | Args: 134 | num_steps: number of discretization time steps. 135 | dt: time step duration, float or `None`. 136 | Optional, if provided then final time, t1 = dt * num_steps. 137 | t0: A small float 0. < t0 << 1. The SDE or ODE are integrated to 138 | t0 to avoid numerical issues. 139 | Return: 140 | ts: JAX array of monotonically increasing values t \in [t0, t1]. 141 | """ 142 | if dt is not None: 143 | if t0 is not None: 144 | t1 = dt * (num_steps - 1) + t0 145 | # Defined in forward time, t \in [t0, t1], 0 < t0 << t1 146 | ts, step = jnp.linspace(t0, t1, num_steps, retstep=True) 147 | ts = ts.reshape(-1, 1) 148 | assert jnp.isclose(step, (t1 - t0) / (num_steps - 1)) 149 | assert jnp.isclose(step, dt) 150 | dt = step 151 | assert t0 == ts[0] 152 | else: 153 | t1 = dt * num_steps 154 | # Defined in forward time, t \in [dt , t1], 0 < \t0 << t1 155 | ts, step = jnp.linspace(0.0, t1, num_steps + 1, retstep=True) 156 | ts = ts[1:].reshape(-1, 1) 157 | assert jnp.isclose(step, dt) 158 | dt = step 159 | t0 = ts[0] 160 | else: 161 | t1 = 1.0 162 | if t0 is not None: 163 | ts, dt = jnp.linspace(t0, 1.0, num_steps, retstep=True) 164 | ts = ts.reshape(-1, 1) 165 | assert jnp.isclose(dt, (1.0 - t0) / (num_steps - 1)) 166 | assert t0 == ts[0] 167 | else: 168 | # Defined in forward time, t \in [dt, 1.0], 0 < dt << 1 169 | ts, dt = jnp.linspace(0.0, 1.0, num_steps + 1, retstep=True) 170 | ts = ts[1:].reshape(-1, 1) 171 | assert jnp.isclose(dt, 1.0 / num_steps) 172 | t0 = ts[0] 173 | assert ts[0, 0] == t0 174 | assert ts[-1, 0] == t1 175 | dts = jnp.diff(ts) 176 | assert jnp.all(dts > 0.0) 177 | assert jnp.all(dts == dt) 178 | return ts, dt 179 | 180 | 181 | def batch_linalg_solve_A(A, b): 182 | return vmap(lambda b: jnp.linalg.solve(A, b))(b) 183 | 184 | 185 | def batch_linalg_solve(A, b): 186 | return vmap(jnp.linalg.solve)(A, b) 187 | 188 | 189 | def batch_mul(a, b): 190 | return vmap(lambda a, b: a * b)(a, b) 191 | 192 | 193 | def batch_mul_A(a, b): 194 | return vmap(lambda b: a * b)(b) 195 | 196 | 197 | def batch_matmul(A, b): 198 | return vmap(lambda A, b: A @ b)(A, b) 199 | 200 | 201 | def batch_matmul_A(A, b): 202 | return vmap(lambda b: A @ b)(b) 203 | 204 | 205 | def errors(t, sde, score, rng, data, likelihood_weighting=True): 206 | """ 207 | Args: 208 | ts: JAX array of times. 209 | sde: Instantiation of a valid SDE class. 210 | score: A function taking in (x, t) and returning the score. 211 | rng: Random number generator from JAX. 212 | data: A batch of samples from the training data, representing samples from the data distribution, shape (J, N). 213 | likelihood_weighting: Bool, set to `True` if likelihood weighting, as described in Song et al. 2020 (https://arxiv.org/abs/2011.13456), is applied. 214 | Returns: 215 | A Monte-Carlo approximation to the (likelihood weighted) score errors. 216 | """ 217 | m = sde.mean_coeff(t) 218 | mean = batch_mul(m, data) 219 | std = jnp.sqrt(sde.variance(t)) 220 | rng, step_rng = random.split(rng) 221 | noise = random.normal(step_rng, data.shape) 222 | x = mean + batch_mul(std, noise) 223 | if not likelihood_weighting: 224 | return noise + batch_mul(score(x, t), std) 225 | else: 226 | return batch_mul(noise, 1.0 / std) + score(x, t) 227 | 228 | 229 | def get_pointwise_loss( 230 | sde, 231 | model, 232 | score_scaling=True, 233 | likelihood_weighting=True, 234 | reduce_mean=True, 235 | ): 236 | """Create a loss function for score matching training, returning a function that can evaluate the loss pointwise over time. 237 | Args: 238 | sde: Instantiation of a valid SDE class. 239 | solver: Instantiation of a valid Solver class. 240 | model: A valid flax neural network `:class:flax.linen.Module` class. 241 | score_scaling: Bool, set to `True` if learning a score scaled by the marginal standard deviation. 242 | likelihood_weighting: Bool, set to `True` if likelihood weighting, as described in Song et al. 2020 (https://arxiv.org/abs/2011.13456), is applied. 243 | reduce_mean: Bool, set to `True` if taking the mean of the errors in the loss, set to `False` if taking the sum. 244 | 245 | Returns: 246 | A loss function that can be used for score matching training and can evaluate the loss pointwise over time. 247 | """ 248 | reduce_op = ( 249 | jnp.mean if reduce_mean else lambda *args, **kwargs: 0.5 * jnp.sum(*args, **kwargs) 250 | ) 251 | 252 | def pointwise_loss(t, params, rng, data): 253 | n_batch = data.shape[0] 254 | ts = jnp.ones((n_batch,)) * t 255 | score = get_score(sde, model, params, score_scaling) 256 | e = errors(ts, sde, score, rng, data, likelihood_weighting) 257 | losses = e**2 258 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) 259 | if likelihood_weighting: 260 | g2 = sde.sde(jnp.zeros_like(data), ts)[1] ** 2 261 | losses = losses * g2 262 | return jnp.mean(losses) 263 | 264 | return pointwise_loss 265 | 266 | 267 | def get_loss( 268 | sde, 269 | solver, 270 | model, 271 | score_scaling=True, 272 | likelihood_weighting=True, 273 | reduce_mean=True, 274 | ): 275 | """Create a loss function for score matching training. 276 | Args: 277 | sde: Instantiation of a valid SDE class. 278 | solver: Instantiation of a valid Solver class. 279 | model: A valid flax neural network `:class:flax.linen.Module` class. 280 | score_scaling: Bool, set to `True` if learning a score scaled by the marginal standard deviation. 281 | likelihood_weighting: Bool, set to `True` if likelihood weighting, as described in Song et al. 2020 (https://arxiv.org/abs/2011.13456), is applied. 282 | reduce_mean: Bool, set to `True` if taking the mean of the errors in the loss, set to `False` if taking the sum. 283 | 284 | Returns: 285 | A loss function that can be used for score matching training and is an expectation of the regression loss over time. 286 | """ 287 | reduce_op = ( 288 | jnp.mean if reduce_mean else lambda *args, **kwargs: 0.5 * jnp.sum(*args, **kwargs) 289 | ) 290 | 291 | def loss(params, rng, data): 292 | rng, step_rng = random.split(rng) 293 | ts = random.uniform( 294 | step_rng, (data.shape[0],), minval=solver.ts[0], maxval=solver.t1 295 | ) 296 | score = get_score(sde, model, params, score_scaling) 297 | e = errors(ts, sde, score, rng, data, likelihood_weighting) 298 | losses = e**2 299 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) 300 | if likelihood_weighting: 301 | g2 = sde.sde(jnp.zeros_like(data), ts)[1] ** 2 302 | losses = losses * g2 303 | return jnp.mean(losses) 304 | 305 | return loss 306 | 307 | 308 | class EDM2Loss: 309 | """ 310 | Uncertainty-based loss function (Equations 14,15,16,21) proposed in the 311 | paper "Analyzing and Improving the Training Dynamics of Diffusion Models". 312 | """ 313 | 314 | def __init__( 315 | self, net, batch_gpu_total, loss_scaling=1.0, p_mean=-0.4, p_std=1.0, sigma_data=0.5 316 | ): 317 | self.net = net 318 | self.p_mean = p_mean 319 | self.p_std = p_std 320 | self.sigma_data = sigma_data 321 | self.loss_scaling = loss_scaling 322 | self.batch_gpu_total = batch_gpu_total 323 | 324 | def __call__(self, params, rng, data, labels=None): 325 | rng, step_rng = random.split(rng) 326 | random_normal = random.normal( 327 | step_rng, (data.shape[0],) + (1,) * (len(data.shape) - 1) 328 | ) 329 | print("r", random_normal.shape) 330 | sigma = jnp.exp(random_normal * self.p_std + self.p_mean) 331 | print("rs",sigma.shape) 332 | weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 333 | noise = random.normal(step_rng, data.shape) * sigma 334 | denoised, logvar = self.net.apply(params, data + noise, sigma, labels) 335 | loss = (weight / jnp.exp(logvar)) * ((denoised - data) ** 2) + logvar 336 | return jnp.sum(loss) * (self.loss_scaling / self.batch_gpu_total) 337 | 338 | 339 | def get_score(sde, model, params, score_scaling): 340 | if score_scaling is True: 341 | return lambda x, t: -batch_mul( 342 | model.apply(params, x, t), 1.0 / jnp.sqrt(sde.variance(t)) 343 | ) 344 | else: 345 | return lambda x, t: -model.apply(params, x, t) 346 | 347 | 348 | def get_net(model, params): 349 | # TODO: compare to edmv2 code and work out if it is correct 350 | return lambda x, t: -model.apply(params, x, t) 351 | 352 | 353 | def get_epsilon(sde, model, params, score_scaling): 354 | if score_scaling is True: 355 | return lambda x, t: model.apply(params, x, t) 356 | else: 357 | return lambda x, t: batch_mul(jnp.sqrt(sde.variance(t)), model.apply(params, x, t)) 358 | 359 | 360 | def shared_update(rng, x, t, solver, probability_flow=None): 361 | """A wrapper that configures and returns the update function of the solvers. 362 | 363 | :probablity_flow: Placeholder for probability flow ODE (TODO). 364 | """ 365 | return solver.update(rng, x, t) 366 | 367 | 368 | def get_sampler( 369 | shape, 370 | outer_solver, 371 | inner_solver=None, 372 | denoise=True, 373 | stack_samples=False, 374 | inverse_scaler=None, 375 | ): 376 | """Get a sampler from (possibly interleaved) numerical solver(s). 377 | 378 | Args: 379 | shape: Shape of array, x. (num_samples,) + x_shape, where x_shape is the shape 380 | of the object being sampled from, for example, an image may have 381 | x_shape==(H, W, C), and so shape==(N, H, W, C) where N is the number of samples. 382 | outer_solver: A valid numerical solver class that will act on an outer loop. 383 | inner_solver: '' that will act on an inner loop. 384 | denoise: Bool, that if `True` applies one-step denoising to final samples. 385 | stack_samples: Bool, that if `True` return all of the sample path or 386 | just returns the last sample. 387 | inverse_scaler: The inverse data normalizer function. 388 | Returns: 389 | A sampler. 390 | """ 391 | if inverse_scaler is None: 392 | inverse_scaler = lambda x: x 393 | 394 | def sampler(rng, x_0=None): 395 | """ 396 | Args: 397 | rng: A JAX random state. 398 | x_0: Initial condition. If `None`, then samples an initial condition from the 399 | sde's initial condition prior. Note that this initial condition represents 400 | `x_T sim Normal(O, I)` in reverse-time diffusion. 401 | Returns: 402 | Samples and the number of score function (model) evaluations. 403 | """ 404 | outer_update = partial(shared_update, solver=outer_solver) 405 | outer_ts = outer_solver.ts 406 | 407 | if inner_solver: 408 | inner_update = partial(shared_update, solver=inner_solver) 409 | inner_ts = inner_solver.ts 410 | num_function_evaluations = jnp.size(outer_ts) * (jnp.size(inner_ts) + 1) 411 | 412 | def inner_step(carry, t): 413 | rng, x, x_mean, vec_t = carry 414 | rng, step_rng = random.split(rng) 415 | x, x_mean = inner_update(step_rng, x, vec_t) 416 | return (rng, x, x_mean, vec_t), () 417 | 418 | def outer_step(carry, t): 419 | rng, x, x_mean = carry 420 | vec_t = jnp.full(shape[0], t) 421 | rng, step_rng = random.split(rng) 422 | x, x_mean = outer_update(step_rng, x, vec_t) 423 | (rng, x, x_mean, vec_t), _ = scan( 424 | inner_step, (step_rng, x, x_mean, vec_t), inner_ts 425 | ) 426 | if not stack_samples: 427 | return (rng, x, x_mean), () 428 | else: 429 | if denoise: 430 | return (rng, x, x_mean), x_mean 431 | else: 432 | return (rng, x, x_mean), x 433 | else: 434 | num_function_evaluations = jnp.size(outer_ts) 435 | 436 | def outer_step(carry, t): 437 | rng, x, x_mean = carry 438 | vec_t = jnp.full((shape[0],), t) 439 | rng, step_rng = random.split(rng) 440 | x, x_mean = outer_update(step_rng, x, vec_t) 441 | if not stack_samples: 442 | return (rng, x, x_mean), () 443 | else: 444 | return ((rng, x, x_mean), x_mean) if denoise else ((rng, x, x_mean), x) 445 | 446 | rng, step_rng = random.split(rng) 447 | if x_0 is None: 448 | if inner_solver: 449 | x = inner_solver.prior(step_rng, shape) 450 | else: 451 | x = outer_solver.prior(step_rng, shape) 452 | else: 453 | assert x_0.shape == shape 454 | x = x_0 455 | if not stack_samples: 456 | (_, x, x_mean), _ = scan(outer_step, (rng, x, x), outer_ts, reverse=True) 457 | return inverse_scaler(x_mean if denoise else x), num_function_evaluations 458 | else: 459 | (_, _, _), xs = scan(outer_step, (rng, x, x), outer_ts, reverse=True) 460 | return inverse_scaler(xs), num_function_evaluations 461 | 462 | # return jax.pmap(sampler, in_axes=(0), axis_name='batch') 463 | return sampler 464 | -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | """Diffusion models introduction. 2 | 3 | Based off the Jupyter notebook: https://jakiw.com/sgm_intro 4 | A tutorial on the theoretical and implementation aspects of score-based generative models, also called diffusion models. 5 | """ 6 | 7 | # Uncomment to enable double precision 8 | # from jax.config import config as jax_config 9 | # jax_config.update("jax_enable_x64", True) 10 | import jax 11 | from jax import jit, vmap, grad 12 | import jax.random as random 13 | import jax.numpy as jnp 14 | import flax.linen as nn 15 | from jax.scipy.special import logsumexp 16 | import numpy as np 17 | from diffusionjax.run_lib import train 18 | from diffusionjax.utils import get_score, get_sampler, get_times 19 | from diffusionjax.solvers import EulerMaruyama, Inpainted 20 | from diffusionjax.inverse_problems import get_pseudo_inverse_guidance, get_vjp_guidance 21 | from diffusionjax.plot import plot_scatter, plot_score, plot_heatmap 22 | import diffusionjax.sde as sde_lib 23 | from absl import app, flags 24 | from ml_collections.config_flags import config_flags 25 | from flax import serialization 26 | import time 27 | import os 28 | 29 | # Dependencies: 30 | # This example requires optax, https://optax.readthedocs.io/en/latest/ 31 | # This example requires orbax, https://orbax.readthedocs.io/en/latest/ 32 | # This example requires torch[cpu], https://pytorch.org/get-started/locally/ 33 | from torch.utils.data import Dataset 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | config_flags.DEFINE_config_file( 38 | "config", "./configs/example.py", "Training configuration.", lock_config=True 39 | ) 40 | flags.DEFINE_string("workdir", "./examples/", "Work directory.") 41 | flags.mark_flags_as_required(["workdir", "config"]) 42 | 43 | 44 | class CircleDataset(Dataset): 45 | """Dataset containing samples from the circle.""" 46 | 47 | def __init__(self, num_samples): 48 | self.train_data = self.sample_circle(num_samples) 49 | 50 | def __len__(self): 51 | return self.train_data.shape[0] 52 | 53 | def __getitem__(self, idx): 54 | return self.train_data[idx] 55 | 56 | def sample_circle(self, num_samples): 57 | """Samples from the unit circle, angles split. 58 | 59 | Args: 60 | num_samples: The number of samples. 61 | 62 | Returns: 63 | An (num_samples, 2) array of samples. 64 | """ 65 | alphas = jnp.linspace(0, 2 * jnp.pi * (1 - 1 / num_samples), num_samples) 66 | xs = jnp.cos(alphas) 67 | ys = jnp.sin(alphas) 68 | samples = jnp.stack([xs, ys], axis=1) 69 | return samples 70 | 71 | def metric_names(self): 72 | return ["mean"] 73 | 74 | def calculate_metrics_batch(self, batch): 75 | return vmap(lambda x: jnp.mean(x, axis=0))(batch)[0, 0] 76 | 77 | def get_data_scaler(self, config): 78 | def data_scaler(x): 79 | return x / jnp.sqrt(2) 80 | 81 | return data_scaler 82 | 83 | def get_data_inverse_scaler(self, config): 84 | def data_inverse_scaler(x): 85 | return x * jnp.sqrt(2) 86 | 87 | return data_inverse_scaler 88 | 89 | 90 | class MLP(nn.Module): 91 | @nn.compact 92 | def __call__(self, x, t): 93 | x_shape = x.shape 94 | in_size = np.prod(x_shape[1:]) 95 | n_hidden = 256 96 | t = t.reshape((t.shape[0], -1)) 97 | x = x.reshape((x.shape[0], -1)) # flatten 98 | t = jnp.concatenate([t - 0.5, jnp.cos(2 * jnp.pi * t)], axis=-1) 99 | x = jnp.concatenate([x, t], axis=-1) 100 | x = nn.Dense(n_hidden)(x) 101 | x = nn.relu(x) 102 | x = nn.Dense(n_hidden)(x) 103 | x = nn.relu(x) 104 | x = nn.Dense(n_hidden)(x) 105 | x = nn.relu(x) 106 | x = nn.Dense(in_size)(x) 107 | return x.reshape(x_shape) 108 | 109 | 110 | def main(argv): 111 | workdir = FLAGS.workdir 112 | config = FLAGS.config 113 | jax.default_device = jax.devices()[0] 114 | # Tip: use CUDA_VISIBLE_DEVICES to restrict the devices visible to jax 115 | # ... they must be all the same model of device for pmap to work 116 | num_devices = int(jax.local_device_count()) if config.training.pmap else 1 117 | rng = random.PRNGKey(config.seed) 118 | 119 | # Setup SDE 120 | if config.training.sde.lower() == "vpsde": 121 | from diffusionjax.utils import get_linear_beta_function 122 | 123 | beta, mean_coeff = get_linear_beta_function( 124 | beta_min=config.model.beta_min, beta_max=config.model.beta_max 125 | ) 126 | sde = sde_lib.VP(beta, mean_coeff) 127 | elif config.training.sde.lower() == "vesde": 128 | from diffusionjax.utils import get_exponential_sigma_function 129 | 130 | sigma = get_exponential_sigma_function( 131 | sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max 132 | ) 133 | sde = sde_lib.VE(sigma) 134 | else: 135 | raise NotImplementedError(f"SDE {config.training.SDE} unknown.") 136 | 137 | # Build data iterators 138 | num_samples = 8 139 | dataset = CircleDataset(num_samples=num_samples) 140 | scaler = dataset.get_data_scaler(config) 141 | inverse_scaler = dataset.get_data_inverse_scaler(config) 142 | plot_scatter( 143 | samples=dataset.train_data, index=(0, 1), fname="samples", lims=((-3, 3), (-3, 3)) 144 | ) 145 | 146 | def log_hat_pt(x, t): 147 | """Empirical distribution score. 148 | 149 | Args: 150 | x: One location in $\mathbb{R}^2$ 151 | t: time 152 | Returns: 153 | The empirical log density, as described in the Jupyter notebook 154 | .. math:: 155 | \log\hat{p}_{t}(x) 156 | """ 157 | mean_coeff = sde.mean_coeff(t) # argument t can be scalar BatchTracer or JaxArray 158 | mean = mean_coeff * scaler(dataset.train_data) 159 | std = jnp.sqrt(sde.variance(t)) 160 | potentials = jnp.sum(-((x - mean) ** 2) / (2 * std**2), axis=1) 161 | return logsumexp(potentials, axis=0, b=1 / num_samples) 162 | 163 | # Get a jax grad function, which can be batched with vmap 164 | nabla_log_hat_pt = jit(vmap(grad(log_hat_pt))) 165 | 166 | # Running the reverse SDE with the empirical drift 167 | plot_score( 168 | score=nabla_log_hat_pt, 169 | scaler=scaler, 170 | t=0.01, 171 | area_bounds=[-3.0, 3], 172 | fname="empirical score", 173 | ) 174 | ts, _ = get_times( 175 | num_steps=config.solver.num_outer_steps, 176 | dt=config.solver.dt, 177 | t0=config.solver.epsilon, 178 | ) 179 | outer_solver = EulerMaruyama(sde.reverse(nabla_log_hat_pt), ts) 180 | sampler = get_sampler( 181 | (5760, config.data.image_size), 182 | outer_solver, 183 | denoise=config.sampling.denoise, 184 | stack_samples=False, 185 | inverse_scaler=inverse_scaler, 186 | ) 187 | rng, sample_rng = random.split(rng, 2) 188 | q_samples, _ = sampler(sample_rng) 189 | plot_heatmap( 190 | samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap empirical score" 191 | ) 192 | 193 | # What happens when I perturb the score with a constant? 194 | perturbed_score = lambda x, t: nabla_log_hat_pt(x, t) + 1.0 195 | outer_solver = EulerMaruyama(sde.reverse(perturbed_score), ts) 196 | sampler = get_sampler( 197 | (5760, config.data.image_size), 198 | outer_solver, 199 | denoise=config.sampling.denoise, 200 | inverse_scaler=inverse_scaler, 201 | ) 202 | rng, sample_rng = random.split(rng, 2) 203 | q_samples, _ = sampler(sample_rng) 204 | plot_heatmap( 205 | samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap bounded perturbation" 206 | ) 207 | 208 | if not os.path.exists("/tmp/output0"): 209 | time_prev = time.time() 210 | params, *_ = train( 211 | (config.training.batch_size // jax.local_device_count(), config.data.image_size), 212 | config, 213 | MLP(), 214 | dataset, 215 | workdir, 216 | use_wandb=False, 217 | ) # Optionally visualize results on weightsandbiases 218 | time_delta = time.time() - time_prev 219 | print("train time: {}s".format(time_delta)) 220 | 221 | # Save params 222 | output = serialization.to_bytes(params) 223 | f = open("/tmp/output0", "wb") 224 | f.write(output) 225 | else: # Load pre-trained model parameters 226 | params = MLP().init( 227 | rng, 228 | jnp.zeros( 229 | (config.training.batch_size // jax.local_device_count(), config.data.image_size) 230 | ), 231 | jnp.ones((config.training.batch_size // jax.local_device_count(),)), 232 | ) 233 | f = open("/tmp/output0", "rb") 234 | output = f.read() 235 | params = serialization.from_bytes(params, output) 236 | 237 | # Get trained score 238 | trained_score = get_score( 239 | sde, MLP(), params, score_scaling=config.training.score_scaling 240 | ) 241 | plot_score( 242 | score=trained_score, 243 | scaler=scaler, 244 | t=0.01, 245 | area_bounds=[-3.0, 3.0], 246 | fname="trained score", 247 | ) 248 | outer_solver = EulerMaruyama(sde.reverse(trained_score), ts) 249 | sampler = get_sampler( 250 | (config.eval.batch_size // num_devices, config.data.image_size), 251 | outer_solver, 252 | denoise=config.sampling.denoise, 253 | inverse_scaler=inverse_scaler, 254 | ) 255 | 256 | if config.training.pmap: 257 | sampler = jax.pmap(sampler, axis_name="batch") 258 | rng, *sample_rng = random.split(rng, 1 + num_devices) 259 | sample_rng = jnp.asarray(sample_rng) 260 | else: 261 | rng, sample_rng = random.split(rng, 2) 262 | 263 | q_samples, _ = sampler(sample_rng) 264 | q_samples = q_samples.reshape(config.eval.batch_size, config.data.image_size) 265 | plot_heatmap( 266 | samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap trained score" 267 | ) 268 | 269 | 270 | # Inverse problems 271 | sampling_shape = (config.eval.batch_size, config.data.image_size) 272 | rsde = sde.reverse(trained_score) 273 | # Condition on one of the coordinates 274 | y = jnp.array([-0.5, 0.0]) 275 | mask = jnp.array([1.0, 0.0]) 276 | y = scaler(y) 277 | 278 | # Get inpainter 279 | sampler = get_sampler( 280 | sampling_shape, 281 | outer_solver, 282 | Inpainted(rsde, mask, y), 283 | inverse_scaler=inverse_scaler, 284 | stack_samples=False, 285 | denoise=True, 286 | ) 287 | q_samples, _ = sampler(sample_rng) 288 | q_samples = q_samples.reshape(sampling_shape) 289 | plot_heatmap(samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap inpainted") 290 | 291 | # Get projection sampler 292 | sampler = get_sampler( 293 | sampling_shape, 294 | outer_solver, 295 | Inpainted(rsde, mask, y), 296 | inverse_scaler=inverse_scaler, 297 | stack_samples=False, 298 | denoise=True, 299 | ) 300 | q_samples, _ = sampler(sample_rng) 301 | q_samples = q_samples.reshape(sampling_shape) 302 | plot_heatmap(samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap projected") 303 | 304 | def observation_map(x): 305 | return mask * x 306 | 307 | y = jnp.tile(y, (sampling_shape[0], 1)) 308 | # Get pseudo-inverse-guidance sampler 309 | sampler = get_sampler( 310 | sampling_shape, 311 | EulerMaruyama( 312 | sde.reverse(trained_score).guide( 313 | get_pseudo_inverse_guidance, observation_map, y, config.sampling.noise_std 314 | ) 315 | ), 316 | inverse_scaler=inverse_scaler, 317 | stack_samples=False, 318 | denoise=True, 319 | ) 320 | q_samples, _ = sampler(sample_rng) 321 | q_samples = q_samples.reshape(sampling_shape) 322 | plot_heatmap(samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap guided") 323 | 324 | y = jnp.array([-0.5, 1.0]) 325 | H = jnp.eye(2) 326 | y = jnp.tile(y, (sampling_shape[0], 1)) 327 | # Get pseudo-inverse-guidance sampler 328 | sampler = get_sampler( 329 | sampling_shape, 330 | EulerMaruyama( 331 | sde.reverse(trained_score).guide( 332 | get_vjp_guidance, H, y, config.sampling.noise_std, sampling_shape 333 | ) 334 | ), 335 | inverse_scaler=inverse_scaler, 336 | stack_samples=False, 337 | denoise=True, 338 | ) 339 | q_samples, _ = sampler(sample_rng) 340 | q_samples = q_samples.reshape(sampling_shape) 341 | plot_heatmap(samples=q_samples, area_bounds=[-3.0, 3.0], fname="heatmap tmpd guided") 342 | 343 | 344 | if __name__ == "__main__": 345 | app.run(main) 346 | -------------------------------------------------------------------------------- /examples/example1.py: -------------------------------------------------------------------------------- 1 | """Diffusion models introduction. An example using 1 dimensional image data.""" 2 | 3 | from jax import jit, value_and_grad 4 | import jax.random as random 5 | import jax.numpy as jnp 6 | from flax import serialization 7 | from functools import partial 8 | from diffusionjax.inverse_problems import get_pseudo_inverse_guidance 9 | from diffusionjax.plot import plot_heatmap, plot_samples_1D 10 | from diffusionjax.utils import ( 11 | get_score, 12 | get_loss, 13 | get_sampler, 14 | get_exponential_sigma_function, 15 | ) 16 | from diffusionjax.solvers import EulerMaruyama, Inpainted, Projected 17 | from diffusionjax.sde import VE 18 | import numpy as np 19 | import flax.linen as nn 20 | import os 21 | 22 | # Dependencies: 23 | # This example requires mlkernels package, https://github.com/wesselb/mlkernels#installation 24 | import lab as B 25 | from mlkernels import Matern52 26 | 27 | # This example requires optax, https://optax.readthedocs.io/en/latest/ 28 | import optax 29 | 30 | 31 | x_max = 5.0 32 | epsilon = 1e-4 33 | 34 | 35 | # Initialize the optimizer 36 | optimizer = optax.adam(1e-3) 37 | 38 | 39 | class MLP(nn.Module): 40 | @nn.compact 41 | def __call__(self, x, t): 42 | x_shape = x.shape 43 | in_size = np.prod(x_shape[1:]) 44 | n_hidden = 256 45 | t = t.reshape((t.shape[0], -1)) 46 | x = x.reshape((x.shape[0], -1)) # flatten 47 | t = jnp.concatenate([t - 0.5, jnp.cos(2 * jnp.pi * t)], axis=-1) 48 | x = jnp.concatenate([x, t], axis=-1) 49 | x = nn.Dense(n_hidden)(x) 50 | x = nn.relu(x) 51 | x = nn.Dense(n_hidden)(x) 52 | x = nn.relu(x) 53 | x = nn.Dense(n_hidden)(x) 54 | x = nn.relu(x) 55 | x = nn.Dense(in_size)(x) 56 | return x.reshape(x_shape) 57 | 58 | 59 | @partial(jit, static_argnums=[4]) 60 | def update_step(params, rng, batch, opt_state, loss): 61 | """ 62 | Takes the gradient of the loss function and updates the model weights (params) using it. 63 | Args: 64 | params: the current weights of the model 65 | rng: random number generator from jax 66 | batch: a batch of samples from the training data, representing samples from \mu_text{data}, shape (J, N) 67 | opt_state: the internal state of the optimizer 68 | loss: A loss function that can be used for score matching training. 69 | Returns: 70 | The value of the loss function (for metrics), the new params and the new optimizer states function (for metrics), 71 | the new params and the new optimizer state. 72 | """ 73 | val, grads = value_and_grad(loss)(params, rng, batch) 74 | updates, opt_state = optimizer.update(grads, opt_state) 75 | params = optax.apply_updates(params, updates) 76 | return val, params, opt_state 77 | 78 | 79 | def retrain_nn( 80 | update_step, num_epochs, step_rng, samples, params, opt_state, loss, batch_size=5 81 | ): 82 | train_size = samples.shape[0] 83 | batch_size = min(train_size, batch_size) 84 | steps_per_epoch = train_size // batch_size 85 | mean_losses = jnp.zeros((num_epochs, 1)) 86 | for i in range(num_epochs): 87 | rng, step_rng = random.split(step_rng) 88 | perms = random.permutation(step_rng, train_size) 89 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 90 | perms = perms.reshape((steps_per_epoch, batch_size)) 91 | losses = jnp.zeros((jnp.shape(perms)[0], 1)) 92 | for j, perm in enumerate(perms): 93 | batch = samples[perm, :] 94 | rng, step_rng = random.split(rng) 95 | loss_eval, params, opt_state = update_step( 96 | params, step_rng, batch, opt_state, loss 97 | ) 98 | losses = losses.at[j].set(loss_eval) 99 | mean_loss = jnp.mean(losses, axis=0) 100 | mean_losses = mean_losses.at[i].set(mean_loss) 101 | if i % 10 == 0: 102 | print("Epoch {:d}, Loss {:.2f} ".format(i, mean_loss[0])) 103 | return params, opt_state, mean_losses 104 | 105 | 106 | def sample_image_rgb(rng, num_samples, image_size, kernel, num_channels=1): 107 | """Samples from a GMRF""" 108 | x = np.linspace(-x_max, x_max, image_size) 109 | x = x.reshape(image_size, 1) 110 | C = B.dense(kernel(x)) + epsilon * B.eye(image_size) 111 | u = random.multivariate_normal( 112 | rng, mean=jnp.zeros(x.shape[0]), cov=C, shape=(num_samples, num_channels) 113 | ) 114 | u = u.transpose((0, 2, 1)) 115 | return u, C 116 | 117 | 118 | def plot_score_ax_sample( 119 | ax, sample, score, t, area_min=-1, area_max=1, fname="plot_score" 120 | ): 121 | @partial( 122 | jit, 123 | static_argnums=[ 124 | 0, 125 | ], 126 | ) 127 | def helper(score, sample, t, area_min, area_max): 128 | x = jnp.linspace(area_min, area_max, 16) 129 | x, y = jnp.meshgrid(x, x) 130 | grid = jnp.stack([x.flatten(), y.flatten()], axis=1) 131 | sample = jnp.tile(sample, (len(x.flatten()), 1, 1, 1)) 132 | sample.at[:, [0, 1], 0, 0].set(grid) 133 | t = jnp.ones((grid.shape[0],)) * t 134 | scores = score(sample, t) 135 | return grid, scores 136 | 137 | grid, scores = helper(score, sample, t, area_min, area_max) 138 | ax.quiver(grid[:, 0], grid[:, 1], scores[:, 0, 0, 0], scores[:, 1, 0, 0]) 139 | 140 | 141 | def main(): 142 | num_epochs = 128 143 | rng = random.PRNGKey(2023) 144 | rng, step_rng = random.split(rng, 2) 145 | num_samples = 18000 146 | num_channels = 1 147 | image_size = 64 # image size 148 | 149 | samples, C = sample_image_rgb( 150 | rng, 151 | num_samples=num_samples, 152 | image_size=image_size, 153 | kernel=Matern52(), 154 | num_channels=num_channels, 155 | ) # (num_samples, image_size, num_channels) 156 | 157 | # Reshape image data 158 | samples = samples.reshape(-1, image_size, num_channels) 159 | plot_samples_1D(samples[:64], image_size, x_max=x_max, fname="samples") 160 | 161 | # Get sde model 162 | sigma = get_exponential_sigma_function(sigma_min=0.01, sigma_max=3.0) 163 | sde = VE(sigma) 164 | 165 | def nabla_log_pt(x, t): 166 | """Score. 167 | 168 | Returns: 169 | The true log density. 170 | .. math:: 171 | \nabla_{x} \log p_{t}(x) 172 | """ 173 | x_shape = x.shape 174 | v_t = sde.variance(t) 175 | m_t = sde.mean_coeff(t) 176 | x = x.flatten() 177 | score = -jnp.linalg.solve(m_t**2 * C + v_t * jnp.eye(x_shape[0]), x) 178 | return score.reshape(x_shape) 179 | 180 | # Neural network training via score matching 181 | batch_size = 64 182 | score_model = MLP() 183 | 184 | # Initialize parameters 185 | params = score_model.init( 186 | step_rng, jnp.zeros((batch_size, image_size, num_channels)), jnp.ones((batch_size,)) 187 | ) 188 | 189 | # Initialize optimizer 190 | opt_state = optimizer.init(params) 191 | 192 | if not os.path.exists("/tmp/output1"): 193 | solver = EulerMaruyama(sde) 194 | 195 | # Get loss function 196 | loss = get_loss( 197 | sde, 198 | solver, 199 | score_model, 200 | score_scaling=True, 201 | likelihood_weighting=False, 202 | reduce_mean=True, 203 | ) 204 | 205 | # Train with score matching 206 | params, opt_state, _ = retrain_nn( 207 | update_step=update_step, 208 | num_epochs=num_epochs, 209 | step_rng=step_rng, 210 | samples=samples, 211 | params=params, 212 | opt_state=opt_state, 213 | loss=loss, 214 | batch_size=batch_size, 215 | ) 216 | 217 | # Save params 218 | output = serialization.to_bytes(params) 219 | f = open("/tmp/output1", "wb") 220 | f.write(output) 221 | else: # Load pre-trained model parameters 222 | f = open("/tmp/output1", "rb") 223 | output = f.read() 224 | params = serialization.from_bytes(params, output) 225 | 226 | # Get trained score 227 | trained_score = get_score(sde, score_model, params, score_scaling=True) 228 | rsde = sde.reverse(trained_score) 229 | outer_solver = EulerMaruyama(rsde) 230 | sampling_shape = (512, image_size, num_channels) 231 | sampler = get_sampler(sampling_shape, outer_solver, denoise=True) 232 | 233 | rng, sample_rng = random.split(rng, 2) 234 | q_samples, num_function_evaluations = sampler(sample_rng) 235 | 236 | # C_emp = jnp.corrcoef(q_samples[:, :, 0].T) 237 | # delta = jnp.linalg.norm(C - C_emp) / image_size 238 | 239 | plot_samples_1D( 240 | q_samples[:64], image_size=image_size, x_max=x_max, fname="samples trained score" 241 | ) 242 | plot_heatmap( 243 | samples=q_samples[:, [0, 1], 0], 244 | area_bounds=[-3.0, 3.0], 245 | fname="heatmap trained score", 246 | ) 247 | 248 | # Condition on one of the coordinates 249 | y = jnp.zeros((image_size, num_channels)) 250 | y = y.at[[0, -1], 0].set([-1.0, 1.0]) 251 | mask = jnp.zeros((image_size, num_channels), dtype=float) 252 | mask = mask.at[[0, -1], 0].set([1.0, 1.0]) 253 | 254 | # Get inpainting sampler 255 | sampler = get_sampler( 256 | sampling_shape, 257 | outer_solver, 258 | Inpainted(rsde, mask, y), 259 | stack_samples=False, 260 | denoise=True, 261 | ) 262 | q_samples, _ = sampler(sample_rng) 263 | plot_samples_1D( 264 | q_samples, image_size=image_size, x_max=x_max, fname="samples inpainted" 265 | ) 266 | 267 | # Get projection sampler 268 | sampler = get_sampler( 269 | sampling_shape, 270 | outer_solver, 271 | Projected(rsde, mask, y, coeff=1e-2), 272 | stack_samples=False, 273 | denoise=True, 274 | ) 275 | q_samples, _ = sampler(sample_rng) 276 | plot_samples_1D( 277 | q_samples, image_size=image_size, x_max=x_max, fname="samples projected" 278 | ) 279 | 280 | def observation_map(x): 281 | return mask * x 282 | 283 | # Get pseudo-inverse-guidance sampler 284 | sampler = get_sampler( 285 | sampling_shape, 286 | EulerMaruyama( 287 | rsde.guide(get_pseudo_inverse_guidance, observation_map, y, noise_std=1e-5) 288 | ), 289 | stack_samples=False, 290 | denoise=True, 291 | ) 292 | q_samples, _ = sampler(sample_rng) 293 | q_samples = q_samples.reshape(sampling_shape) 294 | plot_samples_1D(q_samples, image_size=image_size, x_max=x_max, fname="samples guided") 295 | 296 | 297 | if __name__ == "__main__": 298 | main() 299 | -------------------------------------------------------------------------------- /examples/example2.py: -------------------------------------------------------------------------------- 1 | """Diffusion models introduction. An example using 2 dimensional image data.""" 2 | 3 | import jax 4 | from jax import jit, value_and_grad 5 | import jax.random as random 6 | import jax.numpy as jnp 7 | from flax import serialization 8 | from functools import partial 9 | from diffusionjax.plot import plot_samples, plot_heatmap, plot_samples_1D, plot_samples 10 | from diffusionjax.utils import ( 11 | get_score, 12 | get_loss, 13 | get_sampler, 14 | get_times, 15 | get_exponential_sigma_function, 16 | ) 17 | from diffusionjax.solvers import EulerMaruyama, Annealed, Inpainted, Projected 18 | from diffusionjax.inverse_problems import get_pseudo_inverse_guidance 19 | from diffusionjax.sde import VE, ulangevin 20 | import numpy as np 21 | import flax.linen as nn 22 | import os 23 | 24 | # Dependencies: 25 | # This example requires mlkernels package, https://github.com/wesselb/mlkernels#installation 26 | from mlkernels import Matern52 27 | import lab as B 28 | 29 | # This example requires optax, https://optax.readthedocs.io/en/latest/ 30 | import optax 31 | 32 | 33 | x_max = 5.0 34 | epsilon = 1e-4 35 | 36 | 37 | # Initialize the optimizer 38 | optimizer = optax.adam(1e-3) 39 | 40 | 41 | class CNN(nn.Module): 42 | @nn.compact 43 | def __call__(self, x, t): 44 | x_shape = x.shape 45 | ndim = x.ndim 46 | 47 | n_hidden = x_shape[1] 48 | n_time_channels = 1 49 | 50 | t = t.reshape((t.shape[0], -1)) 51 | t = jnp.concatenate([t - 0.5, jnp.cos(2 * jnp.pi * t)], axis=-1) 52 | t = nn.Dense(n_hidden**2 * n_time_channels)(t) 53 | t = nn.relu(t) 54 | t = nn.Dense(n_hidden**2 * n_time_channels)(t) 55 | t = nn.relu(t) 56 | t = t.reshape(t.shape[0], n_hidden, n_hidden, n_time_channels) 57 | # Add time as another channel 58 | x = jnp.concatenate((x, t), axis=-1) 59 | # A single convolution layer 60 | x = nn.Conv(x_shape[-1], kernel_size=(9,) * (ndim - 2))(x) 61 | return x 62 | 63 | 64 | @partial(jit, static_argnums=[4]) 65 | def update_step(params, rng, batch, opt_state, loss): 66 | """ 67 | Takes the gradient of the loss function and updates the model weights (params) using it. 68 | Args: 69 | params: the current weights of the model 70 | rng: random number generator from jax 71 | batch: a batch of samples from the training data, representing samples from \mu_text{data}, shape (J, N) 72 | opt_state: the internal state of the optimizer 73 | loss: A loss function that can be used for score matching training. 74 | Returns: 75 | The value of the loss function (for metrics), the new params and the new optimizer states function (for metrics), 76 | the new params and the new optimizer state. 77 | """ 78 | val, grads = value_and_grad(loss)(params, rng, batch) 79 | updates, opt_state = optimizer.update(grads, opt_state) 80 | params = optax.apply_updates(params, updates) 81 | return val, params, opt_state 82 | 83 | 84 | def retrain_nn( 85 | update_step, num_epochs, step_rng, samples, params, opt_state, loss, batch_size=5 86 | ): 87 | train_size = samples.shape[0] 88 | batch_size = min(train_size, batch_size) 89 | steps_per_epoch = train_size // batch_size 90 | mean_losses = jnp.zeros((num_epochs, 1)) 91 | for i in range(num_epochs): 92 | rng, step_rng = random.split(step_rng) 93 | perms = random.permutation(step_rng, train_size) 94 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 95 | perms = perms.reshape((steps_per_epoch, batch_size)) 96 | losses = jnp.zeros((jnp.shape(perms)[0], 1)) 97 | for j, perm in enumerate(perms): 98 | batch = samples[perm, :] 99 | rng, step_rng = random.split(rng) 100 | loss_eval, params, opt_state = update_step( 101 | params, step_rng, batch, opt_state, loss 102 | ) 103 | losses = losses.at[j].set(loss_eval) 104 | mean_loss = jnp.mean(losses, axis=0) 105 | mean_losses = mean_losses.at[i].set(mean_loss) 106 | if i % 10 == 0: 107 | print("Epoch {:d}, Loss {:.2f} ".format(i, mean_loss[0])) 108 | return params, opt_state, mean_losses 109 | 110 | 111 | def sample_image_rgb(rng, num_samples, image_size, kernel, num_channels): 112 | """Samples from a GMRF.""" 113 | x = np.linspace(-x_max, x_max, image_size) 114 | y = np.linspace(-x_max, x_max, image_size) 115 | xx, yy = np.meshgrid(x, y) 116 | xx = xx.reshape(image_size**2, 1) 117 | yy = yy.reshape(image_size**2, 1) 118 | z = np.hstack((xx, yy)) 119 | C = B.dense(kernel(z)) + epsilon * B.eye(image_size**2) 120 | u = random.multivariate_normal( 121 | rng, mean=jnp.zeros(xx.shape[0]), cov=C, shape=(num_samples, num_channels) 122 | ) 123 | u = u.transpose((0, 2, 1)) 124 | return u, C 125 | 126 | 127 | def main(): 128 | num_epochs = 200 129 | rng = random.PRNGKey(2023) 130 | rng, step_rng = random.split(rng, 2) 131 | num_samples = 144 132 | num_channels = 1 133 | image_size = 32 # image size 134 | num_steps = 1000 135 | 136 | # Get and handle image data 137 | samples, _ = sample_image_rgb( 138 | rng, 139 | num_samples=num_samples, 140 | image_size=image_size, 141 | kernel=Matern52(), 142 | num_channels=num_channels, 143 | ) # (num_samples, image_size**2, num_channels) 144 | plot_samples(samples[:64], image_size=image_size, num_channels=num_channels) 145 | samples = samples.reshape(-1, image_size, image_size, num_channels) 146 | plot_samples_1D(samples[:64, 0], image_size, x_max=x_max, fname="samples 1D") 147 | 148 | # Get sde model 149 | sigma = get_exponential_sigma_function(sigma_min=0.001, sigma_max=3.0) 150 | sde = VE(sigma) 151 | 152 | # Neural network training via score matching 153 | batch_size = 16 154 | score_model = CNN() 155 | 156 | # Initialize parameters 157 | params = score_model.init( 158 | step_rng, 159 | jnp.zeros((batch_size, image_size, image_size, num_channels)), 160 | jnp.ones((batch_size,)), 161 | ) 162 | 163 | # Initialize optimizer 164 | opt_state = optimizer.init(params) 165 | 166 | if not os.path.exists("/tmp/output2"): 167 | # Get loss function 168 | ts, _ = get_times(num_steps=num_steps) 169 | solver = EulerMaruyama(sde, ts=ts) 170 | loss = get_loss( 171 | sde, 172 | solver, 173 | score_model, 174 | score_scaling=True, 175 | likelihood_weighting=False, 176 | reduce_mean=True, 177 | ) 178 | 179 | # Train with score matching 180 | params, opt_state, _ = retrain_nn( 181 | update_step=update_step, 182 | num_epochs=num_epochs, 183 | step_rng=step_rng, 184 | samples=samples, 185 | params=params, 186 | opt_state=opt_state, 187 | loss=loss, 188 | batch_size=batch_size, 189 | ) 190 | 191 | # Save params 192 | output = serialization.to_bytes(params) 193 | f = open("/tmp/output2", "wb") 194 | f.write(output) 195 | else: # Load pre-trained model parameters 196 | f = open("/tmp/output2", "rb") 197 | output = f.read() 198 | params = serialization.from_bytes(params, output) 199 | 200 | # Get trained score 201 | trained_score = get_score(sde, score_model, params, score_scaling=True) 202 | 203 | # Get the outer loop of a numerical solver, also known as "predictor" 204 | rsde = sde.reverse(trained_score) 205 | ts, _ = get_times(num_steps=num_steps) 206 | outer_solver = EulerMaruyama(rsde, ts) 207 | 208 | # Get the inner loop of a numerical solver, also known as "corrector" 209 | inner_solver = Annealed(rsde.correct(ulangevin), snr=0.01, ts=jnp.empty((2, 1))) 210 | 211 | # pmap across devices. pmap assumes devices are identical model. If this is not the case, 212 | # use the devices argument in pmap 213 | num_devices = jax.local_device_count() 214 | sampling_shape = (64, image_size, image_size, num_channels) 215 | sampler = jax.pmap( 216 | get_sampler( 217 | (sampling_shape[0] // num_devices,) + sampling_shape[1:], 218 | outer_solver, 219 | inner_solver, 220 | denoise=True, 221 | ), 222 | axis_name="batch", 223 | # devices = jax.devices()[:], 224 | ) 225 | rng, *sample_rng = random.split(rng, num_devices + 1) 226 | sample_rng = jnp.asarray(sample_rng) 227 | q_samples, _ = sampler(sample_rng) 228 | q_samples = q_samples.reshape(sampling_shape) 229 | plot_samples( 230 | q_samples, 231 | image_size=image_size, 232 | num_channels=num_channels, 233 | fname="samples trained score", 234 | ) 235 | plot_samples_1D( 236 | q_samples[:, 0], image_size, x_max=x_max, fname="samples 1D trained score" 237 | ) 238 | plot_heatmap( 239 | samples=q_samples[:, [0, 1], 0, 0], 240 | area_bounds=[-3.0, 3.0], 241 | fname="heatmap trained score", 242 | ) 243 | 244 | # Condition on one of the coordinates 245 | y = jnp.zeros(sampling_shape[1:]) 246 | y = y.at[[0, -1], [0, -1], 0].set([-1.0, 1.0]) 247 | mask = jnp.zeros(sampling_shape[1:], dtype=float) 248 | mask = mask.at[[0, -1], [0, -1], 0].set([1.0, 1.0]) 249 | 250 | # Get inpainting sampler 251 | sampler = get_sampler( 252 | sampling_shape, 253 | outer_solver, 254 | Inpainted(rsde, mask, y), 255 | stack_samples=False, 256 | denoise=True, 257 | ) 258 | q_samples, _ = sampler(rng) 259 | plot_samples_1D( 260 | q_samples[:, 0], image_size=image_size, x_max=x_max, fname="samples inpainted" 261 | ) 262 | # plot_samples(q_samples[:64], image_size=image_size, num_channels=num_channels, fname="samples inpainted") 263 | 264 | # Get projection sampler 265 | sampler = get_sampler( 266 | sampling_shape, 267 | outer_solver, 268 | Projected(rsde, mask, y, coeff=1e-2), 269 | stack_samples=False, 270 | denoise=True, 271 | ) 272 | q_samples, _ = sampler(rng) 273 | plot_samples_1D( 274 | q_samples[:, 0], image_size=image_size, x_max=x_max, fname="samples projected" 275 | ) 276 | # plot_samples(q_samples[:64], image_size=image_size, num_channels=num_channels, fname="samples projected") 277 | 278 | def observation_map(x): 279 | return mask * x 280 | 281 | # Get pseudo-inverse-guidance sampler 282 | sampler = get_sampler( 283 | sampling_shape, 284 | EulerMaruyama( 285 | rsde.guide(get_pseudo_inverse_guidance, observation_map, y, noise_std=1e-5) 286 | ), 287 | stack_samples=False, 288 | denoise=True, 289 | ) 290 | q_samples, _ = sampler(rng) 291 | q_samples = q_samples.reshape(sampling_shape) 292 | plot_samples_1D( 293 | q_samples[:, 0], image_size=image_size, x_max=x_max, fname="samples guided" 294 | ) 295 | # plot_samples(q_samples[:64], image_size=image_size, num_channels=num_channels, fname="samples guided") 296 | 297 | 298 | if __name__ == "__main__": 299 | main() 300 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_unused_configs = True 3 | files = diffusionjax 4 | ignore_missing_imports = True 5 | check_untyped_defs = True 6 | explicit_package_bases = True 7 | warn_unreachable = True 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=50.0", 4 | "setuptools_scm[toml]>=6.0", 5 | "setuptools_scm_git_archive", 6 | "wheel>=0.33", 7 | "numpy>=1.16", 8 | ] 9 | 10 | [tool.setuptools_scm] 11 | write_to = "diffusionjax/_version.py" 12 | -------------------------------------------------------------------------------- /readme_empirical_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_empirical_score.png -------------------------------------------------------------------------------- /readme_heatmap_bounded_perturbation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_heatmap_bounded_perturbation.png -------------------------------------------------------------------------------- /readme_heatmap_empirical_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_heatmap_empirical_score.png -------------------------------------------------------------------------------- /readme_heatmap_inpainted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_heatmap_inpainted.png -------------------------------------------------------------------------------- /readme_heatmap_trained_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_heatmap_trained_score.png -------------------------------------------------------------------------------- /readme_nplan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_nplan.png -------------------------------------------------------------------------------- /readme_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_samples.png -------------------------------------------------------------------------------- /readme_trained_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bb515/diffusionjax/24fb2c8ee0ca85f618caf797c24614d2ee686be4/readme_trained_score.png -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | tab-size = 2 2 | 3 | select = [ 4 | "F", 5 | "W6", 6 | # "E71", 7 | # "E72", 8 | # "E112", 9 | # "E113", 10 | # "E124", 11 | # "E203", 12 | # "E272", 13 | # "E303", 14 | # "E304", 15 | # "E502", 16 | # "E702", 17 | # "E703", 18 | # "E731", 19 | "W191", 20 | "UP039", # unnecessary-class-parentheses 21 | ] 22 | 23 | # ignore= ["F722", # 'Syntax error in forward annotation' seems incompatible with jaxtyping syntax 24 | # "F821", # 'Undefined name' seems incompatible with jaxtyping syntax 25 | # ] 26 | 27 | exclude = [ 28 | "examples/", 29 | "configs/", 30 | "wandb/", 31 | ] 32 | 33 | [per-file-ignores] 34 | "test/*" = [ 35 | "F401", 36 | "F403", 37 | "F405", 38 | "F541", 39 | "E722", 40 | "E731", 41 | "F821", 42 | "F841", 43 | ] 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup script for diffusionjax. 3 | 4 | This setup is required or else 5 | >> ModuleNotFoundError: No module named 'diffusionjax' 6 | will occur. 7 | """ 8 | from setuptools import setup, find_packages 9 | import pathlib 10 | 11 | 12 | # The directory containing this file 13 | HERE = pathlib.Path(__file__).parent 14 | 15 | # The text of the README file 16 | README = (HERE / "README.md").read_text() 17 | 18 | # The text of the LICENSE file 19 | LICENSE = (HERE / "LICENSE.rst").read_text() 20 | 21 | setup( 22 | name="diffusionjax", 23 | # python_requires=">=3.8", 24 | description="diffusionjax is a simple and accessible diffusion models package in JAX", 25 | long_description=README, 26 | long_description_content_type="text/markdown", 27 | url="https://github.com/bb515/diffusionjax", 28 | author="Benjamin Boys and Jakiw Pidstrigach", 29 | license="MIT", 30 | license_file=LICENSE, 31 | packages=find_packages(exclude=["*.test"]), 32 | install_requires=[ 33 | "numpy", 34 | "scipy", 35 | "matplotlib", 36 | "flax", 37 | "ml_collections", 38 | "tqdm", 39 | "absl-py", 40 | "wandb", 41 | ], 42 | extras_require={ 43 | 'linting': [ 44 | "flake8", 45 | "pylint", 46 | "mypy", 47 | "typing-extensions", 48 | "pre-commit", 49 | "ruff", 50 | 'jaxtyping', 51 | ], 52 | 'testing': [ 53 | "optax", 54 | "orbax-checkpoint", 55 | "torch", 56 | "pytest", 57 | "pytest-xdist", 58 | "pytest-cov", 59 | "coveralls", 60 | "jax>=0.4.1", 61 | "jaxlib>=0.4.1", 62 | "setuptools_scm[toml]", 63 | "setuptools_scm_git_archive", 64 | ], 65 | 'examples': [ 66 | "optax", 67 | "orbax-checkpoint", 68 | "torch", 69 | "mlkernels", 70 | ], 71 | }, 72 | include_package_data=True) 73 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | """tests.""" 2 | -------------------------------------------------------------------------------- /test/external/test_benchmark.py: -------------------------------------------------------------------------------- 1 | """Benchmark of example.py training time, sample time and regression test of summary statistic of samples.""" 2 | 3 | import pytest 4 | import jax 5 | from jax import vmap 6 | import jax.random as random 7 | import jax.numpy as jnp 8 | from diffusionjax.run_lib import train, get_solver 9 | from diffusionjax.utils import get_score, get_sampler 10 | import diffusionjax.sde as sde_lib 11 | from absl import app, flags 12 | from ml_collections.config_flags import config_flags 13 | import time 14 | 15 | # Dependencies: 16 | # This test requires optax, https://optax.readthedocs.io/en/latest/ 17 | # This test requires orbax, https://orbax.readthedocs.io/en/latest/ 18 | # This test requires torch[cpu], https://pytorch.org/get-started/locally/ 19 | from torch.utils.data import Dataset 20 | import flax.linen as nn 21 | import numpy as np 22 | 23 | 24 | FLAGS = flags.FLAGS 25 | config_flags.DEFINE_config_file( 26 | "config", "./configs/example.py", "Training configuration.", lock_config=True 27 | ) 28 | flags.mark_flags_as_required(["config"]) 29 | 30 | 31 | class MLP(nn.Module): 32 | @nn.compact 33 | def __call__(self, x, t): 34 | x_shape = x.shape 35 | in_size = np.prod(x_shape[1:]) 36 | n_hidden = 256 37 | t = t.reshape((t.shape[0], -1)) 38 | x = x.reshape((x.shape[0], -1)) # flatten 39 | t = jnp.concatenate([t - 0.5, jnp.cos(2 * jnp.pi * t)], axis=-1) 40 | x = jnp.concatenate([x, t], axis=-1) 41 | x = nn.Dense(n_hidden)(x) 42 | x = nn.relu(x) 43 | x = nn.Dense(n_hidden)(x) 44 | x = nn.relu(x) 45 | x = nn.Dense(n_hidden)(x) 46 | x = nn.relu(x) 47 | x = nn.Dense(in_size)(x) 48 | return x.reshape(x_shape) 49 | 50 | 51 | class CircleDataset(Dataset): 52 | """Dataset containing samples from the circle.""" 53 | 54 | def __init__(self, num_samples): 55 | self.train_data = self.sample_circle(num_samples) 56 | 57 | def __len__(self): 58 | return self.train_data.shape[0] 59 | 60 | def __getitem__(self, idx): 61 | return self.train_data[idx] 62 | 63 | def sample_circle(self, num_samples): 64 | """Samples from the unit circle, angles split. 65 | 66 | Args: 67 | num_samples: The number of samples. 68 | 69 | Returns: 70 | An (num_samples, 2) array of samples. 71 | """ 72 | alphas = jnp.linspace(0, 2 * jnp.pi * (1 - 1 / num_samples), num_samples) 73 | xs = jnp.cos(alphas) 74 | ys = jnp.sin(alphas) 75 | samples = jnp.stack([xs, ys], axis=1) 76 | return samples 77 | 78 | def metric_names(self): 79 | return ["mean"] 80 | 81 | def calculate_metrics_batch(self, batch): 82 | return vmap(lambda x: jnp.mean(x, axis=0))(batch)[0, 0] 83 | 84 | def get_data_scaler(self, config): 85 | def data_scaler(x): 86 | return x / jnp.sqrt(2) 87 | 88 | return data_scaler 89 | 90 | def get_data_inverse_scaler(self, config): 91 | def data_inverse_scaler(x): 92 | return x * jnp.sqrt(2) 93 | 94 | return data_inverse_scaler 95 | 96 | 97 | def main(argv): 98 | config = FLAGS.config 99 | jax.default_device = jax.devices()[0] 100 | # Tip: use CUDA_VISIBLE_DEVICES to restrict the devices visible to jax 101 | # ... they must be all the same model of device for pmap to work 102 | num_devices = int(jax.local_device_count()) if config.training.pmap else 1 103 | rng = random.PRNGKey(config.seed) 104 | 105 | # Setup SDE 106 | if config.training.sde.lower() == "vpsde": 107 | from diffusionjax.utils import get_linear_beta_function 108 | 109 | beta, mean_coeff = get_linear_beta_function( 110 | config.model.beta_min, config.model.beta_max 111 | ) 112 | sde = sde_lib.VP(beta=beta, mean_coeff=mean_coeff) 113 | elif config.training.sde.lower() == "vesde": 114 | from diffusionjax.utils import get_exponential_sigma_function 115 | 116 | sigma = get_exponential_sigma_function( 117 | config.model.sigma_min, config.model.sigma_max 118 | ) 119 | sde = sde_lib.VE(sigma=sigma) 120 | else: 121 | raise NotImplementedError(f"SDE {config.training.SDE} unknown.") 122 | 123 | # Build data iterators 124 | num_samples = 8 125 | dataset = CircleDataset(num_samples=num_samples) 126 | inverse_scaler = dataset.get_data_inverse_scaler(config) 127 | 128 | time_prev = time.time() 129 | params, _, mean_losses = train( 130 | (config.training.batch_size // jax.local_device_count(), config.data.image_size), 131 | config, 132 | MLP(), 133 | dataset, 134 | workdir=None, 135 | use_wandb=False, 136 | ) 137 | train_time_delta = time.time() - time_prev 138 | print("train time: {}s".format(train_time_delta)) 139 | expected_mean_loss = 0.4081565 140 | mean_loss = jnp.mean(mean_losses) 141 | import matplotlib.pyplot as plt 142 | 143 | plt.plot(mean_losses) 144 | plt.show() 145 | 146 | # Get trained score 147 | trained_score = get_score( 148 | sde, MLP(), params, score_scaling=config.training.score_scaling 149 | ) 150 | outer_solver, inner_solver = get_solver(config, sde, trained_score) 151 | sampler = get_sampler( 152 | (config.eval.batch_size // num_devices, config.data.image_size), 153 | outer_solver, 154 | inner_solver, 155 | denoise=config.sampling.denoise, 156 | inverse_scaler=inverse_scaler, 157 | ) 158 | 159 | if config.training.pmap: 160 | sampler = jax.pmap(sampler, axis_name="batch") 161 | rng, *sample_rng = random.split(rng, 1 + num_devices) 162 | sample_rng = jnp.asarray(sample_rng) 163 | else: 164 | rng, sample_rng = random.split(rng, 2) 165 | 166 | time_prev = time.time() 167 | q_samples, _ = sampler(sample_rng) 168 | sample_time_delta = time.time() - time_prev 169 | print("sample time: {}s".format(sample_time_delta)) 170 | q_samples = q_samples.reshape(config.eval.batch_size, config.data.image_size) 171 | plt.scatter(q_samples[:, 0], q_samples[:, 1]) 172 | plt.show() 173 | radii = jnp.linalg.norm(q_samples, axis=1) 174 | expected_mean_radii = 1.0236381 175 | mean_radii = jnp.mean(radii) 176 | expected_std_radii = 0.09904917 177 | std_radii = jnp.std(radii) 178 | 179 | # Regression 180 | print(mean_radii, expected_mean_radii, "mradii") 181 | print(std_radii, expected_std_radii, "mradii") 182 | assert jnp.isclose(mean_radii, expected_mean_radii) 183 | assert jnp.isclose(std_radii, expected_std_radii) 184 | assert jnp.isclose( 185 | mean_loss, expected_mean_loss 186 | ), "average loss (got {}, expected {})".format(mean_loss, expected_mean_loss) 187 | 188 | 189 | if __name__ == "__main__": 190 | app.run(main) 191 | -------------------------------------------------------------------------------- /test/test_solvers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from diffusionjax.utils import ( 3 | get_times, 4 | get_karras_sigma_function, 5 | get_karras_gamma_function, 6 | ) 7 | import jax.numpy as jnp 8 | import jax.random as random 9 | from diffusionjax.solvers import EDMHeun 10 | from diffusionjax.utils import get_sampler 11 | 12 | 13 | def test_karras_heun_sampler(): 14 | num_steps = 100 15 | sigma_min = 0.002 16 | sigma_max = 80 17 | rho = 7 18 | 19 | batch_size = 4 20 | image_size = 1 21 | sample_shape = (batch_size, image_size) 22 | 23 | s_churn = 100 24 | s_min = 10.0 25 | s_max = 60.0 26 | s_noise = 1 27 | 28 | # NOTE The default ts to use is `diffusionjax.utils.get_times(num_steps, t0=0.0)`. 29 | ts, _ = get_times(num_steps, t0=0.0) 30 | sigma = get_karras_sigma_function(sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) 31 | gamma = get_karras_gamma_function(num_steps, s_churn=s_churn, s_min=s_min, s_max=s_max) 32 | 33 | step_indices = jnp.arange(num_steps) 34 | t_steps = ( 35 | sigma_max ** (1 / rho) 36 | + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) 37 | ) ** rho 38 | 39 | # t_N = 0 40 | t_steps = jnp.append(t_steps, jnp.zeros_like(t_steps[:1])) 41 | 42 | def denoise(x_hat, t_hat): 43 | return x_hat 44 | 45 | def edm_sampler(rng, denoise, t_steps, num_steps): 46 | # Main sampling loop. 47 | rng, step_rng = random.split(rng) 48 | noise = random.normal(step_rng, sample_shape) 49 | x_next = noise * t_steps[0] 50 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 51 | rng, step_rng = random.split(rng) 52 | x_cur = x_next 53 | 54 | # Increase noise temporarily. 55 | if s_churn > 0 and s_min <= t_cur <= s_max: 56 | gamma = min(s_churn / num_steps, jnp.sqrt(2) - 1) 57 | t_hat = t_cur + gamma * t_cur 58 | x_hat = x_cur + jnp.sqrt(t_hat ** 2 - t_cur ** 2) * s_noise * random.normal(step_rng, x_cur.shape) 59 | else: 60 | t_hat = t_cur 61 | x_hat = x_cur 62 | 63 | dt = t_next - t_hat 64 | 65 | # Euler step. 66 | d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat 67 | x_next = x_hat + dt * d_cur 68 | 69 | # Apply 2nd order correction. 70 | if i < num_steps - 1: 71 | d_prime = (x_next - denoise(x_next, t_next)) / t_next 72 | # note that this is not necessarily the same x_next as the one used to calculate d_cur at the next step 73 | # so there is actually no need to carry the score vector across 74 | x_next = x_hat + dt * (0.5 * d_cur + 0.5 * d_prime) 75 | 76 | return x_next 77 | 78 | solver = EDMHeun( 79 | denoise=denoise, sigma=sigma, gamma=gamma, ts=ts, s_noise=s_noise) 80 | 81 | sampler = get_sampler( 82 | (batch_size, image_size), 83 | solver, 84 | stack_samples=False, 85 | ) 86 | 87 | rng0 = random.PRNGKey(2023) 88 | 89 | x_expected = edm_sampler(rng0, denoise, t_steps, num_steps) 90 | x_actual, no_function_evaluations = sampler(rng0) 91 | # TODO: number_function_evaluations is incorrect and needs to be multiplied by two due to Heun step 92 | assert jnp.allclose(x_expected, x_actual) 93 | 94 | 95 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from diffusionjax.utils import ( 3 | batch_mul, 4 | get_times, 5 | get_linear_beta_function, 6 | get_timestep, 7 | continuous_to_discrete, 8 | get_exponential_sigma_function, 9 | get_karras_sigma_function, 10 | get_karras_gamma_function, 11 | ) 12 | import jax.numpy as jnp 13 | from jax import vmap 14 | 15 | 16 | def test_batch_mul(): 17 | """Placeholder test for `:meth:batch_mul` to test CI""" 18 | a = jnp.ones((2,)) * 2.0 19 | bs = [jnp.zeros((2,)), jnp.ones((2,)), jnp.ones((2,)) * jnp.pi] 20 | c_expecteds = [jnp.zeros((2,)), 2.0 * jnp.ones((2,)), 2.0 * jnp.ones((2,)) * jnp.pi] 21 | for i, b in enumerate(bs): 22 | c = batch_mul(a, b) 23 | assert jnp.allclose(c, c_expecteds[i]) 24 | 25 | 26 | def test_continuous_discrete_equivalence_linear_beta_schedule(): 27 | beta_min = 0.1 28 | beta_max = 20.0 29 | num_steps = 1000 30 | # https://github.com/yang-song/score_sde/blob/0acb9e0ea3b8cccd935068cd9c657318fbc6ce4c/sde_lib.py#L127 31 | # expected_discrete_betas = jnp.linspace(beta_min / num_steps, beta_max / num_steps, num_steps) # I think this is incorrect unless training in discrete time 32 | ts, dt = get_times(num_steps) 33 | beta, _ = get_linear_beta_function(beta_min=0.1, beta_max=20.0) 34 | actual_discrete_betas = continuous_to_discrete(vmap(beta)(ts), dt) 35 | expected_discrete_betas = dt * (beta_min + ts * (beta_max - beta_min)) 36 | assert jnp.allclose(expected_discrete_betas, actual_discrete_betas) 37 | 38 | 39 | def test_exponential_sigma_schedule(): 40 | num_steps = 1000 41 | sigma_min = 0.01 42 | sigma_max = 378.0 43 | ts, dt = get_times(num_steps) 44 | sigma = get_exponential_sigma_function(sigma_min=sigma_min, sigma_max=sigma_max) 45 | actual_discrete_sigmas = vmap(sigma)(ts) 46 | # https://github.com/yang-song/score_sde/blob/0acb9e0ea3b8cccd935068cd9c657318fbc6ce4c/sde_lib.py#L222 47 | # expected_sigmas = jnp.exp( # I think this is wrong 48 | # jnp.linspace(jnp.log(sigma_min), 49 | # jnp.log(sigma_max), 50 | # num_steps)) 51 | # 52 | ts, _ = get_times(num_steps, dt) 53 | expected_discrete_sigmas = jnp.exp( 54 | jnp.log(sigma_min) + ts * (jnp.log(sigma_max) - jnp.log(sigma_min)) 55 | ) 56 | 57 | assert jnp.allclose(expected_discrete_sigmas, actual_discrete_sigmas) 58 | 59 | 60 | def test_karras_sigma_schedule(): 61 | num_steps = 1000 62 | sigma_min = 0.002 63 | sigma_max = 80 64 | rho = 7 65 | 66 | # NOTE The default ts to use is `diffusionjax.utils.get_times(num_steps, t0=0.0)`. 67 | ts, _ = get_times(num_steps, t0=0.0) 68 | ts = ts.flatten() 69 | sigma = get_karras_sigma_function( 70 | sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) 71 | actual_discrete_sigmas = vmap(sigma)(ts) 72 | 73 | step_indices = jnp.arange(num_steps) 74 | expected_discrete_sigmas = ( 75 | sigma_max ** (1 / rho) 76 | + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) 77 | ) ** rho 78 | expected_discrete_sigmas = jnp.flip(expected_discrete_sigmas) 79 | assert jnp.allclose(expected_discrete_sigmas, actual_discrete_sigmas) 80 | 81 | 82 | def test_get_timestep_continuous(): 83 | def unit(ts): 84 | t0 = ts[0] 85 | t1 = ts[-1] 86 | t = ts[0] 87 | num_steps = jnp.size(ts) 88 | timestep = get_timestep(t, t0, t1, num_steps) 89 | assert timestep == 0 90 | 91 | t = ts[-1] 92 | timestep = get_timestep(t, t0, t1, num_steps) 93 | assert timestep == num_steps - 1 94 | 95 | t = ts[num_steps - num_steps // 2] 96 | timestep = get_timestep(t, t0, t1, num_steps) 97 | assert timestep == num_steps - num_steps // 2 98 | 99 | ts, dt = get_times() 100 | ts = ts.flatten() 101 | assert jnp.size(ts) == 1000 102 | assert jnp.isclose(ts[1] - ts[0], 0.001) 103 | assert jnp.isclose(ts[1] - ts[0], dt) 104 | assert ts[0] == 0.001 105 | assert ts[-1] == 1.0 106 | unit(ts) 107 | 108 | ts, dt = get_times(dt=0.1) 109 | ts = ts.flatten() 110 | assert jnp.size(ts) == 1000 111 | assert jnp.isclose(ts[1] - ts[0], 0.1) 112 | assert jnp.isclose(ts[1] - ts[0], dt) 113 | assert ts[0] == 0.1 114 | assert ts[-1] == 0.1 * 1000 115 | unit(ts) 116 | 117 | ts, dt = get_times(t0=0.01) 118 | ts = ts.flatten() 119 | assert jnp.size(ts) == 1000 120 | assert jnp.isclose(ts[1] - ts[0], (1.0 - 0.01) / (1000 - 1)) 121 | assert jnp.isclose(ts[1] - ts[0], dt) 122 | assert ts[0] == 0.01 123 | assert ts[-1] == 1.0 124 | unit(ts) 125 | 126 | ts, dt = get_times(dt=0.1, t0=0.01) 127 | ts = ts.flatten() 128 | assert jnp.size(ts) == 1000 129 | assert jnp.isclose(ts[1] - ts[0], 0.1) 130 | assert jnp.isclose(ts[1] - ts[0], dt) 131 | assert ts[0] == 0.01 132 | assert ts[-1] == 0.1 * (1000 - 1) + 0.01 133 | unit(ts) 134 | 135 | ts, dt = get_times(num_steps=100, dt=0.1, t0=0.01) 136 | ts = ts.flatten() 137 | assert jnp.size(ts) == 100 138 | assert jnp.isclose(ts[1] - ts[0], 0.1) 139 | assert jnp.isclose(ts[1] - ts[0], dt) 140 | assert ts[0] == 0.01 141 | assert ts[-1] == 0.1 * (100 - 1) + 0.01 142 | unit(ts) 143 | 144 | # Catch any rounding errors for low number of steps 145 | 146 | ts, dt = get_times(num_steps=10) 147 | ts = ts.flatten() 148 | assert jnp.size(ts) == 10 149 | assert ts[1] - ts[0] == 0.1 150 | assert jnp.isclose(ts[1] - ts[0], dt) 151 | assert ts[0] == 0.1 152 | assert ts[-1] == 1.0 153 | unit(ts) 154 | 155 | ts, dt = get_times(dt=0.05, num_steps=10) 156 | ts = ts.flatten() 157 | assert jnp.size(ts) == 10 158 | assert ts[1] - ts[0] == 0.05 159 | assert jnp.isclose(ts[1] - ts[0], dt) 160 | assert ts[0] == 0.05 161 | assert ts[-1] == 0.05 * 10 162 | unit(ts) 163 | 164 | ts, dt = get_times(t0=0.01, num_steps=10) 165 | ts = ts.flatten() 166 | assert jnp.size(ts) == 10 167 | assert jnp.isclose(ts[1] - ts[0], (1.0 - 0.01) / (10 - 1)) 168 | assert jnp.isclose(ts[1] - ts[0], dt) 169 | assert ts[0] == 0.01 170 | assert ts[-1] == 1.0 171 | unit(ts) 172 | 173 | ts, dt = get_times(dt=0.1, t0=0.01, num_steps=10) 174 | ts = ts.flatten() 175 | assert jnp.size(ts) == 10 176 | assert ts[1] - ts[0] == 0.1 177 | assert jnp.isclose(ts[1] - ts[0], dt) 178 | assert ts[0] == 0.01 179 | assert ts[-1] == 0.1 * (10 - 1) + 0.01 180 | unit(ts) 181 | 182 | 183 | def test_karras_gamma_function(s_churn=50.0, s_min=10.0, s_max=float('inf')): 184 | num_steps = 100 185 | ts, _ = get_times(num_steps, t0=0.0) 186 | ts = ts.flatten() 187 | sigma = get_karras_sigma_function(sigma_min=0.01, sigma_max=80.0, rho=7) 188 | sigmas = vmap(sigma)(ts) 189 | gamma = get_karras_gamma_function(num_steps, s_churn=s_churn, s_min=s_min, s_max=s_max) 190 | gammas = gamma(sigmas) 191 | for gamma_actual, sigma_actual in zip(gammas, sigmas): 192 | gamma_expected = min(s_churn / num_steps, jnp.sqrt(2) - 1) if s_min <= sigma_actual <= s_max else 0.0 193 | assert gamma_actual == gamma_expected 194 | 195 | 196 | test_karras_gamma_function() 197 | 198 | --------------------------------------------------------------------------------