├── .gitignore
├── README.md
├── config
├── celeba_hq_256.yaml
├── cifar10.yaml
├── cifar10_128dim.yaml
├── cifar10_example.yaml
├── cifar10_torch_example.yaml
└── inference
│ ├── celeba_hq_256.yaml
│ ├── cifar10.yaml
│ └── cifar10_128dim.yaml
├── images_README
├── celeba_hq_ex1.gif
├── celeba_hq_ex1.png
├── celeba_hq_ex2.gif
├── celeba_hq_ex2.png
├── celeba_hq_ex3.gif
├── celeba_hq_ex3.png
├── celeba_hq_ex4.gif
├── celeba_hq_ex4.png
├── cifar10_128_ex1.png
├── cifar10_128_ex2.gif
├── cifar10_128_ex3.png
├── cifar10_128_ex4.gif
├── cifar10_64_ex1.png
└── cifar10_64_ex2.gif
├── inference.py
├── requirements.txt
├── src
├── dataset.py
├── diffusion.py
├── inferencer.py
├── model_original.py
├── model_torch.py
├── trainer.py
└── utils.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | __pycache__/
4 | results/
5 | inference_results/
6 | tensorboard/
7 | *.pt
8 | tmp*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DDPM & DDIM PyTorch
2 | ### DDPM & DDIM re-implementation with various functionality
3 |
4 | This code is the implementation code of the papers DDPM ([Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239))
5 | and DDIM ([Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)).
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ---
14 | ## Objective
15 |
16 | Our code is mainly based on [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch)
17 | repository, which is a Pytorch implementation of the official Tensorflow code
18 | [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion). But we find that there are some
19 | differences in the implementation especially in **U-net** structure. And further we find that Pytorch implementation
20 | version lacks some functionality for monitoring training and does not have inference code. So we decided to
21 | re-implement the code such that it can be helpful for someone who is first to **Diffusion** models.
22 |
23 | ---
24 | ## Results
25 |
26 | | Dataset | Model checkpoint name | FID (↓) |
27 | |:---------:|:---------------------:|:-------:|
28 | | Cifar10 | [cifar10_64dim.pt](https://drive.google.com/file/d/1vHfP8f_viyadhuXMaLfAQ1Iu5UE0WiiJ/view?usp=drive_link) | 11.81 |
29 | | Cifar10 | [cifar10_128dim.pt](https://drive.google.com/file/d/1NtysETxHPinns6JabjawyWTnkjJKT34M/view?usp=drive_link) | 8.31 |
30 | | CelebA-HQ | [celeba_hq_256.pt](https://drive.google.com/file/d/1zzZbkNkMYCFKmWKW5Sh2JsrUsNrWyDCs/view?usp=drive_link) | 11.97 |
31 |
32 | - cifar10_64dim
33 |
34 |
35 |
36 |
37 | - cifar10_128dim
38 |
39 |
40 |
41 |
42 | - celeba_hq_256
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | ---
54 | ## Installation
55 | Tested for ```python 3.8.17``` with ```torch 1.12.1+cu113``` and ```torchvision 0.13.1+cu113```.
56 | Download appropriate pytorch version via [torch website](https://pytorch.org/get-started/previous-versions/#v1121) or by following command.
57 | ```
58 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
59 | ```
60 | Install other required moduls by following command.
61 | ```
62 | pip install -r requirements.txt
63 | ```
64 |
65 | ---
66 | ## Quick Start
67 | ### Inference
68 | Download pre-trained model checkpoints from [model checkpoints](https://drive.google.com/drive/folders/1YdFQEb3d7rInRVVLLN3VZu-fI0qq15pG?usp=sharing)
69 |
70 | - Cifar10 (64 dimension for first hidden dimension)
71 | ```commandline
72 | python inference.py -c ./config/inference/cifar10.yaml -l /path_to_cifar10_64dim.pt/cifar10_64dim.pt
73 | ```
74 | - Cifar10 (128 dimension for first hidden dimension, purposed structure by original implementation)
75 | ```commandline
76 | python inference.py -c ./config/inference/cifar10_128dim.yaml -l /path_to_cifar10_128dim.pt/cifar10_128dim.pt
77 | ```
78 | - CelebA-HQ
79 |
80 | You have to download CelebA-HQ dataset from [kaggle](https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256/).
81 | After un-zipping the zip file, you may find folder named ```/celeba_hq_256```, make the folder named ```/data``` if your
82 | project directory does not have it, place the ```/celeba_hq_256``` folder under the ```/data``` folder such that final
83 | structure must be configured as follows.
84 | ```
85 | - DDPM
86 | - /config
87 | - /src
88 | - /images_README
89 | - inference.py
90 | - train.py
91 | ...
92 | - /data (make the directory if you don't have it)
93 | - /celeba_hq_256
94 | - 00000.jpg
95 | - 00001.jpg
96 | ...
97 | - 29999.jpg
98 |
99 | ```
100 |
101 | ```commandline
102 | python inference.py -c ./config/inference/celeba_hq_256.yaml -l /path_to_celeba_hq_256.pt/celeba_hq_256.pt
103 | ```
104 |
105 | ### Training
106 |
107 | - Cifar10 (64 dimension for first hidden dimension)
108 | ```commandline
109 | python train.py -c ./config/cifar10.yaml
110 | ```
111 | - Cifar10 (128 dimension for first hidden dimension, purposed structure by original implementation)
112 | ```commandline
113 | python train.py -c ./config/cifar10_128dim.yaml
114 | ```
115 | - CelebA-HQ
116 |
117 | You have to download dataset, consult details in inference section in Quick Start.
118 |
119 | ```commandline
120 | python train.py -c ./config/celeba_hq_256.yaml
121 | ```
122 |
123 | :pushpin:
124 | ***You can find more detailed explanation for training and inference at the below two section.***
125 |
126 | ---
127 | ## Training ( Detailed version )
128 |
129 |
130 | Expand for details
131 |
132 | To train the diffusion model, first thing you have to do is to configure your training settings by making configuration
133 | file. You can find some example inside the folder ```./config```. I will explain how to configure your training using
134 | ```./config/cifar10_example.yaml``` file and ```./config/cifar10_torch_example.yaml``` file.
135 | Inside the ```cifar10_example.yaml``` you may find 4 primary section, ```type, unet, ddim, trainer```.
136 |
137 | We will first look at ```trainer``` section which is configured as follows.
138 |
139 | ```yaml
140 | dataset: cifar10
141 | batch_size: 128
142 | lr: 0.0002
143 | total_step: 600000
144 | save_and_sample_every: 2500
145 | num_samples: 64
146 | fid_estimate_batch_size: 128
147 | ddpm_fid_score_estimate_every: null
148 | ddpm_num_fid_samples: null
149 | tensorboard: true
150 | clip: both
151 | ```
152 | - ```dataset```: You can give Dataset name which is available by ```torchvision.datasets```. You can find some Datasets provided
153 | by torchvision in this [website](https://pytorch.org/vision/stable/datasets.html). If you want to use torchvision's
154 | Dataset just provide the dataset name, for example ```cifar10```. Currently, tested datasets are ```cifar10```.
155 | Or if you want to use custom datasets which you have prepared, you have to pass the path to the folder which is containing
156 | images
157 |
158 |
159 | - ```batch_size, lr, total_step```: You can find the values used by DDPM author in the [DDPM paper](https://arxiv.org/abs/2006.11239)
160 | Appendix B. total_step means total training step, for example DDPM author trained cifar10 model with 800K steps.
161 |
162 |
163 | - ```save_and_sample_every```: The interval to which save the model and generated samples. For example in this case,
164 | for every 2500 steps trainer will save the model weights and also generate some sample images to visualize the training progress.
165 |
166 |
167 | - ```num_samples```: When sampling the images evey ```save_and_sample_every``` steps, trainer will sample total ```num_samples```
168 | images and save it to one large image containing each sampled images where one large image have (num_samples)**0.5 rows and
169 | columns. So ```num_samples``` must be square number ex) 25, 36, 49, 64, ...
170 |
171 |
172 | - ```fid_estimate_batch_size```: Batch size for sampling images for FID calculation. This batch size will be applied to
173 | DDPM sampler as well as DDIM samplers. If you cannot decide the value, just setting this value equal to ```batch_size``` will
174 | be fine.
175 |
176 |
177 | - ```ddpm_fid_score_estimate_every```: Step interval for FID calculation using DDPM sampler. If set to null, FID score
178 | will not be calculated with DDPM sampling. If you use DDPM sampling for FID calculation, i.e. setting this value other than null,
179 | it can be very time-consuming, ***so it is wise to set this value to null***, and use DDIM sampler for
180 | FID calculation (Using DDIM sampler is explained below). But anyway you can calculate FID score with DDPM sampler if you **insist**.
181 |
182 |
183 | - ```ddpm_num_fid_samples```: Number of sampling images for FID calculation using DDPM sampler. If you set
184 | ```ddpm_fid_score_estimate_every``` to null, i.e. not using DDPM sampler for FID calculation, then this value will
185 | be just ignored.
186 |
187 |
188 | - ```tensorboard```: If set to true, then you can monitor training progress such as loss, FID value, and sampled images,
189 | during training, with the tensorboard.
190 |
191 |
192 | - ```clip```: It must be one of [both, true, false]. This is related to sampling of x_{t-1} from p_{theta}(x_{t-1} | x_t).
193 | There are two ways to sample x_{t-1}.
194 | One way is to follow paper and this corresponds to line 4 in Algorithm 2 in DDPM paper. (```clip==False```)
195 | Another way is to clip(or clamp) the predicted x_0 to -1 ~ 1 for better sampling result.
196 | To clip the x_0 to out desired range, we cannot directly apply (11) to sample x_{t-1}, rather we have to
197 | calculate predicted x_0 using (4) and then calculate mu in (7) using that predicted x_0. Which is exactly
198 | same calculation except for clipping.
199 | As you might easily expect, using clip leads to better sampling result since it
200 | restricts sampled images range to -1 ~ 1. So for the better sampling result, it is strongly suggested
201 | setting ```clip``` to true. If ```clip==both``` then sampling is done twice, one done with
202 | ```clip==True``` and the other done with ```clip==False```.
203 | [Reference](https://github.com/hojonathanho/diffusion/issues/5)
204 |
205 | ---
206 |
207 | Now we will look at ```type, unet``` section which is configured as follows.
208 |
209 | ```yaml
210 | # ```./config/cifar10_example.yaml```
211 | type: original
212 | unet:
213 | dim: 64
214 | image_size: 32
215 | dim_multiply:
216 | - 1
217 | - 2
218 | - 2
219 | - 2
220 | attn_resolutions:
221 | - 16
222 | dropout: 0.1
223 | num_res_blocks: 2
224 | ```
225 |
226 | ```yaml
227 | # ```./config/cifar10_torch_example.yaml```
228 | type: torch
229 | unet:
230 | dim: 64
231 | image_size: 32
232 | dim_multiply:
233 | - 1
234 | - 2
235 | - 2
236 | - 2
237 | full_attn:
238 | - false
239 | - true
240 | - false
241 | - false
242 | attn_heads: 4
243 | attn_head_dim: 32
244 | ```
245 |
246 | -```type```: It must be one of [original, torch]. ***original*** will use U-net structure which was originally suggested by
247 | Jonathan Ho. So it's structure will be the one used in [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
248 | which is an official version written in Tensorflow. ***torch*** will use U-net structure which was suggested by
249 | [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch) which is a transcribed version of
250 | official Tensorflow version.
251 |
252 | I have separated those two because there structure differs significantly. To name a few, following is the difference of
253 | those two U-net structure.
254 | 1. official version use self Attention where the feature map resolution at each U-net level
255 | is in ```attn_resolutions```. In the DDPM paper you can find that they used self Attention at the 16X16 resolution,
256 | and this is why ```attn_resolutions``` is by default ```[16, ]```
257 |
258 | On the other hand, Pytorch transcribed version use Linear Attention and multi-head self Attention. They use
259 | multi-head self Attention at the U-net level where ```full_attn``` is true and Linear Attention at the
260 | rest of the U-net level. So in this particular case, they used multi-head self Attention at the U-net level 1 (I will
261 | denote U-net level as 0, 1, 2, 3, ...) and the Linear Attention at the U-net level 0, 2, 3.
262 |
263 | -```unet.dim```: This is related to the hidden channel dimension of feature map at each U-net model. You can find more
264 | detail right below.
265 |
266 | -```unet.dim_multiply```: ```len(dim_multiply)``` will be the depth of U-net model with at each level i, the dimension
267 | of channel will be ```dim * dim_multiply[i]```. If the input image shape is [H, W, 3] then at the lowest level,
268 | feature map shape will be [H/(2^(len(dim_multiply)-1), W/(2^(len(dim_multiply)-1), dim*dim_multiply[-1]]
269 | if not considering U-net down-up path connection.
270 |
271 | -```unet.image_size```: Size of the input image. Image will be resized and cropped if ```image_size``` does not
272 | equal to actual input image size. Expected to be ```Integer```, I have not tested for non-square images.
273 |
274 | -```unet.attn_resolution / unet.full_attn```: Explained above. Since ```attn_resolution``` value
275 | must equal to the resolution value of feature map where you want to apply self Attention, you have to carefully
276 | calculate desired resolution value. In the case of ```full_attn```, it is related to applying particular Attention mechanism at each
277 | level, it must satisfy ```len(full_attn) == len(dim_multiply)```
278 |
279 | -```unet.num_res_blocks```: Number of ResnetBlock at each level. In downward path, at each level, there will be
280 | num_res_blocks amount of ResnetBlock module and in upward path, at each level, there will be
281 | (num_res_blocks+1) amount of ResnetBlock module.
282 |
283 | -```unet.attn_heads, unet.attn_head_dim```: In the torch implementation it uses multi-head-self-Attention. attn_head is
284 | the # of head. It corresponds to h in "Attention is all you need" paper. See section 3.2.2
285 | attn_head_dim is the dimension of each head. It corresponds to d_k in Attention paper.
286 |
287 | ---
288 |
289 | Lastly we will look at ```ddim``` section which is configured as follows.
290 |
291 |
292 | ```yaml
293 | ddim:
294 | 0:
295 | ddim_sampling_steps: 20
296 | sample_every: 5000
297 | calculate_fid: true
298 | num_fid_sample: 6000
299 | eta: 0
300 | save: true
301 | 1:
302 | ddim_sampling_steps: 50
303 | sample_every: 50000
304 | calculate_fid: true
305 | num_fid_sample: 6000
306 | eta: 0
307 | save: true
308 | 2:
309 | ddim_sampling_steps: 20
310 | sample_every: 100000
311 | calculate_fid: true
312 | num_fid_sample: 60000
313 | eta: 0
314 | save: true
315 | ```
316 |
317 | There are 3 subsection (0, 1, 2) which means it will use 3 DDIM Sampler for sampling image
318 | and FID calculation during training. The name of each subsection, which are 0, 1, 2,
319 | is not important. Each DDIM Sampler name will be set to ```DDIM_{index}_steps{ddim_sampling_steps}_eta{eta}``` no matter
320 | what the name of each subsection is set in configuration file.
321 |
322 | -```ddim_sampling_steps```: The number of de-noising steps for DDIM sampling. In DDPM they used 1000 steps for sampling images. But
323 | in DDIM we can control the total number of de-noising steps for generating images. If this value is set to 20,
324 | then the speed of generating image will be 50 times faster than DDPM sampling. Preferred value would be 10~100.
325 |
326 | -```sample_every```: This control how often sampling be done with particular sampler. If set to 5000 then
327 | every 5000 steps, this sampler will be activated for sampling images. So if total training step is 50000, there will be
328 | total 10 sampling action.
329 |
330 | -```calculata_fid```: Whether to calculate FID
331 |
332 | -```num_fid_sample```: Only valid if ```calculate_fid``` is set to true. This control how many sampled images to use for
333 | FID calculation. The speed of FID calculation for particular sampler will be inversely
334 | proportional to (ddim_sanmpling_steps * num_fid_sample)
335 |
336 | -```eta```: Hyperparameter to control the stochasticity, see (16) in DDIM paper.
337 | 0: deterministic(DDIM) , 1: fully stochastic(DDPM)
338 |
339 | -```save```: Whether to save the model checkpoint based on FID value calculated by particular sampler. If set to true
340 | then model checkpoint will be saved on ```.pt``` file when model achieve the best
341 | FID value for particular sampler.
342 |
343 | ---
344 |
345 | Now we are finished setting configuration file and the training the model can be done by following command.
346 |
347 | ```commandline
348 | python train.py -c /path_to_config_file/configuration_file.yaml
349 | ```
350 | [Mandatory]
351 |
352 | -c, --config : Path to configuration file. Path must include file name with extension .yaml
353 |
354 | [Optional]
355 |
356 | -l, --load : Path to model checkpoint. Path must include file name with extension .pt
357 | You can resume training by using this option.
358 | -t, --tensorboard : Path to tensorboard folder. If you resume training and want to restore previous tensorboard, set
359 | this option to previously generated tensorboard folder.
360 | --exp_name : Name for experiment. If not set, current time will be set as experiment name.
361 | --cpu_percentage : Float value from 0.0 to 1.0, default value 0.0 It is used to control the num_workers parameter for DataLoader.
362 | num_workers will be set to "Number of CPU available for your device * cpu_percentage". In Windows sometimes setting
363 | this value other than 0.0 yields unexpected behavior or failure to train the model. So if you have problem triaining
364 | the model in Windows, do not change this value.
365 | --no_prev_ddim_setting : If set, store true. If you have changed DDIM setting, for example change the
366 | number of DDIM sampler or change the sampling steps for DDIM sampler, set this option.
367 |
368 |
369 | ---
370 | ## Inference ( Detailed version )
371 |
372 |
373 | Expand for details
374 |
375 | To inference the diffusion model, first thing you have to do is to configure your inference settings by making configuration
376 | file. You can find some example inside the folder ```./config/inference```. I will explain how to configure your inference using
377 | ```./config/inference/cifar10.yaml``` file
378 | Inside the file you may find 4 primary section, ```type, unet, ddim, inferencer```. ```type, unet``` must match the
379 | configuration for the training. On the other hand, ```ddim``` section need not match to the configuration for the
380 | training. One thing to notice is that ```sample_every, save``` option will not be used in inference for DDIM. We
381 | are left with ```inferencer``` section.
382 |
383 | ```yaml
384 | inferencer:
385 | dataset: cifar10
386 | batch_size: 128
387 | clip: true
388 | num_samples_per_image: 64
389 | num_images_to_generate: 2
390 | ddpm_fid_estimate: true
391 | ddpm_num_fid_samples: 60000
392 | return_all_step: true
393 | make_denoising_gif: true
394 | num_gif: 50
395 | ```
396 |
397 | -```dataset, batch_size, clip```: Same meaning as in configuration file for training.
398 |
399 | -```num_samples_per_image```: Same meaning as ```num_samples``` in configuration file for training.
400 | Sampler will sample total ```num_samples_per_image``` images and save it to one large image containing each sampled images
401 | where one large image have (num_samples)**0.5 rows and columns. So ```num_samples_per_image``` must be square number ex) 25, 36, 49, 64, ...
402 |
403 | -```num_images_to_generate```: How many large merged image to generate. So if this value is set to 2 then there will be
404 | 2 large image with each image containing ```num_samples_per_image``` sampled sub images.
405 |
406 | -```ddpm_fid_estimate, ddpm_num_fid_samples```: Whether to calculate FID value for DDPM sampler. And if ```ddpm_fid_estimate```
407 | is set to true, ```ddpm_num_fid_samples``` decides the number of sampling images for calculating FID value.
408 |
409 | -```return_all_steps```: Whether to return all the images during de-noising steps. So in DDPM sampler, during 1000
410 | de-noising steps all the intermediate images will be returned. In the case of DDIM sampler, ```ddim_sampling_steps``` images
411 | will be returned.
412 |
413 | -```make_denoising_gif```: Whether to make gif which contains de-noising process visually. To make denoising gif,
414 | ```return_all_steps``` must set to true.
415 |
416 | -```num_gif```: Number of images to make gif which contains de-noising process visually. Intermediate denoised image
417 | will be sampled evenly with ```num_gif``` images to make denoising gif.
418 |
419 | ---
420 |
421 | Now we are finished setting configuration file and the inferencing can be done by following command.
422 |
423 | ```commandline
424 | python inference.py -c /path_to_config_file/configuration_file.yaml -l /path_to_model_checkpoint_file/model_checkpoint.pt
425 | ```
426 |
427 | [Mandatory]
428 |
429 | -c, --config : Path to configuration file. Path must include file name with extension .yaml
430 | -l, --load : Path to model checkpoint. Path must include file name with extension .pt
431 |
432 |
433 | ---
434 | ## References
435 | - [Jonathan Ho, Ajay Jain, Pieter Abbeel : Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
436 |
437 | - [Jiaming Song, Chenlin Meng, Stefano Ermon : Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
438 |
--------------------------------------------------------------------------------
/config/celeba_hq_256.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 128
4 | image_size: 64
5 | dim_multiply:
6 | - 1
7 | - 1
8 | - 2
9 | - 2
10 | - 4
11 | attn_resolutions:
12 | - 16
13 | dropout: 0.0
14 | num_res_blocks: 2
15 | ddim:
16 | 0:
17 | ddim_sampling_steps: 20
18 | sample_every: 20000
19 | calculate_fid: true
20 | num_fid_sample: 10000
21 | save: true
22 | trainer:
23 | dataset: ./data/celeba_hq_256
24 | batch_size: 64
25 | lr: 2.0e-05
26 | clip: true
27 | total_step: 500000
28 | save_and_sample_every: 10000
29 | fid_estimate_batch_size: 64
30 | num_samples: 64
31 |
--------------------------------------------------------------------------------
/config/cifar10.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 64
4 | image_size: 32
5 | dim_multiply:
6 | - 1
7 | - 2
8 | - 2
9 | - 2
10 | attn_resolutions:
11 | - 16
12 | dropout: 0.1
13 | num_res_blocks: 2
14 | ddim:
15 | 0:
16 | ddim_sampling_steps: 20
17 | sample_every: 10000
18 | calculate_fid: true
19 | num_fid_sample: 30000
20 | save: true
21 | trainer:
22 | dataset: cifar10
23 | batch_size: 128
24 | lr: 0.0002
25 | total_step: 600000
26 | save_and_sample_every: 2500
27 | num_samples: 64
28 | fid_estimate_batch_size: 128
29 | clip: true
30 |
--------------------------------------------------------------------------------
/config/cifar10_128dim.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 128
4 | image_size: 32
5 | dim_multiply:
6 | - 1
7 | - 2
8 | - 2
9 | - 2
10 | attn_resolutions:
11 | - 16
12 | dropout: 0.1
13 | num_res_blocks: 2
14 | ddim:
15 | 0:
16 | ddim_sampling_steps: 20
17 | sample_every: 10000
18 | calculate_fid: true
19 | num_fid_sample: 30000
20 | save: true
21 | trainer:
22 | dataset: cifar10
23 | batch_size: 128
24 | lr: 0.0002
25 | total_step: 600000
26 | save_and_sample_every: 5000
27 | fid_estimate_batch_size: 128
28 | num_samples: 64
29 | clip: true
30 |
31 |
--------------------------------------------------------------------------------
/config/cifar10_example.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 64
4 | image_size: 32
5 | dim_multiply:
6 | - 1
7 | - 2
8 | - 2
9 | - 2
10 | attn_resolutions:
11 | - 16
12 | dropout: 0.1
13 | num_res_blocks: 2
14 | ddim:
15 | 0:
16 | ddim_sampling_steps: 20
17 | sample_every: 5000
18 | calculate_fid: true
19 | num_fid_sample: 6000
20 | eta: 0
21 | save: true
22 | 1:
23 | ddim_sampling_steps: 50
24 | sample_every: 50000
25 | calculate_fid: true
26 | num_fid_sample: 6000
27 | eta: 0
28 | save: true
29 | 2:
30 | ddim_sampling_steps: 20
31 | sample_every: 100000
32 | calculate_fid: true
33 | num_fid_sample: 60000
34 | eta: 0
35 | save: true
36 | trainer:
37 | dataset: cifar10
38 | batch_size: 128
39 | lr: 0.0002
40 | total_step: 600000
41 | save_and_sample_every: 2500
42 | num_samples: 64
43 | fid_estimate_batch_size: 128
44 | ddpm_fid_score_estimate_every: null
45 | ddpm_num_fid_samples: null
46 | tensorboard: true
47 | clip: both
48 |
--------------------------------------------------------------------------------
/config/cifar10_torch_example.yaml:
--------------------------------------------------------------------------------
1 | type: torch
2 | unet:
3 | dim: 64
4 | image_size: 32
5 | dim_multiply:
6 | - 1
7 | - 2
8 | - 2
9 | - 2
10 | full_attn:
11 | - false
12 | - true
13 | - false
14 | - false
15 | attn_heads: 4
16 | attn_head_dim: 32
17 | ddim:
18 | 0:
19 | ddim_sampling_steps: 20
20 | sample_every: 5000
21 | calculate_fid: true
22 | num_fid_sample: 6000
23 | eta: 0
24 | save: true
25 | 1:
26 | ddim_sampling_steps: 50
27 | sample_every: 50000
28 | calculate_fid: true
29 | num_fid_sample: 6000
30 | eta: 0
31 | save: true
32 | 2:
33 | ddim_sampling_steps: 20
34 | sample_every: 100000
35 | calculate_fid: true
36 | num_fid_sample: 60000
37 | eta: 0
38 | save: true
39 | trainer:
40 | dataset: cifar10
41 | batch_size: 128
42 | lr: 0.0002
43 | total_step: 600000
44 | save_and_sample_every: 2500
45 | num_samples: 64
46 | fid_estimate_batch_size: 128
47 | ddpm_fid_score_estimate_every: null
48 | ddpm_num_fid_samples: null
49 | tensorboard: true
50 | clip: both
51 |
--------------------------------------------------------------------------------
/config/inference/celeba_hq_256.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 128
4 | image_size: 64
5 | dim_multiply:
6 | - 1
7 | - 1
8 | - 2
9 | - 2
10 | - 4
11 | attn_resolutions:
12 | - 16
13 | dropout: 0.0
14 | num_res_blocks: 2
15 | ddim:
16 | inferencer:
17 | dataset: ./data/celeba_hq_256
18 | batch_size: 32
19 | clip: true
20 | num_samples_per_image: 64
21 | num_images_to_generate: 20
22 | ddpm_fid_estimate: true
23 | ddpm_num_fid_samples: 30000
24 | return_all_step: true
25 | make_denoising_gif: true
26 | num_gif: 50
27 |
--------------------------------------------------------------------------------
/config/inference/cifar10.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 64
4 | image_size: 32
5 | dim_multiply:
6 | - 1
7 | - 2
8 | - 2
9 | - 2
10 | attn_resolutions:
11 | - 16
12 | dropout: 0.1
13 | num_res_blocks: 2
14 | ddim:
15 | 0:
16 | ddim_sampling_steps: 20
17 | calculate_fid: true
18 | num_fid_sample: 6000
19 | generate_image: true
20 | inferencer:
21 | dataset: cifar10
22 | batch_size: 128
23 | clip: true
24 | num_samples_per_image: 64
25 | num_images_to_generate: 10
26 | ddpm_fid_estimate: true
27 | ddpm_num_fid_samples: 60000
28 | return_all_step: true
29 | make_denoising_gif: true
30 | num_gif: 50
--------------------------------------------------------------------------------
/config/inference/cifar10_128dim.yaml:
--------------------------------------------------------------------------------
1 | type: original
2 | unet:
3 | dim: 128
4 | image_size: 32
5 | dim_multiply:
6 | - 1
7 | - 2
8 | - 2
9 | - 2
10 | attn_resolutions:
11 | - 16
12 | dropout: 0.1
13 | num_res_blocks: 2
14 | ddim:
15 | 0:
16 | ddim_sampling_steps: 20
17 | calculate_fid: true
18 | num_fid_sample: 6000
19 | generate_image: true
20 | inferencer:
21 | dataset: cifar10
22 | batch_size: 128
23 | clip: true
24 | num_samples_per_image: 64
25 | num_images_to_generate: 20
26 | ddpm_fid_estimate: true
27 | ddpm_num_fid_samples: 60000
28 | return_all_step: true
29 | make_denoising_gif: true
30 | num_gif: 50
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex1.gif
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex1.png
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex2.gif
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex2.png
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex3.gif
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex3.png
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex4.gif
--------------------------------------------------------------------------------
/images_README/celeba_hq_ex4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex4.png
--------------------------------------------------------------------------------
/images_README/cifar10_128_ex1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex1.png
--------------------------------------------------------------------------------
/images_README/cifar10_128_ex2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex2.gif
--------------------------------------------------------------------------------
/images_README/cifar10_128_ex3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex3.png
--------------------------------------------------------------------------------
/images_README/cifar10_128_ex4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex4.gif
--------------------------------------------------------------------------------
/images_README/cifar10_64_ex1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_64_ex1.png
--------------------------------------------------------------------------------
/images_README/cifar10_64_ex2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_64_ex2.gif
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from src import model_torch
2 | from src import model_original
3 | from src.diffusion import GaussianDiffusion, DDIM_Sampler
4 | from src.inferencer import Inferencer
5 | import yaml
6 | import argparse
7 |
8 |
9 | def main(args):
10 | with open(args.config, 'r') as f:
11 | config = yaml.load(f, Loader=yaml.FullLoader)
12 | unet_cfg = config['unet']
13 | ddim_cfg = config['ddim']
14 | trainer_cfg = config['inferencer']
15 | image_size = unet_cfg['image_size']
16 |
17 | if config['type'] == 'original':
18 | unet = model_original.Unet(**unet_cfg).to(args.device)
19 | elif config['type'] == 'torch':
20 | unet = model_torch.Unet(**unet_cfg).to(args.device)
21 | else:
22 | unet = None
23 | print("Unet type must be one of ['original', 'torch']")
24 | exit()
25 |
26 | diffusion = GaussianDiffusion(unet, image_size=image_size).to(args.device)
27 |
28 | ddim_samplers = list()
29 | if isinstance(ddim_cfg, dict):
30 | for sampler_cfg in ddim_cfg.values():
31 | ddim_samplers.append(DDIM_Sampler(diffusion, **sampler_cfg))
32 |
33 | inferencer = Inferencer(diffusion, ddim_samplers=ddim_samplers, time_step=diffusion.time_step, **trainer_cfg)
34 | inferencer.load(args.load)
35 | inferencer.inference()
36 |
37 |
38 | if __name__ == '__main__':
39 | parse = argparse.ArgumentParser(description='DDPM & DDIM')
40 | parse.add_argument('-c', '--config', type=str, default='./config/inference/cifar10.yaml')
41 | parse.add_argument('-l', '--load', type=str, default=None)
42 | parse.add_argument('-d', '--device', type=str, choices=['cuda', 'cpu'], default='cuda')
43 | args = parse.parse_args()
44 | main(args)
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | pillow==9.5.0
3 | tensorboard
4 | tqdm
5 | termcolor
6 | pytorch-fid
7 | pyyaml
8 | six
9 | imageio
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | from torch.utils.data import Dataset
5 | from termcolor import colored
6 | import ssl
7 | import os
8 | from glob import glob
9 | from PIL import Image
10 |
11 | ssl._create_default_https_context = ssl._create_unverified_context
12 |
13 |
14 | class customDataset(Dataset):
15 | def __init__(self, folder, transform, exts=['jpg', 'jpeg', 'png', 'tiff']):
16 | super().__init__()
17 | self.paths = [p for ext in exts for p in glob(os.path.join(folder, f'*.{ext}'))]
18 | self.transform = transform
19 |
20 | def __len__(self):
21 | return len(self.paths)
22 |
23 | def __getitem__(self, item):
24 | img = Image.open(self.paths[item])
25 | return self.transform(img)
26 |
27 |
28 | def dataset_wrapper(dataset, image_size, augment_horizontal_flip=True, info_color='light_green', min1to1=True):
29 | transform = transforms.Compose([
30 | transforms.Resize(image_size),
31 | transforms.RandomHorizontalFlip() if augment_horizontal_flip else torch.nn.Identity(),
32 | transforms.CenterCrop(image_size),
33 | transforms.ToTensor(), # turn into torch Tensor of shape CHW, 0 ~ 1
34 | transforms.Lambda(lambda x: ((x * 2) - 1)) if min1to1 else torch.nn.Identity()# -1 ~ 1
35 | ])
36 | if os.path.isdir(dataset):
37 | print(colored('Loading local file directory', info_color))
38 | dataSet = customDataset(dataset, transform)
39 | print(colored('Successfully loaded {} images!'.format(len(dataSet)), info_color))
40 | return dataSet
41 | else:
42 | dataset = dataset.lower()
43 | assert dataset in ['cifar10']
44 | print(colored('Loading {} dataset'.format(dataset), info_color))
45 | if dataset == 'cifar10':
46 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
47 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
48 | fullset = torch.utils.data.ConcatDataset([trainset, testset])
49 | return fullset
50 |
--------------------------------------------------------------------------------
/src/diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from tqdm import tqdm
5 |
6 |
7 | class GaussianDiffusion(nn.Module):
8 | def __init__(self, model, image_size, time_step=1000, loss_type='l2'):
9 | """
10 | Diffusion model. It is based on Denoising Diffusion Probabilistic Models (DDPM), Jonathan Ho et al.
11 | :param model: U-net model for de-noising network
12 | :param image_size: image size
13 | :param time_step: Gaussian diffusion length T. In paper they used T=1000
14 | :param loss_type: either l1, l2, huber. Default, l2
15 | """
16 | super().__init__()
17 | self.unet = model
18 | self.channel = self.unet.channel
19 | self.device = self.unet.device
20 | self.image_size = image_size
21 | self.time_step = time_step
22 | self.loss_type = loss_type
23 |
24 | beta = self.linear_beta_schedule() # (t, ) t=time_step, in DDPM paper t=1000
25 | alpha = 1. - beta # (a1, a2, a3, ... at)
26 | alpha_bar = torch.cumprod(alpha, dim=0) # (a1, a1*a2, a1*a2*a3, ..., a1*a2*~*at)
27 | alpha_bar_prev = F.pad(alpha_bar[:-1], pad=(1, 0), value=1.) # (1, a1, a1*a2, ..., a1*a2*~*a(t-1))
28 |
29 | self.register_buffer('beta', beta)
30 | self.register_buffer('alpha', alpha)
31 | self.register_buffer('alpha_bar', alpha_bar)
32 | self.register_buffer('alpha_bar_prev', alpha_bar_prev)
33 |
34 | # calculation for q(x_t | x_0) consult (4) in DDPM paper.
35 | self.register_buffer('sqrt_alpha_bar', torch.sqrt(alpha_bar))
36 | self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - alpha_bar))
37 |
38 | # calculation for q(x_{t-1} | x_t, x_0) consult (7) in DDPM paper.
39 | self.register_buffer('beta_tilde', beta * ((1. - alpha_bar_prev) / (1. - alpha_bar)))
40 | self.register_buffer('mean_tilde_x0_coeff', beta * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar))
41 | self.register_buffer('mean_tilde_xt_coeff', torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
42 |
43 | # calculation for x0 consult (9) in DDPM paper.
44 | self.register_buffer('sqrt_recip_alpha_bar', torch.sqrt(1. / alpha_bar))
45 | self.register_buffer('sqrt_recip_alpha_bar_min_1', torch.sqrt(1. / alpha_bar - 1))
46 |
47 | # calculation for (11) in DDPM paper.
48 | self.register_buffer('sqrt_recip_alpha', torch.sqrt(1. / alpha))
49 | self.register_buffer('beta_over_sqrt_one_minus_alpha_bar', beta / torch.sqrt(1. - alpha_bar))
50 |
51 | # Forward Process / Diffusion Process ##############################################################################
52 | def q_sample(self, x0, t, noise):
53 | """
54 | Sampling x_t, according to q(x_t | x_0). Consult (4) in DDPM paper.
55 | :param x0: (b, c, h, w), original image
56 | :param t: (b, ), timestep t
57 | :param noise: (b, c, h, w), We calculate q(x_t | x_0) using re-parameterization trick.
58 | :return: x_t with shape=(b, c, h, w), which is a noised image at timestep t
59 | """
60 | # Get x_t ~ q(x_t | x_0) using re-parameterization trick
61 | return self.sqrt_alpha_bar[t][:, None, None, None] * x0 + \
62 | self.sqrt_one_minus_alpha_bar[t][:, None, None, None] * noise
63 |
64 | def forward(self, img):
65 | """
66 | Calculate L_simple according to (14) in DDPM paper
67 | :param img: (b, c, h, w), original image
68 | :return: L_simple
69 | """
70 | b, c, h, w = img.shape
71 | assert h == self.image_size and w == self.image_size, f'height and width of image must be {self.image_size}'
72 | t = torch.randint(0, self.time_step, (b,), device=img.device).long() # (b, )
73 | noise = torch.randn_like(img) # corresponds to epsilon in (14)
74 | noised_image = self.q_sample(img, t, noise) # argument inside epsilon_theta
75 | predicted_noise = self.unet(noised_image, t) # epsilon_theta in (14)
76 |
77 | if self.loss_type == 'l1':
78 | loss = F.l1_loss(noise, predicted_noise)
79 | elif self.loss_type == 'l2':
80 | loss = F.mse_loss(noise, predicted_noise)
81 | elif self.loss_type == "huber":
82 | loss = F.smooth_l1_loss(noise, predicted_noise)
83 | else:
84 | raise NotImplementedError()
85 | return loss
86 |
87 | ####################################################################################################################
88 |
89 | # Reverse Process / De-noising Process #############################################################################
90 | @torch.inference_mode()
91 | def p_sample(self, xt, t, clip=True):
92 | """
93 | Sample x_{t-1} from p_{theta}(x_{t-1} | x_t).
94 | There are two ways to sample x_{t-1}.
95 | One way is to follow paper and this corresponds to line 4 in Algorithm 2 in DDPM paper. (clip==False)
96 | Another way is to clip(or clamp) the predicted x_0 to -1 ~ 1 for better sampling result.
97 | To clip the x_0 to out desired range, we cannot directly apply (11) to sample x_{t-1}, rather we have to
98 | calculate predicted x_0 using (4) and then calculate mu in (7) using that predicted x_0. Which is exactly
99 | same calculation except for clipping.
100 | As you might easily expect, using clip leads to better sampling result since it
101 | restricts sampled images range to -1 ~ 1. Ref: https://github.com/hojonathanho/diffusion/issues/5
102 |
103 | :param xt: ( b, c, h, w), noised image at time step t
104 | :param t: ( b, )
105 | :param clip: [True, False] Whether to clip predicted x_0 to our desired range -1 ~ 1.
106 | :return: de-noised image at time step t-1
107 | """
108 | batched_time = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
109 | pred_noise = self.unet(xt, batched_time) # corresponds to epsilon_{theta}
110 | if clip:
111 | x0 = self.sqrt_recip_alpha_bar[t] * xt - self.sqrt_recip_alpha_bar_min_1[t] * pred_noise
112 | x0.clamp_(-1., 1.)
113 | mean = self.mean_tilde_x0_coeff[t] * x0 + self.mean_tilde_xt_coeff[t] * xt
114 | else:
115 | mean = self.sqrt_recip_alpha[t] * (xt - self.beta_over_sqrt_one_minus_alpha_bar[t] * pred_noise)
116 | variance = self.beta_tilde[t]
117 | noise = torch.randn_like(xt) if t > 0 else 0. # corresponds to z, consult 4: in Algorithm 2.
118 | x_t_minus_1 = mean + torch.sqrt(variance) * noise
119 | return x_t_minus_1
120 |
121 | @torch.inference_mode()
122 | def sample(self, batch_size=16, return_all_timestep=False, clip=True, min1to1=False):
123 | """
124 |
125 | :param batch_size: # of image to generate.
126 | :param return_all_timestep: Whether to return all images during de-noising process. So it will return the
127 | images from time step T ~ time step 0
128 | :param clip: [True, False]. Explanation in p_sample function.
129 | :return: Generated image of shape (b, 3, h, w) if return_all_timestep==False else (b, T, 3, h, w)
130 | """
131 | xT = torch.randn([batch_size, self.channel, self.image_size, self.image_size], device=self.device)
132 | denoised_intermediates = [xT]
133 | xt = xT
134 | for t in tqdm(reversed(range(0, self.time_step)), desc='DDPM Sampling', total=self.time_step, leave=False):
135 | x_t_minus_1 = self.p_sample(xt, t, clip)
136 | denoised_intermediates.append(x_t_minus_1)
137 | xt = x_t_minus_1
138 |
139 | images = xt if not return_all_timestep else torch.stack(denoised_intermediates, dim=1)
140 | # images = (images + 1.0) * 0.5 # scale to 0~1
141 | images.clamp_(min=-1.0, max=1.0)
142 | if not min1to1:
143 | images.sub_(-1.0).div_(2.0)
144 | return images
145 |
146 | ####################################################################################################################
147 |
148 | def linear_beta_schedule(self):
149 | """
150 | linear schedule, proposed in original ddpm paper
151 | """
152 | scale = 1000 / self.time_step
153 | beta_start = scale * 0.0001
154 | beta_end = scale * 0.02
155 | return torch.linspace(beta_start, beta_end, self.time_step, dtype=torch.float32)
156 |
157 |
158 | class DDIM_Sampler(nn.Module):
159 | def __init__(self, ddpm_diffusion_model, ddim_sampling_steps=100, eta=0, sample_every=5000, fixed_noise=False,
160 | calculate_fid=False, num_fid_sample=None, generate_image=True, clip=True, save=False):
161 | """
162 | Denoising Diffusion Implicit Models (DDIM), Jiaming Song et al.
163 | :param ddpm_diffusion_model: DDPM diffusion model.
164 | :param ddim_sampling_steps: Total sampling steps for DDIM sampling process. It corresponds to S in DDIM paper.
165 | Consult section 4.2 Accelerated Generation Processes in DDIM paper.
166 | :param eta: Hyperparameter to control the stochasticity, see (16) in DDIM paper.
167 | 0: deterministic(DDIM) , 1: fully stochastic(DDPM)
168 | :param sample_every: The interval for calling this DDIM sampler. It is only valid during training.
169 | For example if sample_every=5000 then during training, every 5000steps, trainer will call this DDIM sampler.
170 | :param fixed_noise: If set to True, then this Sampler will always use same starting noise for image generation.
171 | :param calculate_fid: Whether to calculate FID score for this sampler.
172 | :param num_fid_sample: # of generating samples for FID calculation.
173 | If calculate_fid==True and num_fid_sample==None, then it will automatically set # of generating image to the
174 | total number of image in original dataset.
175 | :param generate_image: Whether to save the generated image to folder.
176 | :param clip: [True, False, 'both'] 'both' will sample twice for clip==True and clip==False.
177 | Details in ddim_p_sample function.
178 | :param save: Whether to save the diffusion model based on the FID score calculated by this sampler.
179 | So calculate_fid must be set to True, if you want to set this parameter to be True.
180 | """
181 | super().__init__()
182 | self.ddim_steps = ddim_sampling_steps
183 | self.eta = eta
184 | self.sample_every = sample_every
185 | self.fixed_noise = fixed_noise
186 | self.calculate_fid = calculate_fid
187 | self.num_fid_sample = num_fid_sample
188 | self.generate_image = generate_image
189 | self.channel = ddpm_diffusion_model.channel
190 | self.image_size = ddpm_diffusion_model.image_size
191 | self.device = ddpm_diffusion_model.device
192 | self.clip = clip
193 | self.save = save
194 | self.sampler_name = None
195 | self.save_path = None
196 | ddpm_steps = ddpm_diffusion_model.time_step
197 | assert self.ddim_steps <= ddpm_steps, 'DDIM sampling step must be smaller or equal to DDPM sampling step'
198 | assert clip in [True, False, 'both'], "clip must be one of [True, False, 'both']"
199 | if self.save:
200 | assert self.calculate_fid is True, 'To save model based on FID score, you must set [calculate_fid] to True'
201 | self.register_buffer('best_fid', torch.tensor([1e10], dtype=torch.float32))
202 |
203 | alpha_bar = ddpm_diffusion_model.alpha_bar
204 | # One thing you mush notice is that although sampling time is indexed as [1,...T] in paper,
205 | # since in computer program we index from [0,...T-1] rather than [1,...T],
206 | # value of tau ranges from [-1, ...T-1] where t=-1 indicate initial state (Data distribution)
207 |
208 | # [tau_1, tau_2, ... tau_S] sec 4.2
209 | self.register_buffer('tau', torch.linspace(-1, ddpm_steps - 1, steps=self.ddim_steps + 1, dtype=torch.long)[1:])
210 |
211 | alpha_tau_i = alpha_bar[self.tau]
212 | alpha_tau_i_min_1 = F.pad(alpha_bar[self.tau[:-1]], pad=(1, 0), value=1.) # alpha_0 = 1
213 |
214 | # (16) in DDIM
215 | self.register_buffer('sigma', eta * (((1 - alpha_tau_i_min_1) / (1 - alpha_tau_i) *
216 | (1 - alpha_tau_i / alpha_tau_i_min_1)).sqrt()))
217 | # (12) in DDIM
218 | self.register_buffer('coeff', (1 - alpha_tau_i_min_1 - self.sigma ** 2).sqrt())
219 | self.register_buffer('sqrt_alpha_i_min_1', alpha_tau_i_min_1.sqrt())
220 |
221 | assert self.coeff[0] == 0.0 and self.sqrt_alpha_i_min_1[0] == 1.0, 'DDIM parameter error'
222 |
223 | @torch.inference_mode()
224 | def ddim_p_sample(self, model, xt, i, clip=True):
225 | """
226 | Sample x_{tau_(i-1)} from p(x_{tau_(i-1)} | x_{tau_i}), consult (56) in DDIM paper.
227 | Calculation is done using (12) in DDIM paper where t-1 has to be changed to tau_(i-1) and t has to be
228 | changed to tau_i in (12), for accelerated generation process where total # of de-noising step is S.
229 |
230 | :param model: Diffusion model
231 | :param xt: noisy image at time step tau_i
232 | :param i: i is the index of array tau which is an sub-sequence of [1, ..., T] of length S. See sec. 4.2
233 | :param clip: Like in GaussianDiffusion p_sample, we can clip(or clamp) the predicted x_0 to -1 ~ 1
234 | for better sampling result. If you see (12) in DDIM paper, sampling x_(t-1) depends on epsilon_theta which is
235 | U-net network predicted noise at time step t. If we want to clip the "predicted x0", we have to
236 | re-calculate the epsilon_theta to make "predicted x0" lie in -1 ~ 1. This is exactly what is going on
237 | if you set clip==True.
238 | :return: de-noised image at time step tau_(i-1)
239 | """
240 | t = self.tau[i]
241 | batched_time = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
242 | pred_noise = model.unet(xt, batched_time) # corresponds to epsilon_{theta}
243 | x0 = model.sqrt_recip_alpha_bar[t] * xt - model.sqrt_recip_alpha_bar_min_1[t] * pred_noise
244 | if clip:
245 | x0.clamp_(-1., 1.)
246 | pred_noise = (model.sqrt_recip_alpha_bar[t] * xt - x0) / model.sqrt_recip_alpha_bar_min_1[t]
247 |
248 | # x0 corresponds to "predicted x0" and pred_noise corresponds to epsilon_theta(xt) in (12) DDIM
249 | # Thus self.coeff[i] * pred_noise corresponds to "direction pointing to xt" in (12)
250 | mean = self.sqrt_alpha_i_min_1[i] * x0 + self.coeff[i] * pred_noise
251 | noise = torch.randn_like(xt) if i > 0 else 0.
252 | # self.sigma[i] * noise corresponds to "random noise" in (12)
253 | x_t_minus_1 = mean + self.sigma[i] * noise
254 | return x_t_minus_1
255 |
256 | @torch.inference_mode()
257 | def sample(self, diffusion_model, batch_size, noise=None, return_all_timestep=False, clip=True, min1to1=False):
258 | """
259 |
260 | :param diffusion_model: Diffusion model
261 | :param batch_size: # of image to generate.
262 | :param noise: If set to True, then this Sampler will always use same starting noise for image generation.
263 | :param return_all_timestep: Whether to return all images during de-noising process. So it will return the
264 | images from time step tau_S ~ time step tau_0
265 | :param clip: See ddim_p_sample function
266 | :return: Generated image of shape (b, 3, h, w) if return_all_timestep==False else (b, S, 3, h, w)
267 | """
268 | clip = clip if clip is not None else self.clip
269 | xT = torch.randn([batch_size, self.channel, self.image_size, self.image_size], device=self.device) \
270 | if noise is None else noise.to(self.device)
271 | denoised_intermediates = [xT]
272 | xt = xT
273 | for i in tqdm(reversed(range(0, self.ddim_steps)), desc='DDIM Sampling', total=self.ddim_steps, leave=False):
274 | x_t_minus_1 = self.ddim_p_sample(diffusion_model, xt, i, clip)
275 | denoised_intermediates.append(x_t_minus_1)
276 | xt = x_t_minus_1
277 |
278 | images = xt if not return_all_timestep else torch.stack(denoised_intermediates, dim=1)
279 | # images = (images + 1.0) * 0.5 # scale to 0~1
280 | images.clamp_(min=-1.0, max=1.0)
281 | if not min1to1:
282 | images.sub_(-1.0).div_(2.0)
283 | return images
284 |
--------------------------------------------------------------------------------
/src/inferencer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from torchvision.utils import save_image
6 | from .dataset import dataset_wrapper
7 | from .utils import *
8 | from termcolor import colored
9 | from tqdm import tqdm
10 | from glob import glob
11 | import imageio
12 | from imageio import mimsave
13 | from functools import partial
14 | from torchvision.transforms import ToPILImage
15 |
16 |
17 | class Inferencer:
18 | def __init__(self, diffusion_model, dataset, ddim_samplers=None, batch_size=32, num_samples_per_image=25,
19 | result_folder='./inference_results', num_images_to_generate=1, ddpm_fid_estimate=True, time_step=1000,
20 | ddpm_num_fid_samples=None, clip=True, return_all_step=True, make_denoising_gif=True, num_gif=50,
21 | save_generated_img_for_fid_cal=False):
22 | """
23 | Inferenceer for Diffusion model. Sampling is supported by DDPM sampling & DDIM sampling
24 | :param diffusion_model: GaussianDiffusion model
25 | :param dataset: either 'cifar10' or path to the custom dataset you've prepared, where images are saved
26 | :param ddim_samplers: List containing DDIM samplers.
27 | :param batch_size: batch_size for inferencing
28 | :param num_samples_per_image: # of generating images, must be square number ex) 25, 36, 49...
29 | :param result_folder: where inference result will be saved.
30 | :param num_images_to_generate: # of generated image set. For example if num_samples_per_image==25 and
31 | num_images_to_generate==3 then, in result folder there will be 3 generated image with each image containing
32 | 25 generated sub-images merged into one image file with 5 rows, 5 columns.
33 | :param ddpm_fid_estimate: Whether to calculate FID score based on DDPM sampling.
34 | :param time_step: Gaussian diffusion length T. In DDPM paper they used T=1000
35 | :param ddpm_num_fid_samples: # of generating images for FID calculation using DDPM sampler. If you set
36 | ddpm_fid_estimate to False, i.e. not using DDPM sampler for FID calculation, then this value will
37 | be just ignored.
38 | :param clip: [True, False, 'both'] you can find detail in p_sample function
39 | and ddim_p_sample function in diffusion.py file.
40 | :param return_all_step: Whether to save the entire de-noising processed image to result folder.
41 | :param make_denoising_gif: Whether to make gif which contains de-noising process visually.
42 | :param num_gif: # of images to make one gif which contains de-noising process visually.
43 | :param save_generated_img_for_fid_cal: Whether to save generated images which are used for FID calculation.
44 | """
45 | dataset_name = os.path.basename(dataset)
46 | if dataset_name == '':
47 | dataset_name=os.path.basename(os.path.dirname(dataset))
48 | self.diffusion_model = diffusion_model
49 | self.ddim_samplers = ddim_samplers
50 | self.batch_size = batch_size
51 | self.time_step = time_step
52 | self.num_samples = num_samples_per_image
53 | self.num_images = num_images_to_generate
54 | self.nrow = int(math.sqrt(self.num_samples))
55 | assert (self.nrow ** 2) == self.num_samples, 'num_samples must be a square number. ex) 25, 36, 49, ...'
56 | self.image_size = self.diffusion_model.image_size
57 | self.result_folder = os.path.join(result_folder, dataset_name)
58 | self.ddpm_result_folder = os.path.join(self.result_folder, 'DDPM')
59 | self.device = self.diffusion_model.device
60 | self.return_all_step = return_all_step or make_denoising_gif
61 | self.make_denoising_gif = make_denoising_gif
62 | self.num_gif = num_gif
63 | self.save_img = save_generated_img_for_fid_cal
64 | self.toPIL = ToPILImage()
65 | self.clip = clip
66 | self.ddpm_fid_flag = ddpm_fid_estimate
67 | self.cal_fid = True if self.ddpm_fid_flag else False
68 | self.fid_score_log = dict()
69 | assert clip in [True, False, 'both'], "clip must be one of [True, False, 'both']"
70 | if clip is True or clip == 'both':
71 | os.makedirs(os.path.join(self.ddpm_result_folder, 'clip'), exist_ok=True)
72 | if clip is False or clip == 'both':
73 | os.makedirs(os.path.join(self.ddpm_result_folder, 'no_clip'), exist_ok=True)
74 |
75 | # Dataset
76 | notification = make_notification('Dataset', color='light_green')
77 | print(notification)
78 | dataSet = dataset_wrapper(dataset, self.image_size, augment_horizontal_flip=False, min1to1=False)
79 | dataLoader = DataLoader(dataSet, batch_size=batch_size)
80 | print(colored('Dataset Length: {}\n'.format(len(dataSet)), 'light_green'))
81 |
82 | # DDIM sampler setting
83 | for idx, sampler in enumerate(self.ddim_samplers):
84 | sampler.sampler_name = 'DDIM_{}_steps{}_eta{}'.format(idx + 1, sampler.ddim_steps, sampler.eta)
85 | save_path = os.path.join(self.result_folder, sampler.sampler_name)
86 | sampler.save_path = save_path
87 | if sampler.generate_image:
88 | if sampler.clip is True or sampler.clip == 'both':
89 | os.makedirs(os.path.join(save_path, 'clip'), exist_ok=True)
90 | if sampler.clip is False or sampler.clip == 'both':
91 | os.makedirs(os.path.join(save_path, 'no_clip'), exist_ok=True)
92 | if sampler.calculate_fid:
93 | self.cal_fid = True
94 | sampler.num_fid_sample = sampler.num_fid_sample if sampler.num_fid_sample is not None else len(dataSet)
95 |
96 | # Image generation log
97 | notification = make_notification('Image Generation', color='light_cyan')
98 | print(notification)
99 | print(colored('Image will be generated with the following sampler(s)', 'light_cyan'))
100 | print(colored('-> DDPM Sampler', 'light_cyan'))
101 | for sampler in self.ddim_samplers:
102 | if sampler.generate_image:
103 | print(colored('-> {}'.format(sampler.sampler_name), 'light_cyan'))
104 | print('\n')
105 |
106 | # FID score
107 | notification = make_notification('FID', color='light_magenta')
108 | print(notification)
109 | if not self.cal_fid:
110 | print(colored('No FID evaluation will be executed!\n'
111 | 'If you want FID evaluation consider using DDIM sampler.', 'light_magenta'))
112 | else:
113 | self.fid_scorer = FID(self.batch_size, dataLoader, dataset_name=dataset_name, device=self.device,
114 | no_label=os.path.isdir(dataset))
115 | print(colored('FID score will be calculated with the following sampler(s)', 'light_magenta'))
116 | if self.ddpm_fid_flag:
117 | self.ddpm_num_fid_samples = ddpm_num_fid_samples if ddpm_num_fid_samples is not None else len(dataSet)
118 | print(colored('-> DDPM Sampler / FID calculation with {} generated samples'
119 | .format(self.ddpm_num_fid_samples), 'light_magenta'))
120 | for sampler in self.ddim_samplers:
121 | if sampler.calculate_fid:
122 | print(colored('-> {} / FID calculation with {} generated samples'
123 | .format(sampler.sampler_name, sampler.num_fid_sample), 'light_magenta'))
124 | print('\n')
125 | del dataset
126 | del dataLoader
127 |
128 | @torch.inference_mode()
129 | def inference(self):
130 | self.diffusion_model.eval()
131 | notification = make_notification('Inferencing', color='light_yellow', boundary='+')
132 | print(notification)
133 | print(colored('Image Generation\n', 'light_yellow'))
134 |
135 | # DDPM sampler
136 | for idx in tqdm(range(self.num_images), desc='DDPM image sampling'):
137 | batches = num_to_groups(self.num_samples, self.batch_size)
138 | for i, j in zip([True, False], ['clip', 'no_clip']):
139 | if self.clip not in [i, 'both']:
140 | continue
141 | imgs = list(map(lambda n: self.diffusion_model.sample(n, self.return_all_step, clip=i), batches))
142 | imgs = torch.cat(imgs, dim=0) # (batch, steps, ch, h, w)
143 | if self.return_all_step:
144 | path = os.path.join(self.ddpm_result_folder, '{}'.format(j), '{}'.format(idx + 1))
145 | os.makedirs(path, exist_ok=True)
146 | for step in range(imgs.shape[1]):
147 | save_image(imgs[:, step], nrow=self.nrow, fp=os.path.join(path, '{:04d}.png'.format(step)))
148 | if self.make_denoising_gif:
149 | gif_step = int(self.time_step/self.num_gif)
150 | gif_step = max(1, gif_step)
151 | file_names = list(reversed(sorted(glob(os.path.join(path, '*.png')))))[::gif_step]
152 | file_names = reversed(file_names)
153 | gif = [imageio.v2.imread(names) for names in file_names]
154 | mimsave(os.path.join(self.ddpm_result_folder, '{}_{}.gif'.format(idx+1, j)),
155 | gif, **{'duration': self.num_gif / imgs.shape[1]})
156 | last_img = imgs[:, -1] if self.return_all_step else imgs
157 | save_image(last_img, nrow=self.nrow,
158 | fp=os.path.join(self.ddpm_result_folder, '{}_{}.png'.format(idx+1, j)))
159 | # DDIM sampler
160 | for sampler in self.ddim_samplers:
161 | if sampler.generate_image:
162 | for idx in tqdm(range(self.num_images), desc='{} image sampling'.format(sampler.sampler_name)):
163 | batches = num_to_groups(self.num_samples, self.batch_size)
164 | for i, j in zip([True, False], ['clip', 'no_clip']):
165 | if sampler.clip not in [i, 'both']:
166 | continue
167 | imgs = list(map(lambda n: sampler.sample(self.diffusion_model, batch_size=n, noise=None,
168 | return_all_timestep=self.return_all_step, clip=i), batches))
169 | imgs = torch.cat(imgs, dim=0)
170 | if self.return_all_step:
171 | path = os.path.join(sampler.save_path, '{}'.format(j), '{}'.format(idx + 1))
172 | os.makedirs(path, exist_ok=True)
173 | for step in range(imgs.shape[1]):
174 | save_image(imgs[:, step], nrow=self.nrow,
175 | fp=os.path.join(path, '{:04d}.png'.format(step)))
176 | if self.make_denoising_gif:
177 | gif_step = int(sampler.ddim_steps/self.num_gif)
178 | gif_step = max(1, gif_step)
179 | file_names = list(reversed(sorted(glob(os.path.join(path, '*.png')))))[::gif_step]
180 | file_names = reversed(file_names)
181 | gif = [imageio.v2.imread(names) for names in file_names]
182 | mimsave(os.path.join(sampler.save_path, '{}_{}.gif'.format(idx + 1, j)),
183 | gif, **{'duration': self.num_gif / imgs.shape[1]})
184 | last_img = imgs[:, -1] if self.return_all_step else imgs
185 | save_image(last_img, nrow=self.nrow,
186 | fp=os.path.join(sampler.save_path, '{}_{}.png'.format(idx + 1, j)))
187 | if self.cal_fid:
188 | print(colored('\nFID score estimation\n', 'light_yellow'))
189 | if self.ddpm_fid_flag:
190 | print(colored('DDPM FID calculation...', 'yellow'))
191 | ddpm_fid, imgs = self.fid_scorer.fid_score(self.diffusion_model.sample, self.ddpm_num_fid_samples,
192 | self.save_img)
193 | self.fid_score_log['DDPM'] = ddpm_fid
194 | if self.save_img:
195 | path_ = os.path.join(self.ddpm_result_folder, 'generated_samples_for_FID_calculation')
196 | os.makedirs(path_, exist_ok=True)
197 | for i in range(imgs.shape[0]):
198 | img = self.toPIL(imgs[i])
199 | img.save(os.path.join(path_, '{:06d}.png'.format(i+1)))
200 | for sampler in self.ddim_samplers:
201 | print(colored('{} FID calculation...'.format(sampler.sampler_name), 'yellow'))
202 | if sampler.calculate_fid:
203 | sample_ = partial(sampler.sample, self.diffusion_model)
204 | ddim_fid, imgs = self.fid_scorer.fid_score(sample_, sampler.num_fid_sample, self.save_img)
205 | self.fid_score_log[f'{sampler.sampler_name}'] = ddim_fid
206 | if self.save_img:
207 | path_ = os.path.join(sampler.save_path, 'generated_samples_for_FID_calculation')
208 | os.makedirs(path_, exist_ok=True)
209 | for i in range(imgs.shape[0]):
210 | img = self.toPIL(imgs[i])
211 | img.save(os.path.join(path_, '{:06d}.png'.format(i + 1)))
212 | print(colored('-'*50, 'yellow'))
213 | for key, val in self.fid_score_log.items():
214 | print(colored('Sampler: {} -> FID score: {}'.format(key, val), 'yellow'))
215 | with open(os.path.join(self.result_folder, 'FID.txt'), 'w') as f:
216 | f.write('Results\n')
217 | f.write('='*50)
218 | f.write('\n')
219 | for key, val in self.fid_score_log.items():
220 | f.write('Sampler: {} -> FID score: {}\n'.format(key, val))
221 |
222 | def load(self, path):
223 | if not os.path.exists(path):
224 | print(make_notification('ERROR', color='red', boundary='*'))
225 | print(colored('No saved checkpoint is detected. Please check you gave existing path!', 'red'))
226 | exit()
227 | print(make_notification('Loading Checkpoint', color='green'))
228 | data = torch.load(path, map_location=self.device)
229 | self.diffusion_model.load_state_dict(data['model'])
230 | print(colored('Successfully loaded checkpoint!\n', 'green'))
--------------------------------------------------------------------------------
/src/model_original.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from .utils import PositionalEncoding
5 |
6 |
7 | class ResnetBlock(nn.Module):
8 | def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=None, groups=32):
9 | super().__init__()
10 |
11 | self.dim, self.dim_out = dim, dim_out
12 |
13 | dim_out = dim if dim_out is None else dim_out
14 | self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=dim)
15 | self.activation1 = nn.SiLU()
16 | self.conv1 = nn.Conv2d(dim, dim_out, kernel_size=(3, 3), padding=1)
17 | self.block1 = nn.Sequential(self.norm1, self.activation1, self.conv1)
18 |
19 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim is not None else None
20 |
21 | self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=dim_out)
22 | self.activation2 = nn.SiLU()
23 | self.dropout = nn.Dropout(dropout) if dropout is not None else nn.Identity()
24 | self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), padding=1)
25 | self.block2 = nn.Sequential(self.norm2, self.activation2, self.dropout, self.conv2)
26 |
27 | self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=(1, 1)) if dim != dim_out else nn.Identity()
28 |
29 | def forward(self, x, time_emb=None):
30 | hidden = self.block1(x)
31 | if time_emb is not None:
32 | # add in timestep embedding
33 | hidden = hidden + self.mlp(time_emb)[..., None, None] # (B, dim_out, 1, 1)
34 | hidden = self.block2(hidden)
35 | return hidden + self.residual_conv(x)
36 |
37 |
38 | class Attention(nn.Module):
39 | def __init__(self, dim, groups=32):
40 | super().__init__()
41 |
42 | self.dim, self.dim_out = dim, dim
43 |
44 | self.scale = dim ** (-0.5) # 1 / sqrt(d_k)
45 | self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim)
46 | self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=(1, 1))
47 | self.to_out = nn.Conv2d(dim, dim, kernel_size=(1, 1))
48 |
49 | def forward(self, x):
50 | b, c, h, w = x.shape
51 | qkv = self.to_qkv(self.norm(x)).chunk(3, dim=1)
52 | # You can think (h*w) as sequence length where c is d_k in
53 | q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), qkv)
54 |
55 | """
56 | q, k, v shape: (batch, seq_length, d_k) seq_length = height*width, d_k == c == dim
57 | similarity shape: (batch, seq_length, seq_length)
58 | attention_score shape: (batch, seq_length, seq_length)
59 | attention shape: (batch, seq_length, d_k)
60 | out shape: (batch, d_k, height, width) d_k == c == dim
61 | return shape: (batch, dim, height, width)
62 | """
63 |
64 | similarity = torch.einsum('b i c, b j c -> b i j', q, k) # Q(K^T)
65 | attention_score = torch.softmax(similarity * self.scale, dim=-1) # softmax(Q(K^T) / sqrt(d_k))
66 | attention = torch.einsum('b i j, b j c -> b i c', attention_score, v)
67 | # attention(Q, K, V) = [softmax(Q(K^T) / sqrt(d_k))]V -> Scaled Dot-Product Attention
68 | out = rearrange(attention, 'b (h w) c -> b c h w', h=h, w=w)
69 | return self.to_out(out) + x
70 |
71 |
72 | class ResnetAttentionBlock(nn.Module):
73 | def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=None, groups=32):
74 | super().__init__()
75 |
76 | self.dim, self.dim_out = dim, dim_out
77 |
78 | self.resnet = ResnetBlock(dim, dim_out, time_emb_dim, dropout, groups)
79 | self.attention = Attention(dim_out, groups)
80 |
81 | def forward(self, x, time_emb=None):
82 | x = self.resnet(x, time_emb)
83 | return self.attention(x)
84 |
85 |
86 | class downSample(nn.Module):
87 | def __init__(self, dim_in):
88 | super().__init__()
89 |
90 | self.dim, self.dim_out = dim_in, dim_in
91 |
92 | self.downsameple = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1)
93 |
94 | def forward(self, x):
95 | return self.downsameple(x)
96 |
97 |
98 | class upSample(nn.Module):
99 | def __init__(self, dim_in):
100 | super().__init__()
101 |
102 | self.dim, self.dim_out = dim_in, dim_in
103 |
104 | self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
105 | nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), padding=1))
106 |
107 | def forward(self, x):
108 | return self.upsample(x)
109 |
110 |
111 | class Unet(nn.Module):
112 | def __init__(self, dim, image_size, dim_multiply=(1, 2, 4, 8), channel=3, num_res_blocks=2,
113 | attn_resolutions=(16,), dropout=0, device='cuda', groups=32):
114 | """
115 | U-net for noise prediction. Code is based on Denoising Diffusion Probabilistic Models
116 | https://github.com/hojonathanho/diffusion
117 | :param dim: See below
118 | :param dim_multiply: len(dim_multiply) will be the depth of U-net model with at each level i, the dimension
119 | of channel will be dim * dim_multiply[i]. If the input image shape is [H, W, 3] then at the lowest level,
120 | feature map shape will be [H/(2^(len(dim_multiply)-1), W/(2^(len(dim_multiply)-1), dim*dim_multiply[-1]]
121 | if not considering U-net down-up path connection.
122 | :param image_size: input image size
123 | :param channel: 3
124 | :param num_res_blocks: # of ResnetBlock at each level. In downward path, at each level, there will be
125 | num_res_blocks amount of ResnetBlock module and in upward path, at each level, there will be
126 | (num_res_blocks+1) amount of ResnetBlock module
127 | :param attn_resolutions: The feature map resolution where we will apply Attention. In DDPM paper, author
128 | used Attention module when resolution of feature map is 16.
129 | :param dropout: dropout. If set to 0 then no dropout.
130 | :param device: either 'cuda' or 'cpu'
131 | :param groups: number of groups for Group normalization.
132 | """
133 | super().__init__()
134 | assert dim % groups == 0, 'parameter [groups] must be divisible by parameter [dim]'
135 |
136 | # Attributes
137 | self.dim = dim
138 | self.channel = channel
139 | self.time_emb_dim = 4 * self.dim
140 | self.num_resolutions = len(dim_multiply)
141 | self.device = device
142 | self.resolution = [int(image_size / (2 ** i)) for i in range(self.num_resolutions)]
143 | self.hidden_dims = [self.dim, *map(lambda x: x * self.dim, dim_multiply)]
144 | self.num_res_blocks = num_res_blocks
145 |
146 | # Time embedding
147 | positional_encoding = PositionalEncoding(self.dim)
148 | self.time_mlp = nn.Sequential(
149 | positional_encoding, nn.Linear(self.dim, self.time_emb_dim),
150 | nn.SiLU(), nn.Linear(self.time_emb_dim, self.time_emb_dim)
151 | )
152 |
153 | # Layer definition
154 | self.down_path = nn.ModuleList([])
155 | self.up_path = nn.ModuleList([])
156 | concat_dim = list()
157 |
158 | # Downward Path layer definition
159 | self.init_conv = nn.Conv2d(channel, self.dim, kernel_size=(3, 3), padding=1)
160 | concat_dim.append(self.dim)
161 |
162 | for level in range(self.num_resolutions):
163 | d_in, d_out = self.hidden_dims[level], self.hidden_dims[level + 1]
164 | for block in range(num_res_blocks):
165 | d_in_ = d_in if block == 0 else d_out
166 | if self.resolution[level] in attn_resolutions:
167 | self.down_path.append(ResnetAttentionBlock(d_in_, d_out, self.time_emb_dim, dropout, groups))
168 | else:
169 | self.down_path.append(ResnetBlock(d_in_, d_out, self.time_emb_dim, dropout, groups))
170 | concat_dim.append(d_out)
171 | if level != self.num_resolutions - 1:
172 | self.down_path.append(downSample(d_out))
173 | concat_dim.append(d_out)
174 |
175 | # Middle layer definition
176 | mid_dim = self.hidden_dims[-1]
177 | self.middle_resnet_attention = ResnetAttentionBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups)
178 | self.middle_resnet = ResnetBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups)
179 |
180 | # Upward Path layer definition
181 | for level in reversed(range(self.num_resolutions)):
182 | d_out = self.hidden_dims[level + 1]
183 | for block in range(num_res_blocks + 1):
184 | d_in = self.hidden_dims[level + 2] if block == 0 and level != self.num_resolutions - 1 else d_out
185 | d_in = d_in + concat_dim.pop()
186 | if self.resolution[level] in attn_resolutions:
187 | self.up_path.append(ResnetAttentionBlock(d_in, d_out, self.time_emb_dim, dropout, groups))
188 | else:
189 | self.up_path.append(ResnetBlock(d_in, d_out, self.time_emb_dim, dropout, groups))
190 | if level != 0:
191 | self.up_path.append(upSample(d_out))
192 |
193 | assert not concat_dim, 'Error in concatenation between downward path and upward path.'
194 |
195 | # Output layer
196 | final_ch = self.hidden_dims[1]
197 | self.final_norm = nn.GroupNorm(groups, final_ch)
198 | self.final_activation = nn.SiLU()
199 | self.final_conv = nn.Conv2d(final_ch, channel, kernel_size=(3, 3), padding=1)
200 |
201 | def forward(self, x, time):
202 | """
203 | return predicted noise given x_t and t
204 | """
205 | t = self.time_mlp(time)
206 | # Downward
207 | concat = list()
208 | x = self.init_conv(x)
209 | concat.append(x)
210 | for layer in self.down_path:
211 | x = layer(x, t) if not isinstance(layer, (upSample, downSample)) else layer(x)
212 | concat.append(x)
213 |
214 | # Middle
215 | x = self.middle_resnet_attention(x, t)
216 | x = self.middle_resnet(x, t)
217 |
218 | # Upward
219 | for layer in self.up_path:
220 | if not isinstance(layer, upSample):
221 | x = torch.cat((x, concat.pop()), dim=1)
222 | x = layer(x, t) if not isinstance(layer, (upSample, downSample)) else layer(x)
223 |
224 | # Final
225 | x = self.final_activation(self.final_norm(x))
226 | return self.final_conv(x)
227 |
228 | def print_model_structure(self):
229 | for i in self.down_path:
230 | if i.__class__.__name__ == 'downSample':
231 | print('-' * 20)
232 | if i.__class__.__name__ == "Conv2d":
233 |
234 | print(i.__class__.__name__)
235 | else:
236 | print(i.__class__.__name__, i.dim, i.dim_out)
237 | print('\n')
238 | print('=' * 20)
239 | print('\n')
240 | for i in self.up_path:
241 | if i.__class__.__name__ == 'upSample':
242 | print('-' * 20)
243 | if i.__class__.__name__ == "Conv2d":
244 | print(i.__class__.__name__)
245 | else:
246 | print(i.__class__.__name__, i.dim, i.dim_out)
247 |
--------------------------------------------------------------------------------
/src/model_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from einops.layers.torch import Rearrange
6 | from functools import partial
7 | from .utils import PositionalEncoding
8 |
9 |
10 | class RMSNorm(nn.Module):
11 | def __init__(self, dim):
12 | super().__init__()
13 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
14 |
15 | def forward(self, x):
16 | return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
17 |
18 |
19 | class Block(nn.Module):
20 | def __init__(self, dim, dim_out, groups=8):
21 | """
22 | Input shape=(B, dim, H, W)
23 | Output shape=(B, dim_out, H, W)
24 |
25 | :param dim: input channel
26 | :param dim_out: output channel
27 | :param groups: number of groups for Group normalization.
28 | """
29 | super().__init__()
30 | self.proj = nn.Conv2d(dim, dim_out, kernel_size=(3, 3), padding=1)
31 | self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim_out)
32 | self.activation = nn.SiLU()
33 |
34 | def forward(self, x, scale_shift=None):
35 | x = self.proj(x)
36 | x = self.norm(x)
37 | if scale_shift is not None:
38 | scale, shift = scale_shift
39 | x = x * (1 + scale) + shift
40 | return self.activation(x)
41 |
42 |
43 | class ResnetBlock(nn.Module):
44 | def __init__(self, dim, dim_out, time_emb_dim=None, group=8):
45 | """
46 | In abstract, it is composed of two Convolutional layer with residual connection,
47 | with information of time encoding is passed to first Convolutional layer.
48 |
49 | Input shape=(B, dim, H, W)
50 | Output shape=(B, dim_out, H, W)
51 |
52 | :param dim: input channel
53 | :param dim_out: output channel
54 | :param time_emb_dim: Embedding dimension for time.
55 | :param group: number of groups for Group normalization.
56 | """
57 | super().__init__()
58 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if time_emb_dim is not None else None
59 | self.block1 = Block(dim, dim_out, group)
60 | self.block2 = Block(dim_out, dim_out, group)
61 | self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=(1, 1)) if dim != dim_out else nn.Identity()
62 |
63 | def forward(self, x, time_emb=None):
64 | """
65 |
66 | :param x: (B, dim, H, W)
67 | :param time_emb: (B, time_emb_dim)
68 | :return: (B, dim_out, H, W)
69 | """
70 | scale_shift = None
71 | if time_emb is not None:
72 | scale_shift = self.mlp(time_emb)[..., None, None] # (B, dim_out*2, 1, 1)
73 | scale_shift = scale_shift.chunk(2, dim=1) # len 2 with each element shape (B, dim_out, 1, 1)
74 | hidden = self.block1(x, scale_shift)
75 | hidden = self.block2(hidden)
76 | return hidden + self.residual_conv(x)
77 |
78 |
79 | class Attention(nn.Module):
80 | def __init__(self, dim, head=4, dim_head=32):
81 | super().__init__()
82 | self.head = head
83 | hidden_dim = head * dim_head
84 |
85 | self.scale = dim_head ** (-0.5) # 1 / sqrt(d_k)
86 | self.norm = RMSNorm(dim)
87 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=(1, 1), bias=False)
88 | self.to_out = nn.Conv2d(hidden_dim, dim, kernel_size=(1, 1))
89 |
90 | def forward(self, x):
91 | b, c, i, j = x.shape
92 | x = self.norm(x)
93 |
94 | qkv = self.to_qkv(x).chunk(3, dim=1)
95 | # h=self.head, f=dim_head, i=height, j=width.
96 | # You can think (i*j) as sequence length where f is d_k in
97 | q, k, v = map(lambda t: rearrange(t, 'b (h f) i j -> b h (i j) f', h=self.head), qkv)
98 |
99 | """
100 | q, k, v shape: (batch, # of head, seq_length, d_k) seq_length = height * width
101 | similarity shape: (batch, # of head, seq_length, seq_length)
102 | attention_score shape: (batch, # of head, seq_length, seq_length)
103 | attention shape: (batch, # of head, seq_length, d_k)
104 | out shape: (batch, hidden_dim, height, width)
105 | return shape: (batch, dim, height, width)
106 | """
107 | # n, m is likewise sequence length.
108 | similarity = torch.einsum('b h n f, b h m f -> b h n m', q, k) # Q(K^T)
109 | attention_score = torch.softmax(similarity * self.scale, dim=-1) # softmax(Q(K^T) / sqrt(d_k))
110 | attention = torch.einsum('b h n m, b h m f -> b h n f', attention_score, v)
111 | # attention(Q, K, V) = [softmax(Q(K^T) / sqrt(d_k))]V -> Scaled Dot-Product Attention
112 |
113 | out = rearrange(attention, 'b h (i j) f -> b (h f) i j', i=i, j=j)
114 | return self.to_out(out)
115 |
116 |
117 | class LinearAttention(nn.Module):
118 | def __init__(self, dim, head=4, dim_head=32):
119 | super().__init__()
120 | self.head = head
121 | hidden_dim = head * dim_head
122 |
123 | self.scale = dim_head ** (-0.5)
124 | self.norm = RMSNorm(dim)
125 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=(1, 1), bias=False)
126 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, kernel_size=(1, 1)), RMSNorm(dim))
127 |
128 | def forward(self, x):
129 | b, c, i, j = x.shape
130 | x = self.norm(x)
131 |
132 | qkv = self.to_qkv(x).chunk(3, dim=1)
133 | # h=self.head, f=dim_head, i=height, j=width.
134 | # You can think (i*j) as sequence length where f is d_k in
135 | q, k, v = map(lambda t: rearrange(t, 'b (h f) i j -> b h f (i j)', h=self.head), qkv)
136 |
137 | q = q.softmax(dim=-2) * self.scale
138 | k = k.softmax(dim=-1)
139 | context = torch.einsum('b h f m, b h e m -> b h f e', k, v)
140 | linear_attention = torch.einsum('b h f e, b h f n -> b h e n', context, q)
141 | out = rearrange(linear_attention, 'b h e (i j) -> b (h e) i j', i=i, j=j, h=self.head)
142 | return self.to_out(out)
143 |
144 |
145 | def downSample(dim_in, dim_out):
146 | return nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
147 | nn.Conv2d(dim_in * 4, dim_out, kernel_size=(1, 1)))
148 |
149 |
150 | def upSample(dim_in, dim_out):
151 | return nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
152 | nn.Conv2d(dim_in, dim_out, kernel_size=(3, 3), padding=1))
153 |
154 |
155 | class Unet(nn.Module):
156 | def __init__(self, dim, image_size, dim_multiply=(1, 2, 4, 8), channel=3, attn_heads=4, attn_head_dim=32,
157 | full_attn=(False, False, False, True), resnet_group_norm=8, device='cuda'):
158 | """
159 | U-net for noise prediction. Code is based on denoising-diffusion-pytorch
160 | https://github.com/lucidrains/denoising-diffusion-pytorch
161 | :param dim: See below
162 | :param dim_multiply: len(dim_multiply) will be the depth of U-net model with at each level i, the dimension
163 | of channel will be dim * dim_multiply[i]. If the input image shape is [H, W, 3] then at the lowest level,
164 | feature map shape will be [H/(2^(len(dim_multiply)-1), W/(2^(len(dim_multiply)-1), dim*dim_multiply[-1]]
165 | if not considering U-net down-up path connection.
166 | :param channel: 3
167 | :param attn_heads: It uses multi-head-self-Attention. attn_head is the # of head. It corresponds to h in
168 | "Attention is all you need" paper. See section 3.2.2
169 | :param attn_head_dim: It is the dimension of each head. It corresponds to d_k in Attention paper.
170 | :param full_attn: In pytorch implementation they used Linear Attention where full Attention(multi head self
171 | attention) is not applied. This param indicates at each level, whether to use full attention
172 | or use linear attention. So the len(full_attn) must equal to len(dim_multiply). For example if
173 | full_attn=(F, F, F, T) then at level 0, 1, 2 it will use Linear Attention and at level 3 it will use
174 | multi-head self attention(i.e. full attention)
175 | :param resnet_group_norm: number of groups for Group normalization.
176 | :param device: either 'cuda' or 'cpu'
177 | """
178 | super().__init__()
179 | assert len(dim_multiply) == len(full_attn), 'Length of dim_multiply and Length of full_attn must be same'
180 |
181 | # Attributes
182 | self.dim = dim
183 | self.channel = channel
184 | self.hidden_dims = [self.dim, *map(lambda x: x * self.dim, dim_multiply)]
185 | self.dim_in_out = list(zip(self.hidden_dims[:-1], self.hidden_dims[1:]))
186 | self.time_emb_dim = 4 * self.dim
187 | self.full_attn = full_attn
188 | self.depth = len(dim_multiply)
189 | self.device = device
190 |
191 | # Time embedding
192 | positional_encoding = PositionalEncoding(self.dim)
193 | self.time_mlp = nn.Sequential(
194 | positional_encoding, nn.Linear(self.dim, self.time_emb_dim),
195 | nn.GELU(), nn.Linear(self.time_emb_dim, self.time_emb_dim)
196 | )
197 |
198 | # Layer definition
199 | resnet_block = partial(ResnetBlock, time_emb_dim=self.time_emb_dim, group=resnet_group_norm)
200 | self.init_conv = nn.Conv2d(self.channel, self.dim, kernel_size=(7, 7), padding=3)
201 | self.down_path = nn.ModuleList([])
202 | self.up_path = nn.ModuleList([])
203 |
204 | # Downward Path layer definition
205 | for idx, ((dim_in, dim_out), full_attn_flag) in enumerate(zip(self.dim_in_out, self.full_attn)):
206 | isLast = idx == (self.depth - 1)
207 | attention = LinearAttention if not full_attn_flag else Attention
208 | self.down_path.append(nn.ModuleList([
209 | resnet_block(dim_in, dim_in),
210 | resnet_block(dim_in, dim_in),
211 | attention(dim_in, head=attn_heads, dim_head=attn_head_dim),
212 | downSample(dim_in, dim_out) if not isLast else nn.Conv2d(dim_in, dim_out, kernel_size=(3, 3), padding=1)
213 | ]))
214 |
215 | # Middle layer definition
216 | mid_dim = self.hidden_dims[-1]
217 | self.mid_resnet_block1 = resnet_block(mid_dim, mid_dim)
218 | self.mid_attention = Attention(mid_dim, head=attn_heads, dim_head=attn_head_dim)
219 | self.mid_resnet_block2 = resnet_block(mid_dim, mid_dim)
220 |
221 | # Upward Path layer definition
222 | for idx, ((dim_in, dim_out), full_attn_flag) in enumerate(
223 | zip(reversed(self.dim_in_out), reversed(self.full_attn))):
224 | isLast = idx == (self.depth - 1)
225 | attention = LinearAttention if not full_attn_flag else Attention
226 | self.up_path.append(nn.ModuleList([
227 | resnet_block(dim_in + dim_out, dim_out),
228 | resnet_block(dim_in + dim_out, dim_out),
229 | attention(dim_out, head=attn_heads, dim_head=attn_head_dim),
230 | upSample(dim_out, dim_in) if not isLast else nn.Conv2d(dim_out, dim_in, kernel_size=(3, 3), padding=1)
231 | ]))
232 |
233 | self.final_resnet_block = resnet_block(2 * self.dim, self.dim)
234 | self.final_conv = nn.Conv2d(self.dim, self.channel, kernel_size=(1, 1))
235 |
236 | def forward(self, x, time):
237 | """
238 | return predicted noise given x_t and t
239 | """
240 | x = self.init_conv(x)
241 | r = x.clone()
242 | t = self.time_mlp(time)
243 | concat = list()
244 |
245 | for block1, block2, attn, downsample in self.down_path:
246 | x = block1(x, t)
247 | concat.append(x)
248 |
249 | x = block2(x, t)
250 | x = attn(x) + x
251 | concat.append(x)
252 |
253 | x = downsample(x)
254 |
255 | x = self.mid_resnet_block1(x, t)
256 | x = self.mid_attention(x) + x
257 | x = self.mid_resnet_block2(x, t)
258 |
259 | for block1, block2, attn, upsample in self.up_path:
260 | x = torch.cat((x, concat.pop()), dim=1)
261 | x = block1(x, t)
262 |
263 | x = torch.cat((x, concat.pop()), dim=1)
264 | x = block2(x, t)
265 | x = attn(x) + x
266 | x = upsample(x)
267 |
268 | x = torch.cat((x, r), dim=1)
269 | x = self.final_resnet_block(x, t)
270 | return self.final_conv(x)
271 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | from .dataset import dataset_wrapper
5 | import torch
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 | import torchvision
9 | from torch.optim import Adam
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torchvision.utils import save_image
12 | from multiprocessing import cpu_count
13 | from functools import partial
14 | from tqdm import tqdm
15 | import datetime
16 | from termcolor import colored
17 | from .utils import *
18 |
19 |
20 | def cycle_with_label(dl):
21 | while True:
22 | for data in dl:
23 | img, label = data
24 | yield img
25 |
26 |
27 | def cycle(dl):
28 | while True:
29 | for data in dl:
30 | yield data
31 |
32 |
33 | class Trainer:
34 | def __init__(self, diffusion_model, dataset, batch_size=32, lr=2e-5, total_step=100000, ddim_samplers=None,
35 | save_and_sample_every=1000, num_samples=25, result_folder='./results', cpu_percentage=0,
36 | fid_estimate_batch_size=None, ddpm_fid_score_estimate_every=None, ddpm_num_fid_samples=None,
37 | max_grad_norm=1., tensorboard=True, exp_name=None, clip=True):
38 | """
39 | Trainer for Diffusion model.
40 | :param diffusion_model: GaussianDiffusion model
41 | :param dataset: either 'cifar10' or path to the custom dataset you've prepared, where images are saved
42 | :param batch_size: batch size for training. DDPM author used 128 for cifar10 and 64 for 256X256 image
43 | :param lr: DDPM author used 2e-4 for cifar10 and 2e-5 for 256X256 image
44 | :param total_step: total training step. DDPM used 800K for cifar10, CelebA-HQ for 0.5M
45 | :param ddim_samplers: List containing DDIM samplers.
46 | :param save_and_sample_every: Step interval for saving model and generated image(by DDPM sampling).
47 | For example if it is set to 1000, then trainer will save models in every 1000 step and save generated images
48 | based on DDPM sampling schema. If you want to generate image based on DDIM sampling, you have to pass a list
49 | containing corresponding DDIM sampler.
50 | :param num_samples: # of generating images, must be square number ex) 25, 36, 49...
51 | :param result_folder: where model, generated images will be saved
52 | :param cpu_percentage: The percentage of CPU used for Dataloader i.e. num_workers in Dataloader.
53 | Value must be [0, 1] where 1 means using all cpu for dataloader. If you are Windows user setting value other
54 | than 0 will cause problem, so set to 0
55 | :param fid_estimate_batch_size: batch size for FID calculation. It has nothing to do with training.
56 | :param ddpm_fid_score_estimate_every: Step interval for FID calculation using DDPM. If set to None, FID score
57 | will not be calculated with DDPM sampling. If you use DDPM sampling for FID calculation, it can be very
58 | time consuming, so it is wise to set this value to None, and use DDIM sampler for FID calculation. But anyway
59 | you can calculate FID score with DDPM sampler if you insist to.
60 | :param ddpm_num_fid_samples: # of generating images for FID calculation using DDPM sampler. If you set
61 | ddpm_fid_score_estimate_every to None, i.e. not using DDPM sampler for FID calculation, then this value will
62 | be just ignored.
63 | :param max_grad_norm: Restrict the norm of maximum gradient to this value
64 | :param tensorboard: Set to ture if you want to monitor training
65 | :param exp_name: experiment name. If set to None, it will be decided automatically as folder name of dataset.
66 | :param clip: [True, False, 'both'] you can find detail in p_sample function in diffusion.py file.
67 | """
68 |
69 | # Metadata & Initialization & Make directory for saving files.
70 | now = datetime.datetime.now()
71 | self.cur_time = now.strftime('%Y-%m-%d_%Hh%Mm')
72 | if exp_name is None:
73 | exp_name = os.path.basename(dataset)
74 | if exp_name == '':
75 | exp_name = os.path.basename(os.path.dirname(dataset))
76 | self.exp_name = exp_name
77 | self.diffusion_model = diffusion_model
78 | self.ddim_samplers = ddim_samplers
79 | self.batch_size = batch_size
80 | self.num_samples = num_samples
81 | self.nrow = int(math.sqrt(self.num_samples))
82 | assert (self.nrow ** 2) == self.num_samples, 'num_samples must be a square number. ex) 25, 36, 49, ...'
83 | self.save_and_sample_every = save_and_sample_every
84 | self.image_size = self.diffusion_model.image_size
85 | self.max_grad_norm = max_grad_norm
86 | self.result_folder = os.path.join(result_folder, exp_name, self.cur_time)
87 | self.ddpm_result_folder = os.path.join(self.result_folder, 'DDPM')
88 | self.device = self.diffusion_model.device
89 | self.clip = clip
90 | self.ddpm_fid_flag = True if ddpm_fid_score_estimate_every is not None else False
91 | self.ddpm_fid_score_estimate_every = ddpm_fid_score_estimate_every
92 | self.cal_fid = True if self.ddpm_fid_flag else False
93 | self.tqdm_sampler_name = None
94 | self.tensorboard = tensorboard
95 | self.tensorboard_name = None
96 | self.writer = None
97 | self.global_step = 0
98 | self.total_step = total_step
99 | self.fid_score_log = dict()
100 | assert clip in [True, False, 'both'], "clip must be one of [True, False, 'both']"
101 | if clip is True or clip == 'both':
102 | os.makedirs(os.path.join(self.ddpm_result_folder, 'clip'), exist_ok=True)
103 | if clip is False or clip == 'both':
104 | os.makedirs(os.path.join(self.ddpm_result_folder, 'no_clip'), exist_ok=True)
105 |
106 | # Dataset & DataLoader & Optimizer
107 | notification = make_notification('Dataset', color='light_green')
108 | print(notification)
109 | dataSet = dataset_wrapper(dataset, self.image_size)
110 | assert len(dataSet) >= 100, 'you should have at least 100 images in your folder.at least 10k images recommended'
111 | print(colored('Dataset Length: {}\n'.format(len(dataSet)), 'light_green'))
112 | CPU_cnt = cpu_count()
113 | # TODO: pin_memory?
114 | num_workers = int(CPU_cnt * cpu_percentage)
115 | assert num_workers <= CPU_cnt, "cpu_percentage must be [0.0, 1.0]"
116 | dataLoader = DataLoader(dataSet, batch_size=self.batch_size, shuffle=True,
117 | num_workers=num_workers, pin_memory=True)
118 | self.dataLoader = cycle(dataLoader) if os.path.isdir(dataset) else cycle_with_label(dataLoader)
119 | self.optimizer = Adam(self.diffusion_model.parameters(), lr=lr)
120 |
121 | # DDIM sampler setting
122 | self.ddim_sampling_schedule = list()
123 | for idx, sampler in enumerate(self.ddim_samplers):
124 | sampler.sampler_name = 'DDIM_{}_steps{}_eta{}'.format(idx + 1, sampler.ddim_steps, sampler.eta)
125 | self.ddim_sampling_schedule.append(sampler.sample_every)
126 | save_path = os.path.join(self.result_folder, sampler.sampler_name)
127 | sampler.save_path = save_path
128 | if sampler.save:
129 | os.makedirs(save_path, exist_ok=True)
130 | if sampler.generate_image:
131 | if sampler.clip is True or sampler.clip == 'both':
132 | os.makedirs(os.path.join(save_path, 'clip'), exist_ok=True)
133 | if sampler.clip is False or sampler.clip == 'both':
134 | os.makedirs(os.path.join(save_path, 'no_clip'), exist_ok=True)
135 | if sampler.calculate_fid:
136 | self.cal_fid = True
137 | if self.tqdm_sampler_name is None:
138 | self.tqdm_sampler_name = sampler.sampler_name
139 | sampler.num_fid_sample = sampler.num_fid_sample if sampler.num_fid_sample is not None else len(dataSet)
140 | self.fid_score_log[sampler.sampler_name] = list()
141 | if sampler.fixed_noise:
142 | sampler.register_buffer('noise', torch.randn([self.num_samples, sampler.channel,
143 | sampler.image_size, sampler.image_size]))
144 |
145 | # Image generation log
146 | notification = make_notification('Image Generation', color='light_cyan')
147 | print(notification)
148 | print(colored('Image will be generated with the following sampler(s)', 'light_cyan'))
149 | print(colored('-> DDPM Sampler / Image generation every {} steps'.format(self.save_and_sample_every),
150 | 'light_cyan'))
151 | for sampler in self.ddim_samplers:
152 | if sampler.generate_image:
153 | print(colored('-> {} / Image generation every {} steps / Fixed Noise : {}'
154 | .format(sampler.sampler_name, sampler.sample_every, sampler.fixed_noise), 'light_cyan'))
155 | print('\n')
156 |
157 | # FID score
158 | notification = make_notification('FID', color='light_magenta')
159 | print(notification)
160 | if not self.cal_fid:
161 | print(colored('No FID evaluation will be executed!\n'
162 | 'If you want FID evaluation consider using DDIM sampler.', 'light_magenta'))
163 | else:
164 | self.fid_batch_size = fid_estimate_batch_size if fid_estimate_batch_size is not None else self.batch_size
165 | dataSet_fid = dataset_wrapper(dataset, self.image_size,
166 | augment_horizontal_flip=False, info_color='light_magenta', min1to1=False)
167 | dataLoader_fid = DataLoader(dataSet_fid, batch_size=self.fid_batch_size, num_workers=num_workers)
168 |
169 | self.fid_scorer = FID(self.fid_batch_size, dataLoader_fid, dataset_name=exp_name, device=self.device,
170 | no_label=os.path.isdir(dataset))
171 |
172 | print(colored('FID score will be calculated with the following sampler(s)', 'light_magenta'))
173 | if self.ddpm_fid_flag:
174 | self.ddpm_num_fid_samples = ddpm_num_fid_samples if ddpm_num_fid_samples is not None else len(dataSet)
175 | print(colored('-> DDPM Sampler / FID calculation every {} steps with {} generated samples'
176 | .format(self.ddpm_fid_score_estimate_every, self.ddpm_num_fid_samples), 'light_magenta'))
177 | for sampler in self.ddim_samplers:
178 | if sampler.calculate_fid:
179 | print(colored('-> {} / FID calculation every {} steps with {} generated samples'
180 | .format(sampler.sampler_name, sampler.sample_every,
181 | sampler.num_fid_sample), 'light_magenta'))
182 | print('\n')
183 | if self.ddpm_fid_flag:
184 | self.tqdm_sampler_name = 'DDPM'
185 | self.fid_score_log['DDPM'] = list()
186 | notification = make_notification('WARNING', color='red', boundary='*')
187 | print(notification)
188 | msg = """
189 | FID computation witm DDPM sampler requires a lot of generated samples and can therefore be very time
190 | consuming.\nTo accelerate sampling, only using DDIM sampling is recommended. To disable DDPM sampling,
191 | set [ddpm_fid_score_estimate_every] parameter to None while instantiating Trainer.\n
192 | """
193 | print(colored(msg, 'red'))
194 | del dataLoader_fid
195 | del dataSet_fid
196 |
197 | def train(self):
198 | # Tensorboard
199 | if self.tensorboard:
200 | os.makedirs('./tensorboard', exist_ok=True)
201 | self.tensorboard_name = self.exp_name + '_' + self.cur_time \
202 | if self.tensorboard_name is None else self.tensorboard_name
203 | notification = make_notification('Tensorboard', color='light_blue')
204 | print(notification)
205 | print(colored('Tensorboard Available!', 'light_blue'))
206 | print(colored('Tensorboard name: {}'.format(self.tensorboard_name), 'light_blue'))
207 | print(colored('Launch Tensorboard by running following command on terminal', 'light_blue'))
208 | print(colored('tensorboard --logdir ./tensorboard\n', 'light_blue'))
209 | self.writer = SummaryWriter(os.path.join('./tensorboard', self.tensorboard_name))
210 | notification = make_notification('Training', color='light_yellow', boundary='+')
211 | print(notification)
212 | cur_fid = 'NAN'
213 | ddpm_best_fid = 1e10
214 | stepTQDM = tqdm(range(self.global_step, self.total_step))
215 | for cur_step in stepTQDM:
216 | self.diffusion_model.train()
217 | self.optimizer.zero_grad()
218 | image = next(self.dataLoader).to(self.device)
219 | loss = self.diffusion_model(image)
220 | loss.backward()
221 | nn.utils.clip_grad_norm_(self.diffusion_model.parameters(), self.max_grad_norm)
222 | self.optimizer.step()
223 |
224 | vis_fid = cur_fid if isinstance(cur_fid, str) else '{:.04f}'.format(cur_fid)
225 | stepTQDM.set_postfix({'loss': '{:.04f}'.format(loss.item()), 'FID': vis_fid, 'step':self.global_step})
226 |
227 | self.diffusion_model.eval()
228 | # DDPM Sampler for generating images
229 | if cur_step != 0 and (cur_step % self.save_and_sample_every) == 0:
230 | if self.writer is not None:
231 | self.writer.add_scalar('Loss', loss.item(), cur_step)
232 | with torch.inference_mode():
233 | batches = num_to_groups(self.num_samples, self.batch_size)
234 | for i, j in zip([True, False], ['clip', 'no_clip']):
235 | if self.clip not in [i, 'both']:
236 | continue
237 | imgs = list(map(lambda n: self.diffusion_model.sample(batch_size=n, clip=i), batches))
238 | imgs = torch.cat(imgs, dim=0)
239 | save_image(imgs, nrow=self.nrow,
240 | fp=os.path.join(self.ddpm_result_folder, j, f'sample_{cur_step}.png'))
241 | if self.writer is not None:
242 | self.writer.add_images('DDPM sampling result ({})'.format(j), imgs, cur_step)
243 | self.save('latest')
244 |
245 | # DDPM Sampler for FID score evaluation
246 | if self.ddpm_fid_flag and cur_step != 0 and (cur_step % self.ddpm_fid_score_estimate_every) == 0:
247 | ddpm_cur_fid, _ = self.fid_scorer.fid_score(self.diffusion_model.sample, self.ddpm_num_fid_samples)
248 | if ddpm_best_fid > ddpm_cur_fid:
249 | ddpm_best_fid = ddpm_cur_fid
250 | self.save('best_fid_ddpm')
251 | if self.writer is not None:
252 | self.writer.add_scalars('FID', {'DDPM': ddpm_cur_fid}, cur_step)
253 | cur_fid = ddpm_cur_fid
254 | self.fid_score_log['DDPM'].append((self.global_step, ddpm_cur_fid))
255 |
256 | # DDIM Sampler
257 | for sampler in self.ddim_samplers:
258 | if cur_step != 0 and (cur_step % sampler.sample_every) == 0:
259 | # DDPM Sampler for generating images
260 | if sampler.generate_image:
261 | with torch.inference_mode():
262 | batches = num_to_groups(self.num_samples, self.batch_size)
263 | c_batch = np.insert(np.cumsum(np.array(batches)), 0, 0)
264 | for i, j in zip([True, False], ['clip', 'no_clip']):
265 | if sampler.clip not in [i, 'both']:
266 | continue
267 | if sampler.fixed_noise:
268 | imgs = list()
269 | for b in range(len(batches)):
270 | imgs.append(sampler.sample(self.diffusion_model, batch_size=None, clip=i,
271 | noise=sampler.noise[c_batch[b]:c_batch[b+1]]))
272 | else:
273 | imgs = list(map(lambda n: sampler.sample(self.diffusion_model,
274 | batch_size=n, clip=i), batches))
275 | imgs = torch.cat(imgs, dim=0)
276 | save_image(imgs, nrow=self.nrow,
277 | fp=os.path.join(sampler.save_path, j, f'sample_{cur_step}.png'))
278 | if self.writer is not None:
279 | self.writer.add_images('{} sampling result ({})'
280 | .format(sampler.sampler_name, j), imgs, cur_step)
281 |
282 | # DDPM Sampler for FID score evaluation
283 | if sampler.calculate_fid:
284 | sample_ = partial(sampler.sample, self.diffusion_model)
285 | ddim_cur_fid, _ = self.fid_scorer.fid_score(sample_, sampler.num_fid_sample)
286 | if sampler.best_fid[0] > ddim_cur_fid:
287 | sampler.best_fid[0] = ddim_cur_fid
288 | if sampler.save:
289 | self.save('best_fid_{}'.format(sampler.sampler_name))
290 | if sampler.sampler_name == self.tqdm_sampler_name:
291 | cur_fid = ddim_cur_fid
292 | if self.writer is not None:
293 | self.writer.add_scalars('FID', {sampler.sampler_name: ddim_cur_fid}, cur_step)
294 | self.fid_score_log[sampler.sampler_name].append((self.global_step, ddim_cur_fid))
295 |
296 | self.global_step += 1
297 |
298 | print(colored('Training Finished!', 'light_yellow'))
299 | if self.writer is not None:
300 | self.writer.close()
301 |
302 | def save(self, name):
303 | data = {
304 | 'global_step': self.global_step,
305 | 'model': self.diffusion_model.state_dict(),
306 | 'opt': self.optimizer.state_dict(),
307 | 'fid_logger': self.fid_score_log,
308 | 'tensorboard': self.tensorboard_name
309 | }
310 | for sampler in self.ddim_samplers:
311 | data[sampler.sampler_name] = sampler.state_dict()
312 | torch.save(data, os.path.join(self.result_folder, 'model_{}.pt'.format(name)))
313 |
314 | def load(self, path, tensorboard_path=None, no_prev_ddim_setting=False):
315 | if not os.path.exists(path):
316 | print(make_notification('ERROR', color='red', boundary='*'))
317 | print(colored('No saved checkpoint is detected. Please check you gave existing path!', 'red'))
318 | exit()
319 | if tensorboard_path is not None and not os.path.exists(tensorboard_path):
320 | print(make_notification('ERROR', color='red', boundary='*'))
321 | print(colored('No tensorboard is detected. Please check you gave existing path!', 'red'))
322 | exit()
323 | print(make_notification('Loading Checkpoint', color='green'))
324 | data = torch.load(path, map_location=self.device)
325 | self.diffusion_model.load_state_dict(data['model'])
326 | self.global_step = data['global_step']
327 | self.optimizer.load_state_dict(data['opt'])
328 | fid_score_log = data['fid_logger']
329 | if no_prev_ddim_setting:
330 | for key, val in self.fid_score_log.items():
331 | if key not in fid_score_log:
332 | fid_score_log[key] = val
333 | else:
334 | for sampler in self.ddim_samplers:
335 | sampler.load_state_dict(data[sampler.sampler_name])
336 | self.fid_score_log = fid_score_log
337 | if tensorboard_path is not None:
338 | self.tensorboard_name = data['tensorboard']
339 | print(colored('Successfully loaded checkpoint!\n', 'green'))
340 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import os
6 | from pytorch_fid.fid_score import calculate_frechet_distance
7 | from pytorch_fid.inception import InceptionV3
8 | from torch.nn.functional import adaptive_avg_pool2d
9 | from tqdm import tqdm
10 | from termcolor import colored
11 |
12 |
13 | def num_to_groups(num, divisor):
14 | groups = num // divisor
15 | remainder = num % divisor
16 | arr = [divisor] * groups
17 | if remainder > 0:
18 | arr.append(remainder)
19 | return arr
20 |
21 |
22 | class PositionalEncoding(nn.Module):
23 | def __init__(self, d_model):
24 | super().__init__()
25 | half_d_model = d_model // 2
26 | log_denominator = -math.log(10000) / (half_d_model - 1)
27 | denominator_ = torch.exp(torch.arange(half_d_model) * log_denominator)
28 | self.register_buffer('denominator', denominator_)
29 |
30 | def forward(self, time):
31 | """
32 | :param time: shape=(B, )
33 | :return: Positional Encoding shape=(B, d_model)
34 | """
35 | argument = time[:, None] * self.denominator[None, :] # (B, half_d_model)
36 | return torch.cat([argument.sin(), argument.cos()], dim=-1) # (B, d_model)
37 |
38 |
39 | class FID:
40 | def __init__(self, batch_size, dataLoader, dataset_name, cache_dir='./results/fid_cache/', device='cuda',
41 | no_label=False, inception_block_idx=3):
42 | assert inception_block_idx in [0, 1, 2, 3], 'inception_block_idx must be either 0, 1, 2, 3'
43 | self.batch_size = batch_size
44 | self.dataLoader = dataLoader
45 | self.cache_dir = cache_dir
46 | self.dataset_name = dataset_name
47 | self.device = device
48 | self.no_label = no_label
49 | self.inception = InceptionV3([inception_block_idx]).to(device)
50 |
51 | os.makedirs(cache_dir, exist_ok=True)
52 | self.m2, self.s2 = self.load_dataset_stats()
53 |
54 | def calculate_inception_features(self, samples):
55 | self.inception.eval()
56 | features = self.inception(samples)[0]
57 | if features.size(2) != 1 or features.size(3) != 1:
58 | features = adaptive_avg_pool2d(features, output_size=(1, 1))
59 | return features.squeeze()
60 |
61 | def load_dataset_stats(self):
62 | path = os.path.join(self.cache_dir, self.dataset_name + '.npz')
63 | if os.path.exists(path):
64 | with np.load(path) as f:
65 | m2, s2 = f['m2'], f['s2']
66 | print(colored('Successfully loaded pre-computed Inception feature from cached file\n', 'light_magenta'))
67 | else:
68 | stacked_real_features = list()
69 | print(colored('Computing Inception features for {} '
70 | 'samples from real dataset.'.format(len(self.dataLoader.dataset)), 'light_magenta'))
71 | for batch in tqdm(self.dataLoader, desc='Calculating stats for data distribution', leave=False):
72 | real_samples = batch.to(self.device) if self.no_label else batch[0].to(self.device)
73 | real_features = self.calculate_inception_features(real_samples)
74 | stacked_real_features.append(real_features)
75 |
76 | stacked_real_features = torch.cat(stacked_real_features, dim=0).cpu().numpy()
77 | m2 = np.mean(stacked_real_features, axis=0)
78 | s2 = np.cov(stacked_real_features, rowvar=False)
79 | np.savez_compressed(path, m2=m2, s2=s2)
80 | print(colored('Dataset stats cached to {} for future use\n'.format(path), 'light_magenta'))
81 | return m2, s2
82 |
83 | @torch.inference_mode()
84 | def fid_score(self, sampler, num_samples, return_sample_image=False):
85 | batches = num_to_groups(num_samples, self.batch_size)
86 | stacked_fake_features = list()
87 | generated_samples = list() if return_sample_image else None
88 | for batch in tqdm(batches, desc='FID score calculation', leave=False):
89 | fake_samples = sampler(batch, clip=True, min1to1=False)
90 | if return_sample_image:
91 | generated_samples.append(fake_samples)
92 | fake_features = self.calculate_inception_features(fake_samples)
93 | stacked_fake_features.append(fake_features)
94 | stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu().numpy()
95 | m1 = np.mean(stacked_fake_features, axis=0)
96 | s1 = np.cov(stacked_fake_features, rowvar=False)
97 | generated_samples_return = None
98 | if return_sample_image:
99 | generated_samples_return = torch.cat(generated_samples, dim=0)
100 | generated_samples_return = (generated_samples_return + 1.0) * 0.5
101 | return calculate_frechet_distance(m1, s1, self.m2, self.s2), generated_samples_return
102 |
103 |
104 | def make_notification(content, color, boundary='-'):
105 | notice = boundary * 30 + '\n'
106 | side = boundary if boundary != '-' else '|'
107 | notice += '{}{:^28}{}\n'.format(side, content, side)
108 | notice += boundary * 30
109 | return colored(notice, color)
110 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from src import model_torch
2 | from src import model_original
3 | from src.trainer import Trainer
4 | from src.diffusion import GaussianDiffusion, DDIM_Sampler
5 | import yaml
6 | import argparse
7 |
8 |
9 | def main(args):
10 | with open(args.config, 'r') as f:
11 | config = yaml.load(f, Loader=yaml.FullLoader)
12 | unet_cfg = config['unet']
13 | ddim_cfg = config['ddim']
14 | trainer_cfg = config['trainer']
15 | image_size = unet_cfg['image_size']
16 |
17 | if config['type'] == 'original':
18 | unet = model_original.Unet(**unet_cfg).to(args.device)
19 | elif config['type'] == 'torch':
20 | unet = model_torch.Unet(**unet_cfg).to(args.device)
21 | else:
22 | unet = None
23 | print("Unet type must be one of ['original', 'torch']")
24 | exit()
25 |
26 | diffusion = GaussianDiffusion(unet, image_size=image_size).to(args.device)
27 |
28 | ddim_samplers = list()
29 | for sampler_cfg in ddim_cfg.values():
30 | ddim_samplers.append(DDIM_Sampler(diffusion, **sampler_cfg))
31 |
32 | trainer = Trainer(diffusion, ddim_samplers=ddim_samplers, exp_name=args.exp_name,
33 | cpu_percentage=args.cpu_percentage, **trainer_cfg)
34 | if args.load is not None:
35 | trainer.load(args.load, args.tensorboard, args.no_prev_ddim_setting)
36 | trainer.train()
37 |
38 |
39 | if __name__ == '__main__':
40 | parse = argparse.ArgumentParser(description='DDPM & DDIM')
41 | parse.add_argument('-c', '--config', type=str, default='./config/cifar10.yaml')
42 | parse.add_argument('-l', '--load', type=str, default=None)
43 | parse.add_argument('-t', '--tensorboard', type=str, default=None)
44 | parse.add_argument('--exp_name', default=None)
45 | parse.add_argument('--device', type=str, choices=['cuda', 'cpu'], default='cuda')
46 | parse.add_argument('--cpu_percentage', type=float, default=0)
47 | parse.add_argument('--no_prev_ddim_setting', action='store_true')
48 | args = parse.parse_args()
49 | main(args)
50 |
--------------------------------------------------------------------------------