├── .gitignore ├── README.md ├── config ├── celeba_hq_256.yaml ├── cifar10.yaml ├── cifar10_128dim.yaml ├── cifar10_example.yaml ├── cifar10_torch_example.yaml └── inference │ ├── celeba_hq_256.yaml │ ├── cifar10.yaml │ └── cifar10_128dim.yaml ├── images_README ├── celeba_hq_ex1.gif ├── celeba_hq_ex1.png ├── celeba_hq_ex2.gif ├── celeba_hq_ex2.png ├── celeba_hq_ex3.gif ├── celeba_hq_ex3.png ├── celeba_hq_ex4.gif ├── celeba_hq_ex4.png ├── cifar10_128_ex1.png ├── cifar10_128_ex2.gif ├── cifar10_128_ex3.png ├── cifar10_128_ex4.gif ├── cifar10_64_ex1.png └── cifar10_64_ex2.gif ├── inference.py ├── requirements.txt ├── src ├── dataset.py ├── diffusion.py ├── inferencer.py ├── model_original.py ├── model_torch.py ├── trainer.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | __pycache__/ 4 | results/ 5 | inference_results/ 6 | tensorboard/ 7 | *.pt 8 | tmp* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDPM & DDIM PyTorch 2 | ### DDPM & DDIM re-implementation with various functionality 3 | 4 | This code is the implementation code of the papers DDPM ([Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)) 5 | and DDIM ([Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)). 6 | 7 |       8 | 9 |

10 |       11 | 12 | 13 | --- 14 | ## Objective 15 | 16 | Our code is mainly based on [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch) 17 | repository, which is a Pytorch implementation of the official Tensorflow code 18 | [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion). But we find that there are some 19 | differences in the implementation especially in **U-net** structure. And further we find that Pytorch implementation 20 | version lacks some functionality for monitoring training and does not have inference code. So we decided to 21 | re-implement the code such that it can be helpful for someone who is first to **Diffusion** models. 22 | 23 | --- 24 | ## Results 25 | 26 | | Dataset | Model checkpoint name | FID (↓) | 27 | |:---------:|:---------------------:|:-------:| 28 | | Cifar10 | [cifar10_64dim.pt](https://drive.google.com/file/d/1vHfP8f_viyadhuXMaLfAQ1Iu5UE0WiiJ/view?usp=drive_link) | 11.81 | 29 | | Cifar10 | [cifar10_128dim.pt](https://drive.google.com/file/d/1NtysETxHPinns6JabjawyWTnkjJKT34M/view?usp=drive_link) | 8.31 | 30 | | CelebA-HQ | [celeba_hq_256.pt](https://drive.google.com/file/d/1zzZbkNkMYCFKmWKW5Sh2JsrUsNrWyDCs/view?usp=drive_link) | 11.97 | 31 | 32 | - cifar10_64dim 33 | 34 |       35 | 36 | 37 | - cifar10_128dim 38 | 39 |       40 | 41 | 42 | - celeba_hq_256 43 | 44 |       45 | 46 |

47 |       48 | 49 |

