├── .gitignore
├── 2d_plot_diffusion_todo
├── __init__.py
├── chamferdist.py
├── dataset.py
├── ddpm.py
├── ddpm_tutorial.ipynb
└── network.py
├── LICENSE
├── README.md
├── assets
├── images
│ ├── cfg_test.png
│ ├── cfg_train.png
│ ├── fid_command.png
│ ├── qs.png
│ ├── sampling_command.png
│ ├── task1_ddim_sample.png
│ ├── task1_ddpm_sample.png
│ ├── task1_distribution.png
│ ├── task1_loss_curve.png
│ ├── task1_output_example.png
│ ├── task2_1_ddpm_sampling_algorithm.png
│ ├── task2_algorithm.png
│ ├── task2_ddim.png
│ ├── task2_output_example.png
│ ├── task2_teaser.png
│ ├── teaser.gif
│ └── teaser.png
└── summary_of_DDPM_and_DDIM.pdf
├── image_diffusion_todo
├── dataset.py
├── fid
│ ├── afhq_inception_v3.ckpt
│ ├── inception.py
│ └── measure_fid.py
├── model.py
├── module.py
├── network.py
├── sampling.py
├── scheduler.py
└── train.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *.ipynb_checkpoints
4 | data/
5 | results/
6 | outputs
7 | samples/
8 | .DS_Store
9 |
10 | # Solutions
11 | 2d_plot_diffusion
12 | image_diffusion
--------------------------------------------------------------------------------
/2d_plot_diffusion_todo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/2d_plot_diffusion_todo/__init__.py
--------------------------------------------------------------------------------
/2d_plot_diffusion_todo/chamferdist.py:
--------------------------------------------------------------------------------
1 | from scipy.spatial import KDTree
2 | from scipy.spatial.distance import cdist
3 |
4 | def chamfer_distance(S1, S2) -> float:
5 | """
6 | Computes the Chamfer distance between two point clouds defined as:
7 | d_CD(S1, S2) = \sigma_{x \in S1} min_{y in S2} ||x - y||^2 + \sigma_{y \in S2} min_{x in S1} ||x - y||^2
8 | """
9 | dist = cdist(S1, S2)
10 | dist1 = dist.min(axis=1) ** 2
11 | dist2 = dist.min(axis=0) ** 2
12 | return dist1.sum() + dist2.sum()
13 |
--------------------------------------------------------------------------------
/2d_plot_diffusion_todo/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn import datasets
4 | from torch.utils.data import DataLoader, Dataset
5 |
6 |
7 | def normalize(ds, scaling_factor=2.0):
8 | return (ds - ds.mean()) / ds.std() * scaling_factor
9 |
10 |
11 | def sample_checkerboard(n):
12 | # https://github.com/ghliu/SB-FBSDE/blob/main/data.py
13 | n_points = 3 * n
14 | n_classes = 2
15 | freq = 5
16 | x = np.random.uniform(
17 | -(freq // 2) * np.pi, (freq // 2) * np.pi, size=(n_points, n_classes)
18 | )
19 | mask = np.logical_or(
20 | np.logical_and(np.sin(x[:, 0]) > 0.0, np.sin(x[:, 1]) > 0.0),
21 | np.logical_and(np.sin(x[:, 0]) < 0.0, np.sin(x[:, 1]) < 0.0),
22 | )
23 | y = np.eye(n_classes)[1 * mask]
24 | x0 = x[:, 0] * y[:, 0]
25 | x1 = x[:, 1] * y[:, 0]
26 | sample = np.concatenate([x0[..., None], x1[..., None]], axis=-1)
27 | sqr = np.sum(np.square(sample), axis=-1)
28 | idxs = np.where(sqr == 0)
29 | sample = np.delete(sample, idxs, axis=0)
30 |
31 | return sample
32 |
33 |
34 | def load_twodim(num_samples: int, dataset: str, dimension: int = 2):
35 |
36 | if dataset == "gaussian_centered":
37 | sample = np.random.normal(size=(num_samples, dimension))
38 | sample = sample
39 |
40 | if dataset == "gaussian_shift":
41 | sample = np.random.normal(size=(num_samples, dimension))
42 | sample = sample + 1.5
43 |
44 | if dataset == "circle":
45 | X, y = datasets.make_circles(
46 | n_samples=num_samples, noise=0.0, random_state=None, factor=0.5
47 | )
48 | sample = X * 4
49 |
50 | if dataset == "scurve":
51 | X, y = datasets.make_s_curve(
52 | n_samples=num_samples, noise=0.0, random_state=None
53 | )
54 | sample = normalize(X[:, [0, 2]])
55 |
56 | if dataset == "moon":
57 | X, y = datasets.make_moons(n_samples=num_samples, noise=0.0, random_state=None)
58 | sample = normalize(X)
59 |
60 | if dataset == "swiss_roll":
61 | X, y = datasets.make_swiss_roll(
62 | n_samples=num_samples, noise=0.0, random_state=None, hole=True
63 | )
64 | sample = normalize(X[:, [0, 2]])
65 |
66 | if dataset == "checkerboard":
67 | sample = normalize(sample_checkerboard(num_samples))
68 |
69 | return torch.tensor(sample).float()
70 |
71 |
72 | class TwoDimDataClass(Dataset):
73 | def __init__(self, dataset_type: str, N: int, batch_size: int, dimension=2):
74 |
75 | self.X = load_twodim(N, dataset_type, dimension=dimension)
76 | self.name = dataset_type
77 | self.batch_size = batch_size
78 | self.dimension = 2
79 |
80 | def __len__(self):
81 | return self.X.shape[0]
82 |
83 | def __getitem__(self, idx):
84 | return self.X[idx]
85 |
86 | def get_dataloader(self, shuffle=True):
87 | return DataLoader(
88 | self,
89 | batch_size=self.batch_size,
90 | shuffle=shuffle,
91 | pin_memory=True,
92 | )
93 |
94 |
95 | def get_data_iterator(iterable):
96 | iterator = iterable.__iter__()
97 | while True:
98 | try:
99 | yield iterator.__next__()
100 | except StopIteration:
101 | iterator = iterable.__iter__()
102 |
--------------------------------------------------------------------------------
/2d_plot_diffusion_todo/ddpm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def extract(input, t: torch.Tensor, x: torch.Tensor):
8 | if t.ndim == 0:
9 | t = t.unsqueeze(0)
10 | shape = x.shape
11 | t = t.long().to(input.device)
12 | out = torch.gather(input, 0, t)
13 | reshape = [t.shape[0]] + [1] * (len(shape) - 1)
14 | return out.reshape(*reshape)
15 |
16 |
17 | class BaseScheduler(nn.Module):
18 | """
19 | Variance scheduler of DDPM.
20 | """
21 |
22 | def __init__(
23 | self,
24 | num_train_timesteps: int,
25 | beta_1: float = 1e-4,
26 | beta_T: float = 0.02,
27 | mode: str = "linear",
28 | ):
29 | super().__init__()
30 | self.num_train_timesteps = num_train_timesteps
31 | self.timesteps = torch.from_numpy(
32 | np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
33 | )
34 |
35 | if mode == "linear":
36 | betas = torch.linspace(beta_1, beta_T, steps=num_train_timesteps)
37 | elif mode == "quad":
38 | betas = (
39 | torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps) ** 2
40 | )
41 | else:
42 | raise NotImplementedError(f"{mode} is not implemented.")
43 |
44 | alphas = 1 - betas
45 | alphas_cumprod = torch.cumprod(alphas, dim=0)
46 |
47 | self.register_buffer("betas", betas)
48 | self.register_buffer("alphas", alphas)
49 | self.register_buffer("alphas_cumprod", alphas_cumprod)
50 |
51 |
52 | class DiffusionModule(nn.Module):
53 | """
54 | A high-level wrapper of DDPM and DDIM.
55 | If you want to sample data based on the DDIM's reverse process, use `ddim_p_sample()` and `ddim_p_sample_loop()`.
56 | """
57 |
58 | def __init__(self, network: nn.Module, var_scheduler: BaseScheduler):
59 | super().__init__()
60 | self.network = network
61 | self.var_scheduler = var_scheduler
62 |
63 | @property
64 | def device(self):
65 | return next(self.network.parameters()).device
66 |
67 | @property
68 | def image_resolution(self):
69 | # For image diffusion model.
70 | return getattr(self.network, "image_resolution", None)
71 |
72 | def q_sample(self, x0, t, noise=None):
73 | """
74 | sample x_t from q(x_t | x_0) of DDPM.
75 |
76 | Input:
77 | x0 (`torch.Tensor`): clean data to be mapped to timestep t in the forward process of DDPM.
78 | t (`torch.Tensor`): timestep
79 | noise (`torch.Tensor`, optional): random Gaussian noise. if None, randomly sample Gaussian noise in the function.
80 | Output:
81 | xt (`torch.Tensor`): noisy samples
82 | """
83 | if noise is None:
84 | noise = torch.randn_like(x0)
85 |
86 | ######## TODO ########
87 | # DO NOT change the code outside this part.
88 | # Compute xt.
89 | alphas_prod_t = extract(self.var_scheduler.alphas_cumprod, t, x0)
90 | xt = x0
91 |
92 | #######################
93 |
94 | return xt
95 |
96 | @torch.no_grad()
97 | def p_sample(self, xt, t):
98 | """
99 | One step denoising function of DDPM: x_t -> x_{t-1}.
100 |
101 | Input:
102 | xt (`torch.Tensor`): samples at arbitrary timestep t.
103 | t (`torch.Tensor`): current timestep in a reverse process.
104 | Ouptut:
105 | x_t_prev (`torch.Tensor`): one step denoised sample. (= x_{t-1})
106 |
107 | """
108 | ######## TODO ########
109 | # DO NOT change the code outside this part.
110 | # compute x_t_prev.
111 | if isinstance(t, int):
112 | t = torch.tensor([t]).to(self.device)
113 | eps_factor = (1 - extract(self.var_scheduler.alphas, t, xt)) / (
114 | 1 - extract(self.var_scheduler.alphas_cumprod, t, xt)
115 | ).sqrt()
116 | eps_theta = self.network(xt, t)
117 |
118 | x_t_prev = xt
119 |
120 | #######################
121 | return x_t_prev
122 |
123 | @torch.no_grad()
124 | def p_sample_loop(self, shape):
125 | """
126 | The loop of the reverse process of DDPM.
127 |
128 | Input:
129 | shape (`Tuple`): The shape of output. e.g., (num particles, 2)
130 | Output:
131 | x0_pred (`torch.Tensor`): The final denoised output through the DDPM reverse process.
132 | """
133 | ######## TODO ########
134 | # DO NOT change the code outside this part.
135 | # sample x0 based on Algorithm 2 of DDPM paper.
136 | x0_pred = torch.zeros(shape).to(self.device)
137 |
138 | ######################
139 | return x0_pred
140 |
141 | @torch.no_grad()
142 | def ddim_p_sample(self, xt, t, t_prev, eta=0.0):
143 | """
144 | One step denoising function of DDIM: $x_t{\tau_i}$ -> $x_{\tau{i-1}}$.
145 |
146 | Input:
147 | xt (`torch.Tensor`): noisy data at timestep $\tau_i$.
148 | t (`torch.Tensor`): current timestep (=\tau_i)
149 | t_prev (`torch.Tensor`): next timestep in a reverse process (=\tau_{i-1})
150 | eta (float): correspond to η in DDIM which controls the stochasticity of a reverse process.
151 | Output:
152 | x_t_prev (`torch.Tensor`): one step denoised sample. (= $x_{\tau_{i-1}}$)
153 | """
154 | ######## TODO ########
155 | # NOTE: This code is used for assignment 2. You don't need to implement this part for assignment 1.
156 | # DO NOT change the code outside this part.
157 | # compute x_t_prev based on ddim reverse process.
158 | alpha_prod_t = extract(self.var_scheduler.alphas_cumprod, t, xt)
159 | if t_prev >= 0:
160 | alpha_prod_t_prev = extract(self.var_scheduler.alphas_cumprod, t_prev, xt)
161 | else:
162 | alpha_prod_t_prev = torch.ones_like(alpha_prod_t)
163 |
164 | x_t_prev = xt
165 |
166 | ######################
167 | return x_t_prev
168 |
169 | @torch.no_grad()
170 | def ddim_p_sample_loop(self, shape, num_inference_timesteps=50, eta=0.0):
171 | """
172 | The loop of the reverse process of DDIM.
173 |
174 | Input:
175 | shape (`Tuple`): The shape of output. e.g., (num particles, 2)
176 | num_inference_timesteps (`int`): the number of timesteps in the reverse process.
177 | eta (`float`): correspond to η in DDIM which controls the stochasticity of a reverse process.
178 | Output:
179 | x0_pred (`torch.Tensor`): The final denoised output through the DDPM reverse process.
180 | """
181 | ######## TODO ########
182 | # NOTE: This code is used for assignment 2. You don't need to implement this part for assignment 1.
183 | # DO NOT change the code outside this part.
184 | # sample x0 based on Algorithm 2 of DDPM paper.
185 | step_ratio = self.var_scheduler.num_train_timesteps // num_inference_timesteps
186 | timesteps = (
187 | (np.arange(0, num_inference_timesteps) * step_ratio)
188 | .round()[::-1]
189 | .copy()
190 | .astype(np.int64)
191 | )
192 | timesteps = torch.from_numpy(timesteps)
193 | prev_timesteps = timesteps - step_ratio
194 |
195 | xt = torch.zeros(shape).to(self.device)
196 | for t, t_prev in zip(timesteps, prev_timesteps):
197 | pass
198 |
199 | x0_pred = xt
200 |
201 | ######################
202 |
203 | return x0_pred
204 |
205 | def compute_loss(self, x0):
206 | """
207 | The simplified noise matching loss corresponding Equation 14 in DDPM paper.
208 |
209 | Input:
210 | x0 (`torch.Tensor`): clean data
211 | Output:
212 | loss: the computed loss to be backpropagated.
213 | """
214 | ######## TODO ########
215 | # DO NOT change the code outside this part.
216 | # compute noise matching loss.
217 | batch_size = x0.shape[0]
218 | t = (
219 | torch.randint(0, self.var_scheduler.num_train_timesteps, size=(batch_size,))
220 | .to(x0.device)
221 | .long()
222 | )
223 |
224 | loss = x0.mean()
225 |
226 | ######################
227 | return loss
228 |
229 | def save(self, file_path):
230 | hparams = {
231 | "network": self.network,
232 | "var_scheduler": self.var_scheduler,
233 | }
234 | state_dict = self.state_dict()
235 |
236 | dic = {"hparams": hparams, "state_dict": state_dict}
237 | torch.save(dic, file_path)
238 |
239 | def load(self, file_path):
240 | dic = torch.load(file_path, map_location="cpu")
241 | hparams = dic["hparams"]
242 | state_dict = dic["state_dict"]
243 |
244 | self.network = hparams["network"]
245 | self.var_scheduler = hparams["var_scheduler"]
246 |
247 | self.load_state_dict(state_dict)
248 |
--------------------------------------------------------------------------------
/2d_plot_diffusion_todo/network.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class TimeEmbedding(nn.Module):
10 | def __init__(self, hidden_size, frequency_embedding_size=256):
11 | super().__init__()
12 | self.mlp = nn.Sequential(
13 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
14 | nn.SiLU(),
15 | nn.Linear(hidden_size, hidden_size, bias=True),
16 | )
17 | self.frequency_embedding_size = frequency_embedding_size
18 |
19 | @staticmethod
20 | def timestep_embedding(t, dim, max_period=10000):
21 | """
22 | Create sinusoidal timestep embeddings.
23 | :param t: a 1-D Tensor of N indices, one per batch element.
24 | These may be fractional.
25 | :param dim: the dimension of the output.
26 | :param max_period: controls the minimum frequency of the embeddings.
27 | :return: an (N, D) Tensor of positional embeddings.
28 | """
29 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
30 | half = dim // 2
31 | freqs = torch.exp(
32 | -math.log(max_period)
33 | * torch.arange(start=0, end=half, dtype=torch.float32)
34 | / half
35 | ).to(device=t.device)
36 | args = t[:, None].float() * freqs[None]
37 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
38 | if dim % 2:
39 | embedding = torch.cat(
40 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
41 | )
42 | return embedding
43 |
44 | def forward(self, t: torch.Tensor):
45 | if t.ndim == 0:
46 | t = t.unsqueeze(-1)
47 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
48 | t_emb = self.mlp(t_freq)
49 | return t_emb
50 |
51 |
52 | class TimeLinear(nn.Module):
53 | def __init__(self, dim_in: int, dim_out: int, num_timesteps: int):
54 | super().__init__()
55 | self.dim_in = dim_in
56 | self.dim_out = dim_out
57 | self.num_timesteps = num_timesteps
58 |
59 | self.time_embedding = TimeEmbedding(dim_out)
60 | self.fc = nn.Linear(dim_in, dim_out)
61 |
62 | def forward(self, x: torch.Tensor, t: torch.Tensor):
63 | x = self.fc(x)
64 | alpha = self.time_embedding(t).view(-1, self.dim_out)
65 |
66 | return alpha * x
67 |
68 |
69 | class SimpleNet(nn.Module):
70 | def __init__(
71 | self, dim_in: int, dim_out: int, dim_hids: List[int], num_timesteps: int
72 | ):
73 | super().__init__()
74 | """
75 | (TODO) Build a noise estimating network.
76 |
77 | Args:
78 | dim_in: dimension of input
79 | dim_out: dimension of output
80 | dim_hids: dimensions of hidden features
81 | num_timesteps: number of timesteps
82 | """
83 |
84 | ######## TODO ########
85 | # DO NOT change the code outside this part.
86 |
87 | ######################
88 |
89 | def forward(self, x: torch.Tensor, t: torch.Tensor):
90 | """
91 | (TODO) Implement the forward pass. This should output
92 | the noise prediction of the noisy input x at timestep t.
93 |
94 | Args:
95 | x: the noisy data after t period diffusion
96 | t: the time that the forward diffusion has been running
97 | """
98 | ######## TODO ########
99 | # DO NOT change the code outside this part.
100 |
101 | ######################
102 | return x
103 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 KAIST Visual AI Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
10 |
11 |
18 |
19 |
20 |

