├── .gitignore ├── README.md ├── assets ├── diffusion.py ├── images │ ├── task1_distribution.png │ ├── task1_ou.png │ ├── task1_sampling.png │ ├── task1_vesde.png │ ├── task1_vpsde.png │ ├── task2_1_ddpm_sampling_algorithm.png │ ├── task2_1_teaser.png │ ├── task2_2_teaser.png │ ├── task2_3_repaint_algorithm8.png │ ├── task2_algorithm.png │ ├── task2_ddim.png │ ├── task3_algorithm.png │ ├── task_2_3_teaser.png │ ├── teaser.gif │ └── teaser.png ├── sb_likelihood_training.pdf └── summary_of_DDPM_and_DDIM.pdf ├── image_diffusion_todo ├── __init__.py ├── dataset.py ├── ddpm.py ├── fid │ ├── afhq_inception_v3.ckpt │ ├── inception.py │ ├── measure_fid.py │ └── train_classifier.py ├── module.py ├── network.py ├── sampling.py ├── scheduler.py └── train.py ├── requirements.txt └── sde_todo ├── HelloScore.ipynb ├── dataset.py ├── eval.py ├── loss.py ├── network.py ├── sampling.py ├── sde.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.ipynb_checkpoints 4 | data/ 5 | results/ 6 | samples/ 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Introduction to Diffusion Models 4 |

5 |
6 |
7 | Juil Koo   Nguyen Minh Hieu 8 |
9 |
10 |

{63days, hieuristics} [at] kaist.ac.kr

