├── .gitignore ├── 2d_plot_diffusion_todo ├── __init__.py ├── chamferdist.py ├── dataset.py ├── ddpm.py ├── ddpm_tutorial.ipynb └── network.py ├── LICENSE ├── README.md ├── assets ├── images │ ├── cfg_test.png │ ├── cfg_train.png │ ├── fid_command.png │ ├── qs.png │ ├── sampling_command.png │ ├── task1_ddim_sample.png │ ├── task1_ddpm_sample.png │ ├── task1_distribution.png │ ├── task1_loss_curve.png │ ├── task1_output_example.png │ ├── task2_1_ddpm_sampling_algorithm.png │ ├── task2_algorithm.png │ ├── task2_ddim.png │ ├── task2_output_example.png │ ├── task2_teaser.png │ ├── teaser.gif │ └── teaser.png └── summary_of_DDPM_and_DDIM.pdf ├── image_diffusion_todo ├── dataset.py ├── fid │ ├── afhq_inception_v3.ckpt │ ├── inception.py │ └── measure_fid.py ├── model.py ├── module.py ├── network.py ├── sampling.py ├── scheduler.py └── train.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.ipynb_checkpoints 4 | data/ 5 | results/ 6 | outputs 7 | samples/ 8 | .DS_Store 9 | 10 | # Solutions 11 | 2d_plot_diffusion 12 | image_diffusion -------------------------------------------------------------------------------- /2d_plot_diffusion_todo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/2d_plot_diffusion_todo/__init__.py -------------------------------------------------------------------------------- /2d_plot_diffusion_todo/chamferdist.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import KDTree 2 | from scipy.spatial.distance import cdist 3 | 4 | def chamfer_distance(S1, S2) -> float: 5 | """ 6 | Computes the Chamfer distance between two point clouds defined as: 7 | d_CD(S1, S2) = \sigma_{x \in S1} min_{y in S2} ||x - y||^2 + \sigma_{y \in S2} min_{x in S1} ||x - y||^2 8 | """ 9 | dist = cdist(S1, S2) 10 | dist1 = dist.min(axis=1) ** 2 11 | dist2 = dist.min(axis=0) ** 2 12 | return dist1.sum() + dist2.sum() 13 | -------------------------------------------------------------------------------- /2d_plot_diffusion_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn import datasets 4 | from torch.utils.data import DataLoader, Dataset 5 | 6 | 7 | def normalize(ds, scaling_factor=2.0): 8 | return (ds - ds.mean()) / ds.std() * scaling_factor 9 | 10 | 11 | def sample_checkerboard(n): 12 | # https://github.com/ghliu/SB-FBSDE/blob/main/data.py 13 | n_points = 3 * n 14 | n_classes = 2 15 | freq = 5 16 | x = np.random.uniform( 17 | -(freq // 2) * np.pi, (freq // 2) * np.pi, size=(n_points, n_classes) 18 | ) 19 | mask = np.logical_or( 20 | np.logical_and(np.sin(x[:, 0]) > 0.0, np.sin(x[:, 1]) > 0.0), 21 | np.logical_and(np.sin(x[:, 0]) < 0.0, np.sin(x[:, 1]) < 0.0), 22 | ) 23 | y = np.eye(n_classes)[1 * mask] 24 | x0 = x[:, 0] * y[:, 0] 25 | x1 = x[:, 1] * y[:, 0] 26 | sample = np.concatenate([x0[..., None], x1[..., None]], axis=-1) 27 | sqr = np.sum(np.square(sample), axis=-1) 28 | idxs = np.where(sqr == 0) 29 | sample = np.delete(sample, idxs, axis=0) 30 | 31 | return sample 32 | 33 | 34 | def load_twodim(num_samples: int, dataset: str, dimension: int = 2): 35 | 36 | if dataset == "gaussian_centered": 37 | sample = np.random.normal(size=(num_samples, dimension)) 38 | sample = sample 39 | 40 | if dataset == "gaussian_shift": 41 | sample = np.random.normal(size=(num_samples, dimension)) 42 | sample = sample + 1.5 43 | 44 | if dataset == "circle": 45 | X, y = datasets.make_circles( 46 | n_samples=num_samples, noise=0.0, random_state=None, factor=0.5 47 | ) 48 | sample = X * 4 49 | 50 | if dataset == "scurve": 51 | X, y = datasets.make_s_curve( 52 | n_samples=num_samples, noise=0.0, random_state=None 53 | ) 54 | sample = normalize(X[:, [0, 2]]) 55 | 56 | if dataset == "moon": 57 | X, y = datasets.make_moons(n_samples=num_samples, noise=0.0, random_state=None) 58 | sample = normalize(X) 59 | 60 | if dataset == "swiss_roll": 61 | X, y = datasets.make_swiss_roll( 62 | n_samples=num_samples, noise=0.0, random_state=None, hole=True 63 | ) 64 | sample = normalize(X[:, [0, 2]]) 65 | 66 | if dataset == "checkerboard": 67 | sample = normalize(sample_checkerboard(num_samples)) 68 | 69 | return torch.tensor(sample).float() 70 | 71 | 72 | class TwoDimDataClass(Dataset): 73 | def __init__(self, dataset_type: str, N: int, batch_size: int, dimension=2): 74 | 75 | self.X = load_twodim(N, dataset_type, dimension=dimension) 76 | self.name = dataset_type 77 | self.batch_size = batch_size 78 | self.dimension = 2 79 | 80 | def __len__(self): 81 | return self.X.shape[0] 82 | 83 | def __getitem__(self, idx): 84 | return self.X[idx] 85 | 86 | def get_dataloader(self, shuffle=True): 87 | return DataLoader( 88 | self, 89 | batch_size=self.batch_size, 90 | shuffle=shuffle, 91 | pin_memory=True, 92 | ) 93 | 94 | 95 | def get_data_iterator(iterable): 96 | iterator = iterable.__iter__() 97 | while True: 98 | try: 99 | yield iterator.__next__() 100 | except StopIteration: 101 | iterator = iterable.__iter__() 102 | -------------------------------------------------------------------------------- /2d_plot_diffusion_todo/ddpm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def extract(input, t: torch.Tensor, x: torch.Tensor): 8 | if t.ndim == 0: 9 | t = t.unsqueeze(0) 10 | shape = x.shape 11 | t = t.long().to(input.device) 12 | out = torch.gather(input, 0, t) 13 | reshape = [t.shape[0]] + [1] * (len(shape) - 1) 14 | return out.reshape(*reshape) 15 | 16 | 17 | class BaseScheduler(nn.Module): 18 | """ 19 | Variance scheduler of DDPM. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_train_timesteps: int, 25 | beta_1: float = 1e-4, 26 | beta_T: float = 0.02, 27 | mode: str = "linear", 28 | ): 29 | super().__init__() 30 | self.num_train_timesteps = num_train_timesteps 31 | self.timesteps = torch.from_numpy( 32 | np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64) 33 | ) 34 | 35 | if mode == "linear": 36 | betas = torch.linspace(beta_1, beta_T, steps=num_train_timesteps) 37 | elif mode == "quad": 38 | betas = ( 39 | torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps) ** 2 40 | ) 41 | else: 42 | raise NotImplementedError(f"{mode} is not implemented.") 43 | 44 | alphas = 1 - betas 45 | alphas_cumprod = torch.cumprod(alphas, dim=0) 46 | 47 | self.register_buffer("betas", betas) 48 | self.register_buffer("alphas", alphas) 49 | self.register_buffer("alphas_cumprod", alphas_cumprod) 50 | 51 | 52 | class DiffusionModule(nn.Module): 53 | """ 54 | A high-level wrapper of DDPM and DDIM. 55 | If you want to sample data based on the DDIM's reverse process, use `ddim_p_sample()` and `ddim_p_sample_loop()`. 56 | """ 57 | 58 | def __init__(self, network: nn.Module, var_scheduler: BaseScheduler): 59 | super().__init__() 60 | self.network = network 61 | self.var_scheduler = var_scheduler 62 | 63 | @property 64 | def device(self): 65 | return next(self.network.parameters()).device 66 | 67 | @property 68 | def image_resolution(self): 69 | # For image diffusion model. 70 | return getattr(self.network, "image_resolution", None) 71 | 72 | def q_sample(self, x0, t, noise=None): 73 | """ 74 | sample x_t from q(x_t | x_0) of DDPM. 75 | 76 | Input: 77 | x0 (`torch.Tensor`): clean data to be mapped to timestep t in the forward process of DDPM. 78 | t (`torch.Tensor`): timestep 79 | noise (`torch.Tensor`, optional): random Gaussian noise. if None, randomly sample Gaussian noise in the function. 80 | Output: 81 | xt (`torch.Tensor`): noisy samples 82 | """ 83 | if noise is None: 84 | noise = torch.randn_like(x0) 85 | 86 | ######## TODO ######## 87 | # DO NOT change the code outside this part. 88 | # Compute xt. 89 | alphas_prod_t = extract(self.var_scheduler.alphas_cumprod, t, x0) 90 | xt = x0 91 | 92 | ####################### 93 | 94 | return xt 95 | 96 | @torch.no_grad() 97 | def p_sample(self, xt, t): 98 | """ 99 | One step denoising function of DDPM: x_t -> x_{t-1}. 100 | 101 | Input: 102 | xt (`torch.Tensor`): samples at arbitrary timestep t. 103 | t (`torch.Tensor`): current timestep in a reverse process. 104 | Ouptut: 105 | x_t_prev (`torch.Tensor`): one step denoised sample. (= x_{t-1}) 106 | 107 | """ 108 | ######## TODO ######## 109 | # DO NOT change the code outside this part. 110 | # compute x_t_prev. 111 | if isinstance(t, int): 112 | t = torch.tensor([t]).to(self.device) 113 | eps_factor = (1 - extract(self.var_scheduler.alphas, t, xt)) / ( 114 | 1 - extract(self.var_scheduler.alphas_cumprod, t, xt) 115 | ).sqrt() 116 | eps_theta = self.network(xt, t) 117 | 118 | x_t_prev = xt 119 | 120 | ####################### 121 | return x_t_prev 122 | 123 | @torch.no_grad() 124 | def p_sample_loop(self, shape): 125 | """ 126 | The loop of the reverse process of DDPM. 127 | 128 | Input: 129 | shape (`Tuple`): The shape of output. e.g., (num particles, 2) 130 | Output: 131 | x0_pred (`torch.Tensor`): The final denoised output through the DDPM reverse process. 132 | """ 133 | ######## TODO ######## 134 | # DO NOT change the code outside this part. 135 | # sample x0 based on Algorithm 2 of DDPM paper. 136 | x0_pred = torch.zeros(shape).to(self.device) 137 | 138 | ###################### 139 | return x0_pred 140 | 141 | @torch.no_grad() 142 | def ddim_p_sample(self, xt, t, t_prev, eta=0.0): 143 | """ 144 | One step denoising function of DDIM: $x_t{\tau_i}$ -> $x_{\tau{i-1}}$. 145 | 146 | Input: 147 | xt (`torch.Tensor`): noisy data at timestep $\tau_i$. 148 | t (`torch.Tensor`): current timestep (=\tau_i) 149 | t_prev (`torch.Tensor`): next timestep in a reverse process (=\tau_{i-1}) 150 | eta (float): correspond to η in DDIM which controls the stochasticity of a reverse process. 151 | Output: 152 | x_t_prev (`torch.Tensor`): one step denoised sample. (= $x_{\tau_{i-1}}$) 153 | """ 154 | ######## TODO ######## 155 | # NOTE: This code is used for assignment 2. You don't need to implement this part for assignment 1. 156 | # DO NOT change the code outside this part. 157 | # compute x_t_prev based on ddim reverse process. 158 | alpha_prod_t = extract(self.var_scheduler.alphas_cumprod, t, xt) 159 | if t_prev >= 0: 160 | alpha_prod_t_prev = extract(self.var_scheduler.alphas_cumprod, t_prev, xt) 161 | else: 162 | alpha_prod_t_prev = torch.ones_like(alpha_prod_t) 163 | 164 | x_t_prev = xt 165 | 166 | ###################### 167 | return x_t_prev 168 | 169 | @torch.no_grad() 170 | def ddim_p_sample_loop(self, shape, num_inference_timesteps=50, eta=0.0): 171 | """ 172 | The loop of the reverse process of DDIM. 173 | 174 | Input: 175 | shape (`Tuple`): The shape of output. e.g., (num particles, 2) 176 | num_inference_timesteps (`int`): the number of timesteps in the reverse process. 177 | eta (`float`): correspond to η in DDIM which controls the stochasticity of a reverse process. 178 | Output: 179 | x0_pred (`torch.Tensor`): The final denoised output through the DDPM reverse process. 180 | """ 181 | ######## TODO ######## 182 | # NOTE: This code is used for assignment 2. You don't need to implement this part for assignment 1. 183 | # DO NOT change the code outside this part. 184 | # sample x0 based on Algorithm 2 of DDPM paper. 185 | step_ratio = self.var_scheduler.num_train_timesteps // num_inference_timesteps 186 | timesteps = ( 187 | (np.arange(0, num_inference_timesteps) * step_ratio) 188 | .round()[::-1] 189 | .copy() 190 | .astype(np.int64) 191 | ) 192 | timesteps = torch.from_numpy(timesteps) 193 | prev_timesteps = timesteps - step_ratio 194 | 195 | xt = torch.zeros(shape).to(self.device) 196 | for t, t_prev in zip(timesteps, prev_timesteps): 197 | pass 198 | 199 | x0_pred = xt 200 | 201 | ###################### 202 | 203 | return x0_pred 204 | 205 | def compute_loss(self, x0): 206 | """ 207 | The simplified noise matching loss corresponding Equation 14 in DDPM paper. 208 | 209 | Input: 210 | x0 (`torch.Tensor`): clean data 211 | Output: 212 | loss: the computed loss to be backpropagated. 213 | """ 214 | ######## TODO ######## 215 | # DO NOT change the code outside this part. 216 | # compute noise matching loss. 217 | batch_size = x0.shape[0] 218 | t = ( 219 | torch.randint(0, self.var_scheduler.num_train_timesteps, size=(batch_size,)) 220 | .to(x0.device) 221 | .long() 222 | ) 223 | 224 | loss = x0.mean() 225 | 226 | ###################### 227 | return loss 228 | 229 | def save(self, file_path): 230 | hparams = { 231 | "network": self.network, 232 | "var_scheduler": self.var_scheduler, 233 | } 234 | state_dict = self.state_dict() 235 | 236 | dic = {"hparams": hparams, "state_dict": state_dict} 237 | torch.save(dic, file_path) 238 | 239 | def load(self, file_path): 240 | dic = torch.load(file_path, map_location="cpu") 241 | hparams = dic["hparams"] 242 | state_dict = dic["state_dict"] 243 | 244 | self.network = hparams["network"] 245 | self.var_scheduler = hparams["var_scheduler"] 246 | 247 | self.load_state_dict(state_dict) 248 | -------------------------------------------------------------------------------- /2d_plot_diffusion_todo/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class TimeEmbedding(nn.Module): 10 | def __init__(self, hidden_size, frequency_embedding_size=256): 11 | super().__init__() 12 | self.mlp = nn.Sequential( 13 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 14 | nn.SiLU(), 15 | nn.Linear(hidden_size, hidden_size, bias=True), 16 | ) 17 | self.frequency_embedding_size = frequency_embedding_size 18 | 19 | @staticmethod 20 | def timestep_embedding(t, dim, max_period=10000): 21 | """ 22 | Create sinusoidal timestep embeddings. 23 | :param t: a 1-D Tensor of N indices, one per batch element. 24 | These may be fractional. 25 | :param dim: the dimension of the output. 26 | :param max_period: controls the minimum frequency of the embeddings. 27 | :return: an (N, D) Tensor of positional embeddings. 28 | """ 29 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 30 | half = dim // 2 31 | freqs = torch.exp( 32 | -math.log(max_period) 33 | * torch.arange(start=0, end=half, dtype=torch.float32) 34 | / half 35 | ).to(device=t.device) 36 | args = t[:, None].float() * freqs[None] 37 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 38 | if dim % 2: 39 | embedding = torch.cat( 40 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 41 | ) 42 | return embedding 43 | 44 | def forward(self, t: torch.Tensor): 45 | if t.ndim == 0: 46 | t = t.unsqueeze(-1) 47 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 48 | t_emb = self.mlp(t_freq) 49 | return t_emb 50 | 51 | 52 | class TimeLinear(nn.Module): 53 | def __init__(self, dim_in: int, dim_out: int, num_timesteps: int): 54 | super().__init__() 55 | self.dim_in = dim_in 56 | self.dim_out = dim_out 57 | self.num_timesteps = num_timesteps 58 | 59 | self.time_embedding = TimeEmbedding(dim_out) 60 | self.fc = nn.Linear(dim_in, dim_out) 61 | 62 | def forward(self, x: torch.Tensor, t: torch.Tensor): 63 | x = self.fc(x) 64 | alpha = self.time_embedding(t).view(-1, self.dim_out) 65 | 66 | return alpha * x 67 | 68 | 69 | class SimpleNet(nn.Module): 70 | def __init__( 71 | self, dim_in: int, dim_out: int, dim_hids: List[int], num_timesteps: int 72 | ): 73 | super().__init__() 74 | """ 75 | (TODO) Build a noise estimating network. 76 | 77 | Args: 78 | dim_in: dimension of input 79 | dim_out: dimension of output 80 | dim_hids: dimensions of hidden features 81 | num_timesteps: number of timesteps 82 | """ 83 | 84 | ######## TODO ######## 85 | # DO NOT change the code outside this part. 86 | 87 | ###################### 88 | 89 | def forward(self, x: torch.Tensor, t: torch.Tensor): 90 | """ 91 | (TODO) Implement the forward pass. This should output 92 | the noise prediction of the noisy input x at timestep t. 93 | 94 | Args: 95 | x: the noisy data after t period diffusion 96 | t: the time that the forward diffusion has been running 97 | """ 98 | ######## TODO ######## 99 | # DO NOT change the code outside this part. 100 | 101 | ###################### 102 | return x 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 KAIST Visual AI Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Denoising Diffusion Probabilistic Models (DDPM) 4 |