50 |       51 | 52 | 53 | --- 54 | ## Installation 55 | Tested for ```python 3.8.17``` with ```torch 1.12.1+cu113``` and ```torchvision 0.13.1+cu113```. 56 | Download appropriate pytorch version via [torch website](https://pytorch.org/get-started/previous-versions/#v1121) or by following command. 57 | ``` 58 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 59 | ``` 60 | Install other required moduls by following command. 61 | ``` 62 | pip install -r requirements.txt 63 | ``` 64 | 65 | --- 66 | ## Quick Start 67 | ### Inference 68 | Download pre-trained model checkpoints from [model checkpoints](https://drive.google.com/drive/folders/1YdFQEb3d7rInRVVLLN3VZu-fI0qq15pG?usp=sharing) 69 | 70 | - Cifar10 (64 dimension for first hidden dimension) 71 | ```commandline 72 | python inference.py -c ./config/inference/cifar10.yaml -l /path_to_cifar10_64dim.pt/cifar10_64dim.pt 73 | ``` 74 | - Cifar10 (128 dimension for first hidden dimension, purposed structure by original implementation) 75 | ```commandline 76 | python inference.py -c ./config/inference/cifar10_128dim.yaml -l /path_to_cifar10_128dim.pt/cifar10_128dim.pt 77 | ``` 78 | - CelebA-HQ 79 | 80 | You have to download CelebA-HQ dataset from [kaggle](https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256/). 81 | After un-zipping the zip file, you may find folder named ```/celeba_hq_256```, make the folder named ```/data``` if your 82 | project directory does not have it, place the ```/celeba_hq_256``` folder under the ```/data``` folder such that final 83 | structure must be configured as follows. 84 | ``` 85 | - DDPM 86 | - /config 87 | - /src 88 | - /images_README 89 | - inference.py 90 | - train.py 91 | ... 92 | - /data (make the directory if you don't have it) 93 | - /celeba_hq_256 94 | - 00000.jpg 95 | - 00001.jpg 96 | ... 97 | - 29999.jpg 98 | 99 | ``` 100 | 101 | ```commandline 102 | python inference.py -c ./config/inference/celeba_hq_256.yaml -l /path_to_celeba_hq_256.pt/celeba_hq_256.pt 103 | ``` 104 | 105 | ### Training 106 | 107 | - Cifar10 (64 dimension for first hidden dimension) 108 | ```commandline 109 | python train.py -c ./config/cifar10.yaml 110 | ``` 111 | - Cifar10 (128 dimension for first hidden dimension, purposed structure by original implementation) 112 | ```commandline 113 | python train.py -c ./config/cifar10_128dim.yaml 114 | ``` 115 | - CelebA-HQ 116 | 117 | You have to download dataset, consult details in inference section in Quick Start. 118 | 119 | ```commandline 120 | python train.py -c ./config/celeba_hq_256.yaml 121 | ``` 122 | 123 | :pushpin: 124 | ***You can find more detailed explanation for training and inference at the below two section.*** 125 | 126 | --- 127 | ## Training ( Detailed version ) 128 | 129 |
130 | Expand for details 131 | 132 | To train the diffusion model, first thing you have to do is to configure your training settings by making configuration 133 | file. You can find some example inside the folder ```./config```. I will explain how to configure your training using 134 | ```./config/cifar10_example.yaml``` file and ```./config/cifar10_torch_example.yaml``` file. 135 | Inside the ```cifar10_example.yaml``` you may find 4 primary section, ```type, unet, ddim, trainer```. 136 | 137 | We will first look at ```trainer``` section which is configured as follows. 138 | 139 | ```yaml 140 | dataset: cifar10 141 | batch_size: 128 142 | lr: 0.0002 143 | total_step: 600000 144 | save_and_sample_every: 2500 145 | num_samples: 64 146 | fid_estimate_batch_size: 128 147 | ddpm_fid_score_estimate_every: null 148 | ddpm_num_fid_samples: null 149 | tensorboard: true 150 | clip: both 151 | ``` 152 | - ```dataset```: You can give Dataset name which is available by ```torchvision.datasets```. You can find some Datasets provided 153 | by torchvision in this [website](https://pytorch.org/vision/stable/datasets.html). If you want to use torchvision's 154 | Dataset just provide the dataset name, for example ```cifar10```. Currently, tested datasets are ```cifar10```. 155 | Or if you want to use custom datasets which you have prepared, you have to pass the path to the folder which is containing 156 | images 157 | 158 | 159 | - ```batch_size, lr, total_step```: You can find the values used by DDPM author in the [DDPM paper](https://arxiv.org/abs/2006.11239) 160 | Appendix B. total_step means total training step, for example DDPM author trained cifar10 model with 800K steps. 161 | 162 | 163 | - ```save_and_sample_every```: The interval to which save the model and generated samples. For example in this case, 164 | for every 2500 steps trainer will save the model weights and also generate some sample images to visualize the training progress. 165 | 166 | 167 | - ```num_samples```: When sampling the images evey ```save_and_sample_every``` steps, trainer will sample total ```num_samples``` 168 | images and save it to one large image containing each sampled images where one large image have (num_samples)**0.5 rows and 169 | columns. So ```num_samples``` must be square number ex) 25, 36, 49, 64, ... 170 | 171 | 172 | - ```fid_estimate_batch_size```: Batch size for sampling images for FID calculation. This batch size will be applied to 173 | DDPM sampler as well as DDIM samplers. If you cannot decide the value, just setting this value equal to ```batch_size``` will 174 | be fine. 175 | 176 | 177 | - ```ddpm_fid_score_estimate_every```: Step interval for FID calculation using DDPM sampler. If set to null, FID score 178 | will not be calculated with DDPM sampling. If you use DDPM sampling for FID calculation, i.e. setting this value other than null, 179 | it can be very time-consuming, ***so it is wise to set this value to null***, and use DDIM sampler for 180 | FID calculation (Using DDIM sampler is explained below). But anyway you can calculate FID score with DDPM sampler if you **insist**. 181 | 182 | 183 | - ```ddpm_num_fid_samples```: Number of sampling images for FID calculation using DDPM sampler. If you set 184 | ```ddpm_fid_score_estimate_every``` to null, i.e. not using DDPM sampler for FID calculation, then this value will 185 | be just ignored. 186 | 187 | 188 | - ```tensorboard```: If set to true, then you can monitor training progress such as loss, FID value, and sampled images, 189 | during training, with the tensorboard. 190 | 191 | 192 | - ```clip```: It must be one of [both, true, false]. This is related to sampling of x_{t-1} from p_{theta}(x_{t-1} | x_t). 193 | There are two ways to sample x_{t-1}. 194 | One way is to follow paper and this corresponds to line 4 in Algorithm 2 in DDPM paper. (```clip==False```) 195 | Another way is to clip(or clamp) the predicted x_0 to -1 ~ 1 for better sampling result. 196 | To clip the x_0 to out desired range, we cannot directly apply (11) to sample x_{t-1}, rather we have to 197 | calculate predicted x_0 using (4) and then calculate mu in (7) using that predicted x_0. Which is exactly 198 | same calculation except for clipping. 199 | As you might easily expect, using clip leads to better sampling result since it 200 | restricts sampled images range to -1 ~ 1. So for the better sampling result, it is strongly suggested 201 | setting ```clip``` to true. If ```clip==both``` then sampling is done twice, one done with 202 | ```clip==True``` and the other done with ```clip==False```. 203 | [Reference](https://github.com/hojonathanho/diffusion/issues/5) 204 | 205 | --- 206 | 207 | Now we will look at ```type, unet``` section which is configured as follows. 208 | 209 | ```yaml 210 | # ```./config/cifar10_example.yaml``` 211 | type: original 212 | unet: 213 | dim: 64 214 | image_size: 32 215 | dim_multiply: 216 | - 1 217 | - 2 218 | - 2 219 | - 2 220 | attn_resolutions: 221 | - 16 222 | dropout: 0.1 223 | num_res_blocks: 2 224 | ``` 225 | 226 | ```yaml 227 | # ```./config/cifar10_torch_example.yaml``` 228 | type: torch 229 | unet: 230 | dim: 64 231 | image_size: 32 232 | dim_multiply: 233 | - 1 234 | - 2 235 | - 2 236 | - 2 237 | full_attn: 238 | - false 239 | - true 240 | - false 241 | - false 242 | attn_heads: 4 243 | attn_head_dim: 32 244 | ``` 245 | 246 | -```type```: It must be one of [original, torch]. ***original*** will use U-net structure which was originally suggested by 247 | Jonathan Ho. So it's structure will be the one used in [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion) 248 | which is an official version written in Tensorflow. ***torch*** will use U-net structure which was suggested by 249 | [denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch) which is a transcribed version of 250 | official Tensorflow version. 251 | 252 | I have separated those two because there structure differs significantly. To name a few, following is the difference of 253 | those two U-net structure. 254 | 1. official version use self Attention where the feature map resolution at each U-net level 255 | is in ```attn_resolutions```. In the DDPM paper you can find that they used self Attention at the 16X16 resolution, 256 | and this is why ```attn_resolutions``` is by default ```[16, ]``` 257 | 258 | On the other hand, Pytorch transcribed version use Linear Attention and multi-head self Attention. They use 259 | multi-head self Attention at the U-net level where ```full_attn``` is true and Linear Attention at the 260 | rest of the U-net level. So in this particular case, they used multi-head self Attention at the U-net level 1 (I will 261 | denote U-net level as 0, 1, 2, 3, ...) and the Linear Attention at the U-net level 0, 2, 3. 262 | 263 | -```unet.dim```: This is related to the hidden channel dimension of feature map at each U-net model. You can find more 264 | detail right below. 265 | 266 | -```unet.dim_multiply```: ```len(dim_multiply)``` will be the depth of U-net model with at each level i, the dimension 267 | of channel will be ```dim * dim_multiply[i]```. If the input image shape is [H, W, 3] then at the lowest level, 268 | feature map shape will be [H/(2^(len(dim_multiply)-1), W/(2^(len(dim_multiply)-1), dim*dim_multiply[-1]] 269 | if not considering U-net down-up path connection. 270 | 271 | -```unet.image_size```: Size of the input image. Image will be resized and cropped if ```image_size``` does not 272 | equal to actual input image size. Expected to be ```Integer```, I have not tested for non-square images. 273 | 274 | -```unet.attn_resolution / unet.full_attn```: Explained above. Since ```attn_resolution``` value 275 | must equal to the resolution value of feature map where you want to apply self Attention, you have to carefully 276 | calculate desired resolution value. In the case of ```full_attn```, it is related to applying particular Attention mechanism at each 277 | level, it must satisfy ```len(full_attn) == len(dim_multiply)``` 278 | 279 | -```unet.num_res_blocks```: Number of ResnetBlock at each level. In downward path, at each level, there will be 280 | num_res_blocks amount of ResnetBlock module and in upward path, at each level, there will be 281 | (num_res_blocks+1) amount of ResnetBlock module. 282 | 283 | -```unet.attn_heads, unet.attn_head_dim```: In the torch implementation it uses multi-head-self-Attention. attn_head is 284 | the # of head. It corresponds to h in "Attention is all you need" paper. See section 3.2.2 285 | attn_head_dim is the dimension of each head. It corresponds to d_k in Attention paper. 286 | 287 | --- 288 | 289 | Lastly we will look at ```ddim``` section which is configured as follows. 290 | 291 | 292 | ```yaml 293 | ddim: 294 | 0: 295 | ddim_sampling_steps: 20 296 | sample_every: 5000 297 | calculate_fid: true 298 | num_fid_sample: 6000 299 | eta: 0 300 | save: true 301 | 1: 302 | ddim_sampling_steps: 50 303 | sample_every: 50000 304 | calculate_fid: true 305 | num_fid_sample: 6000 306 | eta: 0 307 | save: true 308 | 2: 309 | ddim_sampling_steps: 20 310 | sample_every: 100000 311 | calculate_fid: true 312 | num_fid_sample: 60000 313 | eta: 0 314 | save: true 315 | ``` 316 | 317 | There are 3 subsection (0, 1, 2) which means it will use 3 DDIM Sampler for sampling image 318 | and FID calculation during training. The name of each subsection, which are 0, 1, 2, 319 | is not important. Each DDIM Sampler name will be set to ```DDIM_{index}_steps{ddim_sampling_steps}_eta{eta}``` no matter 320 | what the name of each subsection is set in configuration file. 321 | 322 | -```ddim_sampling_steps```: The number of de-noising steps for DDIM sampling. In DDPM they used 1000 steps for sampling images. But 323 | in DDIM we can control the total number of de-noising steps for generating images. If this value is set to 20, 324 | then the speed of generating image will be 50 times faster than DDPM sampling. Preferred value would be 10~100. 325 | 326 | -```sample_every```: This control how often sampling be done with particular sampler. If set to 5000 then 327 | every 5000 steps, this sampler will be activated for sampling images. So if total training step is 50000, there will be 328 | total 10 sampling action. 329 | 330 | -```calculata_fid```: Whether to calculate FID 331 | 332 | -```num_fid_sample```: Only valid if ```calculate_fid``` is set to true. This control how many sampled images to use for 333 | FID calculation. The speed of FID calculation for particular sampler will be inversely 334 | proportional to (ddim_sanmpling_steps * num_fid_sample) 335 | 336 | -```eta```: Hyperparameter to control the stochasticity, see (16) in DDIM paper. 337 | 0: deterministic(DDIM) , 1: fully stochastic(DDPM) 338 | 339 | -```save```: Whether to save the model checkpoint based on FID value calculated by particular sampler. If set to true 340 | then model checkpoint will be saved on ```.pt``` file when model achieve the best 341 | FID value for particular sampler. 342 | 343 | --- 344 | 345 | Now we are finished setting configuration file and the training the model can be done by following command. 346 | 347 | ```commandline 348 | python train.py -c /path_to_config_file/configuration_file.yaml 349 | ``` 350 | [Mandatory] 351 | 352 | -c, --config : Path to configuration file. Path must include file name with extension .yaml 353 | 354 | [Optional] 355 | 356 | -l, --load : Path to model checkpoint. Path must include file name with extension .pt 357 | You can resume training by using this option. 358 | -t, --tensorboard : Path to tensorboard folder. If you resume training and want to restore previous tensorboard, set 359 | this option to previously generated tensorboard folder. 360 | --exp_name : Name for experiment. If not set, current time will be set as experiment name. 361 | --cpu_percentage : Float value from 0.0 to 1.0, default value 0.0 It is used to control the num_workers parameter for DataLoader. 362 | num_workers will be set to "Number of CPU available for your device * cpu_percentage". In Windows sometimes setting 363 | this value other than 0.0 yields unexpected behavior or failure to train the model. So if you have problem triaining 364 | the model in Windows, do not change this value. 365 | --no_prev_ddim_setting : If set, store true. If you have changed DDIM setting, for example change the 366 | number of DDIM sampler or change the sampling steps for DDIM sampler, set this option. 367 |
368 | 369 | --- 370 | ## Inference ( Detailed version ) 371 | 372 |
373 | Expand for details 374 | 375 | To inference the diffusion model, first thing you have to do is to configure your inference settings by making configuration 376 | file. You can find some example inside the folder ```./config/inference```. I will explain how to configure your inference using 377 | ```./config/inference/cifar10.yaml``` file 378 | Inside the file you may find 4 primary section, ```type, unet, ddim, inferencer```. ```type, unet``` must match the 379 | configuration for the training. On the other hand, ```ddim``` section need not match to the configuration for the 380 | training. One thing to notice is that ```sample_every, save``` option will not be used in inference for DDIM. We 381 | are left with ```inferencer``` section. 382 | 383 | ```yaml 384 | inferencer: 385 | dataset: cifar10 386 | batch_size: 128 387 | clip: true 388 | num_samples_per_image: 64 389 | num_images_to_generate: 2 390 | ddpm_fid_estimate: true 391 | ddpm_num_fid_samples: 60000 392 | return_all_step: true 393 | make_denoising_gif: true 394 | num_gif: 50 395 | ``` 396 | 397 | -```dataset, batch_size, clip```: Same meaning as in configuration file for training. 398 | 399 | -```num_samples_per_image```: Same meaning as ```num_samples``` in configuration file for training. 400 | Sampler will sample total ```num_samples_per_image``` images and save it to one large image containing each sampled images 401 | where one large image have (num_samples)**0.5 rows and columns. So ```num_samples_per_image``` must be square number ex) 25, 36, 49, 64, ... 402 | 403 | -```num_images_to_generate```: How many large merged image to generate. So if this value is set to 2 then there will be 404 | 2 large image with each image containing ```num_samples_per_image``` sampled sub images. 405 | 406 | -```ddpm_fid_estimate, ddpm_num_fid_samples```: Whether to calculate FID value for DDPM sampler. And if ```ddpm_fid_estimate``` 407 | is set to true, ```ddpm_num_fid_samples``` decides the number of sampling images for calculating FID value. 408 | 409 | -```return_all_steps```: Whether to return all the images during de-noising steps. So in DDPM sampler, during 1000 410 | de-noising steps all the intermediate images will be returned. In the case of DDIM sampler, ```ddim_sampling_steps``` images 411 | will be returned. 412 | 413 | -```make_denoising_gif```: Whether to make gif which contains de-noising process visually. To make denoising gif, 414 | ```return_all_steps``` must set to true. 415 | 416 | -```num_gif```: Number of images to make gif which contains de-noising process visually. Intermediate denoised image 417 | will be sampled evenly with ```num_gif``` images to make denoising gif. 418 | 419 | --- 420 | 421 | Now we are finished setting configuration file and the inferencing can be done by following command. 422 | 423 | ```commandline 424 | python inference.py -c /path_to_config_file/configuration_file.yaml -l /path_to_model_checkpoint_file/model_checkpoint.pt 425 | ``` 426 | 427 | [Mandatory] 428 | 429 | -c, --config : Path to configuration file. Path must include file name with extension .yaml 430 | -l, --load : Path to model checkpoint. Path must include file name with extension .pt 431 |
432 | 433 | --- 434 | ## References 435 | - [Jonathan Ho, Ajay Jain, Pieter Abbeel : Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) 436 | 437 | - [Jiaming Song, Chenlin Meng, Stefano Ermon : Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) 438 | -------------------------------------------------------------------------------- /config/celeba_hq_256.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 128 4 | image_size: 64 5 | dim_multiply: 6 | - 1 7 | - 1 8 | - 2 9 | - 2 10 | - 4 11 | attn_resolutions: 12 | - 16 13 | dropout: 0.0 14 | num_res_blocks: 2 15 | ddim: 16 | 0: 17 | ddim_sampling_steps: 20 18 | sample_every: 20000 19 | calculate_fid: true 20 | num_fid_sample: 10000 21 | save: true 22 | trainer: 23 | dataset: ./data/celeba_hq_256 24 | batch_size: 64 25 | lr: 2.0e-05 26 | clip: true 27 | total_step: 500000 28 | save_and_sample_every: 10000 29 | fid_estimate_batch_size: 64 30 | num_samples: 64 31 | -------------------------------------------------------------------------------- /config/cifar10.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 64 4 | image_size: 32 5 | dim_multiply: 6 | - 1 7 | - 2 8 | - 2 9 | - 2 10 | attn_resolutions: 11 | - 16 12 | dropout: 0.1 13 | num_res_blocks: 2 14 | ddim: 15 | 0: 16 | ddim_sampling_steps: 20 17 | sample_every: 10000 18 | calculate_fid: true 19 | num_fid_sample: 30000 20 | save: true 21 | trainer: 22 | dataset: cifar10 23 | batch_size: 128 24 | lr: 0.0002 25 | total_step: 600000 26 | save_and_sample_every: 2500 27 | num_samples: 64 28 | fid_estimate_batch_size: 128 29 | clip: true 30 | -------------------------------------------------------------------------------- /config/cifar10_128dim.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 128 4 | image_size: 32 5 | dim_multiply: 6 | - 1 7 | - 2 8 | - 2 9 | - 2 10 | attn_resolutions: 11 | - 16 12 | dropout: 0.1 13 | num_res_blocks: 2 14 | ddim: 15 | 0: 16 | ddim_sampling_steps: 20 17 | sample_every: 10000 18 | calculate_fid: true 19 | num_fid_sample: 30000 20 | save: true 21 | trainer: 22 | dataset: cifar10 23 | batch_size: 128 24 | lr: 0.0002 25 | total_step: 600000 26 | save_and_sample_every: 5000 27 | fid_estimate_batch_size: 128 28 | num_samples: 64 29 | clip: true 30 | 31 | -------------------------------------------------------------------------------- /config/cifar10_example.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 64 4 | image_size: 32 5 | dim_multiply: 6 | - 1 7 | - 2 8 | - 2 9 | - 2 10 | attn_resolutions: 11 | - 16 12 | dropout: 0.1 13 | num_res_blocks: 2 14 | ddim: 15 | 0: 16 | ddim_sampling_steps: 20 17 | sample_every: 5000 18 | calculate_fid: true 19 | num_fid_sample: 6000 20 | eta: 0 21 | save: true 22 | 1: 23 | ddim_sampling_steps: 50 24 | sample_every: 50000 25 | calculate_fid: true 26 | num_fid_sample: 6000 27 | eta: 0 28 | save: true 29 | 2: 30 | ddim_sampling_steps: 20 31 | sample_every: 100000 32 | calculate_fid: true 33 | num_fid_sample: 60000 34 | eta: 0 35 | save: true 36 | trainer: 37 | dataset: cifar10 38 | batch_size: 128 39 | lr: 0.0002 40 | total_step: 600000 41 | save_and_sample_every: 2500 42 | num_samples: 64 43 | fid_estimate_batch_size: 128 44 | ddpm_fid_score_estimate_every: null 45 | ddpm_num_fid_samples: null 46 | tensorboard: true 47 | clip: both 48 | -------------------------------------------------------------------------------- /config/cifar10_torch_example.yaml: -------------------------------------------------------------------------------- 1 | type: torch 2 | unet: 3 | dim: 64 4 | image_size: 32 5 | dim_multiply: 6 | - 1 7 | - 2 8 | - 2 9 | - 2 10 | full_attn: 11 | - false 12 | - true 13 | - false 14 | - false 15 | attn_heads: 4 16 | attn_head_dim: 32 17 | ddim: 18 | 0: 19 | ddim_sampling_steps: 20 20 | sample_every: 5000 21 | calculate_fid: true 22 | num_fid_sample: 6000 23 | eta: 0 24 | save: true 25 | 1: 26 | ddim_sampling_steps: 50 27 | sample_every: 50000 28 | calculate_fid: true 29 | num_fid_sample: 6000 30 | eta: 0 31 | save: true 32 | 2: 33 | ddim_sampling_steps: 20 34 | sample_every: 100000 35 | calculate_fid: true 36 | num_fid_sample: 60000 37 | eta: 0 38 | save: true 39 | trainer: 40 | dataset: cifar10 41 | batch_size: 128 42 | lr: 0.0002 43 | total_step: 600000 44 | save_and_sample_every: 2500 45 | num_samples: 64 46 | fid_estimate_batch_size: 128 47 | ddpm_fid_score_estimate_every: null 48 | ddpm_num_fid_samples: null 49 | tensorboard: true 50 | clip: both 51 | -------------------------------------------------------------------------------- /config/inference/celeba_hq_256.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 128 4 | image_size: 64 5 | dim_multiply: 6 | - 1 7 | - 1 8 | - 2 9 | - 2 10 | - 4 11 | attn_resolutions: 12 | - 16 13 | dropout: 0.0 14 | num_res_blocks: 2 15 | ddim: 16 | inferencer: 17 | dataset: ./data/celeba_hq_256 18 | batch_size: 32 19 | clip: true 20 | num_samples_per_image: 64 21 | num_images_to_generate: 20 22 | ddpm_fid_estimate: true 23 | ddpm_num_fid_samples: 30000 24 | return_all_step: true 25 | make_denoising_gif: true 26 | num_gif: 50 27 | -------------------------------------------------------------------------------- /config/inference/cifar10.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 64 4 | image_size: 32 5 | dim_multiply: 6 | - 1 7 | - 2 8 | - 2 9 | - 2 10 | attn_resolutions: 11 | - 16 12 | dropout: 0.1 13 | num_res_blocks: 2 14 | ddim: 15 | 0: 16 | ddim_sampling_steps: 20 17 | calculate_fid: true 18 | num_fid_sample: 6000 19 | generate_image: true 20 | inferencer: 21 | dataset: cifar10 22 | batch_size: 128 23 | clip: true 24 | num_samples_per_image: 64 25 | num_images_to_generate: 10 26 | ddpm_fid_estimate: true 27 | ddpm_num_fid_samples: 60000 28 | return_all_step: true 29 | make_denoising_gif: true 30 | num_gif: 50 -------------------------------------------------------------------------------- /config/inference/cifar10_128dim.yaml: -------------------------------------------------------------------------------- 1 | type: original 2 | unet: 3 | dim: 128 4 | image_size: 32 5 | dim_multiply: 6 | - 1 7 | - 2 8 | - 2 9 | - 2 10 | attn_resolutions: 11 | - 16 12 | dropout: 0.1 13 | num_res_blocks: 2 14 | ddim: 15 | 0: 16 | ddim_sampling_steps: 20 17 | calculate_fid: true 18 | num_fid_sample: 6000 19 | generate_image: true 20 | inferencer: 21 | dataset: cifar10 22 | batch_size: 128 23 | clip: true 24 | num_samples_per_image: 64 25 | num_images_to_generate: 20 26 | ddpm_fid_estimate: true 27 | ddpm_num_fid_samples: 60000 28 | return_all_step: true 29 | make_denoising_gif: true 30 | num_gif: 50 -------------------------------------------------------------------------------- /images_README/celeba_hq_ex1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex1.gif -------------------------------------------------------------------------------- /images_README/celeba_hq_ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex1.png -------------------------------------------------------------------------------- /images_README/celeba_hq_ex2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex2.gif -------------------------------------------------------------------------------- /images_README/celeba_hq_ex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex2.png -------------------------------------------------------------------------------- /images_README/celeba_hq_ex3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex3.gif -------------------------------------------------------------------------------- /images_README/celeba_hq_ex3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex3.png -------------------------------------------------------------------------------- /images_README/celeba_hq_ex4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex4.gif -------------------------------------------------------------------------------- /images_README/celeba_hq_ex4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/celeba_hq_ex4.png -------------------------------------------------------------------------------- /images_README/cifar10_128_ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex1.png -------------------------------------------------------------------------------- /images_README/cifar10_128_ex2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex2.gif -------------------------------------------------------------------------------- /images_README/cifar10_128_ex3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex3.png -------------------------------------------------------------------------------- /images_README/cifar10_128_ex4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_128_ex4.gif -------------------------------------------------------------------------------- /images_README/cifar10_64_ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_64_ex1.png -------------------------------------------------------------------------------- /images_README/cifar10_64_ex2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taehoon-yoon/Diffusion-Probabilistic-Models/2b6ff285e88583ccf4ebdd38c1662c67414e676d/images_README/cifar10_64_ex2.gif -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from src import model_torch 2 | from src import model_original 3 | from src.diffusion import GaussianDiffusion, DDIM_Sampler 4 | from src.inferencer import Inferencer 5 | import yaml 6 | import argparse 7 | 8 | 9 | def main(args): 10 | with open(args.config, 'r') as f: 11 | config = yaml.load(f, Loader=yaml.FullLoader) 12 | unet_cfg = config['unet'] 13 | ddim_cfg = config['ddim'] 14 | trainer_cfg = config['inferencer'] 15 | image_size = unet_cfg['image_size'] 16 | 17 | if config['type'] == 'original': 18 | unet = model_original.Unet(**unet_cfg).to(args.device) 19 | elif config['type'] == 'torch': 20 | unet = model_torch.Unet(**unet_cfg).to(args.device) 21 | else: 22 | unet = None 23 | print("Unet type must be one of ['original', 'torch']") 24 | exit() 25 | 26 | diffusion = GaussianDiffusion(unet, image_size=image_size).to(args.device) 27 | 28 | ddim_samplers = list() 29 | if isinstance(ddim_cfg, dict): 30 | for sampler_cfg in ddim_cfg.values(): 31 | ddim_samplers.append(DDIM_Sampler(diffusion, **sampler_cfg)) 32 | 33 | inferencer = Inferencer(diffusion, ddim_samplers=ddim_samplers, time_step=diffusion.time_step, **trainer_cfg) 34 | inferencer.load(args.load) 35 | inferencer.inference() 36 | 37 | 38 | if __name__ == '__main__': 39 | parse = argparse.ArgumentParser(description='DDPM & DDIM') 40 | parse.add_argument('-c', '--config', type=str, default='./config/inference/cifar10.yaml') 41 | parse.add_argument('-l', '--load', type=str, default=None) 42 | parse.add_argument('-d', '--device', type=str, choices=['cuda', 'cpu'], default='cuda') 43 | args = parse.parse_args() 44 | main(args) 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | pillow==9.5.0 3 | tensorboard 4 | tqdm 5 | termcolor 6 | pytorch-fid 7 | pyyaml 8 | six 9 | imageio -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import Dataset 5 | from termcolor import colored 6 | import ssl 7 | import os 8 | from glob import glob 9 | from PIL import Image 10 | 11 | ssl._create_default_https_context = ssl._create_unverified_context 12 | 13 | 14 | class customDataset(Dataset): 15 | def __init__(self, folder, transform, exts=['jpg', 'jpeg', 'png', 'tiff']): 16 | super().__init__() 17 | self.paths = [p for ext in exts for p in glob(os.path.join(folder, f'*.{ext}'))] 18 | self.transform = transform 19 | 20 | def __len__(self): 21 | return len(self.paths) 22 | 23 | def __getitem__(self, item): 24 | img = Image.open(self.paths[item]) 25 | return self.transform(img) 26 | 27 | 28 | def dataset_wrapper(dataset, image_size, augment_horizontal_flip=True, info_color='light_green', min1to1=True): 29 | transform = transforms.Compose([ 30 | transforms.Resize(image_size), 31 | transforms.RandomHorizontalFlip() if augment_horizontal_flip else torch.nn.Identity(), 32 | transforms.CenterCrop(image_size), 33 | transforms.ToTensor(), # turn into torch Tensor of shape CHW, 0 ~ 1 34 | transforms.Lambda(lambda x: ((x * 2) - 1)) if min1to1 else torch.nn.Identity()# -1 ~ 1 35 | ]) 36 | if os.path.isdir(dataset): 37 | print(colored('Loading local file directory', info_color)) 38 | dataSet = customDataset(dataset, transform) 39 | print(colored('Successfully loaded {} images!'.format(len(dataSet)), info_color)) 40 | return dataSet 41 | else: 42 | dataset = dataset.lower() 43 | assert dataset in ['cifar10'] 44 | print(colored('Loading {} dataset'.format(dataset), info_color)) 45 | if dataset == 'cifar10': 46 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) 47 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) 48 | fullset = torch.utils.data.ConcatDataset([trainset, testset]) 49 | return fullset 50 | -------------------------------------------------------------------------------- /src/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | 6 | 7 | class GaussianDiffusion(nn.Module): 8 | def __init__(self, model, image_size, time_step=1000, loss_type='l2'): 9 | """ 10 | Diffusion model. It is based on Denoising Diffusion Probabilistic Models (DDPM), Jonathan Ho et al. 11 | :param model: U-net model for de-noising network 12 | :param image_size: image size 13 | :param time_step: Gaussian diffusion length T. In paper they used T=1000 14 | :param loss_type: either l1, l2, huber. Default, l2 15 | """ 16 | super().__init__() 17 | self.unet = model 18 | self.channel = self.unet.channel 19 | self.device = self.unet.device 20 | self.image_size = image_size 21 | self.time_step = time_step 22 | self.loss_type = loss_type 23 | 24 | beta = self.linear_beta_schedule() # (t, ) t=time_step, in DDPM paper t=1000 25 | alpha = 1. - beta # (a1, a2, a3, ... at) 26 | alpha_bar = torch.cumprod(alpha, dim=0) # (a1, a1*a2, a1*a2*a3, ..., a1*a2*~*at) 27 | alpha_bar_prev = F.pad(alpha_bar[:-1], pad=(1, 0), value=1.) # (1, a1, a1*a2, ..., a1*a2*~*a(t-1)) 28 | 29 | self.register_buffer('beta', beta) 30 | self.register_buffer('alpha', alpha) 31 | self.register_buffer('alpha_bar', alpha_bar) 32 | self.register_buffer('alpha_bar_prev', alpha_bar_prev) 33 | 34 | # calculation for q(x_t | x_0) consult (4) in DDPM paper. 35 | self.register_buffer('sqrt_alpha_bar', torch.sqrt(alpha_bar)) 36 | self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - alpha_bar)) 37 | 38 | # calculation for q(x_{t-1} | x_t, x_0) consult (7) in DDPM paper. 39 | self.register_buffer('beta_tilde', beta * ((1. - alpha_bar_prev) / (1. - alpha_bar))) 40 | self.register_buffer('mean_tilde_x0_coeff', beta * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar)) 41 | self.register_buffer('mean_tilde_xt_coeff', torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)) 42 | 43 | # calculation for x0 consult (9) in DDPM paper. 44 | self.register_buffer('sqrt_recip_alpha_bar', torch.sqrt(1. / alpha_bar)) 45 | self.register_buffer('sqrt_recip_alpha_bar_min_1', torch.sqrt(1. / alpha_bar - 1)) 46 | 47 | # calculation for (11) in DDPM paper. 48 | self.register_buffer('sqrt_recip_alpha', torch.sqrt(1. / alpha)) 49 | self.register_buffer('beta_over_sqrt_one_minus_alpha_bar', beta / torch.sqrt(1. - alpha_bar)) 50 | 51 | # Forward Process / Diffusion Process ############################################################################## 52 | def q_sample(self, x0, t, noise): 53 | """ 54 | Sampling x_t, according to q(x_t | x_0). Consult (4) in DDPM paper. 55 | :param x0: (b, c, h, w), original image 56 | :param t: (b, ), timestep t 57 | :param noise: (b, c, h, w), We calculate q(x_t | x_0) using re-parameterization trick. 58 | :return: x_t with shape=(b, c, h, w), which is a noised image at timestep t 59 | """ 60 | # Get x_t ~ q(x_t | x_0) using re-parameterization trick 61 | return self.sqrt_alpha_bar[t][:, None, None, None] * x0 + \ 62 | self.sqrt_one_minus_alpha_bar[t][:, None, None, None] * noise 63 | 64 | def forward(self, img): 65 | """ 66 | Calculate L_simple according to (14) in DDPM paper 67 | :param img: (b, c, h, w), original image 68 | :return: L_simple 69 | """ 70 | b, c, h, w = img.shape 71 | assert h == self.image_size and w == self.image_size, f'height and width of image must be {self.image_size}' 72 | t = torch.randint(0, self.time_step, (b,), device=img.device).long() # (b, ) 73 | noise = torch.randn_like(img) # corresponds to epsilon in (14) 74 | noised_image = self.q_sample(img, t, noise) # argument inside epsilon_theta 75 | predicted_noise = self.unet(noised_image, t) # epsilon_theta in (14) 76 | 77 | if self.loss_type == 'l1': 78 | loss = F.l1_loss(noise, predicted_noise) 79 | elif self.loss_type == 'l2': 80 | loss = F.mse_loss(noise, predicted_noise) 81 | elif self.loss_type == "huber": 82 | loss = F.smooth_l1_loss(noise, predicted_noise) 83 | else: 84 | raise NotImplementedError() 85 | return loss 86 | 87 | #################################################################################################################### 88 | 89 | # Reverse Process / De-noising Process ############################################################################# 90 | @torch.inference_mode() 91 | def p_sample(self, xt, t, clip=True): 92 | """ 93 | Sample x_{t-1} from p_{theta}(x_{t-1} | x_t). 94 | There are two ways to sample x_{t-1}. 95 | One way is to follow paper and this corresponds to line 4 in Algorithm 2 in DDPM paper. (clip==False) 96 | Another way is to clip(or clamp) the predicted x_0 to -1 ~ 1 for better sampling result. 97 | To clip the x_0 to out desired range, we cannot directly apply (11) to sample x_{t-1}, rather we have to 98 | calculate predicted x_0 using (4) and then calculate mu in (7) using that predicted x_0. Which is exactly 99 | same calculation except for clipping. 100 | As you might easily expect, using clip leads to better sampling result since it 101 | restricts sampled images range to -1 ~ 1. Ref: https://github.com/hojonathanho/diffusion/issues/5 102 | 103 | :param xt: ( b, c, h, w), noised image at time step t 104 | :param t: ( b, ) 105 | :param clip: [True, False] Whether to clip predicted x_0 to our desired range -1 ~ 1. 106 | :return: de-noised image at time step t-1 107 | """ 108 | batched_time = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long) 109 | pred_noise = self.unet(xt, batched_time) # corresponds to epsilon_{theta} 110 | if clip: 111 | x0 = self.sqrt_recip_alpha_bar[t] * xt - self.sqrt_recip_alpha_bar_min_1[t] * pred_noise 112 | x0.clamp_(-1., 1.) 113 | mean = self.mean_tilde_x0_coeff[t] * x0 + self.mean_tilde_xt_coeff[t] * xt 114 | else: 115 | mean = self.sqrt_recip_alpha[t] * (xt - self.beta_over_sqrt_one_minus_alpha_bar[t] * pred_noise) 116 | variance = self.beta_tilde[t] 117 | noise = torch.randn_like(xt) if t > 0 else 0. # corresponds to z, consult 4: in Algorithm 2. 118 | x_t_minus_1 = mean + torch.sqrt(variance) * noise 119 | return x_t_minus_1 120 | 121 | @torch.inference_mode() 122 | def sample(self, batch_size=16, return_all_timestep=False, clip=True, min1to1=False): 123 | """ 124 | 125 | :param batch_size: # of image to generate. 126 | :param return_all_timestep: Whether to return all images during de-noising process. So it will return the 127 | images from time step T ~ time step 0 128 | :param clip: [True, False]. Explanation in p_sample function. 129 | :return: Generated image of shape (b, 3, h, w) if return_all_timestep==False else (b, T, 3, h, w) 130 | """ 131 | xT = torch.randn([batch_size, self.channel, self.image_size, self.image_size], device=self.device) 132 | denoised_intermediates = [xT] 133 | xt = xT 134 | for t in tqdm(reversed(range(0, self.time_step)), desc='DDPM Sampling', total=self.time_step, leave=False): 135 | x_t_minus_1 = self.p_sample(xt, t, clip) 136 | denoised_intermediates.append(x_t_minus_1) 137 | xt = x_t_minus_1 138 | 139 | images = xt if not return_all_timestep else torch.stack(denoised_intermediates, dim=1) 140 | # images = (images + 1.0) * 0.5 # scale to 0~1 141 | images.clamp_(min=-1.0, max=1.0) 142 | if not min1to1: 143 | images.sub_(-1.0).div_(2.0) 144 | return images 145 | 146 | #################################################################################################################### 147 | 148 | def linear_beta_schedule(self): 149 | """ 150 | linear schedule, proposed in original ddpm paper 151 | """ 152 | scale = 1000 / self.time_step 153 | beta_start = scale * 0.0001 154 | beta_end = scale * 0.02 155 | return torch.linspace(beta_start, beta_end, self.time_step, dtype=torch.float32) 156 | 157 | 158 | class DDIM_Sampler(nn.Module): 159 | def __init__(self, ddpm_diffusion_model, ddim_sampling_steps=100, eta=0, sample_every=5000, fixed_noise=False, 160 | calculate_fid=False, num_fid_sample=None, generate_image=True, clip=True, save=False): 161 | """ 162 | Denoising Diffusion Implicit Models (DDIM), Jiaming Song et al. 163 | :param ddpm_diffusion_model: DDPM diffusion model. 164 | :param ddim_sampling_steps: Total sampling steps for DDIM sampling process. It corresponds to S in DDIM paper. 165 | Consult section 4.2 Accelerated Generation Processes in DDIM paper. 166 | :param eta: Hyperparameter to control the stochasticity, see (16) in DDIM paper. 167 | 0: deterministic(DDIM) , 1: fully stochastic(DDPM) 168 | :param sample_every: The interval for calling this DDIM sampler. It is only valid during training. 169 | For example if sample_every=5000 then during training, every 5000steps, trainer will call this DDIM sampler. 170 | :param fixed_noise: If set to True, then this Sampler will always use same starting noise for image generation. 171 | :param calculate_fid: Whether to calculate FID score for this sampler. 172 | :param num_fid_sample: # of generating samples for FID calculation. 173 | If calculate_fid==True and num_fid_sample==None, then it will automatically set # of generating image to the 174 | total number of image in original dataset. 175 | :param generate_image: Whether to save the generated image to folder. 176 | :param clip: [True, False, 'both'] 'both' will sample twice for clip==True and clip==False. 177 | Details in ddim_p_sample function. 178 | :param save: Whether to save the diffusion model based on the FID score calculated by this sampler. 179 | So calculate_fid must be set to True, if you want to set this parameter to be True. 180 | """ 181 | super().__init__() 182 | self.ddim_steps = ddim_sampling_steps 183 | self.eta = eta 184 | self.sample_every = sample_every 185 | self.fixed_noise = fixed_noise 186 | self.calculate_fid = calculate_fid 187 | self.num_fid_sample = num_fid_sample 188 | self.generate_image = generate_image 189 | self.channel = ddpm_diffusion_model.channel 190 | self.image_size = ddpm_diffusion_model.image_size 191 | self.device = ddpm_diffusion_model.device 192 | self.clip = clip 193 | self.save = save 194 | self.sampler_name = None 195 | self.save_path = None 196 | ddpm_steps = ddpm_diffusion_model.time_step 197 | assert self.ddim_steps <= ddpm_steps, 'DDIM sampling step must be smaller or equal to DDPM sampling step' 198 | assert clip in [True, False, 'both'], "clip must be one of [True, False, 'both']" 199 | if self.save: 200 | assert self.calculate_fid is True, 'To save model based on FID score, you must set [calculate_fid] to True' 201 | self.register_buffer('best_fid', torch.tensor([1e10], dtype=torch.float32)) 202 | 203 | alpha_bar = ddpm_diffusion_model.alpha_bar 204 | # One thing you mush notice is that although sampling time is indexed as [1,...T] in paper, 205 | # since in computer program we index from [0,...T-1] rather than [1,...T], 206 | # value of tau ranges from [-1, ...T-1] where t=-1 indicate initial state (Data distribution) 207 | 208 | # [tau_1, tau_2, ... tau_S] sec 4.2 209 | self.register_buffer('tau', torch.linspace(-1, ddpm_steps - 1, steps=self.ddim_steps + 1, dtype=torch.long)[1:]) 210 | 211 | alpha_tau_i = alpha_bar[self.tau] 212 | alpha_tau_i_min_1 = F.pad(alpha_bar[self.tau[:-1]], pad=(1, 0), value=1.) # alpha_0 = 1 213 | 214 | # (16) in DDIM 215 | self.register_buffer('sigma', eta * (((1 - alpha_tau_i_min_1) / (1 - alpha_tau_i) * 216 | (1 - alpha_tau_i / alpha_tau_i_min_1)).sqrt())) 217 | # (12) in DDIM 218 | self.register_buffer('coeff', (1 - alpha_tau_i_min_1 - self.sigma ** 2).sqrt()) 219 | self.register_buffer('sqrt_alpha_i_min_1', alpha_tau_i_min_1.sqrt()) 220 | 221 | assert self.coeff[0] == 0.0 and self.sqrt_alpha_i_min_1[0] == 1.0, 'DDIM parameter error' 222 | 223 | @torch.inference_mode() 224 | def ddim_p_sample(self, model, xt, i, clip=True): 225 | """ 226 | Sample x_{tau_(i-1)} from p(x_{tau_(i-1)} | x_{tau_i}), consult (56) in DDIM paper. 227 | Calculation is done using (12) in DDIM paper where t-1 has to be changed to tau_(i-1) and t has to be 228 | changed to tau_i in (12), for accelerated generation process where total # of de-noising step is S. 229 | 230 | :param model: Diffusion model 231 | :param xt: noisy image at time step tau_i 232 | :param i: i is the index of array tau which is an sub-sequence of [1, ..., T] of length S. See sec. 4.2 233 | :param clip: Like in GaussianDiffusion p_sample, we can clip(or clamp) the predicted x_0 to -1 ~ 1 234 | for better sampling result. If you see (12) in DDIM paper, sampling x_(t-1) depends on epsilon_theta which is 235 | U-net network predicted noise at time step t. If we want to clip the "predicted x0", we have to 236 | re-calculate the epsilon_theta to make "predicted x0" lie in -1 ~ 1. This is exactly what is going on 237 | if you set clip==True. 238 | :return: de-noised image at time step tau_(i-1) 239 | """ 240 | t = self.tau[i] 241 | batched_time = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long) 242 | pred_noise = model.unet(xt, batched_time) # corresponds to epsilon_{theta} 243 | x0 = model.sqrt_recip_alpha_bar[t] * xt - model.sqrt_recip_alpha_bar_min_1[t] * pred_noise 244 | if clip: 245 | x0.clamp_(-1., 1.) 246 | pred_noise = (model.sqrt_recip_alpha_bar[t] * xt - x0) / model.sqrt_recip_alpha_bar_min_1[t] 247 | 248 | # x0 corresponds to "predicted x0" and pred_noise corresponds to epsilon_theta(xt) in (12) DDIM 249 | # Thus self.coeff[i] * pred_noise corresponds to "direction pointing to xt" in (12) 250 | mean = self.sqrt_alpha_i_min_1[i] * x0 + self.coeff[i] * pred_noise 251 | noise = torch.randn_like(xt) if i > 0 else 0. 252 | # self.sigma[i] * noise corresponds to "random noise" in (12) 253 | x_t_minus_1 = mean + self.sigma[i] * noise 254 | return x_t_minus_1 255 | 256 | @torch.inference_mode() 257 | def sample(self, diffusion_model, batch_size, noise=None, return_all_timestep=False, clip=True, min1to1=False): 258 | """ 259 | 260 | :param diffusion_model: Diffusion model 261 | :param batch_size: # of image to generate. 262 | :param noise: If set to True, then this Sampler will always use same starting noise for image generation. 263 | :param return_all_timestep: Whether to return all images during de-noising process. So it will return the 264 | images from time step tau_S ~ time step tau_0 265 | :param clip: See ddim_p_sample function 266 | :return: Generated image of shape (b, 3, h, w) if return_all_timestep==False else (b, S, 3, h, w) 267 | """ 268 | clip = clip if clip is not None else self.clip 269 | xT = torch.randn([batch_size, self.channel, self.image_size, self.image_size], device=self.device) \ 270 | if noise is None else noise.to(self.device) 271 | denoised_intermediates = [xT] 272 | xt = xT 273 | for i in tqdm(reversed(range(0, self.ddim_steps)), desc='DDIM Sampling', total=self.ddim_steps, leave=False): 274 | x_t_minus_1 = self.ddim_p_sample(diffusion_model, xt, i, clip) 275 | denoised_intermediates.append(x_t_minus_1) 276 | xt = x_t_minus_1 277 | 278 | images = xt if not return_all_timestep else torch.stack(denoised_intermediates, dim=1) 279 | # images = (images + 1.0) * 0.5 # scale to 0~1 280 | images.clamp_(min=-1.0, max=1.0) 281 | if not min1to1: 282 | images.sub_(-1.0).div_(2.0) 283 | return images 284 | -------------------------------------------------------------------------------- /src/inferencer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torchvision.utils import save_image 6 | from .dataset import dataset_wrapper 7 | from .utils import * 8 | from termcolor import colored 9 | from tqdm import tqdm 10 | from glob import glob 11 | import imageio 12 | from imageio import mimsave 13 | from functools import partial 14 | from torchvision.transforms import ToPILImage 15 | 16 | 17 | class Inferencer: 18 | def __init__(self, diffusion_model, dataset, ddim_samplers=None, batch_size=32, num_samples_per_image=25, 19 | result_folder='./inference_results', num_images_to_generate=1, ddpm_fid_estimate=True, time_step=1000, 20 | ddpm_num_fid_samples=None, clip=True, return_all_step=True, make_denoising_gif=True, num_gif=50, 21 | save_generated_img_for_fid_cal=False): 22 | """ 23 | Inferenceer for Diffusion model. Sampling is supported by DDPM sampling & DDIM sampling 24 | :param diffusion_model: GaussianDiffusion model 25 | :param dataset: either 'cifar10' or path to the custom dataset you've prepared, where images are saved 26 | :param ddim_samplers: List containing DDIM samplers. 27 | :param batch_size: batch_size for inferencing 28 | :param num_samples_per_image: # of generating images, must be square number ex) 25, 36, 49... 29 | :param result_folder: where inference result will be saved. 30 | :param num_images_to_generate: # of generated image set. For example if num_samples_per_image==25 and 31 | num_images_to_generate==3 then, in result folder there will be 3 generated image with each image containing 32 | 25 generated sub-images merged into one image file with 5 rows, 5 columns. 33 | :param ddpm_fid_estimate: Whether to calculate FID score based on DDPM sampling. 34 | :param time_step: Gaussian diffusion length T. In DDPM paper they used T=1000 35 | :param ddpm_num_fid_samples: # of generating images for FID calculation using DDPM sampler. If you set 36 | ddpm_fid_estimate to False, i.e. not using DDPM sampler for FID calculation, then this value will 37 | be just ignored. 38 | :param clip: [True, False, 'both'] you can find detail in p_sample function 39 | and ddim_p_sample function in diffusion.py file. 40 | :param return_all_step: Whether to save the entire de-noising processed image to result folder. 41 | :param make_denoising_gif: Whether to make gif which contains de-noising process visually. 42 | :param num_gif: # of images to make one gif which contains de-noising process visually. 43 | :param save_generated_img_for_fid_cal: Whether to save generated images which are used for FID calculation. 44 | """ 45 | dataset_name = os.path.basename(dataset) 46 | if dataset_name == '': 47 | dataset_name=os.path.basename(os.path.dirname(dataset)) 48 | self.diffusion_model = diffusion_model 49 | self.ddim_samplers = ddim_samplers 50 | self.batch_size = batch_size 51 | self.time_step = time_step 52 | self.num_samples = num_samples_per_image 53 | self.num_images = num_images_to_generate 54 | self.nrow = int(math.sqrt(self.num_samples)) 55 | assert (self.nrow ** 2) == self.num_samples, 'num_samples must be a square number. ex) 25, 36, 49, ...' 56 | self.image_size = self.diffusion_model.image_size 57 | self.result_folder = os.path.join(result_folder, dataset_name) 58 | self.ddpm_result_folder = os.path.join(self.result_folder, 'DDPM') 59 | self.device = self.diffusion_model.device 60 | self.return_all_step = return_all_step or make_denoising_gif 61 | self.make_denoising_gif = make_denoising_gif 62 | self.num_gif = num_gif 63 | self.save_img = save_generated_img_for_fid_cal 64 | self.toPIL = ToPILImage() 65 | self.clip = clip 66 | self.ddpm_fid_flag = ddpm_fid_estimate 67 | self.cal_fid = True if self.ddpm_fid_flag else False 68 | self.fid_score_log = dict() 69 | assert clip in [True, False, 'both'], "clip must be one of [True, False, 'both']" 70 | if clip is True or clip == 'both': 71 | os.makedirs(os.path.join(self.ddpm_result_folder, 'clip'), exist_ok=True) 72 | if clip is False or clip == 'both': 73 | os.makedirs(os.path.join(self.ddpm_result_folder, 'no_clip'), exist_ok=True) 74 | 75 | # Dataset 76 | notification = make_notification('Dataset', color='light_green') 77 | print(notification) 78 | dataSet = dataset_wrapper(dataset, self.image_size, augment_horizontal_flip=False, min1to1=False) 79 | dataLoader = DataLoader(dataSet, batch_size=batch_size) 80 | print(colored('Dataset Length: {}\n'.format(len(dataSet)), 'light_green')) 81 | 82 | # DDIM sampler setting 83 | for idx, sampler in enumerate(self.ddim_samplers): 84 | sampler.sampler_name = 'DDIM_{}_steps{}_eta{}'.format(idx + 1, sampler.ddim_steps, sampler.eta) 85 | save_path = os.path.join(self.result_folder, sampler.sampler_name) 86 | sampler.save_path = save_path 87 | if sampler.generate_image: 88 | if sampler.clip is True or sampler.clip == 'both': 89 | os.makedirs(os.path.join(save_path, 'clip'), exist_ok=True) 90 | if sampler.clip is False or sampler.clip == 'both': 91 | os.makedirs(os.path.join(save_path, 'no_clip'), exist_ok=True) 92 | if sampler.calculate_fid: 93 | self.cal_fid = True 94 | sampler.num_fid_sample = sampler.num_fid_sample if sampler.num_fid_sample is not None else len(dataSet) 95 | 96 | # Image generation log 97 | notification = make_notification('Image Generation', color='light_cyan') 98 | print(notification) 99 | print(colored('Image will be generated with the following sampler(s)', 'light_cyan')) 100 | print(colored('-> DDPM Sampler', 'light_cyan')) 101 | for sampler in self.ddim_samplers: 102 | if sampler.generate_image: 103 | print(colored('-> {}'.format(sampler.sampler_name), 'light_cyan')) 104 | print('\n') 105 | 106 | # FID score 107 | notification = make_notification('FID', color='light_magenta') 108 | print(notification) 109 | if not self.cal_fid: 110 | print(colored('No FID evaluation will be executed!\n' 111 | 'If you want FID evaluation consider using DDIM sampler.', 'light_magenta')) 112 | else: 113 | self.fid_scorer = FID(self.batch_size, dataLoader, dataset_name=dataset_name, device=self.device, 114 | no_label=os.path.isdir(dataset)) 115 | print(colored('FID score will be calculated with the following sampler(s)', 'light_magenta')) 116 | if self.ddpm_fid_flag: 117 | self.ddpm_num_fid_samples = ddpm_num_fid_samples if ddpm_num_fid_samples is not None else len(dataSet) 118 | print(colored('-> DDPM Sampler / FID calculation with {} generated samples' 119 | .format(self.ddpm_num_fid_samples), 'light_magenta')) 120 | for sampler in self.ddim_samplers: 121 | if sampler.calculate_fid: 122 | print(colored('-> {} / FID calculation with {} generated samples' 123 | .format(sampler.sampler_name, sampler.num_fid_sample), 'light_magenta')) 124 | print('\n') 125 | del dataset 126 | del dataLoader 127 | 128 | @torch.inference_mode() 129 | def inference(self): 130 | self.diffusion_model.eval() 131 | notification = make_notification('Inferencing', color='light_yellow', boundary='+') 132 | print(notification) 133 | print(colored('Image Generation\n', 'light_yellow')) 134 | 135 | # DDPM sampler 136 | for idx in tqdm(range(self.num_images), desc='DDPM image sampling'): 137 | batches = num_to_groups(self.num_samples, self.batch_size) 138 | for i, j in zip([True, False], ['clip', 'no_clip']): 139 | if self.clip not in [i, 'both']: 140 | continue 141 | imgs = list(map(lambda n: self.diffusion_model.sample(n, self.return_all_step, clip=i), batches)) 142 | imgs = torch.cat(imgs, dim=0) # (batch, steps, ch, h, w) 143 | if self.return_all_step: 144 | path = os.path.join(self.ddpm_result_folder, '{}'.format(j), '{}'.format(idx + 1)) 145 | os.makedirs(path, exist_ok=True) 146 | for step in range(imgs.shape[1]): 147 | save_image(imgs[:, step], nrow=self.nrow, fp=os.path.join(path, '{:04d}.png'.format(step))) 148 | if self.make_denoising_gif: 149 | gif_step = int(self.time_step/self.num_gif) 150 | gif_step = max(1, gif_step) 151 | file_names = list(reversed(sorted(glob(os.path.join(path, '*.png')))))[::gif_step] 152 | file_names = reversed(file_names) 153 | gif = [imageio.v2.imread(names) for names in file_names] 154 | mimsave(os.path.join(self.ddpm_result_folder, '{}_{}.gif'.format(idx+1, j)), 155 | gif, **{'duration': self.num_gif / imgs.shape[1]}) 156 | last_img = imgs[:, -1] if self.return_all_step else imgs 157 | save_image(last_img, nrow=self.nrow, 158 | fp=os.path.join(self.ddpm_result_folder, '{}_{}.png'.format(idx+1, j))) 159 | # DDIM sampler 160 | for sampler in self.ddim_samplers: 161 | if sampler.generate_image: 162 | for idx in tqdm(range(self.num_images), desc='{} image sampling'.format(sampler.sampler_name)): 163 | batches = num_to_groups(self.num_samples, self.batch_size) 164 | for i, j in zip([True, False], ['clip', 'no_clip']): 165 | if sampler.clip not in [i, 'both']: 166 | continue 167 | imgs = list(map(lambda n: sampler.sample(self.diffusion_model, batch_size=n, noise=None, 168 | return_all_timestep=self.return_all_step, clip=i), batches)) 169 | imgs = torch.cat(imgs, dim=0) 170 | if self.return_all_step: 171 | path = os.path.join(sampler.save_path, '{}'.format(j), '{}'.format(idx + 1)) 172 | os.makedirs(path, exist_ok=True) 173 | for step in range(imgs.shape[1]): 174 | save_image(imgs[:, step], nrow=self.nrow, 175 | fp=os.path.join(path, '{:04d}.png'.format(step))) 176 | if self.make_denoising_gif: 177 | gif_step = int(sampler.ddim_steps/self.num_gif) 178 | gif_step = max(1, gif_step) 179 | file_names = list(reversed(sorted(glob(os.path.join(path, '*.png')))))[::gif_step] 180 | file_names = reversed(file_names) 181 | gif = [imageio.v2.imread(names) for names in file_names] 182 | mimsave(os.path.join(sampler.save_path, '{}_{}.gif'.format(idx + 1, j)), 183 | gif, **{'duration': self.num_gif / imgs.shape[1]}) 184 | last_img = imgs[:, -1] if self.return_all_step else imgs 185 | save_image(last_img, nrow=self.nrow, 186 | fp=os.path.join(sampler.save_path, '{}_{}.png'.format(idx + 1, j))) 187 | if self.cal_fid: 188 | print(colored('\nFID score estimation\n', 'light_yellow')) 189 | if self.ddpm_fid_flag: 190 | print(colored('DDPM FID calculation...', 'yellow')) 191 | ddpm_fid, imgs = self.fid_scorer.fid_score(self.diffusion_model.sample, self.ddpm_num_fid_samples, 192 | self.save_img) 193 | self.fid_score_log['DDPM'] = ddpm_fid 194 | if self.save_img: 195 | path_ = os.path.join(self.ddpm_result_folder, 'generated_samples_for_FID_calculation') 196 | os.makedirs(path_, exist_ok=True) 197 | for i in range(imgs.shape[0]): 198 | img = self.toPIL(imgs[i]) 199 | img.save(os.path.join(path_, '{:06d}.png'.format(i+1))) 200 | for sampler in self.ddim_samplers: 201 | print(colored('{} FID calculation...'.format(sampler.sampler_name), 'yellow')) 202 | if sampler.calculate_fid: 203 | sample_ = partial(sampler.sample, self.diffusion_model) 204 | ddim_fid, imgs = self.fid_scorer.fid_score(sample_, sampler.num_fid_sample, self.save_img) 205 | self.fid_score_log[f'{sampler.sampler_name}'] = ddim_fid 206 | if self.save_img: 207 | path_ = os.path.join(sampler.save_path, 'generated_samples_for_FID_calculation') 208 | os.makedirs(path_, exist_ok=True) 209 | for i in range(imgs.shape[0]): 210 | img = self.toPIL(imgs[i]) 211 | img.save(os.path.join(path_, '{:06d}.png'.format(i + 1))) 212 | print(colored('-'*50, 'yellow')) 213 | for key, val in self.fid_score_log.items(): 214 | print(colored('Sampler: {} -> FID score: {}'.format(key, val), 'yellow')) 215 | with open(os.path.join(self.result_folder, 'FID.txt'), 'w') as f: 216 | f.write('Results\n') 217 | f.write('='*50) 218 | f.write('\n') 219 | for key, val in self.fid_score_log.items(): 220 | f.write('Sampler: {} -> FID score: {}\n'.format(key, val)) 221 | 222 | def load(self, path): 223 | if not os.path.exists(path): 224 | print(make_notification('ERROR', color='red', boundary='*')) 225 | print(colored('No saved checkpoint is detected. Please check you gave existing path!', 'red')) 226 | exit() 227 | print(make_notification('Loading Checkpoint', color='green')) 228 | data = torch.load(path, map_location=self.device) 229 | self.diffusion_model.load_state_dict(data['model']) 230 | print(colored('Successfully loaded checkpoint!\n', 'green')) -------------------------------------------------------------------------------- /src/model_original.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from .utils import PositionalEncoding 5 | 6 | 7 | class ResnetBlock(nn.Module): 8 | def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=None, groups=32): 9 | super().__init__() 10 | 11 | self.dim, self.dim_out = dim, dim_out 12 | 13 | dim_out = dim if dim_out is None else dim_out 14 | self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=dim) 15 | self.activation1 = nn.SiLU() 16 | self.conv1 = nn.Conv2d(dim, dim_out, kernel_size=(3, 3), padding=1) 17 | self.block1 = nn.Sequential(self.norm1, self.activation1, self.conv1) 18 | 19 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim is not None else None 20 | 21 | self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=dim_out) 22 | self.activation2 = nn.SiLU() 23 | self.dropout = nn.Dropout(dropout) if dropout is not None else nn.Identity() 24 | self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), padding=1) 25 | self.block2 = nn.Sequential(self.norm2, self.activation2, self.dropout, self.conv2) 26 | 27 | self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=(1, 1)) if dim != dim_out else nn.Identity() 28 | 29 | def forward(self, x, time_emb=None): 30 | hidden = self.block1(x) 31 | if time_emb is not None: 32 | # add in timestep embedding 33 | hidden = hidden + self.mlp(time_emb)[..., None, None] # (B, dim_out, 1, 1) 34 | hidden = self.block2(hidden) 35 | return hidden + self.residual_conv(x) 36 | 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, groups=32): 40 | super().__init__() 41 | 42 | self.dim, self.dim_out = dim, dim 43 | 44 | self.scale = dim ** (-0.5) # 1 / sqrt(d_k) 45 | self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim) 46 | self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=(1, 1)) 47 | self.to_out = nn.Conv2d(dim, dim, kernel_size=(1, 1)) 48 | 49 | def forward(self, x): 50 | b, c, h, w = x.shape 51 | qkv = self.to_qkv(self.norm(x)).chunk(3, dim=1) 52 | # You can think (h*w) as sequence length where c is d_k in 53 | q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), qkv) 54 | 55 | """ 56 | q, k, v shape: (batch, seq_length, d_k) seq_length = height*width, d_k == c == dim 57 | similarity shape: (batch, seq_length, seq_length) 58 | attention_score shape: (batch, seq_length, seq_length) 59 | attention shape: (batch, seq_length, d_k) 60 | out shape: (batch, d_k, height, width) d_k == c == dim 61 | return shape: (batch, dim, height, width) 62 | """ 63 | 64 | similarity = torch.einsum('b i c, b j c -> b i j', q, k) # Q(K^T) 65 | attention_score = torch.softmax(similarity * self.scale, dim=-1) # softmax(Q(K^T) / sqrt(d_k)) 66 | attention = torch.einsum('b i j, b j c -> b i c', attention_score, v) 67 | # attention(Q, K, V) = [softmax(Q(K^T) / sqrt(d_k))]V -> Scaled Dot-Product Attention 68 | out = rearrange(attention, 'b (h w) c -> b c h w', h=h, w=w) 69 | return self.to_out(out) + x 70 | 71 | 72 | class ResnetAttentionBlock(nn.Module): 73 | def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=None, groups=32): 74 | super().__init__() 75 | 76 | self.dim, self.dim_out = dim, dim_out 77 | 78 | self.resnet = ResnetBlock(dim, dim_out, time_emb_dim, dropout, groups) 79 | self.attention = Attention(dim_out, groups) 80 | 81 | def forward(self, x, time_emb=None): 82 | x = self.resnet(x, time_emb) 83 | return self.attention(x) 84 | 85 | 86 | class downSample(nn.Module): 87 | def __init__(self, dim_in): 88 | super().__init__() 89 | 90 | self.dim, self.dim_out = dim_in, dim_in 91 | 92 | self.downsameple = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1) 93 | 94 | def forward(self, x): 95 | return self.downsameple(x) 96 | 97 | 98 | class upSample(nn.Module): 99 | def __init__(self, dim_in): 100 | super().__init__() 101 | 102 | self.dim, self.dim_out = dim_in, dim_in 103 | 104 | self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), 105 | nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), padding=1)) 106 | 107 | def forward(self, x): 108 | return self.upsample(x) 109 | 110 | 111 | class Unet(nn.Module): 112 | def __init__(self, dim, image_size, dim_multiply=(1, 2, 4, 8), channel=3, num_res_blocks=2, 113 | attn_resolutions=(16,), dropout=0, device='cuda', groups=32): 114 | """ 115 | U-net for noise prediction. Code is based on Denoising Diffusion Probabilistic Models 116 | https://github.com/hojonathanho/diffusion 117 | :param dim: See below 118 | :param dim_multiply: len(dim_multiply) will be the depth of U-net model with at each level i, the dimension 119 | of channel will be dim * dim_multiply[i]. If the input image shape is [H, W, 3] then at the lowest level, 120 | feature map shape will be [H/(2^(len(dim_multiply)-1), W/(2^(len(dim_multiply)-1), dim*dim_multiply[-1]] 121 | if not considering U-net down-up path connection. 122 | :param image_size: input image size 123 | :param channel: 3 124 | :param num_res_blocks: # of ResnetBlock at each level. In downward path, at each level, there will be 125 | num_res_blocks amount of ResnetBlock module and in upward path, at each level, there will be 126 | (num_res_blocks+1) amount of ResnetBlock module 127 | :param attn_resolutions: The feature map resolution where we will apply Attention. In DDPM paper, author 128 | used Attention module when resolution of feature map is 16. 129 | :param dropout: dropout. If set to 0 then no dropout. 130 | :param device: either 'cuda' or 'cpu' 131 | :param groups: number of groups for Group normalization. 132 | """ 133 | super().__init__() 134 | assert dim % groups == 0, 'parameter [groups] must be divisible by parameter [dim]' 135 | 136 | # Attributes 137 | self.dim = dim 138 | self.channel = channel 139 | self.time_emb_dim = 4 * self.dim 140 | self.num_resolutions = len(dim_multiply) 141 | self.device = device 142 | self.resolution = [int(image_size / (2 ** i)) for i in range(self.num_resolutions)] 143 | self.hidden_dims = [self.dim, *map(lambda x: x * self.dim, dim_multiply)] 144 | self.num_res_blocks = num_res_blocks 145 | 146 | # Time embedding 147 | positional_encoding = PositionalEncoding(self.dim) 148 | self.time_mlp = nn.Sequential( 149 | positional_encoding, nn.Linear(self.dim, self.time_emb_dim), 150 | nn.SiLU(), nn.Linear(self.time_emb_dim, self.time_emb_dim) 151 | ) 152 | 153 | # Layer definition 154 | self.down_path = nn.ModuleList([]) 155 | self.up_path = nn.ModuleList([]) 156 | concat_dim = list() 157 | 158 | # Downward Path layer definition 159 | self.init_conv = nn.Conv2d(channel, self.dim, kernel_size=(3, 3), padding=1) 160 | concat_dim.append(self.dim) 161 | 162 | for level in range(self.num_resolutions): 163 | d_in, d_out = self.hidden_dims[level], self.hidden_dims[level + 1] 164 | for block in range(num_res_blocks): 165 | d_in_ = d_in if block == 0 else d_out 166 | if self.resolution[level] in attn_resolutions: 167 | self.down_path.append(ResnetAttentionBlock(d_in_, d_out, self.time_emb_dim, dropout, groups)) 168 | else: 169 | self.down_path.append(ResnetBlock(d_in_, d_out, self.time_emb_dim, dropout, groups)) 170 | concat_dim.append(d_out) 171 | if level != self.num_resolutions - 1: 172 | self.down_path.append(downSample(d_out)) 173 | concat_dim.append(d_out) 174 | 175 | # Middle layer definition 176 | mid_dim = self.hidden_dims[-1] 177 | self.middle_resnet_attention = ResnetAttentionBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups) 178 | self.middle_resnet = ResnetBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups) 179 | 180 | # Upward Path layer definition 181 | for level in reversed(range(self.num_resolutions)): 182 | d_out = self.hidden_dims[level + 1] 183 | for block in range(num_res_blocks + 1): 184 | d_in = self.hidden_dims[level + 2] if block == 0 and level != self.num_resolutions - 1 else d_out 185 | d_in = d_in + concat_dim.pop() 186 | if self.resolution[level] in attn_resolutions: 187 | self.up_path.append(ResnetAttentionBlock(d_in, d_out, self.time_emb_dim, dropout, groups)) 188 | else: 189 | self.up_path.append(ResnetBlock(d_in, d_out, self.time_emb_dim, dropout, groups)) 190 | if level != 0: 191 | self.up_path.append(upSample(d_out)) 192 | 193 | assert not concat_dim, 'Error in concatenation between downward path and upward path.' 194 | 195 | # Output layer 196 | final_ch = self.hidden_dims[1] 197 | self.final_norm = nn.GroupNorm(groups, final_ch) 198 | self.final_activation = nn.SiLU() 199 | self.final_conv = nn.Conv2d(final_ch, channel, kernel_size=(3, 3), padding=1) 200 | 201 | def forward(self, x, time): 202 | """ 203 | return predicted noise given x_t and t 204 | """ 205 | t = self.time_mlp(time) 206 | # Downward 207 | concat = list() 208 | x = self.init_conv(x) 209 | concat.append(x) 210 | for layer in self.down_path: 211 | x = layer(x, t) if not isinstance(layer, (upSample, downSample)) else layer(x) 212 | concat.append(x) 213 | 214 | # Middle 215 | x = self.middle_resnet_attention(x, t) 216 | x = self.middle_resnet(x, t) 217 | 218 | # Upward 219 | for layer in self.up_path: 220 | if not isinstance(layer, upSample): 221 | x = torch.cat((x, concat.pop()), dim=1) 222 | x = layer(x, t) if not isinstance(layer, (upSample, downSample)) else layer(x) 223 | 224 | # Final 225 | x = self.final_activation(self.final_norm(x)) 226 | return self.final_conv(x) 227 | 228 | def print_model_structure(self): 229 | for i in self.down_path: 230 | if i.__class__.__name__ == 'downSample': 231 | print('-' * 20) 232 | if i.__class__.__name__ == "Conv2d": 233 | 234 | print(i.__class__.__name__) 235 | else: 236 | print(i.__class__.__name__, i.dim, i.dim_out) 237 | print('\n') 238 | print('=' * 20) 239 | print('\n') 240 | for i in self.up_path: 241 | if i.__class__.__name__ == 'upSample': 242 | print('-' * 20) 243 | if i.__class__.__name__ == "Conv2d": 244 | print(i.__class__.__name__) 245 | else: 246 | print(i.__class__.__name__, i.dim, i.dim_out) 247 | -------------------------------------------------------------------------------- /src/model_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from einops.layers.torch import Rearrange 6 | from functools import partial 7 | from .utils import PositionalEncoding 8 | 9 | 10 | class RMSNorm(nn.Module): 11 | def __init__(self, dim): 12 | super().__init__() 13 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 14 | 15 | def forward(self, x): 16 | return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5) 17 | 18 | 19 | class Block(nn.Module): 20 | def __init__(self, dim, dim_out, groups=8): 21 | """ 22 | Input shape=(B, dim, H, W) 23 | Output shape=(B, dim_out, H, W) 24 | 25 | :param dim: input channel 26 | :param dim_out: output channel 27 | :param groups: number of groups for Group normalization. 28 | """ 29 | super().__init__() 30 | self.proj = nn.Conv2d(dim, dim_out, kernel_size=(3, 3), padding=1) 31 | self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim_out) 32 | self.activation = nn.SiLU() 33 | 34 | def forward(self, x, scale_shift=None): 35 | x = self.proj(x) 36 | x = self.norm(x) 37 | if scale_shift is not None: 38 | scale, shift = scale_shift 39 | x = x * (1 + scale) + shift 40 | return self.activation(x) 41 | 42 | 43 | class ResnetBlock(nn.Module): 44 | def __init__(self, dim, dim_out, time_emb_dim=None, group=8): 45 | """ 46 | In abstract, it is composed of two Convolutional layer with residual connection, 47 | with information of time encoding is passed to first Convolutional layer. 48 | 49 | Input shape=(B, dim, H, W) 50 | Output shape=(B, dim_out, H, W) 51 | 52 | :param dim: input channel 53 | :param dim_out: output channel 54 | :param time_emb_dim: Embedding dimension for time. 55 | :param group: number of groups for Group normalization. 56 | """ 57 | super().__init__() 58 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if time_emb_dim is not None else None 59 | self.block1 = Block(dim, dim_out, group) 60 | self.block2 = Block(dim_out, dim_out, group) 61 | self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=(1, 1)) if dim != dim_out else nn.Identity() 62 | 63 | def forward(self, x, time_emb=None): 64 | """ 65 | 66 | :param x: (B, dim, H, W) 67 | :param time_emb: (B, time_emb_dim) 68 | :return: (B, dim_out, H, W) 69 | """ 70 | scale_shift = None 71 | if time_emb is not None: 72 | scale_shift = self.mlp(time_emb)[..., None, None] # (B, dim_out*2, 1, 1) 73 | scale_shift = scale_shift.chunk(2, dim=1) # len 2 with each element shape (B, dim_out, 1, 1) 74 | hidden = self.block1(x, scale_shift) 75 | hidden = self.block2(hidden) 76 | return hidden + self.residual_conv(x) 77 | 78 | 79 | class Attention(nn.Module): 80 | def __init__(self, dim, head=4, dim_head=32): 81 | super().__init__() 82 | self.head = head 83 | hidden_dim = head * dim_head 84 | 85 | self.scale = dim_head ** (-0.5) # 1 / sqrt(d_k) 86 | self.norm = RMSNorm(dim) 87 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=(1, 1), bias=False) 88 | self.to_out = nn.Conv2d(hidden_dim, dim, kernel_size=(1, 1)) 89 | 90 | def forward(self, x): 91 | b, c, i, j = x.shape 92 | x = self.norm(x) 93 | 94 | qkv = self.to_qkv(x).chunk(3, dim=1) 95 | # h=self.head, f=dim_head, i=height, j=width. 96 | # You can think (i*j) as sequence length where f is d_k in 97 | q, k, v = map(lambda t: rearrange(t, 'b (h f) i j -> b h (i j) f', h=self.head), qkv) 98 | 99 | """ 100 | q, k, v shape: (batch, # of head, seq_length, d_k) seq_length = height * width 101 | similarity shape: (batch, # of head, seq_length, seq_length) 102 | attention_score shape: (batch, # of head, seq_length, seq_length) 103 | attention shape: (batch, # of head, seq_length, d_k) 104 | out shape: (batch, hidden_dim, height, width) 105 | return shape: (batch, dim, height, width) 106 | """ 107 | # n, m is likewise sequence length. 108 | similarity = torch.einsum('b h n f, b h m f -> b h n m', q, k) # Q(K^T) 109 | attention_score = torch.softmax(similarity * self.scale, dim=-1) # softmax(Q(K^T) / sqrt(d_k)) 110 | attention = torch.einsum('b h n m, b h m f -> b h n f', attention_score, v) 111 | # attention(Q, K, V) = [softmax(Q(K^T) / sqrt(d_k))]V -> Scaled Dot-Product Attention 112 | 113 | out = rearrange(attention, 'b h (i j) f -> b (h f) i j', i=i, j=j) 114 | return self.to_out(out) 115 | 116 | 117 | class LinearAttention(nn.Module): 118 | def __init__(self, dim, head=4, dim_head=32): 119 | super().__init__() 120 | self.head = head 121 | hidden_dim = head * dim_head 122 | 123 | self.scale = dim_head ** (-0.5) 124 | self.norm = RMSNorm(dim) 125 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=(1, 1), bias=False) 126 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, kernel_size=(1, 1)), RMSNorm(dim)) 127 | 128 | def forward(self, x): 129 | b, c, i, j = x.shape 130 | x = self.norm(x) 131 | 132 | qkv = self.to_qkv(x).chunk(3, dim=1) 133 | # h=self.head, f=dim_head, i=height, j=width. 134 | # You can think (i*j) as sequence length where f is d_k in 135 | q, k, v = map(lambda t: rearrange(t, 'b (h f) i j -> b h f (i j)', h=self.head), qkv) 136 | 137 | q = q.softmax(dim=-2) * self.scale 138 | k = k.softmax(dim=-1) 139 | context = torch.einsum('b h f m, b h e m -> b h f e', k, v) 140 | linear_attention = torch.einsum('b h f e, b h f n -> b h e n', context, q) 141 | out = rearrange(linear_attention, 'b h e (i j) -> b (h e) i j', i=i, j=j, h=self.head) 142 | return self.to_out(out) 143 | 144 | 145 | def downSample(dim_in, dim_out): 146 | return nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2), 147 | nn.Conv2d(dim_in * 4, dim_out, kernel_size=(1, 1))) 148 | 149 | 150 | def upSample(dim_in, dim_out): 151 | return nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), 152 | nn.Conv2d(dim_in, dim_out, kernel_size=(3, 3), padding=1)) 153 | 154 | 155 | class Unet(nn.Module): 156 | def __init__(self, dim, image_size, dim_multiply=(1, 2, 4, 8), channel=3, attn_heads=4, attn_head_dim=32, 157 | full_attn=(False, False, False, True), resnet_group_norm=8, device='cuda'): 158 | """ 159 | U-net for noise prediction. Code is based on denoising-diffusion-pytorch 160 | https://github.com/lucidrains/denoising-diffusion-pytorch 161 | :param dim: See below 162 | :param dim_multiply: len(dim_multiply) will be the depth of U-net model with at each level i, the dimension 163 | of channel will be dim * dim_multiply[i]. If the input image shape is [H, W, 3] then at the lowest level, 164 | feature map shape will be [H/(2^(len(dim_multiply)-1), W/(2^(len(dim_multiply)-1), dim*dim_multiply[-1]] 165 | if not considering U-net down-up path connection. 166 | :param channel: 3 167 | :param attn_heads: It uses multi-head-self-Attention. attn_head is the # of head. It corresponds to h in 168 | "Attention is all you need" paper. See section 3.2.2 169 | :param attn_head_dim: It is the dimension of each head. It corresponds to d_k in Attention paper. 170 | :param full_attn: In pytorch implementation they used Linear Attention where full Attention(multi head self 171 | attention) is not applied. This param indicates at each level, whether to use full attention 172 | or use linear attention. So the len(full_attn) must equal to len(dim_multiply). For example if 173 | full_attn=(F, F, F, T) then at level 0, 1, 2 it will use Linear Attention and at level 3 it will use 174 | multi-head self attention(i.e. full attention) 175 | :param resnet_group_norm: number of groups for Group normalization. 176 | :param device: either 'cuda' or 'cpu' 177 | """ 178 | super().__init__() 179 | assert len(dim_multiply) == len(full_attn), 'Length of dim_multiply and Length of full_attn must be same' 180 | 181 | # Attributes 182 | self.dim = dim 183 | self.channel = channel 184 | self.hidden_dims = [self.dim, *map(lambda x: x * self.dim, dim_multiply)] 185 | self.dim_in_out = list(zip(self.hidden_dims[:-1], self.hidden_dims[1:])) 186 | self.time_emb_dim = 4 * self.dim 187 | self.full_attn = full_attn 188 | self.depth = len(dim_multiply) 189 | self.device = device 190 | 191 | # Time embedding 192 | positional_encoding = PositionalEncoding(self.dim) 193 | self.time_mlp = nn.Sequential( 194 | positional_encoding, nn.Linear(self.dim, self.time_emb_dim), 195 | nn.GELU(), nn.Linear(self.time_emb_dim, self.time_emb_dim) 196 | ) 197 | 198 | # Layer definition 199 | resnet_block = partial(ResnetBlock, time_emb_dim=self.time_emb_dim, group=resnet_group_norm) 200 | self.init_conv = nn.Conv2d(self.channel, self.dim, kernel_size=(7, 7), padding=3) 201 | self.down_path = nn.ModuleList([]) 202 | self.up_path = nn.ModuleList([]) 203 | 204 | # Downward Path layer definition 205 | for idx, ((dim_in, dim_out), full_attn_flag) in enumerate(zip(self.dim_in_out, self.full_attn)): 206 | isLast = idx == (self.depth - 1) 207 | attention = LinearAttention if not full_attn_flag else Attention 208 | self.down_path.append(nn.ModuleList([ 209 | resnet_block(dim_in, dim_in), 210 | resnet_block(dim_in, dim_in), 211 | attention(dim_in, head=attn_heads, dim_head=attn_head_dim), 212 | downSample(dim_in, dim_out) if not isLast else nn.Conv2d(dim_in, dim_out, kernel_size=(3, 3), padding=1) 213 | ])) 214 | 215 | # Middle layer definition 216 | mid_dim = self.hidden_dims[-1] 217 | self.mid_resnet_block1 = resnet_block(mid_dim, mid_dim) 218 | self.mid_attention = Attention(mid_dim, head=attn_heads, dim_head=attn_head_dim) 219 | self.mid_resnet_block2 = resnet_block(mid_dim, mid_dim) 220 | 221 | # Upward Path layer definition 222 | for idx, ((dim_in, dim_out), full_attn_flag) in enumerate( 223 | zip(reversed(self.dim_in_out), reversed(self.full_attn))): 224 | isLast = idx == (self.depth - 1) 225 | attention = LinearAttention if not full_attn_flag else Attention 226 | self.up_path.append(nn.ModuleList([ 227 | resnet_block(dim_in + dim_out, dim_out), 228 | resnet_block(dim_in + dim_out, dim_out), 229 | attention(dim_out, head=attn_heads, dim_head=attn_head_dim), 230 | upSample(dim_out, dim_in) if not isLast else nn.Conv2d(dim_out, dim_in, kernel_size=(3, 3), padding=1) 231 | ])) 232 | 233 | self.final_resnet_block = resnet_block(2 * self.dim, self.dim) 234 | self.final_conv = nn.Conv2d(self.dim, self.channel, kernel_size=(1, 1)) 235 | 236 | def forward(self, x, time): 237 | """ 238 | return predicted noise given x_t and t 239 | """ 240 | x = self.init_conv(x) 241 | r = x.clone() 242 | t = self.time_mlp(time) 243 | concat = list() 244 | 245 | for block1, block2, attn, downsample in self.down_path: 246 | x = block1(x, t) 247 | concat.append(x) 248 | 249 | x = block2(x, t) 250 | x = attn(x) + x 251 | concat.append(x) 252 | 253 | x = downsample(x) 254 | 255 | x = self.mid_resnet_block1(x, t) 256 | x = self.mid_attention(x) + x 257 | x = self.mid_resnet_block2(x, t) 258 | 259 | for block1, block2, attn, upsample in self.up_path: 260 | x = torch.cat((x, concat.pop()), dim=1) 261 | x = block1(x, t) 262 | 263 | x = torch.cat((x, concat.pop()), dim=1) 264 | x = block2(x, t) 265 | x = attn(x) + x 266 | x = upsample(x) 267 | 268 | x = torch.cat((x, r), dim=1) 269 | x = self.final_resnet_block(x, t) 270 | return self.final_conv(x) 271 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | from .dataset import dataset_wrapper 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torchvision 9 | from torch.optim import Adam 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torchvision.utils import save_image 12 | from multiprocessing import cpu_count 13 | from functools import partial 14 | from tqdm import tqdm 15 | import datetime 16 | from termcolor import colored 17 | from .utils import * 18 | 19 | 20 | def cycle_with_label(dl): 21 | while True: 22 | for data in dl: 23 | img, label = data 24 | yield img 25 | 26 | 27 | def cycle(dl): 28 | while True: 29 | for data in dl: 30 | yield data 31 | 32 | 33 | class Trainer: 34 | def __init__(self, diffusion_model, dataset, batch_size=32, lr=2e-5, total_step=100000, ddim_samplers=None, 35 | save_and_sample_every=1000, num_samples=25, result_folder='./results', cpu_percentage=0, 36 | fid_estimate_batch_size=None, ddpm_fid_score_estimate_every=None, ddpm_num_fid_samples=None, 37 | max_grad_norm=1., tensorboard=True, exp_name=None, clip=True): 38 | """ 39 | Trainer for Diffusion model. 40 | :param diffusion_model: GaussianDiffusion model 41 | :param dataset: either 'cifar10' or path to the custom dataset you've prepared, where images are saved 42 | :param batch_size: batch size for training. DDPM author used 128 for cifar10 and 64 for 256X256 image 43 | :param lr: DDPM author used 2e-4 for cifar10 and 2e-5 for 256X256 image 44 | :param total_step: total training step. DDPM used 800K for cifar10, CelebA-HQ for 0.5M 45 | :param ddim_samplers: List containing DDIM samplers. 46 | :param save_and_sample_every: Step interval for saving model and generated image(by DDPM sampling). 47 | For example if it is set to 1000, then trainer will save models in every 1000 step and save generated images 48 | based on DDPM sampling schema. If you want to generate image based on DDIM sampling, you have to pass a list 49 | containing corresponding DDIM sampler. 50 | :param num_samples: # of generating images, must be square number ex) 25, 36, 49... 51 | :param result_folder: where model, generated images will be saved 52 | :param cpu_percentage: The percentage of CPU used for Dataloader i.e. num_workers in Dataloader. 53 | Value must be [0, 1] where 1 means using all cpu for dataloader. If you are Windows user setting value other 54 | than 0 will cause problem, so set to 0 55 | :param fid_estimate_batch_size: batch size for FID calculation. It has nothing to do with training. 56 | :param ddpm_fid_score_estimate_every: Step interval for FID calculation using DDPM. If set to None, FID score 57 | will not be calculated with DDPM sampling. If you use DDPM sampling for FID calculation, it can be very 58 | time consuming, so it is wise to set this value to None, and use DDIM sampler for FID calculation. But anyway 59 | you can calculate FID score with DDPM sampler if you insist to. 60 | :param ddpm_num_fid_samples: # of generating images for FID calculation using DDPM sampler. If you set 61 | ddpm_fid_score_estimate_every to None, i.e. not using DDPM sampler for FID calculation, then this value will 62 | be just ignored. 63 | :param max_grad_norm: Restrict the norm of maximum gradient to this value 64 | :param tensorboard: Set to ture if you want to monitor training 65 | :param exp_name: experiment name. If set to None, it will be decided automatically as folder name of dataset. 66 | :param clip: [True, False, 'both'] you can find detail in p_sample function in diffusion.py file. 67 | """ 68 | 69 | # Metadata & Initialization & Make directory for saving files. 70 | now = datetime.datetime.now() 71 | self.cur_time = now.strftime('%Y-%m-%d_%Hh%Mm') 72 | if exp_name is None: 73 | exp_name = os.path.basename(dataset) 74 | if exp_name == '': 75 | exp_name = os.path.basename(os.path.dirname(dataset)) 76 | self.exp_name = exp_name 77 | self.diffusion_model = diffusion_model 78 | self.ddim_samplers = ddim_samplers 79 | self.batch_size = batch_size 80 | self.num_samples = num_samples 81 | self.nrow = int(math.sqrt(self.num_samples)) 82 | assert (self.nrow ** 2) == self.num_samples, 'num_samples must be a square number. ex) 25, 36, 49, ...' 83 | self.save_and_sample_every = save_and_sample_every 84 | self.image_size = self.diffusion_model.image_size 85 | self.max_grad_norm = max_grad_norm 86 | self.result_folder = os.path.join(result_folder, exp_name, self.cur_time) 87 | self.ddpm_result_folder = os.path.join(self.result_folder, 'DDPM') 88 | self.device = self.diffusion_model.device 89 | self.clip = clip 90 | self.ddpm_fid_flag = True if ddpm_fid_score_estimate_every is not None else False 91 | self.ddpm_fid_score_estimate_every = ddpm_fid_score_estimate_every 92 | self.cal_fid = True if self.ddpm_fid_flag else False 93 | self.tqdm_sampler_name = None 94 | self.tensorboard = tensorboard 95 | self.tensorboard_name = None 96 | self.writer = None 97 | self.global_step = 0 98 | self.total_step = total_step 99 | self.fid_score_log = dict() 100 | assert clip in [True, False, 'both'], "clip must be one of [True, False, 'both']" 101 | if clip is True or clip == 'both': 102 | os.makedirs(os.path.join(self.ddpm_result_folder, 'clip'), exist_ok=True) 103 | if clip is False or clip == 'both': 104 | os.makedirs(os.path.join(self.ddpm_result_folder, 'no_clip'), exist_ok=True) 105 | 106 | # Dataset & DataLoader & Optimizer 107 | notification = make_notification('Dataset', color='light_green') 108 | print(notification) 109 | dataSet = dataset_wrapper(dataset, self.image_size) 110 | assert len(dataSet) >= 100, 'you should have at least 100 images in your folder.at least 10k images recommended' 111 | print(colored('Dataset Length: {}\n'.format(len(dataSet)), 'light_green')) 112 | CPU_cnt = cpu_count() 113 | # TODO: pin_memory? 114 | num_workers = int(CPU_cnt * cpu_percentage) 115 | assert num_workers <= CPU_cnt, "cpu_percentage must be [0.0, 1.0]" 116 | dataLoader = DataLoader(dataSet, batch_size=self.batch_size, shuffle=True, 117 | num_workers=num_workers, pin_memory=True) 118 | self.dataLoader = cycle(dataLoader) if os.path.isdir(dataset) else cycle_with_label(dataLoader) 119 | self.optimizer = Adam(self.diffusion_model.parameters(), lr=lr) 120 | 121 | # DDIM sampler setting 122 | self.ddim_sampling_schedule = list() 123 | for idx, sampler in enumerate(self.ddim_samplers): 124 | sampler.sampler_name = 'DDIM_{}_steps{}_eta{}'.format(idx + 1, sampler.ddim_steps, sampler.eta) 125 | self.ddim_sampling_schedule.append(sampler.sample_every) 126 | save_path = os.path.join(self.result_folder, sampler.sampler_name) 127 | sampler.save_path = save_path 128 | if sampler.save: 129 | os.makedirs(save_path, exist_ok=True) 130 | if sampler.generate_image: 131 | if sampler.clip is True or sampler.clip == 'both': 132 | os.makedirs(os.path.join(save_path, 'clip'), exist_ok=True) 133 | if sampler.clip is False or sampler.clip == 'both': 134 | os.makedirs(os.path.join(save_path, 'no_clip'), exist_ok=True) 135 | if sampler.calculate_fid: 136 | self.cal_fid = True 137 | if self.tqdm_sampler_name is None: 138 | self.tqdm_sampler_name = sampler.sampler_name 139 | sampler.num_fid_sample = sampler.num_fid_sample if sampler.num_fid_sample is not None else len(dataSet) 140 | self.fid_score_log[sampler.sampler_name] = list() 141 | if sampler.fixed_noise: 142 | sampler.register_buffer('noise', torch.randn([self.num_samples, sampler.channel, 143 | sampler.image_size, sampler.image_size])) 144 | 145 | # Image generation log 146 | notification = make_notification('Image Generation', color='light_cyan') 147 | print(notification) 148 | print(colored('Image will be generated with the following sampler(s)', 'light_cyan')) 149 | print(colored('-> DDPM Sampler / Image generation every {} steps'.format(self.save_and_sample_every), 150 | 'light_cyan')) 151 | for sampler in self.ddim_samplers: 152 | if sampler.generate_image: 153 | print(colored('-> {} / Image generation every {} steps / Fixed Noise : {}' 154 | .format(sampler.sampler_name, sampler.sample_every, sampler.fixed_noise), 'light_cyan')) 155 | print('\n') 156 | 157 | # FID score 158 | notification = make_notification('FID', color='light_magenta') 159 | print(notification) 160 | if not self.cal_fid: 161 | print(colored('No FID evaluation will be executed!\n' 162 | 'If you want FID evaluation consider using DDIM sampler.', 'light_magenta')) 163 | else: 164 | self.fid_batch_size = fid_estimate_batch_size if fid_estimate_batch_size is not None else self.batch_size 165 | dataSet_fid = dataset_wrapper(dataset, self.image_size, 166 | augment_horizontal_flip=False, info_color='light_magenta', min1to1=False) 167 | dataLoader_fid = DataLoader(dataSet_fid, batch_size=self.fid_batch_size, num_workers=num_workers) 168 | 169 | self.fid_scorer = FID(self.fid_batch_size, dataLoader_fid, dataset_name=exp_name, device=self.device, 170 | no_label=os.path.isdir(dataset)) 171 | 172 | print(colored('FID score will be calculated with the following sampler(s)', 'light_magenta')) 173 | if self.ddpm_fid_flag: 174 | self.ddpm_num_fid_samples = ddpm_num_fid_samples if ddpm_num_fid_samples is not None else len(dataSet) 175 | print(colored('-> DDPM Sampler / FID calculation every {} steps with {} generated samples' 176 | .format(self.ddpm_fid_score_estimate_every, self.ddpm_num_fid_samples), 'light_magenta')) 177 | for sampler in self.ddim_samplers: 178 | if sampler.calculate_fid: 179 | print(colored('-> {} / FID calculation every {} steps with {} generated samples' 180 | .format(sampler.sampler_name, sampler.sample_every, 181 | sampler.num_fid_sample), 'light_magenta')) 182 | print('\n') 183 | if self.ddpm_fid_flag: 184 | self.tqdm_sampler_name = 'DDPM' 185 | self.fid_score_log['DDPM'] = list() 186 | notification = make_notification('WARNING', color='red', boundary='*') 187 | print(notification) 188 | msg = """ 189 | FID computation witm DDPM sampler requires a lot of generated samples and can therefore be very time 190 | consuming.\nTo accelerate sampling, only using DDIM sampling is recommended. To disable DDPM sampling, 191 | set [ddpm_fid_score_estimate_every] parameter to None while instantiating Trainer.\n 192 | """ 193 | print(colored(msg, 'red')) 194 | del dataLoader_fid 195 | del dataSet_fid 196 | 197 | def train(self): 198 | # Tensorboard 199 | if self.tensorboard: 200 | os.makedirs('./tensorboard', exist_ok=True) 201 | self.tensorboard_name = self.exp_name + '_' + self.cur_time \ 202 | if self.tensorboard_name is None else self.tensorboard_name 203 | notification = make_notification('Tensorboard', color='light_blue') 204 | print(notification) 205 | print(colored('Tensorboard Available!', 'light_blue')) 206 | print(colored('Tensorboard name: {}'.format(self.tensorboard_name), 'light_blue')) 207 | print(colored('Launch Tensorboard by running following command on terminal', 'light_blue')) 208 | print(colored('tensorboard --logdir ./tensorboard\n', 'light_blue')) 209 | self.writer = SummaryWriter(os.path.join('./tensorboard', self.tensorboard_name)) 210 | notification = make_notification('Training', color='light_yellow', boundary='+') 211 | print(notification) 212 | cur_fid = 'NAN' 213 | ddpm_best_fid = 1e10 214 | stepTQDM = tqdm(range(self.global_step, self.total_step)) 215 | for cur_step in stepTQDM: 216 | self.diffusion_model.train() 217 | self.optimizer.zero_grad() 218 | image = next(self.dataLoader).to(self.device) 219 | loss = self.diffusion_model(image) 220 | loss.backward() 221 | nn.utils.clip_grad_norm_(self.diffusion_model.parameters(), self.max_grad_norm) 222 | self.optimizer.step() 223 | 224 | vis_fid = cur_fid if isinstance(cur_fid, str) else '{:.04f}'.format(cur_fid) 225 | stepTQDM.set_postfix({'loss': '{:.04f}'.format(loss.item()), 'FID': vis_fid, 'step':self.global_step}) 226 | 227 | self.diffusion_model.eval() 228 | # DDPM Sampler for generating images 229 | if cur_step != 0 and (cur_step % self.save_and_sample_every) == 0: 230 | if self.writer is not None: 231 | self.writer.add_scalar('Loss', loss.item(), cur_step) 232 | with torch.inference_mode(): 233 | batches = num_to_groups(self.num_samples, self.batch_size) 234 | for i, j in zip([True, False], ['clip', 'no_clip']): 235 | if self.clip not in [i, 'both']: 236 | continue 237 | imgs = list(map(lambda n: self.diffusion_model.sample(batch_size=n, clip=i), batches)) 238 | imgs = torch.cat(imgs, dim=0) 239 | save_image(imgs, nrow=self.nrow, 240 | fp=os.path.join(self.ddpm_result_folder, j, f'sample_{cur_step}.png')) 241 | if self.writer is not None: 242 | self.writer.add_images('DDPM sampling result ({})'.format(j), imgs, cur_step) 243 | self.save('latest') 244 | 245 | # DDPM Sampler for FID score evaluation 246 | if self.ddpm_fid_flag and cur_step != 0 and (cur_step % self.ddpm_fid_score_estimate_every) == 0: 247 | ddpm_cur_fid, _ = self.fid_scorer.fid_score(self.diffusion_model.sample, self.ddpm_num_fid_samples) 248 | if ddpm_best_fid > ddpm_cur_fid: 249 | ddpm_best_fid = ddpm_cur_fid 250 | self.save('best_fid_ddpm') 251 | if self.writer is not None: 252 | self.writer.add_scalars('FID', {'DDPM': ddpm_cur_fid}, cur_step) 253 | cur_fid = ddpm_cur_fid 254 | self.fid_score_log['DDPM'].append((self.global_step, ddpm_cur_fid)) 255 | 256 | # DDIM Sampler 257 | for sampler in self.ddim_samplers: 258 | if cur_step != 0 and (cur_step % sampler.sample_every) == 0: 259 | # DDPM Sampler for generating images 260 | if sampler.generate_image: 261 | with torch.inference_mode(): 262 | batches = num_to_groups(self.num_samples, self.batch_size) 263 | c_batch = np.insert(np.cumsum(np.array(batches)), 0, 0) 264 | for i, j in zip([True, False], ['clip', 'no_clip']): 265 | if sampler.clip not in [i, 'both']: 266 | continue 267 | if sampler.fixed_noise: 268 | imgs = list() 269 | for b in range(len(batches)): 270 | imgs.append(sampler.sample(self.diffusion_model, batch_size=None, clip=i, 271 | noise=sampler.noise[c_batch[b]:c_batch[b+1]])) 272 | else: 273 | imgs = list(map(lambda n: sampler.sample(self.diffusion_model, 274 | batch_size=n, clip=i), batches)) 275 | imgs = torch.cat(imgs, dim=0) 276 | save_image(imgs, nrow=self.nrow, 277 | fp=os.path.join(sampler.save_path, j, f'sample_{cur_step}.png')) 278 | if self.writer is not None: 279 | self.writer.add_images('{} sampling result ({})' 280 | .format(sampler.sampler_name, j), imgs, cur_step) 281 | 282 | # DDPM Sampler for FID score evaluation 283 | if sampler.calculate_fid: 284 | sample_ = partial(sampler.sample, self.diffusion_model) 285 | ddim_cur_fid, _ = self.fid_scorer.fid_score(sample_, sampler.num_fid_sample) 286 | if sampler.best_fid[0] > ddim_cur_fid: 287 | sampler.best_fid[0] = ddim_cur_fid 288 | if sampler.save: 289 | self.save('best_fid_{}'.format(sampler.sampler_name)) 290 | if sampler.sampler_name == self.tqdm_sampler_name: 291 | cur_fid = ddim_cur_fid 292 | if self.writer is not None: 293 | self.writer.add_scalars('FID', {sampler.sampler_name: ddim_cur_fid}, cur_step) 294 | self.fid_score_log[sampler.sampler_name].append((self.global_step, ddim_cur_fid)) 295 | 296 | self.global_step += 1 297 | 298 | print(colored('Training Finished!', 'light_yellow')) 299 | if self.writer is not None: 300 | self.writer.close() 301 | 302 | def save(self, name): 303 | data = { 304 | 'global_step': self.global_step, 305 | 'model': self.diffusion_model.state_dict(), 306 | 'opt': self.optimizer.state_dict(), 307 | 'fid_logger': self.fid_score_log, 308 | 'tensorboard': self.tensorboard_name 309 | } 310 | for sampler in self.ddim_samplers: 311 | data[sampler.sampler_name] = sampler.state_dict() 312 | torch.save(data, os.path.join(self.result_folder, 'model_{}.pt'.format(name))) 313 | 314 | def load(self, path, tensorboard_path=None, no_prev_ddim_setting=False): 315 | if not os.path.exists(path): 316 | print(make_notification('ERROR', color='red', boundary='*')) 317 | print(colored('No saved checkpoint is detected. Please check you gave existing path!', 'red')) 318 | exit() 319 | if tensorboard_path is not None and not os.path.exists(tensorboard_path): 320 | print(make_notification('ERROR', color='red', boundary='*')) 321 | print(colored('No tensorboard is detected. Please check you gave existing path!', 'red')) 322 | exit() 323 | print(make_notification('Loading Checkpoint', color='green')) 324 | data = torch.load(path, map_location=self.device) 325 | self.diffusion_model.load_state_dict(data['model']) 326 | self.global_step = data['global_step'] 327 | self.optimizer.load_state_dict(data['opt']) 328 | fid_score_log = data['fid_logger'] 329 | if no_prev_ddim_setting: 330 | for key, val in self.fid_score_log.items(): 331 | if key not in fid_score_log: 332 | fid_score_log[key] = val 333 | else: 334 | for sampler in self.ddim_samplers: 335 | sampler.load_state_dict(data[sampler.sampler_name]) 336 | self.fid_score_log = fid_score_log 337 | if tensorboard_path is not None: 338 | self.tensorboard_name = data['tensorboard'] 339 | print(colored('Successfully loaded checkpoint!\n', 'green')) 340 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import os 6 | from pytorch_fid.fid_score import calculate_frechet_distance 7 | from pytorch_fid.inception import InceptionV3 8 | from torch.nn.functional import adaptive_avg_pool2d 9 | from tqdm import tqdm 10 | from termcolor import colored 11 | 12 | 13 | def num_to_groups(num, divisor): 14 | groups = num // divisor 15 | remainder = num % divisor 16 | arr = [divisor] * groups 17 | if remainder > 0: 18 | arr.append(remainder) 19 | return arr 20 | 21 | 22 | class PositionalEncoding(nn.Module): 23 | def __init__(self, d_model): 24 | super().__init__() 25 | half_d_model = d_model // 2 26 | log_denominator = -math.log(10000) / (half_d_model - 1) 27 | denominator_ = torch.exp(torch.arange(half_d_model) * log_denominator) 28 | self.register_buffer('denominator', denominator_) 29 | 30 | def forward(self, time): 31 | """ 32 | :param time: shape=(B, ) 33 | :return: Positional Encoding shape=(B, d_model) 34 | """ 35 | argument = time[:, None] * self.denominator[None, :] # (B, half_d_model) 36 | return torch.cat([argument.sin(), argument.cos()], dim=-1) # (B, d_model) 37 | 38 | 39 | class FID: 40 | def __init__(self, batch_size, dataLoader, dataset_name, cache_dir='./results/fid_cache/', device='cuda', 41 | no_label=False, inception_block_idx=3): 42 | assert inception_block_idx in [0, 1, 2, 3], 'inception_block_idx must be either 0, 1, 2, 3' 43 | self.batch_size = batch_size 44 | self.dataLoader = dataLoader 45 | self.cache_dir = cache_dir 46 | self.dataset_name = dataset_name 47 | self.device = device 48 | self.no_label = no_label 49 | self.inception = InceptionV3([inception_block_idx]).to(device) 50 | 51 | os.makedirs(cache_dir, exist_ok=True) 52 | self.m2, self.s2 = self.load_dataset_stats() 53 | 54 | def calculate_inception_features(self, samples): 55 | self.inception.eval() 56 | features = self.inception(samples)[0] 57 | if features.size(2) != 1 or features.size(3) != 1: 58 | features = adaptive_avg_pool2d(features, output_size=(1, 1)) 59 | return features.squeeze() 60 | 61 | def load_dataset_stats(self): 62 | path = os.path.join(self.cache_dir, self.dataset_name + '.npz') 63 | if os.path.exists(path): 64 | with np.load(path) as f: 65 | m2, s2 = f['m2'], f['s2'] 66 | print(colored('Successfully loaded pre-computed Inception feature from cached file\n', 'light_magenta')) 67 | else: 68 | stacked_real_features = list() 69 | print(colored('Computing Inception features for {} ' 70 | 'samples from real dataset.'.format(len(self.dataLoader.dataset)), 'light_magenta')) 71 | for batch in tqdm(self.dataLoader, desc='Calculating stats for data distribution', leave=False): 72 | real_samples = batch.to(self.device) if self.no_label else batch[0].to(self.device) 73 | real_features = self.calculate_inception_features(real_samples) 74 | stacked_real_features.append(real_features) 75 | 76 | stacked_real_features = torch.cat(stacked_real_features, dim=0).cpu().numpy() 77 | m2 = np.mean(stacked_real_features, axis=0) 78 | s2 = np.cov(stacked_real_features, rowvar=False) 79 | np.savez_compressed(path, m2=m2, s2=s2) 80 | print(colored('Dataset stats cached to {} for future use\n'.format(path), 'light_magenta')) 81 | return m2, s2 82 | 83 | @torch.inference_mode() 84 | def fid_score(self, sampler, num_samples, return_sample_image=False): 85 | batches = num_to_groups(num_samples, self.batch_size) 86 | stacked_fake_features = list() 87 | generated_samples = list() if return_sample_image else None 88 | for batch in tqdm(batches, desc='FID score calculation', leave=False): 89 | fake_samples = sampler(batch, clip=True, min1to1=False) 90 | if return_sample_image: 91 | generated_samples.append(fake_samples) 92 | fake_features = self.calculate_inception_features(fake_samples) 93 | stacked_fake_features.append(fake_features) 94 | stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu().numpy() 95 | m1 = np.mean(stacked_fake_features, axis=0) 96 | s1 = np.cov(stacked_fake_features, rowvar=False) 97 | generated_samples_return = None 98 | if return_sample_image: 99 | generated_samples_return = torch.cat(generated_samples, dim=0) 100 | generated_samples_return = (generated_samples_return + 1.0) * 0.5 101 | return calculate_frechet_distance(m1, s1, self.m2, self.s2), generated_samples_return 102 | 103 | 104 | def make_notification(content, color, boundary='-'): 105 | notice = boundary * 30 + '\n' 106 | side = boundary if boundary != '-' else '|' 107 | notice += '{}{:^28}{}\n'.format(side, content, side) 108 | notice += boundary * 30 109 | return colored(notice, color) 110 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from src import model_torch 2 | from src import model_original 3 | from src.trainer import Trainer 4 | from src.diffusion import GaussianDiffusion, DDIM_Sampler 5 | import yaml 6 | import argparse 7 | 8 | 9 | def main(args): 10 | with open(args.config, 'r') as f: 11 | config = yaml.load(f, Loader=yaml.FullLoader) 12 | unet_cfg = config['unet'] 13 | ddim_cfg = config['ddim'] 14 | trainer_cfg = config['trainer'] 15 | image_size = unet_cfg['image_size'] 16 | 17 | if config['type'] == 'original': 18 | unet = model_original.Unet(**unet_cfg).to(args.device) 19 | elif config['type'] == 'torch': 20 | unet = model_torch.Unet(**unet_cfg).to(args.device) 21 | else: 22 | unet = None 23 | print("Unet type must be one of ['original', 'torch']") 24 | exit() 25 | 26 | diffusion = GaussianDiffusion(unet, image_size=image_size).to(args.device) 27 | 28 | ddim_samplers = list() 29 | for sampler_cfg in ddim_cfg.values(): 30 | ddim_samplers.append(DDIM_Sampler(diffusion, **sampler_cfg)) 31 | 32 | trainer = Trainer(diffusion, ddim_samplers=ddim_samplers, exp_name=args.exp_name, 33 | cpu_percentage=args.cpu_percentage, **trainer_cfg) 34 | if args.load is not None: 35 | trainer.load(args.load, args.tensorboard, args.no_prev_ddim_setting) 36 | trainer.train() 37 | 38 | 39 | if __name__ == '__main__': 40 | parse = argparse.ArgumentParser(description='DDPM & DDIM') 41 | parse.add_argument('-c', '--config', type=str, default='./config/cifar10.yaml') 42 | parse.add_argument('-l', '--load', type=str, default=None) 43 | parse.add_argument('-t', '--tensorboard', type=str, default=None) 44 | parse.add_argument('--exp_name', default=None) 45 | parse.add_argument('--device', type=str, choices=['cuda', 'cpu'], default='cuda') 46 | parse.add_argument('--cpu_percentage', type=float, default=0) 47 | parse.add_argument('--no_prev_ddim_setting', action='store_true') 48 | args = parse.parse_args() 49 | main(args) 50 | --------------------------------------------------------------------------------