├── denoising_diffusion_pytorch ├── version.py ├── __init__.py ├── weighted_objective_gaussian_diffusion.py ├── learned_gaussian_diffusion.py ├── v_param_continuous_time_gaussian_diffusion.py ├── elucidated_diffusion.py ├── continuous_time_gaussian_diffusion.py ├── denoising_diffusion_pytorch_1d.py ├── classifier_free_guidance.py └── denoising_diffusion_pytorch.py ├── images ├── sample.png └── denoising-diffusion.png ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── LICENSE ├── .gitignore └── README.md /denoising_diffusion_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.6' 2 | -------------------------------------------------------------------------------- /images/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/denoising-diffusion-pytorch/main/images/sample.png -------------------------------------------------------------------------------- /images/denoising-diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/denoising-diffusion-pytorch/main/images/denoising-diffusion.png -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | 3 | from denoising_diffusion_pytorch.learned_gaussian_diffusion import LearnedGaussianDiffusion 4 | from denoising_diffusion_pytorch.continuous_time_gaussian_diffusion import ContinuousTimeGaussianDiffusion 5 | from denoising_diffusion_pytorch.weighted_objective_gaussian_diffusion import WeightedObjectiveGaussianDiffusion 6 | from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion 7 | from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion 8 | 9 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D 10 | 11 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | exec(open('denoising_diffusion_pytorch/version.py').read()) 4 | 5 | setup( 6 | name = 'denoising-diffusion-pytorch', 7 | packages = find_packages(), 8 | version = __version__, 9 | license='MIT', 10 | description = 'Denoising Diffusion Probabilistic Models - Pytorch', 11 | author = 'Phil Wang', 12 | author_email = 'lucidrains@gmail.com', 13 | url = 'https://github.com/lucidrains/denoising-diffusion-pytorch', 14 | long_description_content_type = 'text/markdown', 15 | keywords = [ 16 | 'artificial intelligence', 17 | 'generative models' 18 | ], 19 | install_requires=[ 20 | 'accelerate', 21 | 'einops', 22 | 'ema-pytorch', 23 | 'pillow', 24 | 'torch', 25 | 'torchvision', 26 | 'tqdm' 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 4 - Beta', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 3.6', 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/weighted_objective_gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inspect import isfunction 3 | from torch import nn, einsum 4 | from einops import rearrange 5 | 6 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion 7 | 8 | # helper functions 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | def default(val, d): 14 | if exists(val): 15 | return val 16 | return d() if isfunction(d) else d 17 | 18 | # some improvisation on my end 19 | # where i have the model learn to both predict noise and x0 20 | # and learn the weighted sum for each depending on time step 21 | 22 | class WeightedObjectiveGaussianDiffusion(GaussianDiffusion): 23 | def __init__( 24 | self, 25 | model, 26 | *args, 27 | pred_noise_loss_weight = 0.1, 28 | pred_x_start_loss_weight = 0.1, 29 | **kwargs 30 | ): 31 | super().__init__(model, *args, **kwargs) 32 | channels = model.channels 33 | assert model.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8' 34 | assert not model.self_condition, 'not supported yet' 35 | assert not self.is_ddim_sampling, 'ddim sampling cannot be used' 36 | 37 | self.split_dims = (channels, channels, 2) 38 | self.pred_noise_loss_weight = pred_noise_loss_weight 39 | self.pred_x_start_loss_weight = pred_x_start_loss_weight 40 | 41 | def p_mean_variance(self, *, x, t, clip_denoised, model_output = None): 42 | model_output = self.model(x, t) 43 | 44 | pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1) 45 | normalized_weights = weights.softmax(dim = 1) 46 | 47 | x_start_from_noise = self.predict_start_from_noise(x, t = t, noise = pred_noise) 48 | 49 | x_starts = torch.stack((x_start_from_noise, pred_x_start), dim = 1) 50 | weighted_x_start = einsum('b j h w, b j c h w -> b c h w', normalized_weights, x_starts) 51 | 52 | if clip_denoised: 53 | weighted_x_start.clamp_(-1., 1.) 54 | 55 | model_mean, model_variance, model_log_variance = self.q_posterior(weighted_x_start, x, t) 56 | 57 | return model_mean, model_variance, model_log_variance 58 | 59 | def p_losses(self, x_start, t, noise = None, clip_denoised = False): 60 | noise = default(noise, lambda: torch.randn_like(x_start)) 61 | x_t = self.q_sample(x_start = x_start, t = t, noise = noise) 62 | 63 | model_output = self.model(x_t, t) 64 | pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1) 65 | 66 | # get loss for predicted noise and x_start 67 | # with the loss weight given at initialization 68 | 69 | noise_loss = self.loss_fn(noise, pred_noise) * self.pred_noise_loss_weight 70 | x_start_loss = self.loss_fn(x_start, pred_x_start) * self.pred_x_start_loss_weight 71 | 72 | # calculate x_start from predicted noise 73 | # then do a weighted sum of the x_start prediction, weights also predicted by the model (softmax normalized) 74 | 75 | x_start_from_pred_noise = self.predict_start_from_noise(x_t, t, pred_noise) 76 | x_start_from_pred_noise = x_start_from_pred_noise.clamp(-2., 2.) 77 | weighted_x_start = einsum('b j h w, b j c h w -> b c h w', weights.softmax(dim = 1), torch.stack((x_start_from_pred_noise, pred_x_start), dim = 1)) 78 | 79 | # main loss to x_start with the weighted one 80 | 81 | weighted_x_start_loss = self.loss_fn(x_start, weighted_x_start) 82 | return weighted_x_start_loss + x_start_loss + noise_loss 83 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/learned_gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | from math import pi, sqrt, log as ln 4 | from inspect import isfunction 5 | from torch import nn, einsum 6 | from einops import rearrange 7 | 8 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one 9 | 10 | # constants 11 | 12 | NAT = 1. / ln(2) 13 | 14 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance']) 15 | 16 | # helper functions 17 | 18 | def exists(x): 19 | return x is not None 20 | 21 | def default(val, d): 22 | if exists(val): 23 | return val 24 | return d() if isfunction(d) else d 25 | 26 | # tensor helpers 27 | 28 | def log(t, eps = 1e-15): 29 | return torch.log(t.clamp(min = eps)) 30 | 31 | def meanflat(x): 32 | return x.mean(dim = tuple(range(1, len(x.shape)))) 33 | 34 | def normal_kl(mean1, logvar1, mean2, logvar2): 35 | """ 36 | KL divergence between normal distributions parameterized by mean and log-variance. 37 | """ 38 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)) 39 | 40 | def approx_standard_normal_cdf(x): 41 | return 0.5 * (1.0 + torch.tanh(sqrt(2.0 / pi) * (x + 0.044715 * (x ** 3)))) 42 | 43 | def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): 44 | assert x.shape == means.shape == log_scales.shape 45 | 46 | centered_x = x - means 47 | inv_stdv = torch.exp(-log_scales) 48 | plus_in = inv_stdv * (centered_x + 1. / 255.) 49 | cdf_plus = approx_standard_normal_cdf(plus_in) 50 | min_in = inv_stdv * (centered_x - 1. / 255.) 51 | cdf_min = approx_standard_normal_cdf(min_in) 52 | log_cdf_plus = log(cdf_plus) 53 | log_one_minus_cdf_min = log(1. - cdf_min) 54 | cdf_delta = cdf_plus - cdf_min 55 | 56 | log_probs = torch.where(x < -thres, 57 | log_cdf_plus, 58 | torch.where(x > thres, 59 | log_one_minus_cdf_min, 60 | log(cdf_delta))) 61 | 62 | return log_probs 63 | 64 | # https://arxiv.org/abs/2102.09672 65 | 66 | # i thought the results were questionable, if one were to focus only on FID 67 | # but may as well get this in here for others to try, as GLIDE is using it (and DALL-E2 first stage of cascade) 68 | # gaussian diffusion for learned variance + hybrid eps simple + vb loss 69 | 70 | class LearnedGaussianDiffusion(GaussianDiffusion): 71 | def __init__( 72 | self, 73 | model, 74 | vb_loss_weight = 0.001, # lambda was 0.001 in the paper 75 | *args, 76 | **kwargs 77 | ): 78 | super().__init__(model, *args, **kwargs) 79 | assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`' 80 | assert not model.self_condition, 'not supported yet' 81 | 82 | self.vb_loss_weight = vb_loss_weight 83 | 84 | def model_predictions(self, x, t): 85 | model_output = self.model(x, t) 86 | model_output, pred_variance = model_output.chunk(2, dim = 1) 87 | 88 | if self.objective == 'pred_noise': 89 | pred_noise = model_output 90 | x_start = self.predict_start_from_noise(x, t, model_output) 91 | 92 | elif self.objective == 'pred_x0': 93 | pred_noise = self.predict_noise_from_start(x, t, model_output) 94 | x_start = model_output 95 | 96 | return ModelPrediction(pred_noise, x_start, pred_variance) 97 | 98 | def p_mean_variance(self, *, x, t, clip_denoised, model_output = None): 99 | model_output = default(model_output, lambda: self.model(x, t)) 100 | pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1) 101 | 102 | min_log = extract(self.posterior_log_variance_clipped, t, x.shape) 103 | max_log = extract(torch.log(self.betas), t, x.shape) 104 | var_interp_frac = unnormalize_to_zero_to_one(var_interp_frac_unnormalized) 105 | 106 | model_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log 107 | model_variance = model_log_variance.exp() 108 | 109 | x_start = self.predict_start_from_noise(x, t, pred_noise) 110 | 111 | if clip_denoised: 112 | x_start.clamp_(-1., 1.) 113 | 114 | model_mean, _, _ = self.q_posterior(x_start, x, t) 115 | 116 | return model_mean, model_variance, model_log_variance 117 | 118 | def p_losses(self, x_start, t, noise = None, clip_denoised = False): 119 | noise = default(noise, lambda: torch.randn_like(x_start)) 120 | x_t = self.q_sample(x_start = x_start, t = t, noise = noise) 121 | 122 | # model output 123 | 124 | model_output = self.model(x_t, t) 125 | 126 | # calculating kl loss for learned variance (interpolation) 127 | 128 | true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t) 129 | model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output) 130 | 131 | # kl loss with detached model predicted mean, for stability reasons as in paper 132 | 133 | detached_model_mean = model_mean.detach() 134 | 135 | kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance) 136 | kl = meanflat(kl) * NAT 137 | 138 | decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance) 139 | decoder_nll = meanflat(decoder_nll) * NAT 140 | 141 | # at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 142 | 143 | vb_losses = torch.where(t == 0, decoder_nll, kl) 144 | 145 | # simple loss - predicting noise, x0, or x_prev 146 | 147 | pred_noise, _ = model_output.chunk(2, dim = 1) 148 | 149 | simple_losses = self.loss_fn(pred_noise, noise) 150 | 151 | return simple_losses + vb_losses.mean() * self.vb_loss_weight 152 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/v_param_continuous_time_gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import sqrt 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from torch.special import expm1 7 | 8 | from tqdm import tqdm 9 | from einops import rearrange, repeat, reduce 10 | from einops.layers.torch import Rearrange 11 | 12 | # helpers 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def default(val, d): 18 | if exists(val): 19 | return val 20 | return d() if callable(d) else d 21 | 22 | # normalization functions 23 | 24 | def normalize_to_neg_one_to_one(img): 25 | return img * 2 - 1 26 | 27 | def unnormalize_to_zero_to_one(t): 28 | return (t + 1) * 0.5 29 | 30 | # diffusion helpers 31 | 32 | def right_pad_dims_to(x, t): 33 | padding_dims = x.ndim - t.ndim 34 | if padding_dims <= 0: 35 | return t 36 | return t.view(*t.shape, *((1,) * padding_dims)) 37 | 38 | # continuous schedules 39 | # log(snr) that approximates the original linear schedule 40 | 41 | def log(t, eps = 1e-20): 42 | return torch.log(t.clamp(min = eps)) 43 | 44 | def alpha_cosine_log_snr(t, s = 0.008): 45 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) 46 | 47 | class VParamContinuousTimeGaussianDiffusion(nn.Module): 48 | """ 49 | a new type of parameterization in v-space proposed in https://arxiv.org/abs/2202.00512 that 50 | (1) allows for improved distillation over noise prediction objective and 51 | (2) noted in imagen-video to improve upsampling unets by removing the color shifting artifacts 52 | """ 53 | 54 | def __init__( 55 | self, 56 | model, 57 | *, 58 | image_size, 59 | channels = 3, 60 | num_sample_steps = 500, 61 | clip_sample_denoised = True, 62 | ): 63 | super().__init__() 64 | assert model.random_or_learned_sinusoidal_cond 65 | assert not model.self_condition, 'not supported yet' 66 | 67 | self.model = model 68 | 69 | # image dimensions 70 | 71 | self.channels = channels 72 | self.image_size = image_size 73 | 74 | # continuous noise schedule related stuff 75 | 76 | self.log_snr = alpha_cosine_log_snr 77 | 78 | # sampling 79 | 80 | self.num_sample_steps = num_sample_steps 81 | self.clip_sample_denoised = clip_sample_denoised 82 | 83 | @property 84 | def device(self): 85 | return next(self.model.parameters()).device 86 | 87 | def p_mean_variance(self, x, time, time_next): 88 | # reviewer found an error in the equation in the paper (missing sigma) 89 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt 90 | 91 | log_snr = self.log_snr(time) 92 | log_snr_next = self.log_snr(time_next) 93 | c = -expm1(log_snr - log_snr_next) 94 | 95 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid() 96 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid() 97 | 98 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next)) 99 | 100 | batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0]) 101 | 102 | pred_v = self.model(x, batch_log_snr) 103 | 104 | # shown in Appendix D in the paper 105 | x_start = alpha * x - sigma * pred_v 106 | 107 | if self.clip_sample_denoised: 108 | x_start.clamp_(-1., 1.) 109 | 110 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start) 111 | 112 | posterior_variance = squared_sigma_next * c 113 | 114 | return model_mean, posterior_variance 115 | 116 | # sampling related functions 117 | 118 | @torch.no_grad() 119 | def p_sample(self, x, time, time_next): 120 | batch, *_, device = *x.shape, x.device 121 | 122 | model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next) 123 | 124 | if time_next == 0: 125 | return model_mean 126 | 127 | noise = torch.randn_like(x) 128 | return model_mean + sqrt(model_variance) * noise 129 | 130 | @torch.no_grad() 131 | def p_sample_loop(self, shape): 132 | batch = shape[0] 133 | 134 | img = torch.randn(shape, device = self.device) 135 | steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device) 136 | 137 | for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps): 138 | times = steps[i] 139 | times_next = steps[i + 1] 140 | img = self.p_sample(img, times, times_next) 141 | 142 | img.clamp_(-1., 1.) 143 | img = unnormalize_to_zero_to_one(img) 144 | return img 145 | 146 | @torch.no_grad() 147 | def sample(self, batch_size = 16): 148 | return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size)) 149 | 150 | # training related functions - noise prediction 151 | 152 | def q_sample(self, x_start, times, noise = None): 153 | noise = default(noise, lambda: torch.randn_like(x_start)) 154 | 155 | log_snr = self.log_snr(times) 156 | 157 | log_snr_padded = right_pad_dims_to(x_start, log_snr) 158 | alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid()) 159 | x_noised = x_start * alpha + noise * sigma 160 | 161 | return x_noised, log_snr, alpha, sigma 162 | 163 | def random_times(self, batch_size): 164 | return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1) 165 | 166 | def p_losses(self, x_start, times, noise = None): 167 | noise = default(noise, lambda: torch.randn_like(x_start)) 168 | 169 | x, log_snr, alpha, sigma = self.q_sample(x_start = x_start, times = times, noise = noise) 170 | 171 | # described in section 4 as the prediction objective, with derivation in Appendix D 172 | v = alpha * noise - sigma * x_start 173 | 174 | model_out = self.model(x, log_snr) 175 | 176 | return F.mse_loss(model_out, v) 177 | 178 | def forward(self, img, *args, **kwargs): 179 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 180 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 181 | 182 | times = self.random_times(b) 183 | img = normalize_to_neg_one_to_one(img) 184 | return self.p_losses(img, times, *args, **kwargs) 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Denoising Diffusion Probabilistic Model, in Pytorch 4 | 5 | Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to generative modeling that may have the potential to rival GANs. It uses denoising score matching to estimate the gradient of the data distribution, followed by Langevin sampling to sample from the true distribution. 6 | 7 | This implementation was transcribed from the official Tensorflow version here 8 | 9 | Youtube AI Educators - Yannic Kilcher | AI Coffeebreak with Letitia | Outlier 10 | 11 | Flax implementation from YiYi Xu 12 | 13 | Annotated code by Research Scientists / Engineers from 🤗 Huggingface 14 | 15 | Update: Turns out none of the technicalities really matters at all | "Cold Diffusion" paper | Muse 16 | 17 | 18 | 19 | [![PyPI version](https://badge.fury.io/py/denoising-diffusion-pytorch.svg)](https://badge.fury.io/py/denoising-diffusion-pytorch) 20 | 21 | ## Install 22 | 23 | ```bash 24 | $ pip install denoising_diffusion_pytorch 25 | ``` 26 | 27 | ## Usage 28 | 29 | ```python 30 | import torch 31 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion 32 | 33 | model = Unet( 34 | dim = 64, 35 | dim_mults = (1, 2, 4, 8) 36 | ) 37 | 38 | diffusion = GaussianDiffusion( 39 | model, 40 | image_size = 128, 41 | timesteps = 1000, # number of steps 42 | loss_type = 'l1' # L1 or L2 43 | ) 44 | 45 | training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1 46 | loss = diffusion(training_images) 47 | loss.backward() 48 | # after a lot of training 49 | 50 | sampled_images = diffusion.sample(batch_size = 4) 51 | sampled_images.shape # (4, 3, 128, 128) 52 | ``` 53 | 54 | Or, if you simply want to pass in a folder name and the desired image dimensions, you can use the `Trainer` class to easily train a model. 55 | 56 | ```python 57 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer 58 | 59 | model = Unet( 60 | dim = 64, 61 | dim_mults = (1, 2, 4, 8) 62 | ).cuda() 63 | 64 | diffusion = GaussianDiffusion( 65 | model, 66 | image_size = 128, 67 | timesteps = 1000, # number of steps 68 | sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) 69 | loss_type = 'l1' # L1 or L2 70 | ).cuda() 71 | 72 | trainer = Trainer( 73 | diffusion, 74 | 'path/to/your/images', 75 | train_batch_size = 32, 76 | train_lr = 8e-5, 77 | train_num_steps = 700000, # total training steps 78 | gradient_accumulate_every = 2, # gradient accumulation steps 79 | ema_decay = 0.995, # exponential moving average decay 80 | amp = True # turn on mixed precision 81 | ) 82 | 83 | trainer.train() 84 | ``` 85 | 86 | Samples and model checkpoints will be logged to `./results` periodically 87 | 88 | ## Multi-GPU Training 89 | 90 | The `Trainer` class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their `accelerate` CLI 91 | 92 | At the project root directory, where the training script is, run 93 | 94 | ```python 95 | $ accelerate config 96 | ``` 97 | 98 | Then, in the same directory 99 | 100 | ```python 101 | $ accelerate launch train.py 102 | ``` 103 | 104 | ## Miscellaneous 105 | 106 | ### 1D Sequence 107 | 108 | By popular request, a 1D Unet + Gaussian Diffusion implementation. You will have to do the training code yourself 109 | 110 | ```python 111 | import torch 112 | from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D 113 | 114 | model = Unet1D( 115 | dim = 64, 116 | dim_mults = (1, 2, 4, 8), 117 | channels = 32 118 | ) 119 | 120 | diffusion = GaussianDiffusion1D( 121 | model, 122 | seq_length = 128, 123 | timesteps = 1000, 124 | objective = 'pred_v' 125 | ) 126 | 127 | training_seq = torch.rand(8, 32, 128) # features are normalized from 0 to 1 128 | loss = diffusion(training_seq) 129 | loss.backward() 130 | 131 | # after a lot of training 132 | 133 | sampled_seq = diffusion.sample(batch_size = 4) 134 | sampled_seq.shape # (4, 32, 128) 135 | ``` 136 | 137 | ## Citations 138 | 139 | ```bibtex 140 | @inproceedings{NEURIPS2020_4c5bcfec, 141 | author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter}, 142 | booktitle = {Advances in Neural Information Processing Systems}, 143 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin}, 144 | pages = {6840--6851}, 145 | publisher = {Curran Associates, Inc.}, 146 | title = {Denoising Diffusion Probabilistic Models}, 147 | url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf}, 148 | volume = {33}, 149 | year = {2020} 150 | } 151 | ``` 152 | 153 | ```bibtex 154 | @InProceedings{pmlr-v139-nichol21a, 155 | title = {Improved Denoising Diffusion Probabilistic Models}, 156 | author = {Nichol, Alexander Quinn and Dhariwal, Prafulla}, 157 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 158 | pages = {8162--8171}, 159 | year = {2021}, 160 | editor = {Meila, Marina and Zhang, Tong}, 161 | volume = {139}, 162 | series = {Proceedings of Machine Learning Research}, 163 | month = {18--24 Jul}, 164 | publisher = {PMLR}, 165 | pdf = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf}, 166 | url = {https://proceedings.mlr.press/v139/nichol21a.html}, 167 | } 168 | ``` 169 | 170 | ```bibtex 171 | @inproceedings{kingma2021on, 172 | title = {On Density Estimation with Diffusion Models}, 173 | author = {Diederik P Kingma and Tim Salimans and Ben Poole and Jonathan Ho}, 174 | booktitle = {Advances in Neural Information Processing Systems}, 175 | editor = {A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan}, 176 | year = {2021}, 177 | url = {https://openreview.net/forum?id=2LdBqxc1Yv} 178 | } 179 | ``` 180 | 181 | ```bibtex 182 | @article{Choi2022PerceptionPT, 183 | title = {Perception Prioritized Training of Diffusion Models}, 184 | author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon}, 185 | journal = {ArXiv}, 186 | year = {2022}, 187 | volume = {abs/2204.00227} 188 | } 189 | ``` 190 | 191 | ```bibtex 192 | @article{Karras2022ElucidatingTD, 193 | title = {Elucidating the Design Space of Diffusion-Based Generative Models}, 194 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, 195 | journal = {ArXiv}, 196 | year = {2022}, 197 | volume = {abs/2206.00364} 198 | } 199 | ``` 200 | 201 | ```bibtex 202 | @article{Song2021DenoisingDI, 203 | title = {Denoising Diffusion Implicit Models}, 204 | author = {Jiaming Song and Chenlin Meng and Stefano Ermon}, 205 | journal = {ArXiv}, 206 | year = {2021}, 207 | volume = {abs/2010.02502} 208 | } 209 | ``` 210 | 211 | ```bibtex 212 | @misc{chen2022analog, 213 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, 214 | author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton}, 215 | year = {2022}, 216 | eprint = {2208.04202}, 217 | archivePrefix = {arXiv}, 218 | primaryClass = {cs.CV} 219 | } 220 | ``` 221 | 222 | ```bibtex 223 | @article{Qiao2019WeightS, 224 | title = {Weight Standardization}, 225 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille}, 226 | journal = {ArXiv}, 227 | year = {2019}, 228 | volume = {abs/1903.10520} 229 | } 230 | ``` 231 | 232 | ```bibtex 233 | @article{Salimans2022ProgressiveDF, 234 | title = {Progressive Distillation for Fast Sampling of Diffusion Models}, 235 | author = {Tim Salimans and Jonathan Ho}, 236 | journal = {ArXiv}, 237 | year = {2022}, 238 | volume = {abs/2202.00512} 239 | } 240 | ``` 241 | 242 | ```bibtex 243 | @article{Ho2022ClassifierFreeDG, 244 | title = {Classifier-Free Diffusion Guidance}, 245 | author = {Jonathan Ho}, 246 | journal = {ArXiv}, 247 | year = {2022}, 248 | volume = {abs/2207.12598} 249 | } 250 | ``` 251 | 252 | ```bibtex 253 | @article{Sunkara2022NoMS, 254 | title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects}, 255 | author = {Raja Sunkara and Tie Luo}, 256 | journal = {ArXiv}, 257 | year = {2022}, 258 | volume = {abs/2208.03641} 259 | } 260 | ``` 261 | 262 | ```bibtex 263 | @inproceedings{Jabri2022ScalableAC, 264 | title = {Scalable Adaptive Computation for Iterative Generation}, 265 | author = {A. Jabri and David J. Fleet and Ting Chen}, 266 | year = {2022} 267 | } 268 | ``` 269 | 270 | ```bibtex 271 | @article{Cheng2022DPMSolverPlusPlus, 272 | title = {DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models}, 273 | author = {Cheng Lu and Yuhao Zhou and Fan Bao and Jianfei Chen and Chongxuan Li and Jun Zhu}, 274 | journal = {NeuRips 2022 Oral}, 275 | year = {2022}, 276 | volume = {abs/2211.01095} 277 | } 278 | ``` 279 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/elucidated_diffusion.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from random import random 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | 7 | from tqdm import tqdm 8 | from einops import rearrange, repeat, reduce 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if callable(d) else d 19 | 20 | # tensor helpers 21 | 22 | def log(t, eps = 1e-20): 23 | return torch.log(t.clamp(min = eps)) 24 | 25 | # normalization functions 26 | 27 | def normalize_to_neg_one_to_one(img): 28 | return img * 2 - 1 29 | 30 | def unnormalize_to_zero_to_one(t): 31 | return (t + 1) * 0.5 32 | 33 | # main class 34 | 35 | class ElucidatedDiffusion(nn.Module): 36 | def __init__( 37 | self, 38 | net, 39 | *, 40 | image_size, 41 | channels = 3, 42 | num_sample_steps = 32, # number of sampling steps 43 | sigma_min = 0.002, # min noise level 44 | sigma_max = 80, # max noise level 45 | sigma_data = 0.5, # standard deviation of data distribution 46 | rho = 7, # controls the sampling schedule 47 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 48 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 49 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 50 | S_tmin = 0.05, 51 | S_tmax = 50, 52 | S_noise = 1.003, 53 | ): 54 | super().__init__() 55 | assert net.random_or_learned_sinusoidal_cond 56 | self.self_condition = net.self_condition 57 | 58 | self.net = net 59 | 60 | # image dimensions 61 | 62 | self.channels = channels 63 | self.image_size = image_size 64 | 65 | # parameters 66 | 67 | self.sigma_min = sigma_min 68 | self.sigma_max = sigma_max 69 | self.sigma_data = sigma_data 70 | 71 | self.rho = rho 72 | 73 | self.P_mean = P_mean 74 | self.P_std = P_std 75 | 76 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper 77 | 78 | self.S_churn = S_churn 79 | self.S_tmin = S_tmin 80 | self.S_tmax = S_tmax 81 | self.S_noise = S_noise 82 | 83 | @property 84 | def device(self): 85 | return next(self.net.parameters()).device 86 | 87 | # derived preconditioning params - Table 1 88 | 89 | def c_skip(self, sigma): 90 | return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2) 91 | 92 | def c_out(self, sigma): 93 | return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5 94 | 95 | def c_in(self, sigma): 96 | return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5 97 | 98 | def c_noise(self, sigma): 99 | return log(sigma) * 0.25 100 | 101 | # preconditioned network output 102 | # equation (7) in the paper 103 | 104 | def preconditioned_network_forward(self, noised_images, sigma, self_cond = None, clamp = False): 105 | batch, device = noised_images.shape[0], noised_images.device 106 | 107 | if isinstance(sigma, float): 108 | sigma = torch.full((batch,), sigma, device = device) 109 | 110 | padded_sigma = rearrange(sigma, 'b -> b 1 1 1') 111 | 112 | net_out = self.net( 113 | self.c_in(padded_sigma) * noised_images, 114 | self.c_noise(sigma), 115 | self_cond 116 | ) 117 | 118 | out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out 119 | 120 | if clamp: 121 | out = out.clamp(-1., 1.) 122 | 123 | return out 124 | 125 | # sampling 126 | 127 | # sample schedule 128 | # equation (5) in the paper 129 | 130 | def sample_schedule(self, num_sample_steps = None): 131 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 132 | 133 | N = num_sample_steps 134 | inv_rho = 1 / self.rho 135 | 136 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) 137 | sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho 138 | 139 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. 140 | return sigmas 141 | 142 | @torch.no_grad() 143 | def sample(self, batch_size = 16, num_sample_steps = None, clamp = True): 144 | num_sample_steps = default(num_sample_steps, self.num_sample_steps) 145 | 146 | shape = (batch_size, self.channels, self.image_size, self.image_size) 147 | 148 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 149 | 150 | sigmas = self.sample_schedule(num_sample_steps) 151 | 152 | gammas = torch.where( 153 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), 154 | min(self.S_churn / num_sample_steps, sqrt(2) - 1), 155 | 0. 156 | ) 157 | 158 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 159 | 160 | # images is noise at the beginning 161 | 162 | init_sigma = sigmas[0] 163 | 164 | images = init_sigma * torch.randn(shape, device = self.device) 165 | 166 | # for self conditioning 167 | 168 | x_start = None 169 | 170 | # gradually denoise 171 | 172 | for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'): 173 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 174 | 175 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 176 | 177 | sigma_hat = sigma + gamma * sigma 178 | images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps 179 | 180 | self_cond = x_start if self.self_condition else None 181 | 182 | model_output = self.preconditioned_network_forward(images_hat, sigma_hat, self_cond, clamp = clamp) 183 | denoised_over_sigma = (images_hat - model_output) / sigma_hat 184 | 185 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma 186 | 187 | # second order correction, if not the last timestep 188 | 189 | if sigma_next != 0: 190 | self_cond = model_output if self.self_condition else None 191 | 192 | model_output_next = self.preconditioned_network_forward(images_next, sigma_next, self_cond, clamp = clamp) 193 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next 194 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 195 | 196 | images = images_next 197 | x_start = model_output_next if sigma_next != 0 else model_output 198 | 199 | images = images.clamp(-1., 1.) 200 | return unnormalize_to_zero_to_one(images) 201 | 202 | @torch.no_grad() 203 | def sample_using_dpmpp(self, batch_size = 16, num_sample_steps = None): 204 | """ 205 | thanks to Katherine Crowson (https://github.com/crowsonkb) for figuring it all out! 206 | https://arxiv.org/abs/2211.01095 207 | """ 208 | 209 | device, num_sample_steps = self.device, default(num_sample_steps, self.num_sample_steps) 210 | 211 | sigmas = self.sample_schedule(num_sample_steps) 212 | 213 | shape = (batch_size, self.channels, self.image_size, self.image_size) 214 | images = sigmas[0] * torch.randn(shape, device = device) 215 | 216 | sigma_fn = lambda t: t.neg().exp() 217 | t_fn = lambda sigma: sigma.log().neg() 218 | 219 | old_denoised = None 220 | for i in tqdm(range(len(sigmas) - 1)): 221 | denoised = self.preconditioned_network_forward(images, sigmas[i].item()) 222 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) 223 | h = t_next - t 224 | 225 | if not exists(old_denoised) or sigmas[i + 1] == 0: 226 | denoised_d = denoised 227 | else: 228 | h_last = t - t_fn(sigmas[i - 1]) 229 | r = h_last / h 230 | gamma = - 1 / (2 * r) 231 | denoised_d = (1 - gamma) * denoised + gamma * old_denoised 232 | 233 | images = (sigma_fn(t_next) / sigma_fn(t)) * images - (-h).expm1() * denoised_d 234 | old_denoised = denoised 235 | 236 | images = images.clamp(-1., 1.) 237 | return unnormalize_to_zero_to_one(images) 238 | 239 | # training 240 | 241 | def loss_weight(self, sigma): 242 | return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2 243 | 244 | def noise_distribution(self, batch_size): 245 | return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() 246 | 247 | def forward(self, images): 248 | batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels 249 | 250 | assert h == image_size and w == image_size, f'height and width of image must be {image_size}' 251 | assert c == channels, 'mismatch of image channels' 252 | 253 | images = normalize_to_neg_one_to_one(images) 254 | 255 | sigmas = self.noise_distribution(batch_size) 256 | padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1') 257 | 258 | noise = torch.randn_like(images) 259 | 260 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper 261 | 262 | self_cond = None 263 | 264 | if self.self_condition and random() < 0.5: 265 | # from hinton's group's bit diffusion paper 266 | with torch.no_grad(): 267 | self_cond = self.preconditioned_network_forward(noised_images, sigmas) 268 | self_cond.detach_() 269 | 270 | denoised = self.preconditioned_network_forward(noised_images, sigmas, self_cond) 271 | 272 | losses = F.mse_loss(denoised, images, reduction = 'none') 273 | losses = reduce(losses, 'b ... -> b', 'mean') 274 | 275 | losses = losses * self.loss_weight(sigmas) 276 | 277 | return losses.mean() 278 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import sqrt 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from torch.special import expm1 7 | 8 | from tqdm import tqdm 9 | from einops import rearrange, repeat, reduce 10 | from einops.layers.torch import Rearrange 11 | 12 | # helpers 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def default(val, d): 18 | if exists(val): 19 | return val 20 | return d() if callable(d) else d 21 | 22 | # normalization functions 23 | 24 | def normalize_to_neg_one_to_one(img): 25 | return img * 2 - 1 26 | 27 | def unnormalize_to_zero_to_one(t): 28 | return (t + 1) * 0.5 29 | 30 | # diffusion helpers 31 | 32 | def right_pad_dims_to(x, t): 33 | padding_dims = x.ndim - t.ndim 34 | if padding_dims <= 0: 35 | return t 36 | return t.view(*t.shape, *((1,) * padding_dims)) 37 | 38 | # neural net helpers 39 | 40 | class Residual(nn.Module): 41 | def __init__(self, fn): 42 | super().__init__() 43 | self.fn = fn 44 | 45 | def forward(self, x): 46 | return x + self.fn(x) 47 | 48 | class MonotonicLinear(nn.Module): 49 | def __init__(self, *args, **kwargs): 50 | super().__init__() 51 | self.net = nn.Linear(*args, **kwargs) 52 | 53 | def forward(self, x): 54 | return F.linear(x, self.net.weight.abs(), self.net.bias.abs()) 55 | 56 | # continuous schedules 57 | 58 | # equations are taken from https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material 59 | # @crowsonkb Katherine's repository also helped here https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py 60 | 61 | # log(snr) that approximates the original linear schedule 62 | 63 | def log(t, eps = 1e-20): 64 | return torch.log(t.clamp(min = eps)) 65 | 66 | def beta_linear_log_snr(t): 67 | return -log(expm1(1e-4 + 10 * (t ** 2))) 68 | 69 | def alpha_cosine_log_snr(t, s = 0.008): 70 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) 71 | 72 | class learned_noise_schedule(nn.Module): 73 | """ described in section H and then I.2 of the supplementary material for variational ddpm paper """ 74 | 75 | def __init__( 76 | self, 77 | *, 78 | log_snr_max, 79 | log_snr_min, 80 | hidden_dim = 1024, 81 | frac_gradient = 1. 82 | ): 83 | super().__init__() 84 | self.slope = log_snr_min - log_snr_max 85 | self.intercept = log_snr_max 86 | 87 | self.net = nn.Sequential( 88 | Rearrange('... -> ... 1'), 89 | MonotonicLinear(1, 1), 90 | Residual(nn.Sequential( 91 | MonotonicLinear(1, hidden_dim), 92 | nn.Sigmoid(), 93 | MonotonicLinear(hidden_dim, 1) 94 | )), 95 | Rearrange('... 1 -> ...'), 96 | ) 97 | 98 | self.frac_gradient = frac_gradient 99 | 100 | def forward(self, x): 101 | frac_gradient = self.frac_gradient 102 | device = x.device 103 | 104 | out_zero = self.net(torch.zeros_like(x)) 105 | out_one = self.net(torch.ones_like(x)) 106 | 107 | x = self.net(x) 108 | 109 | normed = self.slope * ((x - out_zero) / (out_one - out_zero)) + self.intercept 110 | return normed * frac_gradient + normed.detach() * (1 - frac_gradient) 111 | 112 | class ContinuousTimeGaussianDiffusion(nn.Module): 113 | def __init__( 114 | self, 115 | model, 116 | *, 117 | image_size, 118 | channels = 3, 119 | loss_type = 'l1', 120 | noise_schedule = 'linear', 121 | num_sample_steps = 500, 122 | clip_sample_denoised = True, 123 | learned_schedule_net_hidden_dim = 1024, 124 | learned_noise_schedule_frac_gradient = 1., # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly 125 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time 126 | p2_loss_weight_k = 1 127 | ): 128 | super().__init__() 129 | assert model.random_or_learned_sinusoidal_cond 130 | assert not model.self_condition, 'not supported yet' 131 | 132 | self.model = model 133 | 134 | # image dimensions 135 | 136 | self.channels = channels 137 | self.image_size = image_size 138 | 139 | # continuous noise schedule related stuff 140 | 141 | self.loss_type = loss_type 142 | 143 | if noise_schedule == 'linear': 144 | self.log_snr = beta_linear_log_snr 145 | elif noise_schedule == 'cosine': 146 | self.log_snr = alpha_cosine_log_snr 147 | elif noise_schedule == 'learned': 148 | log_snr_max, log_snr_min = [beta_linear_log_snr(torch.tensor([time])).item() for time in (0., 1.)] 149 | 150 | self.log_snr = learned_noise_schedule( 151 | log_snr_max = log_snr_max, 152 | log_snr_min = log_snr_min, 153 | hidden_dim = learned_schedule_net_hidden_dim, 154 | frac_gradient = learned_noise_schedule_frac_gradient 155 | ) 156 | else: 157 | raise ValueError(f'unknown noise schedule {noise_schedule}') 158 | 159 | # sampling 160 | 161 | self.num_sample_steps = num_sample_steps 162 | self.clip_sample_denoised = clip_sample_denoised 163 | 164 | # p2 loss weight 165 | # proposed https://arxiv.org/abs/2204.00227 166 | 167 | assert p2_loss_weight_gamma <= 2, 'in paper, they noticed any gamma greater than 2 is harmful' 168 | 169 | self.p2_loss_weight_gamma = p2_loss_weight_gamma # recommended to be 0.5 or 1 170 | self.p2_loss_weight_k = p2_loss_weight_k 171 | 172 | @property 173 | def device(self): 174 | return next(self.model.parameters()).device 175 | 176 | @property 177 | def loss_fn(self): 178 | if self.loss_type == 'l1': 179 | return F.l1_loss 180 | elif self.loss_type == 'l2': 181 | return F.mse_loss 182 | else: 183 | raise ValueError(f'invalid loss type {self.loss_type}') 184 | 185 | def p_mean_variance(self, x, time, time_next): 186 | # reviewer found an error in the equation in the paper (missing sigma) 187 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt 188 | 189 | log_snr = self.log_snr(time) 190 | log_snr_next = self.log_snr(time_next) 191 | c = -expm1(log_snr - log_snr_next) 192 | 193 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid() 194 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid() 195 | 196 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next)) 197 | 198 | batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0]) 199 | pred_noise = self.model(x, batch_log_snr) 200 | 201 | if self.clip_sample_denoised: 202 | x_start = (x - sigma * pred_noise) / alpha 203 | 204 | # in Imagen, this was changed to dynamic thresholding 205 | x_start.clamp_(-1., 1.) 206 | 207 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start) 208 | else: 209 | model_mean = alpha_next / alpha * (x - c * sigma * pred_noise) 210 | 211 | posterior_variance = squared_sigma_next * c 212 | 213 | return model_mean, posterior_variance 214 | 215 | # sampling related functions 216 | 217 | @torch.no_grad() 218 | def p_sample(self, x, time, time_next): 219 | batch, *_, device = *x.shape, x.device 220 | 221 | model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next) 222 | 223 | if time_next == 0: 224 | return model_mean 225 | 226 | noise = torch.randn_like(x) 227 | return model_mean + sqrt(model_variance) * noise 228 | 229 | @torch.no_grad() 230 | def p_sample_loop(self, shape): 231 | batch = shape[0] 232 | 233 | img = torch.randn(shape, device = self.device) 234 | steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device) 235 | 236 | for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps): 237 | times = steps[i] 238 | times_next = steps[i + 1] 239 | img = self.p_sample(img, times, times_next) 240 | 241 | img.clamp_(-1., 1.) 242 | img = unnormalize_to_zero_to_one(img) 243 | return img 244 | 245 | @torch.no_grad() 246 | def sample(self, batch_size = 16): 247 | return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size)) 248 | 249 | # training related functions - noise prediction 250 | 251 | def q_sample(self, x_start, times, noise = None): 252 | noise = default(noise, lambda: torch.randn_like(x_start)) 253 | 254 | log_snr = self.log_snr(times) 255 | 256 | log_snr_padded = right_pad_dims_to(x_start, log_snr) 257 | alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid()) 258 | x_noised = x_start * alpha + noise * sigma 259 | 260 | return x_noised, log_snr 261 | 262 | def random_times(self, batch_size): 263 | # times are now uniform from 0 to 1 264 | return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1) 265 | 266 | def p_losses(self, x_start, times, noise = None): 267 | noise = default(noise, lambda: torch.randn_like(x_start)) 268 | 269 | x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise) 270 | model_out = self.model(x, log_snr) 271 | 272 | losses = self.loss_fn(model_out, noise, reduction = 'none') 273 | losses = reduce(losses, 'b ... -> b', 'mean') 274 | 275 | if self.p2_loss_weight_gamma >= 0: 276 | # following eq 8. in https://arxiv.org/abs/2204.00227 277 | loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -self.p2_loss_weight_gamma 278 | losses = losses * loss_weight 279 | 280 | return losses.mean() 281 | 282 | def forward(self, img, *args, **kwargs): 283 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 284 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 285 | 286 | times = self.random_times(b) 287 | img = normalize_to_neg_one_to_one(img) 288 | return self.p_losses(img, times, *args, **kwargs) 289 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py: -------------------------------------------------------------------------------- 1 | import math 2 | from random import random 3 | from functools import partial 4 | from collections import namedtuple 5 | 6 | import torch 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange, reduce 11 | from einops.layers.torch import Rearrange 12 | 13 | from tqdm.auto import tqdm 14 | 15 | # constants 16 | 17 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) 18 | 19 | # helpers functions 20 | 21 | def exists(x): 22 | return x is not None 23 | 24 | def default(val, d): 25 | if exists(val): 26 | return val 27 | return d() if callable(d) else d 28 | 29 | def identity(t, *args, **kwargs): 30 | return t 31 | 32 | def cycle(dl): 33 | while True: 34 | for data in dl: 35 | yield data 36 | 37 | def has_int_squareroot(num): 38 | return (math.sqrt(num) ** 2) == num 39 | 40 | def num_to_groups(num, divisor): 41 | groups = num // divisor 42 | remainder = num % divisor 43 | arr = [divisor] * groups 44 | if remainder > 0: 45 | arr.append(remainder) 46 | return arr 47 | 48 | def convert_image_to_fn(img_type, image): 49 | if image.mode != img_type: 50 | return image.convert(img_type) 51 | return image 52 | 53 | # normalization functions 54 | 55 | def normalize_to_neg_one_to_one(img): 56 | return img * 2 - 1 57 | 58 | def unnormalize_to_zero_to_one(t): 59 | return (t + 1) * 0.5 60 | 61 | # small helper modules 62 | 63 | class Residual(nn.Module): 64 | def __init__(self, fn): 65 | super().__init__() 66 | self.fn = fn 67 | 68 | def forward(self, x, *args, **kwargs): 69 | return self.fn(x, *args, **kwargs) + x 70 | 71 | def Upsample(dim, dim_out = None): 72 | return nn.Sequential( 73 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 74 | nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1) 75 | ) 76 | 77 | def Downsample(dim, dim_out = None): 78 | return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1) 79 | 80 | class WeightStandardizedConv2d(nn.Conv1d): 81 | """ 82 | https://arxiv.org/abs/1903.10520 83 | weight standardization purportedly works synergistically with group normalization 84 | """ 85 | def forward(self, x): 86 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 87 | 88 | weight = self.weight 89 | mean = reduce(weight, 'o ... -> o 1 1', 'mean') 90 | var = reduce(weight, 'o ... -> o 1 1', partial(torch.var, unbiased = False)) 91 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 92 | 93 | return F.conv1d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 94 | 95 | class LayerNorm(nn.Module): 96 | def __init__(self, dim): 97 | super().__init__() 98 | self.g = nn.Parameter(torch.ones(1, dim, 1)) 99 | 100 | def forward(self, x): 101 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 102 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 103 | mean = torch.mean(x, dim = 1, keepdim = True) 104 | return (x - mean) * (var + eps).rsqrt() * self.g 105 | 106 | class PreNorm(nn.Module): 107 | def __init__(self, dim, fn): 108 | super().__init__() 109 | self.fn = fn 110 | self.norm = LayerNorm(dim) 111 | 112 | def forward(self, x): 113 | x = self.norm(x) 114 | return self.fn(x) 115 | 116 | # sinusoidal positional embeds 117 | 118 | class SinusoidalPosEmb(nn.Module): 119 | def __init__(self, dim): 120 | super().__init__() 121 | self.dim = dim 122 | 123 | def forward(self, x): 124 | device = x.device 125 | half_dim = self.dim // 2 126 | emb = math.log(10000) / (half_dim - 1) 127 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 128 | emb = x[:, None] * emb[None, :] 129 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 130 | return emb 131 | 132 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 133 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ 134 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 135 | 136 | def __init__(self, dim, is_random = False): 137 | super().__init__() 138 | assert (dim % 2) == 0 139 | half_dim = dim // 2 140 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) 141 | 142 | def forward(self, x): 143 | x = rearrange(x, 'b -> b 1') 144 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 145 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 146 | fouriered = torch.cat((x, fouriered), dim = -1) 147 | return fouriered 148 | 149 | # building block modules 150 | 151 | class Block(nn.Module): 152 | def __init__(self, dim, dim_out, groups = 8): 153 | super().__init__() 154 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) 155 | self.norm = nn.GroupNorm(groups, dim_out) 156 | self.act = nn.SiLU() 157 | 158 | def forward(self, x, scale_shift = None): 159 | x = self.proj(x) 160 | x = self.norm(x) 161 | 162 | if exists(scale_shift): 163 | scale, shift = scale_shift 164 | x = x * (scale + 1) + shift 165 | 166 | x = self.act(x) 167 | return x 168 | 169 | class ResnetBlock(nn.Module): 170 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): 171 | super().__init__() 172 | self.mlp = nn.Sequential( 173 | nn.SiLU(), 174 | nn.Linear(time_emb_dim, dim_out * 2) 175 | ) if exists(time_emb_dim) else None 176 | 177 | self.block1 = Block(dim, dim_out, groups = groups) 178 | self.block2 = Block(dim_out, dim_out, groups = groups) 179 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 180 | 181 | def forward(self, x, time_emb = None): 182 | 183 | scale_shift = None 184 | if exists(self.mlp) and exists(time_emb): 185 | time_emb = self.mlp(time_emb) 186 | time_emb = rearrange(time_emb, 'b c -> b c 1') 187 | scale_shift = time_emb.chunk(2, dim = 1) 188 | 189 | h = self.block1(x, scale_shift = scale_shift) 190 | 191 | h = self.block2(h) 192 | 193 | return h + self.res_conv(x) 194 | 195 | class LinearAttention(nn.Module): 196 | def __init__(self, dim, heads = 4, dim_head = 32): 197 | super().__init__() 198 | self.scale = dim_head ** -0.5 199 | self.heads = heads 200 | hidden_dim = dim_head * heads 201 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False) 202 | 203 | self.to_out = nn.Sequential( 204 | nn.Conv1d(hidden_dim, dim, 1), 205 | LayerNorm(dim) 206 | ) 207 | 208 | def forward(self, x): 209 | b, c, n = x.shape 210 | qkv = self.to_qkv(x).chunk(3, dim = 1) 211 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv) 212 | 213 | q = q.softmax(dim = -2) 214 | k = k.softmax(dim = -1) 215 | 216 | q = q * self.scale 217 | 218 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 219 | 220 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 221 | out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads) 222 | return self.to_out(out) 223 | 224 | class Attention(nn.Module): 225 | def __init__(self, dim, heads = 4, dim_head = 32): 226 | super().__init__() 227 | self.scale = dim_head ** -0.5 228 | self.heads = heads 229 | hidden_dim = dim_head * heads 230 | 231 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False) 232 | self.to_out = nn.Conv1d(hidden_dim, dim, 1) 233 | 234 | def forward(self, x): 235 | b, c, n = x.shape 236 | qkv = self.to_qkv(x).chunk(3, dim = 1) 237 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv) 238 | 239 | q = q * self.scale 240 | 241 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 242 | attn = sim.softmax(dim = -1) 243 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 244 | 245 | out = rearrange(out, 'b h n d -> b (h d) n') 246 | return self.to_out(out) 247 | 248 | # model 249 | 250 | class Unet1D(nn.Module): 251 | def __init__( 252 | self, 253 | dim, 254 | init_dim = None, 255 | out_dim = None, 256 | dim_mults=(1, 2, 4, 8), 257 | channels = 3, 258 | self_condition = False, 259 | resnet_block_groups = 8, 260 | learned_variance = False, 261 | learned_sinusoidal_cond = False, 262 | random_fourier_features = False, 263 | learned_sinusoidal_dim = 16 264 | ): 265 | super().__init__() 266 | 267 | # determine dimensions 268 | 269 | self.channels = channels 270 | self.self_condition = self_condition 271 | input_channels = channels * (2 if self_condition else 1) 272 | 273 | init_dim = default(init_dim, dim) 274 | self.init_conv = nn.Conv1d(input_channels, init_dim, 7, padding = 3) 275 | 276 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 277 | in_out = list(zip(dims[:-1], dims[1:])) 278 | 279 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 280 | 281 | # time embeddings 282 | 283 | time_dim = dim * 4 284 | 285 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 286 | 287 | if self.random_or_learned_sinusoidal_cond: 288 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) 289 | fourier_dim = learned_sinusoidal_dim + 1 290 | else: 291 | sinu_pos_emb = SinusoidalPosEmb(dim) 292 | fourier_dim = dim 293 | 294 | self.time_mlp = nn.Sequential( 295 | sinu_pos_emb, 296 | nn.Linear(fourier_dim, time_dim), 297 | nn.GELU(), 298 | nn.Linear(time_dim, time_dim) 299 | ) 300 | 301 | # layers 302 | 303 | self.downs = nn.ModuleList([]) 304 | self.ups = nn.ModuleList([]) 305 | num_resolutions = len(in_out) 306 | 307 | for ind, (dim_in, dim_out) in enumerate(in_out): 308 | is_last = ind >= (num_resolutions - 1) 309 | 310 | self.downs.append(nn.ModuleList([ 311 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 312 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 313 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 314 | Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1) 315 | ])) 316 | 317 | mid_dim = dims[-1] 318 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 319 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 320 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 321 | 322 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 323 | is_last = ind == (len(in_out) - 1) 324 | 325 | self.ups.append(nn.ModuleList([ 326 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 327 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 328 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 329 | Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1) 330 | ])) 331 | 332 | default_out_dim = channels * (1 if not learned_variance else 2) 333 | self.out_dim = default(out_dim, default_out_dim) 334 | 335 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 336 | self.final_conv = nn.Conv1d(dim, self.out_dim, 1) 337 | 338 | def forward(self, x, time, x_self_cond = None): 339 | if self.self_condition: 340 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 341 | x = torch.cat((x_self_cond, x), dim = 1) 342 | 343 | x = self.init_conv(x) 344 | r = x.clone() 345 | 346 | t = self.time_mlp(time) 347 | 348 | h = [] 349 | 350 | for block1, block2, attn, downsample in self.downs: 351 | x = block1(x, t) 352 | h.append(x) 353 | 354 | x = block2(x, t) 355 | x = attn(x) 356 | h.append(x) 357 | 358 | x = downsample(x) 359 | 360 | x = self.mid_block1(x, t) 361 | x = self.mid_attn(x) 362 | x = self.mid_block2(x, t) 363 | 364 | for block1, block2, attn, upsample in self.ups: 365 | x = torch.cat((x, h.pop()), dim = 1) 366 | x = block1(x, t) 367 | 368 | x = torch.cat((x, h.pop()), dim = 1) 369 | x = block2(x, t) 370 | x = attn(x) 371 | 372 | x = upsample(x) 373 | 374 | x = torch.cat((x, r), dim = 1) 375 | 376 | x = self.final_res_block(x, t) 377 | return self.final_conv(x) 378 | 379 | # gaussian diffusion trainer class 380 | 381 | def extract(a, t, x_shape): 382 | b, *_ = t.shape 383 | out = a.gather(-1, t) 384 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 385 | 386 | def linear_beta_schedule(timesteps): 387 | scale = 1000 / timesteps 388 | beta_start = scale * 0.0001 389 | beta_end = scale * 0.02 390 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 391 | 392 | def cosine_beta_schedule(timesteps, s = 0.008): 393 | """ 394 | cosine schedule 395 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 396 | """ 397 | steps = timesteps + 1 398 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 399 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 400 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 401 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 402 | return torch.clip(betas, 0, 0.999) 403 | 404 | class GaussianDiffusion1D(nn.Module): 405 | def __init__( 406 | self, 407 | model, 408 | *, 409 | seq_length, 410 | timesteps = 1000, 411 | sampling_timesteps = None, 412 | loss_type = 'l1', 413 | objective = 'pred_noise', 414 | beta_schedule = 'cosine', 415 | p2_loss_weight_gamma = 0., 416 | p2_loss_weight_k = 1, 417 | ddim_sampling_eta = 0., 418 | auto_normalize = True 419 | ): 420 | super().__init__() 421 | self.model = model 422 | self.channels = self.model.channels 423 | self.self_condition = self.model.self_condition 424 | 425 | self.seq_length = seq_length 426 | 427 | self.objective = objective 428 | 429 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' 430 | 431 | if beta_schedule == 'linear': 432 | betas = linear_beta_schedule(timesteps) 433 | elif beta_schedule == 'cosine': 434 | betas = cosine_beta_schedule(timesteps) 435 | else: 436 | raise ValueError(f'unknown beta schedule {beta_schedule}') 437 | 438 | alphas = 1. - betas 439 | alphas_cumprod = torch.cumprod(alphas, dim=0) 440 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 441 | 442 | timesteps, = betas.shape 443 | self.num_timesteps = int(timesteps) 444 | self.loss_type = loss_type 445 | 446 | # sampling related parameters 447 | 448 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training 449 | 450 | assert self.sampling_timesteps <= timesteps 451 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 452 | self.ddim_sampling_eta = ddim_sampling_eta 453 | 454 | # helper function to register buffer from float64 to float32 455 | 456 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 457 | 458 | register_buffer('betas', betas) 459 | register_buffer('alphas_cumprod', alphas_cumprod) 460 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 461 | 462 | # calculations for diffusion q(x_t | x_{t-1}) and others 463 | 464 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 465 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 466 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 467 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 468 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 469 | 470 | # calculations for posterior q(x_{t-1} | x_t, x_0) 471 | 472 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 473 | 474 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 475 | 476 | register_buffer('posterior_variance', posterior_variance) 477 | 478 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 479 | 480 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 481 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 482 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 483 | 484 | # calculate p2 reweighting 485 | 486 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) 487 | 488 | # whether to autonormalize 489 | 490 | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity 491 | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity 492 | 493 | def predict_start_from_noise(self, x_t, t, noise): 494 | return ( 495 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 496 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 497 | ) 498 | 499 | def predict_noise_from_start(self, x_t, t, x0): 500 | return ( 501 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ 502 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 503 | ) 504 | 505 | def predict_v(self, x_start, t, noise): 506 | return ( 507 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - 508 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start 509 | ) 510 | 511 | def predict_start_from_v(self, x_t, t, v): 512 | return ( 513 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 514 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 515 | ) 516 | 517 | def q_posterior(self, x_start, x_t, t): 518 | posterior_mean = ( 519 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 520 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 521 | ) 522 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 523 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 524 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 525 | 526 | def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): 527 | model_output = self.model(x, t, x_self_cond) 528 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity 529 | 530 | if self.objective == 'pred_noise': 531 | pred_noise = model_output 532 | x_start = self.predict_start_from_noise(x, t, pred_noise) 533 | x_start = maybe_clip(x_start) 534 | 535 | elif self.objective == 'pred_x0': 536 | x_start = model_output 537 | x_start = maybe_clip(x_start) 538 | pred_noise = self.predict_noise_from_start(x, t, x_start) 539 | 540 | elif self.objective == 'pred_v': 541 | v = model_output 542 | x_start = self.predict_start_from_v(x, t, v) 543 | x_start = maybe_clip(x_start) 544 | pred_noise = self.predict_noise_from_start(x, t, x_start) 545 | 546 | return ModelPrediction(pred_noise, x_start) 547 | 548 | def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): 549 | preds = self.model_predictions(x, t, x_self_cond) 550 | x_start = preds.pred_x_start 551 | 552 | if clip_denoised: 553 | x_start.clamp_(-1., 1.) 554 | 555 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) 556 | return model_mean, posterior_variance, posterior_log_variance, x_start 557 | 558 | @torch.no_grad() 559 | def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True): 560 | b, *_, device = *x.shape, x.device 561 | batched_times = torch.full((b,), t, device = x.device, dtype = torch.long) 562 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised) 563 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 564 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 565 | return pred_img, x_start 566 | 567 | @torch.no_grad() 568 | def p_sample_loop(self, shape): 569 | batch, device = shape[0], self.betas.device 570 | 571 | img = torch.randn(shape, device=device) 572 | 573 | x_start = None 574 | 575 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): 576 | self_cond = x_start if self.self_condition else None 577 | img, x_start = self.p_sample(img, t, self_cond) 578 | 579 | img = self.unnormalize(img) 580 | return img 581 | 582 | @torch.no_grad() 583 | def ddim_sample(self, shape, clip_denoised = True): 584 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 585 | 586 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps 587 | times = list(reversed(times.int().tolist())) 588 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] 589 | 590 | img = torch.randn(shape, device = device) 591 | 592 | x_start = None 593 | 594 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): 595 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 596 | self_cond = x_start if self.self_condition else None 597 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = clip_denoised) 598 | 599 | if time_next < 0: 600 | img = x_start 601 | continue 602 | 603 | alpha = self.alphas_cumprod[time] 604 | alpha_next = self.alphas_cumprod[time_next] 605 | 606 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 607 | c = (1 - alpha_next - sigma ** 2).sqrt() 608 | 609 | noise = torch.randn_like(img) 610 | 611 | img = x_start * alpha_next.sqrt() + \ 612 | c * pred_noise + \ 613 | sigma * noise 614 | 615 | img = self.unnormalize(img) 616 | return img 617 | 618 | @torch.no_grad() 619 | def sample(self, batch_size = 16): 620 | seq_length, channels = self.seq_length, self.channels 621 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample 622 | return sample_fn((batch_size, channels, seq_length)) 623 | 624 | @torch.no_grad() 625 | def interpolate(self, x1, x2, t = None, lam = 0.5): 626 | b, *_, device = *x1.shape, x1.device 627 | t = default(t, self.num_timesteps - 1) 628 | 629 | assert x1.shape == x2.shape 630 | 631 | t_batched = torch.full((b,), t, device = device) 632 | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) 633 | 634 | img = (1 - lam) * xt1 + lam * xt2 635 | 636 | x_start = None 637 | 638 | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): 639 | self_cond = x_start if self.self_condition else None 640 | img, x_start = self.p_sample(img, i, self_cond) 641 | 642 | return img 643 | 644 | def q_sample(self, x_start, t, noise=None): 645 | noise = default(noise, lambda: torch.randn_like(x_start)) 646 | 647 | return ( 648 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 649 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 650 | ) 651 | 652 | @property 653 | def loss_fn(self): 654 | if self.loss_type == 'l1': 655 | return F.l1_loss 656 | elif self.loss_type == 'l2': 657 | return F.mse_loss 658 | else: 659 | raise ValueError(f'invalid loss type {self.loss_type}') 660 | 661 | def p_losses(self, x_start, t, noise = None): 662 | b, c, n = x_start.shape 663 | noise = default(noise, lambda: torch.randn_like(x_start)) 664 | 665 | # noise sample 666 | 667 | x = self.q_sample(x_start = x_start, t = t, noise = noise) 668 | 669 | # if doing self-conditioning, 50% of the time, predict x_start from current set of times 670 | # and condition with unet with that 671 | # this technique will slow down training by 25%, but seems to lower FID significantly 672 | 673 | x_self_cond = None 674 | if self.self_condition and random() < 0.5: 675 | with torch.no_grad(): 676 | x_self_cond = self.model_predictions(x, t).pred_x_start 677 | x_self_cond.detach_() 678 | 679 | # predict and take gradient step 680 | 681 | model_out = self.model(x, t, x_self_cond) 682 | 683 | if self.objective == 'pred_noise': 684 | target = noise 685 | elif self.objective == 'pred_x0': 686 | target = x_start 687 | elif self.objective == 'pred_v': 688 | v = self.predict_v(x_start, t, noise) 689 | target = v 690 | else: 691 | raise ValueError(f'unknown objective {self.objective}') 692 | 693 | loss = self.loss_fn(model_out, target, reduction = 'none') 694 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 695 | 696 | loss = loss * extract(self.p2_loss_weight, t, loss.shape) 697 | return loss.mean() 698 | 699 | def forward(self, img, *args, **kwargs): 700 | b, c, n, device, seq_length, = *img.shape, img.device, self.seq_length 701 | assert n == seq_length, f'seq length must be {seq_length}' 702 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 703 | 704 | img = self.normalize(img) 705 | return self.p_losses(img, t, *args, **kwargs) 706 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/classifier_free_guidance.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | from pathlib import Path 4 | from random import random 5 | from functools import partial 6 | from collections import namedtuple 7 | from multiprocessing import cpu_count 8 | 9 | import torch 10 | from torch import nn, einsum 11 | import torch.nn.functional as F 12 | 13 | from einops import rearrange, reduce, repeat 14 | from einops.layers.torch import Rearrange 15 | 16 | from tqdm.auto import tqdm 17 | 18 | # constants 19 | 20 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) 21 | 22 | # helpers functions 23 | 24 | def exists(x): 25 | return x is not None 26 | 27 | def default(val, d): 28 | if exists(val): 29 | return val 30 | return d() if callable(d) else d 31 | 32 | def identity(t, *args, **kwargs): 33 | return t 34 | 35 | def cycle(dl): 36 | while True: 37 | for data in dl: 38 | yield data 39 | 40 | def has_int_squareroot(num): 41 | return (math.sqrt(num) ** 2) == num 42 | 43 | def num_to_groups(num, divisor): 44 | groups = num // divisor 45 | remainder = num % divisor 46 | arr = [divisor] * groups 47 | if remainder > 0: 48 | arr.append(remainder) 49 | return arr 50 | 51 | def convert_image_to_fn(img_type, image): 52 | if image.mode != img_type: 53 | return image.convert(img_type) 54 | return image 55 | 56 | # normalization functions 57 | 58 | def normalize_to_neg_one_to_one(img): 59 | return img * 2 - 1 60 | 61 | def unnormalize_to_zero_to_one(t): 62 | return (t + 1) * 0.5 63 | 64 | # classifier free guidance functions 65 | 66 | def uniform(shape, device): 67 | return torch.zeros(shape, device = device).float().uniform_(0, 1) 68 | 69 | def prob_mask_like(shape, prob, device): 70 | if prob == 1: 71 | return torch.ones(shape, device = device, dtype = torch.bool) 72 | elif prob == 0: 73 | return torch.zeros(shape, device = device, dtype = torch.bool) 74 | else: 75 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob 76 | 77 | # small helper modules 78 | 79 | class Residual(nn.Module): 80 | def __init__(self, fn): 81 | super().__init__() 82 | self.fn = fn 83 | 84 | def forward(self, x, *args, **kwargs): 85 | return self.fn(x, *args, **kwargs) + x 86 | 87 | def Upsample(dim, dim_out = None): 88 | return nn.Sequential( 89 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 90 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) 91 | ) 92 | 93 | def Downsample(dim, dim_out = None): 94 | return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) 95 | 96 | class WeightStandardizedConv2d(nn.Conv2d): 97 | """ 98 | https://arxiv.org/abs/1903.10520 99 | weight standardization purportedly works synergistically with group normalization 100 | """ 101 | def forward(self, x): 102 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 103 | 104 | weight = self.weight 105 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') 106 | var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) 107 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 108 | 109 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 110 | 111 | class LayerNorm(nn.Module): 112 | def __init__(self, dim): 113 | super().__init__() 114 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 115 | 116 | def forward(self, x): 117 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 118 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 119 | mean = torch.mean(x, dim = 1, keepdim = True) 120 | return (x - mean) * (var + eps).rsqrt() * self.g 121 | 122 | class PreNorm(nn.Module): 123 | def __init__(self, dim, fn): 124 | super().__init__() 125 | self.fn = fn 126 | self.norm = LayerNorm(dim) 127 | 128 | def forward(self, x): 129 | x = self.norm(x) 130 | return self.fn(x) 131 | 132 | # sinusoidal positional embeds 133 | 134 | class SinusoidalPosEmb(nn.Module): 135 | def __init__(self, dim): 136 | super().__init__() 137 | self.dim = dim 138 | 139 | def forward(self, x): 140 | device = x.device 141 | half_dim = self.dim // 2 142 | emb = math.log(10000) / (half_dim - 1) 143 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 144 | emb = x[:, None] * emb[None, :] 145 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 146 | return emb 147 | 148 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 149 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ 150 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 151 | 152 | def __init__(self, dim, is_random = False): 153 | super().__init__() 154 | assert (dim % 2) == 0 155 | half_dim = dim // 2 156 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) 157 | 158 | def forward(self, x): 159 | x = rearrange(x, 'b -> b 1') 160 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 161 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 162 | fouriered = torch.cat((x, fouriered), dim = -1) 163 | return fouriered 164 | 165 | # building block modules 166 | 167 | class Block(nn.Module): 168 | def __init__(self, dim, dim_out, groups = 8): 169 | super().__init__() 170 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) 171 | self.norm = nn.GroupNorm(groups, dim_out) 172 | self.act = nn.SiLU() 173 | 174 | def forward(self, x, scale_shift = None): 175 | x = self.proj(x) 176 | x = self.norm(x) 177 | 178 | if exists(scale_shift): 179 | scale, shift = scale_shift 180 | x = x * (scale + 1) + shift 181 | 182 | x = self.act(x) 183 | return x 184 | 185 | class ResnetBlock(nn.Module): 186 | def __init__(self, dim, dim_out, *, time_emb_dim = None, classes_emb_dim = None, groups = 8): 187 | super().__init__() 188 | self.mlp = nn.Sequential( 189 | nn.SiLU(), 190 | nn.Linear(int(time_emb_dim) + int(classes_emb_dim), dim_out * 2) 191 | ) if exists(time_emb_dim) or exists(classes_emb_dim) else None 192 | 193 | self.block1 = Block(dim, dim_out, groups = groups) 194 | self.block2 = Block(dim_out, dim_out, groups = groups) 195 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 196 | 197 | def forward(self, x, time_emb = None, class_emb = None): 198 | 199 | scale_shift = None 200 | if exists(self.mlp) and (exists(time_emb) or exists(class_emb)): 201 | cond_emb = tuple(filter(exists, (time_emb, class_emb))) 202 | cond_emb = torch.cat(cond_emb, dim = -1) 203 | cond_emb = self.mlp(cond_emb) 204 | cond_emb = rearrange(cond_emb, 'b c -> b c 1 1') 205 | scale_shift = cond_emb.chunk(2, dim = 1) 206 | 207 | h = self.block1(x, scale_shift = scale_shift) 208 | 209 | h = self.block2(h) 210 | 211 | return h + self.res_conv(x) 212 | 213 | class LinearAttention(nn.Module): 214 | def __init__(self, dim, heads = 4, dim_head = 32): 215 | super().__init__() 216 | self.scale = dim_head ** -0.5 217 | self.heads = heads 218 | hidden_dim = dim_head * heads 219 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 220 | 221 | self.to_out = nn.Sequential( 222 | nn.Conv2d(hidden_dim, dim, 1), 223 | LayerNorm(dim) 224 | ) 225 | 226 | def forward(self, x): 227 | b, c, h, w = x.shape 228 | qkv = self.to_qkv(x).chunk(3, dim = 1) 229 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 230 | 231 | q = q.softmax(dim = -2) 232 | k = k.softmax(dim = -1) 233 | 234 | q = q * self.scale 235 | v = v / (h * w) 236 | 237 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 238 | 239 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 240 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 241 | return self.to_out(out) 242 | 243 | class Attention(nn.Module): 244 | def __init__(self, dim, heads = 4, dim_head = 32): 245 | super().__init__() 246 | self.scale = dim_head ** -0.5 247 | self.heads = heads 248 | hidden_dim = dim_head * heads 249 | 250 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 251 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 252 | 253 | def forward(self, x): 254 | b, c, h, w = x.shape 255 | qkv = self.to_qkv(x).chunk(3, dim = 1) 256 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 257 | 258 | q = q * self.scale 259 | 260 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 261 | attn = sim.softmax(dim = -1) 262 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 263 | 264 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 265 | return self.to_out(out) 266 | 267 | # model 268 | 269 | class Unet(nn.Module): 270 | def __init__( 271 | self, 272 | dim, 273 | num_classes, 274 | cond_drop_prob = 0.5, 275 | init_dim = None, 276 | out_dim = None, 277 | dim_mults=(1, 2, 4, 8), 278 | channels = 3, 279 | resnet_block_groups = 8, 280 | learned_variance = False, 281 | learned_sinusoidal_cond = False, 282 | random_fourier_features = False, 283 | learned_sinusoidal_dim = 16, 284 | ): 285 | super().__init__() 286 | 287 | # classifier free guidance stuff 288 | 289 | self.cond_drop_prob = cond_drop_prob 290 | 291 | # determine dimensions 292 | 293 | self.channels = channels 294 | input_channels = channels 295 | 296 | init_dim = default(init_dim, dim) 297 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) 298 | 299 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 300 | in_out = list(zip(dims[:-1], dims[1:])) 301 | 302 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 303 | 304 | # time embeddings 305 | 306 | time_dim = dim * 4 307 | 308 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 309 | 310 | if self.random_or_learned_sinusoidal_cond: 311 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) 312 | fourier_dim = learned_sinusoidal_dim + 1 313 | else: 314 | sinu_pos_emb = SinusoidalPosEmb(dim) 315 | fourier_dim = dim 316 | 317 | self.time_mlp = nn.Sequential( 318 | sinu_pos_emb, 319 | nn.Linear(fourier_dim, time_dim), 320 | nn.GELU(), 321 | nn.Linear(time_dim, time_dim) 322 | ) 323 | 324 | # class embeddings 325 | 326 | self.classes_emb = nn.Embedding(num_classes, dim) 327 | self.null_classes_emb = nn.Parameter(torch.randn(dim)) 328 | 329 | classes_dim = dim * 4 330 | 331 | self.classes_mlp = nn.Sequential( 332 | nn.Linear(dim, classes_dim), 333 | nn.GELU(), 334 | nn.Linear(classes_dim, classes_dim) 335 | ) 336 | 337 | # layers 338 | 339 | self.downs = nn.ModuleList([]) 340 | self.ups = nn.ModuleList([]) 341 | num_resolutions = len(in_out) 342 | 343 | for ind, (dim_in, dim_out) in enumerate(in_out): 344 | is_last = ind >= (num_resolutions - 1) 345 | 346 | self.downs.append(nn.ModuleList([ 347 | block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim), 348 | block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim), 349 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 350 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) 351 | ])) 352 | 353 | mid_dim = dims[-1] 354 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) 355 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 356 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) 357 | 358 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 359 | is_last = ind == (len(in_out) - 1) 360 | 361 | self.ups.append(nn.ModuleList([ 362 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim), 363 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim), 364 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 365 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) 366 | ])) 367 | 368 | default_out_dim = channels * (1 if not learned_variance else 2) 369 | self.out_dim = default(out_dim, default_out_dim) 370 | 371 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) 372 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 373 | 374 | def forward_with_cond_scale( 375 | self, 376 | *args, 377 | cond_scale = 1., 378 | **kwargs 379 | ): 380 | logits = self.forward(*args, **kwargs) 381 | 382 | if cond_scale == 1: 383 | return logits 384 | 385 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) 386 | return null_logits + (logits - null_logits) * cond_scale 387 | 388 | def forward( 389 | self, 390 | x, 391 | time, 392 | classes, 393 | cond_drop_prob = None 394 | ): 395 | batch, device = x.shape[0], x.device 396 | 397 | cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) 398 | 399 | # derive condition, with condition dropout for classifier free guidance 400 | 401 | classes_emb = self.classes_emb(classes) 402 | 403 | if cond_drop_prob > 0: 404 | keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device) 405 | null_classes_emb = repeat(self.null_classes_emb, 'd -> b d', b = batch) 406 | 407 | classes_emb = torch.where( 408 | rearrange(keep_mask, 'b -> b 1'), 409 | classes_emb, 410 | null_classes_emb 411 | ) 412 | 413 | c = self.classes_mlp(classes_emb) 414 | 415 | # unet 416 | 417 | x = self.init_conv(x) 418 | r = x.clone() 419 | 420 | t = self.time_mlp(time) 421 | 422 | h = [] 423 | 424 | for block1, block2, attn, downsample in self.downs: 425 | x = block1(x, t, c) 426 | h.append(x) 427 | 428 | x = block2(x, t, c) 429 | x = attn(x) 430 | h.append(x) 431 | 432 | x = downsample(x) 433 | 434 | x = self.mid_block1(x, t, c) 435 | x = self.mid_attn(x) 436 | x = self.mid_block2(x, t, c) 437 | 438 | for block1, block2, attn, upsample in self.ups: 439 | x = torch.cat((x, h.pop()), dim = 1) 440 | x = block1(x, t, c) 441 | 442 | x = torch.cat((x, h.pop()), dim = 1) 443 | x = block2(x, t, c) 444 | x = attn(x) 445 | 446 | x = upsample(x) 447 | 448 | x = torch.cat((x, r), dim = 1) 449 | 450 | x = self.final_res_block(x, t, c) 451 | return self.final_conv(x) 452 | 453 | # gaussian diffusion trainer class 454 | 455 | def extract(a, t, x_shape): 456 | b, *_ = t.shape 457 | out = a.gather(-1, t) 458 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 459 | 460 | def linear_beta_schedule(timesteps): 461 | scale = 1000 / timesteps 462 | beta_start = scale * 0.0001 463 | beta_end = scale * 0.02 464 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 465 | 466 | def cosine_beta_schedule(timesteps, s = 0.008): 467 | """ 468 | cosine schedule 469 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 470 | """ 471 | steps = timesteps + 1 472 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 473 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 474 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 475 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 476 | return torch.clip(betas, 0, 0.999) 477 | 478 | class GaussianDiffusion(nn.Module): 479 | def __init__( 480 | self, 481 | model, 482 | *, 483 | image_size, 484 | timesteps = 1000, 485 | sampling_timesteps = None, 486 | loss_type = 'l1', 487 | objective = 'pred_noise', 488 | beta_schedule = 'cosine', 489 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended 490 | p2_loss_weight_k = 1, 491 | ddim_sampling_eta = 1. 492 | ): 493 | super().__init__() 494 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) 495 | assert not model.random_or_learned_sinusoidal_cond 496 | 497 | self.model = model 498 | self.channels = self.model.channels 499 | 500 | self.image_size = image_size 501 | 502 | self.objective = objective 503 | 504 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' 505 | 506 | if beta_schedule == 'linear': 507 | betas = linear_beta_schedule(timesteps) 508 | elif beta_schedule == 'cosine': 509 | betas = cosine_beta_schedule(timesteps) 510 | else: 511 | raise ValueError(f'unknown beta schedule {beta_schedule}') 512 | 513 | alphas = 1. - betas 514 | alphas_cumprod = torch.cumprod(alphas, dim=0) 515 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 516 | 517 | timesteps, = betas.shape 518 | self.num_timesteps = int(timesteps) 519 | self.loss_type = loss_type 520 | 521 | # sampling related parameters 522 | 523 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training 524 | 525 | assert self.sampling_timesteps <= timesteps 526 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 527 | self.ddim_sampling_eta = ddim_sampling_eta 528 | 529 | # helper function to register buffer from float64 to float32 530 | 531 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 532 | 533 | register_buffer('betas', betas) 534 | register_buffer('alphas_cumprod', alphas_cumprod) 535 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 536 | 537 | # calculations for diffusion q(x_t | x_{t-1}) and others 538 | 539 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 540 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 541 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 542 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 543 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 544 | 545 | # calculations for posterior q(x_{t-1} | x_t, x_0) 546 | 547 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 548 | 549 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 550 | 551 | register_buffer('posterior_variance', posterior_variance) 552 | 553 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 554 | 555 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 556 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 557 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 558 | 559 | # calculate p2 reweighting 560 | 561 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) 562 | 563 | def predict_start_from_noise(self, x_t, t, noise): 564 | return ( 565 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 566 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 567 | ) 568 | 569 | def predict_noise_from_start(self, x_t, t, x0): 570 | return ( 571 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ 572 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 573 | ) 574 | 575 | def predict_v(self, x_start, t, noise): 576 | return ( 577 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - 578 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start 579 | ) 580 | 581 | def predict_start_from_v(self, x_t, t, v): 582 | return ( 583 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 584 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 585 | ) 586 | 587 | def q_posterior(self, x_start, x_t, t): 588 | posterior_mean = ( 589 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 590 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 591 | ) 592 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 593 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 594 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 595 | 596 | def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False): 597 | model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale) 598 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity 599 | 600 | if self.objective == 'pred_noise': 601 | pred_noise = model_output 602 | x_start = self.predict_start_from_noise(x, t, pred_noise) 603 | x_start = maybe_clip(x_start) 604 | 605 | elif self.objective == 'pred_x0': 606 | x_start = model_output 607 | x_start = maybe_clip(x_start) 608 | pred_noise = self.predict_noise_from_start(x, t, x_start) 609 | 610 | elif self.objective == 'pred_v': 611 | v = model_output 612 | x_start = self.predict_start_from_v(x, t, v) 613 | x_start = maybe_clip(x_start) 614 | pred_noise = self.predict_noise_from_start(x, t, x_start) 615 | 616 | return ModelPrediction(pred_noise, x_start) 617 | 618 | def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True): 619 | preds = self.model_predictions(x, t, classes, cond_scale) 620 | x_start = preds.pred_x_start 621 | 622 | if clip_denoised: 623 | x_start.clamp_(-1., 1.) 624 | 625 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) 626 | return model_mean, posterior_variance, posterior_log_variance, x_start 627 | 628 | @torch.no_grad() 629 | def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True): 630 | b, *_, device = *x.shape, x.device 631 | batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) 632 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, classes = classes, cond_scale = cond_scale, clip_denoised = clip_denoised) 633 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 634 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 635 | return pred_img, x_start 636 | 637 | @torch.no_grad() 638 | def p_sample_loop(self, classes, shape, cond_scale = 3.): 639 | batch, device = shape[0], self.betas.device 640 | 641 | img = torch.randn(shape, device=device) 642 | 643 | x_start = None 644 | 645 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): 646 | img, x_start = self.p_sample(img, t, classes, cond_scale) 647 | 648 | img = unnormalize_to_zero_to_one(img) 649 | return img 650 | 651 | @torch.no_grad() 652 | def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True): 653 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 654 | 655 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps 656 | times = list(reversed(times.int().tolist())) 657 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] 658 | 659 | img = torch.randn(shape, device = device) 660 | 661 | x_start = None 662 | 663 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): 664 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long) 665 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, classes, cond_scale = cond_scale, clip_x_start = clip_denoised) 666 | 667 | if time_next < 0: 668 | img = x_start 669 | continue 670 | 671 | alpha = self.alphas_cumprod[time] 672 | alpha_next = self.alphas_cumprod[time_next] 673 | 674 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 675 | c = (1 - alpha_next - sigma ** 2).sqrt() 676 | 677 | noise = torch.randn_like(img) 678 | 679 | img = x_start * alpha_next.sqrt() + \ 680 | c * pred_noise + \ 681 | sigma * noise 682 | 683 | img = unnormalize_to_zero_to_one(img) 684 | return img 685 | 686 | @torch.no_grad() 687 | def sample(self, classes, cond_scale = 3.): 688 | batch_size, image_size, channels = classes.shape[0], self.image_size, self.channels 689 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample 690 | return sample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale) 691 | 692 | @torch.no_grad() 693 | def interpolate(self, x1, x2, t = None, lam = 0.5): 694 | b, *_, device = *x1.shape, x1.device 695 | t = default(t, self.num_timesteps - 1) 696 | 697 | assert x1.shape == x2.shape 698 | 699 | t_batched = torch.stack([torch.tensor(t, device = device)] * b) 700 | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) 701 | 702 | img = (1 - lam) * xt1 + lam * xt2 703 | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): 704 | img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) 705 | 706 | return img 707 | 708 | def q_sample(self, x_start, t, noise=None): 709 | noise = default(noise, lambda: torch.randn_like(x_start)) 710 | 711 | return ( 712 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 713 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 714 | ) 715 | 716 | @property 717 | def loss_fn(self): 718 | if self.loss_type == 'l1': 719 | return F.l1_loss 720 | elif self.loss_type == 'l2': 721 | return F.mse_loss 722 | else: 723 | raise ValueError(f'invalid loss type {self.loss_type}') 724 | 725 | def p_losses(self, x_start, t, *, classes, noise = None): 726 | b, c, h, w = x_start.shape 727 | noise = default(noise, lambda: torch.randn_like(x_start)) 728 | 729 | # noise sample 730 | 731 | x = self.q_sample(x_start = x_start, t = t, noise = noise) 732 | 733 | # predict and take gradient step 734 | 735 | model_out = self.model(x, t, classes) 736 | 737 | if self.objective == 'pred_noise': 738 | target = noise 739 | elif self.objective == 'pred_x0': 740 | target = x_start 741 | elif self.objective == 'pred_v': 742 | v = self.predict_v(x_start, t, noise) 743 | target = v 744 | else: 745 | raise ValueError(f'unknown objective {self.objective}') 746 | 747 | loss = self.loss_fn(model_out, target, reduction = 'none') 748 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 749 | 750 | loss = loss * extract(self.p2_loss_weight, t, loss.shape) 751 | return loss.mean() 752 | 753 | def forward(self, img, *args, **kwargs): 754 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 755 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 756 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 757 | 758 | img = normalize_to_neg_one_to_one(img) 759 | return self.p_losses(img, t, *args, **kwargs) 760 | 761 | # example 762 | 763 | if __name__ == '__main__': 764 | num_classes = 10 765 | 766 | model = Unet( 767 | dim = 64, 768 | dim_mults = (1, 2, 4, 8), 769 | num_classes = num_classes, 770 | cond_drop_prob = 0.5 771 | ) 772 | 773 | diffusion = GaussianDiffusion( 774 | model, 775 | image_size = 128, 776 | timesteps = 1000 777 | ).cuda() 778 | 779 | training_images = torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1 780 | image_classes = torch.randint(0, num_classes, (8,)).cuda() # say 10 classes 781 | 782 | loss = diffusion(training_images, classes = image_classes) 783 | loss.backward() 784 | 785 | # do above for many steps 786 | 787 | sampled_images = diffusion.sample( 788 | classes = image_classes, 789 | cond_scale = 3. # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically 790 | ) 791 | 792 | sampled_images.shape # (8, 3, 128, 128) 793 | -------------------------------------------------------------------------------- /denoising_diffusion_pytorch/denoising_diffusion_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | from pathlib import Path 4 | from random import random 5 | from functools import partial 6 | from collections import namedtuple 7 | from multiprocessing import cpu_count 8 | 9 | import torch 10 | from torch import nn, einsum 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | from torch.optim import Adam 15 | from torchvision import transforms as T, utils 16 | 17 | from einops import rearrange, reduce 18 | from einops.layers.torch import Rearrange 19 | 20 | from PIL import Image 21 | from tqdm.auto import tqdm 22 | from ema_pytorch import EMA 23 | 24 | from accelerate import Accelerator 25 | 26 | from denoising_diffusion_pytorch.version import __version__ 27 | 28 | # constants 29 | 30 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) 31 | 32 | # helpers functions 33 | 34 | def exists(x): 35 | return x is not None 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if callable(d) else d 41 | 42 | def identity(t, *args, **kwargs): 43 | return t 44 | 45 | def cycle(dl): 46 | while True: 47 | for data in dl: 48 | yield data 49 | 50 | def has_int_squareroot(num): 51 | return (math.sqrt(num) ** 2) == num 52 | 53 | def num_to_groups(num, divisor): 54 | groups = num // divisor 55 | remainder = num % divisor 56 | arr = [divisor] * groups 57 | if remainder > 0: 58 | arr.append(remainder) 59 | return arr 60 | 61 | def convert_image_to_fn(img_type, image): 62 | if image.mode != img_type: 63 | return image.convert(img_type) 64 | return image 65 | 66 | # normalization functions 67 | 68 | def normalize_to_neg_one_to_one(img): 69 | return img * 2 - 1 70 | 71 | def unnormalize_to_zero_to_one(t): 72 | return (t + 1) * 0.5 73 | 74 | # small helper modules 75 | 76 | class Residual(nn.Module): 77 | def __init__(self, fn): 78 | super().__init__() 79 | self.fn = fn 80 | 81 | def forward(self, x, *args, **kwargs): 82 | return self.fn(x, *args, **kwargs) + x 83 | 84 | def Upsample(dim, dim_out = None): 85 | return nn.Sequential( 86 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 87 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) 88 | ) 89 | 90 | def Downsample(dim, dim_out = None): 91 | return nn.Sequential( 92 | Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), 93 | nn.Conv2d(dim * 4, default(dim_out, dim), 1) 94 | ) 95 | 96 | class WeightStandardizedConv2d(nn.Conv2d): 97 | """ 98 | https://arxiv.org/abs/1903.10520 99 | weight standardization purportedly works synergistically with group normalization 100 | """ 101 | def forward(self, x): 102 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 103 | 104 | weight = self.weight 105 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') 106 | var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) 107 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 108 | 109 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 110 | 111 | class LayerNorm(nn.Module): 112 | def __init__(self, dim): 113 | super().__init__() 114 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 115 | 116 | def forward(self, x): 117 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 118 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 119 | mean = torch.mean(x, dim = 1, keepdim = True) 120 | return (x - mean) * (var + eps).rsqrt() * self.g 121 | 122 | class PreNorm(nn.Module): 123 | def __init__(self, dim, fn): 124 | super().__init__() 125 | self.fn = fn 126 | self.norm = LayerNorm(dim) 127 | 128 | def forward(self, x): 129 | x = self.norm(x) 130 | return self.fn(x) 131 | 132 | # sinusoidal positional embeds 133 | 134 | class SinusoidalPosEmb(nn.Module): 135 | def __init__(self, dim): 136 | super().__init__() 137 | self.dim = dim 138 | 139 | def forward(self, x): 140 | device = x.device 141 | half_dim = self.dim // 2 142 | emb = math.log(10000) / (half_dim - 1) 143 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 144 | emb = x[:, None] * emb[None, :] 145 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 146 | return emb 147 | 148 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 149 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ 150 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 151 | 152 | def __init__(self, dim, is_random = False): 153 | super().__init__() 154 | assert (dim % 2) == 0 155 | half_dim = dim // 2 156 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) 157 | 158 | def forward(self, x): 159 | x = rearrange(x, 'b -> b 1') 160 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 161 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 162 | fouriered = torch.cat((x, fouriered), dim = -1) 163 | return fouriered 164 | 165 | # building block modules 166 | 167 | class Block(nn.Module): 168 | def __init__(self, dim, dim_out, groups = 8): 169 | super().__init__() 170 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) 171 | self.norm = nn.GroupNorm(groups, dim_out) 172 | self.act = nn.SiLU() 173 | 174 | def forward(self, x, scale_shift = None): 175 | x = self.proj(x) 176 | x = self.norm(x) 177 | 178 | if exists(scale_shift): 179 | scale, shift = scale_shift 180 | x = x * (scale + 1) + shift 181 | 182 | x = self.act(x) 183 | return x 184 | 185 | class ResnetBlock(nn.Module): 186 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): 187 | super().__init__() 188 | self.mlp = nn.Sequential( 189 | nn.SiLU(), 190 | nn.Linear(time_emb_dim, dim_out * 2) 191 | ) if exists(time_emb_dim) else None 192 | 193 | self.block1 = Block(dim, dim_out, groups = groups) 194 | self.block2 = Block(dim_out, dim_out, groups = groups) 195 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 196 | 197 | def forward(self, x, time_emb = None): 198 | 199 | scale_shift = None 200 | if exists(self.mlp) and exists(time_emb): 201 | time_emb = self.mlp(time_emb) 202 | time_emb = rearrange(time_emb, 'b c -> b c 1 1') 203 | scale_shift = time_emb.chunk(2, dim = 1) 204 | 205 | h = self.block1(x, scale_shift = scale_shift) 206 | 207 | h = self.block2(h) 208 | 209 | return h + self.res_conv(x) 210 | 211 | class LinearAttention(nn.Module): 212 | def __init__(self, dim, heads = 4, dim_head = 32): 213 | super().__init__() 214 | self.scale = dim_head ** -0.5 215 | self.heads = heads 216 | hidden_dim = dim_head * heads 217 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 218 | 219 | self.to_out = nn.Sequential( 220 | nn.Conv2d(hidden_dim, dim, 1), 221 | LayerNorm(dim) 222 | ) 223 | 224 | def forward(self, x): 225 | b, c, h, w = x.shape 226 | qkv = self.to_qkv(x).chunk(3, dim = 1) 227 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 228 | 229 | q = q.softmax(dim = -2) 230 | k = k.softmax(dim = -1) 231 | 232 | q = q * self.scale 233 | v = v / (h * w) 234 | 235 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 236 | 237 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 238 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) 239 | return self.to_out(out) 240 | 241 | class Attention(nn.Module): 242 | def __init__(self, dim, heads = 4, dim_head = 32): 243 | super().__init__() 244 | self.scale = dim_head ** -0.5 245 | self.heads = heads 246 | hidden_dim = dim_head * heads 247 | 248 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 249 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 250 | 251 | def forward(self, x): 252 | b, c, h, w = x.shape 253 | qkv = self.to_qkv(x).chunk(3, dim = 1) 254 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) 255 | 256 | q = q * self.scale 257 | 258 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 259 | attn = sim.softmax(dim = -1) 260 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 261 | 262 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 263 | return self.to_out(out) 264 | 265 | # model 266 | 267 | class Unet(nn.Module): 268 | def __init__( 269 | self, 270 | dim, 271 | init_dim = None, 272 | out_dim = None, 273 | dim_mults=(1, 2, 4, 8), 274 | channels = 3, 275 | self_condition = False, 276 | resnet_block_groups = 8, 277 | learned_variance = False, 278 | learned_sinusoidal_cond = False, 279 | random_fourier_features = False, 280 | learned_sinusoidal_dim = 16 281 | ): 282 | super().__init__() 283 | 284 | # determine dimensions 285 | 286 | self.channels = channels 287 | self.self_condition = self_condition 288 | input_channels = channels * (2 if self_condition else 1) 289 | 290 | init_dim = default(init_dim, dim) 291 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) 292 | 293 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 294 | in_out = list(zip(dims[:-1], dims[1:])) 295 | 296 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 297 | 298 | # time embeddings 299 | 300 | time_dim = dim * 4 301 | 302 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features 303 | 304 | if self.random_or_learned_sinusoidal_cond: 305 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) 306 | fourier_dim = learned_sinusoidal_dim + 1 307 | else: 308 | sinu_pos_emb = SinusoidalPosEmb(dim) 309 | fourier_dim = dim 310 | 311 | self.time_mlp = nn.Sequential( 312 | sinu_pos_emb, 313 | nn.Linear(fourier_dim, time_dim), 314 | nn.GELU(), 315 | nn.Linear(time_dim, time_dim) 316 | ) 317 | 318 | # layers 319 | 320 | self.downs = nn.ModuleList([]) 321 | self.ups = nn.ModuleList([]) 322 | num_resolutions = len(in_out) 323 | 324 | for ind, (dim_in, dim_out) in enumerate(in_out): 325 | is_last = ind >= (num_resolutions - 1) 326 | 327 | self.downs.append(nn.ModuleList([ 328 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 329 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 330 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 331 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) 332 | ])) 333 | 334 | mid_dim = dims[-1] 335 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 336 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 337 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) 338 | 339 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 340 | is_last = ind == (len(in_out) - 1) 341 | 342 | self.ups.append(nn.ModuleList([ 343 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 344 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 345 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 346 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) 347 | ])) 348 | 349 | default_out_dim = channels * (1 if not learned_variance else 2) 350 | self.out_dim = default(out_dim, default_out_dim) 351 | 352 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 353 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 354 | 355 | def forward(self, x, time, x_self_cond = None): 356 | if self.self_condition: 357 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 358 | x = torch.cat((x_self_cond, x), dim = 1) 359 | 360 | x = self.init_conv(x) 361 | r = x.clone() 362 | 363 | t = self.time_mlp(time) 364 | 365 | h = [] 366 | 367 | for block1, block2, attn, downsample in self.downs: 368 | x = block1(x, t) 369 | h.append(x) 370 | 371 | x = block2(x, t) 372 | x = attn(x) 373 | h.append(x) 374 | 375 | x = downsample(x) 376 | 377 | x = self.mid_block1(x, t) 378 | x = self.mid_attn(x) 379 | x = self.mid_block2(x, t) 380 | 381 | for block1, block2, attn, upsample in self.ups: 382 | x = torch.cat((x, h.pop()), dim = 1) 383 | x = block1(x, t) 384 | 385 | x = torch.cat((x, h.pop()), dim = 1) 386 | x = block2(x, t) 387 | x = attn(x) 388 | 389 | x = upsample(x) 390 | 391 | x = torch.cat((x, r), dim = 1) 392 | 393 | x = self.final_res_block(x, t) 394 | return self.final_conv(x) 395 | 396 | # gaussian diffusion trainer class 397 | 398 | def extract(a, t, x_shape): 399 | b, *_ = t.shape 400 | out = a.gather(-1, t) 401 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 402 | 403 | def linear_beta_schedule(timesteps): 404 | """ 405 | linear schedule, proposed in original ddpm paper 406 | """ 407 | scale = 1000 / timesteps 408 | beta_start = scale * 0.0001 409 | beta_end = scale * 0.02 410 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 411 | 412 | def cosine_beta_schedule(timesteps, s = 0.008): 413 | """ 414 | cosine schedule 415 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 416 | """ 417 | steps = timesteps + 1 418 | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps 419 | alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 420 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 421 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 422 | return torch.clip(betas, 0, 0.999) 423 | 424 | def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): 425 | """ 426 | sigmoid schedule 427 | proposed in https://arxiv.org/abs/2212.11972 - Figure 8 428 | better for images > 64x64, when used during training 429 | """ 430 | steps = timesteps + 1 431 | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps 432 | v_start = torch.tensor(start / tau).sigmoid() 433 | v_end = torch.tensor(end / tau).sigmoid() 434 | alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) 435 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 436 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 437 | return torch.clip(betas, 0, 0.999) 438 | 439 | class GaussianDiffusion(nn.Module): 440 | def __init__( 441 | self, 442 | model, 443 | *, 444 | image_size, 445 | timesteps = 1000, 446 | sampling_timesteps = None, 447 | loss_type = 'l1', 448 | objective = 'pred_noise', 449 | beta_schedule = 'sigmoid', 450 | schedule_fn_kwargs = dict(), 451 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended 452 | p2_loss_weight_k = 1, 453 | ddim_sampling_eta = 0., 454 | auto_normalize = True 455 | ): 456 | super().__init__() 457 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) 458 | assert not model.random_or_learned_sinusoidal_cond 459 | 460 | self.model = model 461 | self.channels = self.model.channels 462 | self.self_condition = self.model.self_condition 463 | 464 | self.image_size = image_size 465 | 466 | self.objective = objective 467 | 468 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' 469 | 470 | if beta_schedule == 'linear': 471 | beta_schedule_fn = linear_beta_schedule 472 | elif beta_schedule == 'cosine': 473 | beta_schedule_fn = cosine_beta_schedule 474 | elif beta_schedule == 'sigmoid': 475 | beta_schedule_fn = sigmoid_beta_schedule 476 | else: 477 | raise ValueError(f'unknown beta schedule {beta_schedule}') 478 | 479 | betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) 480 | 481 | alphas = 1. - betas 482 | alphas_cumprod = torch.cumprod(alphas, dim=0) 483 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 484 | 485 | timesteps, = betas.shape 486 | self.num_timesteps = int(timesteps) 487 | self.loss_type = loss_type 488 | 489 | # sampling related parameters 490 | 491 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training 492 | 493 | assert self.sampling_timesteps <= timesteps 494 | self.is_ddim_sampling = self.sampling_timesteps < timesteps 495 | self.ddim_sampling_eta = ddim_sampling_eta 496 | 497 | # helper function to register buffer from float64 to float32 498 | 499 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 500 | 501 | register_buffer('betas', betas) 502 | register_buffer('alphas_cumprod', alphas_cumprod) 503 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 504 | 505 | # calculations for diffusion q(x_t | x_{t-1}) and others 506 | 507 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 508 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 509 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 510 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 511 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 512 | 513 | # calculations for posterior q(x_{t-1} | x_t, x_0) 514 | 515 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 516 | 517 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 518 | 519 | register_buffer('posterior_variance', posterior_variance) 520 | 521 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 522 | 523 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 524 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 525 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 526 | 527 | # calculate p2 reweighting 528 | 529 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) 530 | 531 | # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False 532 | 533 | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity 534 | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity 535 | 536 | def predict_start_from_noise(self, x_t, t, noise): 537 | return ( 538 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 539 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 540 | ) 541 | 542 | def predict_noise_from_start(self, x_t, t, x0): 543 | return ( 544 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ 545 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 546 | ) 547 | 548 | def predict_v(self, x_start, t, noise): 549 | return ( 550 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - 551 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start 552 | ) 553 | 554 | def predict_start_from_v(self, x_t, t, v): 555 | return ( 556 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 557 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 558 | ) 559 | 560 | def q_posterior(self, x_start, x_t, t): 561 | posterior_mean = ( 562 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 563 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 564 | ) 565 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 566 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 567 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 568 | 569 | def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): 570 | model_output = self.model(x, t, x_self_cond) 571 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity 572 | 573 | if self.objective == 'pred_noise': 574 | pred_noise = model_output 575 | x_start = self.predict_start_from_noise(x, t, pred_noise) 576 | x_start = maybe_clip(x_start) 577 | 578 | elif self.objective == 'pred_x0': 579 | x_start = model_output 580 | x_start = maybe_clip(x_start) 581 | pred_noise = self.predict_noise_from_start(x, t, x_start) 582 | 583 | elif self.objective == 'pred_v': 584 | v = model_output 585 | x_start = self.predict_start_from_v(x, t, v) 586 | x_start = maybe_clip(x_start) 587 | pred_noise = self.predict_noise_from_start(x, t, x_start) 588 | 589 | return ModelPrediction(pred_noise, x_start) 590 | 591 | def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): 592 | preds = self.model_predictions(x, t, x_self_cond) 593 | x_start = preds.pred_x_start 594 | 595 | if clip_denoised: 596 | x_start.clamp_(-1., 1.) 597 | 598 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) 599 | return model_mean, posterior_variance, posterior_log_variance, x_start 600 | 601 | @torch.no_grad() 602 | def p_sample(self, x, t: int, x_self_cond = None): 603 | b, *_, device = *x.shape, x.device 604 | batched_times = torch.full((b,), t, device = x.device, dtype = torch.long) 605 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) 606 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 607 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise 608 | return pred_img, x_start 609 | 610 | @torch.no_grad() 611 | def p_sample_loop(self, shape, return_all_timesteps = False): 612 | batch, device = shape[0], self.betas.device 613 | 614 | img = torch.randn(shape, device = device) 615 | imgs = [img] 616 | 617 | x_start = None 618 | 619 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): 620 | self_cond = x_start if self.self_condition else None 621 | img, x_start = self.p_sample(img, t, self_cond) 622 | imgs.append(img) 623 | 624 | ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) 625 | 626 | ret = self.unnormalize(ret) 627 | return ret 628 | 629 | @torch.no_grad() 630 | def ddim_sample(self, shape, return_all_timesteps = False): 631 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective 632 | 633 | times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps 634 | times = list(reversed(times.int().tolist())) 635 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] 636 | 637 | img = torch.randn(shape, device = device) 638 | imgs = [img] 639 | 640 | x_start = None 641 | 642 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): 643 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long) 644 | self_cond = x_start if self.self_condition else None 645 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) 646 | 647 | imgs.append(img) 648 | 649 | if time_next < 0: 650 | img = x_start 651 | continue 652 | 653 | alpha = self.alphas_cumprod[time] 654 | alpha_next = self.alphas_cumprod[time_next] 655 | 656 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() 657 | c = (1 - alpha_next - sigma ** 2).sqrt() 658 | 659 | noise = torch.randn_like(img) 660 | 661 | img = x_start * alpha_next.sqrt() + \ 662 | c * pred_noise + \ 663 | sigma * noise 664 | 665 | ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) 666 | 667 | ret = self.unnormalize(ret) 668 | return ret 669 | 670 | @torch.no_grad() 671 | def sample(self, batch_size = 16, return_all_timesteps = False): 672 | image_size, channels = self.image_size, self.channels 673 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample 674 | return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) 675 | 676 | @torch.no_grad() 677 | def interpolate(self, x1, x2, t = None, lam = 0.5): 678 | b, *_, device = *x1.shape, x1.device 679 | t = default(t, self.num_timesteps - 1) 680 | 681 | assert x1.shape == x2.shape 682 | 683 | t_batched = torch.full((b,), t, device = device) 684 | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) 685 | 686 | img = (1 - lam) * xt1 + lam * xt2 687 | 688 | x_start = None 689 | 690 | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): 691 | self_cond = x_start if self.self_condition else None 692 | img, x_start = self.p_sample(img, i, self_cond) 693 | 694 | return img 695 | 696 | def q_sample(self, x_start, t, noise=None): 697 | noise = default(noise, lambda: torch.randn_like(x_start)) 698 | 699 | return ( 700 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 701 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 702 | ) 703 | 704 | @property 705 | def loss_fn(self): 706 | if self.loss_type == 'l1': 707 | return F.l1_loss 708 | elif self.loss_type == 'l2': 709 | return F.mse_loss 710 | else: 711 | raise ValueError(f'invalid loss type {self.loss_type}') 712 | 713 | def p_losses(self, x_start, t, noise = None): 714 | b, c, h, w = x_start.shape 715 | noise = default(noise, lambda: torch.randn_like(x_start)) 716 | 717 | # noise sample 718 | 719 | x = self.q_sample(x_start = x_start, t = t, noise = noise) 720 | 721 | # if doing self-conditioning, 50% of the time, predict x_start from current set of times 722 | # and condition with unet with that 723 | # this technique will slow down training by 25%, but seems to lower FID significantly 724 | 725 | x_self_cond = None 726 | if self.self_condition and random() < 0.5: 727 | with torch.no_grad(): 728 | x_self_cond = self.model_predictions(x, t).pred_x_start 729 | x_self_cond.detach_() 730 | 731 | # predict and take gradient step 732 | 733 | model_out = self.model(x, t, x_self_cond) 734 | 735 | if self.objective == 'pred_noise': 736 | target = noise 737 | elif self.objective == 'pred_x0': 738 | target = x_start 739 | elif self.objective == 'pred_v': 740 | v = self.predict_v(x_start, t, noise) 741 | target = v 742 | else: 743 | raise ValueError(f'unknown objective {self.objective}') 744 | 745 | loss = self.loss_fn(model_out, target, reduction = 'none') 746 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 747 | 748 | loss = loss * extract(self.p2_loss_weight, t, loss.shape) 749 | return loss.mean() 750 | 751 | def forward(self, img, *args, **kwargs): 752 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size 753 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' 754 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() 755 | 756 | img = self.normalize(img) 757 | return self.p_losses(img, t, *args, **kwargs) 758 | 759 | # dataset classes 760 | 761 | class Dataset(Dataset): 762 | def __init__( 763 | self, 764 | folder, 765 | image_size, 766 | exts = ['jpg', 'jpeg', 'png', 'tiff'], 767 | augment_horizontal_flip = False, 768 | convert_image_to = None 769 | ): 770 | super().__init__() 771 | self.folder = folder 772 | self.image_size = image_size 773 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 774 | 775 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity() 776 | 777 | self.transform = T.Compose([ 778 | T.Lambda(maybe_convert_fn), 779 | T.Resize(image_size), 780 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), 781 | T.CenterCrop(image_size), 782 | T.ToTensor() 783 | ]) 784 | 785 | def __len__(self): 786 | return len(self.paths) 787 | 788 | def __getitem__(self, index): 789 | path = self.paths[index] 790 | img = Image.open(path) 791 | return self.transform(img) 792 | 793 | # trainer class 794 | 795 | class Trainer(object): 796 | def __init__( 797 | self, 798 | diffusion_model, 799 | folder, 800 | *, 801 | train_batch_size = 16, 802 | gradient_accumulate_every = 1, 803 | augment_horizontal_flip = True, 804 | train_lr = 1e-4, 805 | train_num_steps = 100000, 806 | ema_update_every = 10, 807 | ema_decay = 0.995, 808 | adam_betas = (0.9, 0.99), 809 | save_and_sample_every = 1000, 810 | num_samples = 25, 811 | results_folder = './results', 812 | amp = False, 813 | fp16 = False, 814 | split_batches = True, 815 | convert_image_to = None 816 | ): 817 | super().__init__() 818 | 819 | self.accelerator = Accelerator( 820 | split_batches = split_batches, 821 | mixed_precision = 'fp16' if fp16 else 'no' 822 | ) 823 | 824 | self.accelerator.native_amp = amp 825 | 826 | self.model = diffusion_model 827 | 828 | assert has_int_squareroot(num_samples), 'number of samples must have an integer square root' 829 | self.num_samples = num_samples 830 | self.save_and_sample_every = save_and_sample_every 831 | 832 | self.batch_size = train_batch_size 833 | self.gradient_accumulate_every = gradient_accumulate_every 834 | 835 | self.train_num_steps = train_num_steps 836 | self.image_size = diffusion_model.image_size 837 | 838 | # dataset and dataloader 839 | 840 | self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to) 841 | dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count()) 842 | 843 | dl = self.accelerator.prepare(dl) 844 | self.dl = cycle(dl) 845 | 846 | # optimizer 847 | 848 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) 849 | 850 | # for logging results in a folder periodically 851 | 852 | if self.accelerator.is_main_process: 853 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) 854 | 855 | self.results_folder = Path(results_folder) 856 | self.results_folder.mkdir(exist_ok = True) 857 | 858 | # step counter state 859 | 860 | self.step = 0 861 | 862 | # prepare model, dataloader, optimizer with accelerator 863 | 864 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt) 865 | 866 | def save(self, milestone): 867 | if not self.accelerator.is_local_main_process: 868 | return 869 | 870 | data = { 871 | 'step': self.step, 872 | 'model': self.accelerator.get_state_dict(self.model), 873 | 'opt': self.opt.state_dict(), 874 | 'ema': self.ema.state_dict(), 875 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, 876 | 'version': __version__ 877 | } 878 | 879 | torch.save(data, str(self.results_folder / f'model-{milestone}.pt')) 880 | 881 | def load(self, milestone): 882 | accelerator = self.accelerator 883 | device = accelerator.device 884 | 885 | data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device) 886 | 887 | model = self.accelerator.unwrap_model(self.model) 888 | model.load_state_dict(data['model']) 889 | 890 | self.step = data['step'] 891 | self.opt.load_state_dict(data['opt']) 892 | self.ema.load_state_dict(data['ema']) 893 | 894 | if 'version' in data: 895 | print(f"loading from version {data['version']}") 896 | 897 | if exists(self.accelerator.scaler) and exists(data['scaler']): 898 | self.accelerator.scaler.load_state_dict(data['scaler']) 899 | 900 | def train(self): 901 | accelerator = self.accelerator 902 | device = accelerator.device 903 | 904 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: 905 | 906 | while self.step < self.train_num_steps: 907 | 908 | total_loss = 0. 909 | 910 | for _ in range(self.gradient_accumulate_every): 911 | data = next(self.dl).to(device) 912 | 913 | with self.accelerator.autocast(): 914 | loss = self.model(data) 915 | loss = loss / self.gradient_accumulate_every 916 | total_loss += loss.item() 917 | 918 | self.accelerator.backward(loss) 919 | 920 | accelerator.clip_grad_norm_(self.model.parameters(), 1.0) 921 | pbar.set_description(f'loss: {total_loss:.4f}') 922 | 923 | accelerator.wait_for_everyone() 924 | 925 | self.opt.step() 926 | self.opt.zero_grad() 927 | 928 | accelerator.wait_for_everyone() 929 | 930 | self.step += 1 931 | if accelerator.is_main_process: 932 | self.ema.to(device) 933 | self.ema.update() 934 | 935 | if self.step != 0 and self.step % self.save_and_sample_every == 0: 936 | self.ema.ema_model.eval() 937 | 938 | with torch.no_grad(): 939 | milestone = self.step // self.save_and_sample_every 940 | batches = num_to_groups(self.num_samples, self.batch_size) 941 | all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches)) 942 | 943 | all_images = torch.cat(all_images_list, dim = 0) 944 | utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples))) 945 | self.save(milestone) 946 | 947 | pbar.update(1) 948 | 949 | accelerator.print('training complete') 950 | --------------------------------------------------------------------------------