├── denoising_diffusion_pytorch
├── version.py
├── __init__.py
├── weighted_objective_gaussian_diffusion.py
├── learned_gaussian_diffusion.py
├── v_param_continuous_time_gaussian_diffusion.py
├── elucidated_diffusion.py
├── continuous_time_gaussian_diffusion.py
├── denoising_diffusion_pytorch_1d.py
├── classifier_free_guidance.py
└── denoising_diffusion_pytorch.py
├── images
├── sample.png
└── denoising-diffusion.png
├── .github
└── workflows
│ └── python-publish.yml
├── setup.py
├── LICENSE
├── .gitignore
└── README.md
/denoising_diffusion_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.0.6'
2 |
--------------------------------------------------------------------------------
/images/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/denoising-diffusion-pytorch/main/images/sample.png
--------------------------------------------------------------------------------
/images/denoising-diffusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/denoising-diffusion-pytorch/main/images/denoising-diffusion.png
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
2 |
3 | from denoising_diffusion_pytorch.learned_gaussian_diffusion import LearnedGaussianDiffusion
4 | from denoising_diffusion_pytorch.continuous_time_gaussian_diffusion import ContinuousTimeGaussianDiffusion
5 | from denoising_diffusion_pytorch.weighted_objective_gaussian_diffusion import WeightedObjectiveGaussianDiffusion
6 | from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion
7 | from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion
8 |
9 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D
10 |
11 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | exec(open('denoising_diffusion_pytorch/version.py').read())
4 |
5 | setup(
6 | name = 'denoising-diffusion-pytorch',
7 | packages = find_packages(),
8 | version = __version__,
9 | license='MIT',
10 | description = 'Denoising Diffusion Probabilistic Models - Pytorch',
11 | author = 'Phil Wang',
12 | author_email = 'lucidrains@gmail.com',
13 | url = 'https://github.com/lucidrains/denoising-diffusion-pytorch',
14 | long_description_content_type = 'text/markdown',
15 | keywords = [
16 | 'artificial intelligence',
17 | 'generative models'
18 | ],
19 | install_requires=[
20 | 'accelerate',
21 | 'einops',
22 | 'ema-pytorch',
23 | 'pillow',
24 | 'torch',
25 | 'torchvision',
26 | 'tqdm'
27 | ],
28 | classifiers=[
29 | 'Development Status :: 4 - Beta',
30 | 'Intended Audience :: Developers',
31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
32 | 'License :: OSI Approved :: MIT License',
33 | 'Programming Language :: Python :: 3.6',
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generation results
2 | results/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/weighted_objective_gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from inspect import isfunction
3 | from torch import nn, einsum
4 | from einops import rearrange
5 |
6 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion
7 |
8 | # helper functions
9 |
10 | def exists(x):
11 | return x is not None
12 |
13 | def default(val, d):
14 | if exists(val):
15 | return val
16 | return d() if isfunction(d) else d
17 |
18 | # some improvisation on my end
19 | # where i have the model learn to both predict noise and x0
20 | # and learn the weighted sum for each depending on time step
21 |
22 | class WeightedObjectiveGaussianDiffusion(GaussianDiffusion):
23 | def __init__(
24 | self,
25 | model,
26 | *args,
27 | pred_noise_loss_weight = 0.1,
28 | pred_x_start_loss_weight = 0.1,
29 | **kwargs
30 | ):
31 | super().__init__(model, *args, **kwargs)
32 | channels = model.channels
33 | assert model.out_dim == (channels * 2 + 2), 'dimension out (out_dim) of unet must be twice the number of channels + 2 (for the softmax weighted sum) - for channels of 3, this should be (3 * 2) + 2 = 8'
34 | assert not model.self_condition, 'not supported yet'
35 | assert not self.is_ddim_sampling, 'ddim sampling cannot be used'
36 |
37 | self.split_dims = (channels, channels, 2)
38 | self.pred_noise_loss_weight = pred_noise_loss_weight
39 | self.pred_x_start_loss_weight = pred_x_start_loss_weight
40 |
41 | def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
42 | model_output = self.model(x, t)
43 |
44 | pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
45 | normalized_weights = weights.softmax(dim = 1)
46 |
47 | x_start_from_noise = self.predict_start_from_noise(x, t = t, noise = pred_noise)
48 |
49 | x_starts = torch.stack((x_start_from_noise, pred_x_start), dim = 1)
50 | weighted_x_start = einsum('b j h w, b j c h w -> b c h w', normalized_weights, x_starts)
51 |
52 | if clip_denoised:
53 | weighted_x_start.clamp_(-1., 1.)
54 |
55 | model_mean, model_variance, model_log_variance = self.q_posterior(weighted_x_start, x, t)
56 |
57 | return model_mean, model_variance, model_log_variance
58 |
59 | def p_losses(self, x_start, t, noise = None, clip_denoised = False):
60 | noise = default(noise, lambda: torch.randn_like(x_start))
61 | x_t = self.q_sample(x_start = x_start, t = t, noise = noise)
62 |
63 | model_output = self.model(x_t, t)
64 | pred_noise, pred_x_start, weights = model_output.split(self.split_dims, dim = 1)
65 |
66 | # get loss for predicted noise and x_start
67 | # with the loss weight given at initialization
68 |
69 | noise_loss = self.loss_fn(noise, pred_noise) * self.pred_noise_loss_weight
70 | x_start_loss = self.loss_fn(x_start, pred_x_start) * self.pred_x_start_loss_weight
71 |
72 | # calculate x_start from predicted noise
73 | # then do a weighted sum of the x_start prediction, weights also predicted by the model (softmax normalized)
74 |
75 | x_start_from_pred_noise = self.predict_start_from_noise(x_t, t, pred_noise)
76 | x_start_from_pred_noise = x_start_from_pred_noise.clamp(-2., 2.)
77 | weighted_x_start = einsum('b j h w, b j c h w -> b c h w', weights.softmax(dim = 1), torch.stack((x_start_from_pred_noise, pred_x_start), dim = 1))
78 |
79 | # main loss to x_start with the weighted one
80 |
81 | weighted_x_start_loss = self.loss_fn(x_start, weighted_x_start)
82 | return weighted_x_start_loss + x_start_loss + noise_loss
83 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/learned_gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import namedtuple
3 | from math import pi, sqrt, log as ln
4 | from inspect import isfunction
5 | from torch import nn, einsum
6 | from einops import rearrange
7 |
8 | from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one
9 |
10 | # constants
11 |
12 | NAT = 1. / ln(2)
13 |
14 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])
15 |
16 | # helper functions
17 |
18 | def exists(x):
19 | return x is not None
20 |
21 | def default(val, d):
22 | if exists(val):
23 | return val
24 | return d() if isfunction(d) else d
25 |
26 | # tensor helpers
27 |
28 | def log(t, eps = 1e-15):
29 | return torch.log(t.clamp(min = eps))
30 |
31 | def meanflat(x):
32 | return x.mean(dim = tuple(range(1, len(x.shape))))
33 |
34 | def normal_kl(mean1, logvar1, mean2, logvar2):
35 | """
36 | KL divergence between normal distributions parameterized by mean and log-variance.
37 | """
38 | return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
39 |
40 | def approx_standard_normal_cdf(x):
41 | return 0.5 * (1.0 + torch.tanh(sqrt(2.0 / pi) * (x + 0.044715 * (x ** 3))))
42 |
43 | def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
44 | assert x.shape == means.shape == log_scales.shape
45 |
46 | centered_x = x - means
47 | inv_stdv = torch.exp(-log_scales)
48 | plus_in = inv_stdv * (centered_x + 1. / 255.)
49 | cdf_plus = approx_standard_normal_cdf(plus_in)
50 | min_in = inv_stdv * (centered_x - 1. / 255.)
51 | cdf_min = approx_standard_normal_cdf(min_in)
52 | log_cdf_plus = log(cdf_plus)
53 | log_one_minus_cdf_min = log(1. - cdf_min)
54 | cdf_delta = cdf_plus - cdf_min
55 |
56 | log_probs = torch.where(x < -thres,
57 | log_cdf_plus,
58 | torch.where(x > thres,
59 | log_one_minus_cdf_min,
60 | log(cdf_delta)))
61 |
62 | return log_probs
63 |
64 | # https://arxiv.org/abs/2102.09672
65 |
66 | # i thought the results were questionable, if one were to focus only on FID
67 | # but may as well get this in here for others to try, as GLIDE is using it (and DALL-E2 first stage of cascade)
68 | # gaussian diffusion for learned variance + hybrid eps simple + vb loss
69 |
70 | class LearnedGaussianDiffusion(GaussianDiffusion):
71 | def __init__(
72 | self,
73 | model,
74 | vb_loss_weight = 0.001, # lambda was 0.001 in the paper
75 | *args,
76 | **kwargs
77 | ):
78 | super().__init__(model, *args, **kwargs)
79 | assert model.out_dim == (model.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
80 | assert not model.self_condition, 'not supported yet'
81 |
82 | self.vb_loss_weight = vb_loss_weight
83 |
84 | def model_predictions(self, x, t):
85 | model_output = self.model(x, t)
86 | model_output, pred_variance = model_output.chunk(2, dim = 1)
87 |
88 | if self.objective == 'pred_noise':
89 | pred_noise = model_output
90 | x_start = self.predict_start_from_noise(x, t, model_output)
91 |
92 | elif self.objective == 'pred_x0':
93 | pred_noise = self.predict_noise_from_start(x, t, model_output)
94 | x_start = model_output
95 |
96 | return ModelPrediction(pred_noise, x_start, pred_variance)
97 |
98 | def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
99 | model_output = default(model_output, lambda: self.model(x, t))
100 | pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)
101 |
102 | min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
103 | max_log = extract(torch.log(self.betas), t, x.shape)
104 | var_interp_frac = unnormalize_to_zero_to_one(var_interp_frac_unnormalized)
105 |
106 | model_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
107 | model_variance = model_log_variance.exp()
108 |
109 | x_start = self.predict_start_from_noise(x, t, pred_noise)
110 |
111 | if clip_denoised:
112 | x_start.clamp_(-1., 1.)
113 |
114 | model_mean, _, _ = self.q_posterior(x_start, x, t)
115 |
116 | return model_mean, model_variance, model_log_variance
117 |
118 | def p_losses(self, x_start, t, noise = None, clip_denoised = False):
119 | noise = default(noise, lambda: torch.randn_like(x_start))
120 | x_t = self.q_sample(x_start = x_start, t = t, noise = noise)
121 |
122 | # model output
123 |
124 | model_output = self.model(x_t, t)
125 |
126 | # calculating kl loss for learned variance (interpolation)
127 |
128 | true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
129 | model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)
130 |
131 | # kl loss with detached model predicted mean, for stability reasons as in paper
132 |
133 | detached_model_mean = model_mean.detach()
134 |
135 | kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
136 | kl = meanflat(kl) * NAT
137 |
138 | decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
139 | decoder_nll = meanflat(decoder_nll) * NAT
140 |
141 | # at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
142 |
143 | vb_losses = torch.where(t == 0, decoder_nll, kl)
144 |
145 | # simple loss - predicting noise, x0, or x_prev
146 |
147 | pred_noise, _ = model_output.chunk(2, dim = 1)
148 |
149 | simple_losses = self.loss_fn(pred_noise, noise)
150 |
151 | return simple_losses + vb_losses.mean() * self.vb_loss_weight
152 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/v_param_continuous_time_gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import sqrt
4 | from torch import nn, einsum
5 | import torch.nn.functional as F
6 | from torch.special import expm1
7 |
8 | from tqdm import tqdm
9 | from einops import rearrange, repeat, reduce
10 | from einops.layers.torch import Rearrange
11 |
12 | # helpers
13 |
14 | def exists(val):
15 | return val is not None
16 |
17 | def default(val, d):
18 | if exists(val):
19 | return val
20 | return d() if callable(d) else d
21 |
22 | # normalization functions
23 |
24 | def normalize_to_neg_one_to_one(img):
25 | return img * 2 - 1
26 |
27 | def unnormalize_to_zero_to_one(t):
28 | return (t + 1) * 0.5
29 |
30 | # diffusion helpers
31 |
32 | def right_pad_dims_to(x, t):
33 | padding_dims = x.ndim - t.ndim
34 | if padding_dims <= 0:
35 | return t
36 | return t.view(*t.shape, *((1,) * padding_dims))
37 |
38 | # continuous schedules
39 | # log(snr) that approximates the original linear schedule
40 |
41 | def log(t, eps = 1e-20):
42 | return torch.log(t.clamp(min = eps))
43 |
44 | def alpha_cosine_log_snr(t, s = 0.008):
45 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5)
46 |
47 | class VParamContinuousTimeGaussianDiffusion(nn.Module):
48 | """
49 | a new type of parameterization in v-space proposed in https://arxiv.org/abs/2202.00512 that
50 | (1) allows for improved distillation over noise prediction objective and
51 | (2) noted in imagen-video to improve upsampling unets by removing the color shifting artifacts
52 | """
53 |
54 | def __init__(
55 | self,
56 | model,
57 | *,
58 | image_size,
59 | channels = 3,
60 | num_sample_steps = 500,
61 | clip_sample_denoised = True,
62 | ):
63 | super().__init__()
64 | assert model.random_or_learned_sinusoidal_cond
65 | assert not model.self_condition, 'not supported yet'
66 |
67 | self.model = model
68 |
69 | # image dimensions
70 |
71 | self.channels = channels
72 | self.image_size = image_size
73 |
74 | # continuous noise schedule related stuff
75 |
76 | self.log_snr = alpha_cosine_log_snr
77 |
78 | # sampling
79 |
80 | self.num_sample_steps = num_sample_steps
81 | self.clip_sample_denoised = clip_sample_denoised
82 |
83 | @property
84 | def device(self):
85 | return next(self.model.parameters()).device
86 |
87 | def p_mean_variance(self, x, time, time_next):
88 | # reviewer found an error in the equation in the paper (missing sigma)
89 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt
90 |
91 | log_snr = self.log_snr(time)
92 | log_snr_next = self.log_snr(time_next)
93 | c = -expm1(log_snr - log_snr_next)
94 |
95 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
96 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
97 |
98 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))
99 |
100 | batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
101 |
102 | pred_v = self.model(x, batch_log_snr)
103 |
104 | # shown in Appendix D in the paper
105 | x_start = alpha * x - sigma * pred_v
106 |
107 | if self.clip_sample_denoised:
108 | x_start.clamp_(-1., 1.)
109 |
110 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)
111 |
112 | posterior_variance = squared_sigma_next * c
113 |
114 | return model_mean, posterior_variance
115 |
116 | # sampling related functions
117 |
118 | @torch.no_grad()
119 | def p_sample(self, x, time, time_next):
120 | batch, *_, device = *x.shape, x.device
121 |
122 | model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next)
123 |
124 | if time_next == 0:
125 | return model_mean
126 |
127 | noise = torch.randn_like(x)
128 | return model_mean + sqrt(model_variance) * noise
129 |
130 | @torch.no_grad()
131 | def p_sample_loop(self, shape):
132 | batch = shape[0]
133 |
134 | img = torch.randn(shape, device = self.device)
135 | steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device)
136 |
137 | for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps):
138 | times = steps[i]
139 | times_next = steps[i + 1]
140 | img = self.p_sample(img, times, times_next)
141 |
142 | img.clamp_(-1., 1.)
143 | img = unnormalize_to_zero_to_one(img)
144 | return img
145 |
146 | @torch.no_grad()
147 | def sample(self, batch_size = 16):
148 | return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size))
149 |
150 | # training related functions - noise prediction
151 |
152 | def q_sample(self, x_start, times, noise = None):
153 | noise = default(noise, lambda: torch.randn_like(x_start))
154 |
155 | log_snr = self.log_snr(times)
156 |
157 | log_snr_padded = right_pad_dims_to(x_start, log_snr)
158 | alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid())
159 | x_noised = x_start * alpha + noise * sigma
160 |
161 | return x_noised, log_snr, alpha, sigma
162 |
163 | def random_times(self, batch_size):
164 | return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1)
165 |
166 | def p_losses(self, x_start, times, noise = None):
167 | noise = default(noise, lambda: torch.randn_like(x_start))
168 |
169 | x, log_snr, alpha, sigma = self.q_sample(x_start = x_start, times = times, noise = noise)
170 |
171 | # described in section 4 as the prediction objective, with derivation in Appendix D
172 | v = alpha * noise - sigma * x_start
173 |
174 | model_out = self.model(x, log_snr)
175 |
176 | return F.mse_loss(model_out, v)
177 |
178 | def forward(self, img, *args, **kwargs):
179 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
180 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
181 |
182 | times = self.random_times(b)
183 | img = normalize_to_neg_one_to_one(img)
184 | return self.p_losses(img, times, *args, **kwargs)
185 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Denoising Diffusion Probabilistic Model, in Pytorch
4 |
5 | Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to generative modeling that may have the potential to rival GANs. It uses denoising score matching to estimate the gradient of the data distribution, followed by Langevin sampling to sample from the true distribution.
6 |
7 | This implementation was transcribed from the official Tensorflow version here
8 |
9 | Youtube AI Educators - Yannic Kilcher | AI Coffeebreak with Letitia | Outlier
10 |
11 | Flax implementation from YiYi Xu
12 |
13 | Annotated code by Research Scientists / Engineers from 🤗 Huggingface
14 |
15 | Update: Turns out none of the technicalities really matters at all | "Cold Diffusion" paper | Muse
16 |
17 | 
18 |
19 | [](https://badge.fury.io/py/denoising-diffusion-pytorch)
20 |
21 | ## Install
22 |
23 | ```bash
24 | $ pip install denoising_diffusion_pytorch
25 | ```
26 |
27 | ## Usage
28 |
29 | ```python
30 | import torch
31 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion
32 |
33 | model = Unet(
34 | dim = 64,
35 | dim_mults = (1, 2, 4, 8)
36 | )
37 |
38 | diffusion = GaussianDiffusion(
39 | model,
40 | image_size = 128,
41 | timesteps = 1000, # number of steps
42 | loss_type = 'l1' # L1 or L2
43 | )
44 |
45 | training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
46 | loss = diffusion(training_images)
47 | loss.backward()
48 | # after a lot of training
49 |
50 | sampled_images = diffusion.sample(batch_size = 4)
51 | sampled_images.shape # (4, 3, 128, 128)
52 | ```
53 |
54 | Or, if you simply want to pass in a folder name and the desired image dimensions, you can use the `Trainer` class to easily train a model.
55 |
56 | ```python
57 | from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
58 |
59 | model = Unet(
60 | dim = 64,
61 | dim_mults = (1, 2, 4, 8)
62 | ).cuda()
63 |
64 | diffusion = GaussianDiffusion(
65 | model,
66 | image_size = 128,
67 | timesteps = 1000, # number of steps
68 | sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
69 | loss_type = 'l1' # L1 or L2
70 | ).cuda()
71 |
72 | trainer = Trainer(
73 | diffusion,
74 | 'path/to/your/images',
75 | train_batch_size = 32,
76 | train_lr = 8e-5,
77 | train_num_steps = 700000, # total training steps
78 | gradient_accumulate_every = 2, # gradient accumulation steps
79 | ema_decay = 0.995, # exponential moving average decay
80 | amp = True # turn on mixed precision
81 | )
82 |
83 | trainer.train()
84 | ```
85 |
86 | Samples and model checkpoints will be logged to `./results` periodically
87 |
88 | ## Multi-GPU Training
89 |
90 | The `Trainer` class is now equipped with 🤗 Accelerator. You can easily do multi-gpu training in two steps using their `accelerate` CLI
91 |
92 | At the project root directory, where the training script is, run
93 |
94 | ```python
95 | $ accelerate config
96 | ```
97 |
98 | Then, in the same directory
99 |
100 | ```python
101 | $ accelerate launch train.py
102 | ```
103 |
104 | ## Miscellaneous
105 |
106 | ### 1D Sequence
107 |
108 | By popular request, a 1D Unet + Gaussian Diffusion implementation. You will have to do the training code yourself
109 |
110 | ```python
111 | import torch
112 | from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D
113 |
114 | model = Unet1D(
115 | dim = 64,
116 | dim_mults = (1, 2, 4, 8),
117 | channels = 32
118 | )
119 |
120 | diffusion = GaussianDiffusion1D(
121 | model,
122 | seq_length = 128,
123 | timesteps = 1000,
124 | objective = 'pred_v'
125 | )
126 |
127 | training_seq = torch.rand(8, 32, 128) # features are normalized from 0 to 1
128 | loss = diffusion(training_seq)
129 | loss.backward()
130 |
131 | # after a lot of training
132 |
133 | sampled_seq = diffusion.sample(batch_size = 4)
134 | sampled_seq.shape # (4, 32, 128)
135 | ```
136 |
137 | ## Citations
138 |
139 | ```bibtex
140 | @inproceedings{NEURIPS2020_4c5bcfec,
141 | author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
142 | booktitle = {Advances in Neural Information Processing Systems},
143 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
144 | pages = {6840--6851},
145 | publisher = {Curran Associates, Inc.},
146 | title = {Denoising Diffusion Probabilistic Models},
147 | url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
148 | volume = {33},
149 | year = {2020}
150 | }
151 | ```
152 |
153 | ```bibtex
154 | @InProceedings{pmlr-v139-nichol21a,
155 | title = {Improved Denoising Diffusion Probabilistic Models},
156 | author = {Nichol, Alexander Quinn and Dhariwal, Prafulla},
157 | booktitle = {Proceedings of the 38th International Conference on Machine Learning},
158 | pages = {8162--8171},
159 | year = {2021},
160 | editor = {Meila, Marina and Zhang, Tong},
161 | volume = {139},
162 | series = {Proceedings of Machine Learning Research},
163 | month = {18--24 Jul},
164 | publisher = {PMLR},
165 | pdf = {http://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf},
166 | url = {https://proceedings.mlr.press/v139/nichol21a.html},
167 | }
168 | ```
169 |
170 | ```bibtex
171 | @inproceedings{kingma2021on,
172 | title = {On Density Estimation with Diffusion Models},
173 | author = {Diederik P Kingma and Tim Salimans and Ben Poole and Jonathan Ho},
174 | booktitle = {Advances in Neural Information Processing Systems},
175 | editor = {A. Beygelzimer and Y. Dauphin and P. Liang and J. Wortman Vaughan},
176 | year = {2021},
177 | url = {https://openreview.net/forum?id=2LdBqxc1Yv}
178 | }
179 | ```
180 |
181 | ```bibtex
182 | @article{Choi2022PerceptionPT,
183 | title = {Perception Prioritized Training of Diffusion Models},
184 | author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
185 | journal = {ArXiv},
186 | year = {2022},
187 | volume = {abs/2204.00227}
188 | }
189 | ```
190 |
191 | ```bibtex
192 | @article{Karras2022ElucidatingTD,
193 | title = {Elucidating the Design Space of Diffusion-Based Generative Models},
194 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
195 | journal = {ArXiv},
196 | year = {2022},
197 | volume = {abs/2206.00364}
198 | }
199 | ```
200 |
201 | ```bibtex
202 | @article{Song2021DenoisingDI,
203 | title = {Denoising Diffusion Implicit Models},
204 | author = {Jiaming Song and Chenlin Meng and Stefano Ermon},
205 | journal = {ArXiv},
206 | year = {2021},
207 | volume = {abs/2010.02502}
208 | }
209 | ```
210 |
211 | ```bibtex
212 | @misc{chen2022analog,
213 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
214 | author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
215 | year = {2022},
216 | eprint = {2208.04202},
217 | archivePrefix = {arXiv},
218 | primaryClass = {cs.CV}
219 | }
220 | ```
221 |
222 | ```bibtex
223 | @article{Qiao2019WeightS,
224 | title = {Weight Standardization},
225 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
226 | journal = {ArXiv},
227 | year = {2019},
228 | volume = {abs/1903.10520}
229 | }
230 | ```
231 |
232 | ```bibtex
233 | @article{Salimans2022ProgressiveDF,
234 | title = {Progressive Distillation for Fast Sampling of Diffusion Models},
235 | author = {Tim Salimans and Jonathan Ho},
236 | journal = {ArXiv},
237 | year = {2022},
238 | volume = {abs/2202.00512}
239 | }
240 | ```
241 |
242 | ```bibtex
243 | @article{Ho2022ClassifierFreeDG,
244 | title = {Classifier-Free Diffusion Guidance},
245 | author = {Jonathan Ho},
246 | journal = {ArXiv},
247 | year = {2022},
248 | volume = {abs/2207.12598}
249 | }
250 | ```
251 |
252 | ```bibtex
253 | @article{Sunkara2022NoMS,
254 | title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
255 | author = {Raja Sunkara and Tie Luo},
256 | journal = {ArXiv},
257 | year = {2022},
258 | volume = {abs/2208.03641}
259 | }
260 | ```
261 |
262 | ```bibtex
263 | @inproceedings{Jabri2022ScalableAC,
264 | title = {Scalable Adaptive Computation for Iterative Generation},
265 | author = {A. Jabri and David J. Fleet and Ting Chen},
266 | year = {2022}
267 | }
268 | ```
269 |
270 | ```bibtex
271 | @article{Cheng2022DPMSolverPlusPlus,
272 | title = {DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models},
273 | author = {Cheng Lu and Yuhao Zhou and Fan Bao and Jianfei Chen and Chongxuan Li and Jun Zhu},
274 | journal = {NeuRips 2022 Oral},
275 | year = {2022},
276 | volume = {abs/2211.01095}
277 | }
278 | ```
279 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/elucidated_diffusion.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from random import random
3 | import torch
4 | from torch import nn, einsum
5 | import torch.nn.functional as F
6 |
7 | from tqdm import tqdm
8 | from einops import rearrange, repeat, reduce
9 |
10 | # helpers
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 | def default(val, d):
16 | if exists(val):
17 | return val
18 | return d() if callable(d) else d
19 |
20 | # tensor helpers
21 |
22 | def log(t, eps = 1e-20):
23 | return torch.log(t.clamp(min = eps))
24 |
25 | # normalization functions
26 |
27 | def normalize_to_neg_one_to_one(img):
28 | return img * 2 - 1
29 |
30 | def unnormalize_to_zero_to_one(t):
31 | return (t + 1) * 0.5
32 |
33 | # main class
34 |
35 | class ElucidatedDiffusion(nn.Module):
36 | def __init__(
37 | self,
38 | net,
39 | *,
40 | image_size,
41 | channels = 3,
42 | num_sample_steps = 32, # number of sampling steps
43 | sigma_min = 0.002, # min noise level
44 | sigma_max = 80, # max noise level
45 | sigma_data = 0.5, # standard deviation of data distribution
46 | rho = 7, # controls the sampling schedule
47 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
48 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
49 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
50 | S_tmin = 0.05,
51 | S_tmax = 50,
52 | S_noise = 1.003,
53 | ):
54 | super().__init__()
55 | assert net.random_or_learned_sinusoidal_cond
56 | self.self_condition = net.self_condition
57 |
58 | self.net = net
59 |
60 | # image dimensions
61 |
62 | self.channels = channels
63 | self.image_size = image_size
64 |
65 | # parameters
66 |
67 | self.sigma_min = sigma_min
68 | self.sigma_max = sigma_max
69 | self.sigma_data = sigma_data
70 |
71 | self.rho = rho
72 |
73 | self.P_mean = P_mean
74 | self.P_std = P_std
75 |
76 | self.num_sample_steps = num_sample_steps # otherwise known as N in the paper
77 |
78 | self.S_churn = S_churn
79 | self.S_tmin = S_tmin
80 | self.S_tmax = S_tmax
81 | self.S_noise = S_noise
82 |
83 | @property
84 | def device(self):
85 | return next(self.net.parameters()).device
86 |
87 | # derived preconditioning params - Table 1
88 |
89 | def c_skip(self, sigma):
90 | return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2)
91 |
92 | def c_out(self, sigma):
93 | return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5
94 |
95 | def c_in(self, sigma):
96 | return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5
97 |
98 | def c_noise(self, sigma):
99 | return log(sigma) * 0.25
100 |
101 | # preconditioned network output
102 | # equation (7) in the paper
103 |
104 | def preconditioned_network_forward(self, noised_images, sigma, self_cond = None, clamp = False):
105 | batch, device = noised_images.shape[0], noised_images.device
106 |
107 | if isinstance(sigma, float):
108 | sigma = torch.full((batch,), sigma, device = device)
109 |
110 | padded_sigma = rearrange(sigma, 'b -> b 1 1 1')
111 |
112 | net_out = self.net(
113 | self.c_in(padded_sigma) * noised_images,
114 | self.c_noise(sigma),
115 | self_cond
116 | )
117 |
118 | out = self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out
119 |
120 | if clamp:
121 | out = out.clamp(-1., 1.)
122 |
123 | return out
124 |
125 | # sampling
126 |
127 | # sample schedule
128 | # equation (5) in the paper
129 |
130 | def sample_schedule(self, num_sample_steps = None):
131 | num_sample_steps = default(num_sample_steps, self.num_sample_steps)
132 |
133 | N = num_sample_steps
134 | inv_rho = 1 / self.rho
135 |
136 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
137 | sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho
138 |
139 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
140 | return sigmas
141 |
142 | @torch.no_grad()
143 | def sample(self, batch_size = 16, num_sample_steps = None, clamp = True):
144 | num_sample_steps = default(num_sample_steps, self.num_sample_steps)
145 |
146 | shape = (batch_size, self.channels, self.image_size, self.image_size)
147 |
148 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
149 |
150 | sigmas = self.sample_schedule(num_sample_steps)
151 |
152 | gammas = torch.where(
153 | (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax),
154 | min(self.S_churn / num_sample_steps, sqrt(2) - 1),
155 | 0.
156 | )
157 |
158 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
159 |
160 | # images is noise at the beginning
161 |
162 | init_sigma = sigmas[0]
163 |
164 | images = init_sigma * torch.randn(shape, device = self.device)
165 |
166 | # for self conditioning
167 |
168 | x_start = None
169 |
170 | # gradually denoise
171 |
172 | for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'):
173 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
174 |
175 | eps = self.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
176 |
177 | sigma_hat = sigma + gamma * sigma
178 | images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps
179 |
180 | self_cond = x_start if self.self_condition else None
181 |
182 | model_output = self.preconditioned_network_forward(images_hat, sigma_hat, self_cond, clamp = clamp)
183 | denoised_over_sigma = (images_hat - model_output) / sigma_hat
184 |
185 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
186 |
187 | # second order correction, if not the last timestep
188 |
189 | if sigma_next != 0:
190 | self_cond = model_output if self.self_condition else None
191 |
192 | model_output_next = self.preconditioned_network_forward(images_next, sigma_next, self_cond, clamp = clamp)
193 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
194 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
195 |
196 | images = images_next
197 | x_start = model_output_next if sigma_next != 0 else model_output
198 |
199 | images = images.clamp(-1., 1.)
200 | return unnormalize_to_zero_to_one(images)
201 |
202 | @torch.no_grad()
203 | def sample_using_dpmpp(self, batch_size = 16, num_sample_steps = None):
204 | """
205 | thanks to Katherine Crowson (https://github.com/crowsonkb) for figuring it all out!
206 | https://arxiv.org/abs/2211.01095
207 | """
208 |
209 | device, num_sample_steps = self.device, default(num_sample_steps, self.num_sample_steps)
210 |
211 | sigmas = self.sample_schedule(num_sample_steps)
212 |
213 | shape = (batch_size, self.channels, self.image_size, self.image_size)
214 | images = sigmas[0] * torch.randn(shape, device = device)
215 |
216 | sigma_fn = lambda t: t.neg().exp()
217 | t_fn = lambda sigma: sigma.log().neg()
218 |
219 | old_denoised = None
220 | for i in tqdm(range(len(sigmas) - 1)):
221 | denoised = self.preconditioned_network_forward(images, sigmas[i].item())
222 | t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
223 | h = t_next - t
224 |
225 | if not exists(old_denoised) or sigmas[i + 1] == 0:
226 | denoised_d = denoised
227 | else:
228 | h_last = t - t_fn(sigmas[i - 1])
229 | r = h_last / h
230 | gamma = - 1 / (2 * r)
231 | denoised_d = (1 - gamma) * denoised + gamma * old_denoised
232 |
233 | images = (sigma_fn(t_next) / sigma_fn(t)) * images - (-h).expm1() * denoised_d
234 | old_denoised = denoised
235 |
236 | images = images.clamp(-1., 1.)
237 | return unnormalize_to_zero_to_one(images)
238 |
239 | # training
240 |
241 | def loss_weight(self, sigma):
242 | return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2
243 |
244 | def noise_distribution(self, batch_size):
245 | return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp()
246 |
247 | def forward(self, images):
248 | batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels
249 |
250 | assert h == image_size and w == image_size, f'height and width of image must be {image_size}'
251 | assert c == channels, 'mismatch of image channels'
252 |
253 | images = normalize_to_neg_one_to_one(images)
254 |
255 | sigmas = self.noise_distribution(batch_size)
256 | padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1')
257 |
258 | noise = torch.randn_like(images)
259 |
260 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
261 |
262 | self_cond = None
263 |
264 | if self.self_condition and random() < 0.5:
265 | # from hinton's group's bit diffusion paper
266 | with torch.no_grad():
267 | self_cond = self.preconditioned_network_forward(noised_images, sigmas)
268 | self_cond.detach_()
269 |
270 | denoised = self.preconditioned_network_forward(noised_images, sigmas, self_cond)
271 |
272 | losses = F.mse_loss(denoised, images, reduction = 'none')
273 | losses = reduce(losses, 'b ... -> b', 'mean')
274 |
275 | losses = losses * self.loss_weight(sigmas)
276 |
277 | return losses.mean()
278 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import sqrt
4 | from torch import nn, einsum
5 | import torch.nn.functional as F
6 | from torch.special import expm1
7 |
8 | from tqdm import tqdm
9 | from einops import rearrange, repeat, reduce
10 | from einops.layers.torch import Rearrange
11 |
12 | # helpers
13 |
14 | def exists(val):
15 | return val is not None
16 |
17 | def default(val, d):
18 | if exists(val):
19 | return val
20 | return d() if callable(d) else d
21 |
22 | # normalization functions
23 |
24 | def normalize_to_neg_one_to_one(img):
25 | return img * 2 - 1
26 |
27 | def unnormalize_to_zero_to_one(t):
28 | return (t + 1) * 0.5
29 |
30 | # diffusion helpers
31 |
32 | def right_pad_dims_to(x, t):
33 | padding_dims = x.ndim - t.ndim
34 | if padding_dims <= 0:
35 | return t
36 | return t.view(*t.shape, *((1,) * padding_dims))
37 |
38 | # neural net helpers
39 |
40 | class Residual(nn.Module):
41 | def __init__(self, fn):
42 | super().__init__()
43 | self.fn = fn
44 |
45 | def forward(self, x):
46 | return x + self.fn(x)
47 |
48 | class MonotonicLinear(nn.Module):
49 | def __init__(self, *args, **kwargs):
50 | super().__init__()
51 | self.net = nn.Linear(*args, **kwargs)
52 |
53 | def forward(self, x):
54 | return F.linear(x, self.net.weight.abs(), self.net.bias.abs())
55 |
56 | # continuous schedules
57 |
58 | # equations are taken from https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material
59 | # @crowsonkb Katherine's repository also helped here https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
60 |
61 | # log(snr) that approximates the original linear schedule
62 |
63 | def log(t, eps = 1e-20):
64 | return torch.log(t.clamp(min = eps))
65 |
66 | def beta_linear_log_snr(t):
67 | return -log(expm1(1e-4 + 10 * (t ** 2)))
68 |
69 | def alpha_cosine_log_snr(t, s = 0.008):
70 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5)
71 |
72 | class learned_noise_schedule(nn.Module):
73 | """ described in section H and then I.2 of the supplementary material for variational ddpm paper """
74 |
75 | def __init__(
76 | self,
77 | *,
78 | log_snr_max,
79 | log_snr_min,
80 | hidden_dim = 1024,
81 | frac_gradient = 1.
82 | ):
83 | super().__init__()
84 | self.slope = log_snr_min - log_snr_max
85 | self.intercept = log_snr_max
86 |
87 | self.net = nn.Sequential(
88 | Rearrange('... -> ... 1'),
89 | MonotonicLinear(1, 1),
90 | Residual(nn.Sequential(
91 | MonotonicLinear(1, hidden_dim),
92 | nn.Sigmoid(),
93 | MonotonicLinear(hidden_dim, 1)
94 | )),
95 | Rearrange('... 1 -> ...'),
96 | )
97 |
98 | self.frac_gradient = frac_gradient
99 |
100 | def forward(self, x):
101 | frac_gradient = self.frac_gradient
102 | device = x.device
103 |
104 | out_zero = self.net(torch.zeros_like(x))
105 | out_one = self.net(torch.ones_like(x))
106 |
107 | x = self.net(x)
108 |
109 | normed = self.slope * ((x - out_zero) / (out_one - out_zero)) + self.intercept
110 | return normed * frac_gradient + normed.detach() * (1 - frac_gradient)
111 |
112 | class ContinuousTimeGaussianDiffusion(nn.Module):
113 | def __init__(
114 | self,
115 | model,
116 | *,
117 | image_size,
118 | channels = 3,
119 | loss_type = 'l1',
120 | noise_schedule = 'linear',
121 | num_sample_steps = 500,
122 | clip_sample_denoised = True,
123 | learned_schedule_net_hidden_dim = 1024,
124 | learned_noise_schedule_frac_gradient = 1., # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly
125 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
126 | p2_loss_weight_k = 1
127 | ):
128 | super().__init__()
129 | assert model.random_or_learned_sinusoidal_cond
130 | assert not model.self_condition, 'not supported yet'
131 |
132 | self.model = model
133 |
134 | # image dimensions
135 |
136 | self.channels = channels
137 | self.image_size = image_size
138 |
139 | # continuous noise schedule related stuff
140 |
141 | self.loss_type = loss_type
142 |
143 | if noise_schedule == 'linear':
144 | self.log_snr = beta_linear_log_snr
145 | elif noise_schedule == 'cosine':
146 | self.log_snr = alpha_cosine_log_snr
147 | elif noise_schedule == 'learned':
148 | log_snr_max, log_snr_min = [beta_linear_log_snr(torch.tensor([time])).item() for time in (0., 1.)]
149 |
150 | self.log_snr = learned_noise_schedule(
151 | log_snr_max = log_snr_max,
152 | log_snr_min = log_snr_min,
153 | hidden_dim = learned_schedule_net_hidden_dim,
154 | frac_gradient = learned_noise_schedule_frac_gradient
155 | )
156 | else:
157 | raise ValueError(f'unknown noise schedule {noise_schedule}')
158 |
159 | # sampling
160 |
161 | self.num_sample_steps = num_sample_steps
162 | self.clip_sample_denoised = clip_sample_denoised
163 |
164 | # p2 loss weight
165 | # proposed https://arxiv.org/abs/2204.00227
166 |
167 | assert p2_loss_weight_gamma <= 2, 'in paper, they noticed any gamma greater than 2 is harmful'
168 |
169 | self.p2_loss_weight_gamma = p2_loss_weight_gamma # recommended to be 0.5 or 1
170 | self.p2_loss_weight_k = p2_loss_weight_k
171 |
172 | @property
173 | def device(self):
174 | return next(self.model.parameters()).device
175 |
176 | @property
177 | def loss_fn(self):
178 | if self.loss_type == 'l1':
179 | return F.l1_loss
180 | elif self.loss_type == 'l2':
181 | return F.mse_loss
182 | else:
183 | raise ValueError(f'invalid loss type {self.loss_type}')
184 |
185 | def p_mean_variance(self, x, time, time_next):
186 | # reviewer found an error in the equation in the paper (missing sigma)
187 | # following - https://openreview.net/forum?id=2LdBqxc1Yv¬eId=rIQgH0zKsRt
188 |
189 | log_snr = self.log_snr(time)
190 | log_snr_next = self.log_snr(time_next)
191 | c = -expm1(log_snr - log_snr_next)
192 |
193 | squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
194 | squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()
195 |
196 | alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))
197 |
198 | batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
199 | pred_noise = self.model(x, batch_log_snr)
200 |
201 | if self.clip_sample_denoised:
202 | x_start = (x - sigma * pred_noise) / alpha
203 |
204 | # in Imagen, this was changed to dynamic thresholding
205 | x_start.clamp_(-1., 1.)
206 |
207 | model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)
208 | else:
209 | model_mean = alpha_next / alpha * (x - c * sigma * pred_noise)
210 |
211 | posterior_variance = squared_sigma_next * c
212 |
213 | return model_mean, posterior_variance
214 |
215 | # sampling related functions
216 |
217 | @torch.no_grad()
218 | def p_sample(self, x, time, time_next):
219 | batch, *_, device = *x.shape, x.device
220 |
221 | model_mean, model_variance = self.p_mean_variance(x = x, time = time, time_next = time_next)
222 |
223 | if time_next == 0:
224 | return model_mean
225 |
226 | noise = torch.randn_like(x)
227 | return model_mean + sqrt(model_variance) * noise
228 |
229 | @torch.no_grad()
230 | def p_sample_loop(self, shape):
231 | batch = shape[0]
232 |
233 | img = torch.randn(shape, device = self.device)
234 | steps = torch.linspace(1., 0., self.num_sample_steps + 1, device = self.device)
235 |
236 | for i in tqdm(range(self.num_sample_steps), desc = 'sampling loop time step', total = self.num_sample_steps):
237 | times = steps[i]
238 | times_next = steps[i + 1]
239 | img = self.p_sample(img, times, times_next)
240 |
241 | img.clamp_(-1., 1.)
242 | img = unnormalize_to_zero_to_one(img)
243 | return img
244 |
245 | @torch.no_grad()
246 | def sample(self, batch_size = 16):
247 | return self.p_sample_loop((batch_size, self.channels, self.image_size, self.image_size))
248 |
249 | # training related functions - noise prediction
250 |
251 | def q_sample(self, x_start, times, noise = None):
252 | noise = default(noise, lambda: torch.randn_like(x_start))
253 |
254 | log_snr = self.log_snr(times)
255 |
256 | log_snr_padded = right_pad_dims_to(x_start, log_snr)
257 | alpha, sigma = sqrt(log_snr_padded.sigmoid()), sqrt((-log_snr_padded).sigmoid())
258 | x_noised = x_start * alpha + noise * sigma
259 |
260 | return x_noised, log_snr
261 |
262 | def random_times(self, batch_size):
263 | # times are now uniform from 0 to 1
264 | return torch.zeros((batch_size,), device = self.device).float().uniform_(0, 1)
265 |
266 | def p_losses(self, x_start, times, noise = None):
267 | noise = default(noise, lambda: torch.randn_like(x_start))
268 |
269 | x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
270 | model_out = self.model(x, log_snr)
271 |
272 | losses = self.loss_fn(model_out, noise, reduction = 'none')
273 | losses = reduce(losses, 'b ... -> b', 'mean')
274 |
275 | if self.p2_loss_weight_gamma >= 0:
276 | # following eq 8. in https://arxiv.org/abs/2204.00227
277 | loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -self.p2_loss_weight_gamma
278 | losses = losses * loss_weight
279 |
280 | return losses.mean()
281 |
282 | def forward(self, img, *args, **kwargs):
283 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
284 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
285 |
286 | times = self.random_times(b)
287 | img = normalize_to_neg_one_to_one(img)
288 | return self.p_losses(img, times, *args, **kwargs)
289 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py:
--------------------------------------------------------------------------------
1 | import math
2 | from random import random
3 | from functools import partial
4 | from collections import namedtuple
5 |
6 | import torch
7 | from torch import nn, einsum
8 | import torch.nn.functional as F
9 |
10 | from einops import rearrange, reduce
11 | from einops.layers.torch import Rearrange
12 |
13 | from tqdm.auto import tqdm
14 |
15 | # constants
16 |
17 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
18 |
19 | # helpers functions
20 |
21 | def exists(x):
22 | return x is not None
23 |
24 | def default(val, d):
25 | if exists(val):
26 | return val
27 | return d() if callable(d) else d
28 |
29 | def identity(t, *args, **kwargs):
30 | return t
31 |
32 | def cycle(dl):
33 | while True:
34 | for data in dl:
35 | yield data
36 |
37 | def has_int_squareroot(num):
38 | return (math.sqrt(num) ** 2) == num
39 |
40 | def num_to_groups(num, divisor):
41 | groups = num // divisor
42 | remainder = num % divisor
43 | arr = [divisor] * groups
44 | if remainder > 0:
45 | arr.append(remainder)
46 | return arr
47 |
48 | def convert_image_to_fn(img_type, image):
49 | if image.mode != img_type:
50 | return image.convert(img_type)
51 | return image
52 |
53 | # normalization functions
54 |
55 | def normalize_to_neg_one_to_one(img):
56 | return img * 2 - 1
57 |
58 | def unnormalize_to_zero_to_one(t):
59 | return (t + 1) * 0.5
60 |
61 | # small helper modules
62 |
63 | class Residual(nn.Module):
64 | def __init__(self, fn):
65 | super().__init__()
66 | self.fn = fn
67 |
68 | def forward(self, x, *args, **kwargs):
69 | return self.fn(x, *args, **kwargs) + x
70 |
71 | def Upsample(dim, dim_out = None):
72 | return nn.Sequential(
73 | nn.Upsample(scale_factor = 2, mode = 'nearest'),
74 | nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1)
75 | )
76 |
77 | def Downsample(dim, dim_out = None):
78 | return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)
79 |
80 | class WeightStandardizedConv2d(nn.Conv1d):
81 | """
82 | https://arxiv.org/abs/1903.10520
83 | weight standardization purportedly works synergistically with group normalization
84 | """
85 | def forward(self, x):
86 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
87 |
88 | weight = self.weight
89 | mean = reduce(weight, 'o ... -> o 1 1', 'mean')
90 | var = reduce(weight, 'o ... -> o 1 1', partial(torch.var, unbiased = False))
91 | normalized_weight = (weight - mean) * (var + eps).rsqrt()
92 |
93 | return F.conv1d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
94 |
95 | class LayerNorm(nn.Module):
96 | def __init__(self, dim):
97 | super().__init__()
98 | self.g = nn.Parameter(torch.ones(1, dim, 1))
99 |
100 | def forward(self, x):
101 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
102 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
103 | mean = torch.mean(x, dim = 1, keepdim = True)
104 | return (x - mean) * (var + eps).rsqrt() * self.g
105 |
106 | class PreNorm(nn.Module):
107 | def __init__(self, dim, fn):
108 | super().__init__()
109 | self.fn = fn
110 | self.norm = LayerNorm(dim)
111 |
112 | def forward(self, x):
113 | x = self.norm(x)
114 | return self.fn(x)
115 |
116 | # sinusoidal positional embeds
117 |
118 | class SinusoidalPosEmb(nn.Module):
119 | def __init__(self, dim):
120 | super().__init__()
121 | self.dim = dim
122 |
123 | def forward(self, x):
124 | device = x.device
125 | half_dim = self.dim // 2
126 | emb = math.log(10000) / (half_dim - 1)
127 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
128 | emb = x[:, None] * emb[None, :]
129 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
130 | return emb
131 |
132 | class RandomOrLearnedSinusoidalPosEmb(nn.Module):
133 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
134 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
135 |
136 | def __init__(self, dim, is_random = False):
137 | super().__init__()
138 | assert (dim % 2) == 0
139 | half_dim = dim // 2
140 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
141 |
142 | def forward(self, x):
143 | x = rearrange(x, 'b -> b 1')
144 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
145 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
146 | fouriered = torch.cat((x, fouriered), dim = -1)
147 | return fouriered
148 |
149 | # building block modules
150 |
151 | class Block(nn.Module):
152 | def __init__(self, dim, dim_out, groups = 8):
153 | super().__init__()
154 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
155 | self.norm = nn.GroupNorm(groups, dim_out)
156 | self.act = nn.SiLU()
157 |
158 | def forward(self, x, scale_shift = None):
159 | x = self.proj(x)
160 | x = self.norm(x)
161 |
162 | if exists(scale_shift):
163 | scale, shift = scale_shift
164 | x = x * (scale + 1) + shift
165 |
166 | x = self.act(x)
167 | return x
168 |
169 | class ResnetBlock(nn.Module):
170 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
171 | super().__init__()
172 | self.mlp = nn.Sequential(
173 | nn.SiLU(),
174 | nn.Linear(time_emb_dim, dim_out * 2)
175 | ) if exists(time_emb_dim) else None
176 |
177 | self.block1 = Block(dim, dim_out, groups = groups)
178 | self.block2 = Block(dim_out, dim_out, groups = groups)
179 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
180 |
181 | def forward(self, x, time_emb = None):
182 |
183 | scale_shift = None
184 | if exists(self.mlp) and exists(time_emb):
185 | time_emb = self.mlp(time_emb)
186 | time_emb = rearrange(time_emb, 'b c -> b c 1')
187 | scale_shift = time_emb.chunk(2, dim = 1)
188 |
189 | h = self.block1(x, scale_shift = scale_shift)
190 |
191 | h = self.block2(h)
192 |
193 | return h + self.res_conv(x)
194 |
195 | class LinearAttention(nn.Module):
196 | def __init__(self, dim, heads = 4, dim_head = 32):
197 | super().__init__()
198 | self.scale = dim_head ** -0.5
199 | self.heads = heads
200 | hidden_dim = dim_head * heads
201 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False)
202 |
203 | self.to_out = nn.Sequential(
204 | nn.Conv1d(hidden_dim, dim, 1),
205 | LayerNorm(dim)
206 | )
207 |
208 | def forward(self, x):
209 | b, c, n = x.shape
210 | qkv = self.to_qkv(x).chunk(3, dim = 1)
211 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv)
212 |
213 | q = q.softmax(dim = -2)
214 | k = k.softmax(dim = -1)
215 |
216 | q = q * self.scale
217 |
218 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
219 |
220 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
221 | out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads)
222 | return self.to_out(out)
223 |
224 | class Attention(nn.Module):
225 | def __init__(self, dim, heads = 4, dim_head = 32):
226 | super().__init__()
227 | self.scale = dim_head ** -0.5
228 | self.heads = heads
229 | hidden_dim = dim_head * heads
230 |
231 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False)
232 | self.to_out = nn.Conv1d(hidden_dim, dim, 1)
233 |
234 | def forward(self, x):
235 | b, c, n = x.shape
236 | qkv = self.to_qkv(x).chunk(3, dim = 1)
237 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv)
238 |
239 | q = q * self.scale
240 |
241 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
242 | attn = sim.softmax(dim = -1)
243 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
244 |
245 | out = rearrange(out, 'b h n d -> b (h d) n')
246 | return self.to_out(out)
247 |
248 | # model
249 |
250 | class Unet1D(nn.Module):
251 | def __init__(
252 | self,
253 | dim,
254 | init_dim = None,
255 | out_dim = None,
256 | dim_mults=(1, 2, 4, 8),
257 | channels = 3,
258 | self_condition = False,
259 | resnet_block_groups = 8,
260 | learned_variance = False,
261 | learned_sinusoidal_cond = False,
262 | random_fourier_features = False,
263 | learned_sinusoidal_dim = 16
264 | ):
265 | super().__init__()
266 |
267 | # determine dimensions
268 |
269 | self.channels = channels
270 | self.self_condition = self_condition
271 | input_channels = channels * (2 if self_condition else 1)
272 |
273 | init_dim = default(init_dim, dim)
274 | self.init_conv = nn.Conv1d(input_channels, init_dim, 7, padding = 3)
275 |
276 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
277 | in_out = list(zip(dims[:-1], dims[1:]))
278 |
279 | block_klass = partial(ResnetBlock, groups = resnet_block_groups)
280 |
281 | # time embeddings
282 |
283 | time_dim = dim * 4
284 |
285 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
286 |
287 | if self.random_or_learned_sinusoidal_cond:
288 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
289 | fourier_dim = learned_sinusoidal_dim + 1
290 | else:
291 | sinu_pos_emb = SinusoidalPosEmb(dim)
292 | fourier_dim = dim
293 |
294 | self.time_mlp = nn.Sequential(
295 | sinu_pos_emb,
296 | nn.Linear(fourier_dim, time_dim),
297 | nn.GELU(),
298 | nn.Linear(time_dim, time_dim)
299 | )
300 |
301 | # layers
302 |
303 | self.downs = nn.ModuleList([])
304 | self.ups = nn.ModuleList([])
305 | num_resolutions = len(in_out)
306 |
307 | for ind, (dim_in, dim_out) in enumerate(in_out):
308 | is_last = ind >= (num_resolutions - 1)
309 |
310 | self.downs.append(nn.ModuleList([
311 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
312 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
313 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
314 | Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
315 | ]))
316 |
317 | mid_dim = dims[-1]
318 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
319 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
320 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
321 |
322 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
323 | is_last = ind == (len(in_out) - 1)
324 |
325 | self.ups.append(nn.ModuleList([
326 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
327 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
328 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
329 | Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1)
330 | ]))
331 |
332 | default_out_dim = channels * (1 if not learned_variance else 2)
333 | self.out_dim = default(out_dim, default_out_dim)
334 |
335 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
336 | self.final_conv = nn.Conv1d(dim, self.out_dim, 1)
337 |
338 | def forward(self, x, time, x_self_cond = None):
339 | if self.self_condition:
340 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
341 | x = torch.cat((x_self_cond, x), dim = 1)
342 |
343 | x = self.init_conv(x)
344 | r = x.clone()
345 |
346 | t = self.time_mlp(time)
347 |
348 | h = []
349 |
350 | for block1, block2, attn, downsample in self.downs:
351 | x = block1(x, t)
352 | h.append(x)
353 |
354 | x = block2(x, t)
355 | x = attn(x)
356 | h.append(x)
357 |
358 | x = downsample(x)
359 |
360 | x = self.mid_block1(x, t)
361 | x = self.mid_attn(x)
362 | x = self.mid_block2(x, t)
363 |
364 | for block1, block2, attn, upsample in self.ups:
365 | x = torch.cat((x, h.pop()), dim = 1)
366 | x = block1(x, t)
367 |
368 | x = torch.cat((x, h.pop()), dim = 1)
369 | x = block2(x, t)
370 | x = attn(x)
371 |
372 | x = upsample(x)
373 |
374 | x = torch.cat((x, r), dim = 1)
375 |
376 | x = self.final_res_block(x, t)
377 | return self.final_conv(x)
378 |
379 | # gaussian diffusion trainer class
380 |
381 | def extract(a, t, x_shape):
382 | b, *_ = t.shape
383 | out = a.gather(-1, t)
384 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
385 |
386 | def linear_beta_schedule(timesteps):
387 | scale = 1000 / timesteps
388 | beta_start = scale * 0.0001
389 | beta_end = scale * 0.02
390 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
391 |
392 | def cosine_beta_schedule(timesteps, s = 0.008):
393 | """
394 | cosine schedule
395 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
396 | """
397 | steps = timesteps + 1
398 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
399 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
400 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
401 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
402 | return torch.clip(betas, 0, 0.999)
403 |
404 | class GaussianDiffusion1D(nn.Module):
405 | def __init__(
406 | self,
407 | model,
408 | *,
409 | seq_length,
410 | timesteps = 1000,
411 | sampling_timesteps = None,
412 | loss_type = 'l1',
413 | objective = 'pred_noise',
414 | beta_schedule = 'cosine',
415 | p2_loss_weight_gamma = 0.,
416 | p2_loss_weight_k = 1,
417 | ddim_sampling_eta = 0.,
418 | auto_normalize = True
419 | ):
420 | super().__init__()
421 | self.model = model
422 | self.channels = self.model.channels
423 | self.self_condition = self.model.self_condition
424 |
425 | self.seq_length = seq_length
426 |
427 | self.objective = objective
428 |
429 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
430 |
431 | if beta_schedule == 'linear':
432 | betas = linear_beta_schedule(timesteps)
433 | elif beta_schedule == 'cosine':
434 | betas = cosine_beta_schedule(timesteps)
435 | else:
436 | raise ValueError(f'unknown beta schedule {beta_schedule}')
437 |
438 | alphas = 1. - betas
439 | alphas_cumprod = torch.cumprod(alphas, dim=0)
440 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
441 |
442 | timesteps, = betas.shape
443 | self.num_timesteps = int(timesteps)
444 | self.loss_type = loss_type
445 |
446 | # sampling related parameters
447 |
448 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
449 |
450 | assert self.sampling_timesteps <= timesteps
451 | self.is_ddim_sampling = self.sampling_timesteps < timesteps
452 | self.ddim_sampling_eta = ddim_sampling_eta
453 |
454 | # helper function to register buffer from float64 to float32
455 |
456 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
457 |
458 | register_buffer('betas', betas)
459 | register_buffer('alphas_cumprod', alphas_cumprod)
460 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
461 |
462 | # calculations for diffusion q(x_t | x_{t-1}) and others
463 |
464 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
465 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
466 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
467 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
468 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
469 |
470 | # calculations for posterior q(x_{t-1} | x_t, x_0)
471 |
472 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
473 |
474 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
475 |
476 | register_buffer('posterior_variance', posterior_variance)
477 |
478 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
479 |
480 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
481 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
482 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
483 |
484 | # calculate p2 reweighting
485 |
486 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
487 |
488 | # whether to autonormalize
489 |
490 | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
491 | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
492 |
493 | def predict_start_from_noise(self, x_t, t, noise):
494 | return (
495 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
496 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
497 | )
498 |
499 | def predict_noise_from_start(self, x_t, t, x0):
500 | return (
501 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
502 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
503 | )
504 |
505 | def predict_v(self, x_start, t, noise):
506 | return (
507 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
508 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
509 | )
510 |
511 | def predict_start_from_v(self, x_t, t, v):
512 | return (
513 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
514 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
515 | )
516 |
517 | def q_posterior(self, x_start, x_t, t):
518 | posterior_mean = (
519 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
520 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
521 | )
522 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
523 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
524 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
525 |
526 | def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False):
527 | model_output = self.model(x, t, x_self_cond)
528 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
529 |
530 | if self.objective == 'pred_noise':
531 | pred_noise = model_output
532 | x_start = self.predict_start_from_noise(x, t, pred_noise)
533 | x_start = maybe_clip(x_start)
534 |
535 | elif self.objective == 'pred_x0':
536 | x_start = model_output
537 | x_start = maybe_clip(x_start)
538 | pred_noise = self.predict_noise_from_start(x, t, x_start)
539 |
540 | elif self.objective == 'pred_v':
541 | v = model_output
542 | x_start = self.predict_start_from_v(x, t, v)
543 | x_start = maybe_clip(x_start)
544 | pred_noise = self.predict_noise_from_start(x, t, x_start)
545 |
546 | return ModelPrediction(pred_noise, x_start)
547 |
548 | def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
549 | preds = self.model_predictions(x, t, x_self_cond)
550 | x_start = preds.pred_x_start
551 |
552 | if clip_denoised:
553 | x_start.clamp_(-1., 1.)
554 |
555 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
556 | return model_mean, posterior_variance, posterior_log_variance, x_start
557 |
558 | @torch.no_grad()
559 | def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
560 | b, *_, device = *x.shape, x.device
561 | batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
562 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
563 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
564 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
565 | return pred_img, x_start
566 |
567 | @torch.no_grad()
568 | def p_sample_loop(self, shape):
569 | batch, device = shape[0], self.betas.device
570 |
571 | img = torch.randn(shape, device=device)
572 |
573 | x_start = None
574 |
575 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
576 | self_cond = x_start if self.self_condition else None
577 | img, x_start = self.p_sample(img, t, self_cond)
578 |
579 | img = self.unnormalize(img)
580 | return img
581 |
582 | @torch.no_grad()
583 | def ddim_sample(self, shape, clip_denoised = True):
584 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
585 |
586 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
587 | times = list(reversed(times.int().tolist()))
588 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
589 |
590 | img = torch.randn(shape, device = device)
591 |
592 | x_start = None
593 |
594 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
595 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
596 | self_cond = x_start if self.self_condition else None
597 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = clip_denoised)
598 |
599 | if time_next < 0:
600 | img = x_start
601 | continue
602 |
603 | alpha = self.alphas_cumprod[time]
604 | alpha_next = self.alphas_cumprod[time_next]
605 |
606 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
607 | c = (1 - alpha_next - sigma ** 2).sqrt()
608 |
609 | noise = torch.randn_like(img)
610 |
611 | img = x_start * alpha_next.sqrt() + \
612 | c * pred_noise + \
613 | sigma * noise
614 |
615 | img = self.unnormalize(img)
616 | return img
617 |
618 | @torch.no_grad()
619 | def sample(self, batch_size = 16):
620 | seq_length, channels = self.seq_length, self.channels
621 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
622 | return sample_fn((batch_size, channels, seq_length))
623 |
624 | @torch.no_grad()
625 | def interpolate(self, x1, x2, t = None, lam = 0.5):
626 | b, *_, device = *x1.shape, x1.device
627 | t = default(t, self.num_timesteps - 1)
628 |
629 | assert x1.shape == x2.shape
630 |
631 | t_batched = torch.full((b,), t, device = device)
632 | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
633 |
634 | img = (1 - lam) * xt1 + lam * xt2
635 |
636 | x_start = None
637 |
638 | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
639 | self_cond = x_start if self.self_condition else None
640 | img, x_start = self.p_sample(img, i, self_cond)
641 |
642 | return img
643 |
644 | def q_sample(self, x_start, t, noise=None):
645 | noise = default(noise, lambda: torch.randn_like(x_start))
646 |
647 | return (
648 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
649 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
650 | )
651 |
652 | @property
653 | def loss_fn(self):
654 | if self.loss_type == 'l1':
655 | return F.l1_loss
656 | elif self.loss_type == 'l2':
657 | return F.mse_loss
658 | else:
659 | raise ValueError(f'invalid loss type {self.loss_type}')
660 |
661 | def p_losses(self, x_start, t, noise = None):
662 | b, c, n = x_start.shape
663 | noise = default(noise, lambda: torch.randn_like(x_start))
664 |
665 | # noise sample
666 |
667 | x = self.q_sample(x_start = x_start, t = t, noise = noise)
668 |
669 | # if doing self-conditioning, 50% of the time, predict x_start from current set of times
670 | # and condition with unet with that
671 | # this technique will slow down training by 25%, but seems to lower FID significantly
672 |
673 | x_self_cond = None
674 | if self.self_condition and random() < 0.5:
675 | with torch.no_grad():
676 | x_self_cond = self.model_predictions(x, t).pred_x_start
677 | x_self_cond.detach_()
678 |
679 | # predict and take gradient step
680 |
681 | model_out = self.model(x, t, x_self_cond)
682 |
683 | if self.objective == 'pred_noise':
684 | target = noise
685 | elif self.objective == 'pred_x0':
686 | target = x_start
687 | elif self.objective == 'pred_v':
688 | v = self.predict_v(x_start, t, noise)
689 | target = v
690 | else:
691 | raise ValueError(f'unknown objective {self.objective}')
692 |
693 | loss = self.loss_fn(model_out, target, reduction = 'none')
694 | loss = reduce(loss, 'b ... -> b (...)', 'mean')
695 |
696 | loss = loss * extract(self.p2_loss_weight, t, loss.shape)
697 | return loss.mean()
698 |
699 | def forward(self, img, *args, **kwargs):
700 | b, c, n, device, seq_length, = *img.shape, img.device, self.seq_length
701 | assert n == seq_length, f'seq length must be {seq_length}'
702 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
703 |
704 | img = self.normalize(img)
705 | return self.p_losses(img, t, *args, **kwargs)
706 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/classifier_free_guidance.py:
--------------------------------------------------------------------------------
1 | import math
2 | import copy
3 | from pathlib import Path
4 | from random import random
5 | from functools import partial
6 | from collections import namedtuple
7 | from multiprocessing import cpu_count
8 |
9 | import torch
10 | from torch import nn, einsum
11 | import torch.nn.functional as F
12 |
13 | from einops import rearrange, reduce, repeat
14 | from einops.layers.torch import Rearrange
15 |
16 | from tqdm.auto import tqdm
17 |
18 | # constants
19 |
20 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
21 |
22 | # helpers functions
23 |
24 | def exists(x):
25 | return x is not None
26 |
27 | def default(val, d):
28 | if exists(val):
29 | return val
30 | return d() if callable(d) else d
31 |
32 | def identity(t, *args, **kwargs):
33 | return t
34 |
35 | def cycle(dl):
36 | while True:
37 | for data in dl:
38 | yield data
39 |
40 | def has_int_squareroot(num):
41 | return (math.sqrt(num) ** 2) == num
42 |
43 | def num_to_groups(num, divisor):
44 | groups = num // divisor
45 | remainder = num % divisor
46 | arr = [divisor] * groups
47 | if remainder > 0:
48 | arr.append(remainder)
49 | return arr
50 |
51 | def convert_image_to_fn(img_type, image):
52 | if image.mode != img_type:
53 | return image.convert(img_type)
54 | return image
55 |
56 | # normalization functions
57 |
58 | def normalize_to_neg_one_to_one(img):
59 | return img * 2 - 1
60 |
61 | def unnormalize_to_zero_to_one(t):
62 | return (t + 1) * 0.5
63 |
64 | # classifier free guidance functions
65 |
66 | def uniform(shape, device):
67 | return torch.zeros(shape, device = device).float().uniform_(0, 1)
68 |
69 | def prob_mask_like(shape, prob, device):
70 | if prob == 1:
71 | return torch.ones(shape, device = device, dtype = torch.bool)
72 | elif prob == 0:
73 | return torch.zeros(shape, device = device, dtype = torch.bool)
74 | else:
75 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
76 |
77 | # small helper modules
78 |
79 | class Residual(nn.Module):
80 | def __init__(self, fn):
81 | super().__init__()
82 | self.fn = fn
83 |
84 | def forward(self, x, *args, **kwargs):
85 | return self.fn(x, *args, **kwargs) + x
86 |
87 | def Upsample(dim, dim_out = None):
88 | return nn.Sequential(
89 | nn.Upsample(scale_factor = 2, mode = 'nearest'),
90 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
91 | )
92 |
93 | def Downsample(dim, dim_out = None):
94 | return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
95 |
96 | class WeightStandardizedConv2d(nn.Conv2d):
97 | """
98 | https://arxiv.org/abs/1903.10520
99 | weight standardization purportedly works synergistically with group normalization
100 | """
101 | def forward(self, x):
102 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
103 |
104 | weight = self.weight
105 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
106 | var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
107 | normalized_weight = (weight - mean) * (var + eps).rsqrt()
108 |
109 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110 |
111 | class LayerNorm(nn.Module):
112 | def __init__(self, dim):
113 | super().__init__()
114 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
115 |
116 | def forward(self, x):
117 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
118 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
119 | mean = torch.mean(x, dim = 1, keepdim = True)
120 | return (x - mean) * (var + eps).rsqrt() * self.g
121 |
122 | class PreNorm(nn.Module):
123 | def __init__(self, dim, fn):
124 | super().__init__()
125 | self.fn = fn
126 | self.norm = LayerNorm(dim)
127 |
128 | def forward(self, x):
129 | x = self.norm(x)
130 | return self.fn(x)
131 |
132 | # sinusoidal positional embeds
133 |
134 | class SinusoidalPosEmb(nn.Module):
135 | def __init__(self, dim):
136 | super().__init__()
137 | self.dim = dim
138 |
139 | def forward(self, x):
140 | device = x.device
141 | half_dim = self.dim // 2
142 | emb = math.log(10000) / (half_dim - 1)
143 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
144 | emb = x[:, None] * emb[None, :]
145 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
146 | return emb
147 |
148 | class RandomOrLearnedSinusoidalPosEmb(nn.Module):
149 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
150 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
151 |
152 | def __init__(self, dim, is_random = False):
153 | super().__init__()
154 | assert (dim % 2) == 0
155 | half_dim = dim // 2
156 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
157 |
158 | def forward(self, x):
159 | x = rearrange(x, 'b -> b 1')
160 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
161 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
162 | fouriered = torch.cat((x, fouriered), dim = -1)
163 | return fouriered
164 |
165 | # building block modules
166 |
167 | class Block(nn.Module):
168 | def __init__(self, dim, dim_out, groups = 8):
169 | super().__init__()
170 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
171 | self.norm = nn.GroupNorm(groups, dim_out)
172 | self.act = nn.SiLU()
173 |
174 | def forward(self, x, scale_shift = None):
175 | x = self.proj(x)
176 | x = self.norm(x)
177 |
178 | if exists(scale_shift):
179 | scale, shift = scale_shift
180 | x = x * (scale + 1) + shift
181 |
182 | x = self.act(x)
183 | return x
184 |
185 | class ResnetBlock(nn.Module):
186 | def __init__(self, dim, dim_out, *, time_emb_dim = None, classes_emb_dim = None, groups = 8):
187 | super().__init__()
188 | self.mlp = nn.Sequential(
189 | nn.SiLU(),
190 | nn.Linear(int(time_emb_dim) + int(classes_emb_dim), dim_out * 2)
191 | ) if exists(time_emb_dim) or exists(classes_emb_dim) else None
192 |
193 | self.block1 = Block(dim, dim_out, groups = groups)
194 | self.block2 = Block(dim_out, dim_out, groups = groups)
195 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
196 |
197 | def forward(self, x, time_emb = None, class_emb = None):
198 |
199 | scale_shift = None
200 | if exists(self.mlp) and (exists(time_emb) or exists(class_emb)):
201 | cond_emb = tuple(filter(exists, (time_emb, class_emb)))
202 | cond_emb = torch.cat(cond_emb, dim = -1)
203 | cond_emb = self.mlp(cond_emb)
204 | cond_emb = rearrange(cond_emb, 'b c -> b c 1 1')
205 | scale_shift = cond_emb.chunk(2, dim = 1)
206 |
207 | h = self.block1(x, scale_shift = scale_shift)
208 |
209 | h = self.block2(h)
210 |
211 | return h + self.res_conv(x)
212 |
213 | class LinearAttention(nn.Module):
214 | def __init__(self, dim, heads = 4, dim_head = 32):
215 | super().__init__()
216 | self.scale = dim_head ** -0.5
217 | self.heads = heads
218 | hidden_dim = dim_head * heads
219 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
220 |
221 | self.to_out = nn.Sequential(
222 | nn.Conv2d(hidden_dim, dim, 1),
223 | LayerNorm(dim)
224 | )
225 |
226 | def forward(self, x):
227 | b, c, h, w = x.shape
228 | qkv = self.to_qkv(x).chunk(3, dim = 1)
229 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
230 |
231 | q = q.softmax(dim = -2)
232 | k = k.softmax(dim = -1)
233 |
234 | q = q * self.scale
235 | v = v / (h * w)
236 |
237 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
238 |
239 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
240 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
241 | return self.to_out(out)
242 |
243 | class Attention(nn.Module):
244 | def __init__(self, dim, heads = 4, dim_head = 32):
245 | super().__init__()
246 | self.scale = dim_head ** -0.5
247 | self.heads = heads
248 | hidden_dim = dim_head * heads
249 |
250 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
251 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
252 |
253 | def forward(self, x):
254 | b, c, h, w = x.shape
255 | qkv = self.to_qkv(x).chunk(3, dim = 1)
256 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
257 |
258 | q = q * self.scale
259 |
260 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
261 | attn = sim.softmax(dim = -1)
262 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
263 |
264 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
265 | return self.to_out(out)
266 |
267 | # model
268 |
269 | class Unet(nn.Module):
270 | def __init__(
271 | self,
272 | dim,
273 | num_classes,
274 | cond_drop_prob = 0.5,
275 | init_dim = None,
276 | out_dim = None,
277 | dim_mults=(1, 2, 4, 8),
278 | channels = 3,
279 | resnet_block_groups = 8,
280 | learned_variance = False,
281 | learned_sinusoidal_cond = False,
282 | random_fourier_features = False,
283 | learned_sinusoidal_dim = 16,
284 | ):
285 | super().__init__()
286 |
287 | # classifier free guidance stuff
288 |
289 | self.cond_drop_prob = cond_drop_prob
290 |
291 | # determine dimensions
292 |
293 | self.channels = channels
294 | input_channels = channels
295 |
296 | init_dim = default(init_dim, dim)
297 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
298 |
299 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
300 | in_out = list(zip(dims[:-1], dims[1:]))
301 |
302 | block_klass = partial(ResnetBlock, groups = resnet_block_groups)
303 |
304 | # time embeddings
305 |
306 | time_dim = dim * 4
307 |
308 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
309 |
310 | if self.random_or_learned_sinusoidal_cond:
311 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
312 | fourier_dim = learned_sinusoidal_dim + 1
313 | else:
314 | sinu_pos_emb = SinusoidalPosEmb(dim)
315 | fourier_dim = dim
316 |
317 | self.time_mlp = nn.Sequential(
318 | sinu_pos_emb,
319 | nn.Linear(fourier_dim, time_dim),
320 | nn.GELU(),
321 | nn.Linear(time_dim, time_dim)
322 | )
323 |
324 | # class embeddings
325 |
326 | self.classes_emb = nn.Embedding(num_classes, dim)
327 | self.null_classes_emb = nn.Parameter(torch.randn(dim))
328 |
329 | classes_dim = dim * 4
330 |
331 | self.classes_mlp = nn.Sequential(
332 | nn.Linear(dim, classes_dim),
333 | nn.GELU(),
334 | nn.Linear(classes_dim, classes_dim)
335 | )
336 |
337 | # layers
338 |
339 | self.downs = nn.ModuleList([])
340 | self.ups = nn.ModuleList([])
341 | num_resolutions = len(in_out)
342 |
343 | for ind, (dim_in, dim_out) in enumerate(in_out):
344 | is_last = ind >= (num_resolutions - 1)
345 |
346 | self.downs.append(nn.ModuleList([
347 | block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
348 | block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
349 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
350 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
351 | ]))
352 |
353 | mid_dim = dims[-1]
354 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
355 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
356 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
357 |
358 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
359 | is_last = ind == (len(in_out) - 1)
360 |
361 | self.ups.append(nn.ModuleList([
362 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
363 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
364 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
365 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
366 | ]))
367 |
368 | default_out_dim = channels * (1 if not learned_variance else 2)
369 | self.out_dim = default(out_dim, default_out_dim)
370 |
371 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
372 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
373 |
374 | def forward_with_cond_scale(
375 | self,
376 | *args,
377 | cond_scale = 1.,
378 | **kwargs
379 | ):
380 | logits = self.forward(*args, **kwargs)
381 |
382 | if cond_scale == 1:
383 | return logits
384 |
385 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
386 | return null_logits + (logits - null_logits) * cond_scale
387 |
388 | def forward(
389 | self,
390 | x,
391 | time,
392 | classes,
393 | cond_drop_prob = None
394 | ):
395 | batch, device = x.shape[0], x.device
396 |
397 | cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
398 |
399 | # derive condition, with condition dropout for classifier free guidance
400 |
401 | classes_emb = self.classes_emb(classes)
402 |
403 | if cond_drop_prob > 0:
404 | keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
405 | null_classes_emb = repeat(self.null_classes_emb, 'd -> b d', b = batch)
406 |
407 | classes_emb = torch.where(
408 | rearrange(keep_mask, 'b -> b 1'),
409 | classes_emb,
410 | null_classes_emb
411 | )
412 |
413 | c = self.classes_mlp(classes_emb)
414 |
415 | # unet
416 |
417 | x = self.init_conv(x)
418 | r = x.clone()
419 |
420 | t = self.time_mlp(time)
421 |
422 | h = []
423 |
424 | for block1, block2, attn, downsample in self.downs:
425 | x = block1(x, t, c)
426 | h.append(x)
427 |
428 | x = block2(x, t, c)
429 | x = attn(x)
430 | h.append(x)
431 |
432 | x = downsample(x)
433 |
434 | x = self.mid_block1(x, t, c)
435 | x = self.mid_attn(x)
436 | x = self.mid_block2(x, t, c)
437 |
438 | for block1, block2, attn, upsample in self.ups:
439 | x = torch.cat((x, h.pop()), dim = 1)
440 | x = block1(x, t, c)
441 |
442 | x = torch.cat((x, h.pop()), dim = 1)
443 | x = block2(x, t, c)
444 | x = attn(x)
445 |
446 | x = upsample(x)
447 |
448 | x = torch.cat((x, r), dim = 1)
449 |
450 | x = self.final_res_block(x, t, c)
451 | return self.final_conv(x)
452 |
453 | # gaussian diffusion trainer class
454 |
455 | def extract(a, t, x_shape):
456 | b, *_ = t.shape
457 | out = a.gather(-1, t)
458 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
459 |
460 | def linear_beta_schedule(timesteps):
461 | scale = 1000 / timesteps
462 | beta_start = scale * 0.0001
463 | beta_end = scale * 0.02
464 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
465 |
466 | def cosine_beta_schedule(timesteps, s = 0.008):
467 | """
468 | cosine schedule
469 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
470 | """
471 | steps = timesteps + 1
472 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
473 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
474 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
475 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
476 | return torch.clip(betas, 0, 0.999)
477 |
478 | class GaussianDiffusion(nn.Module):
479 | def __init__(
480 | self,
481 | model,
482 | *,
483 | image_size,
484 | timesteps = 1000,
485 | sampling_timesteps = None,
486 | loss_type = 'l1',
487 | objective = 'pred_noise',
488 | beta_schedule = 'cosine',
489 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
490 | p2_loss_weight_k = 1,
491 | ddim_sampling_eta = 1.
492 | ):
493 | super().__init__()
494 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
495 | assert not model.random_or_learned_sinusoidal_cond
496 |
497 | self.model = model
498 | self.channels = self.model.channels
499 |
500 | self.image_size = image_size
501 |
502 | self.objective = objective
503 |
504 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
505 |
506 | if beta_schedule == 'linear':
507 | betas = linear_beta_schedule(timesteps)
508 | elif beta_schedule == 'cosine':
509 | betas = cosine_beta_schedule(timesteps)
510 | else:
511 | raise ValueError(f'unknown beta schedule {beta_schedule}')
512 |
513 | alphas = 1. - betas
514 | alphas_cumprod = torch.cumprod(alphas, dim=0)
515 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
516 |
517 | timesteps, = betas.shape
518 | self.num_timesteps = int(timesteps)
519 | self.loss_type = loss_type
520 |
521 | # sampling related parameters
522 |
523 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
524 |
525 | assert self.sampling_timesteps <= timesteps
526 | self.is_ddim_sampling = self.sampling_timesteps < timesteps
527 | self.ddim_sampling_eta = ddim_sampling_eta
528 |
529 | # helper function to register buffer from float64 to float32
530 |
531 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
532 |
533 | register_buffer('betas', betas)
534 | register_buffer('alphas_cumprod', alphas_cumprod)
535 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
536 |
537 | # calculations for diffusion q(x_t | x_{t-1}) and others
538 |
539 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
540 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
541 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
542 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
543 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
544 |
545 | # calculations for posterior q(x_{t-1} | x_t, x_0)
546 |
547 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
548 |
549 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
550 |
551 | register_buffer('posterior_variance', posterior_variance)
552 |
553 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
554 |
555 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
556 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
557 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
558 |
559 | # calculate p2 reweighting
560 |
561 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
562 |
563 | def predict_start_from_noise(self, x_t, t, noise):
564 | return (
565 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
566 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
567 | )
568 |
569 | def predict_noise_from_start(self, x_t, t, x0):
570 | return (
571 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
572 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
573 | )
574 |
575 | def predict_v(self, x_start, t, noise):
576 | return (
577 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
578 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
579 | )
580 |
581 | def predict_start_from_v(self, x_t, t, v):
582 | return (
583 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
584 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
585 | )
586 |
587 | def q_posterior(self, x_start, x_t, t):
588 | posterior_mean = (
589 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
590 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
591 | )
592 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
593 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
594 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
595 |
596 | def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False):
597 | model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale)
598 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
599 |
600 | if self.objective == 'pred_noise':
601 | pred_noise = model_output
602 | x_start = self.predict_start_from_noise(x, t, pred_noise)
603 | x_start = maybe_clip(x_start)
604 |
605 | elif self.objective == 'pred_x0':
606 | x_start = model_output
607 | x_start = maybe_clip(x_start)
608 | pred_noise = self.predict_noise_from_start(x, t, x_start)
609 |
610 | elif self.objective == 'pred_v':
611 | v = model_output
612 | x_start = self.predict_start_from_v(x, t, v)
613 | x_start = maybe_clip(x_start)
614 | pred_noise = self.predict_noise_from_start(x, t, x_start)
615 |
616 | return ModelPrediction(pred_noise, x_start)
617 |
618 | def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
619 | preds = self.model_predictions(x, t, classes, cond_scale)
620 | x_start = preds.pred_x_start
621 |
622 | if clip_denoised:
623 | x_start.clamp_(-1., 1.)
624 |
625 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
626 | return model_mean, posterior_variance, posterior_log_variance, x_start
627 |
628 | @torch.no_grad()
629 | def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True):
630 | b, *_, device = *x.shape, x.device
631 | batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
632 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, classes = classes, cond_scale = cond_scale, clip_denoised = clip_denoised)
633 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
634 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
635 | return pred_img, x_start
636 |
637 | @torch.no_grad()
638 | def p_sample_loop(self, classes, shape, cond_scale = 3.):
639 | batch, device = shape[0], self.betas.device
640 |
641 | img = torch.randn(shape, device=device)
642 |
643 | x_start = None
644 |
645 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
646 | img, x_start = self.p_sample(img, t, classes, cond_scale)
647 |
648 | img = unnormalize_to_zero_to_one(img)
649 | return img
650 |
651 | @torch.no_grad()
652 | def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True):
653 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
654 |
655 | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
656 | times = list(reversed(times.int().tolist()))
657 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
658 |
659 | img = torch.randn(shape, device = device)
660 |
661 | x_start = None
662 |
663 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
664 | time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
665 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, classes, cond_scale = cond_scale, clip_x_start = clip_denoised)
666 |
667 | if time_next < 0:
668 | img = x_start
669 | continue
670 |
671 | alpha = self.alphas_cumprod[time]
672 | alpha_next = self.alphas_cumprod[time_next]
673 |
674 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
675 | c = (1 - alpha_next - sigma ** 2).sqrt()
676 |
677 | noise = torch.randn_like(img)
678 |
679 | img = x_start * alpha_next.sqrt() + \
680 | c * pred_noise + \
681 | sigma * noise
682 |
683 | img = unnormalize_to_zero_to_one(img)
684 | return img
685 |
686 | @torch.no_grad()
687 | def sample(self, classes, cond_scale = 3.):
688 | batch_size, image_size, channels = classes.shape[0], self.image_size, self.channels
689 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
690 | return sample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale)
691 |
692 | @torch.no_grad()
693 | def interpolate(self, x1, x2, t = None, lam = 0.5):
694 | b, *_, device = *x1.shape, x1.device
695 | t = default(t, self.num_timesteps - 1)
696 |
697 | assert x1.shape == x2.shape
698 |
699 | t_batched = torch.stack([torch.tensor(t, device = device)] * b)
700 | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
701 |
702 | img = (1 - lam) * xt1 + lam * xt2
703 | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
704 | img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
705 |
706 | return img
707 |
708 | def q_sample(self, x_start, t, noise=None):
709 | noise = default(noise, lambda: torch.randn_like(x_start))
710 |
711 | return (
712 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
713 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
714 | )
715 |
716 | @property
717 | def loss_fn(self):
718 | if self.loss_type == 'l1':
719 | return F.l1_loss
720 | elif self.loss_type == 'l2':
721 | return F.mse_loss
722 | else:
723 | raise ValueError(f'invalid loss type {self.loss_type}')
724 |
725 | def p_losses(self, x_start, t, *, classes, noise = None):
726 | b, c, h, w = x_start.shape
727 | noise = default(noise, lambda: torch.randn_like(x_start))
728 |
729 | # noise sample
730 |
731 | x = self.q_sample(x_start = x_start, t = t, noise = noise)
732 |
733 | # predict and take gradient step
734 |
735 | model_out = self.model(x, t, classes)
736 |
737 | if self.objective == 'pred_noise':
738 | target = noise
739 | elif self.objective == 'pred_x0':
740 | target = x_start
741 | elif self.objective == 'pred_v':
742 | v = self.predict_v(x_start, t, noise)
743 | target = v
744 | else:
745 | raise ValueError(f'unknown objective {self.objective}')
746 |
747 | loss = self.loss_fn(model_out, target, reduction = 'none')
748 | loss = reduce(loss, 'b ... -> b (...)', 'mean')
749 |
750 | loss = loss * extract(self.p2_loss_weight, t, loss.shape)
751 | return loss.mean()
752 |
753 | def forward(self, img, *args, **kwargs):
754 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
755 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
756 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
757 |
758 | img = normalize_to_neg_one_to_one(img)
759 | return self.p_losses(img, t, *args, **kwargs)
760 |
761 | # example
762 |
763 | if __name__ == '__main__':
764 | num_classes = 10
765 |
766 | model = Unet(
767 | dim = 64,
768 | dim_mults = (1, 2, 4, 8),
769 | num_classes = num_classes,
770 | cond_drop_prob = 0.5
771 | )
772 |
773 | diffusion = GaussianDiffusion(
774 | model,
775 | image_size = 128,
776 | timesteps = 1000
777 | ).cuda()
778 |
779 | training_images = torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1
780 | image_classes = torch.randint(0, num_classes, (8,)).cuda() # say 10 classes
781 |
782 | loss = diffusion(training_images, classes = image_classes)
783 | loss.backward()
784 |
785 | # do above for many steps
786 |
787 | sampled_images = diffusion.sample(
788 | classes = image_classes,
789 | cond_scale = 3. # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
790 | )
791 |
792 | sampled_images.shape # (8, 3, 128, 128)
793 |
--------------------------------------------------------------------------------
/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import copy
3 | from pathlib import Path
4 | from random import random
5 | from functools import partial
6 | from collections import namedtuple
7 | from multiprocessing import cpu_count
8 |
9 | import torch
10 | from torch import nn, einsum
11 | import torch.nn.functional as F
12 | from torch.utils.data import Dataset, DataLoader
13 |
14 | from torch.optim import Adam
15 | from torchvision import transforms as T, utils
16 |
17 | from einops import rearrange, reduce
18 | from einops.layers.torch import Rearrange
19 |
20 | from PIL import Image
21 | from tqdm.auto import tqdm
22 | from ema_pytorch import EMA
23 |
24 | from accelerate import Accelerator
25 |
26 | from denoising_diffusion_pytorch.version import __version__
27 |
28 | # constants
29 |
30 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
31 |
32 | # helpers functions
33 |
34 | def exists(x):
35 | return x is not None
36 |
37 | def default(val, d):
38 | if exists(val):
39 | return val
40 | return d() if callable(d) else d
41 |
42 | def identity(t, *args, **kwargs):
43 | return t
44 |
45 | def cycle(dl):
46 | while True:
47 | for data in dl:
48 | yield data
49 |
50 | def has_int_squareroot(num):
51 | return (math.sqrt(num) ** 2) == num
52 |
53 | def num_to_groups(num, divisor):
54 | groups = num // divisor
55 | remainder = num % divisor
56 | arr = [divisor] * groups
57 | if remainder > 0:
58 | arr.append(remainder)
59 | return arr
60 |
61 | def convert_image_to_fn(img_type, image):
62 | if image.mode != img_type:
63 | return image.convert(img_type)
64 | return image
65 |
66 | # normalization functions
67 |
68 | def normalize_to_neg_one_to_one(img):
69 | return img * 2 - 1
70 |
71 | def unnormalize_to_zero_to_one(t):
72 | return (t + 1) * 0.5
73 |
74 | # small helper modules
75 |
76 | class Residual(nn.Module):
77 | def __init__(self, fn):
78 | super().__init__()
79 | self.fn = fn
80 |
81 | def forward(self, x, *args, **kwargs):
82 | return self.fn(x, *args, **kwargs) + x
83 |
84 | def Upsample(dim, dim_out = None):
85 | return nn.Sequential(
86 | nn.Upsample(scale_factor = 2, mode = 'nearest'),
87 | nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
88 | )
89 |
90 | def Downsample(dim, dim_out = None):
91 | return nn.Sequential(
92 | Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
93 | nn.Conv2d(dim * 4, default(dim_out, dim), 1)
94 | )
95 |
96 | class WeightStandardizedConv2d(nn.Conv2d):
97 | """
98 | https://arxiv.org/abs/1903.10520
99 | weight standardization purportedly works synergistically with group normalization
100 | """
101 | def forward(self, x):
102 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
103 |
104 | weight = self.weight
105 | mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
106 | var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
107 | normalized_weight = (weight - mean) * (var + eps).rsqrt()
108 |
109 | return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110 |
111 | class LayerNorm(nn.Module):
112 | def __init__(self, dim):
113 | super().__init__()
114 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
115 |
116 | def forward(self, x):
117 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
118 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
119 | mean = torch.mean(x, dim = 1, keepdim = True)
120 | return (x - mean) * (var + eps).rsqrt() * self.g
121 |
122 | class PreNorm(nn.Module):
123 | def __init__(self, dim, fn):
124 | super().__init__()
125 | self.fn = fn
126 | self.norm = LayerNorm(dim)
127 |
128 | def forward(self, x):
129 | x = self.norm(x)
130 | return self.fn(x)
131 |
132 | # sinusoidal positional embeds
133 |
134 | class SinusoidalPosEmb(nn.Module):
135 | def __init__(self, dim):
136 | super().__init__()
137 | self.dim = dim
138 |
139 | def forward(self, x):
140 | device = x.device
141 | half_dim = self.dim // 2
142 | emb = math.log(10000) / (half_dim - 1)
143 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
144 | emb = x[:, None] * emb[None, :]
145 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
146 | return emb
147 |
148 | class RandomOrLearnedSinusoidalPosEmb(nn.Module):
149 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
150 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
151 |
152 | def __init__(self, dim, is_random = False):
153 | super().__init__()
154 | assert (dim % 2) == 0
155 | half_dim = dim // 2
156 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
157 |
158 | def forward(self, x):
159 | x = rearrange(x, 'b -> b 1')
160 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
161 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
162 | fouriered = torch.cat((x, fouriered), dim = -1)
163 | return fouriered
164 |
165 | # building block modules
166 |
167 | class Block(nn.Module):
168 | def __init__(self, dim, dim_out, groups = 8):
169 | super().__init__()
170 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
171 | self.norm = nn.GroupNorm(groups, dim_out)
172 | self.act = nn.SiLU()
173 |
174 | def forward(self, x, scale_shift = None):
175 | x = self.proj(x)
176 | x = self.norm(x)
177 |
178 | if exists(scale_shift):
179 | scale, shift = scale_shift
180 | x = x * (scale + 1) + shift
181 |
182 | x = self.act(x)
183 | return x
184 |
185 | class ResnetBlock(nn.Module):
186 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
187 | super().__init__()
188 | self.mlp = nn.Sequential(
189 | nn.SiLU(),
190 | nn.Linear(time_emb_dim, dim_out * 2)
191 | ) if exists(time_emb_dim) else None
192 |
193 | self.block1 = Block(dim, dim_out, groups = groups)
194 | self.block2 = Block(dim_out, dim_out, groups = groups)
195 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
196 |
197 | def forward(self, x, time_emb = None):
198 |
199 | scale_shift = None
200 | if exists(self.mlp) and exists(time_emb):
201 | time_emb = self.mlp(time_emb)
202 | time_emb = rearrange(time_emb, 'b c -> b c 1 1')
203 | scale_shift = time_emb.chunk(2, dim = 1)
204 |
205 | h = self.block1(x, scale_shift = scale_shift)
206 |
207 | h = self.block2(h)
208 |
209 | return h + self.res_conv(x)
210 |
211 | class LinearAttention(nn.Module):
212 | def __init__(self, dim, heads = 4, dim_head = 32):
213 | super().__init__()
214 | self.scale = dim_head ** -0.5
215 | self.heads = heads
216 | hidden_dim = dim_head * heads
217 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
218 |
219 | self.to_out = nn.Sequential(
220 | nn.Conv2d(hidden_dim, dim, 1),
221 | LayerNorm(dim)
222 | )
223 |
224 | def forward(self, x):
225 | b, c, h, w = x.shape
226 | qkv = self.to_qkv(x).chunk(3, dim = 1)
227 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
228 |
229 | q = q.softmax(dim = -2)
230 | k = k.softmax(dim = -1)
231 |
232 | q = q * self.scale
233 | v = v / (h * w)
234 |
235 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
236 |
237 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
238 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
239 | return self.to_out(out)
240 |
241 | class Attention(nn.Module):
242 | def __init__(self, dim, heads = 4, dim_head = 32):
243 | super().__init__()
244 | self.scale = dim_head ** -0.5
245 | self.heads = heads
246 | hidden_dim = dim_head * heads
247 |
248 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
249 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
250 |
251 | def forward(self, x):
252 | b, c, h, w = x.shape
253 | qkv = self.to_qkv(x).chunk(3, dim = 1)
254 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
255 |
256 | q = q * self.scale
257 |
258 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
259 | attn = sim.softmax(dim = -1)
260 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
261 |
262 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
263 | return self.to_out(out)
264 |
265 | # model
266 |
267 | class Unet(nn.Module):
268 | def __init__(
269 | self,
270 | dim,
271 | init_dim = None,
272 | out_dim = None,
273 | dim_mults=(1, 2, 4, 8),
274 | channels = 3,
275 | self_condition = False,
276 | resnet_block_groups = 8,
277 | learned_variance = False,
278 | learned_sinusoidal_cond = False,
279 | random_fourier_features = False,
280 | learned_sinusoidal_dim = 16
281 | ):
282 | super().__init__()
283 |
284 | # determine dimensions
285 |
286 | self.channels = channels
287 | self.self_condition = self_condition
288 | input_channels = channels * (2 if self_condition else 1)
289 |
290 | init_dim = default(init_dim, dim)
291 | self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
292 |
293 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
294 | in_out = list(zip(dims[:-1], dims[1:]))
295 |
296 | block_klass = partial(ResnetBlock, groups = resnet_block_groups)
297 |
298 | # time embeddings
299 |
300 | time_dim = dim * 4
301 |
302 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
303 |
304 | if self.random_or_learned_sinusoidal_cond:
305 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
306 | fourier_dim = learned_sinusoidal_dim + 1
307 | else:
308 | sinu_pos_emb = SinusoidalPosEmb(dim)
309 | fourier_dim = dim
310 |
311 | self.time_mlp = nn.Sequential(
312 | sinu_pos_emb,
313 | nn.Linear(fourier_dim, time_dim),
314 | nn.GELU(),
315 | nn.Linear(time_dim, time_dim)
316 | )
317 |
318 | # layers
319 |
320 | self.downs = nn.ModuleList([])
321 | self.ups = nn.ModuleList([])
322 | num_resolutions = len(in_out)
323 |
324 | for ind, (dim_in, dim_out) in enumerate(in_out):
325 | is_last = ind >= (num_resolutions - 1)
326 |
327 | self.downs.append(nn.ModuleList([
328 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
329 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
330 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
331 | Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
332 | ]))
333 |
334 | mid_dim = dims[-1]
335 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
336 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
337 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
338 |
339 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
340 | is_last = ind == (len(in_out) - 1)
341 |
342 | self.ups.append(nn.ModuleList([
343 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
344 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
345 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
346 | Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
347 | ]))
348 |
349 | default_out_dim = channels * (1 if not learned_variance else 2)
350 | self.out_dim = default(out_dim, default_out_dim)
351 |
352 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
353 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
354 |
355 | def forward(self, x, time, x_self_cond = None):
356 | if self.self_condition:
357 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
358 | x = torch.cat((x_self_cond, x), dim = 1)
359 |
360 | x = self.init_conv(x)
361 | r = x.clone()
362 |
363 | t = self.time_mlp(time)
364 |
365 | h = []
366 |
367 | for block1, block2, attn, downsample in self.downs:
368 | x = block1(x, t)
369 | h.append(x)
370 |
371 | x = block2(x, t)
372 | x = attn(x)
373 | h.append(x)
374 |
375 | x = downsample(x)
376 |
377 | x = self.mid_block1(x, t)
378 | x = self.mid_attn(x)
379 | x = self.mid_block2(x, t)
380 |
381 | for block1, block2, attn, upsample in self.ups:
382 | x = torch.cat((x, h.pop()), dim = 1)
383 | x = block1(x, t)
384 |
385 | x = torch.cat((x, h.pop()), dim = 1)
386 | x = block2(x, t)
387 | x = attn(x)
388 |
389 | x = upsample(x)
390 |
391 | x = torch.cat((x, r), dim = 1)
392 |
393 | x = self.final_res_block(x, t)
394 | return self.final_conv(x)
395 |
396 | # gaussian diffusion trainer class
397 |
398 | def extract(a, t, x_shape):
399 | b, *_ = t.shape
400 | out = a.gather(-1, t)
401 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
402 |
403 | def linear_beta_schedule(timesteps):
404 | """
405 | linear schedule, proposed in original ddpm paper
406 | """
407 | scale = 1000 / timesteps
408 | beta_start = scale * 0.0001
409 | beta_end = scale * 0.02
410 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
411 |
412 | def cosine_beta_schedule(timesteps, s = 0.008):
413 | """
414 | cosine schedule
415 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
416 | """
417 | steps = timesteps + 1
418 | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
419 | alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
420 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
421 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
422 | return torch.clip(betas, 0, 0.999)
423 |
424 | def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
425 | """
426 | sigmoid schedule
427 | proposed in https://arxiv.org/abs/2212.11972 - Figure 8
428 | better for images > 64x64, when used during training
429 | """
430 | steps = timesteps + 1
431 | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
432 | v_start = torch.tensor(start / tau).sigmoid()
433 | v_end = torch.tensor(end / tau).sigmoid()
434 | alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
435 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
436 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
437 | return torch.clip(betas, 0, 0.999)
438 |
439 | class GaussianDiffusion(nn.Module):
440 | def __init__(
441 | self,
442 | model,
443 | *,
444 | image_size,
445 | timesteps = 1000,
446 | sampling_timesteps = None,
447 | loss_type = 'l1',
448 | objective = 'pred_noise',
449 | beta_schedule = 'sigmoid',
450 | schedule_fn_kwargs = dict(),
451 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
452 | p2_loss_weight_k = 1,
453 | ddim_sampling_eta = 0.,
454 | auto_normalize = True
455 | ):
456 | super().__init__()
457 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
458 | assert not model.random_or_learned_sinusoidal_cond
459 |
460 | self.model = model
461 | self.channels = self.model.channels
462 | self.self_condition = self.model.self_condition
463 |
464 | self.image_size = image_size
465 |
466 | self.objective = objective
467 |
468 | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
469 |
470 | if beta_schedule == 'linear':
471 | beta_schedule_fn = linear_beta_schedule
472 | elif beta_schedule == 'cosine':
473 | beta_schedule_fn = cosine_beta_schedule
474 | elif beta_schedule == 'sigmoid':
475 | beta_schedule_fn = sigmoid_beta_schedule
476 | else:
477 | raise ValueError(f'unknown beta schedule {beta_schedule}')
478 |
479 | betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
480 |
481 | alphas = 1. - betas
482 | alphas_cumprod = torch.cumprod(alphas, dim=0)
483 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
484 |
485 | timesteps, = betas.shape
486 | self.num_timesteps = int(timesteps)
487 | self.loss_type = loss_type
488 |
489 | # sampling related parameters
490 |
491 | self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
492 |
493 | assert self.sampling_timesteps <= timesteps
494 | self.is_ddim_sampling = self.sampling_timesteps < timesteps
495 | self.ddim_sampling_eta = ddim_sampling_eta
496 |
497 | # helper function to register buffer from float64 to float32
498 |
499 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
500 |
501 | register_buffer('betas', betas)
502 | register_buffer('alphas_cumprod', alphas_cumprod)
503 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
504 |
505 | # calculations for diffusion q(x_t | x_{t-1}) and others
506 |
507 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
508 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
509 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
510 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
511 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
512 |
513 | # calculations for posterior q(x_{t-1} | x_t, x_0)
514 |
515 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
516 |
517 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
518 |
519 | register_buffer('posterior_variance', posterior_variance)
520 |
521 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
522 |
523 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
524 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
525 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
526 |
527 | # calculate p2 reweighting
528 |
529 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
530 |
531 | # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
532 |
533 | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
534 | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
535 |
536 | def predict_start_from_noise(self, x_t, t, noise):
537 | return (
538 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
539 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
540 | )
541 |
542 | def predict_noise_from_start(self, x_t, t, x0):
543 | return (
544 | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
545 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
546 | )
547 |
548 | def predict_v(self, x_start, t, noise):
549 | return (
550 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
551 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
552 | )
553 |
554 | def predict_start_from_v(self, x_t, t, v):
555 | return (
556 | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
557 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
558 | )
559 |
560 | def q_posterior(self, x_start, x_t, t):
561 | posterior_mean = (
562 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
563 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
564 | )
565 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
566 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
567 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
568 |
569 | def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False):
570 | model_output = self.model(x, t, x_self_cond)
571 | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
572 |
573 | if self.objective == 'pred_noise':
574 | pred_noise = model_output
575 | x_start = self.predict_start_from_noise(x, t, pred_noise)
576 | x_start = maybe_clip(x_start)
577 |
578 | elif self.objective == 'pred_x0':
579 | x_start = model_output
580 | x_start = maybe_clip(x_start)
581 | pred_noise = self.predict_noise_from_start(x, t, x_start)
582 |
583 | elif self.objective == 'pred_v':
584 | v = model_output
585 | x_start = self.predict_start_from_v(x, t, v)
586 | x_start = maybe_clip(x_start)
587 | pred_noise = self.predict_noise_from_start(x, t, x_start)
588 |
589 | return ModelPrediction(pred_noise, x_start)
590 |
591 | def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
592 | preds = self.model_predictions(x, t, x_self_cond)
593 | x_start = preds.pred_x_start
594 |
595 | if clip_denoised:
596 | x_start.clamp_(-1., 1.)
597 |
598 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
599 | return model_mean, posterior_variance, posterior_log_variance, x_start
600 |
601 | @torch.no_grad()
602 | def p_sample(self, x, t: int, x_self_cond = None):
603 | b, *_, device = *x.shape, x.device
604 | batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
605 | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
606 | noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
607 | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
608 | return pred_img, x_start
609 |
610 | @torch.no_grad()
611 | def p_sample_loop(self, shape, return_all_timesteps = False):
612 | batch, device = shape[0], self.betas.device
613 |
614 | img = torch.randn(shape, device = device)
615 | imgs = [img]
616 |
617 | x_start = None
618 |
619 | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
620 | self_cond = x_start if self.self_condition else None
621 | img, x_start = self.p_sample(img, t, self_cond)
622 | imgs.append(img)
623 |
624 | ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
625 |
626 | ret = self.unnormalize(ret)
627 | return ret
628 |
629 | @torch.no_grad()
630 | def ddim_sample(self, shape, return_all_timesteps = False):
631 | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
632 |
633 | times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
634 | times = list(reversed(times.int().tolist()))
635 | time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
636 |
637 | img = torch.randn(shape, device = device)
638 | imgs = [img]
639 |
640 | x_start = None
641 |
642 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
643 | time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
644 | self_cond = x_start if self.self_condition else None
645 | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True)
646 |
647 | imgs.append(img)
648 |
649 | if time_next < 0:
650 | img = x_start
651 | continue
652 |
653 | alpha = self.alphas_cumprod[time]
654 | alpha_next = self.alphas_cumprod[time_next]
655 |
656 | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
657 | c = (1 - alpha_next - sigma ** 2).sqrt()
658 |
659 | noise = torch.randn_like(img)
660 |
661 | img = x_start * alpha_next.sqrt() + \
662 | c * pred_noise + \
663 | sigma * noise
664 |
665 | ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
666 |
667 | ret = self.unnormalize(ret)
668 | return ret
669 |
670 | @torch.no_grad()
671 | def sample(self, batch_size = 16, return_all_timesteps = False):
672 | image_size, channels = self.image_size, self.channels
673 | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
674 | return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
675 |
676 | @torch.no_grad()
677 | def interpolate(self, x1, x2, t = None, lam = 0.5):
678 | b, *_, device = *x1.shape, x1.device
679 | t = default(t, self.num_timesteps - 1)
680 |
681 | assert x1.shape == x2.shape
682 |
683 | t_batched = torch.full((b,), t, device = device)
684 | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
685 |
686 | img = (1 - lam) * xt1 + lam * xt2
687 |
688 | x_start = None
689 |
690 | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
691 | self_cond = x_start if self.self_condition else None
692 | img, x_start = self.p_sample(img, i, self_cond)
693 |
694 | return img
695 |
696 | def q_sample(self, x_start, t, noise=None):
697 | noise = default(noise, lambda: torch.randn_like(x_start))
698 |
699 | return (
700 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
701 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
702 | )
703 |
704 | @property
705 | def loss_fn(self):
706 | if self.loss_type == 'l1':
707 | return F.l1_loss
708 | elif self.loss_type == 'l2':
709 | return F.mse_loss
710 | else:
711 | raise ValueError(f'invalid loss type {self.loss_type}')
712 |
713 | def p_losses(self, x_start, t, noise = None):
714 | b, c, h, w = x_start.shape
715 | noise = default(noise, lambda: torch.randn_like(x_start))
716 |
717 | # noise sample
718 |
719 | x = self.q_sample(x_start = x_start, t = t, noise = noise)
720 |
721 | # if doing self-conditioning, 50% of the time, predict x_start from current set of times
722 | # and condition with unet with that
723 | # this technique will slow down training by 25%, but seems to lower FID significantly
724 |
725 | x_self_cond = None
726 | if self.self_condition and random() < 0.5:
727 | with torch.no_grad():
728 | x_self_cond = self.model_predictions(x, t).pred_x_start
729 | x_self_cond.detach_()
730 |
731 | # predict and take gradient step
732 |
733 | model_out = self.model(x, t, x_self_cond)
734 |
735 | if self.objective == 'pred_noise':
736 | target = noise
737 | elif self.objective == 'pred_x0':
738 | target = x_start
739 | elif self.objective == 'pred_v':
740 | v = self.predict_v(x_start, t, noise)
741 | target = v
742 | else:
743 | raise ValueError(f'unknown objective {self.objective}')
744 |
745 | loss = self.loss_fn(model_out, target, reduction = 'none')
746 | loss = reduce(loss, 'b ... -> b (...)', 'mean')
747 |
748 | loss = loss * extract(self.p2_loss_weight, t, loss.shape)
749 | return loss.mean()
750 |
751 | def forward(self, img, *args, **kwargs):
752 | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
753 | assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
754 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
755 |
756 | img = self.normalize(img)
757 | return self.p_losses(img, t, *args, **kwargs)
758 |
759 | # dataset classes
760 |
761 | class Dataset(Dataset):
762 | def __init__(
763 | self,
764 | folder,
765 | image_size,
766 | exts = ['jpg', 'jpeg', 'png', 'tiff'],
767 | augment_horizontal_flip = False,
768 | convert_image_to = None
769 | ):
770 | super().__init__()
771 | self.folder = folder
772 | self.image_size = image_size
773 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
774 |
775 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
776 |
777 | self.transform = T.Compose([
778 | T.Lambda(maybe_convert_fn),
779 | T.Resize(image_size),
780 | T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
781 | T.CenterCrop(image_size),
782 | T.ToTensor()
783 | ])
784 |
785 | def __len__(self):
786 | return len(self.paths)
787 |
788 | def __getitem__(self, index):
789 | path = self.paths[index]
790 | img = Image.open(path)
791 | return self.transform(img)
792 |
793 | # trainer class
794 |
795 | class Trainer(object):
796 | def __init__(
797 | self,
798 | diffusion_model,
799 | folder,
800 | *,
801 | train_batch_size = 16,
802 | gradient_accumulate_every = 1,
803 | augment_horizontal_flip = True,
804 | train_lr = 1e-4,
805 | train_num_steps = 100000,
806 | ema_update_every = 10,
807 | ema_decay = 0.995,
808 | adam_betas = (0.9, 0.99),
809 | save_and_sample_every = 1000,
810 | num_samples = 25,
811 | results_folder = './results',
812 | amp = False,
813 | fp16 = False,
814 | split_batches = True,
815 | convert_image_to = None
816 | ):
817 | super().__init__()
818 |
819 | self.accelerator = Accelerator(
820 | split_batches = split_batches,
821 | mixed_precision = 'fp16' if fp16 else 'no'
822 | )
823 |
824 | self.accelerator.native_amp = amp
825 |
826 | self.model = diffusion_model
827 |
828 | assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
829 | self.num_samples = num_samples
830 | self.save_and_sample_every = save_and_sample_every
831 |
832 | self.batch_size = train_batch_size
833 | self.gradient_accumulate_every = gradient_accumulate_every
834 |
835 | self.train_num_steps = train_num_steps
836 | self.image_size = diffusion_model.image_size
837 |
838 | # dataset and dataloader
839 |
840 | self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
841 | dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
842 |
843 | dl = self.accelerator.prepare(dl)
844 | self.dl = cycle(dl)
845 |
846 | # optimizer
847 |
848 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
849 |
850 | # for logging results in a folder periodically
851 |
852 | if self.accelerator.is_main_process:
853 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
854 |
855 | self.results_folder = Path(results_folder)
856 | self.results_folder.mkdir(exist_ok = True)
857 |
858 | # step counter state
859 |
860 | self.step = 0
861 |
862 | # prepare model, dataloader, optimizer with accelerator
863 |
864 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
865 |
866 | def save(self, milestone):
867 | if not self.accelerator.is_local_main_process:
868 | return
869 |
870 | data = {
871 | 'step': self.step,
872 | 'model': self.accelerator.get_state_dict(self.model),
873 | 'opt': self.opt.state_dict(),
874 | 'ema': self.ema.state_dict(),
875 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
876 | 'version': __version__
877 | }
878 |
879 | torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
880 |
881 | def load(self, milestone):
882 | accelerator = self.accelerator
883 | device = accelerator.device
884 |
885 | data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)
886 |
887 | model = self.accelerator.unwrap_model(self.model)
888 | model.load_state_dict(data['model'])
889 |
890 | self.step = data['step']
891 | self.opt.load_state_dict(data['opt'])
892 | self.ema.load_state_dict(data['ema'])
893 |
894 | if 'version' in data:
895 | print(f"loading from version {data['version']}")
896 |
897 | if exists(self.accelerator.scaler) and exists(data['scaler']):
898 | self.accelerator.scaler.load_state_dict(data['scaler'])
899 |
900 | def train(self):
901 | accelerator = self.accelerator
902 | device = accelerator.device
903 |
904 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
905 |
906 | while self.step < self.train_num_steps:
907 |
908 | total_loss = 0.
909 |
910 | for _ in range(self.gradient_accumulate_every):
911 | data = next(self.dl).to(device)
912 |
913 | with self.accelerator.autocast():
914 | loss = self.model(data)
915 | loss = loss / self.gradient_accumulate_every
916 | total_loss += loss.item()
917 |
918 | self.accelerator.backward(loss)
919 |
920 | accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
921 | pbar.set_description(f'loss: {total_loss:.4f}')
922 |
923 | accelerator.wait_for_everyone()
924 |
925 | self.opt.step()
926 | self.opt.zero_grad()
927 |
928 | accelerator.wait_for_everyone()
929 |
930 | self.step += 1
931 | if accelerator.is_main_process:
932 | self.ema.to(device)
933 | self.ema.update()
934 |
935 | if self.step != 0 and self.step % self.save_and_sample_every == 0:
936 | self.ema.ema_model.eval()
937 |
938 | with torch.no_grad():
939 | milestone = self.step // self.save_and_sample_every
940 | batches = num_to_groups(self.num_samples, self.batch_size)
941 | all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
942 |
943 | all_images = torch.cat(all_images_list, dim = 0)
944 | utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
945 | self.save(milestone)
946 |
947 | pbar.update(1)
948 |
949 | accelerator.print('training complete')
950 |
--------------------------------------------------------------------------------