21 |
22 |
23 |
24 | ## Abstract
25 | In this programming assignment, you will implement the Denoising Diffusion Probabilistic Model (DDPM), a fundamental building block that empowers today's diffusion-based generative modeling. While DDPM provides the technical foundation for popular generative frameworks like [Stable Diffusion](https://github.com/CompVis/stable-diffusion), its implementation is surprisingly straightforward, making it an excellent starting point for gaining hands-on experience in building diffusion models. We will begin with a relatively simple example: modeling the distribution of 2D points on a spiral (known as the "Swiss Roll"). Following that, we will develop an image generator using the AFHQ dataset to explore how DDPM and diffusion models seamlessly adapt to changes in data format and dimensionality with minimal code changes.
26 |
27 | ## Setup
28 |
29 | Create a `conda` environment named `ddpm` and install PyTorch:
30 | ```
31 | conda create --name ddpm python=3.10
32 | conda activate ddpm
33 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
34 | ```
35 |
36 | Install the required package within the `requirements.txt`
37 | ```
38 | pip install -r requirements.txt
39 | ```
40 |
41 | > **NOTE: We have removed the dependency on `chamferdist` due to issues during installation.**
42 |
43 | ## Code Structure
44 | ```
45 | .
46 | ├── 2d_plot_diffusion_todo (Task 1)
47 | │ ├── ddpm_tutorial.ipynb <--- Main code
48 | │ ├── dataset.py <--- Define dataset (Swiss-roll, moon, gaussians, etc.)
49 | │ ├── network.py <--- (TODO) Implement a noise prediction network
50 | │ └── ddpm.py <--- (TODO) Define a DDPM pipeline
51 | │
52 | └── image_diffusion_todo (Task 2)
53 | ├── dataset.py <--- Ready-to-use AFHQ dataset code
54 | ├── model.py <--- Diffusion model including its backbone and scheduler
55 | ├── module.py <--- Basic modules of a noise prediction network
56 | ├── network.py <--- Definition of the U-Net architecture
57 | ├── sampling.py <--- Image sampling code
58 | ├── scheduler.py <--- (TODO) Implement the forward/reverse step of DDPM
59 | ├── train.py <--- DDPM training code
60 | └── fid
61 | ├── measure_fid.py <--- script measuring FID score
62 | └── afhq_inception.ckpt <--- pre-trained classifier for FID
63 | ```
64 |
65 |
66 | ## Task 0: Introduction
67 | ### Assignment Tips
68 |
69 | Implementation of diffusion models would be simple once you understand the theory.
70 | So, to learn the most from this tutorial, it's highly recommended to check out the details in the
71 | related papers and understand the equations **BEFORE** you start the assignment. You can check out
72 | the resources in this order:
73 |
74 | 1. [[Paper](https://arxiv.org/abs/2006.11239)] Denoising Diffusion Probabilistic Models
75 | 2. [[Blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] Lilian Wang's "What are Diffusion Models?"
76 |
77 | ### Forward Process
78 | Denoising Diffusion Probabilistic Model (DDPM) is one of latent-variable generative models consisting of a Markov chain. In the Markov chain, let us define a _forward process_ that gradually adds noise to the data sampled from a data distribution $\mathbf{x}_0 \sim q(\mathbf{x}_0)$ so that $\mathbf{x}_0$ becomes pure white Gaussian noise at $t=T$. Each transition of the forward process is as follows:
79 |
80 | $$ q(\mathbf{x}_t | \mathbf{x}\_{t-1}) := \mathcal{N}(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}\_{t-1}, \beta_t \mathbf{I}), $$
81 |
82 | where a variance schedule $\beta_1, \dots, \beta_T$ controlls the step sizes.
83 |
84 | Thanks to a nice property of a Gaussian distribution, one can directly sample $\mathbf{x}_t$ at an arbitrary timestep $t$ from real data $\mathbf{x}_0$ in closed form:
85 |
86 | $$q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t) \mathbf{I}). $$
87 |
88 | where $\alpha\_t := 1 - \beta\_t$ and $\bar{\alpha}_t := \prod$ $\_{s=1}^T \alpha_s$.
89 |
90 | Refer to [our slide](./assets/summary_of_DDPM_and_DDIM.pdf) or [blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) for more details.
91 |
92 | ### Reverse Process
93 | If we can reverse the forward process, i.e. sample $\mathbf{x}\_{t-1} \sim q(\mathbf{x}\_{t-1} | \mathbf{x}_t)$ iteratively until $t=0$, we will be able to generate $\mathbf{x}_0$ which is close to the unknown data distribution $\mathbf{q}(\mathbf{x}_0)$ from white Gaussian noise $\mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I})$. You can think of this _reverse process_ as denoising process that gradually denoises noise so that it looks like a true sample from $q(\mathbf{x}_0)$ at the end.
94 | The reverse process is also a Markov chain with learned Gaussian transitions:
95 |
96 | $$p\_\theta(\mathbf{x}\_{0:T}) := p(\mathbf{x}_T) \prod\_{t=1}^T p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t), $$
97 |
98 | where $p(\mathbf{x}_T) = \mathcal{N}(0, \mathbf{I})$ and $p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t) := \mathcal{N}(\mathbf{x}\_{t-1}; \mathbf{\boldsymbol{\mu}}\_\theta (\mathbf{x}_t, t)\boldsymbol{\Sigma}\_\theta (\mathbf{x}_t, t)).$
99 |
100 | ### Training
101 | To learn this reverse process, we set an objective function that minimizes KL divergence between $p_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t)$ and $q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0), \sigma_t^2 \mathbf{I})$ which is also a Gaussian distribution when conditioned on $\mathbf{x}_0$:
102 |
103 | $$\mathcal{L} = \mathbb{E}_q \left[ \sum\_{t > 1} D\_{\text{KL}}( q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \Vert p\_\theta ( \mathbf{x}\_{t-1} | \mathbf{x}_t)) \right].$$
104 |
105 | As a parameterization of DDPM, the authors set $\boldsymbol{\Sigma}\_\theta(\mathbf{x}_t, t) = \sigma_t^2 \mathbf{I}$ to untrained time dependent constants. As a result, we can rewrite the objective function:
106 |
107 | $$\mathcal{L} = \mathbb{E}\_q \left[ \frac{1}{2\sigma\_t^2} \Vert \tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) - \boldsymbol{\mu}\_{\theta}(\mathbf{x}_t, t) \Vert^2 \right] + C $$
108 |
109 | The authors empirically found that predicting $\epsilon$ noise injected to data by a noise prediction network $\epsilon\_\theta$ is better than learning the mean function $\boldsymbol{\mu}\_\theta$.
110 |
111 | In short, the simplified objective function of DDPM is defined as follows:
112 |
113 | $$ \mathcal{L}\_{\text{simple}} := \mathbb{E}\_{t,\mathbf{x}_0,\boldsymbol{\epsilon}} [ \Vert \boldsymbol{\epsilon} - \boldsymbol{\epsilon}\_\theta( \mathbf{x}\_t(\mathbf{x}_0, t), t) \Vert^2 ],$$
114 |
115 | where $\mathbf{x}_t (\mathbf{x}_0, t) = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}$ and $\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$.
116 |
117 | Refer to [the original paper](https://arxiv.org/abs/2006.11239) for more details.
118 |
119 | ### Sampling
120 |
121 | Once we train the noise prediction network $\boldsymbol{\epsilon}\_\theta$, we can run sampling by gradually denoising white Gaussian noise. The algorithm of the DDPM sampling is shown below:
122 |
123 |
124 |
125 |
126 |
127 | ## Task 1: Simple DDPM pipeline with Swiss-Roll
128 |
129 |
130 |
131 |
132 |
133 | A typical diffusion pipeline is divided into three components:
134 | 1. [Forward Process](#forward-process) and [Reverse Process](#reverse-process)
135 | 2. [Training](#training)
136 | 3. [Sampling](#sampling)
137 |
138 | In this task, we will look into each component one by one in a toy experiment and implement them sequentially.
139 | After finishing the implementation, you will be able to train DDPM and evaluate the performance in `ddpm_tutorial.ipynb` under `2d_plot_todo` directory.
140 |
141 | ❗️❗️❗️ **You are only allowed to edit the part marked by TODO.** ❗️❗️❗️
142 |
143 | ### TODO
144 | #### 1-1: Build a noise prediction network
145 | You first need to implement a noise prediction network in `network.py`.
146 | The network should consist of `TimeLinear` layers whose feature dimensions are a sequence of [`dim_in`, `dim_hids[0]`, ..., `dim_hids[-1]`, `dim_out`].
147 | Every `TimeLinear` layer except for the last `TimeLinear` layer should be followed by a ReLU activation.
148 |
149 | #### 1-2: Construct the forward and reverse process of DDPM
150 | Now you should construct a forward and reverse process of DDPM in `ddpm.py`.
151 | `q_sample()` is a forward function that maps $\mathbf{x}_0$ to $\mathbf{x}_t$.
152 |
153 | `p_sample()` is a one-step reverse transition from $\mathbf{x}\_{t}$ to $\mathbf{x}\_{t-1}$ and `p_sample_loop()` is the full reverse process corresponding to DDPM sampling algorithm.
154 |
155 | #### 1-3: Implement the training objective function
156 | In `ddpm.py`, `compute_loss()` function should return the simplified noise matching loss in DDPM paper.
157 |
158 | #### 1-4: Training and Evaluation
159 | Once you finish the implementation above, open and run `ddpm_tutorial.ipynb` via jupyter notebook. It will automatically train a diffudion model and measure chamfer distance between 2D particles sampled by the diffusion model and 2D particles sampled from the target distribution.
160 |
161 | Take screenshots of:
162 |
163 | 1. the training loss curve
164 | 2. the Chamfer Distance reported after executing the Jupyter Notebook
165 | 3. the visualization of the sampled particles
166 |
167 | Below are the examples of (1) and (3).
168 |
169 |
170 |
171 |
172 |
173 | ## Task 2: Image Diffusion
174 |
175 |
176 |
177 |
178 |
179 | ### TODO
180 |
181 | If you successfully finish the task 1, implement the methods `add_noise` and `step` of the class `DDPMScheduler` defined in `image_diffusion_todo/scheduler.py`. You also need to implement the method `get_loss` of the class `DiffusionModule` defined in `image_diffusion_todo/model.py`. Refer to your implementation of the methods `q_sample`, `p_sample`, and `compute_loss` from the 2D experiment.
182 |
183 | In this task, we will generate $64\times64$ animal images by training a DDPM using the AFHQ dataset.
184 |
185 | To train your model, simply execute the command: `python train.py`.
186 |
187 | ❗️❗️❗️ You are NOT allowed to modify any given hyperparameters. ❗️❗️❗️
188 |
189 | It will sample images and save a checkpoint every `args.log_interval`. After training a model, sample & save images by
190 | ```
191 | python sampling.py --ckpt_path ${CKPT_PATH} --save_dir ${SAVE_DIR_PATH}
192 | ```
193 | 
194 |
195 | We recommend starting the training as soon as possible since the training would take **14 hours**.
196 |
197 | As an evaluation, measure FID score using the pre-trained classifier network we provide:
198 | ```
199 | python dataset.py # to constuct eval directory.
200 | python fid/measure_fid.py @GT_IMG_DIR @ GEN_IMG_DIR
201 | ```
202 |
203 | > **Do NOT forget to execute `dataset.py` before measuring FID score. Otherwise, the output will be incorrect due to the discrepancy between the image resolutions.**
204 |
205 | For instance:
206 | 
207 | Use the validation set of the AFHQ dataset (e.g., `data/afhq/eval`) as @GT_IMG_DIR. The script will automatically search and load the images. The path @DIR_TO_SAVE_IMGS should be the same as the one you provided when running the script `sampling.py`.
208 |
209 | Take a screenshot of a FID score and include at least 8 sampled images.
210 |
211 |
212 |
213 |
214 | ## What to Submit
215 |
216 |
217 | Submission Item List
218 |
219 |
220 | - [ ] Code without model checkpoints
221 |
222 | **Task 1**
223 | - [ ] Loss curve screenshot
224 | - [ ] Chamfer distance result of DDPM sampling
225 | - [ ] Visualization of DDPM sampling
226 |
227 | **Task 2**
228 | - [ ] FID score result
229 | - [ ] At least 8 images generated your DDPM model
230 |
231 |
232 |
233 | In a single document, write your name and student ID, and include submission items listed above. Refer to more detailed instructions written in each task section about what to submit.
234 | Name the document `{NAME}_{STUDENT_ID}.pdf` and submit **both your code and the document** as a **ZIP** file named `{NAME}_{STUDENT_ID}.zip`.
235 | **When creating your zip file**, exclude data (e.g., files in AFHQ dataset) and any model checkpoints, including the provided pre-trained classifier checkpoint when compressing the files.
236 | Submit the zip file on GradeScope.
237 |
238 | ## Grading
239 | **You will receive a zero score if:**
240 | - **you do not submit,**
241 | - **your code is not executable in the Python environment we provided, or**
242 | - **you modify anycode outside of the section marked with `TODO` or use different hyperparameters that are supposed to be fixed as given.**
243 |
244 | **Plagiarism in any form will also result in a zero score and will be reported to the university.**
245 |
246 | **Your score will incur a 10% deduction for each missing item in the submission item list.**
247 |
248 | Otherwise, you will receive up to 20 points from this assignment that count toward your final grade.
249 |
250 | - Task 1
251 | - 10 points: Achieve CD lower than **20** from DDPM sampling.
252 | - 5 points: Achieve CD greater, or equal to **20** and less than **40** from DDPM sampling.
253 | - 0 point: otherwise.
254 | - Task 2
255 | - 10 points: Achieve FID less than **20**.
256 | - 5 points: Achieve FID between **20** and **40**.
257 | - 0 point: otherwise.
258 |
259 | ## Further Readings
260 |
261 | If you are interested in this topic, we encourage you to check ou the materials below.
262 |
263 | - [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
264 | - [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
265 | - [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233)
266 | - [Score-Based Generative Modeling through Stochastic Differential Equations](https://arxiv.org/abs/2011.13456)
267 | - [What are Diffusion Models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
268 | - [Generative Modeling by Estimating Gradients of the Data Distribution](https://yang-song.net/blog/2021/score/)
269 |
--------------------------------------------------------------------------------
/assets/images/cfg_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/cfg_test.png
--------------------------------------------------------------------------------
/assets/images/cfg_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/cfg_train.png
--------------------------------------------------------------------------------
/assets/images/fid_command.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/fid_command.png
--------------------------------------------------------------------------------
/assets/images/qs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/qs.png
--------------------------------------------------------------------------------
/assets/images/sampling_command.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/sampling_command.png
--------------------------------------------------------------------------------
/assets/images/task1_ddim_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_ddim_sample.png
--------------------------------------------------------------------------------
/assets/images/task1_ddpm_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_ddpm_sample.png
--------------------------------------------------------------------------------
/assets/images/task1_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_distribution.png
--------------------------------------------------------------------------------
/assets/images/task1_loss_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_loss_curve.png
--------------------------------------------------------------------------------
/assets/images/task1_output_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task1_output_example.png
--------------------------------------------------------------------------------
/assets/images/task2_1_ddpm_sampling_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_1_ddpm_sampling_algorithm.png
--------------------------------------------------------------------------------
/assets/images/task2_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_algorithm.png
--------------------------------------------------------------------------------
/assets/images/task2_ddim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_ddim.png
--------------------------------------------------------------------------------
/assets/images/task2_output_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_output_example.png
--------------------------------------------------------------------------------
/assets/images/task2_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/task2_teaser.png
--------------------------------------------------------------------------------
/assets/images/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/teaser.gif
--------------------------------------------------------------------------------
/assets/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/images/teaser.png
--------------------------------------------------------------------------------
/assets/summary_of_DDPM_and_DDIM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/assets/summary_of_DDPM_and_DDIM.pdf
--------------------------------------------------------------------------------
/image_diffusion_todo/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 | from multiprocessing.pool import Pool
4 | from pathlib import Path
5 |
6 | import torch
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 |
10 |
11 | def listdir(dname):
12 | fnames = list(
13 | chain(
14 | *[
15 | list(Path(dname).rglob("*." + ext))
16 | for ext in ["png", "jpg", "jpeg", "JPG"]
17 | ]
18 | )
19 | )
20 | return fnames
21 |
22 |
23 | def tensor_to_pil_image(x: torch.Tensor, single_image=False):
24 | """
25 | x: [B,C,H,W]
26 | """
27 | if x.ndim == 3:
28 | x = x.unsqueeze(0)
29 | single_image = True
30 |
31 | x = (x * 0.5 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
32 | images = (x * 255).round().astype("uint8")
33 | images = [Image.fromarray(image) for image in images]
34 | if single_image:
35 | return images[0]
36 | return images
37 |
38 |
39 | def get_data_iterator(iterable):
40 | """Allows training with DataLoaders in a single infinite loop:
41 | for i, data in enumerate(inf_generator(train_loader)):
42 | """
43 | iterator = iterable.__iter__()
44 | while True:
45 | try:
46 | yield iterator.__next__()
47 | except StopIteration:
48 | iterator = iterable.__iter__()
49 |
50 |
51 | class AFHQDataset(torch.utils.data.Dataset):
52 | def __init__(
53 | self, root: str, split: str, transform=None, max_num_images_per_cat=-1, label_offset=1
54 | ):
55 | super().__init__()
56 | self.root = root
57 | self.split = split
58 | self.transform = transform
59 | self.max_num_images_per_cat = max_num_images_per_cat
60 | self.label_offset = label_offset
61 |
62 | categories = os.listdir(os.path.join(root, split))
63 | self.num_classes = len(categories)
64 |
65 | fnames, labels = [], []
66 | for idx, cat in enumerate(sorted(categories)):
67 | category_dir = os.path.join(root, split, cat)
68 | cat_fnames = listdir(category_dir)
69 | cat_fnames = sorted(cat_fnames)
70 | if self.max_num_images_per_cat > 0:
71 | cat_fnames = cat_fnames[: self.max_num_images_per_cat]
72 | fnames += cat_fnames
73 | labels += [idx + label_offset] * len(cat_fnames) # label 0 is for null class.
74 |
75 | self.fnames = fnames
76 | self.labels = labels
77 |
78 | def __getitem__(self, idx):
79 | img = Image.open(self.fnames[idx]).convert("RGB")
80 | label = self.labels[idx]
81 | assert label >= self.label_offset
82 | if self.transform is not None:
83 | img = self.transform(img)
84 |
85 | return img, label
86 |
87 | def __len__(self):
88 | return len(self.labels)
89 |
90 |
91 | class AFHQDataModule(object):
92 | def __init__(
93 | self,
94 | root: str = "data",
95 | batch_size: int = 32,
96 | num_workers: int = 4,
97 | max_num_images_per_cat: int = 1000,
98 | image_resolution: int = 64,
99 | label_offset=1,
100 | transform=None
101 | ):
102 | self.root = root
103 | self.batch_size = batch_size
104 | self.num_workers = num_workers
105 | self.afhq_root = os.path.join(root, "afhq")
106 | self.max_num_images_per_cat = max_num_images_per_cat
107 | self.image_resolution = image_resolution
108 | self.label_offset = label_offset
109 | self.transform = transform
110 |
111 | if not os.path.exists(self.afhq_root):
112 | print(f"{self.afhq_root} is empty. Downloading AFHQ dataset...")
113 | self._download_dataset()
114 |
115 | self._set_dataset()
116 |
117 | def _set_dataset(self):
118 | if self.transform is None:
119 | self.transform = transforms.Compose(
120 | [
121 | transforms.Resize((self.image_resolution, self.image_resolution)),
122 | transforms.ToTensor(),
123 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
124 | ]
125 | )
126 | self.train_ds = AFHQDataset(
127 | self.afhq_root,
128 | "train",
129 | self.transform,
130 | max_num_images_per_cat=self.max_num_images_per_cat,
131 | label_offset=self.label_offset
132 | )
133 | self.val_ds = AFHQDataset(
134 | self.afhq_root,
135 | "val",
136 | self.transform,
137 | max_num_images_per_cat=self.max_num_images_per_cat,
138 | label_offset=self.label_offset,
139 | )
140 |
141 | self.num_classes = self.train_ds.num_classes
142 |
143 | def _download_dataset(self):
144 | URL = "https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0"
145 | ZIP_FILE = f"./{self.root}/afhq.zip"
146 | os.system(f"mkdir -p {self.root}")
147 | os.system(f"wget -N {URL} -O {ZIP_FILE}")
148 | os.system(f"unzip {ZIP_FILE} -d {self.root}")
149 | os.system(f"rm {ZIP_FILE}")
150 |
151 | def train_dataloader(self):
152 | return torch.utils.data.DataLoader(
153 | self.train_ds,
154 | batch_size=self.batch_size,
155 | num_workers=self.num_workers,
156 | shuffle=True,
157 | drop_last=True,
158 | )
159 |
160 | def val_dataloader(self):
161 | return torch.utils.data.DataLoader(
162 | self.val_ds,
163 | batch_size=self.batch_size,
164 | num_workers=self.num_workers,
165 | shuffle=False,
166 | drop_last=False,
167 | )
168 |
169 |
170 | if __name__ == "__main__":
171 | data_module = AFHQDataModule("data", 32, 4, -1, 64, 1)
172 |
173 | eval_dir = Path(data_module.afhq_root) / "eval"
174 | eval_dir.mkdir(exist_ok=True)
175 | def func(path):
176 | fn = path.name
177 | cmd = f"cp {path} {eval_dir / fn}"
178 | os.system(cmd)
179 | img = Image.open(str(eval_dir / fn))
180 | img = img.resize((64,64))
181 | img.save(str(eval_dir / fn))
182 | print(fn)
183 |
184 | with Pool(8) as pool:
185 | pool.map(func, data_module.val_ds.fnames)
186 |
187 | print(f"Constructed eval dir at {eval_dir}")
188 |
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/afhq_inception_v3.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/Diffusion-Assignment1-DDPM/5a63ae5c53b576e271290a22021a14b33f23f36d/image_diffusion_todo/fid/afhq_inception_v3.ckpt
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/inception.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 | import numpy as np
11 | import torch.nn as nn
12 | from torchvision import models
13 |
14 |
15 | class InceptionV3(nn.Module):
16 | def __init__(self, for_train):
17 | super().__init__()
18 | self.for_train = for_train
19 |
20 | inception = models.inception_v3(pretrained=False)
21 | self.block1 = nn.Sequential(
22 | inception.Conv2d_1a_3x3,
23 | inception.Conv2d_2a_3x3,
24 | inception.Conv2d_2b_3x3,
25 | nn.MaxPool2d(kernel_size=3, stride=2),
26 | )
27 | self.block2 = nn.Sequential(
28 | inception.Conv2d_3b_1x1,
29 | inception.Conv2d_4a_3x3,
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.block3 = nn.Sequential(
33 | inception.Mixed_5b,
34 | inception.Mixed_5c,
35 | inception.Mixed_5d,
36 | inception.Mixed_6a,
37 | inception.Mixed_6b,
38 | inception.Mixed_6c,
39 | inception.Mixed_6d,
40 | inception.Mixed_6e,
41 | )
42 | self.block4 = nn.Sequential(
43 | inception.Mixed_7a,
44 | inception.Mixed_7b,
45 | inception.Mixed_7c,
46 | nn.AdaptiveAvgPool2d(output_size=(1, 1)),
47 | )
48 |
49 | self.final_fc = nn.Linear(2048, 3)
50 |
51 | def forward(self, x):
52 | x = self.block1(x)
53 | x = self.block2(x)
54 | x = self.block3(x)
55 | x = self.block4(x)
56 | x = x.view(x.size(0), -1)
57 | if self.for_train:
58 | return self.final_fc(x)
59 | else:
60 | return x
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/measure_fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import sys
6 | from PIL import Image
7 | from scipy import linalg
8 | from torchvision import transforms
9 | from itertools import chain
10 | from pathlib import Path
11 | from inception import InceptionV3
12 |
13 | try:
14 | from tqdm import tqdm
15 | except ImportError:
16 | def tqdm(x):
17 | return x
18 |
19 | class ImagePathDataset(torch.utils.data.Dataset):
20 | def __init__(self, files, img_size):
21 | self.files = files
22 | self.img_size = img_size
23 | self.transforms = transforms.Compose(
24 | [
25 | transforms.Resize((img_size, img_size)),
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
28 | ]
29 | )
30 |
31 | def __len__(self):
32 | return len(self.files)
33 |
34 | def __getitem__(self, i):
35 | path = self.files[i]
36 | img = Image.open(path).convert("RGB")
37 | if self.transforms is not None:
38 | img = self.transforms(img)
39 | return img
40 |
41 |
42 | def get_eval_loader(path, img_size, batch_size):
43 | def listdir(dname):
44 | fnames = list(
45 | chain(
46 | *[
47 | list(Path(dname).rglob("*." + ext))
48 | for ext in ["png", "jpg", "jpeg", "JPG"]
49 | ]
50 | )
51 | )
52 | return fnames
53 |
54 | files = listdir(path)
55 | ds = ImagePathDataset(files, img_size)
56 | dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
57 | return dl
58 |
59 | def frechet_distance(mu, cov, mu2, cov2):
60 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
61 | dist = np.sum((mu - mu2) ** 2) + np.trace(cov + cov2 - 2 * cc)
62 | return np.real(dist)
63 |
64 |
65 |
66 | @torch.no_grad()
67 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
68 | print("Calculating FID given paths %s and %s..." % (paths[0], paths[1]))
69 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70 | inception = InceptionV3(for_train=False)
71 | current_dir = Path(os.path.realpath(__file__)).parent
72 | ckpt = torch.load(current_dir / "afhq_inception_v3.ckpt", map_location="cpu")
73 | inception.load_state_dict(ckpt)
74 | inception = inception.eval().to(device)
75 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
76 |
77 | mu, cov = [], []
78 | for loader in loaders:
79 | actvs = []
80 | for x in tqdm(loader, total=len(loader)):
81 | actv = inception(x.to(device))
82 | actvs.append(actv)
83 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
84 | mu.append(np.mean(actvs, axis=0))
85 | cov.append(np.cov(actvs, rowvar=False))
86 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
87 | return fid_value
88 |
89 | if __name__ == "__main__":
90 | # python measure_fid /path/to/dir1 /path/to/dir2
91 |
92 | paths = [sys.argv[1], sys.argv[2]]
93 | fid_value = calculate_fid_given_paths(paths, img_size=256, batch_size=64)
94 | print("FID:", fid_value)
95 |
96 |
--------------------------------------------------------------------------------
/image_diffusion_todo/model.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from tqdm import tqdm
8 |
9 |
10 | class DiffusionModule(nn.Module):
11 | def __init__(self, network, var_scheduler, **kwargs):
12 | super().__init__()
13 | self.network = network
14 | self.var_scheduler = var_scheduler
15 |
16 | def get_loss(self, x0, class_label=None, noise=None):
17 | ######## TODO ########
18 | # DO NOT change the code outside this part.
19 | # compute noise matching loss.
20 | B = x0.shape[0]
21 | timestep = self.var_scheduler.uniform_sample_t(B, self.device)
22 | loss = x0.mean()
23 | ######################
24 | return loss
25 |
26 | @property
27 | def device(self):
28 | return next(self.network.parameters()).device
29 |
30 | @property
31 | def image_resolution(self):
32 | return self.network.image_resolution
33 |
34 | @torch.no_grad()
35 | def sample(
36 | self,
37 | batch_size,
38 | return_traj=False,
39 | class_label: Optional[torch.Tensor] = None,
40 | guidance_scale: Optional[float] = 1.0,
41 | ):
42 | x_T = torch.randn([batch_size, 3, self.image_resolution, self.image_resolution]).to(self.device)
43 |
44 | do_classifier_free_guidance = guidance_scale > 1.0
45 |
46 | if do_classifier_free_guidance:
47 |
48 | ######## TODO ########
49 | # Assignment 2. Implement the classifier-free guidance.
50 | # Specifically, given a tensor of shape (batch_size,) containing class labels,
51 | # create a tensor of shape (2*batch_size,) where the first half is filled with zeros (i.e., null condition).
52 | assert class_label is not None
53 | assert len(class_label) == batch_size, f"len(class_label) != batch_size. {len(class_label)} != {batch_size}"
54 | raise NotImplementedError("TODO")
55 | #######################
56 |
57 | traj = [x_T]
58 | for t in tqdm(self.var_scheduler.timesteps):
59 | x_t = traj[-1]
60 | if do_classifier_free_guidance:
61 | ######## TODO ########
62 | # Assignment 2. Implement the classifier-free guidance.
63 | raise NotImplementedError("TODO")
64 | #######################
65 | else:
66 | noise_pred = self.network(x_t, timestep=t.to(self.device))
67 |
68 | x_t_prev = self.var_scheduler.step(x_t, t, noise_pred)
69 |
70 | traj[-1] = traj[-1].cpu()
71 | traj.append(x_t_prev.detach())
72 |
73 | if return_traj:
74 | return traj
75 | else:
76 | return traj[-1]
77 |
78 | def save(self, file_path):
79 | hparams = {
80 | "network": self.network,
81 | "var_scheduler": self.var_scheduler,
82 | }
83 | state_dict = self.state_dict()
84 |
85 | dic = {"hparams": hparams, "state_dict": state_dict}
86 | torch.save(dic, file_path)
87 |
88 | def load(self, file_path):
89 | dic = torch.load(file_path, map_location="cpu")
90 | hparams = dic["hparams"]
91 | state_dict = dic["state_dict"]
92 |
93 | self.network = hparams["network"]
94 | self.var_scheduler = hparams["var_scheduler"]
95 |
96 | self.load_state_dict(state_dict)
97 |
--------------------------------------------------------------------------------
/image_diffusion_todo/module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 |
8 |
9 | class Swish(nn.Module):
10 | def forward(self, x):
11 | return x * torch.sigmoid(x)
12 |
13 |
14 | class DownSample(nn.Module):
15 | def __init__(self, in_ch):
16 | super().__init__()
17 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
18 | self.initialize()
19 |
20 | def initialize(self):
21 | init.xavier_uniform_(self.main.weight)
22 | init.zeros_(self.main.bias)
23 |
24 | def forward(self, x, temb):
25 | x = self.main(x)
26 | return x
27 |
28 |
29 | class UpSample(nn.Module):
30 | def __init__(self, in_ch):
31 | super().__init__()
32 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
33 | self.initialize()
34 |
35 | def initialize(self):
36 | init.xavier_uniform_(self.main.weight)
37 | init.zeros_(self.main.bias)
38 |
39 | def forward(self, x, temb):
40 | _, _, H, W = x.shape
41 | x = F.interpolate(x, scale_factor=2, mode="nearest")
42 | x = self.main(x)
43 | return x
44 |
45 |
46 | class AttnBlock(nn.Module):
47 | def __init__(self, in_ch):
48 | super().__init__()
49 | self.group_norm = nn.GroupNorm(32, in_ch)
50 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
51 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
52 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
53 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
54 | self.initialize()
55 |
56 | def initialize(self):
57 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
58 | init.xavier_uniform_(module.weight)
59 | init.zeros_(module.bias)
60 | init.xavier_uniform_(self.proj.weight, gain=1e-5)
61 |
62 | def forward(self, x):
63 | B, C, H, W = x.shape
64 | h = self.group_norm(x)
65 | q = self.proj_q(h)
66 | k = self.proj_k(h)
67 | v = self.proj_v(h)
68 |
69 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
70 | k = k.view(B, C, H * W)
71 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
72 | assert list(w.shape) == [B, H * W, H * W]
73 | w = F.softmax(w, dim=-1)
74 |
75 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
76 | h = torch.bmm(w, v)
77 | assert list(h.shape) == [B, H * W, C]
78 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
79 | h = self.proj(h)
80 |
81 | return x + h
82 |
83 |
84 | class ResBlock(nn.Module):
85 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
86 | super().__init__()
87 | self.block1 = nn.Sequential(
88 | nn.GroupNorm(32, in_ch),
89 | Swish(),
90 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
91 | )
92 | self.temb_proj = nn.Sequential(
93 | Swish(),
94 | nn.Linear(tdim, out_ch),
95 | )
96 | self.block2 = nn.Sequential(
97 | nn.GroupNorm(32, out_ch),
98 | Swish(),
99 | nn.Dropout(dropout),
100 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
101 | )
102 | if in_ch != out_ch:
103 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
104 | else:
105 | self.shortcut = nn.Identity()
106 | if attn:
107 | self.attn = AttnBlock(out_ch)
108 | else:
109 | self.attn = nn.Identity()
110 | self.initialize()
111 |
112 | def initialize(self):
113 | for module in self.modules():
114 | if isinstance(module, (nn.Conv2d, nn.Linear)):
115 | init.xavier_uniform_(module.weight)
116 | init.zeros_(module.bias)
117 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
118 |
119 | def forward(self, x, temb):
120 | h = self.block1(x)
121 | h += self.temb_proj(temb)[:, :, None, None]
122 | h = self.block2(h)
123 |
124 | h = h + self.shortcut(x)
125 | h = self.attn(h)
126 | return h
127 |
128 |
129 | class TimeEmbedding(nn.Module):
130 | def __init__(self, hidden_size, frequency_embedding_size=256):
131 | super().__init__()
132 | self.mlp = nn.Sequential(
133 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, hidden_size, bias=True),
136 | )
137 | self.frequency_embedding_size = frequency_embedding_size
138 |
139 | @staticmethod
140 | def timestep_embedding(t, dim, max_period=10000):
141 | """
142 | Create sinusoidal timestep embeddings.
143 | :param t: a 1-D Tensor of N indices, one per batch element.
144 | These may be fractional.
145 | :param dim: the dimension of the output.
146 | :param max_period: controls the minimum frequency of the embeddings.
147 | :return: an (N, D) Tensor of positional embeddings.
148 | """
149 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
150 | half = dim // 2
151 | freqs = torch.exp(
152 | -math.log(max_period)
153 | * torch.arange(start=0, end=half, dtype=torch.float32)
154 | / half
155 | ).to(device=t.device)
156 | args = t[:, None].float() * freqs[None]
157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
158 | if dim % 2:
159 | embedding = torch.cat(
160 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
161 | )
162 | return embedding
163 |
164 | def forward(self, t):
165 | if t.ndim == 0:
166 | t = t.unsqueeze(-1)
167 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
168 | t_emb = self.mlp(t_freq)
169 | return t_emb
170 |
--------------------------------------------------------------------------------
/image_diffusion_todo/network.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from module import DownSample, ResBlock, Swish, TimeEmbedding, UpSample
8 | from torch.nn import init
9 |
10 |
11 | class UNet(nn.Module):
12 | def __init__(self, T=1000, image_resolution=64, ch=128, ch_mult=[1,2,2,2], attn=[1], num_res_blocks=4, dropout=0.1, use_cfg=False, cfg_dropout=0.1, num_classes=None):
13 | super().__init__()
14 | self.image_resolution = image_resolution
15 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
16 | tdim = ch * 4
17 | # self.time_embedding = TimeEmbedding(T, ch, tdim)
18 | self.time_embedding = TimeEmbedding(tdim)
19 |
20 | # classifier-free guidance
21 | self.use_cfg = use_cfg
22 | self.cfg_dropout = cfg_dropout
23 | if use_cfg:
24 | assert num_classes is not None
25 | cdim = tdim
26 | self.class_embedding = nn.Embedding(num_classes+1, cdim)
27 |
28 | self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
29 | self.downblocks = nn.ModuleList()
30 | chs = [ch] # record output channel when dowmsample for upsample
31 | now_ch = ch
32 | for i, mult in enumerate(ch_mult):
33 | out_ch = ch * mult
34 | for _ in range(num_res_blocks):
35 | self.downblocks.append(ResBlock(
36 | in_ch=now_ch, out_ch=out_ch, tdim=tdim,
37 | dropout=dropout, attn=(i in attn)))
38 | now_ch = out_ch
39 | chs.append(now_ch)
40 | if i != len(ch_mult) - 1:
41 | self.downblocks.append(DownSample(now_ch))
42 | chs.append(now_ch)
43 |
44 | self.middleblocks = nn.ModuleList([
45 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
46 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
47 | ])
48 |
49 | self.upblocks = nn.ModuleList()
50 | for i, mult in reversed(list(enumerate(ch_mult))):
51 | out_ch = ch * mult
52 | for _ in range(num_res_blocks + 1):
53 | self.upblocks.append(ResBlock(
54 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
55 | dropout=dropout, attn=(i in attn)))
56 | now_ch = out_ch
57 | if i != 0:
58 | self.upblocks.append(UpSample(now_ch))
59 | assert len(chs) == 0
60 |
61 | self.tail = nn.Sequential(
62 | nn.GroupNorm(32, now_ch),
63 | Swish(),
64 | nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
65 | )
66 | self.initialize()
67 |
68 | def initialize(self):
69 | init.xavier_uniform_(self.head.weight)
70 | init.zeros_(self.head.bias)
71 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
72 | init.zeros_(self.tail[-1].bias)
73 |
74 | def forward(self, x, timestep, class_label=None):
75 | # Timestep embedding
76 | temb = self.time_embedding(timestep)
77 |
78 | if self.use_cfg and class_label is not None:
79 | if self.training:
80 | assert not torch.any(class_label == 0) # 0 for null.
81 |
82 | ######## TODO ########
83 | # DO NOT change the code outside this part.
84 | # Assignment 2. Implement random null conditioning in CFG training.
85 | raise NotImplementedError("TODO")
86 | #######################
87 |
88 | ######## TODO ########
89 | # DO NOT change the code outside this part.
90 | # Assignment 2. Implement class conditioning
91 | raise NotImplementedError("TODO")
92 | #######################
93 |
94 | # Downsampling
95 | h = self.head(x)
96 | hs = [h]
97 | for layer in self.downblocks:
98 | h = layer(h, temb)
99 | hs.append(h)
100 | # Middle
101 | for layer in self.middleblocks:
102 | h = layer(h, temb)
103 | # Upsampling
104 | for layer in self.upblocks:
105 | if isinstance(layer, ResBlock):
106 | h = torch.cat([h, hs.pop()], dim=1)
107 | h = layer(h, temb)
108 | h = self.tail(h)
109 |
110 | assert len(hs) == 0
111 | return h
112 |
--------------------------------------------------------------------------------
/image_diffusion_todo/sampling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | from dataset import tensor_to_pil_image
6 | from model import DiffusionModule
7 | from scheduler import DDPMScheduler
8 | from pathlib import Path
9 |
10 |
11 | def main(args):
12 | save_dir = Path(args.save_dir)
13 | save_dir.mkdir(exist_ok=True, parents=True)
14 |
15 | device = f"cuda:{args.gpu}"
16 |
17 | ddpm = DiffusionModule(None, None)
18 | ddpm.load(args.ckpt_path)
19 | ddpm.eval()
20 | ddpm = ddpm.to(device)
21 |
22 | num_train_timesteps = ddpm.var_scheduler.num_train_timesteps
23 | ddpm.var_scheduler = DDPMScheduler(
24 | num_train_timesteps,
25 | beta_1=1e-4,
26 | beta_T=0.02,
27 | mode="linear",
28 | ).to(device)
29 |
30 | total_num_samples = 500
31 | num_batches = int(np.ceil(total_num_samples / args.batch_size))
32 |
33 | for i in range(num_batches):
34 | sidx = i * args.batch_size
35 | eidx = min(sidx + args.batch_size, total_num_samples)
36 | B = eidx - sidx
37 |
38 | if args.use_cfg: # Enable CFG sampling
39 | assert ddpm.network.use_cfg, f"The model was not trained to support CFG."
40 | samples = ddpm.sample(
41 | B,
42 | class_label=torch.randint(1, 4, (B,)),
43 | guidance_scale=args.cfg_scale,
44 | )
45 | else:
46 | samples = ddpm.sample(B)
47 |
48 | pil_images = tensor_to_pil_image(samples)
49 |
50 | for j, img in zip(range(sidx, eidx), pil_images):
51 | img.save(save_dir / f"{j}.png")
52 | print(f"Saved the {j}-th image.")
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--batch_size", type=int, default=64)
58 | parser.add_argument("--gpu", type=int, default=0)
59 | parser.add_argument("--ckpt_path", type=str)
60 | parser.add_argument("--save_dir", type=str)
61 | parser.add_argument("--use_cfg", action="store_true")
62 | parser.add_argument("--sample_method", type=str, default="ddpm")
63 | parser.add_argument("--cfg_scale", type=float, default=7.5)
64 |
65 | args = parser.parse_args()
66 | main(args)
67 |
--------------------------------------------------------------------------------
/image_diffusion_todo/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class BaseScheduler(nn.Module):
9 | def __init__(
10 | self, num_train_timesteps: int, beta_1: float, beta_T: float, mode="linear"
11 | ):
12 | super().__init__()
13 | self.num_train_timesteps = num_train_timesteps
14 | self.num_inference_timesteps = num_train_timesteps
15 | self.timesteps = torch.from_numpy(
16 | np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
17 | )
18 |
19 | if mode == "linear":
20 | betas = torch.linspace(beta_1, beta_T, steps=num_train_timesteps)
21 | elif mode == "quad":
22 | betas = (
23 | torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps) ** 2
24 | )
25 | else:
26 | raise NotImplementedError(f"{mode} is not implemented.")
27 |
28 | alphas = 1 - betas
29 | alphas_cumprod = torch.cumprod(alphas, dim=0)
30 |
31 | self.register_buffer("betas", betas)
32 | self.register_buffer("alphas", alphas)
33 | self.register_buffer("alphas_cumprod", alphas_cumprod)
34 |
35 | def uniform_sample_t(
36 | self, batch_size, device: Optional[torch.device] = None
37 | ) -> torch.IntTensor:
38 | """
39 | Uniformly sample timesteps.
40 | """
41 | ts = np.random.choice(np.arange(self.num_train_timesteps), batch_size)
42 | ts = torch.from_numpy(ts)
43 | if device is not None:
44 | ts = ts.to(device)
45 | return ts
46 |
47 | class DDPMScheduler(BaseScheduler):
48 | def __init__(
49 | self,
50 | num_train_timesteps: int,
51 | beta_1: float,
52 | beta_T: float,
53 | mode="linear",
54 | sigma_type="small",
55 | ):
56 | super().__init__(num_train_timesteps, beta_1, beta_T, mode)
57 |
58 | # sigmas correspond to $\sigma_t$ in the DDPM paper.
59 | self.sigma_type = sigma_type
60 | if sigma_type == "small":
61 | # when $\sigma_t^2 = \tilde{\beta}_t$.
62 | alphas_cumprod_t_prev = torch.cat(
63 | [torch.tensor([1.0]), self.alphas_cumprod[:-1]]
64 | )
65 | sigmas = (
66 | (1 - alphas_cumprod_t_prev) / (1 - self.alphas_cumprod) * self.betas
67 | ) ** 0.5
68 | elif sigma_type == "large":
69 | # when $\sigma_t^2 = \beta_t$.
70 | sigmas = self.betas ** 0.5
71 |
72 | self.register_buffer("sigmas", sigmas)
73 |
74 | def step(self, x_t: torch.Tensor, t: int, eps_theta: torch.Tensor):
75 | """
76 | One step denoising function of DDPM: x_t -> x_{t-1}.
77 |
78 | Input:
79 | x_t (`torch.Tensor [B,C,H,W]`): samples at arbitrary timestep t.
80 | t (`int`): current timestep in a reverse process.
81 | eps_theta (`torch.Tensor [B,C,H,W]`): predicted noise from a learned model.
82 | Ouptut:
83 | sample_prev (`torch.Tensor [B,C,H,W]`): one step denoised sample. (= x_{t-1})
84 | """
85 |
86 | ######## TODO ########
87 | # DO NOT change the code outside this part.
88 | # Assignment 1. Implement the DDPM reverse step.
89 | sample_prev = None
90 | #######################
91 |
92 | return sample_prev
93 |
94 | # https://nn.labml.ai/diffusion/ddpm/utils.html
95 | def _get_teeth(self, consts: torch.Tensor, t: torch.Tensor): # get t th const
96 | const = consts.gather(-1, t)
97 | return const.reshape(-1, 1, 1, 1)
98 |
99 | def add_noise(
100 | self,
101 | x_0: torch.Tensor,
102 | t: torch.IntTensor,
103 | eps: Optional[torch.Tensor] = None,
104 | ):
105 | """
106 | A forward pass of a Markov chain, i.e., q(x_t | x_0).
107 |
108 | Input:
109 | x_0 (`torch.Tensor [B,C,H,W]`): samples from a real data distribution q(x_0).
110 | t: (`torch.IntTensor [B]`)
111 | eps: (`torch.Tensor [B,C,H,W]`, optional): if None, randomly sample Gaussian noise in the function.
112 | Output:
113 | x_t: (`torch.Tensor [B,C,H,W]`): noisy samples at timestep t.
114 | eps: (`torch.Tensor [B,C,H,W]`): injected noise.
115 | """
116 |
117 | if eps is None:
118 | eps = torch.randn(x_0.shape, device='cuda')
119 |
120 | ######## TODO ########
121 | # DO NOT change the code outside this part.
122 | # Assignment 1. Implement the DDPM forward step.
123 | x_t = None
124 | #######################
125 |
126 | return x_t, eps
127 |
--------------------------------------------------------------------------------
/image_diffusion_todo/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from datetime import datetime
4 | from pathlib import Path
5 |
6 | import matplotlib
7 | import matplotlib.pyplot as plt
8 | import torch
9 | from dataset import AFHQDataModule, get_data_iterator, tensor_to_pil_image
10 | from dotmap import DotMap
11 | from model import DiffusionModule
12 | from network import UNet
13 | from pytorch_lightning import seed_everything
14 | from scheduler import DDPMScheduler
15 | from torchvision.transforms.functional import to_pil_image
16 | from tqdm import tqdm
17 |
18 | matplotlib.use("Agg")
19 |
20 |
21 | def get_current_time():
22 | now = datetime.now().strftime("%m-%d-%H%M%S")
23 | return now
24 |
25 |
26 | def main(args):
27 | """config"""
28 | config = DotMap()
29 | config.update(vars(args))
30 | config.device = f"cuda:{args.gpu}"
31 |
32 | now = get_current_time()
33 | if args.use_cfg:
34 | save_dir = Path(f"results/cfg_diffusion-{args.sample_method}-{now}")
35 | else:
36 | save_dir = Path(f"results/diffusion-{args.sample_method}-{now}")
37 | save_dir.mkdir(exist_ok=True, parents=True)
38 | print(f"save_dir: {save_dir}")
39 |
40 | seed_everything(config.seed)
41 |
42 | with open(save_dir / "config.json", "w") as f:
43 | json.dump(config, f, indent=2)
44 | """######"""
45 |
46 | image_resolution = 64
47 | ds_module = AFHQDataModule(
48 | "./data",
49 | batch_size=config.batch_size,
50 | num_workers=4,
51 | max_num_images_per_cat=config.max_num_images_per_cat,
52 | image_resolution=image_resolution
53 | )
54 |
55 | train_dl = ds_module.train_dataloader()
56 | train_it = get_data_iterator(train_dl)
57 |
58 | # Set up the scheduler
59 | var_scheduler = DDPMScheduler(
60 | config.num_diffusion_train_timesteps,
61 | beta_1=config.beta_1,
62 | beta_T=config.beta_T,
63 | mode="linear",
64 | )
65 |
66 | network = UNet(
67 | T=config.num_diffusion_train_timesteps,
68 | image_resolution=image_resolution,
69 | ch=128,
70 | ch_mult=[1, 2, 2, 2],
71 | attn=[1],
72 | num_res_blocks=4,
73 | dropout=0.1,
74 | use_cfg=args.use_cfg,
75 | cfg_dropout=args.cfg_dropout,
76 | num_classes=getattr(ds_module, "num_classes", None),
77 | )
78 |
79 | ddpm = DiffusionModule(network, var_scheduler)
80 | ddpm = ddpm.to(config.device)
81 |
82 | optimizer = torch.optim.Adam(ddpm.network.parameters(), lr=2e-4)
83 | scheduler = torch.optim.lr_scheduler.LambdaLR(
84 | optimizer, lr_lambda=lambda t: min((t + 1) / config.warmup_steps, 1.0)
85 | )
86 |
87 | step = 0
88 | losses = []
89 | with tqdm(initial=step, total=config.train_num_steps) as pbar:
90 | while step < config.train_num_steps:
91 | if step % config.log_interval == 0:
92 | ddpm.eval()
93 | plt.plot(losses)
94 | plt.savefig(f"{save_dir}/loss.png")
95 | plt.close()
96 | samples = ddpm.sample(4, return_traj=False)
97 | pil_images = tensor_to_pil_image(samples)
98 | for i, img in enumerate(pil_images):
99 | img.save(save_dir / f"step={step}-{i}.png")
100 |
101 | ddpm.save(f"{save_dir}/last.ckpt")
102 | ddpm.train()
103 |
104 | img, label = next(train_it)
105 | img, label = img.to(config.device), label.to(config.device)
106 | if args.use_cfg: # Conditional, CFG training
107 | loss = ddpm.get_loss(img, class_label=label)
108 | else: # Unconditional training
109 | loss = ddpm.get_loss(img)
110 | pbar.set_description(f"Loss: {loss.item():.4f}")
111 |
112 | optimizer.zero_grad()
113 | loss.backward()
114 | optimizer.step()
115 | scheduler.step()
116 | losses.append(loss.item())
117 |
118 | step += 1
119 | pbar.update(1)
120 |
121 |
122 | if __name__ == "__main__":
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument("--gpu", type=int, default=0)
125 | parser.add_argument("--batch_size", type=int, default=16)
126 | parser.add_argument(
127 | "--train_num_steps",
128 | type=int,
129 | default=100000,
130 | help="the number of model training steps.",
131 | )
132 | parser.add_argument("--warmup_steps", type=int, default=200)
133 | parser.add_argument("--log_interval", type=int, default=200)
134 | parser.add_argument(
135 | "--max_num_images_per_cat",
136 | type=int,
137 | default=3000,
138 | help="max number of images per category for AFHQ dataset",
139 | )
140 | parser.add_argument(
141 | "--num_diffusion_train_timesteps",
142 | type=int,
143 | default=1000,
144 | help="diffusion Markov chain num steps",
145 | )
146 | parser.add_argument("--beta_1", type=float, default=1e-4)
147 | parser.add_argument("--beta_T", type=float, default=0.02)
148 | parser.add_argument("--seed", type=int, default=63)
149 | parser.add_argument("--image_resolution", type=int, default=64)
150 | parser.add_argument("--sample_method", type=str, default="ddpm")
151 | parser.add_argument("--use_cfg", action="store_true")
152 | parser.add_argument("--cfg_dropout", type=float, default=0.1)
153 | args = parser.parse_args()
154 | main(args)
155 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn==1.1.3
2 | ipython==8.12.0
3 | jupyter==1.0.0
4 | matplotlib==3.6.0
5 | torch-ema==0.3
6 | tqdm==4.64.1
7 | jupyterlab==4.0.2
8 | jaxtyping==0.2.20
9 | pytorch_lightning
10 | dotmap
11 | numpy==1.26.0
--------------------------------------------------------------------------------