5 |

6 | KAIST CS492(D): Diffusion Models and Their Applications (Fall 2024)
7 | Programming Assignment 1 8 |

9 |
10 | 11 |
12 |

13 | Instructor: Minhyuk Sung (mhsung [at] kaist.ac.kr)
14 | TA: Seungwoo Yoo (dreamy1534 [at] kaist.ac.kr)
15 | Credit: Juil Koo (63days [at] kaist.ac.kr) & Nguyen Minh Hieu (hieuristics [at] kaist.ac.kr) 16 |

17 |
18 | 19 |
20 | 21 |
22 | 23 | 24 | ## Abstract 25 | In this programming assignment, you will implement the Denoising Diffusion Probabilistic Model (DDPM), a fundamental building block that empowers today's diffusion-based generative modeling. While DDPM provides the technical foundation for popular generative frameworks like [Stable Diffusion](https://github.com/CompVis/stable-diffusion), its implementation is surprisingly straightforward, making it an excellent starting point for gaining hands-on experience in building diffusion models. We will begin with a relatively simple example: modeling the distribution of 2D points on a spiral (known as the "Swiss Roll"). Following that, we will develop an image generator using the AFHQ dataset to explore how DDPM and diffusion models seamlessly adapt to changes in data format and dimensionality with minimal code changes. 26 | 27 | ## Setup 28 | 29 | Create a `conda` environment named `ddpm` and install PyTorch: 30 | ``` 31 | conda create --name ddpm python=3.10 32 | conda activate ddpm 33 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 34 | ``` 35 | 36 | Install the required package within the `requirements.txt` 37 | ``` 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | > **NOTE: We have removed the dependency on `chamferdist` due to issues during installation.** 42 | 43 | ## Code Structure 44 | ``` 45 | . 46 | ├── 2d_plot_diffusion_todo (Task 1) 47 | │ ├── ddpm_tutorial.ipynb <--- Main code 48 | │ ├── dataset.py <--- Define dataset (Swiss-roll, moon, gaussians, etc.) 49 | │ ├── network.py <--- (TODO) Implement a noise prediction network 50 | │ └── ddpm.py <--- (TODO) Define a DDPM pipeline 51 | │ 52 | └── image_diffusion_todo (Task 2) 53 | ├── dataset.py <--- Ready-to-use AFHQ dataset code 54 | ├── model.py <--- Diffusion model including its backbone and scheduler 55 | ├── module.py <--- Basic modules of a noise prediction network 56 | ├── network.py <--- Definition of the U-Net architecture 57 | ├── sampling.py <--- Image sampling code 58 | ├── scheduler.py <--- (TODO) Implement the forward/reverse step of DDPM 59 | ├── train.py <--- DDPM training code 60 | └── fid 61 | ├── measure_fid.py <--- script measuring FID score 62 | └── afhq_inception.ckpt <--- pre-trained classifier for FID 63 | ``` 64 | 65 | 66 | ## Task 0: Introduction 67 | ### Assignment Tips 68 | 69 | Implementation of diffusion models would be simple once you understand the theory. 70 | So, to learn the most from this tutorial, it's highly recommended to check out the details in the 71 | related papers and understand the equations **BEFORE** you start the assignment. You can check out 72 | the resources in this order: 73 | 74 | 1. [[Paper](https://arxiv.org/abs/2006.11239)] Denoising Diffusion Probabilistic Models 75 | 2. [[Blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] Lilian Wang's "What are Diffusion Models?" 76 | 77 | ### Forward Process 78 | Denoising Diffusion Probabilistic Model (DDPM) is one of latent-variable generative models consisting of a Markov chain. In the Markov chain, let us define a _forward process_ that gradually adds noise to the data sampled from a data distribution $\mathbf{x}_0 \sim q(\mathbf{x}_0)$ so that $\mathbf{x}_0$ becomes pure white Gaussian noise at $t=T$. Each transition of the forward process is as follows: 79 | 80 | $$ q(\mathbf{x}_t | \mathbf{x}\_{t-1}) := \mathcal{N}(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}\_{t-1}, \beta_t \mathbf{I}), $$ 81 | 82 | where a variance schedule $\beta_1, \dots, \beta_T$ controlls the step sizes. 83 | 84 | Thanks to a nice property of a Gaussian distribution, one can directly sample $\mathbf{x}_t$ at an arbitrary timestep $t$ from real data $\mathbf{x}_0$ in closed form: 85 | 86 | $$q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t) \mathbf{I}). $$ 87 | 88 | where $\alpha\_t := 1 - \beta\_t$ and $\bar{\alpha}_t := \prod$ $\_{s=1}^T \alpha_s$. 89 | 90 | Refer to [our slide](./assets/summary_of_DDPM_and_DDIM.pdf) or [blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) for more details. 91 | 92 | ### Reverse Process 93 | If we can reverse the forward process, i.e. sample $\mathbf{x}\_{t-1} \sim q(\mathbf{x}\_{t-1} | \mathbf{x}_t)$ iteratively until $t=0$, we will be able to generate $\mathbf{x}_0$ which is close to the unknown data distribution $\mathbf{q}(\mathbf{x}_0)$ from white Gaussian noise $\mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I})$. You can think of this _reverse process_ as denoising process that gradually denoises noise so that it looks like a true sample from $q(\mathbf{x}_0)$ at the end. 94 | The reverse process is also a Markov chain with learned Gaussian transitions: 95 | 96 | $$p\_\theta(\mathbf{x}\_{0:T}) := p(\mathbf{x}_T) \prod\_{t=1}^T p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t), $$ 97 | 98 | where $p(\mathbf{x}_T) = \mathcal{N}(0, \mathbf{I})$ and $p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t) := \mathcal{N}(\mathbf{x}\_{t-1}; \mathbf{\boldsymbol{\mu}}\_\theta (\mathbf{x}_t, t)\boldsymbol{\Sigma}\_\theta (\mathbf{x}_t, t)).$ 99 | 100 | ### Training 101 | To learn this reverse process, we set an objective function that minimizes KL divergence between $p_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t)$ and $q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0), \sigma_t^2 \mathbf{I})$ which is also a Gaussian distribution when conditioned on $\mathbf{x}_0$: 102 | 103 | $$\mathcal{L} = \mathbb{E}_q \left[ \sum\_{t > 1} D\_{\text{KL}}( q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \Vert p\_\theta ( \mathbf{x}\_{t-1} | \mathbf{x}_t)) \right].$$ 104 | 105 | As a parameterization of DDPM, the authors set $\boldsymbol{\Sigma}\_\theta(\mathbf{x}_t, t) = \sigma_t^2 \mathbf{I}$ to untrained time dependent constants. As a result, we can rewrite the objective function: 106 | 107 | $$\mathcal{L} = \mathbb{E}\_q \left[ \frac{1}{2\sigma\_t^2} \Vert \tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - \boldsymbol{\mu}\_{\theta}(\mathbf{x}_t, t) \Vert^2 \right] + C $$ 108 | 109 | The authors empirically found that predicting $\epsilon$ noise injected to data by a noise prediction network $\epsilon\_\theta$ is better than learning the mean function $\boldsymbol{\mu}\_\theta$. 110 | 111 | In short, the simplified objective function of DDPM is defined as follows: 112 | 113 | $$ \mathcal{L}\_{\text{simple}} := \mathbb{E}\_{t,\mathbf{x}_0,\boldsymbol{\epsilon}} [ \Vert \boldsymbol{\epsilon} - \boldsymbol{\epsilon}\_\theta( \mathbf{x}\_t(\mathbf{x}_0, t), t) \Vert^2 ],$$ 114 | 115 | where $\mathbf{x}_t (\mathbf{x}_0, t) = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}$ and $\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$. 116 | 117 | Refer to [the original paper](https://arxiv.org/abs/2006.11239) for more details. 118 | 119 | ### Sampling 120 | 121 | Once we train the noise prediction network $\boldsymbol{\epsilon}\_\theta$, we can run sampling by gradually denoising white Gaussian noise. The algorithm of the DDPM sampling is shown below: 122 | 123 |

