├── ddim-tutorial ├── README.md ├── ddim_sampling.py └── interpolation.py ├── ddpm-tutorial ├── README.md ├── ddpm_sampling.py └── ddpm_training.py ├── iddpm-tutorial ├── README.md ├── iddpm_sampling.py └── iddpm_training.py └── LICENSE /ddim-tutorial/README.md: -------------------------------------------------------------------------------- 1 | See: [original post](https://littlenyima.github.io/posts/14-denoising-diffusion-implicit-models/) -------------------------------------------------------------------------------- /ddpm-tutorial/README.md: -------------------------------------------------------------------------------- 1 | See: [original post](https://littlenyima.github.io/posts/13-denoising-diffusion-probabilistic-models/) -------------------------------------------------------------------------------- /iddpm-tutorial/README.md: -------------------------------------------------------------------------------- 1 | See: [original post](https://littlenyima.github.io/posts/15-improved-denoising-diffusion-probabilistic-models/) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LittleNyima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ddpm-tutorial/ddpm_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from diffusers import UNet2DModel 4 | 5 | class DDPM: 6 | def __init__( 7 | self, 8 | num_train_timesteps:int = 1000, 9 | beta_start: float = 0.0001, 10 | beta_end: float = 0.02, 11 | ): 12 | self.num_train_timesteps = num_train_timesteps 13 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 14 | self.alphas = 1.0 - self.betas 15 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 16 | self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) 17 | 18 | @torch.no_grad() 19 | def sample( 20 | self, 21 | unet: UNet2DModel, 22 | batch_size: int, 23 | in_channels: int, 24 | sample_size: int, 25 | ): 26 | betas = self.betas.to(unet.device) 27 | alphas = self.alphas.to(unet.device) 28 | alphas_cumprod = self.alphas_cumprod.to(unet.device) 29 | timesteps = self.timesteps.to(unet.device) 30 | images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device) 31 | for timestep in tqdm(timesteps, desc='Sampling'): 32 | pred_noise: torch.Tensor = unet(images, timestep).sample 33 | 34 | # mean of q(x_{t-1}|x_t) 35 | alpha_t = alphas[timestep] 36 | alpha_cumprod_t = alphas_cumprod[timestep] 37 | sqrt_alpha_t = alpha_t ** 0.5 38 | one_minus_alpha_t = 1.0 - alpha_t 39 | sqrt_one_minus_alpha_cumprod_t = (1 - alpha_cumprod_t) ** 0.5 40 | mean = (images - one_minus_alpha_t / sqrt_one_minus_alpha_cumprod_t * pred_noise) / sqrt_alpha_t 41 | 42 | # variance of q(x_{t-1}|x_t) 43 | if timestep > 0: 44 | beta_t = betas[timestep] 45 | one_minus_alpha_cumprod_t_minus_one = 1.0 - alphas_cumprod[timestep - 1] 46 | one_divided_by_sigma_square = alpha_t / beta_t + 1.0 / one_minus_alpha_cumprod_t_minus_one 47 | variance = (1.0 / one_divided_by_sigma_square) ** 0.5 48 | else: 49 | variance = torch.zeros_like(timestep) 50 | 51 | epsilon = torch.randn_like(images) 52 | images = mean + variance * epsilon 53 | images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() 54 | return images 55 | 56 | model = UNet2DModel.from_pretrained('ddpm-animefaces-64').cuda() 57 | ddpm = DDPM() 58 | images = ddpm.sample(model, 32, 3, 64) 59 | 60 | from diffusers.utils import make_image_grid, numpy_to_pil 61 | image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8) 62 | image_grid.save('ddpm-sample-results.png') 63 | -------------------------------------------------------------------------------- /ddim-tutorial/ddim_sampling.py: -------------------------------------------------------------------------------- 1 | # model 2 | 3 | from diffusers import UNet2DModel 4 | 5 | model = UNet2DModel.from_pretrained('ddpm-anime-faces-64').cuda() 6 | 7 | 8 | # core 9 | 10 | import torch 11 | import math 12 | from tqdm import tqdm 13 | 14 | class DDIM: 15 | def __init__( 16 | self, 17 | num_train_timesteps:int = 1000, 18 | beta_start: float = 0.0001, 19 | beta_end: float = 0.02, 20 | sample_steps: int = 20, 21 | ): 22 | self.num_train_timesteps = num_train_timesteps 23 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 24 | self.alphas = 1.0 - self.betas 25 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 26 | self.timesteps = torch.linspace(num_train_timesteps - 1, 0, sample_steps).long() 27 | 28 | @torch.no_grad() 29 | def sample( 30 | self, 31 | unet: UNet2DModel, 32 | batch_size: int, 33 | in_channels: int, 34 | sample_size: int, 35 | eta: float = 0.0, 36 | ): 37 | alphas = self.alphas.to(unet.device) 38 | alphas_cumprod = self.alphas_cumprod.to(unet.device) 39 | timesteps = self.timesteps.to(unet.device) 40 | images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device) 41 | for t, tau in tqdm(list(zip(timesteps[:-1], timesteps[1:])), desc='Sampling'): 42 | pred_noise: torch.Tensor = unet(images, t).sample 43 | 44 | # sigma_t 45 | if not math.isclose(eta, 0.0): 46 | one_minus_alpha_prod_tau = 1.0 - alphas_cumprod[tau] 47 | one_minus_alpha_prod_t = 1.0 - alphas_cumprod[t] 48 | one_minus_alpha_t = 1.0 - alphas[t] 49 | sigma_t = eta * (one_minus_alpha_prod_tau * one_minus_alpha_t / one_minus_alpha_prod_t) ** 0.5 50 | else: 51 | sigma_t = torch.zeros_like(alphas[0]) 52 | 53 | # first term of x_tau 54 | alphas_cumprod_tau = alphas_cumprod[tau] 55 | sqrt_alphas_cumprod_tau = alphas_cumprod_tau ** 0.5 56 | alphas_cumprod_t = alphas_cumprod[t] 57 | sqrt_alphas_cumprod_t = alphas_cumprod_t ** 0.5 58 | sqrt_one_minus_alphas_cumprod_t = (1.0 - alphas_cumprod_t) ** 0.5 59 | first_term = sqrt_alphas_cumprod_tau * (images - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t 60 | 61 | # second term of x_tau 62 | coeff = (1.0 - alphas_cumprod_tau - sigma_t ** 2) ** 0.5 63 | second_term = coeff * pred_noise 64 | 65 | epsilon = torch.randn_like(images) 66 | images = first_term + second_term + sigma_t * epsilon 67 | images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() 68 | return images 69 | 70 | ddim = DDIM() 71 | images = ddim.sample(model, 32, 3, 64) 72 | 73 | from diffusers.utils import make_image_grid, numpy_to_pil 74 | image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8) 75 | image_grid.save('ddim-sample-results.png') 76 | -------------------------------------------------------------------------------- /iddpm-tutorial/iddpm_sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | from tqdm import tqdm 5 | from diffusers import UNet2DModel 6 | 7 | def make_betas_cosine_schedule( 8 | num_diffusion_timesteps: int = 1000, 9 | beta_max: float = 0.999, 10 | s: float = 8e-3, 11 | ): 12 | fn = lambda t: math.cos((t + s) / (1 + s) * math.pi / 2) ** 2 13 | betas = [] 14 | for i in range(num_diffusion_timesteps): 15 | t1 = i / num_diffusion_timesteps 16 | t2 = (i + 1) / num_diffusion_timesteps 17 | betas.append(1.0 - fn(t2) / fn(t1)) 18 | return torch.tensor(betas, dtype=torch.float32).clamp_max(beta_max) 19 | 20 | def extract(arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size): 21 | arr = arr[timesteps] 22 | while len(arr.shape) < len(broadcast_shape): 23 | arr = arr.unsqueeze(-1) 24 | return arr.expand(broadcast_shape) 25 | 26 | class IDDPM: 27 | 28 | def __init__( 29 | self, 30 | num_diffusion_timesteps: int = 1000, 31 | beta_max: float = 0.999, 32 | ): 33 | self.num_diffusion_timesteps = num_diffusion_timesteps 34 | self.betas = make_betas_cosine_schedule(num_diffusion_timesteps, beta_max) 35 | self.alphas = 1.0 - self.betas 36 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 37 | self.alphas_cumprod_prev = torch.concat((torch.ones(1).to(self.alphas_cumprod), self.alphas_cumprod[:-1])) 38 | self.alphas_cumprod_next = torch.concat((self.alphas_cumprod[1:], torch.zeros(1).to(self.alphas_cumprod))) 39 | self.timesteps = torch.arange(num_diffusion_timesteps - 1, -1, -1) 40 | 41 | @torch.no_grad() 42 | def sample( 43 | self, 44 | unet: UNet2DModel, 45 | batch_size: int, 46 | in_channels: int, 47 | sample_size: int, 48 | ): 49 | images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device) 50 | 51 | betas = self.betas.to(unet.device) 52 | alphas = self.alphas.to(unet.device) 53 | alphas_cumprod = self.alphas_cumprod.to(unet.device) 54 | alphas_cumprod_prev = self.alphas_cumprod_prev.to(unet.device) 55 | timesteps = self.timesteps.to(unet.device) 56 | 57 | sqrt_recip_alphas_cumprod = (1.0 / alphas_cumprod) ** 0.5 58 | sqrt_recipm1_alphas_cumprod = (1.0 / alphas_cumprod - 1.0) ** 0.5 59 | 60 | posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 61 | posterior_log_variance_clipped = torch.log(torch.concat((posterior_variance[1:2], posterior_variance[1:]))) 62 | posterior_mean_coef1 = betas * alphas_cumprod_prev ** 0.5 / (1.0 - alphas_cumprod) 63 | posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas ** 0.5 / (1.0 - alphas_cumprod) 64 | 65 | for timestep in tqdm(timesteps, desc='Sampling'): 66 | _extract = partial(extract, timesteps=timestep, broadcast_shape=images.shape) 67 | preds: torch.Tensor = unet(images, timestep).sample 68 | pred_noises, pred_vars = torch.split(preds, in_channels, dim=1) 69 | 70 | # mean of p(x_{t-1}|x_t), same to DDPM 71 | x_0 = _extract(sqrt_recip_alphas_cumprod) * images - _extract(sqrt_recipm1_alphas_cumprod) * pred_noises 72 | mean = _extract(posterior_mean_coef1) * x_0.clamp(-1, 1) + _extract(posterior_mean_coef2) * images 73 | 74 | # variance of p(x_{t-1}|x_t), learned 75 | if timestep > 0: 76 | min_log = _extract(posterior_log_variance_clipped) 77 | max_log = _extract(torch.log(betas)) 78 | frac = (pred_vars + 1.0) / 2.0 79 | log_variance = frac * max_log + (1.0 - frac) * min_log 80 | stddev = torch.exp(0.5 * log_variance) 81 | else: 82 | stddev = torch.zeros_like(timestep) 83 | 84 | epsilon = torch.randn_like(images) 85 | images = mean + stddev * epsilon 86 | images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() 87 | return images 88 | 89 | model = UNet2DModel.from_pretrained('iddpm-animefaces-64').cuda() 90 | ddpm = IDDPM() 91 | images = ddpm.sample(model, 32, 3, 64) 92 | 93 | from diffusers.utils import make_image_grid, numpy_to_pil 94 | image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8) 95 | image_grid.save('iddpm-sample-results.png') 96 | -------------------------------------------------------------------------------- /ddim-tutorial/interpolation.py: -------------------------------------------------------------------------------- 1 | # model 2 | 3 | from diffusers import UNet2DModel 4 | 5 | model = UNet2DModel.from_pretrained('ddpm-anime-faces-64').cuda() 6 | 7 | 8 | # interpolation 9 | 10 | import torch 11 | 12 | def slerp( 13 | x0: torch.Tensor, 14 | x1: torch.Tensor, 15 | alpha: float, 16 | ): 17 | # todo: separate each batch? 18 | theta = torch.acos(torch.sum(x0 * x1) / (torch.norm(x0) * torch.norm(x1))) 19 | w0 = torch.sin((1.0 - alpha) * theta) / torch.sin(theta) 20 | w1 = torch.sin(alpha * theta) / torch.sin(theta) 21 | return w0 * x0 + w1 * x1 22 | 23 | def interpolation_grid( 24 | rows: int, 25 | cols: int, 26 | in_channels: int, 27 | sample_size: int, 28 | ): 29 | images = torch.zeros((rows * cols, in_channels, sample_size, sample_size), dtype=torch.float32) 30 | images[0, ...] = torch.randn_like(images[0, ...]) # top left 31 | images[cols - 1, ...] = torch.randn_like(images[0, ...]) # top right 32 | images[(rows - 1) * cols, ...] = torch.randn_like(images[0, ...]) # bottom left 33 | images[-1] = torch.randn_like(images[0, ...]) # bottom right 34 | for row in range(1, rows - 1): # interpolate left most column and right most column 35 | alpha = row / (rows - 1) 36 | images[row * cols, ...] = slerp(images[0, ...], images[(rows - 1) * cols, ...], alpha) 37 | images[(row + 1) * cols - 1, ...] = slerp(images[cols - 1, ...], images[-1, ...], alpha) 38 | for col in range(1, cols - 1): # interpolate others 39 | alpha = col / (cols - 1) 40 | images[col::cols, ...] = slerp(images[0::cols, ...], images[cols - 1::cols, ...], alpha) 41 | return images 42 | 43 | 44 | # core 45 | 46 | import torch 47 | import math 48 | from tqdm import tqdm 49 | 50 | class DDIM: 51 | def __init__( 52 | self, 53 | num_train_timesteps:int = 1000, 54 | beta_start: float = 0.0001, 55 | beta_end: float = 0.02, 56 | sample_steps: int = 20, 57 | ): 58 | self.num_train_timesteps = num_train_timesteps 59 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 60 | self.alphas = 1.0 - self.betas 61 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 62 | self.timesteps = torch.linspace(num_train_timesteps - 1, 0, sample_steps).long() 63 | 64 | @torch.no_grad() 65 | def sample( 66 | self, 67 | unet: UNet2DModel, 68 | rows: int, 69 | cols: int, 70 | in_channels: int, 71 | sample_size: int, 72 | eta: float = 0.0, 73 | ): 74 | alphas = self.alphas.to(unet.device) 75 | alphas_cumprod = self.alphas_cumprod.to(unet.device) 76 | timesteps = self.timesteps.to(unet.device) 77 | images = interpolation_grid(rows, cols, in_channels, sample_size).to(unet.device) 78 | for t, tau in tqdm(list(zip(timesteps[:-1], timesteps[1:])), desc='Sampling'): 79 | pred_noise: torch.Tensor = unet(images, t).sample 80 | 81 | # sigma_t 82 | if not math.isclose(eta, 0.0): 83 | one_minus_alpha_prod_tau = 1.0 - alphas_cumprod[tau] 84 | one_minus_alpha_prod_t = 1.0 - alphas_cumprod[t] 85 | one_minus_alpha_t = 1.0 - alphas[t] 86 | sigma_t = eta * (one_minus_alpha_prod_tau * one_minus_alpha_t / one_minus_alpha_prod_t) ** 0.5 87 | else: 88 | sigma_t = torch.zeros_like(alphas[0]) 89 | 90 | # first term of x_tau 91 | alphas_cumprod_tau = alphas_cumprod[tau] 92 | sqrt_alphas_cumprod_tau = alphas_cumprod_tau ** 0.5 93 | alphas_cumprod_t = alphas_cumprod[t] 94 | sqrt_alphas_cumprod_t = alphas_cumprod_t ** 0.5 95 | sqrt_one_minus_alphas_cumprod_t = (1.0 - alphas_cumprod_t) ** 0.5 96 | first_term = sqrt_alphas_cumprod_tau * (images - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t 97 | 98 | # second term of x_tau 99 | coeff = (1.0 - alphas_cumprod_tau - sigma_t ** 2) ** 0.5 100 | second_term = coeff * pred_noise 101 | 102 | epsilon = torch.randn_like(images) 103 | images = first_term + second_term + sigma_t * epsilon 104 | images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() 105 | return images 106 | 107 | ddim = DDIM() 108 | images = ddim.sample(model, 4, 8, 3, 64) 109 | 110 | from diffusers.utils import make_image_grid, numpy_to_pil 111 | image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8) 112 | image_grid.save('ddim-interpolation-results.png') 113 | -------------------------------------------------------------------------------- /ddpm-tutorial/ddpm_training.py: -------------------------------------------------------------------------------- 1 | # config 2 | 3 | from dataclasses import dataclass 4 | 5 | @dataclass 6 | class TrainingConfig: 7 | image_size = 64 8 | train_batch_size = 16 9 | eval_batch_size = 16 10 | num_epochs = 50 11 | gradient_accumulation_steps = 1 12 | learning_rate = 1e-4 13 | lr_warmup_steps = 500 14 | mixed_precision = "fp16" 15 | output_dir = "ddpm-animefaces-64" 16 | overwrite_output_dir = True 17 | 18 | config = TrainingConfig() 19 | 20 | # data 21 | 22 | from datasets import load_dataset 23 | 24 | dataset = load_dataset("huggan/anime-faces", split="train") 25 | dataset = dataset.select(range(21551)) 26 | 27 | ## preprocess 28 | 29 | from torchvision import transforms 30 | 31 | def get_transform(): 32 | preprocess = transforms.Compose([ 33 | transforms.Resize(config.image_size), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5], [0.5]), 36 | ]) 37 | def transform(samples): 38 | images = [preprocess(img.convert("RGB")) for img in samples["image"]] 39 | return dict(images=images) 40 | return transform 41 | 42 | dataset.set_transform(get_transform()) 43 | 44 | ## dataloader 45 | 46 | from torch.utils.data import DataLoader 47 | 48 | dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) 49 | 50 | from diffusers import UNet2DModel 51 | 52 | model = UNet2DModel( 53 | sample_size=config.image_size, 54 | in_channels=3, 55 | out_channels=3, 56 | layers_per_block=2, 57 | block_out_channels=(128, 128, 256, 256, 512, 512), 58 | down_block_types=( 59 | "DownBlock2D", 60 | "DownBlock2D", 61 | "DownBlock2D", 62 | "DownBlock2D", 63 | "AttnDownBlock2D", 64 | "DownBlock2D", 65 | ), 66 | up_block_types=( 67 | "UpBlock2D", 68 | "AttnUpBlock2D", 69 | "UpBlock2D", 70 | "UpBlock2D", 71 | "UpBlock2D", 72 | "UpBlock2D", 73 | ), 74 | ) 75 | 76 | # core 77 | 78 | import torch 79 | from tqdm import tqdm 80 | 81 | class DDPM: 82 | def __init__( 83 | self, 84 | num_train_timesteps:int = 1000, 85 | beta_start: float = 0.0001, 86 | beta_end: float = 0.02, 87 | ): 88 | self.num_train_timesteps = num_train_timesteps 89 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 90 | self.alphas = 1.0 - self.betas 91 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 92 | self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) 93 | 94 | def add_noise( 95 | self, 96 | original_samples: torch.Tensor, 97 | noise: torch.Tensor, 98 | timesteps: torch.Tensor, 99 | ): 100 | alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device ,dtype=original_samples.dtype) 101 | noise = noise.to(original_samples.device) 102 | timesteps = timesteps.to(original_samples.device) 103 | 104 | # \sqrt{\bar\alpha_t} 105 | sqrt_alpha_prod = alphas_cumprod[timesteps].flatten() ** 0.5 106 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 107 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 108 | 109 | # \sqrt{1 - \bar\alpha_t} 110 | sqrt_one_minus_alpha_prod = (1.0 - alphas_cumprod[timesteps]).flatten() ** 0.5 111 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 112 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 113 | 114 | return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 115 | 116 | @torch.no_grad() 117 | def sample( 118 | self, 119 | unet: UNet2DModel, 120 | batch_size: int, 121 | in_channels: int, 122 | sample_size: int, 123 | ): 124 | betas = self.betas.to(unet.device) 125 | alphas = self.alphas.to(unet.device) 126 | alphas_cumprod = self.alphas_cumprod.to(unet.device) 127 | timesteps = self.timesteps.to(unet.device) 128 | images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device) 129 | for timestep in tqdm(timesteps, desc='Sampling'): 130 | pred_noise: torch.Tensor = unet(images, timestep).sample 131 | 132 | # mean of q(x_{t-1}|x_t) 133 | alpha_t = alphas[timestep] 134 | alpha_cumprod_t = alphas_cumprod[timestep] 135 | sqrt_alpha_t = alpha_t ** 0.5 136 | one_minus_alpha_t = 1.0 - alpha_t 137 | sqrt_one_minus_alpha_cumprod_t = (1 - alpha_cumprod_t) ** 0.5 138 | mean = (images - one_minus_alpha_t / sqrt_one_minus_alpha_cumprod_t * pred_noise) / sqrt_alpha_t 139 | 140 | # variance of q(x_{t-1}|x_t) 141 | if timestep > 0: 142 | beta_t = betas[timestep] 143 | one_minus_alpha_cumprod_t_minus_one = 1.0 - alphas_cumprod[timestep - 1] 144 | one_divided_by_sigma_square = alpha_t / beta_t + 1.0 / one_minus_alpha_cumprod_t_minus_one 145 | variance = (1.0 / one_divided_by_sigma_square) ** 0.5 146 | else: 147 | variance = torch.zeros_like(timestep) 148 | 149 | epsilon = torch.randn_like(images) 150 | images = mean + variance * epsilon 151 | images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() 152 | return images 153 | 154 | # training 155 | 156 | from accelerate import Accelerator 157 | from diffusers.optimization import get_cosine_schedule_with_warmup 158 | from diffusers.utils import make_image_grid, numpy_to_pil 159 | import torch.nn.functional as F 160 | import os 161 | 162 | model = model.cuda() 163 | ddpm = DDPM() 164 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 165 | lr_scheduler = get_cosine_schedule_with_warmup( 166 | optimizer=optimizer, 167 | num_warmup_steps=config.lr_warmup_steps, 168 | num_training_steps=(len(dataloader) * config.num_epochs), 169 | ) 170 | accelerator = Accelerator( 171 | mixed_precision=config.mixed_precision, 172 | gradient_accumulation_steps=config.gradient_accumulation_steps, 173 | log_with="tensorboard", 174 | project_dir=os.path.join(config.output_dir, "logs"), 175 | ) 176 | model, optimizer, dataloader, lr_scheduler = accelerator.prepare( 177 | model, optimizer, dataloader, lr_scheduler 178 | ) 179 | global_step = 0 180 | for epoch in range(config.num_epochs): 181 | progress_bar = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process, desc=f'Epoch {epoch}') 182 | 183 | for step, batch in enumerate(dataloader): 184 | clean_images = batch["images"] 185 | # Sample noise to add to the images 186 | noise = torch.randn(clean_images.shape, device=clean_images.device) 187 | bs = clean_images.shape[0] 188 | # Sample a random timestep for each image 189 | timesteps = torch.randint( 190 | 0, ddpm.num_train_timesteps, (bs,), device=clean_images.device, 191 | dtype=torch.int64 192 | ) 193 | # Add noise to the clean images according to the noise magnitude at each timestep 194 | noisy_images = ddpm.add_noise(clean_images, noise, timesteps) 195 | with accelerator.accumulate(model): 196 | # Predict the noise residual 197 | noise_pred = model(noisy_images, timesteps, return_dict=False)[0] 198 | loss = F.mse_loss(noise_pred, noise) 199 | accelerator.backward(loss) 200 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 201 | optimizer.step() 202 | lr_scheduler.step() 203 | optimizer.zero_grad() 204 | progress_bar.update(1) 205 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} 206 | progress_bar.set_postfix(**logs) 207 | accelerator.log(logs, step=global_step) 208 | global_step += 1 209 | 210 | if accelerator.is_main_process: 211 | # evaluate 212 | images = ddpm.sample(model, config.eval_batch_size, 3, config.image_size) 213 | image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=4) 214 | samples_dir = os.path.join(config.output_dir, 'samples') 215 | os.makedirs(samples_dir, exist_ok=True) 216 | image_grid.save(os.path.join(samples_dir, f'{global_step}.png')) 217 | # save models 218 | model.save_pretrained(config.output_dir) 219 | -------------------------------------------------------------------------------- /iddpm-tutorial/iddpm_training.py: -------------------------------------------------------------------------------- 1 | # config 2 | 3 | from dataclasses import dataclass 4 | 5 | import torch.types 6 | 7 | @dataclass 8 | class TrainingConfig: 9 | image_size = 64 10 | train_batch_size = 16 11 | eval_batch_size = 16 12 | num_epochs = 50 13 | gradient_accumulation_steps = 1 14 | learning_rate = 1e-4 15 | lr_warmup_steps = 500 16 | mixed_precision = "fp16" 17 | output_dir = "iddpm-animefaces-64" 18 | overwrite_output_dir = True 19 | 20 | config = TrainingConfig() 21 | 22 | # data 23 | 24 | from datasets import load_dataset 25 | 26 | dataset = load_dataset("huggan/anime-faces", split="train") 27 | dataset = dataset.select(range(21551)) 28 | 29 | ## preprocess 30 | 31 | from torchvision import transforms 32 | 33 | def get_transform(): 34 | preprocess = transforms.Compose([ 35 | transforms.Resize(config.image_size), 36 | transforms.ToTensor(), 37 | transforms.Normalize([0.5], [0.5]), 38 | ]) 39 | def transform(samples): 40 | images = [preprocess(img.convert("RGB")) for img in samples["image"]] 41 | return dict(images=images) 42 | return transform 43 | 44 | dataset.set_transform(get_transform()) 45 | 46 | ## dataloader 47 | 48 | from torch.utils.data import DataLoader 49 | 50 | dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) 51 | 52 | # unet model 53 | 54 | from diffusers import UNet2DModel 55 | 56 | model = UNet2DModel( 57 | sample_size=config.image_size, 58 | in_channels=3, 59 | out_channels=6, 60 | layers_per_block=2, 61 | block_out_channels=(128, 128, 256, 256, 512, 512), 62 | down_block_types=( 63 | "DownBlock2D", 64 | "DownBlock2D", 65 | "DownBlock2D", 66 | "DownBlock2D", 67 | "AttnDownBlock2D", 68 | "DownBlock2D", 69 | ), 70 | up_block_types=( 71 | "UpBlock2D", 72 | "AttnUpBlock2D", 73 | "UpBlock2D", 74 | "UpBlock2D", 75 | "UpBlock2D", 76 | "UpBlock2D", 77 | ), 78 | ) 79 | 80 | # core 81 | 82 | import math 83 | import torch 84 | from tqdm import tqdm 85 | from functools import partial 86 | 87 | def make_betas_cosine_schedule( 88 | num_diffusion_timesteps: int = 1000, 89 | beta_max: float = 0.999, 90 | s: float = 8e-3, 91 | ): 92 | fn = lambda t: math.cos((t + s) / (1 + s) * math.pi / 2) ** 2 93 | betas = [] 94 | for i in range(num_diffusion_timesteps): 95 | t1 = i / num_diffusion_timesteps 96 | t2 = (i + 1) / num_diffusion_timesteps 97 | betas.append(1.0 - fn(t2) / fn(t1)) 98 | return torch.tensor(betas, dtype=torch.float32).clamp_max(beta_max) 99 | 100 | def extract(arr: torch.Tensor, timesteps: torch.Tensor, broadcast_shape: torch.Size): 101 | arr = arr[timesteps] 102 | while len(arr.shape) < len(broadcast_shape): 103 | arr = arr.unsqueeze(-1) 104 | return arr.expand(broadcast_shape) 105 | 106 | class IDDPM: 107 | 108 | def __init__( 109 | self, 110 | num_diffusion_timesteps: int = 1000, 111 | beta_max: float = 0.999, 112 | ): 113 | self.num_diffusion_timesteps = num_diffusion_timesteps 114 | self.betas = make_betas_cosine_schedule(num_diffusion_timesteps, beta_max) 115 | self.alphas = 1.0 - self.betas 116 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 117 | self.alphas_cumprod_prev = torch.concat((torch.ones(1).to(self.alphas_cumprod), self.alphas_cumprod[:-1])) 118 | self.alphas_cumprod_next = torch.concat((self.alphas_cumprod[1:], torch.zeros(1).to(self.alphas_cumprod))) 119 | self.timesteps = torch.arange(num_diffusion_timesteps - 1, -1, -1) 120 | 121 | def add_noise( 122 | self, 123 | original_samples: torch.Tensor, 124 | noise: torch.Tensor, 125 | timesteps: torch.Tensor, 126 | ): 127 | alphas_cumprod = self.alphas_cumprod.to(original_samples) 128 | noise = noise.to(original_samples.device) 129 | timesteps = timesteps.to(original_samples.device) 130 | 131 | # \sqrt{\bar\alpha_t} 132 | sqrt_alpha_prod = alphas_cumprod[timesteps].flatten() ** 0.5 133 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 134 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 135 | 136 | # \sqrt{1 - \bar\alpha_t} 137 | sqrt_one_minus_alpha_prod = (1.0 - alphas_cumprod[timesteps]).flatten() ** 0.5 138 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 139 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 140 | 141 | return sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 142 | 143 | @torch.no_grad() 144 | def sample( 145 | self, 146 | unet: UNet2DModel, 147 | batch_size: int, 148 | in_channels: int, 149 | sample_size: int, 150 | ): 151 | images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device) 152 | 153 | betas = self.betas.to(unet.device) 154 | alphas = self.alphas.to(unet.device) 155 | alphas_cumprod = self.alphas_cumprod.to(unet.device) 156 | alphas_cumprod_prev = self.alphas_cumprod_prev.to(unet.device) 157 | timesteps = self.timesteps.to(unet.device) 158 | 159 | sqrt_recip_alphas_cumprod = (1.0 / alphas_cumprod) ** 0.5 160 | sqrt_recipm1_alphas_cumprod = (1.0 / alphas_cumprod - 1.0) ** 0.5 161 | 162 | posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 163 | posterior_log_variance_clipped = torch.log(torch.concat((posterior_variance[1:2], posterior_variance[1:]))) 164 | posterior_mean_coef1 = betas * alphas_cumprod_prev ** 0.5 / (1.0 - alphas_cumprod) 165 | posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas ** 0.5 / (1.0 - alphas_cumprod) 166 | 167 | for timestep in tqdm(timesteps, desc='Sampling'): 168 | _extract = partial(extract, timesteps=timestep, broadcast_shape=images.shape) 169 | preds: torch.Tensor = unet(images, timestep).sample 170 | pred_noises, pred_vars = torch.split(preds, in_channels, dim=1) 171 | 172 | # mean of p(x_{t-1}|x_t), same to DDPM 173 | x_0 = _extract(sqrt_recip_alphas_cumprod) * images - _extract(sqrt_recipm1_alphas_cumprod) * pred_noises 174 | mean = _extract(posterior_mean_coef1) * x_0.clamp(-1, 1) + _extract(posterior_mean_coef2) * images 175 | 176 | # variance of p(x_{t-1}|x_t), learned 177 | if timestep > 0: 178 | min_log = _extract(posterior_log_variance_clipped) 179 | max_log = _extract(torch.log(betas)) 180 | frac = (pred_vars + 1.0) / 2.0 181 | log_variance = frac * max_log + (1.0 - frac) * min_log 182 | stddev = torch.exp(0.5 * log_variance) 183 | else: 184 | stddev = torch.zeros_like(timestep) 185 | 186 | epsilon = torch.randn_like(images) 187 | images = mean + stddev * epsilon 188 | images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() 189 | return images 190 | 191 | # training 192 | 193 | from accelerate import Accelerator 194 | from diffusers.optimization import get_cosine_schedule_with_warmup 195 | from diffusers.utils import make_image_grid, numpy_to_pil 196 | import os 197 | 198 | ## training losses 199 | 200 | def pred_mean_logvar( 201 | iddpm: IDDPM, 202 | pred_noises: torch.Tensor, 203 | pred_vars: torch.Tensor, 204 | noisy_images: torch.Tensor, 205 | timesteps: torch.Tensor, 206 | ): 207 | betas = iddpm.betas.to(timesteps.device) 208 | alphas = iddpm.alphas.to(timesteps.device) 209 | alphas_cumprod = iddpm.alphas_cumprod.to(timesteps.device) 210 | alphas_cumprod_prev = iddpm.alphas_cumprod_prev.to(timesteps.device) 211 | 212 | sqrt_recip_alphas_cumprod = (1.0 / alphas_cumprod) ** 0.5 213 | sqrt_recipm1_alphas_cumprod = (1.0 / alphas_cumprod - 1.0) ** 0.5 214 | 215 | posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 216 | posterior_log_variance_clipped = torch.log(torch.concat((posterior_variance[1:2], posterior_variance[1:]))) 217 | posterior_mean_coef1 = betas * alphas_cumprod_prev ** 0.5 / (1.0 - alphas_cumprod) 218 | posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas ** 0.5 / (1.0 - alphas_cumprod) 219 | 220 | _extract = partial(extract, timesteps=timesteps, broadcast_shape=noisy_images.shape) 221 | # mean of p(x_{t-1}|x_t), same to DDPM 222 | x_0 = _extract(sqrt_recip_alphas_cumprod) * noisy_images - _extract(sqrt_recipm1_alphas_cumprod) * pred_noises 223 | mean = _extract(posterior_mean_coef1) * x_0.clamp(-1, 1) + _extract(posterior_mean_coef2) * noisy_images 224 | # variance of p(x_{t-1}|x_t), learned 225 | min_log = _extract(posterior_log_variance_clipped) 226 | max_log = _extract(torch.log(betas)) 227 | frac = (pred_vars + 1.0) / 2.0 228 | log_variance = frac * max_log + (1.0 - frac) * min_log 229 | 230 | return mean, log_variance 231 | 232 | def true_mean_logvar( 233 | iddpm: IDDPM, 234 | clean_images: torch.Tensor, 235 | noisy_images: torch.Tensor, 236 | timesteps: torch.Tensor, 237 | ): 238 | betas = iddpm.betas.to(timesteps.device) 239 | alphas = iddpm.alphas.to(timesteps.device) 240 | alphas_cumprod = iddpm.alphas_cumprod.to(timesteps.device) 241 | alphas_cumprod_prev = iddpm.alphas_cumprod_prev.to(timesteps.device) 242 | 243 | posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 244 | posterior_log_variance_clipped = torch.log(torch.concat((posterior_variance[1:2], posterior_variance[1:]))) 245 | posterior_mean_coef1 = betas * alphas_cumprod_prev ** 0.5 / (1.0 - alphas_cumprod) 246 | posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * alphas ** 0.5 / (1.0 - alphas_cumprod) 247 | 248 | _extract = partial(extract, timesteps=timesteps, broadcast_shape=noisy_images.shape) 249 | posterior_mean = _extract(posterior_mean_coef1) * clean_images + _extract(posterior_mean_coef2) * noisy_images 250 | posterior_log_variance_clipped = _extract(posterior_log_variance_clipped) 251 | 252 | return posterior_mean, posterior_log_variance_clipped 253 | 254 | def gaussian_kl_divergence( 255 | mean_1: torch.Tensor, 256 | logvar_1: torch.Tensor, 257 | mean_2: torch.Tensor, 258 | logvar_2: torch.Tensor, 259 | ): 260 | return 0.5 * ( 261 | -1.0 262 | + logvar_2 263 | - logvar_1 264 | + torch.exp(logvar_1 - logvar_2) 265 | + ((mean_1 - mean_2) ** 2) * torch.exp(-logvar_2) 266 | ) 267 | 268 | def approx_standard_normal_cdf(x): 269 | return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 270 | 271 | def gaussian_nll( 272 | clean_images: torch.Tensor, 273 | pred_mean: torch.Tensor, 274 | pred_logvar: torch.Tensor, 275 | ): 276 | # stdnorm = Normal(torch.Tensor([0.0]), torch.Tensor([1.0])) 277 | centered_x = clean_images - pred_mean 278 | inv_stdv = torch.exp(-pred_logvar) 279 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 280 | cdf_plus = approx_standard_normal_cdf(plus_in).clamp_min(1e-12) 281 | # cdf_plus = stdnorm.cdf(plus_in).clamp_min(1e-12) 282 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 283 | cdf_min = approx_standard_normal_cdf(min_in).clamp_min(1e-12) 284 | # cdf_min = stdnorm.cdf(min_in).clamp_min(1e-12) 285 | cdf_delta = (cdf_plus - cdf_min).clamp_min(1e-12) 286 | log_cdf_plus = torch.log(cdf_plus) 287 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp_min(1e-12)) 288 | log_probs = torch.log(cdf_delta.clamp_min(1e-12)) 289 | log_probs[clean_images < -0.999] = log_cdf_plus[clean_images < -0.999] 290 | log_probs[clean_images > 0.999] = log_one_minus_cdf_min[clean_images > 0.999] 291 | return log_probs 292 | 293 | def vlb_loss( 294 | iddpm: IDDPM, 295 | pred_noises: torch.Tensor, 296 | pred_vars: torch.Tensor, 297 | clean_images: torch.Tensor, # x_0 298 | noisy_images: torch.Tensor, # x_t 299 | timesteps: torch.Tensor, # t 300 | ): 301 | # 1. calculate predicted mean and log var, same to sampling 302 | pred_mean, pred_logvar = pred_mean_logvar(iddpm, pred_noises, pred_vars, noisy_images, timesteps) 303 | # 2. calculate the true mean and log var with q(x_{t-1}|x_t,x_0) 304 | true_mean, true_logvar = true_mean_logvar(iddpm, clean_images, noisy_images, timesteps) 305 | # 3. calculate the KL divergences 306 | kl = gaussian_kl_divergence(true_mean, true_logvar, pred_mean, pred_logvar) 307 | kl = kl.mean(dim=list(range(1, len(kl.shape)))) / math.log(2.0) 308 | # 4. calculate the NLL 309 | nll = gaussian_nll(clean_images, pred_mean, pred_logvar * 0.5) 310 | nll = nll.mean(dim=list(range(1, len(nll.shape)))) / math.log(2.0) 311 | # 5. gather results 312 | results = torch.where(timesteps == 0, nll, kl) 313 | return results 314 | 315 | def training_losses( 316 | iddpm: IDDPM, 317 | model: UNet2DModel, 318 | clean_images: torch.Tensor, 319 | noise: torch.Tensor, 320 | noisy_images: torch.Tensor, 321 | timesteps: torch.Tensor, 322 | vlb_weight: float = 1e-3, 323 | ) -> torch.Tensor: 324 | _, channels, _, _ = noisy_images.shape 325 | pred: torch.Tensor = model(noisy_images, timesteps, return_dict=False)[0] 326 | pred_noises, pred_vars = torch.split(pred, channels, dim=1) 327 | # 1. L_simple 328 | l_simple = (pred_noises - noise) ** 2 329 | l_simple = l_simple.mean(dim=list(range(1, len(l_simple.shape)))) 330 | # 2. L_vlb 331 | l_vlb = vlb_loss(iddpm, pred_noises.detach(), pred_vars, clean_images, noisy_images, timesteps) 332 | return l_simple + vlb_weight * l_vlb 333 | 334 | ## importance sampler 335 | 336 | import numpy as np 337 | from typing import List 338 | from torch.distributed.nn import dist 339 | 340 | class ImportanceSampler: 341 | 342 | def __init__( 343 | self, 344 | num_diffusion_timesteps: int = 1000, 345 | history_per_term: int = 10, 346 | ): 347 | self.num_diffusion_timesteps = num_diffusion_timesteps 348 | self.history_per_term = history_per_term 349 | self.uniform_prob = 1.0 / num_diffusion_timesteps 350 | self.loss_history = np.zeros([num_diffusion_timesteps, history_per_term], dtype=np.float64) 351 | self.loss_counts = np.zeros([num_diffusion_timesteps], dtype=int) 352 | 353 | def sample(self, batch_size: int): 354 | weights = self.weights 355 | prob = weights / np.sum(weights) 356 | timesteps = np.random.choice(self.num_diffusion_timesteps, size=(batch_size,), p=prob) 357 | weights = 1.0 / (self.num_diffusion_timesteps * prob[timesteps]) 358 | return torch.from_numpy(timesteps).long(), torch.from_numpy(weights).float() 359 | 360 | @property 361 | def weights(self): 362 | if not np.all(self.loss_counts == self.history_per_term): 363 | return np.ones([self.num_diffusion_timesteps], dtype=np.float64) 364 | weights = np.sqrt(np.mean(self.loss_history ** 2, axis=-1)) 365 | weights /= np.sum(weights) 366 | weights *= 1.0 - self.uniform_prob 367 | weights += self.uniform_prob / len(weights) 368 | return weights 369 | 370 | def update(self, timesteps: torch.Tensor, losses: torch.Tensor): 371 | # collect 372 | if dist.is_initialized(): 373 | world_size = dist.get_world_size() 374 | # get batch sizes for padding to the maximum bs 375 | batch_sizes = [torch.tensor([0], dtype=torch.int32, device=timesteps.device) for _ in range(world_size)] 376 | dist.all_gather(batch_sizes, torch.full_like(batch_sizes[0], timesteps.size(0))) 377 | max_batch_size = max([bs.item() for bs in batch_sizes]) 378 | # gather all timesteps and losses 379 | timestep_batches: List[torch.Tensor] = [torch.zeros(max_batch_size).to(timesteps) for _ in range(world_size)] 380 | loss_batches: List[torch.Tensor] = [torch.zeros(max_batch_size).to(losses) for _ in range(world_size)] 381 | dist.all_gather(timestep_batches, timesteps) 382 | dist.all_gather(loss_batches, losses) 383 | all_timesteps = [ts.item() for ts_batch, bs in zip(timestep_batches, batch_sizes) for ts in ts_batch[:bs]] 384 | all_losses = [loss.item() for loss_batch, bs in zip(loss_batches, batch_sizes) for loss in loss_batch[:bs]] 385 | else: 386 | all_timesteps = timesteps.tolist() 387 | all_losses = losses.tolist() 388 | # update 389 | for timestep, loss in zip(all_timesteps, all_losses): 390 | if self.loss_counts[timestep] == self.history_per_term: 391 | self.loss_history[timestep, :-1] = self.loss_history[timestep, 1:] 392 | self.loss_history[timestep, -1] = loss 393 | else: 394 | self.loss_history[timestep, self.loss_counts[timestep]] = loss 395 | self.loss_counts[timestep] += 1 396 | 397 | model = model.cuda() 398 | iddpm = IDDPM() 399 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 400 | lr_scheduler = get_cosine_schedule_with_warmup( 401 | optimizer=optimizer, 402 | num_warmup_steps=config.lr_warmup_steps, 403 | num_training_steps=(len(dataloader) * config.num_epochs), 404 | ) 405 | importance_sampler = ImportanceSampler() 406 | accelerator = Accelerator( 407 | mixed_precision=config.mixed_precision, 408 | gradient_accumulation_steps=config.gradient_accumulation_steps, 409 | log_with="tensorboard", 410 | project_dir=os.path.join(config.output_dir, "logs"), 411 | ) 412 | model, optimizer, dataloader, lr_scheduler = accelerator.prepare( 413 | model, optimizer, dataloader, lr_scheduler 414 | ) 415 | global_step = 0 416 | for epoch in range(config.num_epochs): 417 | progress_bar = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process, desc=f'Epoch {epoch}') 418 | 419 | for step, batch in enumerate(dataloader): 420 | clean_images = batch["images"] 421 | # Sample noise to add to the images 422 | noise = torch.randn(clean_images.shape, device=clean_images.device) 423 | bs = clean_images.shape[0] 424 | # Sample a random timestep for each image 425 | timesteps, weights = importance_sampler.sample(bs) 426 | timesteps = timesteps.to(clean_images.device) 427 | weights = weights.to(clean_images.device) 428 | # Add noise to the clean images according to the noise magnitude at each timestep 429 | noisy_images = iddpm.add_noise(clean_images, noise, timesteps) 430 | with accelerator.accumulate(model): 431 | # Predict the noise residual 432 | losses = training_losses(iddpm, model, clean_images, noise, noisy_images, timesteps) 433 | importance_sampler.update(timesteps, losses) 434 | loss = (losses * weights).mean() 435 | accelerator.backward(loss) 436 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 437 | optimizer.step() 438 | lr_scheduler.step() 439 | optimizer.zero_grad() 440 | progress_bar.update(1) 441 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} 442 | progress_bar.set_postfix(**logs) 443 | accelerator.log(logs, step=global_step) 444 | global_step += 1 445 | 446 | if accelerator.is_main_process: 447 | # evaluate 448 | images = iddpm.sample(model, config.eval_batch_size, 3, config.image_size) 449 | image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=4) 450 | samples_dir = os.path.join(config.output_dir, 'samples') 451 | os.makedirs(samples_dir, exist_ok=True) 452 | image_grid.save(os.path.join(samples_dir, f'{global_step}.png')) 453 | # save models 454 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 455 | model.module.save_pretrained(config.output_dir) 456 | else: 457 | model.save_pretrained(config.output_dir) 458 | --------------------------------------------------------------------------------