├── .gitignore
├── README.md
├── assets
├── diffusion.py
├── images
│ ├── task1_distribution.png
│ ├── task1_ou.png
│ ├── task1_sampling.png
│ ├── task1_vesde.png
│ ├── task1_vpsde.png
│ ├── task2_1_ddpm_sampling_algorithm.png
│ ├── task2_1_teaser.png
│ ├── task2_2_teaser.png
│ ├── task2_3_repaint_algorithm8.png
│ ├── task2_algorithm.png
│ ├── task2_ddim.png
│ ├── task3_algorithm.png
│ ├── task_2_3_teaser.png
│ ├── teaser.gif
│ └── teaser.png
├── sb_likelihood_training.pdf
└── summary_of_DDPM_and_DDIM.pdf
├── image_diffusion_todo
├── __init__.py
├── dataset.py
├── ddpm.py
├── fid
│ ├── afhq_inception_v3.ckpt
│ ├── inception.py
│ ├── measure_fid.py
│ └── train_classifier.py
├── module.py
├── network.py
├── sampling.py
├── scheduler.py
└── train.py
├── requirements.txt
└── sde_todo
├── HelloScore.ipynb
├── dataset.py
├── eval.py
├── loss.py
├── network.py
├── sampling.py
├── sde.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *.ipynb_checkpoints
4 | data/
5 | results/
6 | samples/
7 | .DS_Store
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Introduction to Diffusion Models
4 |
5 |
6 |
7 | Juil Koo Nguyen Minh Hieu
8 |
9 |
10 |
{63days, hieuristics} [at] kaist.ac.kr
11 |
12 |
13 |
14 |
15 |
16 |