124 | image 125 |

126 | 127 | ## Task 1: Simple DDPM pipeline with Swiss-Roll 128 | 129 |

130 | image 131 |

132 | 133 | A typical diffusion pipeline is divided into three components: 134 | 1. [Forward Process](#forward-process) and [Reverse Process](#reverse-process) 135 | 2. [Training](#training) 136 | 3. [Sampling](#sampling) 137 | 138 | In this task, we will look into each component one by one in a toy experiment and implement them sequentially. 139 | After finishing the implementation, you will be able to train DDPM and evaluate the performance in `ddpm_tutorial.ipynb` under `2d_plot_todo` directory. 140 | 141 | ❗️❗️❗️ **You are only allowed to edit the part marked by TODO.** ❗️❗️❗️ 142 | 143 | ### TODO 144 | #### 1-1: Build a noise prediction network 145 | You first need to implement a noise prediction network in `network.py`. 146 | The network should consist of `TimeLinear` layers whose feature dimensions are a sequence of [`dim_in`, `dim_hids[0]`, ..., `dim_hids[-1]`, `dim_out`]. 147 | Every `TimeLinear` layer except for the last `TimeLinear` layer should be followed by a ReLU activation. 148 | 149 | #### 1-2: Construct the forward and reverse process of DDPM 150 | Now you should construct a forward and reverse process of DDPM in `ddpm.py`. 151 | `q_sample()` is a forward function that maps $\mathbf{x}_0$ to $\mathbf{x}_t$. 152 | 153 | `p_sample()` is a one-step reverse transition from $\mathbf{x}\_{t}$ to $\mathbf{x}\_{t-1}$ and `p_sample_loop()` is the full reverse process corresponding to DDPM sampling algorithm. 154 | 155 | #### 1-3: Implement the training objective function 156 | In `ddpm.py`, `compute_loss()` function should return the simplified noise matching loss in DDPM paper. 157 | 158 | #### 1-4: Training and Evaluation 159 | Once you finish the implementation above, open and run `ddpm_tutorial.ipynb` via jupyter notebook. It will automatically train a diffudion model and measure chamfer distance between 2D particles sampled by the diffusion model and 2D particles sampled from the target distribution. 160 | 161 | Take screenshots of: 162 | 163 | 1. the training loss curve 164 | 2. the Chamfer Distance reported after executing the Jupyter Notebook 165 | 3. the visualization of the sampled particles 166 | 167 | Below are the examples of (1) and (3). 168 |

169 | image 170 | image 171 |

172 | 173 | ## Task 2: Image Diffusion 174 | 175 |

176 | image 177 |

178 | 179 | ### TODO 180 | 181 | If you successfully finish the task 1, implement the methods `add_noise` and `step` of the class `DDPMScheduler` defined in `image_diffusion_todo/scheduler.py`. You also need to implement the method `get_loss` of the class `DiffusionModule` defined in `image_diffusion_todo/model.py`. Refer to your implementation of the methods `q_sample`, `p_sample`, and `compute_loss` from the 2D experiment. 182 | 183 | In this task, we will generate $64\times64$ animal images by training a DDPM using the AFHQ dataset. 184 | 185 | To train your model, simply execute the command: `python train.py`. 186 | 187 | ❗️❗️❗️ You are NOT allowed to modify any given hyperparameters. ❗️❗️❗️ 188 | 189 | It will sample images and save a checkpoint every `args.log_interval`. After training a model, sample & save images by 190 | ``` 191 | python sampling.py --ckpt_path ${CKPT_PATH} --save_dir ${SAVE_DIR_PATH} 192 | ``` 193 | ![sampling_command](./assets/images/sampling_command.png) 194 | 195 | We recommend starting the training as soon as possible since the training would take **14 hours**. 196 | 197 | As an evaluation, measure FID score using the pre-trained classifier network we provide: 198 | ``` 199 | python dataset.py # to constuct eval directory. 200 | python fid/measure_fid.py @GT_IMG_DIR @ GEN_IMG_DIR 201 | ``` 202 | 203 | > **Do NOT forget to execute `dataset.py` before measuring FID score. Otherwise, the output will be incorrect due to the discrepancy between the image resolutions.** 204 | 205 | For instance: 206 | ![fid_comamnd](./assets/images/fid_command.png) 207 | Use the validation set of the AFHQ dataset (e.g., `data/afhq/eval`) as @GT_IMG_DIR. The script will automatically search and load the images. The path @DIR_TO_SAVE_IMGS should be the same as the one you provided when running the script `sampling.py`. 208 | 209 | Take a screenshot of a FID score and include at least 8 sampled images. 210 |

211 | image 212 |

213 | 214 | ## What to Submit 215 | 216 |
217 | Submission Item List 218 |
219 | 220 | - [ ] Code without model checkpoints 221 | 222 | **Task 1** 223 | - [ ] Loss curve screenshot 224 | - [ ] Chamfer distance result of DDPM sampling 225 | - [ ] Visualization of DDPM sampling 226 | 227 | **Task 2** 228 | - [ ] FID score result 229 | - [ ] At least 8 images generated your DDPM model 230 | 231 |
232 | 233 | In a single document, write your name and student ID, and include submission items listed above. Refer to more detailed instructions written in each task section about what to submit. 234 | Name the document `{NAME}_{STUDENT_ID}.pdf` and submit **both your code and the document** as a **ZIP** file named `{NAME}_{STUDENT_ID}.zip`. 235 | **When creating your zip file**, exclude data (e.g., files in AFHQ dataset) and any model checkpoints, including the provided pre-trained classifier checkpoint when compressing the files. 236 | Submit the zip file on GradeScope. 237 | 238 | ## Grading 239 | **You will receive a zero score if:** 240 | - **you do not submit,** 241 | - **your code is not executable in the Python environment we provided, or** 242 | - **you modify anycode outside of the section marked with `TODO` or use different hyperparameters that are supposed to be fixed as given.** 243 | 244 | **Plagiarism in any form will also result in a zero score and will be reported to the university.** 245 | 246 | **Your score will incur a 10% deduction for each missing item in the submission item list.** 247 | 248 | Otherwise, you will receive up to 20 points from this assignment that count toward your final grade. 249 | 250 | - Task 1 251 | - 10 points: Achieve CD lower than **20** from DDPM sampling. 252 | - 5 points: Achieve CD greater, or equal to **20** and less than **40** from DDPM sampling. 253 | - 0 point: otherwise. 254 | - Task 2 255 | - 10 points: Achieve FID less than **20**. 256 | - 5 points: Achieve FID between **20** and **40**. 257 | - 0 point: otherwise. 258 | 259 | ## Further Readings 260 | 261 | If you are interested in this topic, we encourage you to check ou the materials below. 262 | 263 | - [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) 264 | - [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) 265 | - [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233) 266 | - [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456) 267 | - [What are Diffusion Models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) 268 | - [Generative Modeling by Estimating Gradients of the Data Distribution](https://yang-song.net/blog/2021/score/) 269 | -------------------------------------------------------------------------------- /assets/images/cfg_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/cfg_test.png -------------------------------------------------------------------------------- /assets/images/cfg_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/cfg_train.png -------------------------------------------------------------------------------- /assets/images/fid_command.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/fid_command.png -------------------------------------------------------------------------------- /assets/images/qs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/qs.png -------------------------------------------------------------------------------- /assets/images/sampling_command.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/sampling_command.png -------------------------------------------------------------------------------- /assets/images/task1_ddim_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_ddim_sample.png -------------------------------------------------------------------------------- /assets/images/task1_ddpm_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_ddpm_sample.png -------------------------------------------------------------------------------- /assets/images/task1_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_distribution.png -------------------------------------------------------------------------------- /assets/images/task1_loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_loss_curve.png -------------------------------------------------------------------------------- /assets/images/task1_output_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_output_example.png -------------------------------------------------------------------------------- /assets/images/task2_1_ddpm_sampling_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_1_ddpm_sampling_algorithm.png -------------------------------------------------------------------------------- /assets/images/task2_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_algorithm.png -------------------------------------------------------------------------------- /assets/images/task2_ddim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_ddim.png -------------------------------------------------------------------------------- /assets/images/task2_output_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_output_example.png -------------------------------------------------------------------------------- /assets/images/task2_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_teaser.png -------------------------------------------------------------------------------- /assets/images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/teaser.gif -------------------------------------------------------------------------------- /assets/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/teaser.png -------------------------------------------------------------------------------- /assets/summary_of_DDPM_and_DDIM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/summary_of_DDPM_and_DDIM.pdf -------------------------------------------------------------------------------- /image_diffusion_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | from multiprocessing.pool import Pool 4 | from pathlib import Path 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | 10 | 11 | def listdir(dname): 12 | fnames = list( 13 | chain( 14 | *[ 15 | list(Path(dname).rglob("*." + ext)) 16 | for ext in ["png", "jpg", "jpeg", "JPG"] 17 | ] 18 | ) 19 | ) 20 | return fnames 21 | 22 | 23 | def tensor_to_pil_image(x: torch.Tensor, single_image=False): 24 | """ 25 | x: [B,C,H,W] 26 | """ 27 | if x.ndim == 3: 28 | x = x.unsqueeze(0) 29 | single_image = True 30 | 31 | x = (x * 0.5 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy() 32 | images = (x * 255).round().astype("uint8") 33 | images = [Image.fromarray(image) for image in images] 34 | if single_image: 35 | return images[0] 36 | return images 37 | 38 | 39 | def get_data_iterator(iterable): 40 | """Allows training with DataLoaders in a single infinite loop: 41 | for i, data in enumerate(inf_generator(train_loader)): 42 | """ 43 | iterator = iterable.__iter__() 44 | while True: 45 | try: 46 | yield iterator.__next__() 47 | except StopIteration: 48 | iterator = iterable.__iter__() 49 | 50 | 51 | class AFHQDataset(torch.utils.data.Dataset): 52 | def __init__( 53 | self, root: str, split: str, transform=None, max_num_images_per_cat=-1, label_offset=1 54 | ): 55 | super().__init__() 56 | self.root = root 57 | self.split = split 58 | self.transform = transform 59 | self.max_num_images_per_cat = max_num_images_per_cat 60 | self.label_offset = label_offset 61 | 62 | categories = os.listdir(os.path.join(root, split)) 63 | self.num_classes = len(categories) 64 | 65 | fnames, labels = [], [] 66 | for idx, cat in enumerate(sorted(categories)): 67 | category_dir = os.path.join(root, split, cat) 68 | cat_fnames = listdir(category_dir) 69 | cat_fnames = sorted(cat_fnames) 70 | if self.max_num_images_per_cat > 0: 71 | cat_fnames = cat_fnames[: self.max_num_images_per_cat] 72 | fnames += cat_fnames 73 | labels += [idx + label_offset] * len(cat_fnames) # label 0 is for null class. 74 | 75 | self.fnames = fnames 76 | self.labels = labels 77 | 78 | def __getitem__(self, idx): 79 | img = Image.open(self.fnames[idx]).convert("RGB") 80 | label = self.labels[idx] 81 | assert label >= self.label_offset 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | 85 | return img, label 86 | 87 | def __len__(self): 88 | return len(self.labels) 89 | 90 | 91 | class AFHQDataModule(object): 92 | def __init__( 93 | self, 94 | root: str = "data", 95 | batch_size: int = 32, 96 | num_workers: int = 4, 97 | max_num_images_per_cat: int = 1000, 98 | image_resolution: int = 64, 99 | label_offset=1, 100 | transform=None 101 | ): 102 | self.root = root 103 | self.batch_size = batch_size 104 | self.num_workers = num_workers 105 | self.afhq_root = os.path.join(root, "afhq") 106 | self.max_num_images_per_cat = max_num_images_per_cat 107 | self.image_resolution = image_resolution 108 | self.label_offset = label_offset 109 | self.transform = transform 110 | 111 | if not os.path.exists(self.afhq_root): 112 | print(f"{self.afhq_root} is empty. Downloading AFHQ dataset...") 113 | self._download_dataset() 114 | 115 | self._set_dataset() 116 | 117 | def _set_dataset(self): 118 | if self.transform is None: 119 | self.transform = transforms.Compose( 120 | [ 121 | transforms.Resize((self.image_resolution, self.image_resolution)), 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 124 | ] 125 | ) 126 | self.train_ds = AFHQDataset( 127 | self.afhq_root, 128 | "train", 129 | self.transform, 130 | max_num_images_per_cat=self.max_num_images_per_cat, 131 | label_offset=self.label_offset 132 | ) 133 | self.val_ds = AFHQDataset( 134 | self.afhq_root, 135 | "val", 136 | self.transform, 137 | max_num_images_per_cat=self.max_num_images_per_cat, 138 | label_offset=self.label_offset, 139 | ) 140 | 141 | self.num_classes = self.train_ds.num_classes 142 | 143 | def _download_dataset(self): 144 | URL = "https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0" 145 | ZIP_FILE = f"./{self.root}/afhq.zip" 146 | os.system(f"mkdir -p {self.root}") 147 | os.system(f"wget -N {URL} -O {ZIP_FILE}") 148 | os.system(f"unzip {ZIP_FILE} -d {self.root}") 149 | os.system(f"rm {ZIP_FILE}") 150 | 151 | def train_dataloader(self): 152 | return torch.utils.data.DataLoader( 153 | self.train_ds, 154 | batch_size=self.batch_size, 155 | num_workers=self.num_workers, 156 | shuffle=True, 157 | drop_last=True, 158 | ) 159 | 160 | def val_dataloader(self): 161 | return torch.utils.data.DataLoader( 162 | self.val_ds, 163 | batch_size=self.batch_size, 164 | num_workers=self.num_workers, 165 | shuffle=False, 166 | drop_last=False, 167 | ) 168 | 169 | 170 | if __name__ == "__main__": 171 | data_module = AFHQDataModule("data", 32, 4, -1, 64, 1) 172 | 173 | eval_dir = Path(data_module.afhq_root) / "eval" 174 | eval_dir.mkdir(exist_ok=True) 175 | def func(path): 176 | fn = path.name 177 | cmd = f"cp {path} {eval_dir / fn}" 178 | os.system(cmd) 179 | img = Image.open(str(eval_dir / fn)) 180 | img = img.resize((64,64)) 181 | img.save(str(eval_dir / fn)) 182 | print(fn) 183 | 184 | with Pool(8) as pool: 185 | pool.map(func, data_module.val_ds.fnames) 186 | 187 | print(f"Constructed eval dir at {eval_dir}") 188 | -------------------------------------------------------------------------------- /image_diffusion_todo/fid/afhq_inception_v3.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/image_diffusion_todo/fid/afhq_inception_v3.ckpt -------------------------------------------------------------------------------- /image_diffusion_todo/fid/inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | import numpy as np 11 | import torch.nn as nn 12 | from torchvision import models 13 | 14 | 15 | class InceptionV3(nn.Module): 16 | def __init__(self, for_train): 17 | super().__init__() 18 | self.for_train = for_train 19 | 20 | inception = models.inception_v3(pretrained=False) 21 | self.block1 = nn.Sequential( 22 | inception.Conv2d_1a_3x3, 23 | inception.Conv2d_2a_3x3, 24 | inception.Conv2d_2b_3x3, 25 | nn.MaxPool2d(kernel_size=3, stride=2), 26 | ) 27 | self.block2 = nn.Sequential( 28 | inception.Conv2d_3b_1x1, 29 | inception.Conv2d_4a_3x3, 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.block3 = nn.Sequential( 33 | inception.Mixed_5b, 34 | inception.Mixed_5c, 35 | inception.Mixed_5d, 36 | inception.Mixed_6a, 37 | inception.Mixed_6b, 38 | inception.Mixed_6c, 39 | inception.Mixed_6d, 40 | inception.Mixed_6e, 41 | ) 42 | self.block4 = nn.Sequential( 43 | inception.Mixed_7a, 44 | inception.Mixed_7b, 45 | inception.Mixed_7c, 46 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 47 | ) 48 | 49 | self.final_fc = nn.Linear(2048, 3) 50 | 51 | def forward(self, x): 52 | x = self.block1(x) 53 | x = self.block2(x) 54 | x = self.block3(x) 55 | x = self.block4(x) 56 | x = x.view(x.size(0), -1) 57 | if self.for_train: 58 | return self.final_fc(x) 59 | else: 60 | return x 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /image_diffusion_todo/fid/measure_fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | from PIL import Image 7 | from scipy import linalg 8 | from torchvision import transforms 9 | from itertools import chain 10 | from pathlib import Path 11 | from inception import InceptionV3 12 | 13 | try: 14 | from tqdm import tqdm 15 | except ImportError: 16 | def tqdm(x): 17 | return x 18 | 19 | class ImagePathDataset(torch.utils.data.Dataset): 20 | def __init__(self, files, img_size): 21 | self.files = files 22 | self.img_size = img_size 23 | self.transforms = transforms.Compose( 24 | [ 25 | transforms.Resize((img_size, img_size)), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 28 | ] 29 | ) 30 | 31 | def __len__(self): 32 | return len(self.files) 33 | 34 | def __getitem__(self, i): 35 | path = self.files[i] 36 | img = Image.open(path).convert("RGB") 37 | if self.transforms is not None: 38 | img = self.transforms(img) 39 | return img 40 | 41 | 42 | def get_eval_loader(path, img_size, batch_size): 43 | def listdir(dname): 44 | fnames = list( 45 | chain( 46 | *[ 47 | list(Path(dname).rglob("*." + ext)) 48 | for ext in ["png", "jpg", "jpeg", "JPG"] 49 | ] 50 | ) 51 | ) 52 | return fnames 53 | 54 | files = listdir(path) 55 | ds = ImagePathDataset(files, img_size) 56 | dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4) 57 | return dl 58 | 59 | def frechet_distance(mu, cov, mu2, cov2): 60 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False) 61 | dist = np.sum((mu - mu2) ** 2) + np.trace(cov + cov2 - 2 * cc) 62 | return np.real(dist) 63 | 64 | 65 | 66 | @torch.no_grad() 67 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50): 68 | print("Calculating FID given paths %s and %s..." % (paths[0], paths[1])) 69 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 70 | inception = InceptionV3(for_train=False) 71 | current_dir = Path(os.path.realpath(__file__)).parent 72 | ckpt = torch.load(current_dir / "afhq_inception_v3.ckpt", map_location="cpu") 73 | inception.load_state_dict(ckpt) 74 | inception = inception.eval().to(device) 75 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] 76 | 77 | mu, cov = [], [] 78 | for loader in loaders: 79 | actvs = [] 80 | for x in tqdm(loader, total=len(loader)): 81 | actv = inception(x.to(device)) 82 | actvs.append(actv) 83 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy() 84 | mu.append(np.mean(actvs, axis=0)) 85 | cov.append(np.cov(actvs, rowvar=False)) 86 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1]) 87 | return fid_value 88 | 89 | if __name__ == "__main__": 90 | # python measure_fid /path/to/dir1 /path/to/dir2 91 | 92 | paths = [sys.argv[1], sys.argv[2]] 93 | fid_value = calculate_fid_given_paths(paths, img_size=256, batch_size=64) 94 | print("FID:", fid_value) 95 | 96 | -------------------------------------------------------------------------------- /image_diffusion_todo/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | class DiffusionModule(nn.Module): 11 | def __init__(self, network, var_scheduler, **kwargs): 12 | super().__init__() 13 | self.network = network 14 | self.var_scheduler = var_scheduler 15 | 16 | def get_loss(self, x0, class_label=None, noise=None): 17 | ######## TODO ######## 18 | # DO NOT change the code outside this part. 19 | # compute noise matching loss. 20 | B = x0.shape[0] 21 | timestep = self.var_scheduler.uniform_sample_t(B, self.device) 22 | loss = x0.mean() 23 | ###################### 24 | return loss 25 | 26 | @property 27 | def device(self): 28 | return next(self.network.parameters()).device 29 | 30 | @property 31 | def image_resolution(self): 32 | return self.network.image_resolution 33 | 34 | @torch.no_grad() 35 | def sample( 36 | self, 37 | batch_size, 38 | return_traj=False, 39 | class_label: Optional[torch.Tensor] = None, 40 | guidance_scale: Optional[float] = 1.0, 41 | ): 42 | x_T = torch.randn([batch_size, 3, self.image_resolution, self.image_resolution]).to(self.device) 43 | 44 | do_classifier_free_guidance = guidance_scale > 1.0 45 | 46 | if do_classifier_free_guidance: 47 | 48 | ######## TODO ######## 49 | # Assignment 2. Implement the classifier-free guidance. 50 | # Specifically, given a tensor of shape (batch_size,) containing class labels, 51 | # create a tensor of shape (2*batch_size,) where the first half is filled with zeros (i.e., null condition). 52 | assert class_label is not None 53 | assert len(class_label) == batch_size, f"len(class_label) != batch_size. {len(class_label)} != {batch_size}" 54 | raise NotImplementedError("TODO") 55 | ####################### 56 | 57 | traj = [x_T] 58 | for t in tqdm(self.var_scheduler.timesteps): 59 | x_t = traj[-1] 60 | if do_classifier_free_guidance: 61 | ######## TODO ######## 62 | # Assignment 2. Implement the classifier-free guidance. 63 | raise NotImplementedError("TODO") 64 | ####################### 65 | else: 66 | noise_pred = self.network(x_t, timestep=t.to(self.device)) 67 | 68 | x_t_prev = self.var_scheduler.step(x_t, t, noise_pred) 69 | 70 | traj[-1] = traj[-1].cpu() 71 | traj.append(x_t_prev.detach()) 72 | 73 | if return_traj: 74 | return traj 75 | else: 76 | return traj[-1] 77 | 78 | def save(self, file_path): 79 | hparams = { 80 | "network": self.network, 81 | "var_scheduler": self.var_scheduler, 82 | } 83 | state_dict = self.state_dict() 84 | 85 | dic = {"hparams": hparams, "state_dict": state_dict} 86 | torch.save(dic, file_path) 87 | 88 | def load(self, file_path): 89 | dic = torch.load(file_path, map_location="cpu") 90 | hparams = dic["hparams"] 91 | state_dict = dic["state_dict"] 92 | 93 | self.network = hparams["network"] 94 | self.var_scheduler = hparams["var_scheduler"] 95 | 96 | self.load_state_dict(state_dict) 97 | -------------------------------------------------------------------------------- /image_diffusion_todo/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | class Swish(nn.Module): 10 | def forward(self, x): 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | class DownSample(nn.Module): 15 | def __init__(self, in_ch): 16 | super().__init__() 17 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 18 | self.initialize() 19 | 20 | def initialize(self): 21 | init.xavier_uniform_(self.main.weight) 22 | init.zeros_(self.main.bias) 23 | 24 | def forward(self, x, temb): 25 | x = self.main(x) 26 | return x 27 | 28 | 29 | class UpSample(nn.Module): 30 | def __init__(self, in_ch): 31 | super().__init__() 32 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 33 | self.initialize() 34 | 35 | def initialize(self): 36 | init.xavier_uniform_(self.main.weight) 37 | init.zeros_(self.main.bias) 38 | 39 | def forward(self, x, temb): 40 | _, _, H, W = x.shape 41 | x = F.interpolate(x, scale_factor=2, mode="nearest") 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class AttnBlock(nn.Module): 47 | def __init__(self, in_ch): 48 | super().__init__() 49 | self.group_norm = nn.GroupNorm(32, in_ch) 50 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 51 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 52 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 53 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 54 | self.initialize() 55 | 56 | def initialize(self): 57 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 58 | init.xavier_uniform_(module.weight) 59 | init.zeros_(module.bias) 60 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | h = self.group_norm(x) 65 | q = self.proj_q(h) 66 | k = self.proj_k(h) 67 | v = self.proj_v(h) 68 | 69 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 70 | k = k.view(B, C, H * W) 71 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 72 | assert list(w.shape) == [B, H * W, H * W] 73 | w = F.softmax(w, dim=-1) 74 | 75 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 76 | h = torch.bmm(w, v) 77 | assert list(h.shape) == [B, H * W, C] 78 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 79 | h = self.proj(h) 80 | 81 | return x + h 82 | 83 | 84 | class ResBlock(nn.Module): 85 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): 86 | super().__init__() 87 | self.block1 = nn.Sequential( 88 | nn.GroupNorm(32, in_ch), 89 | Swish(), 90 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 91 | ) 92 | self.temb_proj = nn.Sequential( 93 | Swish(), 94 | nn.Linear(tdim, out_ch), 95 | ) 96 | self.block2 = nn.Sequential( 97 | nn.GroupNorm(32, out_ch), 98 | Swish(), 99 | nn.Dropout(dropout), 100 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 101 | ) 102 | if in_ch != out_ch: 103 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 104 | else: 105 | self.shortcut = nn.Identity() 106 | if attn: 107 | self.attn = AttnBlock(out_ch) 108 | else: 109 | self.attn = nn.Identity() 110 | self.initialize() 111 | 112 | def initialize(self): 113 | for module in self.modules(): 114 | if isinstance(module, (nn.Conv2d, nn.Linear)): 115 | init.xavier_uniform_(module.weight) 116 | init.zeros_(module.bias) 117 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) 118 | 119 | def forward(self, x, temb): 120 | h = self.block1(x) 121 | h += self.temb_proj(temb)[:, :, None, None] 122 | h = self.block2(h) 123 | 124 | h = h + self.shortcut(x) 125 | h = self.attn(h) 126 | return h 127 | 128 | 129 | class TimeEmbedding(nn.Module): 130 | def __init__(self, hidden_size, frequency_embedding_size=256): 131 | super().__init__() 132 | self.mlp = nn.Sequential( 133 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, hidden_size, bias=True), 136 | ) 137 | self.frequency_embedding_size = frequency_embedding_size 138 | 139 | @staticmethod 140 | def timestep_embedding(t, dim, max_period=10000): 141 | """ 142 | Create sinusoidal timestep embeddings. 143 | :param t: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an (N, D) Tensor of positional embeddings. 148 | """ 149 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 150 | half = dim // 2 151 | freqs = torch.exp( 152 | -math.log(max_period) 153 | * torch.arange(start=0, end=half, dtype=torch.float32) 154 | / half 155 | ).to(device=t.device) 156 | args = t[:, None].float() * freqs[None] 157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 158 | if dim % 2: 159 | embedding = torch.cat( 160 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 161 | ) 162 | return embedding 163 | 164 | def forward(self, t): 165 | if t.ndim == 0: 166 | t = t.unsqueeze(-1) 167 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 168 | t_emb = self.mlp(t_freq) 169 | return t_emb 170 | -------------------------------------------------------------------------------- /image_diffusion_todo/network.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from module import DownSample, ResBlock, Swish, TimeEmbedding, UpSample 8 | from torch.nn import init 9 | 10 | 11 | class UNet(nn.Module): 12 | def __init__(self, T=1000, image_resolution=64, ch=128, ch_mult=[1,2,2,2], attn=[1], num_res_blocks=4, dropout=0.1, use_cfg=False, cfg_dropout=0.1, num_classes=None): 13 | super().__init__() 14 | self.image_resolution = image_resolution 15 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 16 | tdim = ch * 4 17 | # self.time_embedding = TimeEmbedding(T, ch, tdim) 18 | self.time_embedding = TimeEmbedding(tdim) 19 | 20 | # classifier-free guidance 21 | self.use_cfg = use_cfg 22 | self.cfg_dropout = cfg_dropout 23 | if use_cfg: 24 | assert num_classes is not None 25 | cdim = tdim 26 | self.class_embedding = nn.Embedding(num_classes+1, cdim) 27 | 28 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) 29 | self.downblocks = nn.ModuleList() 30 | chs = [ch] # record output channel when dowmsample for upsample 31 | now_ch = ch 32 | for i, mult in enumerate(ch_mult): 33 | out_ch = ch * mult 34 | for _ in range(num_res_blocks): 35 | self.downblocks.append(ResBlock( 36 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 37 | dropout=dropout, attn=(i in attn))) 38 | now_ch = out_ch 39 | chs.append(now_ch) 40 | if i != len(ch_mult) - 1: 41 | self.downblocks.append(DownSample(now_ch)) 42 | chs.append(now_ch) 43 | 44 | self.middleblocks = nn.ModuleList([ 45 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True), 46 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False), 47 | ]) 48 | 49 | self.upblocks = nn.ModuleList() 50 | for i, mult in reversed(list(enumerate(ch_mult))): 51 | out_ch = ch * mult 52 | for _ in range(num_res_blocks + 1): 53 | self.upblocks.append(ResBlock( 54 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 55 | dropout=dropout, attn=(i in attn))) 56 | now_ch = out_ch 57 | if i != 0: 58 | self.upblocks.append(UpSample(now_ch)) 59 | assert len(chs) == 0 60 | 61 | self.tail = nn.Sequential( 62 | nn.GroupNorm(32, now_ch), 63 | Swish(), 64 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) 65 | ) 66 | self.initialize() 67 | 68 | def initialize(self): 69 | init.xavier_uniform_(self.head.weight) 70 | init.zeros_(self.head.bias) 71 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 72 | init.zeros_(self.tail[-1].bias) 73 | 74 | def forward(self, x, timestep, class_label=None): 75 | # Timestep embedding 76 | temb = self.time_embedding(timestep) 77 | 78 | if self.use_cfg and class_label is not None: 79 | if self.training: 80 | assert not torch.any(class_label == 0) # 0 for null. 81 | 82 | ######## TODO ######## 83 | # DO NOT change the code outside this part. 84 | # Assignment 2. Implement random null conditioning in CFG training. 85 | raise NotImplementedError("TODO") 86 | ####################### 87 | 88 | ######## TODO ######## 89 | # DO NOT change the code outside this part. 90 | # Assignment 2. Implement class conditioning 91 | raise NotImplementedError("TODO") 92 | ####################### 93 | 94 | # Downsampling 95 | h = self.head(x) 96 | hs = [h] 97 | for layer in self.downblocks: 98 | h = layer(h, temb) 99 | hs.append(h) 100 | # Middle 101 | for layer in self.middleblocks: 102 | h = layer(h, temb) 103 | # Upsampling 104 | for layer in self.upblocks: 105 | if isinstance(layer, ResBlock): 106 | h = torch.cat([h, hs.pop()], dim=1) 107 | h = layer(h, temb) 108 | h = self.tail(h) 109 | 110 | assert len(hs) == 0 111 | return h 112 | -------------------------------------------------------------------------------- /image_diffusion_todo/sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from dataset import tensor_to_pil_image 6 | from model import DiffusionModule 7 | from scheduler import DDPMScheduler 8 | from pathlib import Path 9 | 10 | 11 | def main(args): 12 | save_dir = Path(args.save_dir) 13 | save_dir.mkdir(exist_ok=True, parents=True) 14 | 15 | device = f"cuda:{args.gpu}" 16 | 17 | ddpm = DiffusionModule(None, None) 18 | ddpm.load(args.ckpt_path) 19 | ddpm.eval() 20 | ddpm = ddpm.to(device) 21 | 22 | num_train_timesteps = ddpm.var_scheduler.num_train_timesteps 23 | ddpm.var_scheduler = DDPMScheduler( 24 | num_train_timesteps, 25 | beta_1=1e-4, 26 | beta_T=0.02, 27 | mode="linear", 28 | ).to(device) 29 | 30 | total_num_samples = 500 31 | num_batches = int(np.ceil(total_num_samples / args.batch_size)) 32 | 33 | for i in range(num_batches): 34 | sidx = i * args.batch_size 35 | eidx = min(sidx + args.batch_size, total_num_samples) 36 | B = eidx - sidx 37 | 38 | if args.use_cfg: # Enable CFG sampling 39 | assert ddpm.network.use_cfg, f"The model was not trained to support CFG." 40 | samples = ddpm.sample( 41 | B, 42 | class_label=torch.randint(1, 4, (B,)), 43 | guidance_scale=args.cfg_scale, 44 | ) 45 | else: 46 | samples = ddpm.sample(B) 47 | 48 | pil_images = tensor_to_pil_image(samples) 49 | 50 | for j, img in zip(range(sidx, eidx), pil_images): 51 | img.save(save_dir / f"{j}.png") 52 | print(f"Saved the {j}-th image.") 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--batch_size", type=int, default=64) 58 | parser.add_argument("--gpu", type=int, default=0) 59 | parser.add_argument("--ckpt_path", type=str) 60 | parser.add_argument("--save_dir", type=str) 61 | parser.add_argument("--use_cfg", action="store_true") 62 | parser.add_argument("--sample_method", type=str, default="ddpm") 63 | parser.add_argument("--cfg_scale", type=float, default=7.5) 64 | 65 | args = parser.parse_args() 66 | main(args) 67 | -------------------------------------------------------------------------------- /image_diffusion_todo/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class BaseScheduler(nn.Module): 9 | def __init__( 10 | self, num_train_timesteps: int, beta_1: float, beta_T: float, mode="linear" 11 | ): 12 | super().__init__() 13 | self.num_train_timesteps = num_train_timesteps 14 | self.num_inference_timesteps = num_train_timesteps 15 | self.timesteps = torch.from_numpy( 16 | np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64) 17 | ) 18 | 19 | if mode == "linear": 20 | betas = torch.linspace(beta_1, beta_T, steps=num_train_timesteps) 21 | elif mode == "quad": 22 | betas = ( 23 | torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps) ** 2 24 | ) 25 | else: 26 | raise NotImplementedError(f"{mode} is not implemented.") 27 | 28 | alphas = 1 - betas 29 | alphas_cumprod = torch.cumprod(alphas, dim=0) 30 | 31 | self.register_buffer("betas", betas) 32 | self.register_buffer("alphas", alphas) 33 | self.register_buffer("alphas_cumprod", alphas_cumprod) 34 | 35 | def uniform_sample_t( 36 | self, batch_size, device: Optional[torch.device] = None 37 | ) -> torch.IntTensor: 38 | """ 39 | Uniformly sample timesteps. 40 | """ 41 | ts = np.random.choice(np.arange(self.num_train_timesteps), batch_size) 42 | ts = torch.from_numpy(ts) 43 | if device is not None: 44 | ts = ts.to(device) 45 | return ts 46 | 47 | class DDPMScheduler(BaseScheduler): 48 | def __init__( 49 | self, 50 | num_train_timesteps: int, 51 | beta_1: float, 52 | beta_T: float, 53 | mode="linear", 54 | sigma_type="small", 55 | ): 56 | super().__init__(num_train_timesteps, beta_1, beta_T, mode) 57 | 58 | # sigmas correspond to $\sigma_t$ in the DDPM paper. 59 | self.sigma_type = sigma_type 60 | if sigma_type == "small": 61 | # when $\sigma_t^2 = \tilde{\beta}_t$. 62 | alphas_cumprod_t_prev = torch.cat( 63 | [torch.tensor([1.0]), self.alphas_cumprod[:-1]] 64 | ) 65 | sigmas = ( 66 | (1 - alphas_cumprod_t_prev) / (1 - self.alphas_cumprod) * self.betas 67 | ) ** 0.5 68 | elif sigma_type == "large": 69 | # when $\sigma_t^2 = \beta_t$. 70 | sigmas = self.betas ** 0.5 71 | 72 | self.register_buffer("sigmas", sigmas) 73 | 74 | def step(self, x_t: torch.Tensor, t: int, eps_theta: torch.Tensor): 75 | """ 76 | One step denoising function of DDPM: x_t -> x_{t-1}. 77 | 78 | Input: 79 | x_t (`torch.Tensor [B,C,H,W]`): samples at arbitrary timestep t. 80 | t (`int`): current timestep in a reverse process. 81 | eps_theta (`torch.Tensor [B,C,H,W]`): predicted noise from a learned model. 82 | Ouptut: 83 | sample_prev (`torch.Tensor [B,C,H,W]`): one step denoised sample. (= x_{t-1}) 84 | """ 85 | 86 | ######## TODO ######## 87 | # DO NOT change the code outside this part. 88 | # Assignment 1. Implement the DDPM reverse step. 89 | sample_prev = None 90 | ####################### 91 | 92 | return sample_prev 93 | 94 | # https://nn.labml.ai/diffusion/ddpm/utils.html 95 | def _get_teeth(self, consts: torch.Tensor, t: torch.Tensor): # get t th const 96 | const = consts.gather(-1, t) 97 | return const.reshape(-1, 1, 1, 1) 98 | 99 | def add_noise( 100 | self, 101 | x_0: torch.Tensor, 102 | t: torch.IntTensor, 103 | eps: Optional[torch.Tensor] = None, 104 | ): 105 | """ 106 | A forward pass of a Markov chain, i.e., q(x_t | x_0). 107 | 108 | Input: 109 | x_0 (`torch.Tensor [B,C,H,W]`): samples from a real data distribution q(x_0). 110 | t: (`torch.IntTensor [B]`) 111 | eps: (`torch.Tensor [B,C,H,W]`, optional): if None, randomly sample Gaussian noise in the function. 112 | Output: 113 | x_t: (`torch.Tensor [B,C,H,W]`): noisy samples at timestep t. 114 | eps: (`torch.Tensor [B,C,H,W]`): injected noise. 115 | """ 116 | 117 | if eps is None: 118 | eps = torch.randn(x_0.shape, device='cuda') 119 | 120 | ######## TODO ######## 121 | # DO NOT change the code outside this part. 122 | # Assignment 1. Implement the DDPM forward step. 123 | x_t = None 124 | ####################### 125 | 126 | return x_t, eps 127 | -------------------------------------------------------------------------------- /image_diffusion_todo/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from dataset import AFHQDataModule, get_data_iterator, tensor_to_pil_image 10 | from dotmap import DotMap 11 | from model import DiffusionModule 12 | from network import UNet 13 | from pytorch_lightning import seed_everything 14 | from scheduler import DDPMScheduler 15 | from torchvision.transforms.functional import to_pil_image 16 | from tqdm import tqdm 17 | 18 | matplotlib.use("Agg") 19 | 20 | 21 | def get_current_time(): 22 | now = datetime.now().strftime("%m-%d-%H%M%S") 23 | return now 24 | 25 | 26 | def main(args): 27 | """config""" 28 | config = DotMap() 29 | config.update(vars(args)) 30 | config.device = f"cuda:{args.gpu}" 31 | 32 | now = get_current_time() 33 | if args.use_cfg: 34 | save_dir = Path(f"results/cfg_diffusion-{args.sample_method}-{now}") 35 | else: 36 | save_dir = Path(f"results/diffusion-{args.sample_method}-{now}") 37 | save_dir.mkdir(exist_ok=True, parents=True) 38 | print(f"save_dir: {save_dir}") 39 | 40 | seed_everything(config.seed) 41 | 42 | with open(save_dir / "config.json", "w") as f: 43 | json.dump(config, f, indent=2) 44 | """######""" 45 | 46 | image_resolution = 64 47 | ds_module = AFHQDataModule( 48 | "./data", 49 | batch_size=config.batch_size, 50 | num_workers=4, 51 | max_num_images_per_cat=config.max_num_images_per_cat, 52 | image_resolution=image_resolution 53 | ) 54 | 55 | train_dl = ds_module.train_dataloader() 56 | train_it = get_data_iterator(train_dl) 57 | 58 | # Set up the scheduler 59 | var_scheduler = DDPMScheduler( 60 | config.num_diffusion_train_timesteps, 61 | beta_1=config.beta_1, 62 | beta_T=config.beta_T, 63 | mode="linear", 64 | ) 65 | 66 | network = UNet( 67 | T=config.num_diffusion_train_timesteps, 68 | image_resolution=image_resolution, 69 | ch=128, 70 | ch_mult=[1, 2, 2, 2], 71 | attn=[1], 72 | num_res_blocks=4, 73 | dropout=0.1, 74 | use_cfg=args.use_cfg, 75 | cfg_dropout=args.cfg_dropout, 76 | num_classes=getattr(ds_module, "num_classes", None), 77 | ) 78 | 79 | ddpm = DiffusionModule(network, var_scheduler) 80 | ddpm = ddpm.to(config.device) 81 | 82 | optimizer = torch.optim.Adam(ddpm.network.parameters(), lr=2e-4) 83 | scheduler = torch.optim.lr_scheduler.LambdaLR( 84 | optimizer, lr_lambda=lambda t: min((t + 1) / config.warmup_steps, 1.0) 85 | ) 86 | 87 | step = 0 88 | losses = [] 89 | with tqdm(initial=step, total=config.train_num_steps) as pbar: 90 | while step < config.train_num_steps: 91 | if step % config.log_interval == 0: 92 | ddpm.eval() 93 | plt.plot(losses) 94 | plt.savefig(f"{save_dir}/loss.png") 95 | plt.close() 96 | samples = ddpm.sample(4, return_traj=False) 97 | pil_images = tensor_to_pil_image(samples) 98 | for i, img in enumerate(pil_images): 99 | img.save(save_dir / f"step={step}-{i}.png") 100 | 101 | ddpm.save(f"{save_dir}/last.ckpt") 102 | ddpm.train() 103 | 104 | img, label = next(train_it) 105 | img, label = img.to(config.device), label.to(config.device) 106 | if args.use_cfg: # Conditional, CFG training 107 | loss = ddpm.get_loss(img, class_label=label) 108 | else: # Unconditional training 109 | loss = ddpm.get_loss(img) 110 | pbar.set_description(f"Loss: {loss.item():.4f}") 111 | 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | scheduler.step() 116 | losses.append(loss.item()) 117 | 118 | step += 1 119 | pbar.update(1) 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--gpu", type=int, default=0) 125 | parser.add_argument("--batch_size", type=int, default=16) 126 | parser.add_argument( 127 | "--train_num_steps", 128 | type=int, 129 | default=100000, 130 | help="the number of model training steps.", 131 | ) 132 | parser.add_argument("--warmup_steps", type=int, default=200) 133 | parser.add_argument("--log_interval", type=int, default=200) 134 | parser.add_argument( 135 | "--max_num_images_per_cat", 136 | type=int, 137 | default=3000, 138 | help="max number of images per category for AFHQ dataset", 139 | ) 140 | parser.add_argument( 141 | "--num_diffusion_train_timesteps", 142 | type=int, 143 | default=1000, 144 | help="diffusion Markov chain num steps", 145 | ) 146 | parser.add_argument("--beta_1", type=float, default=1e-4) 147 | parser.add_argument("--beta_T", type=float, default=0.02) 148 | parser.add_argument("--seed", type=int, default=63) 149 | parser.add_argument("--image_resolution", type=int, default=64) 150 | parser.add_argument("--sample_method", type=str, default="ddpm") 151 | parser.add_argument("--use_cfg", action="store_true") 152 | parser.add_argument("--cfg_dropout", type=float, default=0.1) 153 | args = parser.parse_args() 154 | main(args) 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.1.3 2 | ipython==8.12.0 3 | jupyter==1.0.0 4 | matplotlib==3.6.0 5 | torch-ema==0.3 6 | tqdm==4.64.1 7 | jupyterlab==4.0.2 8 | jaxtyping==0.2.20 9 | pytorch_lightning 10 | dotmap 11 | numpy==1.26.0 --------------------------------------------------------------------------------