11 |
12 |
13 |
14 | 15 |
16 | 17 |
18 | 19 |
20 | Table of Content 21 | 22 | - [Task 0](#task-0-introduction) 23 | - [Task 1](#task-1-very-simple-sgm-pipeline-with-delicious-swiss-roll) 24 | - [Task 1.1](#11-forward-and-reverse-process) [(a)](#a-ou-process), [(b)](#b-vpsde--vesde) 25 | - [Task 1.2](#12-training) 26 | - [Task 1.3](#13-sampling) 27 | - [Task 1.4](#14-evaluation) 28 | - [Task 2](#task-2-image-diffusion) 29 | - [Task 2.1](21-ddim) 30 | - [Task 2.2](22-classifier-free-guidance) 31 | - [Task 2.3](23-image-inpainting) 32 | 33 |
34 | 35 |
36 | Task Checklist 37 | 38 |
39 | 40 | **Task 1** 41 | - [ ] 1.1 Define Forward SDE 42 | - [ ] 1.1 Define Backward SDE 43 | - [ ] 1.1 Define VPSDE 44 | - [ ] 1.1 Define VESDE 45 | - [ ] 1.2 Implement MLP Network 46 | - [ ] 1.2 Implement DSM Loss 47 | - [ ] 1.2 Implement Training Loop 48 | - [ ] 1.3 Implement Discretization 49 | - [ ] 1.3 Implement Sampling Loop 50 | - [ ] 1.4 Evaluate Implementation 51 | 52 | **Task 2** 53 | - [ ] 2.1 Implement DDIM Variance Scheduling 54 | - [ ] 2.2 Implement CFG 55 | - [ ] 2.3 Implement Image Inpainting 56 | 57 | **Optional Tasks** 58 | - [X] Add more additional task that you did here. 59 | - [ ] Implement EMA Training 60 | - [ ] Implement ISM Loss 61 | - [ ] Implement ODE Sampling 62 | - [ ] Implement Schrodinger Bridge 63 | - [ ] Implement MCG Inpainting 64 | 65 |
66 | 67 | 68 | ## Setup 69 | 70 | Install the required package within the `requirements.txt` 71 | ``` 72 | pip install -r requirements.txt 73 | ``` 74 | 75 | ## Code Structure 76 | ``` 77 | . 78 | ├── image_diffusion (Task 2) 79 | │   ├── dataset.py <--- Ready-to-use AFHQ dataset code 80 | │   ├── train.py <--- DDPM training code 81 | │   ├── sampling.py <--- Image sampling code 82 | │   ├── ddpm.py <--- DDPM high-level wrapper code 83 | │   ├── module.py <--- Basic modules of a noise prediction network 84 | │   ├── network.py <--- Noise prediction network 85 | │   ├── scheduler.py <--- (TODO) Define variance schedulers 86 | │ └── fid 87 | │ ├── measure_fid.py <--- script measuring FID score 88 | │ └── afhq_inception.ckpt <--- pre-trained classifier for FID 89 | └── sde_todo (Task 1) 90 | ├── HelloScore.ipynb <--- Main code 91 | ├── dataset.py <--- Define dataset (Swiss-roll, moon, gaussians, etc.) 92 | ├── eval.py <--- Evaluation code 93 | ├── loss.py <--- (TODO) Define Training Objective 94 | ├── network.py <--- (TODO) Define Network Architecture 95 | ├── sampling.py <--- (TODO) Define Discretization and Sampling 96 | ├── sde.py <--- (TODO) Define SDE Processes 97 | └── train.py <--- (TODO) Define Training Loop 98 | ``` 99 | 100 | ## Tutorial Tips 101 | 102 | Implementation of Diffusion Models is typically very simple once you understand the theory. 103 | So, to learn the most from this tutorial, it's highly recommended to check out the details in the 104 | related papers and understand the equations **BEFORE** you start the tutorial. You can check out 105 | the resources in this order: 106 | 1. [[blog](https://min-hieu.github.io/blogs/blogs/brownian/)] Charlie's "Brownian Motion and SDE" 107 | 2. [[paper](https://arxiv.org/abs/2011.13456)] Score-Based Generative Modeling through Stochastic Differential Equations 108 | 3. [[blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] Lilian Wang's "What is Diffusion Model?" 109 | 4. [[paper](https://arxiv.org/abs/2006.11239)] Denoising Diffusion Probabilistic Models 110 | 5. [[slide](./assets/summary_of_DDPM_and_DDIM.pdf)] Summary of DDPM and DDIM 111 | 112 | ## Task 0: Introduction 113 | The first part of this tutorial will introduce diffusion models through the lens of stochastic differential equations (SDE). 114 | Prior to the [Yang Song et al. (2021)](https://arxiv.org/abs/2011.13456) paper, diffusion models are often understood 115 | in terms of Markov Processes with tractable transition kernel. Understanding SDE could also help you develop more 116 | efficient variance scheduling or give more flexibility to your diffusion model.  117 | 118 | We know that a stochastic differential equation has the following form: 119 | $$d\mathbf{X}_t = f(t,\mathbf{X}_t)dt + G(t)d\mathbf{B}_t$$ 120 | where $f$ and $G$ are the drift and diffusion coefficients respectively and $\mathbf{B}_t$ is the 121 | standard Brownian noise. A popular SDE often used is called Ornstein-Uhlenbeck (OU) process 122 | which is defined as 123 | $$d\mathbf{X}_t = -\mu \mathbf{X}_tdt + \sigma d\mathbf{B}_t$$ 124 | Where $\mu, \sigma$ are constants. In this tutorial, we will set $\mu = \frac{1}{2}, \sigma = 1$. 125 | Score-based generative modelling (SGM) aims to sample from an unknown distribution of a given dataset. 126 | We have the following two observations: 127 | - The OU process always results in a unit Gaussian. 128 | - We can derive the equation for the inverse OU process. 129 | 130 | From these facts, we can directly sample from the unknown distribution by 131 | 1. Sample from unit Gaussian 132 | 2. Run the reverse process on samples from step 1. 133 | 134 | [Yang Song et al. (2021)](https://arxiv.org/abs/2011.13456) derived the likelihood training scheme 135 | for learning the reverse process. In summary, the reverse process for any SDE given above is 136 | of the form 137 | $$d\mathbf{X}_t = [f(t,\mathbf{X}_t) - G(t)^2\nabla_x\log p_t(\mathbf{X}_t)]dt + G(t)d\bar{\mathbf{B}}_t$$ 138 | where $\bar{\mathbf{B}}_t$ is the reverse brownian noise. The only unknown term is the score function 139 | $\nabla_x\log p_t(\mathbf{X}_t)$, which we will approximate with a Neural Network. One main difference 140 | between SGM and other generative models is that they generate iteratively during the sampling process. 141 | 142 | **TODO:** 143 | ``` 144 | - Derive the expression for the mean and std of the OU process at time t given X0 = 0, 145 | i.e. Find E[Xt|X0] and Var[Xt|X0]. You will need this for task 1.1(a). 146 | ``` 147 | *hint*: We know that the solution to the OU process is given as 148 | 149 | $\mathbf{X}_T = \mathbf{X}_0 e^{-\mu T} + \sigma \int_0^T e^{-\mu(T-t)} d\mathbf{B}_t$ 150 | 151 | and you can use the fact that $d\mathbf{B}_t^2 = dt$, and $\mathbb{E}[\int_0^T f(t) d\mathbf{B}_t] = 0$ where $f(t)$ is any 152 | deterministic function. 153 | 154 | ## Task 1: very simple SGM pipeline with delicious swiss-roll 155 | A typical diffusion pipeline is divided into three components: 156 | 1. [Forward Process and Reverse Process](#11-forward-and-reverse-process) 157 | 2. [Training](#12-training) 158 | 3. [Sampling](#13-sampling) 159 | 160 | In this task, we will look into each component one by one and implement them sequentially. 161 | #### 1.1. Forward and Reverse Process 162 |

163 | image 164 |

165 | 166 | Our first goal is to setup the forward and reverse processes. In the forward process, the final distribution should be 167 | the prior distribution which is the standard normal distribution. 168 | #### (a) OU Process 169 | Following the formulation of the OU Process introduced in the previous section, complete the `TODO` in the 170 | `sde.py` and check if the final distribution approach unit gaussian as $t\rightarrow \infty$. 171 | 172 |

173 | image 174 |

175 | 176 | 177 | **TODO:** 178 | ``` 179 | - implement the forward process using the given marginal probability p_t0(Xt|X0) in SDE.py 180 | - implement the reverse process for general SDE in SDE.py 181 | - (optional) Play around with terminal time (T) and number of time steps (N) and observe its effect 182 | ``` 183 | 184 | #### (b) VPSDE & VESDE 185 | It's mentioned by [Yang Song et al. (2021)](https://arxiv.org/abs/2011.13456) that the DDPM and SMLD are distretization of SDEs. 186 | Implement this in the `sde.py` and check their mean and and std. 187 | 188 | *hint*: Although you can simulate the diffusion process through discretization, sampling with the explicit equation of the marginal probability $p_{t0}(\mathbf{X}_t \mid \mathbf{X}_0)$ is much faster. 189 | 190 | You should also obtain the following graphs for VPSDE and VESDE respectively 191 |

192 | image 193 | image 194 | "> 195 |

196 | 197 | **TODO:** 198 | ``` 199 | - implement VPSDE in SDE.py 200 | - implement VESDE in SDE.py 201 | - plot the mean and variance of VPSDE and VESDE vs. time. 202 | What can you say about the differences between OU, VPSDE, VESDE? 203 | ``` 204 | 205 | #### 1.2. Training 206 | The typical training objective of diffusion model uses **D**enoising **S**core **M**atching loss: 207 | 208 | $$f_{\theta^*} = \textrm{ argmin } \mathbb{E} [||f_\theta(\mathbf{X}s) - \nabla_{\mathbf{X}s} p_{s0}(\mathbf{X}s\mid \mathbf{X}_0)||^2] $$ 209 | 210 | Where $f$ is the score prediction network with parameter $\theta^*$. 211 | Another popular training objective is **I**mplicit **S**core **M**atching loss which can be derived from DSM. 212 | One main different between ISM and DSM is that ISM doesn't require marignal density but instead the divergence. 213 | Although DSM is easier to implement, when the given [exotic domains](https://arxiv.org/abs/2202.02763) or 214 | when the marginal density [doesn't have closed form](https://openreview.net/pdf?id=nioAdKCEdXB) ISM is used. 215 | 216 | **(Important)** you need to derive a **different DSM objective** for each SDE since 217 | their marginal density is different. You first need to obtain the closed form for $p_{0t}$, then you can find the equation for $\nabla \log p_{0t}$. 218 | For $p_{0t}$, you can refer to the appendix of the SDE paper. 219 | 220 | 221 | However, there are other training objectives with their different trade-offs (SSM, EDM, etc.). Highly recommend to checkout 222 | [A Variational Perspective on Diffusion-based Generative Models and Score Matching](https://arxiv.org/abs/2106.02808) 223 | and [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) for a more in-depth analysis of the recent training objectives. 224 | 225 | **TODO:** 226 | ``` 227 | - implement your own network in network.py 228 | (Recommend to implement Positional Encoding, Residual Connection) 229 | - implement DSMLoss in loss.py 230 | - implement the training loop in train_utils.py 231 | - (optional) implement ISMLoss in loss.py (hint: you will need to use torch.autograd.grad) 232 | - (optional) implement SSMLoss in loss.py 233 | ``` 234 | #### 1.3. Sampling 235 | Finally, we can now use the trained score prediction network to sample from the swiss-roll dataset. Unlike the forward process, there is no analytical form 236 | of the marginal probabillity. Therefore, we have to run the simulation process. Your final sampling should be close to the target distribution 237 | **within 10000 training steps**. For this task, you are free to use **ANY** variations of diffusion process that **was mentioned** above. 238 | 239 |

240 | image 241 |

242 | 243 | 244 | **TODO:** 245 | ``` 246 | - implement the predict_fn in sde.py 247 | - complete the code in sampling.py 248 | - (optional) train with ema 249 | - (optional) implement the correct_fn (for VPSDE, VESDE) in sde.py 250 | - (optional) implement the ODE discretization and check out their differences 251 | ``` 252 | 253 | #### 1.4. Evaluation 254 | To evaluate your performance, we compute the chamfer distance (CD) and earth mover distance (EMD) between the target and generated point cloud. 255 | Your method should be on par or better than the following metrics. For this task, you can use **ANY** variations, even ones that were **NOT** mentioned. 256 | 257 | | target distribution | CD | 258 | |---------------------|----------| 259 | | swiss-roll | 0.1975 | 260 | 261 | #### 1.5. [Coming Soon] Schrödinger Bridge (Optional) 262 | One restriction to the typical diffusion processes are that they requires the prior to be easy to sample (gaussian, uniform, etc.). 263 | Schrödinger Bridge removes this limitation by making the forward process also learnable and allow a diffusion defined between **two** unknown distribution. 264 | 265 | ## Task 2: Image Diffusion 266 | 267 |

268 | image 269 |

270 | 271 | In this task, we will play with diffusion models to generate 2D images. We first look into some background of DDPM and then dive into DDPM in a code level. 272 | 273 | ### Background 274 | From the perspective of SDE, SGM and DDPM are the same models with only different parameterizations. As there are forward and reverse processes in SGM, the forward process, or called _diffusion process_, of DDPM is fixed to a Markov chain that gradually adds Gaussian noise to the data: 275 | 276 | $$ q(\mathbf{x}\_{1:T} | \mathbf{x}_0) := \prod\_{t=1}^T q(\mathbf{x}_t | \mathbf{x}\_{t-1}), \quad q(\mathbf{x}_t | \mathbf{x}\_{t-1}) := \mathcal{N} (\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}\_{t-1}, \beta_t \mathbf{I}).$$ 277 | 278 | 279 | Thanks to a nice property of a Gaussian distribution, one can sample $\mathbf{x}_t$ at an arbitrary timestep $t$ from real data $\mathbf{x}_0$ in closed form: 280 | 281 | $$q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t) \mathbf{I}) $$ 282 | 283 | where $\alpha\_t := 1 - \beta\_t$ and $\bar{\alpha}_t := \prod$ $\_{s=1}^T \alpha_s$. 284 | 285 | Given the diffusion process, we want to model the _reverse process_ that gradually denoises white Gaussian noise $\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ to sample real data. It is also defined as a Markov chain with learned Gaussian transitions: 286 | 287 | $$p\_\theta(\mathbf{x}\_{0:T}) := p(\mathbf{x}_T) \prod\_{t=1}^T p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t), \quad p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t) := \mathcal{N}(\mathbf{x}\_{t-1}; \mathbf{\boldsymbol{\mu}}\_\theta (\mathbf{x}_t, t), \boldsymbol{\Sigma}\_\theta (\mathbf{x}_t, t)).$$ 288 | 289 | To learn this reverse process, we set an objective function that minimizes KL divergence between $p_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t)$ and $q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0)$ which is tractable when conditioned on $\mathbf{x}_0$: 290 | 291 | $$\mathcal{L} = \mathbb{E}_q \left[ \sum\_{t > 1} D\_{\text{KL}}( q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \Vert p\_\theta ( \mathbf{x}\_{t-1} | \mathbf{x}_t)) \right]$$ 292 | 293 | Refer to [the original paper](https://arxiv.org/abs/2006.11239) or our PPT material for more details. 294 | 295 | As a parameterization of DDPM, the authors set $\boldsymbol{\Sigma}\_\theta(\mathbf{x}_t, t) = \sigma_t^2 \mathbf{I}$ to untrained time dependent constants, and they empirically found that predicting noise injected to data by a noise prediction network $\epsilon\_\theta$ is better than learning the mean function $\boldsymbol{\mu}\_\theta$. 296 | 297 | In short, the simplified objective function of DDPM is defined as follows: 298 | 299 | $$ \mathcal{L}\_{\text{simple}} := \mathbb{E}\_{t,\mathbf{x}_0,\boldsymbol{\epsilon}} [ \Vert \boldsymbol{\epsilon} - \boldsymbol{\epsilon}\_\theta( \mathbf{x}\_t(\mathbf{x}_0, t), t) \Vert^2 ],$$ 300 | 301 | where $\mathbf{x}_t (\mathbf{x}_0, t) = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}$ and $\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$. 302 | 303 | #### Sampling 304 | 305 | Once we train the noise prediction network $\boldsymbol{\epsilon}\_\theta$, we can run sampling by gradually denoising white Gaussian noise. The algorithm of the DDPM sampling is shown below: 306 | 307 |

308 | image 309 |

310 | 311 | [DDIM](https://arxiv.org/abs/2010.02502) proposed a way to speed up the sampling using the same pre-trained DDPM. The reverse step of DDIM is below: 312 | 313 |

314 | image 315 |

316 | 317 | Note that $\alpha_t$ notation in DDIM corresponds to $\bar{\alpha}_t$ in DDPM paper. 318 | 319 | Please refer to DDIM paper for more details. 320 | 321 | #### 2.1. DDIM 322 | ### TODO 323 | 324 | In this task, we will generate $64\times64$ animal images using DDPM with AFHQ dataset. We provide most of code only except for variance scheduling code. You need to implement DDPM scheduler and DDIM scheduler in `scheduler.py`. After implementing the schedulers, train a model by `python train.py`. It will sample images and save a checkpoint every `args.log_interval`. After training a model, sample & save images by 325 | ``` 326 | python sampling.py --ckpt_path ${CKPT_PATH} --save_dir ${SAVE_DIR_PATH} 327 | ``` 328 | 329 | We recommend starting the training as soon as possible since the training would take about half of a day. Also, DDPM scheduler is really slow. We recommend **implementing DDIM scheduler first** and set `inference timesteps` 20~50 which is enough to get high-quality images with much less sampling time. 330 | 331 | 332 | As an evaluation, measure FID score using the pre-trained classifier network we provide: 333 | 334 | ``` 335 | python dataset.py # to constuct eval directory. 336 | python fid/measure_fid.py /path/to/eval/dir /path/to/sample/dir 337 | ``` 338 | _**Success condition**_: Achieve FID score lower than `30`. 339 | 340 | #### 2.2. Classifier-Free Guidance 341 | 342 | Now, we will implement a classifier-free guidance diffusion model. It trains an unconditional diffusion model and a conditional diffusion model jointly by randomly dropping out a conditional term. The algorithm is below: 343 | 344 |

345 | image 346 |

347 | 348 | You need to train another diffusion model for classifier-free guidance by slightly modifying the network architecture so that it can take class labels as input. The network design is your choice. Our implementation used `nn.Embedding` for class label embeddings and simply add class label embeddings to time embeddings. We set condition term dropout rate 0.1 in training and `guidance_scale` 7.5. 349 | 350 | Note that the provided code considers null class label as 0. 351 | 352 | Generate 200 images per category, 600 in total. Measure FID with the same validation set used in 2.1. 353 | _**Success condition**_: Achieve FID score lower than `30`. 354 | 355 | For more details, refer to the [paper](https://arxiv.org/abs/2207.12598). 356 | 357 | #### 2.3. Image Inpainting 358 | 359 |

360 | image 361 |

362 | 363 | DDPMs have zero-shot capabilities handling various downstream tasks beyond unconditional generation. Among them, we will focus on the image inpainting task only. 364 | 365 | Note that there is **no base code** for image inpainting. 366 | 367 | Make a rectangle hole with a $32 \times 32$ size in the images generated in 2.1. 368 | Then, do image inpainting based on Algorithm 8 in [Repaint](https://arxiv.org/abs/2201.09865). 369 | 370 |

371 | image 372 |

373 | 374 | 375 | Report FID scores with 500 result images and the same validation set used in Task 2.1. 376 | You will get a full credit if the **_FID is lower than 30_**. 377 | 378 | #### [Optional] Improving image inpainting by MCG 379 | 380 | A recent paper [Improving Diffusion Models for Inverse Problems using Manifold Constraints](https://arxiv.org/abs/2206.00941), also known as MCG, proposed a way to improve the solving various inverse problems, such as image inpainting, using DDPMs. In a high-level idea, in the reverse process, it takes an additional gradient descent towards a subspace of a latent space satisfying a given partial observation. Refer to the [original paper](https://arxiv.org/abs/2206.00941) for more details and implement MCG-based image inpainting code. 381 | 382 | Compare image inpainting results between MCG and the baseline. 383 | 384 | ## Resources 385 | - [[paper](https://arxiv.org/abs/2011.13456)] Score-Based Generative Modeling through Stochastic Differential Equations 386 | - [[paper](https://arxiv.org/abs/2006.09011)] Improved Techniques for Training Score-Based Generative Models 387 | - [[paper](https://arxiv.org/abs/2006.11239)] Denoising Diffusion Probabilistic Models 388 | - [[paper](https://arxiv.org/abs/2105.05233)] Diffusion Models Beat GANs on Image Synthesis 389 | - [[paper](https://arxiv.org/abs/2207.12598)] Classifier-Free Diffusion Guidance 390 | - [[paper](https://arxiv.org/abs/2010.02502)] Denoising Diffusion Implicit Models 391 | - [[paper](https://arxiv.org/abs/2206.00364)] Elucidating the Design Space of Diffusion-Based Generative Models 392 | - [[paper](https://arxiv.org/abs/2106.02808)] A Variational Perspective on Diffusion-Based Generative Models and Score Matching 393 | - [[paper](https://arxiv.org/abs/2305.16261)] Trans-Dimensional Generative Modeling via Jump Diffusion Models 394 | - [[paper](https://openreview.net/pdf?id=nioAdKCEdXB)] Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory 395 | - [[blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] What is Diffusion Model? 396 | - [[blog](https://yang-song.net/blog/2021/score/)] Generative Modeling by Estimating Gradients of the Data Distribution 397 | - [[lecture](https://youtube.com/playlist?list=PLCf12vHS8ONRpLNVGYBa_UbqWB_SeLsY2)] Charlie's Playlist on Diffusion Processes 398 | - [[slide](./assets/summary_of_DDPM_and_DDIM.pdf)] Juil's presentation slide of DDIM 399 | - [[slide](./assets/sb_likelihood_training.pdf)] Charlie's presentation of Schrödinger Bridge. 400 | -------------------------------------------------------------------------------- /assets/diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib import animation 4 | from scipy.ndimage import gaussian_filter 5 | from tqdm import tqdm 6 | from PIL import Image 7 | 8 | def step(x, dt=0.001, mu=1.7, sigma=1.2): 9 | z = np.random.randn(*x.shape) 10 | dx = -mu*x*dt + sigma*np.sqrt(dt)*z 11 | return x + dx 12 | 13 | def step_img(x, dt=0.001, mu=120, sigma=1): 14 | z = np.random.randn(*x.shape) 15 | dx = -mu*x*dt + sigma*np.sqrt(dt)*z 16 | return x + dx 17 | 18 | 19 | def sample_two_gaus(N=10000): 20 | N1 = N // 2 21 | N2 = N - N1 22 | samples1 = np.random.randn(N1) + 2 23 | samples2 = np.random.randn(N2) - 2 24 | samples = np.hstack((samples1, samples2)) 25 | return samples 26 | 27 | def p(xt, N=100, mi=-5, ma=5): 28 | return np.histogram(xt, bins=N, range=(mi,ma))[0] 29 | 30 | 31 | n_steps = 500 32 | bins = 100 33 | path = np.zeros((bins, n_steps)) 34 | single_path = np.zeros((n_steps)) 35 | xt = sample_two_gaus() 36 | single_xt = np.array([-2.8]) 37 | 38 | for i in range(n_steps): 39 | path[:,i] = p(xt, N=bins) 40 | xt = step(xt, dt=1/n_steps) 41 | 42 | single_xt = step(single_xt, dt=1/n_steps) 43 | single_path[i] = single_xt[0] 44 | 45 | smooth_path = gaussian_filter(path, sigma=3) 46 | 47 | thanos_img = np.array(Image.open('./thanos.png')) / 255. - 0.5 48 | 49 | fig = plt.Figure(figsize=(15,5)) 50 | fig.tight_layout() 51 | 52 | ax = fig.add_subplot(3,9,(4,27)) 53 | ax.get_xaxis().set_visible(False) 54 | ax.get_yaxis().set_visible(False) 55 | ax.imshow(smooth_path, interpolation='nearest', aspect='auto') 56 | 57 | ax2 = fig.add_subplot(3,9,(4,27)) 58 | ax2.get_xaxis().set_visible(False) 59 | ax2.get_yaxis().set_visible(False) 60 | ax2.set_ylim(-5, 5) 61 | ax2.set_xlim(0, n_steps) 62 | ax2.patch.set_alpha(0.) 63 | point = ax2.plot(-6*np.ones(n_steps), 'ro')[0] 64 | line = ax2.plot(single_path, 'r-')[0] 65 | 66 | ax_img = fig.add_subplot(3,9,(1,21)) 67 | ax_img.get_xaxis().set_visible(False) 68 | ax_img.get_yaxis().set_visible(False) 69 | ax_img.margins(0,0) 70 | ax_img.imshow(thanos_img) 71 | 72 | n_animate = 50 73 | pbar = tqdm(total=n_animate) 74 | 75 | def update(i): 76 | point_y = -6*np.ones(n_steps) 77 | point_y[1000//n_animate*i] = single_path[1000//n_animate*i] 78 | point.set_ydata(point_y) 79 | 80 | line_y = single_path[:1000//n_animate*i] 81 | line.set_data(range(line_y.shape[0]), line_y) 82 | 83 | ax_img.cla() 84 | global thanos_img 85 | ax_img.imshow(thanos_img + 0.5, interpolation='nearest', aspect='auto') 86 | 87 | thanos_img = step_img(thanos_img) 88 | print(thanos_img.min(), thanos_img.max()) 89 | 90 | pbar.update(1) 91 | 92 | ani = animation.FuncAnimation(fig, update, range(n_animate)) 93 | ani.save('diff_traj.gif', writer='imagemagick', fps=25, savefig_kwargs=dict(pad_inches=0, bbox_inches='tight')); 94 | 95 | # fig.savefig('diff_traj.png') 96 | -------------------------------------------------------------------------------- /assets/images/task1_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_distribution.png -------------------------------------------------------------------------------- /assets/images/task1_ou.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_ou.png -------------------------------------------------------------------------------- /assets/images/task1_sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_sampling.png -------------------------------------------------------------------------------- /assets/images/task1_vesde.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_vesde.png -------------------------------------------------------------------------------- /assets/images/task1_vpsde.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_vpsde.png -------------------------------------------------------------------------------- /assets/images/task2_1_ddpm_sampling_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_1_ddpm_sampling_algorithm.png -------------------------------------------------------------------------------- /assets/images/task2_1_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_1_teaser.png -------------------------------------------------------------------------------- /assets/images/task2_2_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_2_teaser.png -------------------------------------------------------------------------------- /assets/images/task2_3_repaint_algorithm8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_3_repaint_algorithm8.png -------------------------------------------------------------------------------- /assets/images/task2_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_algorithm.png -------------------------------------------------------------------------------- /assets/images/task2_ddim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_ddim.png -------------------------------------------------------------------------------- /assets/images/task3_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task3_algorithm.png -------------------------------------------------------------------------------- /assets/images/task_2_3_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task_2_3_teaser.png -------------------------------------------------------------------------------- /assets/images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/teaser.gif -------------------------------------------------------------------------------- /assets/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/teaser.png -------------------------------------------------------------------------------- /assets/sb_likelihood_training.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/sb_likelihood_training.pdf -------------------------------------------------------------------------------- /assets/summary_of_DDPM_and_DDIM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/summary_of_DDPM_and_DDIM.pdf -------------------------------------------------------------------------------- /image_diffusion_todo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/image_diffusion_todo/__init__.py -------------------------------------------------------------------------------- /image_diffusion_todo/dataset.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import os 3 | from itertools import chain 4 | from pathlib import Path 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | 10 | 11 | def listdir(dname): 12 | fnames = list( 13 | chain( 14 | *[ 15 | list(Path(dname).rglob("*." + ext)) 16 | for ext in ["png", "jpg", "jpeg", "JPG"] 17 | ] 18 | ) 19 | ) 20 | return fnames 21 | 22 | 23 | def tensor_to_pil_image(x: torch.Tensor, single_image=False): 24 | """ 25 | x: [B,C,H,W] 26 | """ 27 | if x.ndim == 3: 28 | x = x.unsqueeze(0) 29 | single_image = True 30 | 31 | x = (x * 0.5 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy() 32 | images = (x * 255).round().astype("uint8") 33 | images = [Image.fromarray(image) for image in images] 34 | if single_image: 35 | return images[0] 36 | return images 37 | 38 | 39 | def get_data_iterator(iterable): 40 | """Allows training with DataLoaders in a single infinite loop: 41 | for i, data in enumerate(inf_generator(train_loader)): 42 | """ 43 | iterator = iterable.__iter__() 44 | while True: 45 | try: 46 | yield iterator.__next__() 47 | except StopIteration: 48 | iterator = iterable.__iter__() 49 | 50 | 51 | class AFHQDataset(torch.utils.data.Dataset): 52 | def __init__( 53 | self, root: str, split: str, transform=None, max_num_images_per_cat=-1, label_offset=1 54 | ): 55 | super().__init__() 56 | self.root = root 57 | self.split = split 58 | self.transform = transform 59 | self.max_num_images_per_cat = max_num_images_per_cat 60 | self.label_offset = label_offset 61 | 62 | categories = os.listdir(os.path.join(root, split)) 63 | self.num_classes = len(categories) 64 | 65 | fnames, labels = [], [] 66 | for idx, cat in enumerate(sorted(categories)): 67 | category_dir = os.path.join(root, split, cat) 68 | cat_fnames = listdir(category_dir) 69 | cat_fnames = sorted(cat_fnames) 70 | if self.max_num_images_per_cat > 0: 71 | cat_fnames = cat_fnames[: self.max_num_images_per_cat] 72 | fnames += cat_fnames 73 | labels += [idx + label_offset] * len(cat_fnames) # label 0 is for null class. 74 | 75 | self.fnames = fnames 76 | self.labels = labels 77 | 78 | def __getitem__(self, idx): 79 | img = Image.open(self.fnames[idx]).convert("RGB") 80 | label = self.labels[idx] 81 | assert label >= self.label_offset 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | 85 | return img, label 86 | 87 | def __len__(self): 88 | return len(self.labels) 89 | 90 | 91 | class AFHQDataModule(object): 92 | def __init__( 93 | self, 94 | root: str = "data", 95 | batch_size: int = 32, 96 | num_workers: int = 4, 97 | max_num_images_per_cat: int = -1, 98 | image_resolution: int = 64, 99 | label_offset=1, 100 | ): 101 | self.root = root 102 | self.batch_size = batch_size 103 | self.num_workers = num_workers 104 | self.afhq_root = os.path.join(root, "afhq") 105 | self.max_num_images_per_cat = max_num_images_per_cat 106 | self.image_resolution = image_resolution 107 | self.label_offset = label_offset 108 | 109 | if not os.path.exists(self.afhq_root): 110 | print(f"{self.afhq_root} is empty. Downloading AFHQ dataset...") 111 | self._download_dataset() 112 | 113 | self._set_dataset() 114 | 115 | def _set_dataset(self): 116 | self.transform = transforms.Compose( 117 | [ 118 | transforms.Resize((self.image_resolution, self.image_resolution)), 119 | transforms.ToTensor(), 120 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 121 | ] 122 | ) 123 | self.train_ds = AFHQDataset( 124 | self.afhq_root, 125 | "train", 126 | self.transform, 127 | max_num_images_per_cat=self.max_num_images_per_cat, 128 | label_offset=self.label_offset 129 | ) 130 | self.val_ds = AFHQDataset( 131 | self.afhq_root, 132 | "val", 133 | self.transform, 134 | max_num_images_per_cat=self.max_num_images_per_cat, 135 | label_offset=self.label_offset, 136 | ) 137 | 138 | self.num_classes = self.train_ds.num_classes 139 | 140 | def _download_dataset(self): 141 | URL = "https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0" 142 | ZIP_FILE = f"./{self.root}/afhq.zip" 143 | os.system(f"mkdir -p {self.root}") 144 | os.system(f"wget -N {URL} -O {ZIP_FILE}") 145 | os.system(f"unzip {ZIP_FILE} -d {self.root}") 146 | os.system(f"rm {ZIP_FILE}") 147 | 148 | 149 | def train_dataloader(self): 150 | return torch.utils.data.DataLoader( 151 | self.train_ds, 152 | batch_size=self.batch_size, 153 | num_workers=self.num_workers, 154 | shuffle=True, 155 | drop_last=True, 156 | ) 157 | 158 | def val_dataloader(self): 159 | return torch.utils.data.DataLoader( 160 | self.val_ds, 161 | batch_size=self.batch_size, 162 | num_workers=self.num_workers, 163 | shuffle=False, 164 | drop_last=False, 165 | ) 166 | 167 | if __name__ == "__main__": 168 | data_module = AFHQDataModule("data", 32, 4, -1, 64, 1) 169 | 170 | eval_dir = Path(data_module.afhq_root) / "eval" 171 | eval_dir.mkdir(exist_ok=True) 172 | def func(path): 173 | fn = path.name 174 | cmd = f"cp {path} {eval_dir / fn}" 175 | os.system(cmd) 176 | img = Image.open(str(eval_dir / fn)) 177 | img = img.resize((64,64)) 178 | img.save(str(eval_dir / fn)) 179 | print(fn) 180 | 181 | with Pool(8) as pool: 182 | pool.map(func, data_module.val_ds.fnames) 183 | 184 | print(f"Constructed eval dir at {eval_dir}") 185 | -------------------------------------------------------------------------------- /image_diffusion_todo/ddpm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from scheduler import BaseScheduler 8 | 9 | 10 | class DiffusionModule(nn.Module): 11 | def __init__(self, network: nn.Module, var_scheduler: BaseScheduler, **kwargs): 12 | super().__init__() 13 | self.network = network 14 | self.var_scheduler = var_scheduler 15 | 16 | def get_loss(self, x0, class_label=None, noise=None): 17 | B = x0.shape[0] 18 | timestep = self.var_scheduler.uniform_sample_t(B, self.device) 19 | x_noisy, noise = self.var_scheduler.add_noise(x0, timestep) 20 | noise_pred = self.network(x_noisy, timestep=timestep, class_label=class_label) 21 | 22 | loss = F.mse_loss(noise_pred.flatten(), noise.flatten(), reduction="mean") 23 | return loss 24 | 25 | @property 26 | def device(self): 27 | return next(self.network.parameters()).device 28 | 29 | @property 30 | def image_resolution(self): 31 | return self.network.image_resolution 32 | 33 | @torch.no_grad() 34 | def sample( 35 | self, 36 | batch_size: int, 37 | return_traj: bool = False, 38 | ): 39 | """ 40 | Sample x_0 from a learned diffusion model. 41 | """ 42 | x_T = torch.randn([batch_size, 3, self.image_resolution, self.image_resolution]).to( 43 | self.device 44 | ) 45 | 46 | traj = [x_T] 47 | for t in self.var_scheduler.timesteps: 48 | x_t = traj[-1] 49 | noise_pred = self.network(x_t, timestep=t.to(self.device)) 50 | 51 | x_t_prev = self.var_scheduler.step(x_t, t, noise_pred) 52 | 53 | traj[-1] = traj[-1].cpu() 54 | traj.append(x_t_prev.detach()) 55 | 56 | if return_traj: 57 | return traj 58 | else: 59 | return traj[-1] 60 | 61 | def save(self, file_path): 62 | hparams = { 63 | "network": self.network, 64 | "var_scheduler": self.var_scheduler, 65 | } 66 | state_dict = self.state_dict() 67 | 68 | dic = {"hparams": hparams, "state_dict": state_dict} 69 | torch.save(dic, file_path) 70 | 71 | def load(self, file_path): 72 | dic = torch.load(file_path, map_location="cpu") 73 | hparams = dic["hparams"] 74 | state_dict = dic["state_dict"] 75 | 76 | self.network = hparams["network"] 77 | self.var_scheduler = hparams["var_scheduler"] 78 | 79 | self.load_state_dict(state_dict) 80 | 81 | -------------------------------------------------------------------------------- /image_diffusion_todo/fid/afhq_inception_v3.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/image_diffusion_todo/fid/afhq_inception_v3.ckpt -------------------------------------------------------------------------------- /image_diffusion_todo/fid/inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | import numpy as np 11 | import torch.nn as nn 12 | from torchvision import models 13 | 14 | 15 | class InceptionV3(nn.Module): 16 | def __init__(self, for_train): 17 | super().__init__() 18 | self.for_train = for_train 19 | 20 | inception = models.inception_v3(pretrained=False) 21 | self.block1 = nn.Sequential( 22 | inception.Conv2d_1a_3x3, 23 | inception.Conv2d_2a_3x3, 24 | inception.Conv2d_2b_3x3, 25 | nn.MaxPool2d(kernel_size=3, stride=2), 26 | ) 27 | self.block2 = nn.Sequential( 28 | inception.Conv2d_3b_1x1, 29 | inception.Conv2d_4a_3x3, 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.block3 = nn.Sequential( 33 | inception.Mixed_5b, 34 | inception.Mixed_5c, 35 | inception.Mixed_5d, 36 | inception.Mixed_6a, 37 | inception.Mixed_6b, 38 | inception.Mixed_6c, 39 | inception.Mixed_6d, 40 | inception.Mixed_6e, 41 | ) 42 | self.block4 = nn.Sequential( 43 | inception.Mixed_7a, 44 | inception.Mixed_7b, 45 | inception.Mixed_7c, 46 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 47 | ) 48 | 49 | self.final_fc = nn.Linear(2048, 3) 50 | 51 | def forward(self, x): 52 | x = self.block1(x) 53 | x = self.block2(x) 54 | x = self.block3(x) 55 | x = self.block4(x) 56 | x = x.view(x.size(0), -1) 57 | if self.for_train: 58 | return self.final_fc(x) 59 | else: 60 | return x 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /image_diffusion_todo/fid/measure_fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | from PIL import Image 7 | from scipy import linalg 8 | from torchvision import transforms 9 | from itertools import chain 10 | from pathlib import Path 11 | from inception import InceptionV3 12 | 13 | try: 14 | from tqdm import tqdm 15 | except ImportError: 16 | def tqdm(x): 17 | return x 18 | 19 | class ImagePathDataset(torch.utils.data.Dataset): 20 | def __init__(self, files, img_size): 21 | self.files = files 22 | self.img_size = img_size 23 | self.transforms = transforms.Compose( 24 | [ 25 | transforms.Resize((img_size, img_size)), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 28 | ] 29 | ) 30 | 31 | def __len__(self): 32 | return len(self.files) 33 | 34 | def __getitem__(self, i): 35 | path = self.files[i] 36 | img = Image.open(path).convert("RGB") 37 | if self.transforms is not None: 38 | img = self.transforms(img) 39 | return img 40 | 41 | 42 | def get_eval_loader(path, img_size, batch_size): 43 | def listdir(dname): 44 | fnames = list( 45 | chain( 46 | *[ 47 | list(Path(dname).rglob("*." + ext)) 48 | for ext in ["png", "jpg", "jpeg", "JPG"] 49 | ] 50 | ) 51 | ) 52 | return fnames 53 | 54 | files = listdir(path) 55 | ds = ImagePathDataset(files, img_size) 56 | dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4) 57 | return dl 58 | 59 | def frechet_distance(mu, cov, mu2, cov2): 60 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False) 61 | dist = np.sum((mu - mu2) ** 2) + np.trace(cov + cov2 - 2 * cc) 62 | return np.real(dist) 63 | 64 | 65 | 66 | @torch.no_grad() 67 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50): 68 | print("Calculating FID given paths %s and %s..." % (paths[0], paths[1])) 69 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 70 | inception = InceptionV3(for_train=False) 71 | current_dir = Path(os.path.realpath(__file__)).parent 72 | ckpt = torch.load(current_dir / "afhq_inception_v3.ckpt") 73 | inception.load_state_dict(ckpt) 74 | inception = inception.eval().to(device) 75 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] 76 | 77 | mu, cov = [], [] 78 | for loader in loaders: 79 | actvs = [] 80 | for x in tqdm(loader, total=len(loader)): 81 | actv = inception(x.to(device)) 82 | actvs.append(actv) 83 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy() 84 | mu.append(np.mean(actvs, axis=0)) 85 | cov.append(np.cov(actvs, rowvar=False)) 86 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1]) 87 | return fid_value 88 | 89 | if __name__ == "__main__": 90 | # python measure_fid /path/to/dir1 /path/to/dir2 91 | 92 | paths = [sys.argv[1], sys.argv[2]] 93 | fid_value = calculate_fid_given_paths(paths, img_size=256, batch_size=64) 94 | print("FID:", fid_value) 95 | 96 | -------------------------------------------------------------------------------- /image_diffusion_todo/fid/train_classifier.py: -------------------------------------------------------------------------------- 1 | from inception import InceptionV3 2 | import sys 3 | sys.path.append("..") 4 | import torch 5 | import torch.nn.functional as F 6 | from dataset import AFHQDataset, AFHQDataModule 7 | from tqdm import tqdm 8 | import cv2 9 | import numpy as np 10 | 11 | data_module = AFHQDataModule("/home/juil/workspace/23summer_tutorial/HelloScore/image_diffusion/data/", 32, 4, -1, 256, 0) 12 | 13 | train_dl = data_module.train_dataloader() 14 | val_dl = data_module.val_dataloader() 15 | 16 | device = f"cuda:1" 17 | 18 | net = InceptionV3(for_train=True) 19 | net = net.to(device) 20 | net.train() 21 | for n, p in net.named_parameters(): 22 | p.requires_grad_(True) 23 | 24 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) 25 | 26 | epochs = 10 27 | for epoch in range(epochs): 28 | pbar = tqdm(train_dl) 29 | net.train() 30 | for img, label in pbar: 31 | img, label = img.to(device), label.to(device) 32 | pred = net(img) 33 | loss = F.cross_entropy(pred, label) 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | pred_label = pred.max(-1)[1] 40 | acc = (pred_label == label).float().mean() 41 | 42 | pbar.set_description(f"E {epoch} | loss: {loss:.4f} acc: {acc*100:.2f}%") 43 | 44 | net.eval() 45 | val_accs = [] 46 | for img, label in val_dl: 47 | img, label = img.to(device), label.to(device) 48 | pred = net(img) 49 | pred_label = pred.max(-1)[1] 50 | 51 | acc = (pred_label == label).float().mean() 52 | val_accs.append(acc) 53 | print(f"Val Acc: {sum(val_accs) / len(val_accs) * 100:.2f}%") 54 | 55 | torch.save(net.state_dict(), "afhq_inception_v3.ckpt") 56 | print("Saved model") 57 | 58 | -------------------------------------------------------------------------------- /image_diffusion_todo/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | class Swish(nn.Module): 10 | def forward(self, x): 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | class DownSample(nn.Module): 15 | def __init__(self, in_ch): 16 | super().__init__() 17 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 18 | self.initialize() 19 | 20 | def initialize(self): 21 | init.xavier_uniform_(self.main.weight) 22 | init.zeros_(self.main.bias) 23 | 24 | def forward(self, x, temb): 25 | x = self.main(x) 26 | return x 27 | 28 | 29 | class UpSample(nn.Module): 30 | def __init__(self, in_ch): 31 | super().__init__() 32 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 33 | self.initialize() 34 | 35 | def initialize(self): 36 | init.xavier_uniform_(self.main.weight) 37 | init.zeros_(self.main.bias) 38 | 39 | def forward(self, x, temb): 40 | _, _, H, W = x.shape 41 | x = F.interpolate(x, scale_factor=2, mode="nearest") 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class AttnBlock(nn.Module): 47 | def __init__(self, in_ch): 48 | super().__init__() 49 | self.group_norm = nn.GroupNorm(32, in_ch) 50 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 51 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 52 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 53 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 54 | self.initialize() 55 | 56 | def initialize(self): 57 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 58 | init.xavier_uniform_(module.weight) 59 | init.zeros_(module.bias) 60 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 61 | 62 | def forward(self, x): 63 | B, C, H, W = x.shape 64 | h = self.group_norm(x) 65 | q = self.proj_q(h) 66 | k = self.proj_k(h) 67 | v = self.proj_v(h) 68 | 69 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 70 | k = k.view(B, C, H * W) 71 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 72 | assert list(w.shape) == [B, H * W, H * W] 73 | w = F.softmax(w, dim=-1) 74 | 75 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 76 | h = torch.bmm(w, v) 77 | assert list(h.shape) == [B, H * W, C] 78 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 79 | h = self.proj(h) 80 | 81 | return x + h 82 | 83 | 84 | class ResBlock(nn.Module): 85 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): 86 | super().__init__() 87 | self.block1 = nn.Sequential( 88 | nn.GroupNorm(32, in_ch), 89 | Swish(), 90 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 91 | ) 92 | self.temb_proj = nn.Sequential( 93 | Swish(), 94 | nn.Linear(tdim, out_ch), 95 | ) 96 | self.block2 = nn.Sequential( 97 | nn.GroupNorm(32, out_ch), 98 | Swish(), 99 | nn.Dropout(dropout), 100 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 101 | ) 102 | if in_ch != out_ch: 103 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 104 | else: 105 | self.shortcut = nn.Identity() 106 | if attn: 107 | self.attn = AttnBlock(out_ch) 108 | else: 109 | self.attn = nn.Identity() 110 | self.initialize() 111 | 112 | def initialize(self): 113 | for module in self.modules(): 114 | if isinstance(module, (nn.Conv2d, nn.Linear)): 115 | init.xavier_uniform_(module.weight) 116 | init.zeros_(module.bias) 117 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) 118 | 119 | def forward(self, x, temb): 120 | h = self.block1(x) 121 | h += self.temb_proj(temb)[:, :, None, None] 122 | h = self.block2(h) 123 | 124 | h = h + self.shortcut(x) 125 | h = self.attn(h) 126 | return h 127 | 128 | 129 | class TimeEmbedding(nn.Module): 130 | def __init__(self, hidden_size, frequency_embedding_size=256): 131 | super().__init__() 132 | self.mlp = nn.Sequential( 133 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 134 | nn.SiLU(), 135 | nn.Linear(hidden_size, hidden_size, bias=True), 136 | ) 137 | self.frequency_embedding_size = frequency_embedding_size 138 | 139 | @staticmethod 140 | def timestep_embedding(t, dim, max_period=10000): 141 | """ 142 | Create sinusoidal timestep embeddings. 143 | :param t: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an (N, D) Tensor of positional embeddings. 148 | """ 149 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 150 | half = dim // 2 151 | freqs = torch.exp( 152 | -math.log(max_period) 153 | * torch.arange(start=0, end=half, dtype=torch.float32) 154 | / half 155 | ).to(device=t.device) 156 | args = t[:, None].float() * freqs[None] 157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 158 | if dim % 2: 159 | embedding = torch.cat( 160 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 161 | ) 162 | return embedding 163 | 164 | def forward(self, t): 165 | if t.ndim == 0: 166 | t = t.unsqueeze(-1) 167 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 168 | t_emb = self.mlp(t_freq) 169 | return t_emb 170 | -------------------------------------------------------------------------------- /image_diffusion_todo/network.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from module import DownSample, ResBlock, Swish, TimeEmbedding, UpSample 7 | from torch.nn import init 8 | 9 | 10 | class UNet(nn.Module): 11 | def __init__( 12 | self, 13 | T: int = 1000, 14 | image_resolution: int = 64, 15 | ch: int = 128, 16 | ch_mult: List[int] = [1, 2, 2, 2], 17 | attn: List[int] = [1], 18 | num_res_blocks: int = 4, 19 | dropout: float = 0.1, 20 | use_cfg: bool = False, 21 | cfg_dropout: float = 0.1, 22 | num_classes: Optional[int] = None, 23 | ): 24 | super().__init__() 25 | self.image_resolution = image_resolution 26 | 27 | # TODO: Implement an architecture according to the provided architecture diagram. 28 | # You can use the modules in `module.py`. 29 | 30 | def forward(self, x, timestep, class_label=None): 31 | """ 32 | Input: 33 | x (`torch.Tensor [B,C,H,W]`) 34 | timestep (`torch.Tensor [B]`) 35 | class_label (`torch.Tensor [B]`, optional) 36 | Output: 37 | out (`torch.Tensor [B,C,H,W]`): noise prediction. 38 | """ 39 | assert ( 40 | x.shape[-1] == x.shape[-2] == self.image_resolution 41 | ), f"The resolution of x ({x.shape[-2]}, {x.shape[-1]}) does not match with the image resolution ({self.image_resolution})." 42 | 43 | # TODO: Implement noise prediction network's forward function. 44 | 45 | out = torch.randn_like(x) 46 | return out 47 | -------------------------------------------------------------------------------- /image_diffusion_todo/sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from dataset import tensor_to_pil_image 6 | from ddpm import DiffusionModule 7 | from network import UNet 8 | from scheduler import DDIMScheduler, DDPMScheduler 9 | from pathlib import Path 10 | 11 | 12 | def main(args): 13 | save_dir = Path(args.save_dir) 14 | save_dir.mkdir(exist_ok=True, parents=True) 15 | 16 | device = f"cuda:{args.gpu}" 17 | 18 | ddpm = DiffusionModule(None, None) 19 | ddpm.load(args.ckpt_path) 20 | ddpm.eval() 21 | ddpm = ddpm.to(device) 22 | 23 | if isinstance(ddpm.var_scheduler, DDIMScheduler): 24 | ddpm.var_scheduler.set_timesteps(20) 25 | 26 | total_num_samples = 500 27 | num_batches = int(np.ceil(total_num_samples / args.batch_size)) 28 | 29 | for i in range(num_batches): 30 | sidx = i * args.batch_size 31 | eidx = min(sidx + args.batch_size, total_num_samples) 32 | samples = ddpm.sample(eidx - sidx) 33 | pil_images = tensor_to_pil_image(samples) 34 | 35 | for j, img in zip(range(sidx, eidx), pil_images): 36 | img.save(save_dir / f"{j}.png") 37 | print(f"Saved the {j}-th image.") 38 | 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--batch_size", type=int, default=32) 44 | parser.add_argument("--gpu", type=int, default=0) 45 | parser.add_argument("--ckpt_path", type=str) 46 | parser.add_argument("--save_dir", type=str) 47 | 48 | args = parser.parse_args() 49 | main(args) 50 | -------------------------------------------------------------------------------- /image_diffusion_todo/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class BaseScheduler(nn.Module): 9 | def __init__( 10 | self, num_train_timesteps: int = 1000, beta_1: float = 1e-4, beta_T: float = 0.02, mode="linear" 11 | ): 12 | super().__init__() 13 | self.num_train_timesteps = num_train_timesteps 14 | self.num_inference_timesteps = num_train_timesteps 15 | self.timesteps = torch.from_numpy( 16 | np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64) 17 | ) 18 | 19 | if mode == "linear": 20 | betas = torch.linspace(beta_1, beta_T, steps=num_train_timesteps) 21 | elif mode == "quad": 22 | betas = ( 23 | torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps) ** 2 24 | ) 25 | else: 26 | raise NotImplementedError(f"{mode} is not implemented.") 27 | 28 | # TODO: Compute alphas and alphas_cumprod 29 | # alphas and alphas_cumprod correspond to $\alpha$ and $\bar{\alpha}$ in the DDPM paper (https://arxiv.org/abs/2006.11239). 30 | alphas = alphas_cumprod = betas 31 | 32 | self.register_buffer("betas", betas) 33 | self.register_buffer("alphas", alphas) 34 | self.register_buffer("alphas_cumprod", alphas_cumprod) 35 | 36 | def uniform_sample_t( 37 | self, batch_size, device: Optional[torch.device] = None 38 | ) -> torch.IntTensor: 39 | """ 40 | Uniformly sample timesteps. 41 | """ 42 | ts = np.random.choice(np.arange(self.num_train_timesteps), batch_size) 43 | ts = torch.from_numpy(ts) 44 | if device is not None: 45 | ts = ts.to(device) 46 | return ts 47 | 48 | 49 | class DDPMScheduler(BaseScheduler): 50 | def __init__( 51 | self, 52 | num_train_timesteps: int, 53 | beta_1: float, 54 | beta_T: float, 55 | mode="linear", 56 | sigma_type="small", 57 | ): 58 | super().__init__(num_train_timesteps, beta_1, beta_T, mode) 59 | 60 | # sigmas correspond to $\sigma_t$ in the DDPM paper. 61 | self.sigma_type = sigma_type 62 | if sigma_type == "small": 63 | # when $\sigma_t^2 = \tilde{\beta}_t$. 64 | alphas_cumprod_t_prev = torch.cat( 65 | [torch.tensor(1.0), self.alphas_cumprod[-1:]] 66 | ) 67 | sigmas = ( 68 | (1 - alphas_cumprod_t_prev) / (1 - self.alphas_cumprod) * self.betas 69 | ) ** 0.5 70 | elif sigma_type == "large": 71 | # when $\sigma_t^2 = \beta_t$. 72 | sigmas = self.betas ** 0.5 73 | 74 | self.register_buffer("sigmas", sigmas) 75 | 76 | def step(self, sample: torch.Tensor, timestep: int, noise_pred: torch.Tensor): 77 | """ 78 | One step denoising function of DDPM: x_t -> x_{t-1}. 79 | 80 | Input: 81 | sample (`torch.Tensor [B,C,H,W]`): samples at arbitrary timestep t. 82 | timestep (`int`): current timestep in a reverse process. 83 | noise_pred (`torch.Tensor [B,C,H,W]`): predicted noise from a learned model. 84 | Ouptut: 85 | sample_prev (`torch.Tensor [B,C,H,W]`): one step denoised sample. (= x_{t-1}) 86 | """ 87 | 88 | # TODO: Implement the DDPM's one step denoising function. 89 | # Refer to Algorithm 2 in the DDPM paper (https://arxiv.org/abs/2006.11239). 90 | 91 | sample_prev = sample 92 | 93 | return sample_prev 94 | 95 | def add_noise( 96 | self, 97 | original_sample: torch.Tensor, 98 | timesteps: torch.IntTensor, 99 | noise: Optional[torch.Tensor] = None, 100 | ): 101 | """ 102 | A forward pass of a Markov chain, i.e., q(x_t | x_0). 103 | 104 | Input: 105 | sample (`torch.Tensor [B,C,H,W]`): samples from a real data distribution q(x_0). 106 | timesteps: (`torch.IntTensor [B]`) 107 | noise: (`torch.Tensor [B,C,H,W]`, optional): if None, randomly sample Gaussian noise in the function. 108 | Output: 109 | x_noisy: (`torch.Tensor [B,C,H,W]`): noisy samples 110 | noise: (`torch.Tensor [B,C,H,W]`): injected noise. 111 | """ 112 | 113 | # TODO: Implement the function that samples $\mathbf{x}_t$ from $\mathbf{x}_0$. 114 | # Refer to Equation 4 in the DDPM paper (https://arxiv.org/abs/2006.11239). 115 | 116 | noisy_sample = noise = original_sample 117 | 118 | return noisy_sample, noise 119 | 120 | 121 | class DDIMScheduler(BaseScheduler): 122 | def __init__(self, num_train_timesteps, beta_1, beta_T, mode="linear"): 123 | super().__init__(num_train_timesteps, beta_1, beta_T, mode) 124 | 125 | def set_timesteps( 126 | self, num_inference_timesteps: int, device: Union[str, torch.device] = None 127 | ): 128 | """ 129 | Sets the timesteps of a diffusion Markov chain. It is for accelerated generation process (Sec. 4.2) in the DDIM paper (https://arxiv.org/abs/2010.02502). 130 | """ 131 | if num_inference_timesteps > self.num_train_timesteps: 132 | raise ValueError( 133 | f"num_inference_timesteps ({num_inference_timesteps}) cannot exceed self.num_train_timesteps ({self.num_train_timesteps})" 134 | ) 135 | 136 | self.num_inference_timesteps = num_inference_timesteps 137 | 138 | step_ratio = self.num_train_timesteps // num_inference_timesteps 139 | timesteps = ( 140 | (np.arange(0, num_inference_timesteps) * step_ratio) 141 | .round()[::-1] 142 | .copy() 143 | .astype(np.int64) 144 | ) 145 | self.timesteps = torch.from_numpy(timesteps) 146 | 147 | def step( 148 | self, 149 | sample: torch.Tensor, 150 | timestep: int, 151 | noise_pred: torch.Tensor, 152 | eta: float = 0.0, 153 | ): 154 | """ 155 | One step denoising function of DDIM: $x_{\tau_i}$ -> $x_{\tau_{i-1}}$. 156 | 157 | Input: 158 | sample (`torch.Tensor [B,C,H,W]`): samples at arbitrary timestep $\tau_i$. 159 | timestep (`int`): current timestep in a reverse process. 160 | noise_pred (`torch.Tensor [B,C,H,W]`): predicted noise from a learned model. 161 | eta (float): correspond to η in DDIM which controls the stochasticity of a reverse process. 162 | Ouptut: 163 | sample_prev (`torch.Tensor [B,C,H,W]`): one step denoised sample. (= $x_{\tau_{i-1}}$) 164 | """ 165 | # TODO: Implement the DDIM's one step denoising function. 166 | # Refer to Equation 12 in the DDIM paper (https://arxiv.org/abs/2010.02502). 167 | 168 | sample_prev = sample 169 | 170 | return sample_prev 171 | -------------------------------------------------------------------------------- /image_diffusion_todo/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from dataset import AFHQDataModule, get_data_iterator, tensor_to_pil_image 10 | from dotmap import DotMap 11 | from ddpm import DiffusionModule 12 | from network import UNet 13 | from pytorch_lightning import seed_everything 14 | from scheduler import DDIMScheduler, DDPMScheduler 15 | from torchvision.transforms.functional import to_pil_image 16 | from tqdm import tqdm 17 | 18 | matplotlib.use("Agg") 19 | 20 | 21 | def get_current_time(): 22 | now = datetime.now().strftime("%m-%d-%H%M%S") 23 | return now 24 | 25 | 26 | def main(args): 27 | """config""" 28 | config = DotMap() 29 | config.update(vars(args)) 30 | config.device = f"cuda:{args.gpu}" 31 | 32 | now = get_current_time() 33 | save_dir = Path(f"results/diffusion-{now}") 34 | save_dir.mkdir(exist_ok=True) 35 | print(f"save_dir: {save_dir}") 36 | 37 | seed_everything(config.seed) 38 | 39 | with open(save_dir / "config.json", "w") as f: 40 | json.dump(config, f, indent=2) 41 | """######""" 42 | 43 | image_resolution = config.image_resolution 44 | ds_module = AFHQDataModule( 45 | "./data", 46 | batch_size=config.batch_size, 47 | num_workers=4, 48 | max_num_images_per_cat=config.max_num_images_per_cat, 49 | image_resolution=image_resolution 50 | ) 51 | 52 | train_dl = ds_module.train_dataloader() 53 | train_it = get_data_iterator(train_dl) 54 | 55 | var_scheduler = DDPMScheduler( 56 | config.num_diffusion_train_timesteps, 57 | beta_1=config.beta_1, 58 | beta_T=config.beta_T, 59 | mode="linear", 60 | ) 61 | if isinstance(var_scheduler, DDIMScheduler): 62 | var_scheduler.set_timesteps(20) # 20 steps are enough in the case of DDIM. 63 | 64 | network = UNet( 65 | T=config.num_diffusion_train_timesteps, 66 | image_resolution=image_resolution, 67 | ch=128, 68 | ch_mult=[1, 2, 2, 2], 69 | attn=[1], 70 | num_res_blocks=4, 71 | dropout=0.1, 72 | use_cfg=args.use_cfg, 73 | cfg_dropout=args.cfg_dropout, 74 | num_classes=getattr(ds_module, "num_classes", None), 75 | ) 76 | 77 | ddpm = DiffusionModule(network, var_scheduler) 78 | ddpm = ddpm.to(config.device) 79 | 80 | optimizer = torch.optim.Adam(ddpm.network.parameters(), lr=2e-4) 81 | scheduler = torch.optim.lr_scheduler.LambdaLR( 82 | optimizer, lr_lambda=lambda t: min((t + 1) / config.warmup_steps, 1.0) 83 | ) 84 | 85 | step = 0 86 | losses = [] 87 | with tqdm(initial=step, total=config.train_num_steps) as pbar: 88 | while step < config.train_num_steps: 89 | if step % config.log_interval == 0: 90 | ddpm.eval() 91 | plt.plot(losses) 92 | plt.savefig(f"{save_dir}/loss.png") 93 | plt.close() 94 | 95 | samples = ddpm.sample(4, return_traj=False) 96 | pil_images = tensor_to_pil_image(samples) 97 | for i, img in enumerate(pil_images): 98 | img.save(save_dir / f"step={step}-{i}.png") 99 | 100 | ddpm.save(f"{save_dir}/last.ckpt") 101 | ddpm.train() 102 | 103 | img, label = next(train_it) 104 | img, label = img.to(config.device), label.to(config.device) 105 | loss = ddpm.get_loss(img, class_label=label) 106 | pbar.set_description(f"Loss: {loss.item():.4f}") 107 | 108 | optimizer.zero_grad() 109 | loss.backward() 110 | optimizer.step() 111 | scheduler.step() 112 | losses.append(loss.item()) 113 | 114 | step += 1 115 | pbar.update(1) 116 | 117 | print(f"last.ckpt is saved at {save_dir}") 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--gpu", type=int, default=0) 122 | parser.add_argument("--batch_size", type=int, default=32) 123 | parser.add_argument( 124 | "--train_num_steps", 125 | type=int, 126 | default=100000, 127 | help="the number of model training steps.", 128 | ) 129 | parser.add_argument("--warmup_steps", type=int, default=200) 130 | parser.add_argument("--log_interval", type=int, default=200) 131 | parser.add_argument( 132 | "--max_num_images_per_cat", 133 | type=int, 134 | default=-1, 135 | help="max number of images per category for AFHQ dataset", 136 | ) 137 | parser.add_argument( 138 | "--num_diffusion_train_timesteps", 139 | type=int, 140 | default=1000, 141 | help="diffusion Markov chain num steps", 142 | ) 143 | parser.add_argument("--beta_1", type=float, default=1e-4) 144 | parser.add_argument("--beta_T", type=float, default=0.02) 145 | parser.add_argument("--seed", type=int, default=63) 146 | parser.add_argument("--image_resolution", type=int, default=64) 147 | 148 | args = parser.parse_args() 149 | main(args) 150 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.1.3 2 | ipython==8.12.0 3 | jupyter==1.0.0 4 | matplotlib==3.6.0 5 | torch==2.0.1 6 | torch-ema==0.3 7 | torchvision==0.15.2 8 | tqdm==4.64.1 9 | jupyterlab==4.0.2 10 | jaxtyping==0.2.20 11 | pytorch_lightning 12 | dotmap 13 | -------------------------------------------------------------------------------- /sde_todo/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from sklearn import datasets 5 | 6 | def normalize(ds, scaling_factor=2.): 7 | return (ds - ds.mean()) / \ 8 | ds.std() * scaling_factor 9 | 10 | 11 | def sample_checkerboard(n): 12 | # https://github.com/ghliu/SB-FBSDE/blob/main/data.py 13 | n_points = 3*n 14 | n_classes = 2 15 | freq = 5 16 | x = np.random.uniform(-(freq//2)*np.pi, (freq//2)*np.pi, size=(n_points, n_classes)) 17 | mask = np.logical_or(np.logical_and(np.sin(x[:,0]) > 0.0, np.sin(x[:,1]) > 0.0), \ 18 | np.logical_and(np.sin(x[:,0]) < 0.0, np.sin(x[:,1]) < 0.0)) 19 | y = np.eye(n_classes)[1*mask] 20 | x0 = x[:,0]*y[:,0] 21 | x1 = x[:,1]*y[:,0] 22 | sample = np.concatenate([x0[...,None],x1[...,None]],axis=-1) 23 | sqr = np.sum(np.square(sample),axis=-1) 24 | idxs = np.where(sqr==0) 25 | sample = np.delete(sample,idxs,axis=0) 26 | 27 | return sample 28 | 29 | 30 | def load_twodim(num_samples: int, 31 | dataset: str, 32 | dimension: int = 2): 33 | 34 | if dataset == 'gaussian_centered': 35 | sample = np.random.normal(size=(num_samples, dimension)) 36 | sample = sample 37 | 38 | if dataset == 'gaussian_shift': 39 | sample = np.random.normal(size=(num_samples, dimension)) 40 | sample = sample + 1.5 41 | 42 | if dataset == 'circle': 43 | X, y = datasets.make_circles( 44 | n_samples=num_samples, noise=0.0, random_state=None, factor=.5) 45 | sample = X * 4 46 | 47 | if dataset == 'scurve': 48 | X, y = datasets.make_s_curve( 49 | n_samples=num_samples, noise=0.0, random_state=None) 50 | sample = normalize(X[:, [0, 2]]) 51 | 52 | if dataset == 'moon': 53 | X, y = datasets.make_moons( 54 | n_samples=num_samples, noise=0.0, random_state=None) 55 | sample = normalize(X) 56 | 57 | if dataset == 'swiss_roll': 58 | X, y = datasets.make_swiss_roll( 59 | n_samples=num_samples, noise=0.0, random_state=None, hole=True) 60 | sample = normalize(X[:, [0, 2]]) 61 | 62 | if dataset == 'checkerboard': 63 | sample = normalize(sample_checkerboard(num_samples)) 64 | 65 | return torch.tensor(sample).float() 66 | 67 | 68 | class TwoDimDataClass(Dataset): 69 | 70 | def __init__(self, 71 | dataset_type: str, 72 | N: int, 73 | batch_size: int, 74 | dimension = 2): 75 | 76 | self.X = load_twodim(N, dataset_type, dimension=dimension) 77 | self.name = dataset_type 78 | self.batch_size = batch_size 79 | self.dimension = 2 80 | 81 | def __len__(self): 82 | return self.X.shape[0] 83 | 84 | def __getitem__(self, idx): 85 | return self.X[idx] 86 | 87 | def get_dataloader(self, shuffle=True): 88 | return DataLoader(self, 89 | batch_size=self.batch_size, 90 | shuffle=shuffle, 91 | pin_memory=True, 92 | ) 93 | -------------------------------------------------------------------------------- /sde_todo/eval.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import linear_sum_assignment 2 | import numpy as np 3 | import torch 4 | 5 | def CD(a, b): 6 | x, y = a, b 7 | bs, num_points, points_dim = x.size() 8 | xx = torch.bmm(x, x.transpose(2, 1)) 9 | yy = torch.bmm(y, y.transpose(2, 1)) 10 | zz = torch.bmm(x, y.transpose(2, 1)) 11 | diag_ind = torch.arange(0, num_points).to(a).long() 12 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) 13 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) 14 | P = (rx.transpose(2, 1) + ry - 2 * zz) 15 | return P.min(1)[0], P.min(2)[0] 16 | 17 | def EMD(x, y): 18 | bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2) 19 | assert npts == mpts, "EMD only works if two point clouds are equal size" 20 | dim = x.shape[-1] 21 | x = x.reshape(bs, npts, 1, dim) 22 | y = y.reshape(bs, 1, mpts, dim) 23 | dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts) 24 | 25 | emd_lst = [] 26 | dist_np = dist.cpu().detach().numpy() 27 | for i in range(bs): 28 | d_i = dist_np[i] 29 | r_idx, c_idx = linear_sum_assignment(d_i) 30 | emd_i = d_i[r_idx, c_idx].mean() 31 | emd_lst.append(emd_i) 32 | emd = np.stack(emd_lst).reshape(-1) 33 | emd_torch = torch.from_numpy(emd).to(x) 34 | return emd_torch 35 | 36 | -------------------------------------------------------------------------------- /sde_todo/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from jaxtyping import Array 4 | from typing import Callable 5 | 6 | def get_div(y: Array, x: Array): 7 | """ 8 | (Optional) 9 | Return the divergence of y wrt x. Let y = f(x). get_div return div_x(y). 10 | 11 | Args: 12 | x Input of a differentiable function 13 | y Output of a differentiable function 14 | 15 | Returns: 16 | div_x(y) 17 | """ 18 | pass 19 | 20 | class DSMLoss(): 21 | 22 | def __init__(self, alpha: float, diff_weight: bool): 23 | """ 24 | (TODO) Initialize the DSM Loss. 25 | 26 | Args: 27 | alpha: regularization weight 28 | diff_weight: scale loss by square of diffusion 29 | 30 | Returns: 31 | None 32 | """ 33 | 34 | def __call__(self, 35 | t: Array, 36 | x: Array, 37 | model: Callable[[Array], Array], 38 | s: Array, 39 | diff_sq: Float): 40 | """ 41 | Args: 42 | t: uniformly sampled time period 43 | x: samples after t diffusion 44 | model: score prediction function s(x,t) 45 | s: ground truth score 46 | 47 | Returns: 48 | loss: average loss value 49 | """ 50 | loss = None 51 | return loss 52 | 53 | 54 | class ISMLoss(): 55 | """ 56 | (Optional) Implicit Score Matching Loss 57 | """ 58 | 59 | def __init__(self): 60 | pass 61 | 62 | def __call__(self, t, x, model): 63 | """ 64 | Args: 65 | t: uniformly sampled time period 66 | x: samples after t diffusion 67 | model: score prediction function s(x,t) 68 | 69 | Returns: 70 | loss: average loss value 71 | """ 72 | return loss 73 | 74 | 75 | class SBJLoss(): 76 | """ 77 | (Optional) Joint Schrodinger Bridge Loss 78 | 79 | hint: You will need to implement the divergence. 80 | """ 81 | 82 | def __init__(self): 83 | pass 84 | 85 | def __call__(self, t, xf, zf, zb_fn): 86 | """ 87 | Initialize the SBJLoss Loss. 88 | 89 | Args: 90 | t: uniformly sampled time period 91 | xf: samples after t forward diffusion 92 | zf: ground truth forward value 93 | zb_fn: backward Z function 94 | 95 | Returns: 96 | loss: average loss value 97 | """ 98 | return loss 99 | 100 | 101 | class SBALoss(): 102 | """ 103 | (Optional) Alternating Schrodinger Bridge Loss 104 | 105 | hint: You will need to implement the divergence. 106 | """ 107 | 108 | def __init__(self): 109 | pass 110 | 111 | def __call__(self, t, xf, zf, zb_fn): 112 | """ 113 | Initialize the SBALoss Loss. 114 | 115 | Args: 116 | t: uniformly sampled time period 117 | xf: samples after t forward diffusion 118 | zf: ground truth forward value 119 | zb_fn: backward Z function 120 | 121 | Returns: 122 | loss: average loss value 123 | """ 124 | return loss 125 | -------------------------------------------------------------------------------- /sde_todo/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from jaxtyping import Array, Int, Float 6 | 7 | class PositionalEncoding(nn.Module): 8 | 9 | def __init__(self, t_channel: Int): 10 | """ 11 | (Optional) Initialize positional encoding network 12 | 13 | Args: 14 | t_channel: number of modulation channel 15 | """ 16 | super().__init__() 17 | 18 | def forward(self, t: Float): 19 | """ 20 | Return the positional encoding of 21 | 22 | Args: 23 | t: input time 24 | 25 | Returns: 26 | emb: time embedding 27 | """ 28 | emb = None 29 | return emb 30 | 31 | 32 | class MLP(nn.Module): 33 | 34 | def __init__(self, 35 | in_dim: Int, 36 | out_dim: Int, 37 | hid_shapes: Int[Array]): 38 | ''' 39 | (TODO) Build simple MLP 40 | 41 | Args: 42 | in_dim: input dimension 43 | out_dim: output dimension 44 | hid_shapes: array of hidden layers' dimension 45 | ''' 46 | super().__init__() 47 | self.model = None 48 | 49 | def forward(self, x: Array): 50 | return self.model(x) 51 | 52 | 53 | 54 | class SimpleNet(nn.Module): 55 | 56 | def __init__(self, 57 | in_dim: Int, 58 | enc_shapes: Int[Array], 59 | dec_shapes: Int[Array], 60 | z_dim: Int): 61 | super().__init__() 62 | ''' 63 | (TODO) Build Score Estimation network. 64 | You are free to modify this function signature. 65 | You can design whatever architecture. 66 | 67 | hint: it's recommended to first encode the time and x to get 68 | time and x embeddings then concatenate them before feeding it 69 | to the decoder. 70 | 71 | Args: 72 | in_dim: dimension of input 73 | enc_shapes: array of dimensions of encoder 74 | dec_shapes: array of dimensions of decoder 75 | z_dim: output dimension of encoder 76 | ''' 77 | 78 | def forward(self, t: Array, x: Array): 79 | ''' 80 | (TODO) Implement the forward pass. This should output 81 | the score s of the noisy input x. 82 | 83 | hint: you are free 84 | 85 | Args: 86 | t: the time that the forward diffusion has been running 87 | x: the noisy data after t period diffusion 88 | ''' 89 | s = None 90 | return s 91 | -------------------------------------------------------------------------------- /sde_todo/sampling.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from tqdm import tqdm 4 | from jaxtyping import Float 5 | 6 | class Sampler(): 7 | 8 | def __init__(self, eps: Float): 9 | self.eps = eps 10 | 11 | def get_sampling_fn(self, sde, dataset): 12 | 13 | def sampling_fn(N_samples: Int): 14 | """ 15 | return the final denoised sample, number of step, 16 | timesteps, and trajectory. 17 | 18 | Args: 19 | N_samples: number of samples 20 | 21 | Returns: 22 | out: the final denoised samples (out == x_traj[-1]) 23 | ntot (int): the total number of timesteps 24 | timesteps Int[Array]: the array of timesteps used 25 | x_traj: the entire diffusion trajectory 26 | """ 27 | x = dataset[range(N_samples)] # initial sample 28 | timesteps = torch.linspace(0, sde.T-self.eps, sde.N) 29 | 30 | x_traj = torch.zeros((sde.N, *x.shape)) 31 | with torch.no_grad(): 32 | for i, t in enumerate(tqdm(timesteps, desc='sampling')): 33 | # TODO 34 | pass 35 | 36 | out = x 37 | ntot = sde.N 38 | return out, ntot, timesteps, x_traj 39 | 40 | return sampling_fn 41 | -------------------------------------------------------------------------------- /sde_todo/sde.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import numpy as np 4 | from jaxtyping import Array 5 | 6 | class SDE(abc.ABC): 7 | def __init__(self, N: int, T: int): 8 | super().__init__() 9 | self.N = N # number of discretization steps 10 | self.T = T # terminal time 11 | self.dt = T / N 12 | self.is_reverse = False 13 | self.is_bridge = False 14 | 15 | @abc.abstractmethod 16 | def sde_coeff(self, t, x): 17 | return NotImplemented 18 | 19 | @abc.abstractmethod 20 | def marginal_prob(self, t, x): 21 | return NotImplemented 22 | 23 | @abc.abstractmethod 24 | def predict_fn(self, x): 25 | return NotImplemented 26 | 27 | @abc.abstractmethod 28 | def correct_fn(self, t, x): 29 | return NotImplemented 30 | 31 | def dw(self, x, dt=None): 32 | """ 33 | (TODO) Return the differential of Brownian motion 34 | 35 | Args: 36 | x: input data 37 | 38 | Returns: 39 | dw (same shape as x) 40 | """ 41 | dt = self.dt if dt is None else dt 42 | return None 43 | 44 | def prior_sampling(self, x: Array): 45 | """ 46 | Sampling from prior distribution. Default to unit gaussian. 47 | 48 | Args: 49 | x: input data 50 | 51 | Returns: 52 | z: random variable with same shape as x 53 | """ 54 | return torch.randn_like(x) 55 | 56 | def predict_fn(self, 57 | t: Array, 58 | x: Array, 59 | dt: Float=None): 60 | """ 61 | (TODO) Perform single step diffusion. 62 | 63 | Args: 64 | t: current diffusion time 65 | x: input with noise level at time t 66 | dt: the discrete time step. Default to T/N 67 | 68 | Returns: 69 | x: input at time t+dt 70 | """ 71 | dt = self.dt if dt is None else dt 72 | pred = None 73 | return pred 74 | 75 | def correct_fn(self, t: Array, x: Array): 76 | return None 77 | 78 | def reverse(self, model): 79 | N = self.N 80 | T = self.T 81 | forward_sde_coeff = self.sde_coeff 82 | 83 | class RSDE(self.__class__): 84 | def __init__(self, score_fn): 85 | super().__init__(N, T) 86 | self.score_fn = score_fn 87 | self.is_reverse = True 88 | self.forward_sde_coeff = forward_sde_coeff 89 | 90 | def sde_coeff(self, t: Array, x: Array): 91 | """ 92 | (TODO) Return the reverse drift and diffusion terms. 93 | 94 | Args: 95 | t: current diffusion time 96 | x: current input at time t 97 | 98 | Returns: 99 | reverse_f: reverse drift term 100 | g: reverse diffusion term 101 | """ 102 | reverse_f = None 103 | g = None 104 | return reverse_f, g 105 | 106 | def ode_coeff(self, t: Array, x: Array): 107 | """ 108 | (Optional) Return the reverse drift and diffusion terms in 109 | ODE sampling. 110 | 111 | Args: 112 | t: current diffusion time 113 | x: current input at time t 114 | 115 | Returns: 116 | reverse_f: reverse drift term 117 | g: reverse diffusion term 118 | """ 119 | reverse_f = None 120 | g = None 121 | return reverse_f, g 122 | 123 | def predict_fn(self, 124 | t: Array, 125 | x, 126 | dt=None, 127 | ode=False): 128 | """ 129 | (TODO) Perform single step reverse diffusion 130 | 131 | """ 132 | return x 133 | 134 | return RSDE(model) 135 | 136 | class OU(SDE): 137 | def __init__(self, N=1000, T=1): 138 | super().__init__(N, T) 139 | 140 | def sde_coeff(self, t, x): 141 | f, g = None, None 142 | return f, g 143 | 144 | def marginal_prob(self, t, x): 145 | mean, std = None, None 146 | return mean, std 147 | 148 | class VESDE(SDE): 149 | def __init__(self, N=100, T=1, sigma_min=0.01, sigma_max=50): 150 | super().__init__(N, T) 151 | 152 | def sde_coeff(self, t, x): 153 | f, g = None, None 154 | return f, g 155 | 156 | def marginal_prob(self, t, x): 157 | mean, std = None, None 158 | return mean, std 159 | 160 | 161 | class VPSDE(SDE): 162 | def __init__(self, N=1000, T=1, beta_min=0.1, beta_max=20): 163 | super().__init__(N, T) 164 | 165 | def sde_coeff(self, t, x): 166 | f, g = None, None 167 | return f, g 168 | 169 | def marginal_prob(self, t, x): 170 | mean, std = None, None 171 | return mean, std 172 | 173 | 174 | class SB(abc.ABC): 175 | def __init__(self, N=1000, T=1, zf_model=None, zb_model=None): 176 | super().__init__() 177 | self.N = N # number of time step 178 | self.T = T # end time 179 | self.dt = T / N 180 | 181 | self.is_reverse = False 182 | self.is_bridge = True 183 | 184 | self.zf_model = zf_model 185 | self.zb_model = zb_model 186 | 187 | def dw(self, x, dt=None): 188 | dt = self.dt if dt is None else dt 189 | return torch.randn_like(x) * (dt**0.5) 190 | 191 | @abc.abstractmethod 192 | def sde_coeff(self, t, x): 193 | return NotImplemented 194 | 195 | def sb_coeff(self, t, x): 196 | """ 197 | (Optional) Return the SB reverse drift and diffusion terms. 198 | 199 | Args: 200 | """ 201 | sb_f = None 202 | g = None 203 | return sb_f, g 204 | 205 | def predict_fn(self, 206 | t: Array, 207 | x: Array, 208 | dt:Float =None): 209 | """ 210 | Args: 211 | t: 212 | x: 213 | dt: 214 | """ 215 | return x 216 | 217 | def correct_fn(self, t, x, dt=None): 218 | return x 219 | 220 | def reverse(self, model): 221 | """ 222 | (Optional) Initialize the reverse process 223 | """ 224 | 225 | class RSB(self.__class__): 226 | def __init__(self, model): 227 | super().__init__(N, T, zf_model, zb_model) 228 | """ 229 | (Optional) Initialize the reverse process 230 | """ 231 | 232 | def sb_coeff(self, t, x): 233 | """ 234 | (Optional) Return the SB reverse drift and diffusion terms. 235 | """ 236 | sb_f = None 237 | g = None 238 | return sb_f, g 239 | 240 | return RSDE(model) 241 | 242 | class OUSB(SB): 243 | def __init__(self, N=1000, T=1, zf_model=None, zb_model=None): 244 | super().__init__(N, T, zf_model, zb_model) 245 | 246 | def sde_coeff(self, t, x): 247 | f = -0.5 * x 248 | g = torch.ones(x.shape) 249 | return f, g 250 | -------------------------------------------------------------------------------- /sde_todo/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from itertools import repeat 4 | import matplotlib.pyplot as plt 5 | from loss import ISMLoss, DSMLoss 6 | 7 | def freeze(model): 8 | """ 9 | (Optional) This is for Alternating Schrodinger Bridge. 10 | """ 11 | for p in model.parameters(): 12 | p.requires_grad = False 13 | model.eval() 14 | return model 15 | 16 | 17 | def unfreeze(model): 18 | """ 19 | (Optional) This is for Alternating Schrodinger Bridge. 20 | """ 21 | for p in model.parameters(): 22 | p.requires_grad = True 23 | model.train() 24 | return model 25 | 26 | 27 | def get_sde_step_fn(model, ema, opt, loss_fn, sde): 28 | def step_fn(batch): 29 | # uniformly sample time step 30 | t = sde.T*torch.rand(batch.shape[0]) 31 | 32 | # TODO forward diffusion 33 | xt = None 34 | 35 | # get loss 36 | if isinstance(loss_fn, DSMLoss): 37 | logp_grad = None # TODO 38 | loss = loss_fn(t, xt.float(), model, logp_grad, diff_sq) 39 | elif isinstance(loss_fn, ISMLoss): 40 | loss = loss_fn(t, xt.float(), model) 41 | else: 42 | print(loss_fn) 43 | raise Exception("undefined loss") 44 | 45 | # optimize model 46 | opt.zero_grad() 47 | loss.backward() 48 | opt.step() 49 | 50 | if ema is not None: 51 | ema.update() 52 | 53 | return loss.item() 54 | 55 | return step_fn 56 | 57 | 58 | def get_sb_step_fn(model_f, model_b, ema_f, ema_b, 59 | opt_f, opt_b, loss_fn, sb, joint=True): 60 | def step_fn_alter(batch, forward): 61 | """ 62 | (Optional) Implement the optimization step for alternating 63 | likelihood training of Schrodinger Bridge 64 | """ 65 | pass 66 | 67 | def step_fn_joint(batch): 68 | """ 69 | (Optional) Implement the optimization step for joint likelihood 70 | training of Schrodinger Bridge 71 | """ 72 | pass 73 | 74 | if joint: 75 | return step_fn_joint 76 | else: 77 | return step_fn_alter 78 | 79 | 80 | def repeater(data_loader): 81 | for loader in repeat(data_loader): 82 | for data in loader: 83 | yield data 84 | 85 | 86 | def train_diffusion(dataloader, step_fn, N_steps, plot=False): 87 | pbar = tqdm(range(N_steps), bar_format="{desc}{bar}{r_bar}", mininterval=1) 88 | loader = iter(repeater(dataloader)) 89 | 90 | log_freq = 200 91 | loss_history = torch.zeros(N_steps//log_freq) 92 | for i, step in enumerate(pbar): 93 | batch = next(loader) 94 | loss = step_fn(batch) 95 | 96 | if step % log_freq == 0: 97 | loss_history[i//log_freq] = loss 98 | pbar.set_description("Loss: {:.3f}".format(loss)) 99 | 100 | if plot: 101 | plt.plot(range(len(loss_history)), loss_history) 102 | plt.show() 103 | --------------------------------------------------------------------------------