17 |
18 |
19 |
20 | Table of Content
21 |
22 | - [Task 0](#task-0-introduction)
23 | - [Task 1](#task-1-very-simple-sgm-pipeline-with-delicious-swiss-roll)
24 | - [Task 1.1](#11-forward-and-reverse-process) [(a)](#a-ou-process), [(b)](#b-vpsde--vesde)
25 | - [Task 1.2](#12-training)
26 | - [Task 1.3](#13-sampling)
27 | - [Task 1.4](#14-evaluation)
28 | - [Task 2](#task-2-image-diffusion)
29 | - [Task 2.1](21-ddim)
30 | - [Task 2.2](22-classifier-free-guidance)
31 | - [Task 2.3](23-image-inpainting)
32 |
33 |
34 |
35 |
36 | Task Checklist
37 |
38 |
39 |
40 | **Task 1**
41 | - [ ] 1.1 Define Forward SDE
42 | - [ ] 1.1 Define Backward SDE
43 | - [ ] 1.1 Define VPSDE
44 | - [ ] 1.1 Define VESDE
45 | - [ ] 1.2 Implement MLP Network
46 | - [ ] 1.2 Implement DSM Loss
47 | - [ ] 1.2 Implement Training Loop
48 | - [ ] 1.3 Implement Discretization
49 | - [ ] 1.3 Implement Sampling Loop
50 | - [ ] 1.4 Evaluate Implementation
51 |
52 | **Task 2**
53 | - [ ] 2.1 Implement DDIM Variance Scheduling
54 | - [ ] 2.2 Implement CFG
55 | - [ ] 2.3 Implement Image Inpainting
56 |
57 | **Optional Tasks**
58 | - [X] Add more additional task that you did here.
59 | - [ ] Implement EMA Training
60 | - [ ] Implement ISM Loss
61 | - [ ] Implement ODE Sampling
62 | - [ ] Implement Schrodinger Bridge
63 | - [ ] Implement MCG Inpainting
64 |
65 |
66 |
67 |
68 | ## Setup
69 |
70 | Install the required package within the `requirements.txt`
71 | ```
72 | pip install -r requirements.txt
73 | ```
74 |
75 | ## Code Structure
76 | ```
77 | .
78 | ├── image_diffusion (Task 2)
79 | │ ├── dataset.py <--- Ready-to-use AFHQ dataset code
80 | │ ├── train.py <--- DDPM training code
81 | │ ├── sampling.py <--- Image sampling code
82 | │ ├── ddpm.py <--- DDPM high-level wrapper code
83 | │ ├── module.py <--- Basic modules of a noise prediction network
84 | │ ├── network.py <--- Noise prediction network
85 | │ ├── scheduler.py <--- (TODO) Define variance schedulers
86 | │ └── fid
87 | │ ├── measure_fid.py <--- script measuring FID score
88 | │ └── afhq_inception.ckpt <--- pre-trained classifier for FID
89 | └── sde_todo (Task 1)
90 | ├── HelloScore.ipynb <--- Main code
91 | ├── dataset.py <--- Define dataset (Swiss-roll, moon, gaussians, etc.)
92 | ├── eval.py <--- Evaluation code
93 | ├── loss.py <--- (TODO) Define Training Objective
94 | ├── network.py <--- (TODO) Define Network Architecture
95 | ├── sampling.py <--- (TODO) Define Discretization and Sampling
96 | ├── sde.py <--- (TODO) Define SDE Processes
97 | └── train.py <--- (TODO) Define Training Loop
98 | ```
99 |
100 | ## Tutorial Tips
101 |
102 | Implementation of Diffusion Models is typically very simple once you understand the theory.
103 | So, to learn the most from this tutorial, it's highly recommended to check out the details in the
104 | related papers and understand the equations **BEFORE** you start the tutorial. You can check out
105 | the resources in this order:
106 | 1. [[blog](https://min-hieu.github.io/blogs/blogs/brownian/)] Charlie's "Brownian Motion and SDE"
107 | 2. [[paper](https://arxiv.org/abs/2011.13456)] Score-Based Generative Modeling through Stochastic Differential Equations
108 | 3. [[blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] Lilian Wang's "What is Diffusion Model?"
109 | 4. [[paper](https://arxiv.org/abs/2006.11239)] Denoising Diffusion Probabilistic Models
110 | 5. [[slide](./assets/summary_of_DDPM_and_DDIM.pdf)] Summary of DDPM and DDIM
111 |
112 | ## Task 0: Introduction
113 | The first part of this tutorial will introduce diffusion models through the lens of stochastic differential equations (SDE).
114 | Prior to the [Yang Song et al. (2021)](https://arxiv.org/abs/2011.13456) paper, diffusion models are often understood
115 | in terms of Markov Processes with tractable transition kernel. Understanding SDE could also help you develop more
116 | efficient variance scheduling or give more flexibility to your diffusion model.
117 |
118 | We know that a stochastic differential equation has the following form:
119 | $$d\mathbf{X}_t = f(t,\mathbf{X}_t)dt + G(t)d\mathbf{B}_t$$
120 | where $f$ and $G$ are the drift and diffusion coefficients respectively and $\mathbf{B}_t$ is the
121 | standard Brownian noise. A popular SDE often used is called Ornstein-Uhlenbeck (OU) process
122 | which is defined as
123 | $$d\mathbf{X}_t = -\mu \mathbf{X}_tdt + \sigma d\mathbf{B}_t$$
124 | Where $\mu, \sigma$ are constants. In this tutorial, we will set $\mu = \frac{1}{2}, \sigma = 1$.
125 | Score-based generative modelling (SGM) aims to sample from an unknown distribution of a given dataset.
126 | We have the following two observations:
127 | - The OU process always results in a unit Gaussian.
128 | - We can derive the equation for the inverse OU process.
129 |
130 | From these facts, we can directly sample from the unknown distribution by
131 | 1. Sample from unit Gaussian
132 | 2. Run the reverse process on samples from step 1.
133 |
134 | [Yang Song et al. (2021)](https://arxiv.org/abs/2011.13456) derived the likelihood training scheme
135 | for learning the reverse process. In summary, the reverse process for any SDE given above is
136 | of the form
137 | $$d\mathbf{X}_t = [f(t,\mathbf{X}_t) - G(t)^2\nabla_x\log p_t(\mathbf{X}_t)]dt + G(t)d\bar{\mathbf{B}}_t$$
138 | where $\bar{\mathbf{B}}_t$ is the reverse brownian noise. The only unknown term is the score function
139 | $\nabla_x\log p_t(\mathbf{X}_t)$, which we will approximate with a Neural Network. One main difference
140 | between SGM and other generative models is that they generate iteratively during the sampling process.
141 |
142 | **TODO:**
143 | ```
144 | - Derive the expression for the mean and std of the OU process at time t given X0 = 0,
145 | i.e. Find E[Xt|X0] and Var[Xt|X0]. You will need this for task 1.1(a).
146 | ```
147 | *hint*: We know that the solution to the OU process is given as
148 |
149 | $\mathbf{X}_T = \mathbf{X}_0 e^{-\mu T} + \sigma \int_0^T e^{-\mu(T-t)} d\mathbf{B}_t$
150 |
151 | and you can use the fact that $d\mathbf{B}_t^2 = dt$, and $\mathbb{E}[\int_0^T f(t) d\mathbf{B}_t] = 0$ where $f(t)$ is any
152 | deterministic function.
153 |
154 | ## Task 1: very simple SGM pipeline with delicious swiss-roll
155 | A typical diffusion pipeline is divided into three components:
156 | 1. [Forward Process and Reverse Process](#11-forward-and-reverse-process)
157 | 2. [Training](#12-training)
158 | 3. [Sampling](#13-sampling)
159 |
160 | In this task, we will look into each component one by one and implement them sequentially.
161 | #### 1.1. Forward and Reverse Process
162 |
163 |
164 |
165 |
166 | Our first goal is to setup the forward and reverse processes. In the forward process, the final distribution should be
167 | the prior distribution which is the standard normal distribution.
168 | #### (a) OU Process
169 | Following the formulation of the OU Process introduced in the previous section, complete the `TODO` in the
170 | `sde.py` and check if the final distribution approach unit gaussian as $t\rightarrow \infty$.
171 |
172 |
173 |
174 |
175 |
176 |
177 | **TODO:**
178 | ```
179 | - implement the forward process using the given marginal probability p_t0(Xt|X0) in SDE.py
180 | - implement the reverse process for general SDE in SDE.py
181 | - (optional) Play around with terminal time (T) and number of time steps (N) and observe its effect
182 | ```
183 |
184 | #### (b) VPSDE & VESDE
185 | It's mentioned by [Yang Song et al. (2021)](https://arxiv.org/abs/2011.13456) that the DDPM and SMLD are distretization of SDEs.
186 | Implement this in the `sde.py` and check their mean and and std.
187 |
188 | *hint*: Although you can simulate the diffusion process through discretization, sampling with the explicit equation of the marginal probability $p_{t0}(\mathbf{X}_t \mid \mathbf{X}_0)$ is much faster.
189 |
190 | You should also obtain the following graphs for VPSDE and VESDE respectively
191 |
192 |
193 |
194 | ">
195 |
196 |
197 | **TODO:**
198 | ```
199 | - implement VPSDE in SDE.py
200 | - implement VESDE in SDE.py
201 | - plot the mean and variance of VPSDE and VESDE vs. time.
202 | What can you say about the differences between OU, VPSDE, VESDE?
203 | ```
204 |
205 | #### 1.2. Training
206 | The typical training objective of diffusion model uses **D**enoising **S**core **M**atching loss:
207 |
208 | $$f_{\theta^*} = \textrm{ argmin } \mathbb{E} [||f_\theta(\mathbf{X}s) - \nabla_{\mathbf{X}s} p_{s0}(\mathbf{X}s\mid \mathbf{X}_0)||^2] $$
209 |
210 | Where $f$ is the score prediction network with parameter $\theta^*$.
211 | Another popular training objective is **I**mplicit **S**core **M**atching loss which can be derived from DSM.
212 | One main different between ISM and DSM is that ISM doesn't require marignal density but instead the divergence.
213 | Although DSM is easier to implement, when the given [exotic domains](https://arxiv.org/abs/2202.02763) or
214 | when the marginal density [doesn't have closed form](https://openreview.net/pdf?id=nioAdKCEdXB) ISM is used.
215 |
216 | **(Important)** you need to derive a **different DSM objective** for each SDE since
217 | their marginal density is different. You first need to obtain the closed form for $p_{0t}$, then you can find the equation for $\nabla \log p_{0t}$.
218 | For $p_{0t}$, you can refer to the appendix of the SDE paper.
219 |
220 |
221 | However, there are other training objectives with their different trade-offs (SSM, EDM, etc.). Highly recommend to checkout
222 | [A Variational Perspective on Diffusion-based Generative Models and Score Matching](https://arxiv.org/abs/2106.02808)
223 | and [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) for a more in-depth analysis of the recent training objectives.
224 |
225 | **TODO:**
226 | ```
227 | - implement your own network in network.py
228 | (Recommend to implement Positional Encoding, Residual Connection)
229 | - implement DSMLoss in loss.py
230 | - implement the training loop in train_utils.py
231 | - (optional) implement ISMLoss in loss.py (hint: you will need to use torch.autograd.grad)
232 | - (optional) implement SSMLoss in loss.py
233 | ```
234 | #### 1.3. Sampling
235 | Finally, we can now use the trained score prediction network to sample from the swiss-roll dataset. Unlike the forward process, there is no analytical form
236 | of the marginal probabillity. Therefore, we have to run the simulation process. Your final sampling should be close to the target distribution
237 | **within 10000 training steps**. For this task, you are free to use **ANY** variations of diffusion process that **was mentioned** above.
238 |
239 |
240 |
241 |
242 |
243 |
244 | **TODO:**
245 | ```
246 | - implement the predict_fn in sde.py
247 | - complete the code in sampling.py
248 | - (optional) train with ema
249 | - (optional) implement the correct_fn (for VPSDE, VESDE) in sde.py
250 | - (optional) implement the ODE discretization and check out their differences
251 | ```
252 |
253 | #### 1.4. Evaluation
254 | To evaluate your performance, we compute the chamfer distance (CD) and earth mover distance (EMD) between the target and generated point cloud.
255 | Your method should be on par or better than the following metrics. For this task, you can use **ANY** variations, even ones that were **NOT** mentioned.
256 |
257 | | target distribution | CD |
258 | |---------------------|----------|
259 | | swiss-roll | 0.1975 |
260 |
261 | #### 1.5. [Coming Soon] Schrödinger Bridge (Optional)
262 | One restriction to the typical diffusion processes are that they requires the prior to be easy to sample (gaussian, uniform, etc.).
263 | Schrödinger Bridge removes this limitation by making the forward process also learnable and allow a diffusion defined between **two** unknown distribution.
264 |
265 | ## Task 2: Image Diffusion
266 |
267 |
268 |
269 |
270 |
271 | In this task, we will play with diffusion models to generate 2D images. We first look into some background of DDPM and then dive into DDPM in a code level.
272 |
273 | ### Background
274 | From the perspective of SDE, SGM and DDPM are the same models with only different parameterizations. As there are forward and reverse processes in SGM, the forward process, or called _diffusion process_, of DDPM is fixed to a Markov chain that gradually adds Gaussian noise to the data:
275 |
276 | $$ q(\mathbf{x}\_{1:T} | \mathbf{x}_0) := \prod\_{t=1}^T q(\mathbf{x}_t | \mathbf{x}\_{t-1}), \quad q(\mathbf{x}_t | \mathbf{x}\_{t-1}) := \mathcal{N} (\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}\_{t-1}, \beta_t \mathbf{I}).$$
277 |
278 |
279 | Thanks to a nice property of a Gaussian distribution, one can sample $\mathbf{x}_t$ at an arbitrary timestep $t$ from real data $\mathbf{x}_0$ in closed form:
280 |
281 | $$q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t ; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t) \mathbf{I}) $$
282 |
283 | where $\alpha\_t := 1 - \beta\_t$ and $\bar{\alpha}_t := \prod$ $\_{s=1}^T \alpha_s$.
284 |
285 | Given the diffusion process, we want to model the _reverse process_ that gradually denoises white Gaussian noise $\mathbf{x}_T \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ to sample real data. It is also defined as a Markov chain with learned Gaussian transitions:
286 |
287 | $$p\_\theta(\mathbf{x}\_{0:T}) := p(\mathbf{x}_T) \prod\_{t=1}^T p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t), \quad p\_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t) := \mathcal{N}(\mathbf{x}\_{t-1}; \mathbf{\boldsymbol{\mu}}\_\theta (\mathbf{x}_t, t), \boldsymbol{\Sigma}\_\theta (\mathbf{x}_t, t)).$$
288 |
289 | To learn this reverse process, we set an objective function that minimizes KL divergence between $p_\theta(\mathbf{x}\_{t-1} | \mathbf{x}_t)$ and $q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0)$ which is tractable when conditioned on $\mathbf{x}_0$:
290 |
291 | $$\mathcal{L} = \mathbb{E}_q \left[ \sum\_{t > 1} D\_{\text{KL}}( q(\mathbf{x}\_{t-1} | \mathbf{x}_t, \mathbf{x}_0) \Vert p\_\theta ( \mathbf{x}\_{t-1} | \mathbf{x}_t)) \right]$$
292 |
293 | Refer to [the original paper](https://arxiv.org/abs/2006.11239) or our PPT material for more details.
294 |
295 | As a parameterization of DDPM, the authors set $\boldsymbol{\Sigma}\_\theta(\mathbf{x}_t, t) = \sigma_t^2 \mathbf{I}$ to untrained time dependent constants, and they empirically found that predicting noise injected to data by a noise prediction network $\epsilon\_\theta$ is better than learning the mean function $\boldsymbol{\mu}\_\theta$.
296 |
297 | In short, the simplified objective function of DDPM is defined as follows:
298 |
299 | $$ \mathcal{L}\_{\text{simple}} := \mathbb{E}\_{t,\mathbf{x}_0,\boldsymbol{\epsilon}} [ \Vert \boldsymbol{\epsilon} - \boldsymbol{\epsilon}\_\theta( \mathbf{x}\_t(\mathbf{x}_0, t), t) \Vert^2 ],$$
300 |
301 | where $\mathbf{x}_t (\mathbf{x}_0, t) = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}$ and $\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$.
302 |
303 | #### Sampling
304 |
305 | Once we train the noise prediction network $\boldsymbol{\epsilon}\_\theta$, we can run sampling by gradually denoising white Gaussian noise. The algorithm of the DDPM sampling is shown below:
306 |
307 |
308 |
309 |
310 |
311 | [DDIM](https://arxiv.org/abs/2010.02502) proposed a way to speed up the sampling using the same pre-trained DDPM. The reverse step of DDIM is below:
312 |
313 |
314 |
315 |
316 |
317 | Note that $\alpha_t$ notation in DDIM corresponds to $\bar{\alpha}_t$ in DDPM paper.
318 |
319 | Please refer to DDIM paper for more details.
320 |
321 | #### 2.1. DDIM
322 | ### TODO
323 |
324 | In this task, we will generate $64\times64$ animal images using DDPM with AFHQ dataset. We provide most of code only except for variance scheduling code. You need to implement DDPM scheduler and DDIM scheduler in `scheduler.py`. After implementing the schedulers, train a model by `python train.py`. It will sample images and save a checkpoint every `args.log_interval`. After training a model, sample & save images by
325 | ```
326 | python sampling.py --ckpt_path ${CKPT_PATH} --save_dir ${SAVE_DIR_PATH}
327 | ```
328 |
329 | We recommend starting the training as soon as possible since the training would take about half of a day. Also, DDPM scheduler is really slow. We recommend **implementing DDIM scheduler first** and set `inference timesteps` 20~50 which is enough to get high-quality images with much less sampling time.
330 |
331 |
332 | As an evaluation, measure FID score using the pre-trained classifier network we provide:
333 |
334 | ```
335 | python dataset.py # to constuct eval directory.
336 | python fid/measure_fid.py /path/to/eval/dir /path/to/sample/dir
337 | ```
338 | _**Success condition**_: Achieve FID score lower than `30`.
339 |
340 | #### 2.2. Classifier-Free Guidance
341 |
342 | Now, we will implement a classifier-free guidance diffusion model. It trains an unconditional diffusion model and a conditional diffusion model jointly by randomly dropping out a conditional term. The algorithm is below:
343 |
344 |
345 |
346 |
347 |
348 | You need to train another diffusion model for classifier-free guidance by slightly modifying the network architecture so that it can take class labels as input. The network design is your choice. Our implementation used `nn.Embedding` for class label embeddings and simply add class label embeddings to time embeddings. We set condition term dropout rate 0.1 in training and `guidance_scale` 7.5.
349 |
350 | Note that the provided code considers null class label as 0.
351 |
352 | Generate 200 images per category, 600 in total. Measure FID with the same validation set used in 2.1.
353 | _**Success condition**_: Achieve FID score lower than `30`.
354 |
355 | For more details, refer to the [paper](https://arxiv.org/abs/2207.12598).
356 |
357 | #### 2.3. Image Inpainting
358 |
359 |
360 |
361 |
362 |
363 | DDPMs have zero-shot capabilities handling various downstream tasks beyond unconditional generation. Among them, we will focus on the image inpainting task only.
364 |
365 | Note that there is **no base code** for image inpainting.
366 |
367 | Make a rectangle hole with a $32 \times 32$ size in the images generated in 2.1.
368 | Then, do image inpainting based on Algorithm 8 in [Repaint](https://arxiv.org/abs/2201.09865).
369 |
370 |
371 |
372 |
373 |
374 |
375 | Report FID scores with 500 result images and the same validation set used in Task 2.1.
376 | You will get a full credit if the **_FID is lower than 30_**.
377 |
378 | #### [Optional] Improving image inpainting by MCG
379 |
380 | A recent paper [Improving Diffusion Models for Inverse Problems using Manifold Constraints](https://arxiv.org/abs/2206.00941), also known as MCG, proposed a way to improve the solving various inverse problems, such as image inpainting, using DDPMs. In a high-level idea, in the reverse process, it takes an additional gradient descent towards a subspace of a latent space satisfying a given partial observation. Refer to the [original paper](https://arxiv.org/abs/2206.00941) for more details and implement MCG-based image inpainting code.
381 |
382 | Compare image inpainting results between MCG and the baseline.
383 |
384 | ## Resources
385 | - [[paper](https://arxiv.org/abs/2011.13456)] Score-Based Generative Modeling through Stochastic Differential Equations
386 | - [[paper](https://arxiv.org/abs/2006.09011)] Improved Techniques for Training Score-Based Generative Models
387 | - [[paper](https://arxiv.org/abs/2006.11239)] Denoising Diffusion Probabilistic Models
388 | - [[paper](https://arxiv.org/abs/2105.05233)] Diffusion Models Beat GANs on Image Synthesis
389 | - [[paper](https://arxiv.org/abs/2207.12598)] Classifier-Free Diffusion Guidance
390 | - [[paper](https://arxiv.org/abs/2010.02502)] Denoising Diffusion Implicit Models
391 | - [[paper](https://arxiv.org/abs/2206.00364)] Elucidating the Design Space of Diffusion-Based Generative Models
392 | - [[paper](https://arxiv.org/abs/2106.02808)] A Variational Perspective on Diffusion-Based Generative Models and Score Matching
393 | - [[paper](https://arxiv.org/abs/2305.16261)] Trans-Dimensional Generative Modeling via Jump Diffusion Models
394 | - [[paper](https://openreview.net/pdf?id=nioAdKCEdXB)] Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory
395 | - [[blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)] What is Diffusion Model?
396 | - [[blog](https://yang-song.net/blog/2021/score/)] Generative Modeling by Estimating Gradients of the Data Distribution
397 | - [[lecture](https://youtube.com/playlist?list=PLCf12vHS8ONRpLNVGYBa_UbqWB_SeLsY2)] Charlie's Playlist on Diffusion Processes
398 | - [[slide](./assets/summary_of_DDPM_and_DDIM.pdf)] Juil's presentation slide of DDIM
399 | - [[slide](./assets/sb_likelihood_training.pdf)] Charlie's presentation of Schrödinger Bridge.
400 |
--------------------------------------------------------------------------------
/assets/diffusion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib import animation
4 | from scipy.ndimage import gaussian_filter
5 | from tqdm import tqdm
6 | from PIL import Image
7 |
8 | def step(x, dt=0.001, mu=1.7, sigma=1.2):
9 | z = np.random.randn(*x.shape)
10 | dx = -mu*x*dt + sigma*np.sqrt(dt)*z
11 | return x + dx
12 |
13 | def step_img(x, dt=0.001, mu=120, sigma=1):
14 | z = np.random.randn(*x.shape)
15 | dx = -mu*x*dt + sigma*np.sqrt(dt)*z
16 | return x + dx
17 |
18 |
19 | def sample_two_gaus(N=10000):
20 | N1 = N // 2
21 | N2 = N - N1
22 | samples1 = np.random.randn(N1) + 2
23 | samples2 = np.random.randn(N2) - 2
24 | samples = np.hstack((samples1, samples2))
25 | return samples
26 |
27 | def p(xt, N=100, mi=-5, ma=5):
28 | return np.histogram(xt, bins=N, range=(mi,ma))[0]
29 |
30 |
31 | n_steps = 500
32 | bins = 100
33 | path = np.zeros((bins, n_steps))
34 | single_path = np.zeros((n_steps))
35 | xt = sample_two_gaus()
36 | single_xt = np.array([-2.8])
37 |
38 | for i in range(n_steps):
39 | path[:,i] = p(xt, N=bins)
40 | xt = step(xt, dt=1/n_steps)
41 |
42 | single_xt = step(single_xt, dt=1/n_steps)
43 | single_path[i] = single_xt[0]
44 |
45 | smooth_path = gaussian_filter(path, sigma=3)
46 |
47 | thanos_img = np.array(Image.open('./thanos.png')) / 255. - 0.5
48 |
49 | fig = plt.Figure(figsize=(15,5))
50 | fig.tight_layout()
51 |
52 | ax = fig.add_subplot(3,9,(4,27))
53 | ax.get_xaxis().set_visible(False)
54 | ax.get_yaxis().set_visible(False)
55 | ax.imshow(smooth_path, interpolation='nearest', aspect='auto')
56 |
57 | ax2 = fig.add_subplot(3,9,(4,27))
58 | ax2.get_xaxis().set_visible(False)
59 | ax2.get_yaxis().set_visible(False)
60 | ax2.set_ylim(-5, 5)
61 | ax2.set_xlim(0, n_steps)
62 | ax2.patch.set_alpha(0.)
63 | point = ax2.plot(-6*np.ones(n_steps), 'ro')[0]
64 | line = ax2.plot(single_path, 'r-')[0]
65 |
66 | ax_img = fig.add_subplot(3,9,(1,21))
67 | ax_img.get_xaxis().set_visible(False)
68 | ax_img.get_yaxis().set_visible(False)
69 | ax_img.margins(0,0)
70 | ax_img.imshow(thanos_img)
71 |
72 | n_animate = 50
73 | pbar = tqdm(total=n_animate)
74 |
75 | def update(i):
76 | point_y = -6*np.ones(n_steps)
77 | point_y[1000//n_animate*i] = single_path[1000//n_animate*i]
78 | point.set_ydata(point_y)
79 |
80 | line_y = single_path[:1000//n_animate*i]
81 | line.set_data(range(line_y.shape[0]), line_y)
82 |
83 | ax_img.cla()
84 | global thanos_img
85 | ax_img.imshow(thanos_img + 0.5, interpolation='nearest', aspect='auto')
86 |
87 | thanos_img = step_img(thanos_img)
88 | print(thanos_img.min(), thanos_img.max())
89 |
90 | pbar.update(1)
91 |
92 | ani = animation.FuncAnimation(fig, update, range(n_animate))
93 | ani.save('diff_traj.gif', writer='imagemagick', fps=25, savefig_kwargs=dict(pad_inches=0, bbox_inches='tight'));
94 |
95 | # fig.savefig('diff_traj.png')
96 |
--------------------------------------------------------------------------------
/assets/images/task1_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_distribution.png
--------------------------------------------------------------------------------
/assets/images/task1_ou.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_ou.png
--------------------------------------------------------------------------------
/assets/images/task1_sampling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_sampling.png
--------------------------------------------------------------------------------
/assets/images/task1_vesde.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_vesde.png
--------------------------------------------------------------------------------
/assets/images/task1_vpsde.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task1_vpsde.png
--------------------------------------------------------------------------------
/assets/images/task2_1_ddpm_sampling_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_1_ddpm_sampling_algorithm.png
--------------------------------------------------------------------------------
/assets/images/task2_1_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_1_teaser.png
--------------------------------------------------------------------------------
/assets/images/task2_2_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_2_teaser.png
--------------------------------------------------------------------------------
/assets/images/task2_3_repaint_algorithm8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_3_repaint_algorithm8.png
--------------------------------------------------------------------------------
/assets/images/task2_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_algorithm.png
--------------------------------------------------------------------------------
/assets/images/task2_ddim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task2_ddim.png
--------------------------------------------------------------------------------
/assets/images/task3_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task3_algorithm.png
--------------------------------------------------------------------------------
/assets/images/task_2_3_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/task_2_3_teaser.png
--------------------------------------------------------------------------------
/assets/images/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/teaser.gif
--------------------------------------------------------------------------------
/assets/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/images/teaser.png
--------------------------------------------------------------------------------
/assets/sb_likelihood_training.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/sb_likelihood_training.pdf
--------------------------------------------------------------------------------
/assets/summary_of_DDPM_and_DDIM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/assets/summary_of_DDPM_and_DDIM.pdf
--------------------------------------------------------------------------------
/image_diffusion_todo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/image_diffusion_todo/__init__.py
--------------------------------------------------------------------------------
/image_diffusion_todo/dataset.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import os
3 | from itertools import chain
4 | from pathlib import Path
5 |
6 | import torch
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 |
10 |
11 | def listdir(dname):
12 | fnames = list(
13 | chain(
14 | *[
15 | list(Path(dname).rglob("*." + ext))
16 | for ext in ["png", "jpg", "jpeg", "JPG"]
17 | ]
18 | )
19 | )
20 | return fnames
21 |
22 |
23 | def tensor_to_pil_image(x: torch.Tensor, single_image=False):
24 | """
25 | x: [B,C,H,W]
26 | """
27 | if x.ndim == 3:
28 | x = x.unsqueeze(0)
29 | single_image = True
30 |
31 | x = (x * 0.5 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
32 | images = (x * 255).round().astype("uint8")
33 | images = [Image.fromarray(image) for image in images]
34 | if single_image:
35 | return images[0]
36 | return images
37 |
38 |
39 | def get_data_iterator(iterable):
40 | """Allows training with DataLoaders in a single infinite loop:
41 | for i, data in enumerate(inf_generator(train_loader)):
42 | """
43 | iterator = iterable.__iter__()
44 | while True:
45 | try:
46 | yield iterator.__next__()
47 | except StopIteration:
48 | iterator = iterable.__iter__()
49 |
50 |
51 | class AFHQDataset(torch.utils.data.Dataset):
52 | def __init__(
53 | self, root: str, split: str, transform=None, max_num_images_per_cat=-1, label_offset=1
54 | ):
55 | super().__init__()
56 | self.root = root
57 | self.split = split
58 | self.transform = transform
59 | self.max_num_images_per_cat = max_num_images_per_cat
60 | self.label_offset = label_offset
61 |
62 | categories = os.listdir(os.path.join(root, split))
63 | self.num_classes = len(categories)
64 |
65 | fnames, labels = [], []
66 | for idx, cat in enumerate(sorted(categories)):
67 | category_dir = os.path.join(root, split, cat)
68 | cat_fnames = listdir(category_dir)
69 | cat_fnames = sorted(cat_fnames)
70 | if self.max_num_images_per_cat > 0:
71 | cat_fnames = cat_fnames[: self.max_num_images_per_cat]
72 | fnames += cat_fnames
73 | labels += [idx + label_offset] * len(cat_fnames) # label 0 is for null class.
74 |
75 | self.fnames = fnames
76 | self.labels = labels
77 |
78 | def __getitem__(self, idx):
79 | img = Image.open(self.fnames[idx]).convert("RGB")
80 | label = self.labels[idx]
81 | assert label >= self.label_offset
82 | if self.transform is not None:
83 | img = self.transform(img)
84 |
85 | return img, label
86 |
87 | def __len__(self):
88 | return len(self.labels)
89 |
90 |
91 | class AFHQDataModule(object):
92 | def __init__(
93 | self,
94 | root: str = "data",
95 | batch_size: int = 32,
96 | num_workers: int = 4,
97 | max_num_images_per_cat: int = -1,
98 | image_resolution: int = 64,
99 | label_offset=1,
100 | ):
101 | self.root = root
102 | self.batch_size = batch_size
103 | self.num_workers = num_workers
104 | self.afhq_root = os.path.join(root, "afhq")
105 | self.max_num_images_per_cat = max_num_images_per_cat
106 | self.image_resolution = image_resolution
107 | self.label_offset = label_offset
108 |
109 | if not os.path.exists(self.afhq_root):
110 | print(f"{self.afhq_root} is empty. Downloading AFHQ dataset...")
111 | self._download_dataset()
112 |
113 | self._set_dataset()
114 |
115 | def _set_dataset(self):
116 | self.transform = transforms.Compose(
117 | [
118 | transforms.Resize((self.image_resolution, self.image_resolution)),
119 | transforms.ToTensor(),
120 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
121 | ]
122 | )
123 | self.train_ds = AFHQDataset(
124 | self.afhq_root,
125 | "train",
126 | self.transform,
127 | max_num_images_per_cat=self.max_num_images_per_cat,
128 | label_offset=self.label_offset
129 | )
130 | self.val_ds = AFHQDataset(
131 | self.afhq_root,
132 | "val",
133 | self.transform,
134 | max_num_images_per_cat=self.max_num_images_per_cat,
135 | label_offset=self.label_offset,
136 | )
137 |
138 | self.num_classes = self.train_ds.num_classes
139 |
140 | def _download_dataset(self):
141 | URL = "https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0"
142 | ZIP_FILE = f"./{self.root}/afhq.zip"
143 | os.system(f"mkdir -p {self.root}")
144 | os.system(f"wget -N {URL} -O {ZIP_FILE}")
145 | os.system(f"unzip {ZIP_FILE} -d {self.root}")
146 | os.system(f"rm {ZIP_FILE}")
147 |
148 |
149 | def train_dataloader(self):
150 | return torch.utils.data.DataLoader(
151 | self.train_ds,
152 | batch_size=self.batch_size,
153 | num_workers=self.num_workers,
154 | shuffle=True,
155 | drop_last=True,
156 | )
157 |
158 | def val_dataloader(self):
159 | return torch.utils.data.DataLoader(
160 | self.val_ds,
161 | batch_size=self.batch_size,
162 | num_workers=self.num_workers,
163 | shuffle=False,
164 | drop_last=False,
165 | )
166 |
167 | if __name__ == "__main__":
168 | data_module = AFHQDataModule("data", 32, 4, -1, 64, 1)
169 |
170 | eval_dir = Path(data_module.afhq_root) / "eval"
171 | eval_dir.mkdir(exist_ok=True)
172 | def func(path):
173 | fn = path.name
174 | cmd = f"cp {path} {eval_dir / fn}"
175 | os.system(cmd)
176 | img = Image.open(str(eval_dir / fn))
177 | img = img.resize((64,64))
178 | img.save(str(eval_dir / fn))
179 | print(fn)
180 |
181 | with Pool(8) as pool:
182 | pool.map(func, data_module.val_ds.fnames)
183 |
184 | print(f"Constructed eval dir at {eval_dir}")
185 |
--------------------------------------------------------------------------------
/image_diffusion_todo/ddpm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from scheduler import BaseScheduler
8 |
9 |
10 | class DiffusionModule(nn.Module):
11 | def __init__(self, network: nn.Module, var_scheduler: BaseScheduler, **kwargs):
12 | super().__init__()
13 | self.network = network
14 | self.var_scheduler = var_scheduler
15 |
16 | def get_loss(self, x0, class_label=None, noise=None):
17 | B = x0.shape[0]
18 | timestep = self.var_scheduler.uniform_sample_t(B, self.device)
19 | x_noisy, noise = self.var_scheduler.add_noise(x0, timestep)
20 | noise_pred = self.network(x_noisy, timestep=timestep, class_label=class_label)
21 |
22 | loss = F.mse_loss(noise_pred.flatten(), noise.flatten(), reduction="mean")
23 | return loss
24 |
25 | @property
26 | def device(self):
27 | return next(self.network.parameters()).device
28 |
29 | @property
30 | def image_resolution(self):
31 | return self.network.image_resolution
32 |
33 | @torch.no_grad()
34 | def sample(
35 | self,
36 | batch_size: int,
37 | return_traj: bool = False,
38 | ):
39 | """
40 | Sample x_0 from a learned diffusion model.
41 | """
42 | x_T = torch.randn([batch_size, 3, self.image_resolution, self.image_resolution]).to(
43 | self.device
44 | )
45 |
46 | traj = [x_T]
47 | for t in self.var_scheduler.timesteps:
48 | x_t = traj[-1]
49 | noise_pred = self.network(x_t, timestep=t.to(self.device))
50 |
51 | x_t_prev = self.var_scheduler.step(x_t, t, noise_pred)
52 |
53 | traj[-1] = traj[-1].cpu()
54 | traj.append(x_t_prev.detach())
55 |
56 | if return_traj:
57 | return traj
58 | else:
59 | return traj[-1]
60 |
61 | def save(self, file_path):
62 | hparams = {
63 | "network": self.network,
64 | "var_scheduler": self.var_scheduler,
65 | }
66 | state_dict = self.state_dict()
67 |
68 | dic = {"hparams": hparams, "state_dict": state_dict}
69 | torch.save(dic, file_path)
70 |
71 | def load(self, file_path):
72 | dic = torch.load(file_path, map_location="cpu")
73 | hparams = dic["hparams"]
74 | state_dict = dic["state_dict"]
75 |
76 | self.network = hparams["network"]
77 | self.var_scheduler = hparams["var_scheduler"]
78 |
79 | self.load_state_dict(state_dict)
80 |
81 |
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/afhq_inception_v3.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/min-hieu/Tutorial_4/1ae266c00d78a1afb0c329ebbe5e441316891ffd/image_diffusion_todo/fid/afhq_inception_v3.ckpt
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/inception.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 | import numpy as np
11 | import torch.nn as nn
12 | from torchvision import models
13 |
14 |
15 | class InceptionV3(nn.Module):
16 | def __init__(self, for_train):
17 | super().__init__()
18 | self.for_train = for_train
19 |
20 | inception = models.inception_v3(pretrained=False)
21 | self.block1 = nn.Sequential(
22 | inception.Conv2d_1a_3x3,
23 | inception.Conv2d_2a_3x3,
24 | inception.Conv2d_2b_3x3,
25 | nn.MaxPool2d(kernel_size=3, stride=2),
26 | )
27 | self.block2 = nn.Sequential(
28 | inception.Conv2d_3b_1x1,
29 | inception.Conv2d_4a_3x3,
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.block3 = nn.Sequential(
33 | inception.Mixed_5b,
34 | inception.Mixed_5c,
35 | inception.Mixed_5d,
36 | inception.Mixed_6a,
37 | inception.Mixed_6b,
38 | inception.Mixed_6c,
39 | inception.Mixed_6d,
40 | inception.Mixed_6e,
41 | )
42 | self.block4 = nn.Sequential(
43 | inception.Mixed_7a,
44 | inception.Mixed_7b,
45 | inception.Mixed_7c,
46 | nn.AdaptiveAvgPool2d(output_size=(1, 1)),
47 | )
48 |
49 | self.final_fc = nn.Linear(2048, 3)
50 |
51 | def forward(self, x):
52 | x = self.block1(x)
53 | x = self.block2(x)
54 | x = self.block3(x)
55 | x = self.block4(x)
56 | x = x.view(x.size(0), -1)
57 | if self.for_train:
58 | return self.final_fc(x)
59 | else:
60 | return x
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/measure_fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import sys
6 | from PIL import Image
7 | from scipy import linalg
8 | from torchvision import transforms
9 | from itertools import chain
10 | from pathlib import Path
11 | from inception import InceptionV3
12 |
13 | try:
14 | from tqdm import tqdm
15 | except ImportError:
16 | def tqdm(x):
17 | return x
18 |
19 | class ImagePathDataset(torch.utils.data.Dataset):
20 | def __init__(self, files, img_size):
21 | self.files = files
22 | self.img_size = img_size
23 | self.transforms = transforms.Compose(
24 | [
25 | transforms.Resize((img_size, img_size)),
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
28 | ]
29 | )
30 |
31 | def __len__(self):
32 | return len(self.files)
33 |
34 | def __getitem__(self, i):
35 | path = self.files[i]
36 | img = Image.open(path).convert("RGB")
37 | if self.transforms is not None:
38 | img = self.transforms(img)
39 | return img
40 |
41 |
42 | def get_eval_loader(path, img_size, batch_size):
43 | def listdir(dname):
44 | fnames = list(
45 | chain(
46 | *[
47 | list(Path(dname).rglob("*." + ext))
48 | for ext in ["png", "jpg", "jpeg", "JPG"]
49 | ]
50 | )
51 | )
52 | return fnames
53 |
54 | files = listdir(path)
55 | ds = ImagePathDataset(files, img_size)
56 | dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
57 | return dl
58 |
59 | def frechet_distance(mu, cov, mu2, cov2):
60 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
61 | dist = np.sum((mu - mu2) ** 2) + np.trace(cov + cov2 - 2 * cc)
62 | return np.real(dist)
63 |
64 |
65 |
66 | @torch.no_grad()
67 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
68 | print("Calculating FID given paths %s and %s..." % (paths[0], paths[1]))
69 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70 | inception = InceptionV3(for_train=False)
71 | current_dir = Path(os.path.realpath(__file__)).parent
72 | ckpt = torch.load(current_dir / "afhq_inception_v3.ckpt")
73 | inception.load_state_dict(ckpt)
74 | inception = inception.eval().to(device)
75 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
76 |
77 | mu, cov = [], []
78 | for loader in loaders:
79 | actvs = []
80 | for x in tqdm(loader, total=len(loader)):
81 | actv = inception(x.to(device))
82 | actvs.append(actv)
83 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
84 | mu.append(np.mean(actvs, axis=0))
85 | cov.append(np.cov(actvs, rowvar=False))
86 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
87 | return fid_value
88 |
89 | if __name__ == "__main__":
90 | # python measure_fid /path/to/dir1 /path/to/dir2
91 |
92 | paths = [sys.argv[1], sys.argv[2]]
93 | fid_value = calculate_fid_given_paths(paths, img_size=256, batch_size=64)
94 | print("FID:", fid_value)
95 |
96 |
--------------------------------------------------------------------------------
/image_diffusion_todo/fid/train_classifier.py:
--------------------------------------------------------------------------------
1 | from inception import InceptionV3
2 | import sys
3 | sys.path.append("..")
4 | import torch
5 | import torch.nn.functional as F
6 | from dataset import AFHQDataset, AFHQDataModule
7 | from tqdm import tqdm
8 | import cv2
9 | import numpy as np
10 |
11 | data_module = AFHQDataModule("/home/juil/workspace/23summer_tutorial/HelloScore/image_diffusion/data/", 32, 4, -1, 256, 0)
12 |
13 | train_dl = data_module.train_dataloader()
14 | val_dl = data_module.val_dataloader()
15 |
16 | device = f"cuda:1"
17 |
18 | net = InceptionV3(for_train=True)
19 | net = net.to(device)
20 | net.train()
21 | for n, p in net.named_parameters():
22 | p.requires_grad_(True)
23 |
24 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
25 |
26 | epochs = 10
27 | for epoch in range(epochs):
28 | pbar = tqdm(train_dl)
29 | net.train()
30 | for img, label in pbar:
31 | img, label = img.to(device), label.to(device)
32 | pred = net(img)
33 | loss = F.cross_entropy(pred, label)
34 |
35 | optimizer.zero_grad()
36 | loss.backward()
37 | optimizer.step()
38 |
39 | pred_label = pred.max(-1)[1]
40 | acc = (pred_label == label).float().mean()
41 |
42 | pbar.set_description(f"E {epoch} | loss: {loss:.4f} acc: {acc*100:.2f}%")
43 |
44 | net.eval()
45 | val_accs = []
46 | for img, label in val_dl:
47 | img, label = img.to(device), label.to(device)
48 | pred = net(img)
49 | pred_label = pred.max(-1)[1]
50 |
51 | acc = (pred_label == label).float().mean()
52 | val_accs.append(acc)
53 | print(f"Val Acc: {sum(val_accs) / len(val_accs) * 100:.2f}%")
54 |
55 | torch.save(net.state_dict(), "afhq_inception_v3.ckpt")
56 | print("Saved model")
57 |
58 |
--------------------------------------------------------------------------------
/image_diffusion_todo/module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 |
8 |
9 | class Swish(nn.Module):
10 | def forward(self, x):
11 | return x * torch.sigmoid(x)
12 |
13 |
14 | class DownSample(nn.Module):
15 | def __init__(self, in_ch):
16 | super().__init__()
17 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
18 | self.initialize()
19 |
20 | def initialize(self):
21 | init.xavier_uniform_(self.main.weight)
22 | init.zeros_(self.main.bias)
23 |
24 | def forward(self, x, temb):
25 | x = self.main(x)
26 | return x
27 |
28 |
29 | class UpSample(nn.Module):
30 | def __init__(self, in_ch):
31 | super().__init__()
32 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
33 | self.initialize()
34 |
35 | def initialize(self):
36 | init.xavier_uniform_(self.main.weight)
37 | init.zeros_(self.main.bias)
38 |
39 | def forward(self, x, temb):
40 | _, _, H, W = x.shape
41 | x = F.interpolate(x, scale_factor=2, mode="nearest")
42 | x = self.main(x)
43 | return x
44 |
45 |
46 | class AttnBlock(nn.Module):
47 | def __init__(self, in_ch):
48 | super().__init__()
49 | self.group_norm = nn.GroupNorm(32, in_ch)
50 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
51 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
52 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
53 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
54 | self.initialize()
55 |
56 | def initialize(self):
57 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
58 | init.xavier_uniform_(module.weight)
59 | init.zeros_(module.bias)
60 | init.xavier_uniform_(self.proj.weight, gain=1e-5)
61 |
62 | def forward(self, x):
63 | B, C, H, W = x.shape
64 | h = self.group_norm(x)
65 | q = self.proj_q(h)
66 | k = self.proj_k(h)
67 | v = self.proj_v(h)
68 |
69 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
70 | k = k.view(B, C, H * W)
71 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
72 | assert list(w.shape) == [B, H * W, H * W]
73 | w = F.softmax(w, dim=-1)
74 |
75 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
76 | h = torch.bmm(w, v)
77 | assert list(h.shape) == [B, H * W, C]
78 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
79 | h = self.proj(h)
80 |
81 | return x + h
82 |
83 |
84 | class ResBlock(nn.Module):
85 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
86 | super().__init__()
87 | self.block1 = nn.Sequential(
88 | nn.GroupNorm(32, in_ch),
89 | Swish(),
90 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
91 | )
92 | self.temb_proj = nn.Sequential(
93 | Swish(),
94 | nn.Linear(tdim, out_ch),
95 | )
96 | self.block2 = nn.Sequential(
97 | nn.GroupNorm(32, out_ch),
98 | Swish(),
99 | nn.Dropout(dropout),
100 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
101 | )
102 | if in_ch != out_ch:
103 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
104 | else:
105 | self.shortcut = nn.Identity()
106 | if attn:
107 | self.attn = AttnBlock(out_ch)
108 | else:
109 | self.attn = nn.Identity()
110 | self.initialize()
111 |
112 | def initialize(self):
113 | for module in self.modules():
114 | if isinstance(module, (nn.Conv2d, nn.Linear)):
115 | init.xavier_uniform_(module.weight)
116 | init.zeros_(module.bias)
117 | init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
118 |
119 | def forward(self, x, temb):
120 | h = self.block1(x)
121 | h += self.temb_proj(temb)[:, :, None, None]
122 | h = self.block2(h)
123 |
124 | h = h + self.shortcut(x)
125 | h = self.attn(h)
126 | return h
127 |
128 |
129 | class TimeEmbedding(nn.Module):
130 | def __init__(self, hidden_size, frequency_embedding_size=256):
131 | super().__init__()
132 | self.mlp = nn.Sequential(
133 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
134 | nn.SiLU(),
135 | nn.Linear(hidden_size, hidden_size, bias=True),
136 | )
137 | self.frequency_embedding_size = frequency_embedding_size
138 |
139 | @staticmethod
140 | def timestep_embedding(t, dim, max_period=10000):
141 | """
142 | Create sinusoidal timestep embeddings.
143 | :param t: a 1-D Tensor of N indices, one per batch element.
144 | These may be fractional.
145 | :param dim: the dimension of the output.
146 | :param max_period: controls the minimum frequency of the embeddings.
147 | :return: an (N, D) Tensor of positional embeddings.
148 | """
149 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
150 | half = dim // 2
151 | freqs = torch.exp(
152 | -math.log(max_period)
153 | * torch.arange(start=0, end=half, dtype=torch.float32)
154 | / half
155 | ).to(device=t.device)
156 | args = t[:, None].float() * freqs[None]
157 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
158 | if dim % 2:
159 | embedding = torch.cat(
160 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
161 | )
162 | return embedding
163 |
164 | def forward(self, t):
165 | if t.ndim == 0:
166 | t = t.unsqueeze(-1)
167 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
168 | t_emb = self.mlp(t_freq)
169 | return t_emb
170 |
--------------------------------------------------------------------------------
/image_diffusion_todo/network.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from module import DownSample, ResBlock, Swish, TimeEmbedding, UpSample
7 | from torch.nn import init
8 |
9 |
10 | class UNet(nn.Module):
11 | def __init__(
12 | self,
13 | T: int = 1000,
14 | image_resolution: int = 64,
15 | ch: int = 128,
16 | ch_mult: List[int] = [1, 2, 2, 2],
17 | attn: List[int] = [1],
18 | num_res_blocks: int = 4,
19 | dropout: float = 0.1,
20 | use_cfg: bool = False,
21 | cfg_dropout: float = 0.1,
22 | num_classes: Optional[int] = None,
23 | ):
24 | super().__init__()
25 | self.image_resolution = image_resolution
26 |
27 | # TODO: Implement an architecture according to the provided architecture diagram.
28 | # You can use the modules in `module.py`.
29 |
30 | def forward(self, x, timestep, class_label=None):
31 | """
32 | Input:
33 | x (`torch.Tensor [B,C,H,W]`)
34 | timestep (`torch.Tensor [B]`)
35 | class_label (`torch.Tensor [B]`, optional)
36 | Output:
37 | out (`torch.Tensor [B,C,H,W]`): noise prediction.
38 | """
39 | assert (
40 | x.shape[-1] == x.shape[-2] == self.image_resolution
41 | ), f"The resolution of x ({x.shape[-2]}, {x.shape[-1]}) does not match with the image resolution ({self.image_resolution})."
42 |
43 | # TODO: Implement noise prediction network's forward function.
44 |
45 | out = torch.randn_like(x)
46 | return out
47 |
--------------------------------------------------------------------------------
/image_diffusion_todo/sampling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | from dataset import tensor_to_pil_image
6 | from ddpm import DiffusionModule
7 | from network import UNet
8 | from scheduler import DDIMScheduler, DDPMScheduler
9 | from pathlib import Path
10 |
11 |
12 | def main(args):
13 | save_dir = Path(args.save_dir)
14 | save_dir.mkdir(exist_ok=True, parents=True)
15 |
16 | device = f"cuda:{args.gpu}"
17 |
18 | ddpm = DiffusionModule(None, None)
19 | ddpm.load(args.ckpt_path)
20 | ddpm.eval()
21 | ddpm = ddpm.to(device)
22 |
23 | if isinstance(ddpm.var_scheduler, DDIMScheduler):
24 | ddpm.var_scheduler.set_timesteps(20)
25 |
26 | total_num_samples = 500
27 | num_batches = int(np.ceil(total_num_samples / args.batch_size))
28 |
29 | for i in range(num_batches):
30 | sidx = i * args.batch_size
31 | eidx = min(sidx + args.batch_size, total_num_samples)
32 | samples = ddpm.sample(eidx - sidx)
33 | pil_images = tensor_to_pil_image(samples)
34 |
35 | for j, img in zip(range(sidx, eidx), pil_images):
36 | img.save(save_dir / f"{j}.png")
37 | print(f"Saved the {j}-th image.")
38 |
39 |
40 |
41 | if __name__ == "__main__":
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--batch_size", type=int, default=32)
44 | parser.add_argument("--gpu", type=int, default=0)
45 | parser.add_argument("--ckpt_path", type=str)
46 | parser.add_argument("--save_dir", type=str)
47 |
48 | args = parser.parse_args()
49 | main(args)
50 |
--------------------------------------------------------------------------------
/image_diffusion_todo/scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class BaseScheduler(nn.Module):
9 | def __init__(
10 | self, num_train_timesteps: int = 1000, beta_1: float = 1e-4, beta_T: float = 0.02, mode="linear"
11 | ):
12 | super().__init__()
13 | self.num_train_timesteps = num_train_timesteps
14 | self.num_inference_timesteps = num_train_timesteps
15 | self.timesteps = torch.from_numpy(
16 | np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
17 | )
18 |
19 | if mode == "linear":
20 | betas = torch.linspace(beta_1, beta_T, steps=num_train_timesteps)
21 | elif mode == "quad":
22 | betas = (
23 | torch.linspace(beta_1**0.5, beta_T**0.5, num_train_timesteps) ** 2
24 | )
25 | else:
26 | raise NotImplementedError(f"{mode} is not implemented.")
27 |
28 | # TODO: Compute alphas and alphas_cumprod
29 | # alphas and alphas_cumprod correspond to $\alpha$ and $\bar{\alpha}$ in the DDPM paper (https://arxiv.org/abs/2006.11239).
30 | alphas = alphas_cumprod = betas
31 |
32 | self.register_buffer("betas", betas)
33 | self.register_buffer("alphas", alphas)
34 | self.register_buffer("alphas_cumprod", alphas_cumprod)
35 |
36 | def uniform_sample_t(
37 | self, batch_size, device: Optional[torch.device] = None
38 | ) -> torch.IntTensor:
39 | """
40 | Uniformly sample timesteps.
41 | """
42 | ts = np.random.choice(np.arange(self.num_train_timesteps), batch_size)
43 | ts = torch.from_numpy(ts)
44 | if device is not None:
45 | ts = ts.to(device)
46 | return ts
47 |
48 |
49 | class DDPMScheduler(BaseScheduler):
50 | def __init__(
51 | self,
52 | num_train_timesteps: int,
53 | beta_1: float,
54 | beta_T: float,
55 | mode="linear",
56 | sigma_type="small",
57 | ):
58 | super().__init__(num_train_timesteps, beta_1, beta_T, mode)
59 |
60 | # sigmas correspond to $\sigma_t$ in the DDPM paper.
61 | self.sigma_type = sigma_type
62 | if sigma_type == "small":
63 | # when $\sigma_t^2 = \tilde{\beta}_t$.
64 | alphas_cumprod_t_prev = torch.cat(
65 | [torch.tensor(1.0), self.alphas_cumprod[-1:]]
66 | )
67 | sigmas = (
68 | (1 - alphas_cumprod_t_prev) / (1 - self.alphas_cumprod) * self.betas
69 | ) ** 0.5
70 | elif sigma_type == "large":
71 | # when $\sigma_t^2 = \beta_t$.
72 | sigmas = self.betas ** 0.5
73 |
74 | self.register_buffer("sigmas", sigmas)
75 |
76 | def step(self, sample: torch.Tensor, timestep: int, noise_pred: torch.Tensor):
77 | """
78 | One step denoising function of DDPM: x_t -> x_{t-1}.
79 |
80 | Input:
81 | sample (`torch.Tensor [B,C,H,W]`): samples at arbitrary timestep t.
82 | timestep (`int`): current timestep in a reverse process.
83 | noise_pred (`torch.Tensor [B,C,H,W]`): predicted noise from a learned model.
84 | Ouptut:
85 | sample_prev (`torch.Tensor [B,C,H,W]`): one step denoised sample. (= x_{t-1})
86 | """
87 |
88 | # TODO: Implement the DDPM's one step denoising function.
89 | # Refer to Algorithm 2 in the DDPM paper (https://arxiv.org/abs/2006.11239).
90 |
91 | sample_prev = sample
92 |
93 | return sample_prev
94 |
95 | def add_noise(
96 | self,
97 | original_sample: torch.Tensor,
98 | timesteps: torch.IntTensor,
99 | noise: Optional[torch.Tensor] = None,
100 | ):
101 | """
102 | A forward pass of a Markov chain, i.e., q(x_t | x_0).
103 |
104 | Input:
105 | sample (`torch.Tensor [B,C,H,W]`): samples from a real data distribution q(x_0).
106 | timesteps: (`torch.IntTensor [B]`)
107 | noise: (`torch.Tensor [B,C,H,W]`, optional): if None, randomly sample Gaussian noise in the function.
108 | Output:
109 | x_noisy: (`torch.Tensor [B,C,H,W]`): noisy samples
110 | noise: (`torch.Tensor [B,C,H,W]`): injected noise.
111 | """
112 |
113 | # TODO: Implement the function that samples $\mathbf{x}_t$ from $\mathbf{x}_0$.
114 | # Refer to Equation 4 in the DDPM paper (https://arxiv.org/abs/2006.11239).
115 |
116 | noisy_sample = noise = original_sample
117 |
118 | return noisy_sample, noise
119 |
120 |
121 | class DDIMScheduler(BaseScheduler):
122 | def __init__(self, num_train_timesteps, beta_1, beta_T, mode="linear"):
123 | super().__init__(num_train_timesteps, beta_1, beta_T, mode)
124 |
125 | def set_timesteps(
126 | self, num_inference_timesteps: int, device: Union[str, torch.device] = None
127 | ):
128 | """
129 | Sets the timesteps of a diffusion Markov chain. It is for accelerated generation process (Sec. 4.2) in the DDIM paper (https://arxiv.org/abs/2010.02502).
130 | """
131 | if num_inference_timesteps > self.num_train_timesteps:
132 | raise ValueError(
133 | f"num_inference_timesteps ({num_inference_timesteps}) cannot exceed self.num_train_timesteps ({self.num_train_timesteps})"
134 | )
135 |
136 | self.num_inference_timesteps = num_inference_timesteps
137 |
138 | step_ratio = self.num_train_timesteps // num_inference_timesteps
139 | timesteps = (
140 | (np.arange(0, num_inference_timesteps) * step_ratio)
141 | .round()[::-1]
142 | .copy()
143 | .astype(np.int64)
144 | )
145 | self.timesteps = torch.from_numpy(timesteps)
146 |
147 | def step(
148 | self,
149 | sample: torch.Tensor,
150 | timestep: int,
151 | noise_pred: torch.Tensor,
152 | eta: float = 0.0,
153 | ):
154 | """
155 | One step denoising function of DDIM: $x_{\tau_i}$ -> $x_{\tau_{i-1}}$.
156 |
157 | Input:
158 | sample (`torch.Tensor [B,C,H,W]`): samples at arbitrary timestep $\tau_i$.
159 | timestep (`int`): current timestep in a reverse process.
160 | noise_pred (`torch.Tensor [B,C,H,W]`): predicted noise from a learned model.
161 | eta (float): correspond to η in DDIM which controls the stochasticity of a reverse process.
162 | Ouptut:
163 | sample_prev (`torch.Tensor [B,C,H,W]`): one step denoised sample. (= $x_{\tau_{i-1}}$)
164 | """
165 | # TODO: Implement the DDIM's one step denoising function.
166 | # Refer to Equation 12 in the DDIM paper (https://arxiv.org/abs/2010.02502).
167 |
168 | sample_prev = sample
169 |
170 | return sample_prev
171 |
--------------------------------------------------------------------------------
/image_diffusion_todo/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from datetime import datetime
4 | from pathlib import Path
5 |
6 | import matplotlib
7 | import matplotlib.pyplot as plt
8 | import torch
9 | from dataset import AFHQDataModule, get_data_iterator, tensor_to_pil_image
10 | from dotmap import DotMap
11 | from ddpm import DiffusionModule
12 | from network import UNet
13 | from pytorch_lightning import seed_everything
14 | from scheduler import DDIMScheduler, DDPMScheduler
15 | from torchvision.transforms.functional import to_pil_image
16 | from tqdm import tqdm
17 |
18 | matplotlib.use("Agg")
19 |
20 |
21 | def get_current_time():
22 | now = datetime.now().strftime("%m-%d-%H%M%S")
23 | return now
24 |
25 |
26 | def main(args):
27 | """config"""
28 | config = DotMap()
29 | config.update(vars(args))
30 | config.device = f"cuda:{args.gpu}"
31 |
32 | now = get_current_time()
33 | save_dir = Path(f"results/diffusion-{now}")
34 | save_dir.mkdir(exist_ok=True)
35 | print(f"save_dir: {save_dir}")
36 |
37 | seed_everything(config.seed)
38 |
39 | with open(save_dir / "config.json", "w") as f:
40 | json.dump(config, f, indent=2)
41 | """######"""
42 |
43 | image_resolution = config.image_resolution
44 | ds_module = AFHQDataModule(
45 | "./data",
46 | batch_size=config.batch_size,
47 | num_workers=4,
48 | max_num_images_per_cat=config.max_num_images_per_cat,
49 | image_resolution=image_resolution
50 | )
51 |
52 | train_dl = ds_module.train_dataloader()
53 | train_it = get_data_iterator(train_dl)
54 |
55 | var_scheduler = DDPMScheduler(
56 | config.num_diffusion_train_timesteps,
57 | beta_1=config.beta_1,
58 | beta_T=config.beta_T,
59 | mode="linear",
60 | )
61 | if isinstance(var_scheduler, DDIMScheduler):
62 | var_scheduler.set_timesteps(20) # 20 steps are enough in the case of DDIM.
63 |
64 | network = UNet(
65 | T=config.num_diffusion_train_timesteps,
66 | image_resolution=image_resolution,
67 | ch=128,
68 | ch_mult=[1, 2, 2, 2],
69 | attn=[1],
70 | num_res_blocks=4,
71 | dropout=0.1,
72 | use_cfg=args.use_cfg,
73 | cfg_dropout=args.cfg_dropout,
74 | num_classes=getattr(ds_module, "num_classes", None),
75 | )
76 |
77 | ddpm = DiffusionModule(network, var_scheduler)
78 | ddpm = ddpm.to(config.device)
79 |
80 | optimizer = torch.optim.Adam(ddpm.network.parameters(), lr=2e-4)
81 | scheduler = torch.optim.lr_scheduler.LambdaLR(
82 | optimizer, lr_lambda=lambda t: min((t + 1) / config.warmup_steps, 1.0)
83 | )
84 |
85 | step = 0
86 | losses = []
87 | with tqdm(initial=step, total=config.train_num_steps) as pbar:
88 | while step < config.train_num_steps:
89 | if step % config.log_interval == 0:
90 | ddpm.eval()
91 | plt.plot(losses)
92 | plt.savefig(f"{save_dir}/loss.png")
93 | plt.close()
94 |
95 | samples = ddpm.sample(4, return_traj=False)
96 | pil_images = tensor_to_pil_image(samples)
97 | for i, img in enumerate(pil_images):
98 | img.save(save_dir / f"step={step}-{i}.png")
99 |
100 | ddpm.save(f"{save_dir}/last.ckpt")
101 | ddpm.train()
102 |
103 | img, label = next(train_it)
104 | img, label = img.to(config.device), label.to(config.device)
105 | loss = ddpm.get_loss(img, class_label=label)
106 | pbar.set_description(f"Loss: {loss.item():.4f}")
107 |
108 | optimizer.zero_grad()
109 | loss.backward()
110 | optimizer.step()
111 | scheduler.step()
112 | losses.append(loss.item())
113 |
114 | step += 1
115 | pbar.update(1)
116 |
117 | print(f"last.ckpt is saved at {save_dir}")
118 |
119 | if __name__ == "__main__":
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--gpu", type=int, default=0)
122 | parser.add_argument("--batch_size", type=int, default=32)
123 | parser.add_argument(
124 | "--train_num_steps",
125 | type=int,
126 | default=100000,
127 | help="the number of model training steps.",
128 | )
129 | parser.add_argument("--warmup_steps", type=int, default=200)
130 | parser.add_argument("--log_interval", type=int, default=200)
131 | parser.add_argument(
132 | "--max_num_images_per_cat",
133 | type=int,
134 | default=-1,
135 | help="max number of images per category for AFHQ dataset",
136 | )
137 | parser.add_argument(
138 | "--num_diffusion_train_timesteps",
139 | type=int,
140 | default=1000,
141 | help="diffusion Markov chain num steps",
142 | )
143 | parser.add_argument("--beta_1", type=float, default=1e-4)
144 | parser.add_argument("--beta_T", type=float, default=0.02)
145 | parser.add_argument("--seed", type=int, default=63)
146 | parser.add_argument("--image_resolution", type=int, default=64)
147 |
148 | args = parser.parse_args()
149 | main(args)
150 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn==1.1.3
2 | ipython==8.12.0
3 | jupyter==1.0.0
4 | matplotlib==3.6.0
5 | torch==2.0.1
6 | torch-ema==0.3
7 | torchvision==0.15.2
8 | tqdm==4.64.1
9 | jupyterlab==4.0.2
10 | jaxtyping==0.2.20
11 | pytorch_lightning
12 | dotmap
13 |
--------------------------------------------------------------------------------
/sde_todo/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 | from sklearn import datasets
5 |
6 | def normalize(ds, scaling_factor=2.):
7 | return (ds - ds.mean()) / \
8 | ds.std() * scaling_factor
9 |
10 |
11 | def sample_checkerboard(n):
12 | # https://github.com/ghliu/SB-FBSDE/blob/main/data.py
13 | n_points = 3*n
14 | n_classes = 2
15 | freq = 5
16 | x = np.random.uniform(-(freq//2)*np.pi, (freq//2)*np.pi, size=(n_points, n_classes))
17 | mask = np.logical_or(np.logical_and(np.sin(x[:,0]) > 0.0, np.sin(x[:,1]) > 0.0), \
18 | np.logical_and(np.sin(x[:,0]) < 0.0, np.sin(x[:,1]) < 0.0))
19 | y = np.eye(n_classes)[1*mask]
20 | x0 = x[:,0]*y[:,0]
21 | x1 = x[:,1]*y[:,0]
22 | sample = np.concatenate([x0[...,None],x1[...,None]],axis=-1)
23 | sqr = np.sum(np.square(sample),axis=-1)
24 | idxs = np.where(sqr==0)
25 | sample = np.delete(sample,idxs,axis=0)
26 |
27 | return sample
28 |
29 |
30 | def load_twodim(num_samples: int,
31 | dataset: str,
32 | dimension: int = 2):
33 |
34 | if dataset == 'gaussian_centered':
35 | sample = np.random.normal(size=(num_samples, dimension))
36 | sample = sample
37 |
38 | if dataset == 'gaussian_shift':
39 | sample = np.random.normal(size=(num_samples, dimension))
40 | sample = sample + 1.5
41 |
42 | if dataset == 'circle':
43 | X, y = datasets.make_circles(
44 | n_samples=num_samples, noise=0.0, random_state=None, factor=.5)
45 | sample = X * 4
46 |
47 | if dataset == 'scurve':
48 | X, y = datasets.make_s_curve(
49 | n_samples=num_samples, noise=0.0, random_state=None)
50 | sample = normalize(X[:, [0, 2]])
51 |
52 | if dataset == 'moon':
53 | X, y = datasets.make_moons(
54 | n_samples=num_samples, noise=0.0, random_state=None)
55 | sample = normalize(X)
56 |
57 | if dataset == 'swiss_roll':
58 | X, y = datasets.make_swiss_roll(
59 | n_samples=num_samples, noise=0.0, random_state=None, hole=True)
60 | sample = normalize(X[:, [0, 2]])
61 |
62 | if dataset == 'checkerboard':
63 | sample = normalize(sample_checkerboard(num_samples))
64 |
65 | return torch.tensor(sample).float()
66 |
67 |
68 | class TwoDimDataClass(Dataset):
69 |
70 | def __init__(self,
71 | dataset_type: str,
72 | N: int,
73 | batch_size: int,
74 | dimension = 2):
75 |
76 | self.X = load_twodim(N, dataset_type, dimension=dimension)
77 | self.name = dataset_type
78 | self.batch_size = batch_size
79 | self.dimension = 2
80 |
81 | def __len__(self):
82 | return self.X.shape[0]
83 |
84 | def __getitem__(self, idx):
85 | return self.X[idx]
86 |
87 | def get_dataloader(self, shuffle=True):
88 | return DataLoader(self,
89 | batch_size=self.batch_size,
90 | shuffle=shuffle,
91 | pin_memory=True,
92 | )
93 |
--------------------------------------------------------------------------------
/sde_todo/eval.py:
--------------------------------------------------------------------------------
1 | from scipy.optimize import linear_sum_assignment
2 | import numpy as np
3 | import torch
4 |
5 | def CD(a, b):
6 | x, y = a, b
7 | bs, num_points, points_dim = x.size()
8 | xx = torch.bmm(x, x.transpose(2, 1))
9 | yy = torch.bmm(y, y.transpose(2, 1))
10 | zz = torch.bmm(x, y.transpose(2, 1))
11 | diag_ind = torch.arange(0, num_points).to(a).long()
12 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
13 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
14 | P = (rx.transpose(2, 1) + ry - 2 * zz)
15 | return P.min(1)[0], P.min(2)[0]
16 |
17 | def EMD(x, y):
18 | bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)
19 | assert npts == mpts, "EMD only works if two point clouds are equal size"
20 | dim = x.shape[-1]
21 | x = x.reshape(bs, npts, 1, dim)
22 | y = y.reshape(bs, 1, mpts, dim)
23 | dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts)
24 |
25 | emd_lst = []
26 | dist_np = dist.cpu().detach().numpy()
27 | for i in range(bs):
28 | d_i = dist_np[i]
29 | r_idx, c_idx = linear_sum_assignment(d_i)
30 | emd_i = d_i[r_idx, c_idx].mean()
31 | emd_lst.append(emd_i)
32 | emd = np.stack(emd_lst).reshape(-1)
33 | emd_torch = torch.from_numpy(emd).to(x)
34 | return emd_torch
35 |
36 |
--------------------------------------------------------------------------------
/sde_todo/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from jaxtyping import Array
4 | from typing import Callable
5 |
6 | def get_div(y: Array, x: Array):
7 | """
8 | (Optional)
9 | Return the divergence of y wrt x. Let y = f(x). get_div return div_x(y).
10 |
11 | Args:
12 | x Input of a differentiable function
13 | y Output of a differentiable function
14 |
15 | Returns:
16 | div_x(y)
17 | """
18 | pass
19 |
20 | class DSMLoss():
21 |
22 | def __init__(self, alpha: float, diff_weight: bool):
23 | """
24 | (TODO) Initialize the DSM Loss.
25 |
26 | Args:
27 | alpha: regularization weight
28 | diff_weight: scale loss by square of diffusion
29 |
30 | Returns:
31 | None
32 | """
33 |
34 | def __call__(self,
35 | t: Array,
36 | x: Array,
37 | model: Callable[[Array], Array],
38 | s: Array,
39 | diff_sq: Float):
40 | """
41 | Args:
42 | t: uniformly sampled time period
43 | x: samples after t diffusion
44 | model: score prediction function s(x,t)
45 | s: ground truth score
46 |
47 | Returns:
48 | loss: average loss value
49 | """
50 | loss = None
51 | return loss
52 |
53 |
54 | class ISMLoss():
55 | """
56 | (Optional) Implicit Score Matching Loss
57 | """
58 |
59 | def __init__(self):
60 | pass
61 |
62 | def __call__(self, t, x, model):
63 | """
64 | Args:
65 | t: uniformly sampled time period
66 | x: samples after t diffusion
67 | model: score prediction function s(x,t)
68 |
69 | Returns:
70 | loss: average loss value
71 | """
72 | return loss
73 |
74 |
75 | class SBJLoss():
76 | """
77 | (Optional) Joint Schrodinger Bridge Loss
78 |
79 | hint: You will need to implement the divergence.
80 | """
81 |
82 | def __init__(self):
83 | pass
84 |
85 | def __call__(self, t, xf, zf, zb_fn):
86 | """
87 | Initialize the SBJLoss Loss.
88 |
89 | Args:
90 | t: uniformly sampled time period
91 | xf: samples after t forward diffusion
92 | zf: ground truth forward value
93 | zb_fn: backward Z function
94 |
95 | Returns:
96 | loss: average loss value
97 | """
98 | return loss
99 |
100 |
101 | class SBALoss():
102 | """
103 | (Optional) Alternating Schrodinger Bridge Loss
104 |
105 | hint: You will need to implement the divergence.
106 | """
107 |
108 | def __init__(self):
109 | pass
110 |
111 | def __call__(self, t, xf, zf, zb_fn):
112 | """
113 | Initialize the SBALoss Loss.
114 |
115 | Args:
116 | t: uniformly sampled time period
117 | xf: samples after t forward diffusion
118 | zf: ground truth forward value
119 | zb_fn: backward Z function
120 |
121 | Returns:
122 | loss: average loss value
123 | """
124 | return loss
125 |
--------------------------------------------------------------------------------
/sde_todo/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from jaxtyping import Array, Int, Float
6 |
7 | class PositionalEncoding(nn.Module):
8 |
9 | def __init__(self, t_channel: Int):
10 | """
11 | (Optional) Initialize positional encoding network
12 |
13 | Args:
14 | t_channel: number of modulation channel
15 | """
16 | super().__init__()
17 |
18 | def forward(self, t: Float):
19 | """
20 | Return the positional encoding of
21 |
22 | Args:
23 | t: input time
24 |
25 | Returns:
26 | emb: time embedding
27 | """
28 | emb = None
29 | return emb
30 |
31 |
32 | class MLP(nn.Module):
33 |
34 | def __init__(self,
35 | in_dim: Int,
36 | out_dim: Int,
37 | hid_shapes: Int[Array]):
38 | '''
39 | (TODO) Build simple MLP
40 |
41 | Args:
42 | in_dim: input dimension
43 | out_dim: output dimension
44 | hid_shapes: array of hidden layers' dimension
45 | '''
46 | super().__init__()
47 | self.model = None
48 |
49 | def forward(self, x: Array):
50 | return self.model(x)
51 |
52 |
53 |
54 | class SimpleNet(nn.Module):
55 |
56 | def __init__(self,
57 | in_dim: Int,
58 | enc_shapes: Int[Array],
59 | dec_shapes: Int[Array],
60 | z_dim: Int):
61 | super().__init__()
62 | '''
63 | (TODO) Build Score Estimation network.
64 | You are free to modify this function signature.
65 | You can design whatever architecture.
66 |
67 | hint: it's recommended to first encode the time and x to get
68 | time and x embeddings then concatenate them before feeding it
69 | to the decoder.
70 |
71 | Args:
72 | in_dim: dimension of input
73 | enc_shapes: array of dimensions of encoder
74 | dec_shapes: array of dimensions of decoder
75 | z_dim: output dimension of encoder
76 | '''
77 |
78 | def forward(self, t: Array, x: Array):
79 | '''
80 | (TODO) Implement the forward pass. This should output
81 | the score s of the noisy input x.
82 |
83 | hint: you are free
84 |
85 | Args:
86 | t: the time that the forward diffusion has been running
87 | x: the noisy data after t period diffusion
88 | '''
89 | s = None
90 | return s
91 |
--------------------------------------------------------------------------------
/sde_todo/sampling.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from tqdm import tqdm
4 | from jaxtyping import Float
5 |
6 | class Sampler():
7 |
8 | def __init__(self, eps: Float):
9 | self.eps = eps
10 |
11 | def get_sampling_fn(self, sde, dataset):
12 |
13 | def sampling_fn(N_samples: Int):
14 | """
15 | return the final denoised sample, number of step,
16 | timesteps, and trajectory.
17 |
18 | Args:
19 | N_samples: number of samples
20 |
21 | Returns:
22 | out: the final denoised samples (out == x_traj[-1])
23 | ntot (int): the total number of timesteps
24 | timesteps Int[Array]: the array of timesteps used
25 | x_traj: the entire diffusion trajectory
26 | """
27 | x = dataset[range(N_samples)] # initial sample
28 | timesteps = torch.linspace(0, sde.T-self.eps, sde.N)
29 |
30 | x_traj = torch.zeros((sde.N, *x.shape))
31 | with torch.no_grad():
32 | for i, t in enumerate(tqdm(timesteps, desc='sampling')):
33 | # TODO
34 | pass
35 |
36 | out = x
37 | ntot = sde.N
38 | return out, ntot, timesteps, x_traj
39 |
40 | return sampling_fn
41 |
--------------------------------------------------------------------------------
/sde_todo/sde.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | import numpy as np
4 | from jaxtyping import Array
5 |
6 | class SDE(abc.ABC):
7 | def __init__(self, N: int, T: int):
8 | super().__init__()
9 | self.N = N # number of discretization steps
10 | self.T = T # terminal time
11 | self.dt = T / N
12 | self.is_reverse = False
13 | self.is_bridge = False
14 |
15 | @abc.abstractmethod
16 | def sde_coeff(self, t, x):
17 | return NotImplemented
18 |
19 | @abc.abstractmethod
20 | def marginal_prob(self, t, x):
21 | return NotImplemented
22 |
23 | @abc.abstractmethod
24 | def predict_fn(self, x):
25 | return NotImplemented
26 |
27 | @abc.abstractmethod
28 | def correct_fn(self, t, x):
29 | return NotImplemented
30 |
31 | def dw(self, x, dt=None):
32 | """
33 | (TODO) Return the differential of Brownian motion
34 |
35 | Args:
36 | x: input data
37 |
38 | Returns:
39 | dw (same shape as x)
40 | """
41 | dt = self.dt if dt is None else dt
42 | return None
43 |
44 | def prior_sampling(self, x: Array):
45 | """
46 | Sampling from prior distribution. Default to unit gaussian.
47 |
48 | Args:
49 | x: input data
50 |
51 | Returns:
52 | z: random variable with same shape as x
53 | """
54 | return torch.randn_like(x)
55 |
56 | def predict_fn(self,
57 | t: Array,
58 | x: Array,
59 | dt: Float=None):
60 | """
61 | (TODO) Perform single step diffusion.
62 |
63 | Args:
64 | t: current diffusion time
65 | x: input with noise level at time t
66 | dt: the discrete time step. Default to T/N
67 |
68 | Returns:
69 | x: input at time t+dt
70 | """
71 | dt = self.dt if dt is None else dt
72 | pred = None
73 | return pred
74 |
75 | def correct_fn(self, t: Array, x: Array):
76 | return None
77 |
78 | def reverse(self, model):
79 | N = self.N
80 | T = self.T
81 | forward_sde_coeff = self.sde_coeff
82 |
83 | class RSDE(self.__class__):
84 | def __init__(self, score_fn):
85 | super().__init__(N, T)
86 | self.score_fn = score_fn
87 | self.is_reverse = True
88 | self.forward_sde_coeff = forward_sde_coeff
89 |
90 | def sde_coeff(self, t: Array, x: Array):
91 | """
92 | (TODO) Return the reverse drift and diffusion terms.
93 |
94 | Args:
95 | t: current diffusion time
96 | x: current input at time t
97 |
98 | Returns:
99 | reverse_f: reverse drift term
100 | g: reverse diffusion term
101 | """
102 | reverse_f = None
103 | g = None
104 | return reverse_f, g
105 |
106 | def ode_coeff(self, t: Array, x: Array):
107 | """
108 | (Optional) Return the reverse drift and diffusion terms in
109 | ODE sampling.
110 |
111 | Args:
112 | t: current diffusion time
113 | x: current input at time t
114 |
115 | Returns:
116 | reverse_f: reverse drift term
117 | g: reverse diffusion term
118 | """
119 | reverse_f = None
120 | g = None
121 | return reverse_f, g
122 |
123 | def predict_fn(self,
124 | t: Array,
125 | x,
126 | dt=None,
127 | ode=False):
128 | """
129 | (TODO) Perform single step reverse diffusion
130 |
131 | """
132 | return x
133 |
134 | return RSDE(model)
135 |
136 | class OU(SDE):
137 | def __init__(self, N=1000, T=1):
138 | super().__init__(N, T)
139 |
140 | def sde_coeff(self, t, x):
141 | f, g = None, None
142 | return f, g
143 |
144 | def marginal_prob(self, t, x):
145 | mean, std = None, None
146 | return mean, std
147 |
148 | class VESDE(SDE):
149 | def __init__(self, N=100, T=1, sigma_min=0.01, sigma_max=50):
150 | super().__init__(N, T)
151 |
152 | def sde_coeff(self, t, x):
153 | f, g = None, None
154 | return f, g
155 |
156 | def marginal_prob(self, t, x):
157 | mean, std = None, None
158 | return mean, std
159 |
160 |
161 | class VPSDE(SDE):
162 | def __init__(self, N=1000, T=1, beta_min=0.1, beta_max=20):
163 | super().__init__(N, T)
164 |
165 | def sde_coeff(self, t, x):
166 | f, g = None, None
167 | return f, g
168 |
169 | def marginal_prob(self, t, x):
170 | mean, std = None, None
171 | return mean, std
172 |
173 |
174 | class SB(abc.ABC):
175 | def __init__(self, N=1000, T=1, zf_model=None, zb_model=None):
176 | super().__init__()
177 | self.N = N # number of time step
178 | self.T = T # end time
179 | self.dt = T / N
180 |
181 | self.is_reverse = False
182 | self.is_bridge = True
183 |
184 | self.zf_model = zf_model
185 | self.zb_model = zb_model
186 |
187 | def dw(self, x, dt=None):
188 | dt = self.dt if dt is None else dt
189 | return torch.randn_like(x) * (dt**0.5)
190 |
191 | @abc.abstractmethod
192 | def sde_coeff(self, t, x):
193 | return NotImplemented
194 |
195 | def sb_coeff(self, t, x):
196 | """
197 | (Optional) Return the SB reverse drift and diffusion terms.
198 |
199 | Args:
200 | """
201 | sb_f = None
202 | g = None
203 | return sb_f, g
204 |
205 | def predict_fn(self,
206 | t: Array,
207 | x: Array,
208 | dt:Float =None):
209 | """
210 | Args:
211 | t:
212 | x:
213 | dt:
214 | """
215 | return x
216 |
217 | def correct_fn(self, t, x, dt=None):
218 | return x
219 |
220 | def reverse(self, model):
221 | """
222 | (Optional) Initialize the reverse process
223 | """
224 |
225 | class RSB(self.__class__):
226 | def __init__(self, model):
227 | super().__init__(N, T, zf_model, zb_model)
228 | """
229 | (Optional) Initialize the reverse process
230 | """
231 |
232 | def sb_coeff(self, t, x):
233 | """
234 | (Optional) Return the SB reverse drift and diffusion terms.
235 | """
236 | sb_f = None
237 | g = None
238 | return sb_f, g
239 |
240 | return RSDE(model)
241 |
242 | class OUSB(SB):
243 | def __init__(self, N=1000, T=1, zf_model=None, zb_model=None):
244 | super().__init__(N, T, zf_model, zb_model)
245 |
246 | def sde_coeff(self, t, x):
247 | f = -0.5 * x
248 | g = torch.ones(x.shape)
249 | return f, g
250 |
--------------------------------------------------------------------------------
/sde_todo/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from itertools import repeat
4 | import matplotlib.pyplot as plt
5 | from loss import ISMLoss, DSMLoss
6 |
7 | def freeze(model):
8 | """
9 | (Optional) This is for Alternating Schrodinger Bridge.
10 | """
11 | for p in model.parameters():
12 | p.requires_grad = False
13 | model.eval()
14 | return model
15 |
16 |
17 | def unfreeze(model):
18 | """
19 | (Optional) This is for Alternating Schrodinger Bridge.
20 | """
21 | for p in model.parameters():
22 | p.requires_grad = True
23 | model.train()
24 | return model
25 |
26 |
27 | def get_sde_step_fn(model, ema, opt, loss_fn, sde):
28 | def step_fn(batch):
29 | # uniformly sample time step
30 | t = sde.T*torch.rand(batch.shape[0])
31 |
32 | # TODO forward diffusion
33 | xt = None
34 |
35 | # get loss
36 | if isinstance(loss_fn, DSMLoss):
37 | logp_grad = None # TODO
38 | loss = loss_fn(t, xt.float(), model, logp_grad, diff_sq)
39 | elif isinstance(loss_fn, ISMLoss):
40 | loss = loss_fn(t, xt.float(), model)
41 | else:
42 | print(loss_fn)
43 | raise Exception("undefined loss")
44 |
45 | # optimize model
46 | opt.zero_grad()
47 | loss.backward()
48 | opt.step()
49 |
50 | if ema is not None:
51 | ema.update()
52 |
53 | return loss.item()
54 |
55 | return step_fn
56 |
57 |
58 | def get_sb_step_fn(model_f, model_b, ema_f, ema_b,
59 | opt_f, opt_b, loss_fn, sb, joint=True):
60 | def step_fn_alter(batch, forward):
61 | """
62 | (Optional) Implement the optimization step for alternating
63 | likelihood training of Schrodinger Bridge
64 | """
65 | pass
66 |
67 | def step_fn_joint(batch):
68 | """
69 | (Optional) Implement the optimization step for joint likelihood
70 | training of Schrodinger Bridge
71 | """
72 | pass
73 |
74 | if joint:
75 | return step_fn_joint
76 | else:
77 | return step_fn_alter
78 |
79 |
80 | def repeater(data_loader):
81 | for loader in repeat(data_loader):
82 | for data in loader:
83 | yield data
84 |
85 |
86 | def train_diffusion(dataloader, step_fn, N_steps, plot=False):
87 | pbar = tqdm(range(N_steps), bar_format="{desc}{bar}{r_bar}", mininterval=1)
88 | loader = iter(repeater(dataloader))
89 |
90 | log_freq = 200
91 | loss_history = torch.zeros(N_steps//log_freq)
92 | for i, step in enumerate(pbar):
93 | batch = next(loader)
94 | loss = step_fn(batch)
95 |
96 | if step % log_freq == 0:
97 | loss_history[i//log_freq] = loss
98 | pbar.set_description("Loss: {:.3f}".format(loss))
99 |
100 | if plot:
101 | plt.plot(range(len(loss_history)), loss_history)
102 | plt.show()
103 |
--------------------------------------------------------------------------------