├── .gitignore ├── README.md ├── assets ├── datasets │ ├── download_extract_DIODE.sh │ ├── download_extract_ImageNet.sh │ ├── download_extract_edges2handbags.sh │ ├── val_faster_imagefolder_10k_fn.txt │ └── val_faster_imagefolder_10k_label.txt └── teaser.png ├── corruption ├── LICENSE_DDRM ├── LICENSE_DDRM_JPEG ├── __init__.py ├── base.py ├── blur.py ├── inpaint.py ├── jpeg.py ├── mixture.py └── superresolution.py ├── datasets ├── __init__.py ├── aligned_dataset.py ├── augment.py ├── image_folder.py ├── imagenet_inpaint.py └── misc.py ├── ddbm ├── __init__.py ├── dist_util.py ├── karras_diffusion.py ├── logger.py ├── nn.py ├── random_util.py ├── resample.py ├── script_util.py ├── train_util.py └── unet.py ├── download_diffusion.py ├── evaluation ├── __init__.py ├── compute_metrices_imagenet.py ├── fid_util.py └── resnet.py ├── evaluations ├── __init__.py ├── evaluator.py ├── feature_extractor.py ├── inception_pytorch.py ├── inception_torchscript.py ├── inception_v3.py └── requirements.txt ├── logger.py ├── preprocess_ckpt.py ├── preprocess_depth.py ├── sample.py ├── scripts ├── args.sh ├── evaluate.sh ├── sample.sh └── train_bridge.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | wandb 4 | workdir 5 | *.pt 6 | *.ipynb 7 | *.npz 8 | .ipynb_checkpoints 9 | assets/datasets/DIODE/ 10 | assets/datasets/ImageNet/ 11 | assets/datasets/edges2handbags/ 12 | assets/datasets/DIODE-256/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accelerating Diffusion Bridge Models for Image-to-Image Translation 2 | 3 | Official Implementation of [Diffusion Bridge Implicit Models](https://arxiv.org/abs/2405.15885) (ICLR 2025) and [Consistency Diffusion Bridge Models](https://arxiv.org/abs/2410.22637) (NeurIPS 2024). 4 | 5 | # Dependencies 6 | 7 | To install all packages in this codebase along with their dependencies, run 8 | ```sh 9 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 10 | pip install blobfile piq matplotlib opencv-python joblib lmdb scipy clean-fid easydict torchmetrics rich ipdb wandb 11 | ``` 12 | 13 | # Datasets 14 | 15 | Please put (or link) the datasets under `assets/datasets/`. 16 | 17 | - For Edges2Handbags, please follow instructions from [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md). The resulting folder structure should be `assets/datasets/edges2handbags/train` and `assets/datasets/edges2handbags/val`. 18 | - For DIODE, please download the training dataset and the data list from [here](https://diode-dataset.org/). The resulting folder structure should be `assets/datasets/DIODE/train` and `assets/datasets/DIODE/data_list`. 19 | - For ImageNet, please download the dataset from [here](https://image-net.org/download.php). The resulting folder structure should be `assets/datasets/ImageNet/train` and `assets/datasets/ImageNet/val`. 20 | 21 | We also provide automatic downloading scripts. 22 | ``` 23 | cd assets/datasets 24 | bash download_extract_edges2handbags.sh 25 | bash download_extract_DIODE.sh 26 | bash download_extract_ImageNet.sh 27 | ``` 28 | 29 | After downloading, the DIODE dataset requires preprocessing by running `python preprocess_depth.py`. 30 | 31 | # Diffusion Bridge Implicit Models 32 | 33 |

34 | 35 |

36 | 37 | DBIM offers a suite of fast samplers tailored for [Denoising Diffusion Bridge Models (DDBMs)](https://arxiv.org/abs/2309.16948). We clean the codebase to support a broad range of diffusion bridges, facilitating unified training and sampling workflows. We also streamline the deployment process by replacing the cumbersome MPI-based distributed launcher with the more efficient and engineer-friendly `torchrun`. 38 | 39 | ## Pre-trained models 40 | 41 | Please put the downloaded checkpoints under `assets/ckpts/`. 42 | 43 | For image translation, we directly adopt the pretrained checkpoints from [DDBM](https://github.com/alexzhou907/DDBM): 44 | 45 | - Edges2Handbags: [e2h_ema_0.9999_420000.pt](https://huggingface.co/alexzhou907/DDBM/resolve/main/e2h_ema_0.9999_420000.pt) 46 | - DIODE: [diode_ema_0.9999_440000.pt](https://huggingface.co/alexzhou907/DDBM/resolve/main/diode_ema_0.9999_440000.pt) 47 | 48 | We remove the dependency on external packages such as `flash_attn` in this codebase, which is already supported natively by PyTorch. After downloading the two checkpoints above, please run `python preprocess_ckpt.py` to complete the conversion. 49 | 50 | For image restoration: 51 | 52 | - Center 128x128 Inpainting on ImageNet 256x256: [imagenet256_inpaint_ema_0.9999_400000.pt](https://drive.google.com/file/d/1WozJyVOAFukj0nUYLS-ZUp1-QHuGNfox) 53 | 54 | ## Sampling 55 | 56 | ``` 57 | bash scripts/sample.sh $DATASET_NAME $NFE $SAMPLER ($AUX) 58 | ``` 59 | 60 | - `$DATASET_NAME` can be chosen from `e2h`/`diode`/`imagenet_inpaint_center`. 61 | - `$NFE` is the *Number of Function Evaluations*, which is proportional to the sampling time. 62 | - `$SAMPLER` can be chosen from `heun`/`dbim`/`dbim_high_order`. 63 | - `heun` is the vanilla sampler of DDBM, which simulates the SDE/ODE step alternatively. In this case, `$AUX` is not required. 64 | - `dbim` and `dbim_high_order` are our proposed samplers. When using `dbim`, `$AUX` corresponds to $\eta$ which controls the stochasticity level (floating-point value in $[0,1]$). When using `dbim_high_order`, `$AUX` corresponds to the order (2 or 3). 65 | 66 | The samples will be saved to `workdir/`. 67 | 68 | ## Evaluations 69 | 70 | Before evaluating the image translation results, please download the reference statistics from DDBM and put them under `assets/stats/`: 71 | - Reference stats for Edge2Handbags: [edges2handbags_ref_64_data.npz](https://huggingface.co/alexzhou907/DDBM/resolve/main/edges2handbags_ref_64_data.npz). 72 | - Reference stats for DIODE: [diode_ref_256_data.npz](https://huggingface.co/alexzhou907/DDBM/resolve/main/diode_ref_256_data.npz). 73 | 74 | The evaluation can automatically proceed by specifying the same dataset and sampler arguments as sampling: 75 | 76 | ``` 77 | bash scripts/evaluate.sh $DATASET_NAME $NFE $SAMPLER ($AUX) 78 | ``` 79 | 80 | ## Training 81 | 82 | We provide the script for training diffusion bridge models on the ImageNet 256x256 inpainting (center 128x128) task. As ImageNet 256x256 is a challenging dataset, we follow [I2SB](https://github.com/NVlabs/I2SB) and initialize the network with a pretrained diffusion model [ADM](https://github.com/openai/guided-diffusion/). To start training, run 83 | 84 | ``` 85 | python download_diffusion.py 86 | bash scripts/train_bridge.sh 87 | ``` 88 | 89 | Training for 400k iterations takes around 2 weeks on 8xA100. We recommend using multiple nodes and modifying `scripts/train_bridge.sh` accordingly. 90 | 91 | # Consistency Diffusion Bridge Models 92 | 93 | TODO 94 | 95 | # Acknowledgement 96 | 97 | This codebase is built upon [DDBM](https://github.com/alexzhou907/DDBM) and [I2SB](https://github.com/NVlabs/I2SB). 98 | 99 | 100 | # Citation 101 | 102 | If you find this method and/or code useful, please consider citing 103 | 104 | ``` 105 | @article{zheng2024diffusion, 106 | title={Diffusion Bridge Implicit Models}, 107 | author={Zheng, Kaiwen and He, Guande and Chen, Jianfei and Bao, Fan and Zhu, Jun}, 108 | journal={arXiv preprint arXiv:2405.15885}, 109 | year={2024} 110 | } 111 | ``` 112 | 113 | and 114 | 115 | ``` 116 | @article{he2024consistency, 117 | title={Consistency Diffusion Bridge Models}, 118 | author={He, Guande and Zheng, Kaiwen and Chen, Jianfei and Bao, Fan and Zhu, Jun}, 119 | journal={arXiv preprint arXiv:2410.22637}, 120 | year={2024} 121 | } 122 | ``` -------------------------------------------------------------------------------- /assets/datasets/download_extract_DIODE.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | 4 | mkdir DIODE && cd DIODE 5 | 6 | wget http://diode-dataset.s3.amazonaws.com/train.tar.gz 7 | wget http://diode-dataset.s3.amazonaws.com/train_normals.tar.gz 8 | wget https://diode-1254389886.cos.ap-hongkong.myqcloud.com/data_list.zip 9 | 10 | tar -xvzf train.tar.gz && rm -f train.tar.gz 11 | tar -xvzf train_normals.tar.gz && rm -f train_normals.tar.gz 12 | unzip data_list.zip && rm -f data_list.zip -------------------------------------------------------------------------------- /assets/datasets/download_extract_ImageNet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | 4 | wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar --no-check-certificate 5 | wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificate 6 | 7 | # script to extract ImageNet dataset 8 | # ILSVRC2012_img_train.tar (about 138 GB) 9 | # ILSVRC2012_img_val.tar (about 6.3 GB) 10 | # make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory 11 | # 12 | # Adapted from: 13 | # https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md 14 | # https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4 15 | # 16 | # ImageNet/train/ 17 | # ├── n01440764 18 | # │ ├── n01440764_10026.JPEG 19 | # │ ├── n01440764_10027.JPEG 20 | # │ ├── ...... 21 | # ├── ...... 22 | # ImageNet/val/ 23 | # ├── n01440764 24 | # │ ├── ILSVRC2012_val_00000293.JPEG 25 | # │ ├── ILSVRC2012_val_00002138.JPEG 26 | # │ ├── ...... 27 | # ├── ...... 28 | # 29 | # 30 | # Make imagnet directory 31 | # 32 | mkdir ImageNet 33 | # 34 | # Extract the training data: 35 | # 36 | # Create train directory; move .tar file; change directory 37 | mkdir ImageNet/train && mv ILSVRC2012_img_train.tar ImageNet/train/ && cd ImageNet/train 38 | # Extract training set; remove compressed file 39 | tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar 40 | # 41 | # At this stage ImageNet/train will contain 1000 compressed .tar files, one for each category 42 | # 43 | # For each .tar file: 44 | # 1. create directory with same name as .tar file 45 | # 2. extract and copy contents of .tar file into directory 46 | # 3. remove .tar file 47 | find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done 48 | # 49 | # This results in a training directory like so: 50 | # 51 | # ImageNet/train/ 52 | # ├── n01440764 53 | # │ ├── n01440764_10026.JPEG 54 | # │ ├── n01440764_10027.JPEG 55 | # │ ├── ...... 56 | # ├── ...... 57 | # 58 | # Change back to original directory 59 | cd ../.. 60 | # 61 | # Extract the validation data and move images to subfolders: 62 | # 63 | # Create validation directory; move .tar file; change directory; extract validation .tar; remove compressed file 64 | mkdir ImageNet/val && mv ILSVRC2012_img_val.tar ImageNet/val/ && cd ImageNet/val && tar -xvf ILSVRC2012_img_val.tar && rm -f ILSVRC2012_img_val.tar 65 | # get script from soumith and run; this script creates all class directories and moves images into corresponding directories 66 | wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash 67 | # 68 | # This results in a validation directory like so: 69 | # 70 | # ImageNet/val/ 71 | # ├── n01440764 72 | # │ ├── ILSVRC2012_val_00000293.JPEG 73 | # │ ├── ILSVRC2012_val_00002138.JPEG 74 | # │ ├── ...... 75 | # ├── ...... 76 | # 77 | # 78 | # Check total files after extract 79 | # 80 | # $ find train/ -name "*.JPEG" | wc -l 81 | # 1281167 82 | # $ find val/ -name "*.JPEG" | wc -l 83 | # 50000 84 | # 85 | -------------------------------------------------------------------------------- /assets/datasets/download_extract_edges2handbags.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | 4 | wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/edges2handbags.tar.gz 5 | 6 | tar -xvzf edges2handbags.tar.gz && rm -f edges2handbags.tar.gz -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DiffusionBridge/92522733cc602686df77f07a1824bb89f89cda1a/assets/teaser.png -------------------------------------------------------------------------------- /corruption/LICENSE_DDRM: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bahjat Kawar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /corruption/LICENSE_DDRM_JPEG: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright [2022] [Bahjat Kawar, Jiaming Song, Stefano Ermon, Michael Elad] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /corruption/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | 9 | def build_corruption(opt, log, corrupt_type=None): 10 | 11 | if corrupt_type is None: 12 | corrupt_type = opt.corrupt 13 | 14 | if "inpaint" in corrupt_type: 15 | from .inpaint import build_inpaint_center, build_inpaint_freeform 16 | 17 | mask = corrupt_type.split("-")[1] 18 | assert mask in ["center", "freeform1020", "freeform2030"] 19 | if mask == "center": 20 | method = build_inpaint_center(opt, log, mask) 21 | elif "freeform" in mask: 22 | method = build_inpaint_freeform(opt, log, mask) 23 | 24 | elif "jpeg" in corrupt_type: 25 | from .jpeg import build_jpeg 26 | 27 | quality_factor = int(corrupt_type.split("-")[1]) 28 | method = build_jpeg(log, quality_factor) 29 | 30 | elif "sr4x" in corrupt_type: 31 | from .superresolution import build_sr4x 32 | 33 | sr_filter = corrupt_type.split("-")[1] 34 | assert sr_filter in ["pool", "bicubic"] 35 | method = build_sr4x(opt, log, sr_filter, image_size=opt.image_size) 36 | 37 | elif "blur" in corrupt_type: 38 | from .blur import build_blur 39 | 40 | kernel = corrupt_type.split("-")[1] 41 | assert kernel in ["uni", "gauss"] 42 | method = build_blur(opt, log, kernel) 43 | 44 | elif "mixture" in corrupt_type: 45 | method = None # 46 | else: 47 | raise RuntimeWarning(f"Unknown corruption: {corrupt_type}!") 48 | 49 | return method 50 | -------------------------------------------------------------------------------- /corruption/base.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ddrm. 5 | # 6 | # Source: 7 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L3 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_DDRM). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | 15 | class H_functions: 16 | """ 17 | A class replacing the SVD of a matrix H, perhaps efficiently. 18 | All input vectors are of shape (Batch, ...). 19 | All output vectors are of shape (Batch, DataDimension). 20 | """ 21 | 22 | def V(self, vec): 23 | """ 24 | Multiplies the input vector by V 25 | """ 26 | raise NotImplementedError() 27 | 28 | def Vt(self, vec): 29 | """ 30 | Multiplies the input vector by V transposed 31 | """ 32 | raise NotImplementedError() 33 | 34 | def U(self, vec): 35 | """ 36 | Multiplies the input vector by U 37 | """ 38 | raise NotImplementedError() 39 | 40 | def Ut(self, vec): 41 | """ 42 | Multiplies the input vector by U transposed 43 | """ 44 | raise NotImplementedError() 45 | 46 | def singulars(self): 47 | """ 48 | Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U) 49 | """ 50 | raise NotImplementedError() 51 | 52 | def add_zeros(self, vec): 53 | """ 54 | Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V) 55 | """ 56 | raise NotImplementedError() 57 | 58 | def H(self, vec): 59 | """ 60 | Multiplies the input vector by H 61 | """ 62 | temp = self.Vt(vec) 63 | singulars = self.singulars() 64 | return self.U(singulars * temp[:, : singulars.shape[0]]) 65 | 66 | def Ht(self, vec): 67 | """ 68 | Multiplies the input vector by H transposed 69 | """ 70 | temp = self.Ut(vec) 71 | singulars = self.singulars() 72 | return self.V(self.add_zeros(singulars * temp[:, : singulars.shape[0]])) 73 | 74 | def H_pinv(self, vec): 75 | """ 76 | Multiplies the input vector by the pseudo inverse of H 77 | """ 78 | temp = self.Ut(vec) 79 | singulars = self.singulars() 80 | temp[:, : singulars.shape[0]] = temp[:, : singulars.shape[0]] / singulars 81 | return self.V(self.add_zeros(temp)) 82 | -------------------------------------------------------------------------------- /corruption/blur.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ddrm. 5 | # 6 | # Source: 7 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L397 8 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L245 9 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L251 10 | # 11 | # The license for the original version of this file can be 12 | # found in this directory (LICENSE_DDRM). 13 | # The modifications to this file are subject to the same license. 14 | # --------------------------------------------------------------- 15 | 16 | import torch 17 | from .base import H_functions 18 | 19 | class Deblurring(H_functions): 20 | def mat_by_img(self, M, v): 21 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 22 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 23 | 24 | def img_by_mat(self, v, M): 25 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 26 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 27 | 28 | def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2): 29 | self.img_dim = img_dim 30 | self.channels = channels 31 | #build 1D conv matrix 32 | H_small = torch.zeros(img_dim, img_dim, device=device) 33 | for i in range(img_dim): 34 | for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2): 35 | if j < 0 or j >= img_dim: continue 36 | H_small[i, j] = kernel[j - i + kernel.shape[0]//2] 37 | #get the svd of the 1D conv 38 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False) 39 | #ZERO = 3e-2 40 | self.singulars_small[self.singulars_small < ZERO] = 0 41 | #calculate the singular values of the big matrix 42 | self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2) 43 | #sort the big matrix singulars and save the permutation 44 | self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True) 45 | 46 | def V(self, vec): 47 | #invert the permutation 48 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 49 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 50 | temp = temp.permute(0, 2, 1) 51 | #multiply the image by V from the left and by V^T from the right 52 | out = self.mat_by_img(self.V_small, temp) 53 | out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 54 | return out 55 | 56 | def Vt(self, vec): 57 | #multiply the image by V^T from the left and by V from the right 58 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone()) 59 | temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1) 60 | #permute the entries according to the singular values 61 | temp = temp[:, :, self._perm].permute(0, 2, 1) 62 | return temp.reshape(vec.shape[0], -1) 63 | 64 | def U(self, vec): 65 | #invert the permutation 66 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 67 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 68 | temp = temp.permute(0, 2, 1) 69 | #multiply the image by U from the left and by U^T from the right 70 | out = self.mat_by_img(self.U_small, temp) 71 | out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1) 72 | return out 73 | 74 | def Ut(self, vec): 75 | #multiply the image by U^T from the left and by U from the right 76 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone()) 77 | temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1) 78 | #permute the entries according to the singular values 79 | temp = temp[:, :, self._perm].permute(0, 2, 1) 80 | return temp.reshape(vec.shape[0], -1) 81 | 82 | def singulars(self): 83 | return self._singulars.repeat(1, 3).reshape(-1) 84 | 85 | def add_zeros(self, vec): 86 | return vec.clone().reshape(vec.shape[0], -1) 87 | 88 | def build_blur(opt, log, kernel): 89 | log.info(f"[Corrupt] Bluring {kernel=}...") 90 | 91 | uni = Deblurring(torch.Tensor([1/9] * 9).to(opt.device), 3, opt.image_size, opt.device) 92 | 93 | sigma = 10 94 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2])) 95 | g_kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(opt.device) 96 | gauss = Deblurring(g_kernel / g_kernel.sum(), 3, opt.image_size, opt.device) 97 | 98 | xdim = (3, opt.image_size, opt.image_size) 99 | 100 | assert kernel in ["uni", "gauss"] 101 | def blur(img): 102 | # img: [-1,1] -> [0,1] 103 | img = (img + 1) / 2 104 | if kernel == "uni": 105 | img = uni.H(img).reshape(img.shape[0], *xdim) 106 | elif kernel == "gauss": 107 | img = gauss.H(img).reshape(img.shape[0], *xdim) 108 | # [0,1] -> [-1,1] 109 | return img * 2 - 1 110 | 111 | return blur 112 | -------------------------------------------------------------------------------- /corruption/inpaint.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from diffusion_palette_eval. 5 | # 6 | # Source: 7 | # https://bit.ly/eval-pix2pix 8 | # 9 | # --------------------------------------------------------------- 10 | 11 | import io 12 | 13 | import os 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from pathlib import Path 19 | 20 | # import gdown 21 | # from ipdb import set_trace as debug 22 | 23 | FREEFORM_URL = "https://drive.google.com/file/d/1-5YRGsekjiRKQWqo0BV5RVQu0bagc12w/view?usp=share_link" 24 | 25 | 26 | # code adoptted from 27 | # https://bit.ly/eval- pix2pix 28 | def bbox2mask(img_shape, bbox, dtype="uint8"): 29 | """Generate mask in ndarray from bbox. 30 | 31 | The returned mask has the shape of (h, w, 1). '1' indicates the 32 | hole and '0' indicates the valid regions. 33 | 34 | We prefer to use `uint8` as the data type of masks, which may be different 35 | from other codes in the community. 36 | 37 | Args: 38 | img_shape (tuple[int]): The size of the image. 39 | bbox (tuple[int]): Configuration tuple, (top, left, height, width) 40 | dtype (str): Indicate the data type of returned masks. Default: 'uint8' 41 | 42 | Return: 43 | numpy.ndarray: Mask in the shape of (h, w, 1). 44 | """ 45 | 46 | height, width = img_shape[:2] 47 | 48 | mask = np.zeros((height, width, 1), dtype=dtype) 49 | mask[bbox[0] : bbox[0] + bbox[2], bbox[1] : bbox[1] + bbox[3], :] = 1 50 | 51 | return mask 52 | 53 | 54 | # code adoptted from 55 | # https://bit.ly/eval-pix2pix 56 | def load_masks(filename): 57 | # filename = "imagenet_freeform_masks.npz" 58 | shape = [10000, 256, 256] 59 | 60 | # shape = [10950, 256, 256] # Uncomment this for places2. 61 | 62 | # Load the npz file. 63 | with open(filename, "rb") as f: 64 | data = f.read() 65 | 66 | data = dict(np.load(io.BytesIO(data))) 67 | # print("Categories of masks:") 68 | # for key in data: 69 | # print(key) 70 | 71 | # Unpack and reshape the masks. 72 | for key in data: 73 | data[key] = np.unpackbits(data[key], axis=None)[: np.prod(shape)].reshape(shape).astype(np.uint8) 74 | 75 | # data[key] contains [10000, 256, 256] array i.e. 10000 256x256 masks. 76 | return data 77 | 78 | 79 | def load_freeform_masks(op_type): 80 | data_dir = Path("assets/datasets/ImageNet") 81 | 82 | mask_fn = data_dir / f"imagenet_{op_type}_masks.npz" 83 | if not mask_fn.exists(): 84 | # download orignal npz from palette google drive 85 | orig_mask_fn = str(data_dir / "imagenet_freeform_masks.npz") 86 | if not os.path.exists(orig_mask_fn): 87 | assert False 88 | # gdown.download(url=FREEFORM_URL, output=orig_mask_fn, quiet=False, fuzzy=True) 89 | masks = load_masks(orig_mask_fn) 90 | 91 | # store freeform of current ratio for faster loading in future 92 | key = { 93 | "freeform1020": "10-20% freeform", 94 | "freeform2030": "20-30% freeform", 95 | "freeform3040": "30-40% freeform", 96 | }.get(op_type) 97 | np.savez(mask_fn, mask=masks[key]) 98 | 99 | # [10000, 256, 256] --> [10000, 1, 256, 256] 100 | return np.load(mask_fn)["mask"][:, None] 101 | 102 | 103 | def get_center_mask(image_size): 104 | h, w = image_size 105 | mask = bbox2mask(image_size, (h // 4, w // 4, h // 2, w // 2)) 106 | return torch.from_numpy(mask).permute(2, 0, 1) 107 | 108 | 109 | def build_inpaint_center(mask_type, image_size): 110 | assert mask_type == "center" 111 | 112 | # log.info(f"[Corrupt] Inpaint: {mask_type=} ...") 113 | 114 | center_mask = get_center_mask([image_size, image_size]) # [1,1,256,256] 115 | # center_mask = center_mask.to(opt.device) 116 | 117 | def inpaint_center(img): 118 | # img: [-1,1] 119 | mask = center_mask 120 | # img[mask==0] = img[mask==0], img[mask==1] = 1 (white) 121 | return img * (1.0 - mask) + mask * 0, mask 122 | 123 | return inpaint_center 124 | 125 | 126 | def build_inpaint_freeform(mask_type): 127 | assert "freeform" in mask_type 128 | 129 | # log.info(f"[Corrupt] Inpaint: {mask_type=} ...") 130 | 131 | freeform_masks = load_freeform_masks(mask_type) # [10000, 1, 256, 256] 132 | n_freeform_masks = freeform_masks.shape[0] 133 | # freeform_masks = torch.from_numpy(freeform_masks).to(opt.device) 134 | freeform_masks = torch.from_numpy(freeform_masks) 135 | 136 | def inpaint_freeform(img): 137 | # img: [-1,1] 138 | # index = np.random.randint(n_freeform_masks, size=img.shape[0]) 139 | index = np.random.randint(n_freeform_masks) 140 | mask = freeform_masks[index] 141 | # img[mask==0] = img[mask==0], img[mask==1] = 1 (white) 142 | return img * (1.0 - mask) + mask * 0, mask 143 | 144 | return inpaint_freeform 145 | -------------------------------------------------------------------------------- /corruption/jpeg.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ddrm-jpeg. 5 | # 6 | # Source: 7 | # https://github.com/bahjat-kawar/ddrm-jpeg/blob/master/functions/jpeg_torch.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_DDRM_JPEG). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | 18 | def dct1(x): 19 | """ 20 | Discrete Cosine Transform, Type I 21 | :param x: the input signal 22 | :return: the DCT-I of the signal over the last dimension 23 | """ 24 | x_shape = x.shape 25 | x = x.view(-1, x_shape[-1]) 26 | 27 | return torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1))[:, :, 0].view(*x_shape) 28 | 29 | 30 | def idct1(X): 31 | """ 32 | The inverse of DCT-I, which is just a scaled DCT-I 33 | Our definition if idct1 is such that idct1(dct1(x)) == x 34 | :param X: the input signal 35 | :return: the inverse DCT-I of the signal over the last dimension 36 | """ 37 | n = X.shape[-1] 38 | return dct1(X) / (2 * (n - 1)) 39 | 40 | 41 | def dct(x, norm=None): 42 | """ 43 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 44 | For the meaning of the parameter `norm`, see: 45 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 46 | :param x: the input signal 47 | :param norm: the normalization, None or 'ortho' 48 | :return: the DCT-II of the signal over the last dimension 49 | """ 50 | x_shape = x.shape 51 | N = x_shape[-1] 52 | x = x.contiguous().view(-1, N) 53 | 54 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 55 | 56 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 57 | 58 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 59 | W_r = torch.cos(k) 60 | W_i = torch.sin(k) 61 | 62 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 63 | 64 | if norm == 'ortho': 65 | V[:, 0] /= np.sqrt(N) * 2 66 | V[:, 1:] /= np.sqrt(N / 2) * 2 67 | 68 | V = 2 * V.view(*x_shape) 69 | 70 | return V 71 | 72 | 73 | def idct(X, norm=None): 74 | """ 75 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 76 | Our definition of idct is that idct(dct(x)) == x 77 | For the meaning of the parameter `norm`, see: 78 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 79 | :param X: the input signal 80 | :param norm: the normalization, None or 'ortho' 81 | :return: the inverse DCT-II of the signal over the last dimension 82 | """ 83 | 84 | x_shape = X.shape 85 | N = x_shape[-1] 86 | 87 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 88 | 89 | if norm == 'ortho': 90 | X_v[:, 0] *= np.sqrt(N) * 2 91 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 92 | 93 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 94 | W_r = torch.cos(k) 95 | W_i = torch.sin(k) 96 | 97 | V_t_r = X_v 98 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 99 | 100 | V_r = V_t_r * W_r - V_t_i * W_i 101 | V_i = V_t_r * W_i + V_t_i * W_r 102 | 103 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 104 | 105 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 106 | x = v.new_zeros(v.shape) 107 | x[:, ::2] += v[:, :N - (N // 2)] 108 | x[:, 1::2] += v.flip([1])[:, :N // 2] 109 | 110 | return x.view(*x_shape) 111 | 112 | 113 | def dct_2d(x, norm=None): 114 | """ 115 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 116 | For the meaning of the parameter `norm`, see: 117 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 118 | :param x: the input signal 119 | :param norm: the normalization, None or 'ortho' 120 | :return: the DCT-II of the signal over the last 2 dimensions 121 | """ 122 | X1 = dct(x, norm=norm) 123 | X2 = dct(X1.transpose(-1, -2), norm=norm) 124 | return X2.transpose(-1, -2) 125 | 126 | 127 | def idct_2d(X, norm=None): 128 | """ 129 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III 130 | Our definition of idct is that idct_2d(dct_2d(x)) == x 131 | For the meaning of the parameter `norm`, see: 132 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 133 | :param X: the input signal 134 | :param norm: the normalization, None or 'ortho' 135 | :return: the DCT-II of the signal over the last 2 dimensions 136 | """ 137 | x1 = idct(X, norm=norm) 138 | x2 = idct(x1.transpose(-1, -2), norm=norm) 139 | return x2.transpose(-1, -2) 140 | 141 | 142 | def dct_3d(x, norm=None): 143 | """ 144 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 145 | For the meaning of the parameter `norm`, see: 146 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 147 | :param x: the input signal 148 | :param norm: the normalization, None or 'ortho' 149 | :return: the DCT-II of the signal over the last 3 dimensions 150 | """ 151 | X1 = dct(x, norm=norm) 152 | X2 = dct(X1.transpose(-1, -2), norm=norm) 153 | X3 = dct(X2.transpose(-1, -3), norm=norm) 154 | return X3.transpose(-1, -3).transpose(-1, -2) 155 | 156 | 157 | def idct_3d(X, norm=None): 158 | """ 159 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III 160 | Our definition of idct is that idct_3d(dct_3d(x)) == x 161 | For the meaning of the parameter `norm`, see: 162 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 163 | :param X: the input signal 164 | :param norm: the normalization, None or 'ortho' 165 | :return: the DCT-II of the signal over the last 3 dimensions 166 | """ 167 | x1 = idct(X, norm=norm) 168 | x2 = idct(x1.transpose(-1, -2), norm=norm) 169 | x3 = idct(x2.transpose(-1, -3), norm=norm) 170 | return x3.transpose(-1, -3).transpose(-1, -2) 171 | 172 | 173 | class LinearDCT(nn.Linear): 174 | """Implement any DCT as a linear layer; in practice this executes around 175 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will 176 | increase memory usage. 177 | :param in_features: size of expected input 178 | :param type: which dct function in this file to use""" 179 | def __init__(self, in_features, type, norm=None, bias=False): 180 | self.type = type 181 | self.N = in_features 182 | self.norm = norm 183 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias) 184 | 185 | def reset_parameters(self): 186 | # initialise using dct function 187 | I = torch.eye(self.N) 188 | if self.type == 'dct1': 189 | self.weight.data = dct1(I).data.t() 190 | elif self.type == 'idct1': 191 | self.weight.data = idct1(I).data.t() 192 | elif self.type == 'dct': 193 | self.weight.data = dct(I, norm=self.norm).data.t() 194 | elif self.type == 'idct': 195 | self.weight.data = idct(I, norm=self.norm).data.t() 196 | self.weight.requires_grad = False # don't learn this! 197 | 198 | 199 | def apply_linear_2d(x, linear_layer): 200 | """Can be used with a LinearDCT layer to do a 2D DCT. 201 | :param x: the input signal 202 | :param linear_layer: any PyTorch Linear layer 203 | :return: result of linear layer applied to last 2 dimensions 204 | """ 205 | X1 = linear_layer(x) 206 | X2 = linear_layer(X1.transpose(-1, -2)) 207 | return X2.transpose(-1, -2) 208 | 209 | 210 | def apply_linear_3d(x, linear_layer): 211 | """Can be used with a LinearDCT layer to do a 3D DCT. 212 | :param x: the input signal 213 | :param linear_layer: any PyTorch Linear layer 214 | :return: result of linear layer applied to last 3 dimensions 215 | """ 216 | X1 = linear_layer(x) 217 | X2 = linear_layer(X1.transpose(-1, -2)) 218 | X3 = linear_layer(X2.transpose(-1, -3)) 219 | return X3.transpose(-1, -3).transpose(-1, -2) 220 | 221 | 222 | def torch_rgb2ycbcr(x): 223 | # Assume x is a batch of size (N x C x H x W) 224 | v = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).to(x.device) 225 | ycbcr = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1) 226 | ycbcr[:,1:] += 128 227 | return ycbcr 228 | 229 | 230 | def torch_ycbcr2rgb(x): 231 | # Assume x is a batch of size (N x C x H x W) 232 | v = torch.tensor([[ 1.00000000e+00, -3.68199903e-05, 1.40198758e+00], 233 | [ 1.00000000e+00, -3.44113281e-01, -7.14103821e-01], 234 | [ 1.00000000e+00, 1.77197812e+00, -1.34583413e-04]]).to(x.device) 235 | x[:, 1:] -= 128 236 | rgb = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1) 237 | return rgb 238 | 239 | def chroma_subsample(x): 240 | return x[:, 0:1, :, :], x[:, 1:, ::2, ::2] 241 | 242 | 243 | def general_quant_matrix(qf = 10): 244 | q1 = torch.tensor([ 245 | 16, 11, 10, 16, 24, 40, 51, 61, 246 | 12, 12, 14, 19, 26, 58, 60, 55, 247 | 14, 13, 16, 24, 40, 57, 69, 56, 248 | 14, 17, 22, 29, 51, 87, 80, 62, 249 | 18, 22, 37, 56, 68, 109, 103, 77, 250 | 24, 35, 55, 64, 81, 104, 113, 92, 251 | 49, 64, 78, 87, 103, 121, 120, 101, 252 | 72, 92, 95, 98, 112, 100, 103, 99 253 | ]) 254 | q2 = torch.tensor([ 255 | 17, 18, 24, 47, 99, 99, 99, 99, 256 | 18, 21, 26, 66, 99, 99, 99, 99, 257 | 24, 26, 56, 99, 99, 99, 99, 99, 258 | 47, 66, 99, 99, 99, 99, 99, 99, 259 | 99, 99, 99, 99, 99, 99, 99, 99, 260 | 99, 99, 99, 99, 99, 99, 99, 99, 261 | 99, 99, 99, 99, 99, 99, 99, 99, 262 | 99, 99, 99, 99, 99, 99, 99, 99 263 | ]) 264 | s = (5000 / qf) if qf < 50 else (200 - 2 * qf) 265 | q1 = torch.floor((s * q1 + 50) / 100) 266 | q1[q1 <= 0] = 1 267 | q1[q1 > 255] = 255 268 | q2 = torch.floor((s * q2 + 50) / 100) 269 | q2[q2 <= 0] = 1 270 | q2[q2 > 255] = 255 271 | return q1, q2 272 | 273 | 274 | def quantization_matrix(qf): 275 | return general_quant_matrix(qf) 276 | # q1 = torch.tensor([[ 80, 55, 50, 80, 120, 200, 255, 255], 277 | # [ 60, 60, 70, 95, 130, 255, 255, 255], 278 | # [ 70, 65, 80, 120, 200, 255, 255, 255], 279 | # [ 70, 85, 110, 145, 255, 255, 255, 255], 280 | # [ 90, 110, 185, 255, 255, 255, 255, 255], 281 | # [120, 175, 255, 255, 255, 255, 255, 255], 282 | # [245, 255, 255, 255, 255, 255, 255, 255], 283 | # [255, 255, 255, 255, 255, 255, 255, 255]]) 284 | # q2 = torch.tensor([[ 85, 90, 120, 235, 255, 255, 255, 255], 285 | # [ 90, 105, 130, 255, 255, 255, 255, 255], 286 | # [120, 130, 255, 255, 255, 255, 255, 255], 287 | # [235, 255, 255, 255, 255, 255, 255, 255], 288 | # [255, 255, 255, 255, 255, 255, 255, 255], 289 | # [255, 255, 255, 255, 255, 255, 255, 255], 290 | # [255, 255, 255, 255, 255, 255, 255, 255], 291 | # [255, 255, 255, 255, 255, 255, 255, 255]]) 292 | # return q1, q2 293 | 294 | def jpeg_encode(x, qf): 295 | # Assume x is a batch of size (N x C x H x W) 296 | # [-1, 1] to [0, 255] 297 | x = (x + 1) / 2 * 255 298 | n_batch, _, n_size, _ = x.shape 299 | 300 | x = torch_rgb2ycbcr(x) 301 | x_luma, x_chroma = chroma_subsample(x) 302 | unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8)) 303 | x_luma = unfold(x_luma).transpose(2, 1) 304 | x_chroma = unfold(x_chroma).transpose(2, 1) 305 | 306 | x_luma = x_luma.reshape(-1, 8, 8) - 128 307 | x_chroma = x_chroma.reshape(-1, 8, 8) - 128 308 | 309 | dct_layer = LinearDCT(8, 'dct', norm='ortho') 310 | dct_layer.to(x_luma.device) 311 | x_luma = apply_linear_2d(x_luma, dct_layer) 312 | x_chroma = apply_linear_2d(x_chroma, dct_layer) 313 | 314 | x_luma = x_luma.view(-1, 1, 8, 8) 315 | x_chroma = x_chroma.view(-1, 2, 8, 8) 316 | 317 | q1, q2 = quantization_matrix(qf) 318 | q1 = q1.to(x_luma.device) 319 | q2 = q2.to(x_luma.device) 320 | x_luma /= q1.view(1, 8, 8) 321 | x_chroma /= q2.view(1, 8, 8) 322 | 323 | x_luma = x_luma.round() 324 | x_chroma = x_chroma.round() 325 | 326 | x_luma = x_luma.reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1) 327 | x_chroma = x_chroma.reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1) 328 | 329 | fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8)) 330 | x_luma = fold(x_luma) 331 | fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8)) 332 | x_chroma = fold(x_chroma) 333 | 334 | return [x_luma, x_chroma] 335 | 336 | 337 | 338 | def jpeg_decode(x, qf): 339 | # Assume x[0] is a batch of size (N x 1 x H x W) (luma) 340 | # Assume x[1:] is a batch of size (N x 2 x H/2 x W/2) (chroma) 341 | x_luma, x_chroma = x 342 | n_batch, _, n_size, _ = x_luma.shape 343 | unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8)) 344 | x_luma = unfold(x_luma).transpose(2, 1) 345 | x_luma = x_luma.reshape(-1, 1, 8, 8) 346 | x_chroma = unfold(x_chroma).transpose(2, 1) 347 | x_chroma = x_chroma.reshape(-1, 2, 8, 8) 348 | 349 | q1, q2 = quantization_matrix(qf) 350 | q1 = q1.to(x_luma.device) 351 | q2 = q2.to(x_luma.device) 352 | x_luma *= q1.view(1, 8, 8) 353 | x_chroma *= q2.view(1, 8, 8) 354 | 355 | x_luma = x_luma.reshape(-1, 8, 8) 356 | x_chroma = x_chroma.reshape(-1, 8, 8) 357 | 358 | dct_layer = LinearDCT(8, 'idct', norm='ortho') 359 | dct_layer.to(x_luma.device) 360 | x_luma = apply_linear_2d(x_luma, dct_layer) 361 | x_chroma = apply_linear_2d(x_chroma, dct_layer) 362 | 363 | x_luma = (x_luma + 128).reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1) 364 | x_chroma = (x_chroma + 128).reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1) 365 | 366 | fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8)) 367 | x_luma = fold(x_luma) 368 | fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8)) 369 | x_chroma = fold(x_chroma) 370 | 371 | x_chroma_repeated = torch.zeros(n_batch, 2, n_size, n_size, device = x_luma.device) 372 | x_chroma_repeated[:, :, 0::2, 0::2] = x_chroma 373 | x_chroma_repeated[:, :, 0::2, 1::2] = x_chroma 374 | x_chroma_repeated[:, :, 1::2, 0::2] = x_chroma 375 | x_chroma_repeated[:, :, 1::2, 1::2] = x_chroma 376 | 377 | x = torch.cat([x_luma, x_chroma_repeated], dim=1) 378 | 379 | x = torch_ycbcr2rgb(x) 380 | 381 | # [0, 255] to [-1, 1] 382 | x = x / 255 * 2 - 1 383 | 384 | return x 385 | 386 | 387 | def build_jpeg(log, qf): 388 | log.info(f"[Corrupt] JPEG restoration: {qf=} ...") 389 | def jpeg(img): 390 | return jpeg_decode(jpeg_encode(img, qf), qf) 391 | return jpeg 392 | -------------------------------------------------------------------------------- /corruption/mixture.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import numpy as np 9 | import enum 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | 14 | from .jpeg import jpeg_decode, jpeg_encode 15 | from .blur import Deblurring 16 | from .superresolution import build_sr_bicubic, build_sr_pool 17 | from .inpaint import get_center_mask, load_freeform_masks 18 | 19 | from ipdb import set_trace as debug 20 | 21 | 22 | class AllCorrupt(enum.IntEnum): 23 | JPEG_5 = 0 24 | JPEG_10 = 1 25 | BLUR_UNI = 2 26 | BLUR_GAUSS = 3 27 | SR4X_POOL = 4 28 | SR4X_BICUBIC = 5 29 | INPAINT_CENTER = 6 30 | INPAINT_FREE1020 = 7 31 | INPAINT_FREE2030 = 8 32 | 33 | 34 | class MixtureCorruptMethod: 35 | def __init__(self, opt, device="cpu"): 36 | 37 | # ===== blur ==== 38 | self.blur_uni = Deblurring(torch.Tensor([1 / 9] * 9).to(device), 3, opt.image_size, device) 39 | 40 | sigma = 10 41 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 42 | g_kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(device) 43 | self.blur_gauss = Deblurring(g_kernel / g_kernel.sum(), 3, opt.image_size, device) 44 | 45 | # ===== sr4x ==== 46 | factor = 4 47 | self.sr4x_pool = build_sr_pool(factor, device, opt.image_size) 48 | self.sr4x_bicubic = build_sr_bicubic(factor, device, opt.image_size) 49 | self.upsample = torch.nn.Upsample(scale_factor=factor, mode="nearest") 50 | 51 | # ===== inpaint ==== 52 | self.center_mask = get_center_mask([opt.image_size, opt.image_size])[None, ...] # [1, 1, 256, 256] 53 | self.free1020_masks = torch.from_numpy((load_freeform_masks("freeform1020"))) # [10000, 1, 256, 256] 54 | self.free2030_masks = torch.from_numpy((load_freeform_masks("freeform2030"))) # [10000, 1, 256, 256] 55 | 56 | def jpeg(self, img, qf): 57 | return jpeg_decode(jpeg_encode(img, qf), qf) 58 | 59 | def blur(self, img, kernel): 60 | img = (img + 1) / 2 61 | if kernel == "uni": 62 | _img = self.blur_uni.H(img).reshape(*img.shape) 63 | elif kernel == "gauss": 64 | _img = self.blur_gauss.H(img).reshape(*img.shape) 65 | # [0,1] -> [-1,1] 66 | return _img * 2 - 1 67 | 68 | def sr4x(self, img, filter): 69 | b, c, w, h = img.shape 70 | if filter == "pool": 71 | _img = self.sr4x_pool.H(img).reshape(b, c, w // 4, h // 4) 72 | elif filter == "bicubic": 73 | _img = self.sr4x_bicubic.H(img).reshape(b, c, w // 4, h // 4) 74 | 75 | # scale to original image size for I2SB 76 | return self.upsample(_img) 77 | 78 | def inpaint(self, img, mask_type, mask_index=None): 79 | if mask_type == "center": 80 | mask = self.center_mask 81 | elif mask_type == "free1020": 82 | if mask_index is None: 83 | mask_index = np.random.randint(len(self.free1020_masks)) 84 | mask = self.free1020_masks[[mask_index]] 85 | elif mask_type == "free2030": 86 | if mask_index is None: 87 | mask_index = np.random.randint(len(self.free2030_masks)) 88 | mask = self.free2030_masks[[mask_index]] 89 | return img * (1.0 - mask) + mask * torch.randn_like(img) 90 | 91 | def mixture(self, img, corrupt_idx, mask_index=None): 92 | if corrupt_idx == AllCorrupt.JPEG_5: 93 | corrupt_img = self.jpeg(img, 5) 94 | elif corrupt_idx == AllCorrupt.JPEG_10: 95 | corrupt_img = self.jpeg(img, 10) 96 | elif corrupt_idx == AllCorrupt.BLUR_UNI: 97 | corrupt_img = self.blur(img, "uni") 98 | elif corrupt_idx == AllCorrupt.BLUR_GAUSS: 99 | corrupt_img = self.blur(img, "gauss") 100 | elif corrupt_idx == AllCorrupt.SR4X_POOL: 101 | corrupt_img = self.sr4x(img, "pool") 102 | elif corrupt_idx == AllCorrupt.SR4X_BICUBIC: 103 | corrupt_img = self.sr4x(img, "bicubic") 104 | elif corrupt_idx == AllCorrupt.INPAINT_CENTER: 105 | corrupt_img = self.inpaint(img, "center") 106 | elif corrupt_idx == AllCorrupt.INPAINT_FREE1020: 107 | corrupt_img = self.inpaint(img, "free1020", mask_index=mask_index) 108 | elif corrupt_idx == AllCorrupt.INPAINT_FREE2030: 109 | corrupt_img = self.inpaint(img, "free2030", mask_index=mask_index) 110 | return corrupt_img 111 | 112 | 113 | class MixtureCorruptDatasetTrain(Dataset): 114 | def __init__(self, opt, dataset): 115 | super(MixtureCorruptDatasetTrain, self).__init__() 116 | self.dataset = dataset 117 | self.method = MixtureCorruptMethod(opt) 118 | 119 | def __len__(self): 120 | return self.dataset.__len__() 121 | 122 | def __getitem__(self, index): 123 | clean_img, y = self.dataset[index] # clean_img: tensor [-1,1] 124 | 125 | rand_idx = np.random.choice(AllCorrupt) 126 | corrupt_img = self.method.mixture(clean_img.unsqueeze(0), rand_idx).squeeze(0) 127 | 128 | assert corrupt_img.shape == clean_img.shape, (clean_img.shape, corrupt_img.shape) 129 | return clean_img, corrupt_img, y 130 | 131 | 132 | class MixtureCorruptDatasetVal(Dataset): 133 | def __init__(self, opt, dataset): 134 | super(MixtureCorruptDatasetVal, self).__init__() 135 | self.dataset = dataset 136 | self.method = MixtureCorruptMethod(opt) 137 | 138 | def __len__(self): 139 | return self.dataset.__len__() 140 | 141 | def __getitem__(self, index): 142 | clean_img, y = self.dataset[index] # clean_img: tensor [-1,1] 143 | 144 | idx = index % len(AllCorrupt) 145 | corrupt_img = self.method.mixture(clean_img.unsqueeze(0), idx, mask_index=idx).squeeze(0) 146 | 147 | assert corrupt_img.shape == clean_img.shape, (clean_img.shape, corrupt_img.shape) 148 | return clean_img, corrupt_img, y 149 | -------------------------------------------------------------------------------- /corruption/superresolution.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ddrm. 5 | # 6 | # Source: 7 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L171 8 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L264 9 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L314 10 | # 11 | # The license for the original version of this file can be 12 | # found in this directory (LICENSE_DDRM). 13 | # The modifications to this file are subject to the same license. 14 | # --------------------------------------------------------------- 15 | 16 | import numpy as np 17 | import torch 18 | from .base import H_functions 19 | 20 | from ipdb import set_trace as debug 21 | 22 | 23 | class SuperResolution(H_functions): 24 | def __init__(self, channels, img_dim, ratio, device): # ratio = 2 or 4 25 | assert img_dim % ratio == 0 26 | self.img_dim = img_dim 27 | self.channels = channels 28 | self.y_dim = img_dim // ratio 29 | self.ratio = ratio 30 | H = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device) 31 | self.U_small, self.singulars_small, self.V_small = torch.svd(H, some=False) 32 | self.Vt_small = self.V_small.transpose(0, 1) 33 | 34 | def V(self, vec): 35 | # reorder the vector back into patches (because singulars are ordered descendingly) 36 | temp = vec.clone().reshape(vec.shape[0], -1) 37 | patches = torch.zeros(vec.shape[0], self.channels, self.y_dim**2, self.ratio**2, device=vec.device) 38 | patches[:, :, :, 0] = temp[:, : self.channels * self.y_dim**2].view(vec.shape[0], self.channels, -1) 39 | for idx in range(self.ratio**2 - 1): 40 | patches[:, :, :, idx + 1] = temp[:, (self.channels * self.y_dim**2 + idx) :: self.ratio**2 - 1].view( 41 | vec.shape[0], self.channels, -1 42 | ) 43 | # multiply each patch by the small V 44 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape( 45 | vec.shape[0], self.channels, -1, self.ratio**2 46 | ) 47 | # repatch the patches into an image 48 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 49 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 50 | recon = recon.reshape(vec.shape[0], self.channels * self.img_dim**2) 51 | return recon 52 | 53 | def Vt(self, vec): 54 | # extract flattened patches 55 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 56 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 57 | unfold_shape = patches.shape 58 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2) 59 | # multiply each by the small V transposed 60 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape( 61 | vec.shape[0], self.channels, -1, self.ratio**2 62 | ) 63 | # reorder the vector to have the first entry first (because singulars are ordered descendingly) 64 | recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 65 | recon[:, : self.channels * self.y_dim**2] = patches[:, :, :, 0].view( 66 | vec.shape[0], self.channels * self.y_dim**2 67 | ) 68 | for idx in range(self.ratio**2 - 1): 69 | recon[:, (self.channels * self.y_dim**2 + idx) :: self.ratio**2 - 1] = patches[:, :, :, idx + 1].view( 70 | vec.shape[0], self.channels * self.y_dim**2 71 | ) 72 | return recon 73 | 74 | def U(self, vec): 75 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 76 | 77 | def Ut(self, vec): # U is 1x1, so U^T = U 78 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 79 | 80 | def singulars(self): 81 | return self.singulars_small.repeat(self.channels * self.y_dim**2) 82 | 83 | def add_zeros(self, vec): 84 | reshaped = vec.clone().reshape(vec.shape[0], -1) 85 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 86 | temp[:, : reshaped.shape[1]] = reshaped 87 | return temp 88 | 89 | 90 | class SRConv(H_functions): 91 | def mat_by_img(self, M, v, dim): 92 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim, dim)).reshape( 93 | v.shape[0], self.channels, M.shape[0], dim 94 | ) 95 | 96 | def img_by_mat(self, v, M, dim): 97 | return torch.matmul(v.reshape(v.shape[0] * self.channels, dim, dim), M).reshape( 98 | v.shape[0], self.channels, dim, M.shape[1] 99 | ) 100 | 101 | def __init__(self, kernel, channels, img_dim, device, stride=1): 102 | self.img_dim = img_dim 103 | self.channels = channels 104 | self.ratio = stride 105 | small_dim = img_dim // stride 106 | self.small_dim = small_dim 107 | # build 1D conv matrix 108 | H_small = torch.zeros(small_dim, img_dim, device=device) 109 | for i in range(stride // 2, img_dim + stride // 2, stride): 110 | for j in range(i - kernel.shape[0] // 2, i + kernel.shape[0] // 2): 111 | j_effective = j 112 | # reflective padding 113 | if j_effective < 0: 114 | j_effective = -j_effective - 1 115 | if j_effective >= img_dim: 116 | j_effective = (img_dim - 1) - (j_effective - img_dim) 117 | # matrix building 118 | H_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0] // 2] 119 | # get the svd of the 1D conv 120 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False) 121 | ZERO = 3e-2 122 | self.singulars_small[self.singulars_small < ZERO] = 0 123 | # calculate the singular values of the big matrix 124 | self._singulars = torch.matmul( 125 | self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim) 126 | ).reshape(small_dim**2) 127 | # permutation for matching the singular values. See P_1 in Appendix D.5. 128 | self._perm = ( 129 | torch.Tensor( 130 | [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] 131 | + [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)] 132 | ) 133 | .to(device) 134 | .long() 135 | ) 136 | 137 | def V(self, vec): 138 | # invert the permutation 139 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 140 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[ 141 | :, : self._perm.shape[0], : 142 | ] 143 | temp[:, self._perm.shape[0] :, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[ 144 | :, self._perm.shape[0] :, : 145 | ] 146 | temp = temp.permute(0, 2, 1) 147 | # multiply the image by V from the left and by V^T from the right 148 | out = self.mat_by_img(self.V_small, temp, self.img_dim) 149 | out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1) 150 | return out 151 | 152 | def Vt(self, vec): 153 | # multiply the image by V^T from the left and by V from the right 154 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim) 155 | temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1) 156 | # permute the entries 157 | temp[:, :, : self._perm.shape[0]] = temp[:, :, self._perm] 158 | temp = temp.permute(0, 2, 1) 159 | return temp.reshape(vec.shape[0], -1) 160 | 161 | def U(self, vec): 162 | # invert the permutation 163 | temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device) 164 | temp[:, : self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels) 165 | temp = temp.permute(0, 2, 1) 166 | # multiply the image by U from the left and by U^T from the right 167 | out = self.mat_by_img(self.U_small, temp, self.small_dim) 168 | out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1) 169 | return out 170 | 171 | def Ut(self, vec): 172 | # multiply the image by U^T from the left and by U from the right 173 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim) 174 | temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1) 175 | # permute the entries 176 | temp = temp.permute(0, 2, 1) 177 | return temp.reshape(vec.shape[0], -1) 178 | 179 | def singulars(self): 180 | return self._singulars.repeat_interleave(3).reshape(-1) 181 | 182 | def add_zeros(self, vec): 183 | reshaped = vec.clone().reshape(vec.shape[0], -1) 184 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 185 | temp[:, : reshaped.shape[1]] = reshaped 186 | return temp 187 | 188 | 189 | # note: codes adoptted from 190 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L228 191 | def build_sr_bicubic(factor, device, image_size, data_channels=3): 192 | def bicubic_kernel(x, a=-0.5): 193 | if abs(x) <= 1: 194 | return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1 195 | elif 1 < abs(x) and abs(x) < 2: 196 | return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a 197 | else: 198 | return 0 199 | 200 | k = np.zeros((factor * 4)) 201 | for i in range(factor * 4): 202 | x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5) 203 | k[i] = bicubic_kernel(x) 204 | k = k / np.sum(k) 205 | kernel = torch.from_numpy(k).float().to(device) 206 | H_funcs = SRConv(kernel / kernel.sum(), data_channels, image_size, device, stride=factor) 207 | 208 | return H_funcs 209 | 210 | 211 | def build_sr_pool(factor, device, image_size, data_channels=3): 212 | H_funcs = SuperResolution(data_channels, image_size, factor, device) 213 | return H_funcs 214 | 215 | 216 | def build_sr4x(opt, log, sr_filter, image_size): 217 | assert sr_filter in ["pool", "bicubic"] 218 | log.info(f"[Corrupt] Super-resolution (4x): {sr_filter=} ...") 219 | 220 | factor = 4 221 | 222 | sr_bicubic = build_sr_bicubic(factor, opt.device, image_size) 223 | sr_pool = build_sr_pool(factor, opt.device, image_size) 224 | 225 | upsample = torch.nn.Upsample(scale_factor=factor, mode="nearest") 226 | 227 | assert sr_filter in ["pool", "bicubic"] 228 | 229 | def sr4x(img): 230 | b, c, w, h = img.shape 231 | img = img.to(opt.device) 232 | if sr_filter == "pool": 233 | _img = sr_pool.H(img).reshape(b, c, w // factor, h // factor) 234 | elif sr_filter == "bicubic": 235 | _img = sr_bicubic.H(img).reshape(b, c, w // factor, h // factor) 236 | 237 | # scale to original image size for I2SB 238 | return upsample(_img) 239 | 240 | return sr4x 241 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import torch.distributed as dist 5 | from torch.utils.data.sampler import Sampler 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | def get_data_scaler(config): 10 | """Data normalizer. Assume data are always in [0, 1].""" 11 | if config.data.centered: 12 | # Rescale to [-1, 1] 13 | return lambda x: x * 2.0 - 1.0 14 | else: 15 | return lambda x: x 16 | 17 | 18 | def get_data_inverse_scaler(config): 19 | """Inverse data normalizer.""" 20 | if config.data.centered: 21 | # Rescale [-1, 1] to [0, 1] 22 | return lambda x: (x + 1.0) / 2.0 23 | else: 24 | return lambda x: x 25 | 26 | 27 | class UniformDequant(object): 28 | def __call__(self, x): 29 | return x + torch.rand_like(x) / 256 30 | 31 | 32 | class RASampler(Sampler): 33 | """Sampler that restricts data loading to a subset of the dataset for distributed, 34 | with repeated augmentation. 35 | It ensures that different each augmented version of a sample will be visible to a 36 | different process (GPU). 37 | Heavily based on 'torch.utils.data.DistributedSampler'. 38 | This is borrowed from the DeiT Repo: 39 | https://github.com/facebookresearch/deit/blob/main/samplers.py 40 | """ 41 | 42 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): 43 | if num_replicas is None: 44 | num_replicas = dist.get_world_size() 45 | if rank is None: 46 | rank = dist.get_rank() 47 | self.dataset = dataset 48 | self.num_replicas = num_replicas 49 | self.rank = rank 50 | self.epoch = 0 51 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) 52 | self.total_size = self.num_samples * self.num_replicas 53 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 54 | self.shuffle = shuffle 55 | self.seed = seed 56 | self.repetitions = repetitions 57 | 58 | def __iter__(self): 59 | if self.shuffle: 60 | # Deterministically shuffle based on epoch 61 | g = torch.Generator() 62 | g.manual_seed(self.seed + self.epoch) 63 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 64 | else: 65 | indices = list(range(len(self.dataset))) 66 | 67 | # Add extra samples to make it evenly divisible 68 | indices = [ele for ele in indices for i in range(self.repetitions)] 69 | indices += indices[: (self.total_size - len(indices))] 70 | assert len(indices) == self.total_size 71 | 72 | # Subsample 73 | indices = indices[self.rank : self.total_size : self.num_replicas] 74 | assert len(indices) == self.num_samples 75 | 76 | return iter(indices[: self.num_selected_samples]) 77 | 78 | def __len__(self): 79 | return self.num_selected_samples 80 | 81 | def set_epoch(self, epoch): 82 | self.epoch = epoch 83 | 84 | 85 | class InfiniteBatchSampler(Sampler): 86 | def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False): 87 | self.dataset_len = dataset_len 88 | self.batch_size = batch_size 89 | self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size 90 | self.max_p = self.iters_per_ep * batch_size 91 | self.filling = filling 92 | self.shuffle = shuffle 93 | self.epoch = 0 94 | self.seed = seed 95 | self.indices = self.gener_indices() 96 | 97 | def gener_indices(self): 98 | if self.shuffle: 99 | g = torch.Generator() 100 | g.manual_seed(self.epoch + self.seed) 101 | indices = torch.randperm(self.dataset_len, generator=g).numpy() 102 | else: 103 | indices = torch.arange(self.dataset_len).numpy() 104 | 105 | tails = self.batch_size - (self.dataset_len % self.batch_size) 106 | if tails != self.batch_size and self.filling: 107 | tails = indices[:tails] 108 | np.random.shuffle(indices) 109 | indices = np.concatenate((indices, tails)) 110 | 111 | # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop) 112 | # noinspection PyTypeChecker 113 | return tuple(indices.tolist()) 114 | 115 | def __iter__(self): 116 | self.epoch = 0 117 | while True: 118 | self.epoch += 1 119 | p, q = 0, 0 120 | while p < self.max_p: 121 | q = p + self.batch_size 122 | yield self.indices[p:q] 123 | p = q 124 | if self.shuffle: 125 | self.indices = self.gener_indices() 126 | 127 | def __len__(self): 128 | return self.iters_per_ep 129 | 130 | 131 | class DistInfiniteBatchSampler(InfiniteBatchSampler): 132 | def __init__( 133 | self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True 134 | ): 135 | # from torchvision.models import ResNet50_Weights 136 | # RA sampler: https://github.com/pytorch/vision/blob/5521e9d01ee7033b9ee9d421c1ef6fb211ed3782/references/classification/sampler.py 137 | 138 | assert glb_batch_size % world_size == 0 139 | self.world_size, self.rank = world_size, rank 140 | self.dataset_len = dataset_len 141 | self.glb_batch_size = glb_batch_size 142 | self.batch_size = glb_batch_size // world_size 143 | 144 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size 145 | self.filling = filling 146 | self.shuffle = shuffle 147 | self.repeated_aug = repeated_aug 148 | self.epoch = 0 149 | self.seed = seed 150 | self.indices = self.gener_indices() 151 | 152 | def gener_indices(self): 153 | global_max_p = ( 154 | self.iters_per_ep * self.glb_batch_size 155 | ) # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 156 | if self.shuffle: 157 | g = torch.Generator() 158 | g.manual_seed(self.epoch + self.seed) 159 | global_indices = torch.randperm(self.dataset_len, generator=g) 160 | if self.repeated_aug > 1: 161 | global_indices = global_indices[ 162 | : (self.dataset_len + self.repeated_aug - 1) // self.repeated_aug 163 | ].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p] 164 | else: 165 | global_indices = torch.arange(self.dataset_len) 166 | filling = global_max_p - global_indices.shape[0] 167 | if filling > 0 and self.filling: 168 | global_indices = torch.cat((global_indices, global_indices[:filling])) 169 | global_indices = tuple(global_indices.numpy().tolist()) 170 | 171 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) 172 | local_indices = global_indices[seps[self.rank] : seps[self.rank + 1]] 173 | self.max_p = len(local_indices) 174 | return local_indices 175 | 176 | 177 | def load_data( 178 | data_dir, 179 | dataset, 180 | batch_size, 181 | image_size, 182 | deterministic=False, 183 | include_test=False, 184 | seed=42, 185 | num_workers=2, 186 | ): 187 | # Compute batch size for this worker. 188 | root = data_dir 189 | 190 | if dataset == "edges2handbags": 191 | 192 | from .aligned_dataset import EdgesDataset 193 | 194 | trainset = EdgesDataset(dataroot=root, train=True, img_size=image_size, random_crop=True, random_flip=True) 195 | 196 | valset = EdgesDataset(dataroot=root, train=True, img_size=image_size, random_crop=False, random_flip=False) 197 | if include_test: 198 | testset = EdgesDataset( 199 | dataroot=root, train=False, img_size=image_size, random_crop=False, random_flip=False 200 | ) 201 | 202 | elif dataset == "diode": 203 | 204 | from .aligned_dataset import DIODE 205 | 206 | trainset = DIODE( 207 | dataroot=root, train=True, img_size=image_size, random_crop=True, random_flip=True, disable_cache=True 208 | ) 209 | 210 | valset = DIODE( 211 | dataroot=root, train=True, img_size=image_size, random_crop=False, random_flip=False, disable_cache=True 212 | ) 213 | 214 | if include_test: 215 | testset = DIODE(dataroot=root, train=False, img_size=image_size, random_crop=False, random_flip=False) 216 | 217 | elif "imagenet_inpaint" in dataset: 218 | corrupt_type = dataset.split("_")[-1] 219 | assert corrupt_type in ["center", "freeform2030"] 220 | from .imagenet_inpaint import ImageNetInpaintingDataset, InpaintingVal10kSubset 221 | 222 | trainset = ImageNetInpaintingDataset(root, image_size, corrupt_type, train=True) 223 | valset = ImageNetInpaintingDataset(root, image_size, corrupt_type, train=False) 224 | 225 | if include_test: 226 | testset = InpaintingVal10kSubset(root, image_size, corrupt_type) 227 | 228 | loader = DataLoader( 229 | dataset=trainset, 230 | num_workers=num_workers, 231 | pin_memory=True, 232 | batch_sampler=DistInfiniteBatchSampler( 233 | dataset_len=len(trainset), 234 | glb_batch_size=batch_size * dist.get_world_size(), 235 | seed=seed, 236 | shuffle=not deterministic, 237 | filling=True, 238 | rank=dist.get_rank(), 239 | world_size=dist.get_world_size(), 240 | ), 241 | ) 242 | 243 | num_tasks = dist.get_world_size() 244 | global_rank = dist.get_rank() 245 | sampler = torch.utils.data.DistributedSampler( 246 | valset, num_replicas=num_tasks, rank=global_rank, shuffle=False, drop_last=False 247 | ) 248 | val_loader = torch.utils.data.DataLoader( 249 | valset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, drop_last=False 250 | ) 251 | 252 | if include_test: 253 | 254 | num_tasks = dist.get_world_size() 255 | global_rank = dist.get_rank() 256 | sampler = torch.utils.data.DistributedSampler( 257 | testset, num_replicas=num_tasks, rank=global_rank, shuffle=False, drop_last=False 258 | ) 259 | test_loader = torch.utils.data.DataLoader( 260 | testset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, shuffle=False, drop_last=False 261 | ) 262 | 263 | return loader, val_loader, test_loader 264 | else: 265 | return loader, val_loader 266 | -------------------------------------------------------------------------------- /datasets/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import random 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from .image_folder import make_dataset 7 | from PIL import Image 8 | import blobfile as bf 9 | 10 | 11 | def get_params(size, resize_size, crop_size): 12 | w, h = size 13 | new_h = h 14 | new_w = w 15 | 16 | ss, ls = min(w, h), max(w, h) # shortside and longside 17 | width_is_shorter = w == ss 18 | ls = int(resize_size * ls / ss) 19 | ss = resize_size 20 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) 21 | 22 | x = random.randint(0, np.maximum(0, new_w - crop_size)) 23 | y = random.randint(0, np.maximum(0, new_h - crop_size)) 24 | 25 | flip = random.random() > 0.5 26 | return {"crop_pos": (x, y), "flip": flip} 27 | 28 | 29 | def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop=True, totensor=True): 30 | transform_list = [] 31 | transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method))) 32 | 33 | if flip: 34 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"]))) 35 | if totensor: 36 | transform_list.append(transforms.ToTensor()) 37 | transform_list.append(transforms.Lambda(lambda t: (t * 2) - 1)) # [0,1] --> [-1, 1]) 38 | return transforms.Compose(transform_list) 39 | 40 | 41 | def get_tensor(normalize=True, toTensor=True): 42 | transform_list = [] 43 | if toTensor: 44 | transform_list += [transforms.ToTensor()] 45 | 46 | if normalize: 47 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 48 | return transforms.Compose(transform_list) 49 | 50 | 51 | def normalize(): 52 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 53 | 54 | 55 | def __scale(img, target_width, method=Image.BICUBIC): 56 | if isinstance(img, torch.Tensor): 57 | return torch.nn.functional.interpolate( 58 | img.unsqueeze(0), size=(target_width, target_width), mode="bicubic", align_corners=False 59 | ).squeeze(0) 60 | else: 61 | return img.resize((target_width, target_width), method) 62 | 63 | 64 | def __flip(img, flip): 65 | if flip: 66 | if isinstance(img, torch.Tensor): 67 | return img.flip(-1) 68 | else: 69 | return img.transpose(Image.FLIP_LEFT_RIGHT) 70 | return img 71 | 72 | 73 | def get_flip(img, flip): 74 | return __flip(img, flip) 75 | 76 | 77 | class EdgesDataset(torch.utils.data.Dataset): 78 | """A dataset class for paired image dataset. 79 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 80 | During test time, you need to prepare a directory '/path/to/data/test'. 81 | """ 82 | 83 | def __init__(self, dataroot, train=True, img_size=256, random_crop=False, random_flip=True): 84 | """Initialize this dataset class. 85 | Parameters: 86 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 87 | """ 88 | super().__init__() 89 | if train: 90 | self.train_dir = os.path.join(dataroot, "train") # get the image directory 91 | self.train_paths = make_dataset(self.train_dir) # get image paths 92 | self.AB_paths = sorted(self.train_paths) 93 | else: 94 | 95 | self.test_dir = os.path.join(dataroot, "val") # get the image directory 96 | 97 | self.AB_paths = make_dataset(self.test_dir) # get image paths 98 | 99 | self.crop_size = img_size 100 | self.resize_size = img_size 101 | 102 | self.random_crop = random_crop 103 | self.random_flip = random_flip 104 | self.train = train 105 | 106 | def __getitem__(self, index): 107 | """Return a data point and its metadata information. 108 | Parameters: 109 | index - - a random integer for data indexing 110 | Returns a dictionary that contains A, B, A_paths and B_paths 111 | A (tensor) - - an image in the input domain 112 | B (tensor) - - its corresponding image in the target domain 113 | A_paths (str) - - image paths 114 | B_paths (str) - - image paths (same as A_paths) 115 | """ 116 | # read a image given a random integer index 117 | 118 | AB_path = self.AB_paths[index] 119 | AB = Image.open(AB_path).convert("RGB") 120 | # split AB image into A and B 121 | w, h = AB.size 122 | w2 = int(w / 2) 123 | A = AB.crop((0, 0, w2, h)) 124 | B = AB.crop((w2, 0, w, h)) 125 | 126 | # apply the same transform to both A and B 127 | params = get_params(A.size, self.resize_size, self.crop_size) 128 | 129 | transform_image = get_transform( 130 | params, self.resize_size, self.crop_size, crop=self.random_crop, flip=self.random_flip 131 | ) 132 | 133 | A = transform_image(A) 134 | B = transform_image(B) 135 | 136 | return B, A, (index, AB_path) 137 | 138 | def __len__(self): 139 | """Return the total number of images in the dataset.""" 140 | return len(self.AB_paths) 141 | 142 | 143 | class DIODE(torch.utils.data.Dataset): 144 | """A dataset class for paired image dataset. 145 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 146 | During test time, you need to prepare a directory '/path/to/data/test'. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | dataroot, 152 | train=True, 153 | img_size=256, 154 | random_crop=False, 155 | random_flip=True, 156 | down_sample_img_size=0, 157 | cache_name="cache", 158 | disable_cache=False, 159 | ): 160 | """Initialize this dataset class. 161 | Parameters: 162 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 163 | """ 164 | super().__init__() 165 | self.image_root = os.path.join(dataroot, "train" if train else "val") 166 | self.crop_size = img_size 167 | self.resize_size = img_size 168 | 169 | self.random_crop = random_crop 170 | self.random_flip = random_flip 171 | self.train = train 172 | 173 | self.filenames = [ 174 | l 175 | for l in os.listdir(self.image_root) 176 | if not l.endswith(".pth") and not l.endswith("_depth.png") and not l.endswith("_normal.png") 177 | ] 178 | 179 | self.cache_path = os.path.join(self.image_root, cache_name + f"_{img_size}.pth") 180 | if os.path.exists(self.cache_path) and not disable_cache: 181 | self.cache = torch.load(self.cache_path) 182 | # self.cache['img'] = self.cache['img'][:256] 183 | self.scale_factor = self.cache["scale_factor"] 184 | print("Loaded cache from {}".format(self.cache_path)) 185 | else: 186 | self.cache = None 187 | 188 | def __getitem__(self, index): 189 | """Return a data point and its metadata information. 190 | Parameters: 191 | index - - a random integer for data indexing 192 | Returns a dictionary that contains A, B, A_paths and B_paths 193 | A (tensor) - - an image in the input domain 194 | B (tensor) - - its corresponding image in the target domain 195 | A_paths (str) - - image paths 196 | B_paths (str) - - image paths (same as A_paths) 197 | """ 198 | # read a image given a random integer index 199 | 200 | fn = self.filenames[index] 201 | img_path = os.path.join(self.image_root, fn) 202 | label_path = os.path.join(self.image_root, fn[:-4] + "_normal.png") 203 | 204 | with bf.BlobFile(img_path, "rb") as f: 205 | pil_image = Image.open(f) 206 | pil_image.load() 207 | pil_image = pil_image.convert("RGB") 208 | 209 | with bf.BlobFile(label_path, "rb") as f: 210 | pil_label = Image.open(f) 211 | pil_label.load() 212 | pil_label = pil_label.convert("RGB") 213 | 214 | # apply the same transform to both A and B 215 | params = get_params(pil_image.size, self.resize_size, self.crop_size) 216 | 217 | transform_label = get_transform( 218 | params, self.resize_size, self.crop_size, method=Image.NEAREST, crop=False, flip=self.random_flip 219 | ) 220 | transform_image = get_transform(params, self.resize_size, self.crop_size, crop=False, flip=self.random_flip) 221 | 222 | cond = transform_label(pil_label) 223 | img = transform_image(pil_image) 224 | 225 | return img, cond, (index, fn) 226 | 227 | def __len__(self): 228 | """Return the total number of images in the dataset.""" 229 | if self.cache is not None: 230 | return len(self.cache["img"]) 231 | else: 232 | return len(self.filenames) 233 | -------------------------------------------------------------------------------- /datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | '.tif', '.TIF', '.tiff', '.TIFF', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | 18 | def make_dataset(dir, max_dataset_size=float("inf")): 19 | images = [] 20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 21 | 22 | for root, _, fnames in sorted(os.walk(dir)): 23 | for fname in fnames: 24 | if is_image_file(fname): 25 | path = os.path.join(root, fname) 26 | images.append(path) 27 | return images[:min(max_dataset_size, len(images))] 28 | 29 | 30 | def default_loader(path): 31 | return Image.open(path).convert('RGB') 32 | 33 | 34 | class ImageFolder(data.Dataset): 35 | 36 | def __init__(self, root, transform=None, return_paths=False, 37 | loader=default_loader): 38 | imgs = make_dataset(root) 39 | if len(imgs) == 0: 40 | raise(RuntimeError("Found 0 images in: " + root + "\n" 41 | "Supported image extensions are: " + 42 | ",".join(IMG_EXTENSIONS))) 43 | 44 | self.root = root 45 | self.imgs = imgs 46 | self.transform = transform 47 | self.return_paths = return_paths 48 | self.loader = loader 49 | 50 | def __getitem__(self, index): 51 | path = self.imgs[index] 52 | img = self.loader(path) 53 | if self.transform is not None: 54 | img = self.transform(img) 55 | if self.return_paths: 56 | return img, path 57 | else: 58 | return img 59 | 60 | def __len__(self): 61 | return len(self.imgs) -------------------------------------------------------------------------------- /datasets/imagenet_inpaint.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import io 10 | 11 | from PIL import Image 12 | import lmdb 13 | 14 | import torch 15 | import torchvision.datasets as datasets 16 | from torchvision import transforms 17 | from torch.utils.data import Dataset 18 | 19 | from corruption.inpaint import build_inpaint_freeform, build_inpaint_center 20 | 21 | # from ipdb import set_trace as debug 22 | 23 | 24 | def lmdb_loader(path, lmdb_data): 25 | # In-memory binary streams 26 | # print(path) 27 | with lmdb_data.begin(write=False, buffers=True) as txn: 28 | bytedata = txn.get(path.encode()) 29 | img = Image.open(io.BytesIO(bytedata)) 30 | return img.convert("RGB") 31 | 32 | 33 | def _build_lmdb_dataset(root, transform=None, target_transform=None, loader=lmdb_loader): 34 | """ 35 | You can create this dataloader using: 36 | train_data = _build_lmdb_dataset(traindir, transform=train_transform) 37 | valid_data = _build_lmdb_dataset(validdir, transform=val_transform) 38 | """ 39 | 40 | root = str(root) 41 | if root.endswith("/"): 42 | root = root[:-1] 43 | pt_path = os.path.join(root + "_faster_imagefolder.lmdb.pt") 44 | lmdb_path = os.path.join(root + "_faster_imagefolder.lmdb") 45 | 46 | if os.path.isfile(pt_path) and os.path.isdir(lmdb_path): 47 | # log.info('[Dataset] Loading pt {} and lmdb {}'.format(pt_path, lmdb_path)) 48 | data_set = torch.load(pt_path) 49 | else: 50 | data_set = datasets.ImageFolder(root, None, None, None) 51 | torch.save(data_set, pt_path, pickle_protocol=4) 52 | # log.info('[Dataset] Saving pt to {}'.format(pt_path)) 53 | # log.info('[Dataset] Building lmdb to {}'.format(lmdb_path)) 54 | env = lmdb.open(lmdb_path, map_size=1e12) 55 | with env.begin(write=True) as txn: 56 | for _path, class_index in data_set.imgs: 57 | with open(_path, "rb") as f: 58 | data = f.read() 59 | txn.put(_path.encode("ascii"), data) 60 | data_set.lmdb_data = lmdb.open(lmdb_path, readonly=True, max_readers=1, lock=False, readahead=False, meminit=False) 61 | # reset transform and target_transform 62 | data_set.samples = data_set.imgs 63 | data_set.transform = transform 64 | data_set.target_transform = target_transform 65 | data_set.loader = lambda path: loader(path, data_set.lmdb_data) 66 | 67 | return data_set 68 | 69 | 70 | def build_train_transform(image_size): 71 | return transforms.Compose( 72 | [ 73 | transforms.Resize(image_size), 74 | transforms.CenterCrop(image_size), 75 | transforms.RandomHorizontalFlip(p=0.5), 76 | transforms.ToTensor(), 77 | transforms.Lambda(lambda t: (t * 2) - 1), # [0,1] --> [-1, 1] 78 | ] 79 | ) 80 | 81 | 82 | def build_test_transform(image_size): 83 | return transforms.Compose( 84 | [ 85 | transforms.Resize(image_size), 86 | transforms.CenterCrop(image_size), 87 | # transforms.RandomHorizontalFlip(p=0.5), 88 | transforms.ToTensor(), 89 | transforms.Lambda(lambda t: (t * 2) - 1), # [0,1] --> [-1, 1] 90 | ] 91 | ) 92 | 93 | 94 | def build_lmdb_dataset(dataset_dir, image_size, train, transform=None): 95 | """resize -> crop -> to_tensor -> norm(-1,1)""" 96 | fn = os.path.join(dataset_dir, "train" if train else "val") 97 | 98 | if transform is None: 99 | build_transform = build_train_transform if train else build_test_transform 100 | transform = build_transform(image_size) 101 | 102 | dataset = _build_lmdb_dataset(fn, transform=transform) 103 | # log.info(f"[Dataset] Built Imagenet dataset {fn=}, size={len(dataset)}!") 104 | return dataset 105 | 106 | 107 | def readlines(fn): 108 | file = open(fn, "r").readlines() 109 | return [line.strip("\n\r") for line in file] 110 | 111 | 112 | def build_lmdb_dataset_val10k(dataset_dir, image_size, transform=None): 113 | fn = os.path.join(dataset_dir, "val") 114 | fn_10k = readlines("assets/datasets/val_faster_imagefolder_10k_fn.txt") 115 | label_10k = readlines("assets/datasets/val_faster_imagefolder_10k_label.txt") 116 | 117 | if transform is None: 118 | transform = build_test_transform(image_size) 119 | dataset = _build_lmdb_dataset(fn, transform=transform) 120 | dataset.samples = [(os.path.join(fn, name), int(label)) for name, label in zip(fn_10k, label_10k)] 121 | 122 | assert len(dataset) == 10_000 123 | # log.info(f"[Dataset] Built Imagenet val10k, size={len(dataset)}!") 124 | return dataset 125 | 126 | 127 | class InpaintingVal10kSubset(Dataset): 128 | def __init__(self, dataset_dir, image_size, corrupt_type, transform=None): 129 | super(InpaintingVal10kSubset, self).__init__() 130 | self.dataset = build_lmdb_dataset_val10k(dataset_dir, image_size, transform) 131 | 132 | if corrupt_type == "center": 133 | self.corrupt = build_inpaint_center(corrupt_type, image_size) 134 | elif "freeform" in corrupt_type: 135 | self.corrupt = build_inpaint_freeform(corrupt_type) 136 | else: 137 | raise NotImplementedError() 138 | 139 | def __len__(self): 140 | return self.dataset.__len__() 141 | 142 | def __getitem__(self, index): 143 | img_clean, label = self.dataset[index] 144 | img_corrupt, mask = self.corrupt(img_clean) 145 | return img_clean, img_corrupt, (index, mask, label) 146 | 147 | 148 | class ImageNetInpaintingDataset(Dataset): 149 | def __init__(self, dataset_dir, image_size, corrupt_type, train, transform=None): 150 | super(ImageNetInpaintingDataset, self).__init__() 151 | self.dataset = build_lmdb_dataset(dataset_dir, image_size, train, transform) 152 | 153 | if corrupt_type == "center": 154 | self.corrupt = build_inpaint_center(corrupt_type, image_size) 155 | elif "freeform" in corrupt_type: 156 | self.corrupt = build_inpaint_freeform(corrupt_type) 157 | else: 158 | raise NotImplementedError() 159 | 160 | def __len__(self): 161 | return self.dataset.__len__() 162 | 163 | def __getitem__(self, index): 164 | img_clean, label = self.dataset[index] 165 | img_corrupt, mask = self.corrupt(img_clean) 166 | return img_clean, img_corrupt, (index, mask, label) 167 | -------------------------------------------------------------------------------- /datasets/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | # ---------------------------------------------------------------------------- 12 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 13 | # same constant is used multiple times. 14 | 15 | _constant_cache = dict() 16 | 17 | 18 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 19 | value = np.asarray(value) 20 | if shape is not None: 21 | shape = tuple(shape) 22 | if dtype is None: 23 | dtype = torch.get_default_dtype() 24 | if device is None: 25 | device = torch.device("cpu") 26 | if memory_format is None: 27 | memory_format = torch.contiguous_format 28 | 29 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 30 | tensor = _constant_cache.get(key, None) 31 | if tensor is None: 32 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 33 | if shape is not None: 34 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 35 | tensor = tensor.contiguous(memory_format=memory_format) 36 | _constant_cache[key] = tensor 37 | return tensor 38 | -------------------------------------------------------------------------------- /ddbm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DiffusionBridge/92522733cc602686df77f07a1824bb89f89cda1a/ddbm/__init__.py -------------------------------------------------------------------------------- /ddbm/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import os 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 11 | WORLD_SIZE = int(os.environ["WORLD_SIZE"]) 12 | WORLD_RANK = int(os.environ["RANK"]) 13 | 14 | 15 | def setup_dist(): 16 | """ 17 | Setup a distributed process group. 18 | """ 19 | if dist.is_initialized(): 20 | return 21 | 22 | torch.cuda.set_device(LOCAL_RANK) 23 | backend = "gloo" if not torch.cuda.is_available() else "nccl" 24 | dist.init_process_group(backend) 25 | 26 | 27 | def dev(): 28 | """ 29 | Get the device to use for torch.distributed. 30 | """ 31 | if torch.cuda.is_available(): 32 | return torch.device("cuda") 33 | return torch.device("cpu") 34 | -------------------------------------------------------------------------------- /ddbm/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import os.path as osp 9 | import json 10 | import time 11 | import datetime 12 | import tempfile 13 | import warnings 14 | from collections import defaultdict 15 | from contextlib import contextmanager 16 | import torch.distributed as dist 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), "expected file or str, got %s" % filename_or_file 43 | self.file = filename_or_file 44 | self.own_file = False 45 | 46 | def writekvs(self, kvs): 47 | if dist.get_rank() != 0 and self.file == sys.stdout: 48 | return 49 | # Create strings for printing 50 | key2str = {} 51 | for key, val in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append("| %s%s | %s%s |" % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))) 71 | lines.append(dashes) 72 | self.file.write("\n".join(lines) + "\n") 73 | 74 | # Flush the output to the file 75 | self.file.flush() 76 | 77 | def _truncate(self, s): 78 | maxlen = 30 79 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 80 | 81 | def writeseq(self, seq): 82 | seq = list(seq) 83 | for i, elem in enumerate(seq): 84 | self.file.write(elem) 85 | if i < len(seq) - 1: # add space unless this is the last one 86 | self.file.write(" ") 87 | self.file.write("\n") 88 | self.file.flush() 89 | 90 | def close(self): 91 | if self.own_file: 92 | self.file.close() 93 | 94 | 95 | class JSONOutputFormat(KVWriter): 96 | def __init__(self, filename): 97 | self.file = open(filename, "wt") 98 | 99 | def writekvs(self, kvs): 100 | for k, v in sorted(kvs.items()): 101 | if hasattr(v, "dtype"): 102 | kvs[k] = float(v) 103 | self.file.write(json.dumps(kvs) + "\n") 104 | self.file.flush() 105 | 106 | def close(self): 107 | self.file.close() 108 | 109 | 110 | class CSVOutputFormat(KVWriter): 111 | def __init__(self, filename): 112 | self.file = open(filename, "w+t") 113 | self.keys = [] 114 | self.sep = "," 115 | 116 | def writekvs(self, kvs): 117 | # Add our current row to the history 118 | extra_keys = list(kvs.keys() - self.keys) 119 | extra_keys.sort() 120 | if extra_keys: 121 | self.keys.extend(extra_keys) 122 | self.file.seek(0) 123 | lines = self.file.readlines() 124 | self.file.seek(0) 125 | for i, k in enumerate(self.keys): 126 | if i > 0: 127 | self.file.write(",") 128 | self.file.write(k) 129 | self.file.write("\n") 130 | for line in lines[1:]: 131 | self.file.write(line[:-1]) 132 | self.file.write(self.sep * len(extra_keys)) 133 | self.file.write("\n") 134 | for i, k in enumerate(self.keys): 135 | if i > 0: 136 | self.file.write(",") 137 | v = kvs.get(k) 138 | if v is not None: 139 | self.file.write(str(v)) 140 | self.file.write("\n") 141 | self.file.flush() 142 | 143 | def close(self): 144 | self.file.close() 145 | 146 | 147 | class TensorBoardOutputFormat(KVWriter): 148 | """ 149 | Dumps key/value pairs into TensorBoard's numeric format. 150 | """ 151 | 152 | def __init__(self, dir): 153 | os.makedirs(dir, exist_ok=True) 154 | self.dir = dir 155 | self.step = 1 156 | prefix = "events" 157 | path = osp.join(osp.abspath(dir), prefix) 158 | import tensorflow as tf 159 | from tensorflow.python import pywrap_tensorflow 160 | from tensorflow.core.util import event_pb2 161 | from tensorflow.python.util import compat 162 | 163 | self.tf = tf 164 | self.event_pb2 = event_pb2 165 | self.pywrap_tensorflow = pywrap_tensorflow 166 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 167 | 168 | def writekvs(self, kvs): 169 | def summary_val(k, v): 170 | kwargs = {"tag": k, "simple_value": float(v)} 171 | return self.tf.Summary.Value(**kwargs) 172 | 173 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 174 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 175 | event.step = self.step # is there any reason why you'd want to specify the step? 176 | self.writer.WriteEvent(event) 177 | self.writer.Flush() 178 | self.step += 1 179 | 180 | def close(self): 181 | if self.writer: 182 | self.writer.Close() 183 | self.writer = None 184 | 185 | 186 | def make_output_format(format, ev_dir, log_suffix=""): 187 | os.makedirs(ev_dir, exist_ok=True) 188 | if format == "stdout": 189 | return HumanOutputFormat(sys.stdout) 190 | elif format == "log": 191 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 192 | elif format == "json": 193 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 194 | elif format == "csv": 195 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 196 | elif format == "tensorboard": 197 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 198 | else: 199 | raise ValueError("Unknown format specified: %s" % (format,)) 200 | 201 | 202 | # ================================================================ 203 | # API 204 | # ================================================================ 205 | 206 | 207 | def logkv(key, val): 208 | """ 209 | Log a value of some diagnostic 210 | Call this once for each diagnostic quantity, each iteration 211 | If called many times, last value will be used. 212 | """ 213 | get_current().logkv(key, val) 214 | 215 | 216 | def logkv_mean(key, val): 217 | """ 218 | The same as logkv(), but if called many times, values averaged. 219 | """ 220 | get_current().logkv_mean(key, val) 221 | 222 | 223 | def logkvs(d): 224 | """ 225 | Log a dictionary of key-value pairs 226 | """ 227 | for k, v in d.items(): 228 | logkv(k, v) 229 | 230 | 231 | def dumpkvs(): 232 | """ 233 | Write all of the diagnostics from the current iteration 234 | """ 235 | return get_current().dumpkvs() 236 | 237 | 238 | def getkvs(): 239 | return get_current().name2val 240 | 241 | 242 | def log(*args, level=INFO): 243 | """ 244 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 245 | """ 246 | get_current().log(*args, level=level) 247 | 248 | 249 | def debug(*args): 250 | log(*args, level=DEBUG) 251 | 252 | 253 | def info(*args): 254 | log(*args, level=INFO) 255 | 256 | 257 | def warn(*args): 258 | log(*args, level=WARN) 259 | 260 | 261 | def error(*args): 262 | log(*args, level=ERROR) 263 | 264 | 265 | def set_level(level): 266 | """ 267 | Set logging threshold on current logger. 268 | """ 269 | get_current().set_level(level) 270 | 271 | 272 | def set_comm(comm): 273 | get_current().set_comm(comm) 274 | 275 | 276 | def get_dir(): 277 | """ 278 | Get directory that log files are being written to. 279 | will be None if there is no output directory (i.e., if you didn't call start) 280 | """ 281 | return get_current().get_dir() 282 | 283 | 284 | record_tabular = logkv 285 | dump_tabular = dumpkvs 286 | 287 | 288 | @contextmanager 289 | def profile_kv(scopename): 290 | logkey = "wait_" + scopename 291 | tstart = time.time() 292 | try: 293 | yield 294 | finally: 295 | get_current().name2val[logkey] += time.time() - tstart 296 | 297 | 298 | def profile(n): 299 | """ 300 | Usage: 301 | @profile("my_func") 302 | def my_func(): code 303 | """ 304 | 305 | def decorator_with_name(func): 306 | def func_wrapper(*args, **kwargs): 307 | with profile_kv(n): 308 | return func(*args, **kwargs) 309 | 310 | return func_wrapper 311 | 312 | return decorator_with_name 313 | 314 | 315 | # ================================================================ 316 | # Backend 317 | # ================================================================ 318 | 319 | 320 | def get_current(dir=None): 321 | if Logger.CURRENT is None: 322 | _configure_default_logger(dir=dir) 323 | 324 | return Logger.CURRENT 325 | 326 | 327 | class Logger(object): 328 | DEFAULT = None # A logger with no output files. (See right below class definition) 329 | # So that you can still log to the terminal without setting up any output files 330 | CURRENT = None # Current logger being used by the free functions above 331 | 332 | def __init__(self, dir, output_formats, comm=None): 333 | self.name2val = defaultdict(float) # values this iteration 334 | self.name2cnt = defaultdict(int) 335 | self.level = INFO 336 | self.dir = dir 337 | self.output_formats = output_formats 338 | self.comm = comm 339 | 340 | # Logging API, forwarded 341 | # ---------------------------------------- 342 | def logkv(self, key, val): 343 | self.name2val[key] = val 344 | 345 | def logkv_mean(self, key, val): 346 | oldval, cnt = self.name2val[key], self.name2cnt[key] 347 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 348 | self.name2cnt[key] = cnt + 1 349 | 350 | def dumpkvs(self): 351 | if self.comm is None: 352 | d = self.name2val 353 | else: 354 | d = mpi_weighted_mean( 355 | self.comm, 356 | {name: (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()}, 357 | ) 358 | if self.comm.rank != 0: 359 | d["dummy"] = 1 # so we don't get a warning about empty dict 360 | out = d.copy() # Return the dict for unit testing purposes 361 | for fmt in self.output_formats: 362 | if isinstance(fmt, KVWriter): 363 | fmt.writekvs(d) 364 | self.name2val.clear() 365 | self.name2cnt.clear() 366 | return out 367 | 368 | def log(self, *args, level=INFO): 369 | if self.level <= level: 370 | self._do_log(args) 371 | 372 | # Configuration 373 | # ---------------------------------------- 374 | def set_level(self, level): 375 | self.level = level 376 | 377 | def set_comm(self, comm): 378 | self.comm = comm 379 | 380 | def get_dir(self): 381 | return self.dir 382 | 383 | def close(self): 384 | for fmt in self.output_formats: 385 | fmt.close() 386 | 387 | # Misc 388 | # ---------------------------------------- 389 | def _do_log(self, args): 390 | for fmt in self.output_formats: 391 | if isinstance(fmt, SeqWriter): 392 | fmt.writeseq(map(str, args)) 393 | 394 | 395 | def get_rank_without_mpi_import(): 396 | # check environment variables here instead of importing mpi4py 397 | # to avoid calling MPI_Init() when this module is imported 398 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 399 | if varname in os.environ: 400 | return int(os.environ[varname]) 401 | return 0 402 | 403 | 404 | def mpi_weighted_mean(comm, local_name2valcount): 405 | """ 406 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 407 | Perform a weighted average over dicts that are each on a different node 408 | Input: local_name2valcount: dict mapping key -> (value, count) 409 | Returns: key -> mean 410 | """ 411 | all_name2valcount = comm.gather(local_name2valcount) 412 | if comm.rank == 0: 413 | name2sum = defaultdict(float) 414 | name2count = defaultdict(float) 415 | for n2vc in all_name2valcount: 416 | for name, (val, count) in n2vc.items(): 417 | try: 418 | val = float(val) 419 | except ValueError: 420 | if comm.rank == 0: 421 | warnings.warn("WARNING: tried to compute mean on non-float {}={}".format(name, val)) 422 | else: 423 | name2sum[name] += val * count 424 | name2count[name] += count 425 | return {name: name2sum[name] / name2count[name] for name in name2sum} 426 | else: 427 | return {} 428 | 429 | 430 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 431 | """ 432 | If comm is provided, average all numerical stats across that comm 433 | """ 434 | if dir is None: 435 | dir = os.getenv("OPENAI_LOGDIR") 436 | if dir is None: 437 | dir = osp.join( 438 | tempfile.gettempdir(), 439 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 440 | ) 441 | assert isinstance(dir, str) 442 | dir = os.path.expanduser(dir) 443 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 444 | 445 | rank = get_rank_without_mpi_import() 446 | if rank > 0: 447 | log_suffix = log_suffix + "-rank%03i" % rank 448 | 449 | if format_strs is None: 450 | if rank == 0: 451 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 452 | else: 453 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 454 | format_strs = filter(None, format_strs) 455 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 456 | 457 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 458 | if output_formats: 459 | log("Logging to %s" % dir) 460 | 461 | 462 | def _configure_default_logger(dir): 463 | configure(dir=dir) 464 | Logger.DEFAULT = Logger.CURRENT 465 | 466 | 467 | def reset(): 468 | if Logger.CURRENT is not Logger.DEFAULT: 469 | Logger.CURRENT.close() 470 | Logger.CURRENT = Logger.DEFAULT 471 | log("Reset logger") 472 | 473 | 474 | @contextmanager 475 | def scoped_configure(dir=None, format_strs=None, comm=None): 476 | prevlogger = Logger.CURRENT 477 | configure(dir=dir, format_strs=format_strs, comm=comm) 478 | try: 479 | yield 480 | finally: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = prevlogger 483 | -------------------------------------------------------------------------------- /ddbm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * torch.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def append_dims(x, target_dims): 94 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 95 | dims_to_append = target_dims - x.ndim 96 | if dims_to_append < 0: 97 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 98 | return x[(...,) + (None,) * dims_to_append] 99 | 100 | 101 | def append_zero(x): 102 | return torch.cat([x, x.new_zeros([1])]) 103 | 104 | 105 | def normalization(channels): 106 | """ 107 | Make a standard normalization layer. 108 | 109 | :param channels: number of input channels. 110 | :return: an nn.Module for normalization. 111 | """ 112 | return GroupNorm32(32, channels) 113 | 114 | 115 | def timestep_embedding(timesteps, dim, max_period=10000): 116 | """ 117 | Create sinusoidal timestep embeddings. 118 | 119 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 120 | These may be fractional. 121 | :param dim: the dimension of the output. 122 | :param max_period: controls the minimum frequency of the embeddings. 123 | :return: an [N x dim] Tensor of positional embeddings. 124 | """ 125 | half = dim // 2 126 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 127 | device=timesteps.device 128 | ) 129 | args = timesteps[:, None].float() * freqs[None] 130 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 131 | if dim % 2: 132 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 133 | return embedding 134 | 135 | 136 | def checkpoint(func, inputs, params, flag): 137 | """ 138 | Evaluate a function without caching intermediate activations, allowing for 139 | reduced memory at the expense of extra compute in the backward pass. 140 | 141 | :param func: the function to evaluate. 142 | :param inputs: the argument sequence to pass to `func`. 143 | :param params: a sequence of parameters `func` depends on but does not 144 | explicitly take as arguments. 145 | :param flag: if False, disable gradient checkpointing. 146 | """ 147 | if flag: 148 | args = tuple(inputs) + tuple(params) 149 | return CheckpointFunction.apply(func, len(inputs), *args) 150 | else: 151 | return func(*inputs) 152 | 153 | 154 | class CheckpointFunction(torch.autograd.Function): 155 | @staticmethod 156 | def forward(ctx, run_function, length, *args): 157 | ctx.run_function = run_function 158 | ctx.input_tensors = list(args[:length]) 159 | ctx.input_params = list(args[length:]) 160 | with torch.no_grad(): 161 | output_tensors = ctx.run_function(*ctx.input_tensors) 162 | return output_tensors 163 | 164 | @staticmethod 165 | def backward(ctx, *output_grads): 166 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 167 | with torch.enable_grad(): 168 | # Fixes a bug where the first op in run_function modifies the 169 | # Tensor storage in place, which is not allowed for detach()'d 170 | # Tensors. 171 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 172 | output_tensors = ctx.run_function(*shallow_copies) 173 | input_grads = torch.autograd.grad( 174 | output_tensors, 175 | ctx.input_tensors + ctx.input_params, 176 | output_grads, 177 | allow_unused=True, 178 | ) 179 | del ctx.input_tensors 180 | del ctx.input_params 181 | del output_tensors 182 | return (None, None) + input_grads 183 | -------------------------------------------------------------------------------- /ddbm/random_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from . import dist_util 4 | 5 | 6 | def get_generator(generator, num_samples=0, seed=0): 7 | if generator == "dummy": 8 | return DummyGenerator() 9 | elif generator == "determ": 10 | return DeterministicGenerator(num_samples, seed) 11 | elif generator == "determ-indiv": 12 | return DeterministicIndividualGenerator(num_samples, seed) 13 | else: 14 | raise NotImplementedError 15 | 16 | 17 | class DummyGenerator: 18 | def randn(self, *args, **kwargs): 19 | return torch.randn(*args, **kwargs) 20 | 21 | def randint(self, *args, **kwargs): 22 | return torch.randint(*args, **kwargs) 23 | 24 | def randn_like(self, *args, **kwargs): 25 | return torch.randn_like(*args, **kwargs) 26 | 27 | 28 | class DeterministicGenerator: 29 | """ 30 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines 31 | Uses a single rng and samples num_samples sized randomness and subsamples the current indices 32 | """ 33 | 34 | def __init__(self, num_samples, seed=0): 35 | if dist.is_initialized(): 36 | self.rank = dist.get_rank() 37 | self.world_size = dist.get_world_size() 38 | else: 39 | print("Warning: Distributed not initialised, using single rank") 40 | self.rank = 0 41 | self.world_size = 1 42 | self.num_samples = num_samples 43 | self.done_samples = 0 44 | self.seed = seed 45 | self.rng_cpu = torch.Generator() 46 | if torch.cuda.is_available(): 47 | self.rng_cuda = torch.Generator(dist_util.dev()) 48 | self.set_seed(seed) 49 | 50 | def get_global_size_and_indices(self, size): 51 | global_size = (self.num_samples, *size[1:]) 52 | indices = torch.arange( 53 | self.done_samples + self.rank, 54 | self.done_samples + self.world_size * int(size[0]), 55 | self.world_size, 56 | ) 57 | indices = torch.clamp(indices, 0, self.num_samples - 1) 58 | assert len(indices) == size[0], f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" 59 | return global_size, indices 60 | 61 | def get_generator(self, device): 62 | return self.rng_cpu if torch.device(device).type == "cpu" else self.rng_cuda 63 | 64 | def randn(self, *size, dtype=torch.float, device="cpu"): 65 | global_size, indices = self.get_global_size_and_indices(size) 66 | generator = self.get_generator(device) 67 | return torch.randn(*global_size, generator=generator, dtype=dtype, device=device)[indices] 68 | 69 | def randint(self, low, high, size, dtype=torch.long, device="cpu"): 70 | global_size, indices = self.get_global_size_and_indices(size) 71 | generator = self.get_generator(device) 72 | return torch.randint(low, high, generator=generator, size=global_size, dtype=dtype, device=device)[indices] 73 | 74 | def randn_like(self, tensor): 75 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 76 | return self.randn(*size, dtype=dtype, device=device) 77 | 78 | def set_done_samples(self, done_samples): 79 | self.done_samples = done_samples 80 | self.set_seed(self.seed) 81 | 82 | def get_seed(self): 83 | return self.seed 84 | 85 | def set_seed(self, seed): 86 | self.rng_cpu.manual_seed(seed) 87 | if torch.cuda.is_available(): 88 | self.rng_cuda.manual_seed(seed) 89 | 90 | 91 | class DeterministicIndividualGenerator: 92 | """ 93 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines 94 | Uses a separate rng for each sample to reduce memoery usage 95 | """ 96 | 97 | def __init__(self, num_samples, seed=0): 98 | if dist.is_initialized(): 99 | self.rank = dist.get_rank() 100 | self.world_size = dist.get_world_size() 101 | else: 102 | print("Warning: Distributed not initialised, using single rank") 103 | self.rank = 0 104 | self.world_size = 1 105 | self.num_samples = num_samples 106 | self.done_samples = 0 107 | self.seed = seed 108 | self.rng_cpu = [torch.Generator() for _ in range(num_samples)] 109 | if torch.cuda.is_available(): 110 | self.rng_cuda = [torch.Generator(dist_util.dev()) for _ in range(num_samples)] 111 | self.set_seed(seed) 112 | 113 | def get_size_and_indices(self, size): 114 | indices = torch.arange( 115 | self.done_samples + self.rank, 116 | self.done_samples + self.world_size * int(size[0]), 117 | self.world_size, 118 | ) 119 | indices = torch.clamp(indices, 0, self.num_samples - 1) 120 | assert len(indices) == size[0], f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" 121 | return (1, *size[1:]), indices 122 | 123 | def get_generator(self, device): 124 | return self.rng_cpu if torch.device(device).type == "cpu" else self.rng_cuda 125 | 126 | def randn(self, *size, dtype=torch.float, device="cpu"): 127 | size, indices = self.get_size_and_indices(size) 128 | generator = self.get_generator(device) 129 | return torch.cat( 130 | [torch.randn(*size, generator=generator[i], dtype=dtype, device=device) for i in indices], 131 | dim=0, 132 | ) 133 | 134 | def randint(self, low, high, size, dtype=torch.long, device="cpu"): 135 | size, indices = self.get_size_and_indices(size) 136 | generator = self.get_generator(device) 137 | return torch.cat( 138 | [ 139 | torch.randint( 140 | low, 141 | high, 142 | generator=generator[i], 143 | size=size, 144 | dtype=dtype, 145 | device=device, 146 | ) 147 | for i in indices 148 | ], 149 | dim=0, 150 | ) 151 | 152 | def randn_like(self, tensor): 153 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 154 | return self.randn(*size, dtype=dtype, device=device) 155 | 156 | def set_done_samples(self, done_samples): 157 | self.done_samples = done_samples 158 | 159 | def get_seed(self): 160 | return self.seed 161 | 162 | def set_seed(self, seed): 163 | [rng_cpu.manual_seed(i + self.num_samples * seed) for i, rng_cpu in enumerate(self.rng_cpu)] 164 | if torch.cuda.is_available(): 165 | [rng_cuda.manual_seed(i + self.num_samples * seed) for i, rng_cuda in enumerate(self.rng_cuda)] 166 | 167 | 168 | class BatchedSeedGenerator: 169 | 170 | def __init__(self, seeds=None): 171 | self.num_samples = len(seeds) 172 | if torch.cuda.is_available(): 173 | self.rng = [torch.Generator(dist_util.dev()) for _ in range(self.num_samples)] 174 | else: 175 | self.rng = [torch.Generator() for _ in range(self.num_samples)] 176 | [rng.manual_seed(int(seeds[i])) for i, rng in enumerate(self.rng)] 177 | 178 | def randn(self, size, dtype=torch.float, device="cpu"): 179 | assert size[0] == self.num_samples 180 | return torch.cat( 181 | [ 182 | torch.randn(1, *size[1:], generator=self.rng[i], dtype=dtype, device=device) 183 | for i in range(self.num_samples) 184 | ], 185 | dim=0, 186 | ) 187 | 188 | def randint(self, low, high, size, dtype=torch.long, device="cpu"): 189 | assert size[0] == self.num_samples 190 | return torch.cat( 191 | [ 192 | torch.randint( 193 | low, 194 | high, 195 | generator=self.rng[i], 196 | size=(1, *size[1:]), 197 | dtype=dtype, 198 | device=device, 199 | ) 200 | for i in range(self.num_samples) 201 | ], 202 | dim=0, 203 | ) 204 | 205 | def randn_like(self, tensor): 206 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device 207 | return self.randn(size, dtype=dtype, device=device) 208 | -------------------------------------------------------------------------------- /ddbm/resample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_named_schedule_sampler(name, diffusion): 5 | """ 6 | Create a ScheduleSampler from a library of pre-defined samplers. 7 | 8 | :param name: the name of the sampler. 9 | :param diffusion: the diffusion object to sample for. 10 | """ 11 | if name == "real-uniform": 12 | return RealUniformSampler(diffusion) 13 | else: 14 | raise NotImplementedError(f"unknown schedule sampler: {name}") 15 | 16 | 17 | class RealUniformSampler: 18 | def __init__(self, diffusion): 19 | self.t_max = diffusion.t_max 20 | self.t_min = diffusion.t_min 21 | 22 | def sample(self, batch_size, device): 23 | ts = torch.rand(batch_size).to(device) * (self.t_max - self.t_min) + self.t_min 24 | return ts, torch.ones_like(ts) 25 | -------------------------------------------------------------------------------- /ddbm/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .karras_diffusion import ( 4 | KarrasDenoiser, 5 | VPNoiseSchedule, 6 | VENoiseSchedule, 7 | I2SBNoiseSchedule, 8 | DDBMPreCond, 9 | I2SBPreCond, 10 | ) 11 | from .unet import UNetModel 12 | 13 | NUM_CLASSES = 1000 14 | 15 | 16 | def get_workdir(exp): 17 | workdir = f"./workdir/{exp}" 18 | return workdir 19 | 20 | 21 | def sample_defaults(): 22 | return dict( 23 | generator="determ", 24 | clip_denoised=True, 25 | sampler="euler", 26 | s_churn=0.0, 27 | s_tmin=0.002, 28 | s_tmax=80, 29 | s_noise=1.0, 30 | steps=40, 31 | model_path="", 32 | seed=42, 33 | ts="", 34 | ) 35 | 36 | 37 | def model_and_diffusion_defaults(): 38 | """ 39 | Defaults for image training. 40 | """ 41 | res = dict( 42 | sigma_data=0.5, 43 | sigma_min=0.002, 44 | sigma_max=80.0, 45 | beta_d=2, 46 | beta_min=0.1, 47 | beta_max=1.0, 48 | cov_xy=0.0, 49 | image_size=64, 50 | in_channels=3, 51 | num_channels=128, 52 | num_res_blocks=2, 53 | num_heads=4, 54 | num_heads_upsample=-1, 55 | num_head_channels=64, 56 | unet_type="adm", 57 | attention_resolutions="32,16,8", 58 | channel_mult="", 59 | dropout=0.0, 60 | class_cond=False, 61 | use_checkpoint=False, 62 | use_scale_shift_norm=True, 63 | resblock_updown=True, 64 | use_fp16=True, 65 | use_new_attention_order=False, 66 | condition_mode=None, 67 | noise_schedule="ve", 68 | ) 69 | return res 70 | 71 | 72 | def create_model_and_diffusion( 73 | image_size, 74 | in_channels, 75 | class_cond, 76 | num_channels, 77 | num_res_blocks, 78 | channel_mult, 79 | num_heads, 80 | num_head_channels, 81 | num_heads_upsample, 82 | attention_resolutions, 83 | dropout, 84 | use_checkpoint, 85 | use_scale_shift_norm, 86 | resblock_updown, 87 | use_fp16, 88 | use_new_attention_order, 89 | condition_mode, 90 | noise_schedule, 91 | sigma_data=0.5, 92 | sigma_min=0.002, 93 | sigma_max=80.0, 94 | beta_d=2, 95 | beta_min=0.1, 96 | beta_max=1.0, 97 | cov_xy=0.0, 98 | unet_type="adm", 99 | ): 100 | model = create_model( 101 | image_size, 102 | in_channels, 103 | num_channels, 104 | num_res_blocks, 105 | unet_type=unet_type, 106 | channel_mult=channel_mult, 107 | class_cond=class_cond, 108 | use_checkpoint=use_checkpoint, 109 | attention_resolutions=attention_resolutions, 110 | num_heads=num_heads, 111 | num_head_channels=num_head_channels, 112 | num_heads_upsample=num_heads_upsample, 113 | use_scale_shift_norm=use_scale_shift_norm, 114 | dropout=dropout, 115 | resblock_updown=resblock_updown, 116 | use_fp16=use_fp16, 117 | use_new_attention_order=use_new_attention_order, 118 | condition_mode=condition_mode, 119 | ) 120 | if noise_schedule.startswith("vp"): 121 | ns = VPNoiseSchedule(beta_d=beta_d, beta_min=beta_min) 122 | precond = DDBMPreCond(ns, sigma_data=sigma_data, cov_xy=cov_xy) 123 | elif noise_schedule == "ve": 124 | ns = VENoiseSchedule(sigma_max=sigma_max) 125 | precond = DDBMPreCond(ns, sigma_data=sigma_data, cov_xy=cov_xy) 126 | elif noise_schedule.startswith("i2sb"): 127 | ns = I2SBNoiseSchedule(beta_max=beta_max, beta_min=beta_min) 128 | precond = I2SBPreCond(ns) 129 | 130 | diffusion = KarrasDenoiser( 131 | noise_schedule=ns, 132 | precond=precond, 133 | t_max=sigma_max, 134 | t_min=sigma_min, 135 | ) 136 | return model, diffusion 137 | 138 | 139 | def create_model( 140 | image_size, 141 | in_channels, 142 | num_channels, 143 | num_res_blocks, 144 | unet_type="adm", 145 | channel_mult="", 146 | class_cond=False, 147 | use_checkpoint=False, 148 | attention_resolutions="16", 149 | num_heads=1, 150 | num_head_channels=-1, 151 | num_heads_upsample=-1, 152 | use_scale_shift_norm=False, 153 | dropout=0, 154 | resblock_updown=False, 155 | use_fp16=False, 156 | use_new_attention_order=False, 157 | condition_mode=None, 158 | ): 159 | if channel_mult == "": 160 | if image_size == 512: 161 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 162 | elif image_size == 256: 163 | channel_mult = (1, 1, 2, 2, 4, 4) 164 | elif image_size == 128: 165 | channel_mult = (1, 1, 2, 3, 4) 166 | elif image_size == 64: 167 | channel_mult = (1, 2, 3, 4) 168 | elif image_size == 32: 169 | channel_mult = (1, 2, 3, 4) 170 | else: 171 | raise ValueError(f"unsupported image size: {image_size}") 172 | else: 173 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 174 | 175 | attention_ds = [] 176 | for res in attention_resolutions.split(","): 177 | attention_ds.append(image_size // int(res)) 178 | 179 | if unet_type == "adm": 180 | return UNetModel( 181 | image_size=image_size, 182 | in_channels=in_channels, 183 | model_channels=num_channels, 184 | out_channels=in_channels, 185 | num_res_blocks=num_res_blocks, 186 | attention_resolutions=tuple(attention_ds), 187 | dropout=dropout, 188 | channel_mult=channel_mult, 189 | num_classes=(NUM_CLASSES if class_cond else None), 190 | use_checkpoint=use_checkpoint, 191 | use_fp16=use_fp16, 192 | num_heads=num_heads, 193 | num_head_channels=num_head_channels, 194 | num_heads_upsample=num_heads_upsample, 195 | use_scale_shift_norm=use_scale_shift_norm, 196 | resblock_updown=resblock_updown, 197 | use_new_attention_order=use_new_attention_order, 198 | condition_mode=condition_mode, 199 | ) 200 | else: 201 | raise ValueError(f"Unsupported unet type: {unet_type}") 202 | 203 | 204 | def add_dict_to_argparser(parser, default_dict): 205 | for k, v in default_dict.items(): 206 | v_type = type(v) 207 | if v is None: 208 | v_type = str 209 | elif isinstance(v, bool): 210 | v_type = str2bool 211 | parser.add_argument(f"--{k}", default=v, type=v_type) 212 | 213 | 214 | def args_to_dict(args, keys): 215 | return {k: getattr(args, k) for k in keys} 216 | 217 | 218 | def str2bool(v): 219 | """ 220 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 221 | """ 222 | if isinstance(v, bool): 223 | return v 224 | if v.lower() in ("yes", "true", "t", "y", "1"): 225 | return True 226 | elif v.lower() in ("no", "false", "f", "n", "0"): 227 | return False 228 | else: 229 | raise argparse.ArgumentTypeError("boolean value expected") 230 | -------------------------------------------------------------------------------- /ddbm/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import numpy as np 6 | 7 | import blobfile as bf 8 | import torch 9 | import torch.distributed as dist 10 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 11 | from torch.optim import RAdam 12 | 13 | from . import dist_util, logger 14 | from .nn import update_ema 15 | 16 | from ddbm.random_util import get_generator 17 | 18 | import glob 19 | 20 | import wandb 21 | 22 | 23 | class TrainLoop: 24 | def __init__( 25 | self, 26 | *, 27 | model, 28 | diffusion, 29 | train_data, 30 | test_data, 31 | batch_size, 32 | microbatch, 33 | lr, 34 | ema_rate, 35 | log_interval, 36 | test_interval, 37 | save_interval, 38 | save_interval_for_preemption, 39 | resume_checkpoint, 40 | workdir, 41 | use_fp16=False, 42 | fp16_scale_growth=1e-3, 43 | schedule_sampler=None, 44 | weight_decay=0.0, 45 | lr_anneal_steps=0, 46 | total_training_steps=10000000, 47 | augment_pipe=None, 48 | train_mode="ddbm", 49 | resume_train_flag=False, 50 | **sample_kwargs, 51 | ): 52 | self.model = model 53 | self.diffusion = diffusion 54 | self.data = train_data 55 | self.test_data = test_data 56 | self.image_size = model.image_size 57 | self.batch_size = batch_size 58 | self.microbatch = microbatch if microbatch > 0 else batch_size 59 | self.lr = lr 60 | self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] 61 | self.log_interval = log_interval 62 | self.workdir = workdir 63 | self.test_interval = test_interval 64 | self.save_interval = save_interval 65 | self.save_interval_for_preemption = save_interval_for_preemption 66 | self.resume_checkpoint = resume_checkpoint 67 | self.use_fp16 = use_fp16 68 | self.fp16_scale_growth = fp16_scale_growth 69 | self.schedule_sampler = schedule_sampler 70 | self.weight_decay = weight_decay 71 | self.lr_anneal_steps = lr_anneal_steps 72 | self.total_training_steps = total_training_steps 73 | 74 | self.train_mode = train_mode 75 | 76 | self.step = 0 77 | self.resume_train_flag = resume_train_flag 78 | self.resume_step = 0 79 | self.global_batch = self.batch_size * dist.get_world_size() 80 | 81 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16) 82 | 83 | self._load_and_sync_parameters() 84 | if not self.resume_train_flag: 85 | self.resume_step = 0 86 | 87 | self.opt = RAdam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) 88 | if self.resume_step: 89 | self._load_optimizer_state() 90 | # Model was resumed, either due to a restart or a checkpoint 91 | # being specified at the command line. 92 | self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate] 93 | else: 94 | self.ema_params = [copy.deepcopy(list(self.model.parameters())) for _ in range(len(self.ema_rate))] 95 | 96 | if torch.cuda.is_available(): 97 | self.use_ddp = True 98 | local_rank = int(os.environ["LOCAL_RANK"]) 99 | self.ddp_model = DDP( 100 | self.model, 101 | device_ids=[local_rank], 102 | output_device=local_rank, 103 | ) 104 | else: 105 | if dist.get_world_size() > 1: 106 | logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!") 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | self.step = self.resume_step 111 | 112 | self.generator = get_generator(sample_kwargs["generator"], self.batch_size, 42) 113 | self.sample_kwargs = sample_kwargs 114 | 115 | self.augment = augment_pipe 116 | 117 | def _load_and_sync_parameters(self): 118 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 119 | 120 | if resume_checkpoint: 121 | if self.resume_train_flag: 122 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 123 | if dist.get_rank() == 0: 124 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 125 | logger.log("Resume step: ", self.resume_step) 126 | 127 | self.model.load_state_dict(torch.load(resume_checkpoint, map_location="cpu")) 128 | self.model.to(dist_util.dev()) 129 | 130 | dist.barrier() 131 | 132 | def _load_ema_parameters(self, rate): 133 | ema_params = copy.deepcopy(list(self.model.parameters())) 134 | 135 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 136 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 137 | if ema_checkpoint: 138 | if dist.get_rank() == 0: 139 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 140 | state_dict = torch.load(ema_checkpoint, map_location=dist_util.dev()) 141 | ema_params = [state_dict[name] for name, _ in self.model.named_parameters()] 142 | 143 | dist.barrier() 144 | return ema_params 145 | 146 | def _load_optimizer_state(self): 147 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 148 | if main_checkpoint.split("/")[-1].startswith("freq"): 149 | prefix = "freq_" 150 | else: 151 | prefix = "" 152 | opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"{prefix}opt_{self.resume_step:06}.pt") 153 | if bf.exists(opt_checkpoint): 154 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 155 | state_dict = torch.load(opt_checkpoint, map_location=dist_util.dev()) 156 | self.opt.load_state_dict(state_dict) 157 | dist.barrier() 158 | 159 | def run_loop(self): 160 | while True: 161 | for batch, cond, _ in self.data: 162 | 163 | if "inpaint" in self.workdir: 164 | _, mask, label = _ 165 | else: 166 | mask = None 167 | 168 | if not (not self.lr_anneal_steps or self.step < self.total_training_steps): 169 | # Save the last checkpoint if it wasn't already saved. 170 | if (self.step - 1) % self.save_interval != 0: 171 | self.save() 172 | return 173 | 174 | if self.augment is not None: 175 | batch, _ = self.augment(batch) 176 | if isinstance(cond, torch.Tensor) and batch.ndim == cond.ndim: 177 | cond = {"xT": cond} 178 | else: 179 | cond["xT"] = cond["xT"] 180 | if mask is not None: 181 | # cond["mask"] = mask 182 | cond["y"] = label 183 | 184 | took_step = self.run_step(batch, cond) 185 | if took_step and self.step % self.log_interval == 0: 186 | logs = logger.dumpkvs() 187 | 188 | if dist.get_rank() == 0: 189 | wandb.log(logs, step=self.step) 190 | 191 | if took_step and self.step % self.save_interval == 0: 192 | self.save() 193 | # Run for a finite amount of time in integration tests. 194 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 195 | return 196 | 197 | test_batch, test_cond, _ = next(iter(self.test_data)) 198 | if "inpaint" in self.workdir: 199 | _, mask, label = _ 200 | else: 201 | mask = None 202 | if isinstance(test_cond, torch.Tensor) and test_batch.ndim == test_cond.ndim: 203 | test_cond = {"xT": test_cond} 204 | else: 205 | test_cond["xT"] = test_cond["xT"] 206 | if mask is not None: 207 | # test_cond["mask"] = mask 208 | test_cond["y"] = label 209 | self.run_test_step(test_batch, test_cond) 210 | logs = logger.dumpkvs() 211 | 212 | if dist.get_rank() == 0: 213 | wandb.log(logs, step=self.step) 214 | 215 | if took_step and self.step % self.save_interval_for_preemption == 0: 216 | self.save(for_preemption=True) 217 | 218 | def run_step(self, batch, cond): 219 | self.forward_backward(batch, cond) 220 | logger.logkv_mean("lg_loss_scale", np.log2(self.scaler.get_scale())) 221 | self.scaler.unscale_(self.opt) 222 | 223 | def _compute_norms(): 224 | grad_norm = 0.0 225 | param_norm = 0.0 226 | for p in self.model.parameters(): 227 | with torch.no_grad(): 228 | param_norm += torch.norm(p, p=2, dtype=torch.float32).item() ** 2 229 | if p.grad is not None: 230 | grad_norm += torch.norm(p.grad, p=2, dtype=torch.float32).item() ** 2 231 | return np.sqrt(grad_norm), np.sqrt(param_norm) 232 | 233 | grad_norm, param_norm = _compute_norms() 234 | 235 | logger.logkv_mean("grad_norm", grad_norm) 236 | logger.logkv_mean("param_norm", param_norm) 237 | 238 | self.scaler.step(self.opt) 239 | self.scaler.update() 240 | self.step += 1 241 | self._update_ema() 242 | 243 | self._anneal_lr() 244 | self.log_step() 245 | return True 246 | 247 | def run_test_step(self, batch, cond): 248 | with torch.no_grad(): 249 | self.forward_backward(batch, cond, train=False) 250 | 251 | def forward_backward(self, batch, cond, train=True): 252 | if train: 253 | self.opt.zero_grad() 254 | assert batch.shape[0] % self.microbatch == 0 255 | num_microbatches = batch.shape[0] // self.microbatch 256 | for i in range(0, batch.shape[0], self.microbatch): 257 | with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.use_fp16): 258 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 259 | micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()} 260 | last_batch = (i + self.microbatch) >= batch.shape[0] 261 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 262 | 263 | if self.train_mode == "ddbm": 264 | compute_losses = functools.partial( 265 | self.diffusion.training_bridge_losses, 266 | self.ddp_model, 267 | micro, 268 | t, 269 | model_kwargs=micro_cond, 270 | ) 271 | else: 272 | raise NotImplementedError() 273 | 274 | if last_batch or not self.use_ddp: 275 | losses = compute_losses() 276 | else: 277 | with self.ddp_model.no_sync(): 278 | losses = compute_losses() 279 | 280 | loss = (losses["loss"] * weights).mean() / num_microbatches 281 | log_loss_dict(self.diffusion, t, {k if train else "test_" + k: v * weights for k, v in losses.items()}) 282 | if train: 283 | self.scaler.scale(loss).backward() 284 | 285 | def _update_ema(self): 286 | for rate, params in zip(self.ema_rate, self.ema_params): 287 | update_ema(params, self.model.parameters(), rate=rate) 288 | 289 | def _anneal_lr(self): 290 | if not self.lr_anneal_steps: 291 | return 292 | frac_done = (self.step) / self.lr_anneal_steps 293 | lr = self.lr * (1 - frac_done) 294 | for param_group in self.opt.param_groups: 295 | param_group["lr"] = lr 296 | 297 | def log_step(self): 298 | logger.logkv("step", self.step) 299 | logger.logkv("samples", (self.step + 1) * self.global_batch) 300 | 301 | def save(self, for_preemption=False): 302 | def maybe_delete_earliest(filename): 303 | wc = filename.split(f"{(self.step):06d}")[0] + "*" 304 | freq_states = list(glob.glob(os.path.join(get_blob_logdir(), wc))) 305 | if len(freq_states) > 3000: 306 | earliest = min(freq_states, key=lambda x: x.split("_")[-1].split(".")[0]) 307 | os.remove(earliest) 308 | 309 | # if dist.get_rank() == 0 and for_preemption: 310 | # maybe_delete_earliest(get_blob_logdir()) 311 | def save_checkpoint(rate, params): 312 | state_dict = self.model.state_dict() 313 | for i, (name, _) in enumerate(self.model.named_parameters()): 314 | assert name in state_dict 315 | state_dict[name] = params[i] 316 | if dist.get_rank() == 0: 317 | logger.log(f"saving model {rate}...") 318 | if not rate: 319 | filename = f"model_{(self.step):06d}.pt" 320 | else: 321 | filename = f"ema_{rate}_{(self.step):06d}.pt" 322 | if for_preemption: 323 | filename = f"freq_{filename}" 324 | maybe_delete_earliest(filename) 325 | 326 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 327 | torch.save(state_dict, f) 328 | 329 | for rate, params in zip(self.ema_rate, self.ema_params): 330 | save_checkpoint(rate, params) 331 | 332 | if dist.get_rank() == 0: 333 | filename = f"opt_{(self.step):06d}.pt" 334 | if for_preemption: 335 | filename = f"freq_{filename}" 336 | maybe_delete_earliest(filename) 337 | 338 | with bf.BlobFile( 339 | bf.join(get_blob_logdir(), filename), 340 | "wb", 341 | ) as f: 342 | torch.save(self.opt.state_dict(), f) 343 | 344 | # Save model parameters last to prevent race conditions where a restart 345 | # loads model at step N, but opt/ema state isn't saved for step N. 346 | save_checkpoint(0, list(self.model.parameters())) 347 | dist.barrier() 348 | 349 | 350 | def parse_resume_step_from_filename(filename): 351 | """ 352 | Parse filenames of the form path/to/model_NNNNNN.pt, where NNNNNN is the 353 | checkpoint's number of steps. 354 | """ 355 | split = filename.split("model_") 356 | if len(split) < 2: 357 | return 0 358 | split1 = split[-1].split(".")[0] 359 | try: 360 | return int(split1) 361 | except ValueError: 362 | return 0 363 | 364 | 365 | def get_blob_logdir(): 366 | # You can change this to be a separate path to save checkpoints to 367 | # a blobstore or some external drive. 368 | return logger.get_dir() 369 | 370 | 371 | def find_resume_checkpoint(): 372 | # On your infrastructure, you may want to override this to automatically 373 | # discover the latest checkpoint on your blob storage, etc. 374 | return None 375 | 376 | 377 | def find_ema_checkpoint(main_checkpoint, step, rate): 378 | if main_checkpoint is None: 379 | return None 380 | if main_checkpoint.split("/")[-1].startswith("freq"): 381 | prefix = "freq_" 382 | else: 383 | prefix = "" 384 | filename = f"{prefix}ema_{rate}_{(step):06d}.pt" 385 | path = bf.join(bf.dirname(main_checkpoint), filename) 386 | if bf.exists(path): 387 | return path 388 | return None 389 | 390 | 391 | def log_loss_dict(diffusion, ts, losses): 392 | for key, values in losses.items(): 393 | logger.logkv_mean(key, values.mean().item()) 394 | # Log the quantiles (four quartiles, in particular). 395 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 396 | quartile = int(4 * sub_t) 397 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 398 | -------------------------------------------------------------------------------- /download_diffusion.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import requests 10 | from tqdm import tqdm 11 | 12 | import torch 13 | 14 | ADM_IMG256_COND_CKPT = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt" 15 | I2SB_IMG256_COND_CKPT = "256x256_diffusion_fixedsigma.pt" 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def download_adm_image256_cond_ckpt(ckpt_dir): 31 | ckpt_pt = os.path.join(ckpt_dir, I2SB_IMG256_COND_CKPT) 32 | if os.path.exists(ckpt_pt): 33 | return 34 | 35 | adm_ckpt = os.path.join(ckpt_dir, os.path.basename(ADM_IMG256_COND_CKPT)) 36 | 37 | print("Downloading ADM checkpoint to {} ...".format(adm_ckpt)) 38 | download(ADM_IMG256_COND_CKPT, adm_ckpt) 39 | ckpt_state_dict = torch.load(adm_ckpt, map_location="cpu") 40 | 41 | # pt: remove the sigma prediction and add concat module 42 | ckpt_state_dict["out.2.weight"] = ckpt_state_dict["out.2.weight"][:3] 43 | ckpt_state_dict["out.2.bias"] = ckpt_state_dict["out.2.bias"][:3] 44 | ckpt_state_dict["input_blocks.0.0.weight"] = torch.cat( 45 | [ckpt_state_dict["input_blocks.0.0.weight"], ckpt_state_dict["input_blocks.0.0.weight"]], dim=1 46 | ) 47 | torch.save(ckpt_state_dict, ckpt_pt) 48 | 49 | print(f"Saved adm cond pretrain models at {ckpt_pt}!") 50 | 51 | 52 | def download_ckpt(ckpt_dir="assets/ckpts"): 53 | os.makedirs(ckpt_dir, exist_ok=True) 54 | download_adm_image256_cond_ckpt(ckpt_dir=ckpt_dir) 55 | 56 | 57 | download_ckpt() 58 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from .resnet import build_resnet50 -------------------------------------------------------------------------------- /evaluation/compute_metrices_imagenet.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import argparse 10 | import random 11 | from pathlib import Path 12 | import json 13 | 14 | import numpy as np 15 | import torch 16 | import torchvision.transforms as transforms 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data import Dataset 19 | from easydict import EasyDict as edict 20 | from logger import Logger 21 | 22 | from evaluation.resnet import build_resnet50 23 | from evaluation import fid_util 24 | 25 | RESULT_DIR = Path("results") 26 | ADM_IMG256_FID_TRAIN_REF_CKPT = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz" 27 | 28 | def set_seed(seed): 29 | # https://github.com/pytorch/pytorch/issues/7068 30 | random.seed(seed) 31 | os.environ['PYTHONHASHSEED'] = str(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 36 | 37 | class NumpyDataset(Dataset): 38 | def __init__(self, data, targets): 39 | self.data = data 40 | self.targets = torch.LongTensor(targets) 41 | self.transform = transforms.ToTensor() 42 | 43 | def __getitem__(self, index): 44 | img_np = self.data[index] 45 | y = self.targets[index] 46 | 47 | if img_np.dtype == "uint8": 48 | # transform gives [0,1] 49 | img_t = self.transform(img_np) * 2 - 1 50 | elif img_np.dtype == "float32": 51 | # transform gives [0,255] 52 | img_t = self.transform(img_np) / 127.5 - 1 53 | 54 | # img_t: [-1,1] 55 | return img_t, y 56 | 57 | def __len__(self): 58 | return len(self.data) 59 | 60 | @torch.no_grad() 61 | def compute_accu(opt, numpy_arr, numpy_label_arr, batch_size=256): 62 | dataset = NumpyDataset(numpy_arr, numpy_label_arr) 63 | loader = DataLoader(dataset, 64 | batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1, drop_last=False, 65 | ) 66 | 67 | resnet = build_resnet50().to(opt.device) 68 | correct = total = 0 69 | for (x,y) in loader: 70 | pred_y = resnet(x.to(opt.device)) 71 | 72 | _, predicted = torch.max(pred_y.cpu(), 1) 73 | correct += (predicted==y).sum().item() 74 | total += y.size(0) 75 | 76 | accu = correct / total 77 | return accu 78 | 79 | def convert_to_numpy(t): 80 | # t: [-1,1] 81 | out = (t + 1) * 127.5 82 | out = out.clamp(0, 255) 83 | out = out.to(torch.uint8) 84 | out = out.permute(0, 2, 3, 1) # batch, 256, 256, 3 85 | out = out.contiguous() 86 | return out.cpu().numpy() # [0, 255] 87 | 88 | 89 | def build_numpy_data(opt, log): 90 | arr = [] 91 | ckpt_path = opt.ckpt 92 | numpy_arr = np.load(ckpt_path)['arr_0'] 93 | label_path = opt.label 94 | label_arr = np.load(label_path)['arr_0'] 95 | 96 | # converet to numpy 97 | return numpy_arr, label_arr 98 | 99 | def build_ref_opt(opt, ref_fid_fn): 100 | split = ref_fid_fn.name[:-4].split("_")[-1] 101 | image_size = int(ref_fid_fn.name[:-4].split("_")[-2]) 102 | assert opt.image_size == image_size 103 | return edict( 104 | mode=opt.mode, 105 | split=split, 106 | image_size=image_size, 107 | dataset_dir=opt.dataset_dir, 108 | ) 109 | 110 | def get_ref_fid(opt, log): 111 | # get ref fid npz file 112 | 113 | ref_fid_fn = Path("assets/stats/fid_imagenet_256_val.npz") 114 | if not ref_fid_fn.exists(): 115 | log.info(f"Generating {ref_fid_fn=} (this can take a while ...)") 116 | ref_opt = build_ref_opt(opt, ref_fid_fn) 117 | fid_util.compute_fid_ref_stat(ref_opt, log) 118 | 119 | # load npz file 120 | ref_fid = np.load(ref_fid_fn) 121 | ref_mu, ref_sigma = ref_fid['mu'], ref_fid['sigma'] 122 | return ref_fid_fn, ref_mu, ref_sigma 123 | 124 | def log_metrices(opt): 125 | # setup 126 | set_seed(opt.seed) 127 | if opt.gpu is not None: 128 | torch.cuda.set_device(opt.gpu) 129 | log = Logger(0, ".log") 130 | 131 | log.info(f"======== Compute metrices: {opt.ckpt=}, {opt.mode=} ========") 132 | 133 | # find all recon pt files 134 | # recon_imgs_pts = find_recon_imgs_pts(opt, log) 135 | # log.info(f"Found {len(recon_imgs_pts)} pt files={[pt.name for pt in recon_imgs_pts]}") 136 | 137 | # build torch array 138 | numpy_arr, numpy_label_arr = build_numpy_data(opt, log) 139 | log.info(f"Collected {numpy_arr.shape=}!") 140 | 141 | # compute accu 142 | accu = compute_accu(opt, numpy_arr, numpy_label_arr) 143 | log.info(f"Accuracy={accu:.3f}!") 144 | 145 | # load ref fid stat 146 | ref_fid_fn, ref_mu, ref_sigma = get_ref_fid(opt, log) 147 | log.info(f"Loaded FID reference statistics from {ref_fid_fn}!") 148 | 149 | # compute fid 150 | fid = fid_util.compute_fid_from_numpy(numpy_arr, ref_mu, ref_sigma, mode=opt.mode) 151 | log.info(f"FID(w.r.t. {ref_fid_fn=})={fid:.2f}!") 152 | 153 | res_dict = { 154 | "accu": accu, 155 | "fid": fid, 156 | } 157 | 158 | # save to file 159 | res_dir = opt.ckpt.split("/")[:-1] 160 | res_dir = Path("/".join(res_dir)) 161 | with open(res_dir / "res.json", "w") as f: 162 | json.dump(res_dict, f) 163 | 164 | if __name__ == '__main__': 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("--seed", type=int, default=0) 167 | parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device") 168 | parser.add_argument("--ckpt", type=str, default=None, help="the checkpoint name for which we wish to compute metrices") 169 | parser.add_argument("--label", type=str, default=None) 170 | parser.add_argument("--mode", type=str, default="legacy_pytorch", help="the FID computation mode used in clean-fid") 171 | parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset") 172 | parser.add_argument("--sample-dir", type=Path, default=None, help="directory where samples are stored") 173 | parser.add_argument("--image-size", type=int, default=256) 174 | 175 | arg = parser.parse_args() 176 | 177 | opt = edict( 178 | device="cuda", 179 | ) 180 | opt.update(vars(arg)) 181 | 182 | log_metrices(opt) 183 | -------------------------------------------------------------------------------- /evaluation/fid_util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | from tqdm import tqdm 10 | import numpy as np 11 | from pathlib import Path 12 | 13 | import torch 14 | import torchvision 15 | from cleanfid.resize import build_resizer 16 | from cleanfid.features import build_feature_extractor 17 | from cleanfid.fid import get_batch_features, frechet_distance 18 | 19 | FID_REF_DIR = Path("assets/stats") 20 | 21 | 22 | def collect_features( 23 | dataset, mode, batch_size, num_workers, device=torch.device("cuda"), use_dataparallel=True, verbose=True 24 | ): 25 | 26 | dataloader = torch.utils.data.DataLoader( 27 | dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers 28 | ) 29 | feat_model = build_feature_extractor(mode, device, use_dataparallel=use_dataparallel) 30 | l_feats = [] 31 | pbar = tqdm(dataloader, desc="FID") if verbose else dataloader 32 | for batch in pbar: 33 | l_feats.append(get_batch_features(batch, feat_model, device)) 34 | 35 | np_feats = np.concatenate(l_feats) 36 | mu = np.mean(np_feats, axis=0) 37 | sigma = np.cov(np_feats, rowvar=False) 38 | return mu, sigma 39 | 40 | 41 | class NumpyResizeDataset(torch.utils.data.Dataset): 42 | def __init__(self, dataset, mode, size=(299, 299)): 43 | self.dataset = dataset 44 | self.transforms = torchvision.transforms.ToTensor() 45 | self.size = size 46 | self.fn_resize = build_resizer(mode) 47 | self.custom_image_tranform = lambda x: x 48 | 49 | def get_img_np(self, i): 50 | return self.dataset[i] 51 | 52 | def __len__(self): 53 | return len(self.dataset) 54 | 55 | def __getitem__(self, i): 56 | img_np = self.get_img_np(i) 57 | 58 | # apply a custom image transform before resizing the image to 299x299 59 | img_np = self.custom_image_tranform(img_np) 60 | # fn_resize expects a np array and returns a np array 61 | img_resized = self.fn_resize(img_np) 62 | 63 | # ToTensor() converts to [0,1] only if input in uint8 64 | if img_resized.dtype == "uint8": 65 | img_t = self.transforms(np.array(img_resized)) * 255 66 | elif img_resized.dtype == "float32": 67 | img_t = self.transforms(img_resized) 68 | 69 | return img_t 70 | 71 | 72 | @torch.no_grad() 73 | def compute_fid_from_numpy(numpy_arr, ref_mu, ref_sigma, batch_size=256, mode="legacy_pytorch"): 74 | 75 | dataset = NumpyResizeDataset(numpy_arr, mode=mode) 76 | mu, sigma = collect_features( 77 | dataset, 78 | mode, 79 | num_workers=1, 80 | batch_size=batch_size, 81 | use_dataparallel=False, 82 | verbose=False, 83 | ) 84 | return frechet_distance(mu, sigma, ref_mu, ref_sigma) 85 | 86 | 87 | class LMDBResizeDataset(NumpyResizeDataset): 88 | def __init__(self, dataset, mode): 89 | super(LMDBResizeDataset, self).__init__(dataset, mode) 90 | 91 | def get_img_np(self, i): 92 | img_pil, _ = self.dataset[i] 93 | return np.array(img_pil) 94 | 95 | 96 | def compute_fid_ref_stat(opt, log): 97 | from datasets.imagenet_inpaint import build_lmdb_dataset 98 | from torchvision import transforms 99 | 100 | mode = opt.mode 101 | 102 | # build dataset 103 | transform = transforms.Compose( 104 | [ 105 | transforms.Resize(opt.image_size), 106 | transforms.CenterCrop(opt.image_size), 107 | ] 108 | ) 109 | lmdb_dataset = build_lmdb_dataset(opt.dataset_dir, opt.image_size, train=opt.split == "train", transform=transform) 110 | dataset = LMDBResizeDataset(lmdb_dataset, mode=mode) 111 | log.info(f"[FID] Built Imagenet {opt.split} dataset, size={len(dataset)}!") 112 | 113 | # compute fid statistics 114 | num_avail_cpus = len(os.sched_getaffinity(0)) 115 | num_workers = min(num_avail_cpus, 8) 116 | mu, sigma = collect_features(dataset, mode, batch_size=512, num_workers=num_workers) 117 | log.info(f"Collected inception features, {mu.shape=}, {sigma.shape=}!") 118 | 119 | # save and return statistics 120 | os.makedirs(FID_REF_DIR, exist_ok=True) 121 | fn = FID_REF_DIR / f"fid_imagenet_{opt.image_size}_{opt.split}.npz" 122 | np.savez(fn, mu=mu, sigma=sigma) 123 | log.info(f"Saved FID reference statistics to {fn}!") 124 | return mu, sigma 125 | 126 | 127 | if __name__ == "__main__": 128 | import argparse 129 | from logger import Logger 130 | 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument( 133 | "--split", type=str, choices=["train", "val"], help="which dataset to compute FID ref statistics" 134 | ) 135 | parser.add_argument("--mode", type=str, default="legacy_pytorch", help="the FID computation mode used in clean-fid") 136 | parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset") 137 | parser.add_argument("--image-size", type=int, default=256) 138 | opt = parser.parse_args() 139 | 140 | log = Logger(0, ".log") 141 | log.info(f"======== Compute FID ref statistics: mode={opt.mode} ========") 142 | compute_fid_ref_stat(opt, log) 143 | -------------------------------------------------------------------------------- /evaluation/resnet.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | 10 | import torch.nn.functional as F 11 | from collections import OrderedDict 12 | from torchvision.models import resnet50 13 | 14 | from ipdb import set_trace as debug 15 | 16 | 17 | class ImageNormalizer(torch.nn.Module): 18 | 19 | def __init__(self, mean, std) -> None: 20 | super(ImageNormalizer, self).__init__() 21 | 22 | self.register_buffer("mean", torch.as_tensor(mean).view(1, 3, 1, 1)) 23 | self.register_buffer("std", torch.as_tensor(std).view(1, 3, 1, 1)) 24 | 25 | def forward(self, image): 26 | # note: image should be in [-1,1] 27 | image = (image + 1) / 2 # [-1,1] -> [0,1] 28 | image = F.interpolate(image, size=(224, 224), mode="bicubic") 29 | return (image - self.mean) / self.std 30 | 31 | def __repr__(self): 32 | return f"ImageNormalizer(mean={self.mean.squeeze()}, std={self.std.squeeze()})" # type: ignore 33 | 34 | 35 | def normalize_model(model, mean, std): 36 | layers = OrderedDict([("normalize", ImageNormalizer(mean, std)), ("model", model)]) 37 | return torch.nn.Sequential(layers) 38 | 39 | 40 | def build_resnet50(): 41 | model = resnet50(pretrained=True) 42 | mu = (0.485, 0.456, 0.406) 43 | sigma = (0.229, 0.224, 0.225) 44 | model = normalize_model(model, mu, sigma) 45 | model.eval() 46 | return model 47 | -------------------------------------------------------------------------------- /evaluations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-ml/DiffusionBridge/92522733cc602686df77f07a1824bb89f89cda1a/evaluations/__init__.py -------------------------------------------------------------------------------- /evaluations/feature_extractor.py: -------------------------------------------------------------------------------- 1 | """ from clean fid """ 2 | 3 | import os 4 | import platform 5 | import numpy as np 6 | import torch 7 | import cleanfid 8 | from cleanfid.downloads_helper import check_download_url 9 | from .inception_pytorch import InceptionV3 10 | from .inception_torchscript import InceptionV3W 11 | 12 | 13 | """ 14 | returns a functions that takes an image in range [0,255] 15 | and outputs a feature embedding vector 16 | """ 17 | def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False, use_dataparallel=True): 18 | if name == "torchscript_inception": 19 | path = "./" if platform.system() == "Windows" else "/tmp" 20 | model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(device) 21 | model.eval() 22 | if use_dataparallel: 23 | model = torch.nn.DataParallel(model) 24 | def model_fn(x): return model(x) 25 | elif name == "pytorch_inception": 26 | model = InceptionV3(output_blocks=[3], resize_input=False).to(device) 27 | model.eval() 28 | if use_dataparallel: 29 | model = torch.nn.DataParallel(model) 30 | def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1) 31 | else: 32 | raise ValueError(f"{name} feature extractor not implemented") 33 | return model_fn 34 | 35 | 36 | """ 37 | Build a feature extractor for each of the modes 38 | """ 39 | def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True): 40 | assert not (mode == 'legacy_pytorch') 41 | if mode == "legacy_pytorch": 42 | feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel) 43 | elif mode == "legacy_tensorflow": 44 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device, use_dataparallel=use_dataparallel) 45 | elif mode == "clean": 46 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel) 47 | return feat_model 48 | -------------------------------------------------------------------------------- /evaluations/inception_pytorch.py: -------------------------------------------------------------------------------- 1 | """ 2 | File from: https://github.com/mseitzer/pytorch-fid 3 | """ 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | 11 | try: 12 | from torchvision.models.utils import load_state_dict_from_url 13 | except ImportError: 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | 16 | # Inception weights ported to Pytorch from 17 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 18 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 19 | 20 | 21 | class InceptionV3(nn.Module): 22 | """Pretrained InceptionV3 network returning feature maps""" 23 | 24 | # Index of default block of inception to return, 25 | # corresponds to output of final average pooling 26 | DEFAULT_BLOCK_INDEX = 3 27 | 28 | # Maps feature dimensionality to their output blocks indices 29 | BLOCK_INDEX_BY_DIM = { 30 | 64: 0, # First max pooling features 31 | 192: 1, # Second max pooling featurs 32 | 768: 2, # Pre-aux classifier features 33 | 2048: 3 # Final average pooling features 34 | } 35 | 36 | def __init__(self, 37 | output_blocks=(DEFAULT_BLOCK_INDEX,), 38 | resize_input=True, 39 | normalize_input=True, 40 | requires_grad=False, 41 | use_fid_inception=True): 42 | """Build pretrained InceptionV3 43 | Parameters 44 | ---------- 45 | output_blocks : list of int 46 | Indices of blocks to return features of. Possible values are: 47 | - 0: corresponds to output of first max pooling 48 | - 1: corresponds to output of second max pooling 49 | - 2: corresponds to output which is fed to aux classifier 50 | - 3: corresponds to output of final average pooling 51 | resize_input : bool 52 | If true, bilinearly resizes input to width and height 299 before 53 | feeding input to model. As the network without fully connected 54 | layers is fully convolutional, it should be able to handle inputs 55 | of arbitrary size, so resizing might not be strictly needed 56 | normalize_input : bool 57 | If true, scales the input from range (0, 1) to the range the 58 | pretrained Inception network expects, namely (-1, 1) 59 | requires_grad : bool 60 | If true, parameters of the model require gradients. Possibly useful 61 | for finetuning the network 62 | use_fid_inception : bool 63 | If true, uses the pretrained Inception model used in Tensorflow's 64 | FID implementation. If false, uses the pretrained Inception model 65 | available in torchvision. The FID Inception model has different 66 | weights and a slightly different structure from torchvision's 67 | Inception model. If you want to compute FID scores, you are 68 | strongly advised to set this parameter to true to get comparable 69 | results. 70 | """ 71 | super(InceptionV3, self).__init__() 72 | 73 | self.resize_input = resize_input 74 | self.normalize_input = normalize_input 75 | self.output_blocks = sorted(output_blocks) 76 | self.last_needed_block = max(output_blocks) 77 | 78 | assert self.last_needed_block <= 3, \ 79 | 'Last possible output block index is 3' 80 | 81 | self.blocks = nn.ModuleList() 82 | 83 | if use_fid_inception: 84 | inception = fid_inception_v3() 85 | else: 86 | inception = _inception_v3(pretrained=True) 87 | 88 | # Block 0: input to maxpool1 89 | block0 = [ 90 | inception.Conv2d_1a_3x3, 91 | inception.Conv2d_2a_3x3, 92 | inception.Conv2d_2b_3x3, 93 | nn.MaxPool2d(kernel_size=3, stride=2) 94 | ] 95 | self.blocks.append(nn.Sequential(*block0)) 96 | 97 | # Block 1: maxpool1 to maxpool2 98 | if self.last_needed_block >= 1: 99 | block1 = [ 100 | inception.Conv2d_3b_1x1, 101 | inception.Conv2d_4a_3x3, 102 | nn.MaxPool2d(kernel_size=3, stride=2) 103 | ] 104 | self.blocks.append(nn.Sequential(*block1)) 105 | 106 | # Block 2: maxpool2 to aux classifier 107 | if self.last_needed_block >= 2: 108 | block2 = [ 109 | inception.Mixed_5b, 110 | inception.Mixed_5c, 111 | inception.Mixed_5d, 112 | inception.Mixed_6a, 113 | inception.Mixed_6b, 114 | inception.Mixed_6c, 115 | inception.Mixed_6d, 116 | inception.Mixed_6e, 117 | ] 118 | self.blocks.append(nn.Sequential(*block2)) 119 | 120 | # Block 3: aux classifier to final avgpool 121 | if self.last_needed_block >= 3: 122 | block3 = [ 123 | inception.Mixed_7a, 124 | inception.Mixed_7b, 125 | inception.Mixed_7c, 126 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 127 | ] 128 | self.blocks.append(nn.Sequential(*block3)) 129 | 130 | for param in self.parameters(): 131 | param.requires_grad = requires_grad 132 | 133 | def forward(self, inp): 134 | """Get Inception feature maps 135 | Parameters 136 | ---------- 137 | inp : torch.autograd.Variable 138 | Input tensor of shape Bx3xHxW. Values are expected to be in 139 | range (0, 1) 140 | Returns 141 | ------- 142 | List of torch.autograd.Variable, corresponding to the selected output 143 | block, sorted ascending by index 144 | """ 145 | outp = [] 146 | x = inp 147 | 148 | if self.resize_input: 149 | raise ValueError("should not resize here") 150 | x = F.interpolate(x, 151 | size=(299, 299), 152 | mode='bilinear', 153 | align_corners=False) 154 | 155 | if self.normalize_input: 156 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 157 | 158 | for idx, block in enumerate(self.blocks): 159 | x = block(x) 160 | if idx in self.output_blocks: 161 | outp.append(x) 162 | 163 | if idx == self.last_needed_block: 164 | break 165 | 166 | return outp 167 | 168 | 169 | def _inception_v3(*args, **kwargs): 170 | """Wraps `torchvision.models.inception_v3` 171 | Skips default weight inititialization if supported by torchvision version. 172 | See https://github.com/mseitzer/pytorch-fid/issues/28. 173 | """ 174 | try: 175 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 176 | except ValueError: 177 | # Just a caution against weird version strings 178 | version = (0,) 179 | 180 | if version >= (0, 6): 181 | kwargs['init_weights'] = False 182 | 183 | return torchvision.models.inception_v3(*args, **kwargs) 184 | 185 | 186 | def fid_inception_v3(): 187 | """Build pretrained Inception model for FID computation 188 | The Inception model for FID computation uses a different set of weights 189 | and has a slightly different structure than torchvision's Inception. 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /evaluations/inception_torchscript.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from cleanfid.downloads_helper import * 5 | import contextlib 6 | 7 | 8 | @contextlib.contextmanager 9 | def disable_gpu_fuser_on_pt19(): 10 | # On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See 11 | # https://github.com/GaParmar/clean-fid/issues/5 12 | # https://github.com/pytorch/pytorch/issues/64062 13 | if torch.__version__.startswith('1.9.'): 14 | old_val = torch._C._jit_can_fuse_on_gpu() 15 | torch._C._jit_override_can_fuse_on_gpu(False) 16 | yield 17 | if torch.__version__.startswith('1.9.'): 18 | torch._C._jit_override_can_fuse_on_gpu(old_val) 19 | 20 | 21 | class InceptionV3W(nn.Module): 22 | """ 23 | Wrapper around Inception V3 torchscript model provided here 24 | https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt 25 | 26 | path: locally saved inception weights 27 | """ 28 | def __init__(self, path, download=True, resize_inside=False): 29 | super(InceptionV3W, self).__init__() 30 | # download the network if it is not present at the given directory 31 | # use the current directory by default 32 | if download: 33 | check_download_inception(fpath=path) 34 | path = os.path.join(path, "inception-2015-12-05.pt") 35 | self.base = torch.jit.load(path).eval() 36 | self.layers = self.base.layers 37 | self.output = self.base.output 38 | self.resize_inside = resize_inside 39 | 40 | """ 41 | Get the inception features without resizing 42 | x: Image with values in range [0,255] 43 | """ 44 | def forward(self, x): 45 | with disable_gpu_fuser_on_pt19(): 46 | bs = x.shape[0] 47 | if self.resize_inside: 48 | features = self.base(x, return_features=True).view((bs, 2048)) 49 | else: 50 | # make sure it is resized already 51 | assert (x.shape[2] == 299) and (x.shape[3] == 299) 52 | # apply normalization 53 | x1 = x - 128 54 | x2 = x1 / 128 55 | features = self.layers.forward(x2, ).view((bs, 2048)) 56 | return features, self.output(features) -------------------------------------------------------------------------------- /evaluations/inception_v3.py: -------------------------------------------------------------------------------- 1 | # Ported from the model here: 2 | # https://github.com/NVlabs/stylegan3/blob/407db86e6fe432540a22515310188288687858fa/metrics/frechet_inception_distance.py#L22 3 | # 4 | # I have verified that the spatial features and output features are correct 5 | # within a mean absolute error of ~3e-5. 6 | 7 | import collections 8 | 9 | import torch 10 | 11 | 12 | class Conv2dLayer(torch.nn.Module): 13 | def __init__(self, in_channels, out_channels, kh, kw, stride=1, padding=0): 14 | super().__init__() 15 | self.stride = stride 16 | self.padding = padding 17 | self.weight = torch.nn.Parameter(torch.zeros(out_channels, in_channels, kh, kw)) 18 | self.beta = torch.nn.Parameter(torch.zeros(out_channels)) 19 | self.mean = torch.nn.Parameter(torch.zeros(out_channels)) 20 | self.var = torch.nn.Parameter(torch.zeros(out_channels)) 21 | 22 | def forward(self, x): 23 | x = torch.nn.functional.conv2d( 24 | x, self.weight.to(x.dtype), stride=self.stride, padding=self.padding 25 | ) 26 | x = torch.nn.functional.batch_norm( 27 | x, running_mean=self.mean, running_var=self.var, bias=self.beta, eps=1e-3 28 | ) 29 | x = torch.nn.functional.relu(x) 30 | return x 31 | 32 | 33 | # ---------------------------------------------------------------------------- 34 | 35 | 36 | class InceptionA(torch.nn.Module): 37 | def __init__(self, in_channels, tmp_channels): 38 | super().__init__() 39 | self.conv = Conv2dLayer(in_channels, 64, kh=1, kw=1) 40 | self.tower = torch.nn.Sequential( 41 | collections.OrderedDict( 42 | [ 43 | ("conv", Conv2dLayer(in_channels, 48, kh=1, kw=1)), 44 | ("conv_1", Conv2dLayer(48, 64, kh=5, kw=5, padding=2)), 45 | ] 46 | ) 47 | ) 48 | self.tower_1 = torch.nn.Sequential( 49 | collections.OrderedDict( 50 | [ 51 | ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)), 52 | ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)), 53 | ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, padding=1)), 54 | ] 55 | ) 56 | ) 57 | self.tower_2 = torch.nn.Sequential( 58 | collections.OrderedDict( 59 | [ 60 | ( 61 | "pool", 62 | torch.nn.AvgPool2d( 63 | kernel_size=3, stride=1, padding=1, count_include_pad=False 64 | ), 65 | ), 66 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), 67 | ] 68 | ) 69 | ) 70 | 71 | def forward(self, x): 72 | return torch.cat( 73 | [ 74 | self.conv(x).contiguous(), 75 | self.tower(x).contiguous(), 76 | self.tower_1(x).contiguous(), 77 | self.tower_2(x).contiguous(), 78 | ], 79 | dim=1, 80 | ) 81 | 82 | 83 | # ---------------------------------------------------------------------------- 84 | 85 | 86 | class InceptionB(torch.nn.Module): 87 | def __init__(self, in_channels): 88 | super().__init__() 89 | self.conv = Conv2dLayer(in_channels, 384, kh=3, kw=3, stride=2) 90 | self.tower = torch.nn.Sequential( 91 | collections.OrderedDict( 92 | [ 93 | ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)), 94 | ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)), 95 | ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, stride=2)), 96 | ] 97 | ) 98 | ) 99 | self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2) 100 | 101 | def forward(self, x): 102 | return torch.cat( 103 | [ 104 | self.conv(x).contiguous(), 105 | self.tower(x).contiguous(), 106 | self.pool(x).contiguous(), 107 | ], 108 | dim=1, 109 | ) 110 | 111 | 112 | # ---------------------------------------------------------------------------- 113 | 114 | 115 | class InceptionC(torch.nn.Module): 116 | def __init__(self, in_channels, tmp_channels): 117 | super().__init__() 118 | self.conv = Conv2dLayer(in_channels, 192, kh=1, kw=1) 119 | self.tower = torch.nn.Sequential( 120 | collections.OrderedDict( 121 | [ 122 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), 123 | ( 124 | "conv_1", 125 | Conv2dLayer( 126 | tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3] 127 | ), 128 | ), 129 | ( 130 | "conv_2", 131 | Conv2dLayer(tmp_channels, 192, kh=7, kw=1, padding=[3, 0]), 132 | ), 133 | ] 134 | ) 135 | ) 136 | self.tower_1 = torch.nn.Sequential( 137 | collections.OrderedDict( 138 | [ 139 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), 140 | ( 141 | "conv_1", 142 | Conv2dLayer( 143 | tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0] 144 | ), 145 | ), 146 | ( 147 | "conv_2", 148 | Conv2dLayer( 149 | tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3] 150 | ), 151 | ), 152 | ( 153 | "conv_3", 154 | Conv2dLayer( 155 | tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0] 156 | ), 157 | ), 158 | ( 159 | "conv_4", 160 | Conv2dLayer(tmp_channels, 192, kh=1, kw=7, padding=[0, 3]), 161 | ), 162 | ] 163 | ) 164 | ) 165 | self.tower_2 = torch.nn.Sequential( 166 | collections.OrderedDict( 167 | [ 168 | ( 169 | "pool", 170 | torch.nn.AvgPool2d( 171 | kernel_size=3, stride=1, padding=1, count_include_pad=False 172 | ), 173 | ), 174 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), 175 | ] 176 | ) 177 | ) 178 | 179 | def forward(self, x): 180 | return torch.cat( 181 | [ 182 | self.conv(x).contiguous(), 183 | self.tower(x).contiguous(), 184 | self.tower_1(x).contiguous(), 185 | self.tower_2(x).contiguous(), 186 | ], 187 | dim=1, 188 | ) 189 | 190 | 191 | # ---------------------------------------------------------------------------- 192 | 193 | 194 | class InceptionD(torch.nn.Module): 195 | def __init__(self, in_channels): 196 | super().__init__() 197 | self.tower = torch.nn.Sequential( 198 | collections.OrderedDict( 199 | [ 200 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), 201 | ("conv_1", Conv2dLayer(192, 320, kh=3, kw=3, stride=2)), 202 | ] 203 | ) 204 | ) 205 | self.tower_1 = torch.nn.Sequential( 206 | collections.OrderedDict( 207 | [ 208 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), 209 | ("conv_1", Conv2dLayer(192, 192, kh=1, kw=7, padding=[0, 3])), 210 | ("conv_2", Conv2dLayer(192, 192, kh=7, kw=1, padding=[3, 0])), 211 | ("conv_3", Conv2dLayer(192, 192, kh=3, kw=3, stride=2)), 212 | ] 213 | ) 214 | ) 215 | self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2) 216 | 217 | def forward(self, x): 218 | return torch.cat( 219 | [ 220 | self.tower(x).contiguous(), 221 | self.tower_1(x).contiguous(), 222 | self.pool(x).contiguous(), 223 | ], 224 | dim=1, 225 | ) 226 | 227 | 228 | # ---------------------------------------------------------------------------- 229 | 230 | 231 | class InceptionE(torch.nn.Module): 232 | def __init__(self, in_channels, use_avg_pool): 233 | super().__init__() 234 | self.conv = Conv2dLayer(in_channels, 320, kh=1, kw=1) 235 | self.tower_conv = Conv2dLayer(in_channels, 384, kh=1, kw=1) 236 | self.tower_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1]) 237 | self.tower_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0]) 238 | self.tower_1_conv = Conv2dLayer(in_channels, 448, kh=1, kw=1) 239 | self.tower_1_conv_1 = Conv2dLayer(448, 384, kh=3, kw=3, padding=1) 240 | self.tower_1_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1]) 241 | self.tower_1_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0]) 242 | if use_avg_pool: 243 | self.tower_2_pool = torch.nn.AvgPool2d( 244 | kernel_size=3, stride=1, padding=1, count_include_pad=False 245 | ) 246 | else: 247 | self.tower_2_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 248 | self.tower_2_conv = Conv2dLayer(in_channels, 192, kh=1, kw=1) 249 | 250 | def forward(self, x): 251 | a = self.tower_conv(x) 252 | b = self.tower_1_conv_1(self.tower_1_conv(x)) 253 | return torch.cat( 254 | [ 255 | self.conv(x).contiguous(), 256 | self.tower_mixed_conv(a).contiguous(), 257 | self.tower_mixed_conv_1(a).contiguous(), 258 | self.tower_1_mixed_conv(b).contiguous(), 259 | self.tower_1_mixed_conv_1(b).contiguous(), 260 | self.tower_2_conv(self.tower_2_pool(x)).contiguous(), 261 | ], 262 | dim=1, 263 | ) 264 | 265 | 266 | # ---------------------------------------------------------------------------- 267 | 268 | 269 | class InceptionV3(torch.nn.Module): 270 | def __init__(self): 271 | super().__init__() 272 | self.layers = torch.nn.Sequential( 273 | collections.OrderedDict( 274 | [ 275 | ("conv", Conv2dLayer(3, 32, kh=3, kw=3, stride=2)), 276 | ("conv_1", Conv2dLayer(32, 32, kh=3, kw=3)), 277 | ("conv_2", Conv2dLayer(32, 64, kh=3, kw=3, padding=1)), 278 | ("pool0", torch.nn.MaxPool2d(kernel_size=3, stride=2)), 279 | ("conv_3", Conv2dLayer(64, 80, kh=1, kw=1)), 280 | ("conv_4", Conv2dLayer(80, 192, kh=3, kw=3)), 281 | ("pool1", torch.nn.MaxPool2d(kernel_size=3, stride=2)), 282 | ("mixed", InceptionA(192, tmp_channels=32)), 283 | ("mixed_1", InceptionA(256, tmp_channels=64)), 284 | ("mixed_2", InceptionA(288, tmp_channels=64)), 285 | ("mixed_3", InceptionB(288)), 286 | ("mixed_4", InceptionC(768, tmp_channels=128)), 287 | ("mixed_5", InceptionC(768, tmp_channels=160)), 288 | ("mixed_6", InceptionC(768, tmp_channels=160)), 289 | ("mixed_7", InceptionC(768, tmp_channels=192)), 290 | ("mixed_8", InceptionD(768)), 291 | ("mixed_9", InceptionE(1280, use_avg_pool=True)), 292 | ("mixed_10", InceptionE(2048, use_avg_pool=False)), 293 | ("pool2", torch.nn.AvgPool2d(kernel_size=8)), 294 | ] 295 | ) 296 | ) 297 | self.output = torch.nn.Linear(2048, 1008) 298 | 299 | def forward( 300 | self, 301 | img, 302 | return_features: bool = True, 303 | use_fp16: bool = False, 304 | no_output_bias: bool = False, 305 | ): 306 | batch_size, channels, height, width = img.shape # [NCHW] 307 | assert channels == 3 308 | 309 | # Cast to float. 310 | x = img.to(torch.float16 if use_fp16 else torch.float32) 311 | 312 | # Emulate tf.image.resize_bilinear(x, [299, 299]), including the funky alignment. 313 | new_width, new_height = 299, 299 314 | theta = torch.eye(2, 3, device=x.device) 315 | theta[0, 2] += theta[0, 0] / width - theta[0, 0] / new_width 316 | theta[1, 2] += theta[1, 1] / height - theta[1, 1] / new_height 317 | theta = theta.to(x.dtype).unsqueeze(0).repeat([batch_size, 1, 1]) 318 | grid = torch.nn.functional.affine_grid( 319 | theta, [batch_size, channels, new_height, new_width], align_corners=False 320 | ) 321 | x = torch.nn.functional.grid_sample( 322 | x, grid, mode="bilinear", padding_mode="border", align_corners=False 323 | ) 324 | 325 | # Scale dynamic range from [0,255] to [-1,1[. 326 | x -= 128 327 | x /= 128 328 | 329 | # Main layers. 330 | intermediate = self.layers[:-6](x) 331 | spatial_features = ( 332 | self.layers[-6] 333 | .conv(intermediate)[:, :7] 334 | .permute(0, 2, 3, 1) 335 | .reshape(-1, 2023) 336 | ) 337 | features = self.layers[-6:](intermediate).reshape(-1, 2048).to(torch.float32) 338 | if return_features: 339 | return features, spatial_features 340 | 341 | # Output layer. 342 | return self.acts_to_probs(features, no_output_bias=no_output_bias) 343 | 344 | def acts_to_probs(self, features, no_output_bias: bool = False): 345 | if no_output_bias: 346 | logits = torch.nn.functional.linear(features, self.output.weight) 347 | else: 348 | logits = self.output(features) 349 | probs = torch.nn.functional.softmax(logits, dim=1) 350 | return probs 351 | 352 | def create_softmax_model(self): 353 | return SoftmaxModel(self.output.weight) 354 | 355 | 356 | class SoftmaxModel(torch.nn.Module): 357 | def __init__(self, weight: torch.Tensor): 358 | super().__init__() 359 | self.weight = torch.nn.Parameter(weight.detach().clone()) 360 | 361 | def forward(self, x): 362 | logits = torch.nn.functional.linear(x, self.weight) 363 | probs = torch.nn.functional.softmax(logits, dim=1) 364 | return probs 365 | -------------------------------------------------------------------------------- /evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import logging 11 | from rich.console import Console 12 | from rich.logging import RichHandler 13 | 14 | 15 | def get_time(sec): 16 | h = int(sec // 3600) 17 | m = int((sec // 60) % 60) 18 | s = int(sec % 60) 19 | return h, m, s 20 | 21 | 22 | class TimeFilter(logging.Filter): 23 | 24 | def filter(self, record): 25 | try: 26 | start = self.start 27 | except AttributeError: 28 | start = self.start = time.time() 29 | 30 | time_elapsed = get_time(time.time() - start) 31 | 32 | record.relative = "{0}:{1:02d}:{2:02d}".format(*time_elapsed) 33 | 34 | # self.last = record.relativeCreated/1000.0 35 | return True 36 | 37 | 38 | class Logger(object): 39 | def __init__(self, rank=0, log_dir=".log"): 40 | # other libraries may set logging before arriving at this line. 41 | # by reloading logging, we can get rid of previous configs set by other libraries. 42 | from importlib import reload 43 | 44 | reload(logging) 45 | self.rank = rank 46 | if self.rank == 0: 47 | os.makedirs(log_dir, exist_ok=True) 48 | 49 | log_file = open(os.path.join(log_dir, "log.txt"), "w") 50 | file_console = Console(file=log_file, width=150) 51 | logging.basicConfig( 52 | level=logging.INFO, 53 | format="(%(relative)s) %(message)s", 54 | datefmt="[%X]", 55 | force=True, 56 | handlers=[RichHandler(show_path=False), RichHandler(console=file_console, show_path=False)], 57 | ) 58 | # https://stackoverflow.com/questions/31521859/python-logging-module-time-since-last-log 59 | log = logging.getLogger() 60 | [hndl.addFilter(TimeFilter()) for hndl in log.handlers] 61 | 62 | def info(self, string, *args): 63 | if self.rank == 0: 64 | logging.info(string, *args) 65 | 66 | def warning(self, string, *args): 67 | if self.rank == 0: 68 | logging.warning(string, *args) 69 | 70 | def error(self, string, *args): 71 | if self.rank == 0: 72 | logging.error(string, *args) 73 | -------------------------------------------------------------------------------- /preprocess_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # process DDBM checkpoints 4 | 5 | state_dict = torch.load("assets/ckpts/e2h_ema_0.9999_420000.pt", map_location="cpu") 6 | 7 | module_list = [] 8 | 9 | for i in range(5, 16): 10 | if i == 8 or i == 12: 11 | continue 12 | module_list.append(f"input_blocks.{i}.1.qkv.weight") 13 | module_list.append(f"input_blocks.{i}.1.proj_out.weight") 14 | 15 | module_list.append("middle_block.1.qkv.weight") 16 | module_list.append("middle_block.1.proj_out.weight") 17 | 18 | for i in range(0, 12): 19 | module_list.append(f"output_blocks.{i}.1.qkv.weight") 20 | module_list.append(f"output_blocks.{i}.1.proj_out.weight") 21 | 22 | for name in module_list: 23 | state_dict[name] = state_dict[name].squeeze(-1) 24 | 25 | torch.save(state_dict, "assets/ckpts/e2h_ema_0.9999_420000_adapted.pt") 26 | 27 | state_dict = torch.load("assets/ckpts/diode_ema_0.9999_440000.pt", map_location="cpu") 28 | 29 | module_list = [] 30 | 31 | for i in range(10, 18): 32 | if i == 12 or i == 15: 33 | continue 34 | module_list.append(f"input_blocks.{i}.1.qkv.weight") 35 | module_list.append(f"input_blocks.{i}.1.proj_out.weight") 36 | 37 | module_list.append("middle_block.1.qkv.weight") 38 | module_list.append("middle_block.1.proj_out.weight") 39 | 40 | for i in range(0, 9): 41 | module_list.append(f"output_blocks.{i}.1.qkv.weight") 42 | module_list.append(f"output_blocks.{i}.1.proj_out.weight") 43 | 44 | for name in module_list: 45 | state_dict[name] = state_dict[name].squeeze(-1) 46 | 47 | torch.save(state_dict, "assets/ckpts/diode_ema_0.9999_440000_adapted.pt") 48 | -------------------------------------------------------------------------------- /preprocess_depth.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | from matplotlib import pyplot as plt 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | from joblib import Parallel, delayed 9 | from tqdm import tqdm 10 | import csv 11 | 12 | SPLIT = "train" 13 | img_size = 256 14 | data_dir = "assets/datasets/DIODE/data_list/train_outdoor.csv" 15 | 16 | 17 | def plot_depth_map(dm, validity_mask, name): 18 | validity_mask = validity_mask > 0 19 | MIN_DEPTH = 0.5 20 | MAX_DEPTH = min(300, np.percentile(dm, 99)) 21 | dm = np.clip(dm, MIN_DEPTH, MAX_DEPTH) 22 | dm = np.log(dm, where=validity_mask) 23 | 24 | dm = np.ma.masked_where(~validity_mask, dm) 25 | 26 | cmap = plt.cm.get_cmap("jet") 27 | cmap.set_bad(color="black") 28 | norm = plt.Normalize(vmin=0, vmax=np.log(MAX_DEPTH + 1.01)) 29 | image = cmap(norm(dm)) 30 | plt.imsave(name, np.clip(image, 0.0, 1.0)) 31 | 32 | 33 | def plot_normal_map(normal_map, name): 34 | normal_viz = normal_map[:, ::, :] 35 | 36 | normal_viz = normal_viz + np.equal(np.sum(normal_viz, 2, keepdims=True), 0.0).astype(np.float32) * np.min( 37 | normal_viz 38 | ) 39 | 40 | normal_viz = (normal_viz - np.min(normal_viz)) / 2.0 41 | plt.axis("off") 42 | plt.imsave(name, np.clip(normal_viz, 0.0, 1.0)) 43 | 44 | 45 | all_files = [] 46 | with open(data_dir, newline="") as csvfile: 47 | reader = csv.reader(csvfile, delimiter=",") 48 | 49 | for row in reader: 50 | if row[-1] == "Unavailable": 51 | continue 52 | all_files.append(row[0].split("/")[-1]) 53 | 54 | 55 | def process(file): 56 | scene_id, scan_id = file.split("_")[0], file.split("_")[1] 57 | base_path = f"assets/datasets/DIODE/{SPLIT}/outdoor/scene_{scene_id}/scan_{scan_id}" 58 | path = os.path.join(base_path, file) 59 | pil_image = Image.open(path).convert("RGB").resize((img_size, img_size), Image.BICUBIC) 60 | 61 | path2 = os.path.join(base_path, file[:-4] + "_depth.npy") 62 | depth = np.load(path2).squeeze() 63 | depth = depth.astype(np.float32) 64 | 65 | path3 = os.path.join(base_path, file[:-4] + "_depth_mask.npy") 66 | depth_mask = np.load(path3) 67 | depth_mask = depth_mask.astype(np.float32) 68 | 69 | path4 = os.path.join(base_path, file[:-4] + "_normal.npy") 70 | normal = np.load(path4) 71 | normal = normal.astype(np.float32) 72 | 73 | image_depth = cv2.resize(depth, dsize=(img_size, img_size), interpolation=cv2.INTER_NEAREST) 74 | image_depth_mask = cv2.resize(depth_mask, dsize=(img_size, img_size), interpolation=cv2.INTER_NEAREST) 75 | 76 | normal = cv2.resize(normal, dsize=(img_size, img_size), interpolation=cv2.INTER_NEAREST) 77 | 78 | name = os.path.join(target_dir, file) 79 | pil_image.save(name) 80 | plot_depth_map(image_depth, image_depth_mask, name[:-4] + "_depth.png") 81 | plot_normal_map(normal, name[:-4] + "_normal.png") 82 | 83 | 84 | target_dir = f"assets/datasets/DIODE-{img_size}/{SPLIT}" 85 | os.makedirs(target_dir, exist_ok=True) 86 | Parallel(n_jobs=1)(delayed(process)(name) for name in tqdm(all_files)) 87 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch 11 | import torchvision.utils as vutils 12 | import torch.distributed as dist 13 | 14 | from ddbm import dist_util, logger 15 | from ddbm.script_util import ( 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | from ddbm.karras_diffusion import karras_sample 22 | 23 | from datasets import load_data 24 | 25 | from pathlib import Path 26 | 27 | 28 | def main(): 29 | args = create_argparser().parse_args() 30 | args.use_fp16 = False 31 | 32 | workdir = os.path.join("workdir", os.path.basename(args.model_path)[:-3]) 33 | 34 | ## assume ema ckpt format: ema_{rate}_{steps}.pt 35 | split = args.model_path.replace("_adapted", "").split("_") 36 | step = int(split[-1].split(".")[0]) 37 | if args.sampler == "dbim": 38 | sample_dir = Path(workdir) / f"sample_{step}/split={args.split}/dbim_eta={args.eta}/steps={args.steps}" 39 | elif args.sampler == "dbim_high_order": 40 | sample_dir = Path(workdir) / f"sample_{step}/split={args.split}/dbim_order={args.order}/steps={args.steps}" 41 | else: 42 | sample_dir = Path(workdir) / f"sample_{step}/split={args.split}/{args.sampler}/steps={args.steps}" 43 | dist_util.setup_dist() 44 | if dist.get_rank() == 0: 45 | 46 | sample_dir.mkdir(parents=True, exist_ok=True) 47 | logger.configure(dir=str(sample_dir)) 48 | 49 | logger.log("creating model and diffusion...") 50 | model, diffusion = create_model_and_diffusion( 51 | **args_to_dict(args, model_and_diffusion_defaults().keys()), 52 | ) 53 | model.load_state_dict(torch.load(args.model_path, map_location="cpu")) 54 | model = model.to(dist_util.dev()) 55 | 56 | if args.use_fp16: 57 | model = model.half() 58 | model.eval() 59 | 60 | logger.log("sampling...") 61 | 62 | all_images = [] 63 | all_labels = [] 64 | 65 | all_dataloaders = load_data( 66 | data_dir=args.data_dir, 67 | dataset=args.dataset, 68 | batch_size=args.batch_size, 69 | image_size=args.image_size, 70 | include_test=(args.split == "test"), 71 | seed=args.seed, 72 | num_workers=args.num_workers, 73 | ) 74 | if args.split == "train": 75 | dataloader = all_dataloaders[1] 76 | elif args.split == "test": 77 | dataloader = all_dataloaders[2] 78 | else: 79 | raise NotImplementedError 80 | args.num_samples = len(dataloader.dataset) 81 | num = 0 82 | for i, data in enumerate(dataloader): 83 | 84 | x0_image = data[0] 85 | x0 = x0_image.to(dist_util.dev()) 86 | 87 | y0_image = data[1].to(dist_util.dev()) 88 | y0 = y0_image 89 | 90 | model_kwargs = {"xT": y0} 91 | 92 | if "inpaint" in args.dataset: 93 | _, mask, label = data[2] 94 | mask = mask.to(dist_util.dev()) 95 | label = label.to(dist_util.dev()) 96 | model_kwargs["y"] = label 97 | else: 98 | mask = None 99 | 100 | indexes = data[2][0].numpy() 101 | sample, path, nfe, pred_x0, sigmas, _ = karras_sample( 102 | diffusion, 103 | model, 104 | y0, 105 | x0, 106 | steps=args.steps, 107 | mask=mask, 108 | model_kwargs=model_kwargs, 109 | device=dist_util.dev(), 110 | clip_denoised=args.clip_denoised, 111 | sampler=args.sampler, 112 | churn_step_ratio=args.churn_step_ratio, 113 | eta=args.eta, 114 | order=args.order, 115 | seed=indexes + args.seed, 116 | ) 117 | 118 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 119 | sample = sample.permute(0, 2, 3, 1) 120 | sample = sample.contiguous() 121 | 122 | gathered_samples = [torch.zeros_like(sample) for _ in range(dist.get_world_size())] 123 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 124 | gathered_samples = torch.cat(gathered_samples) 125 | if "inpaint" in args.dataset: 126 | gathered_labels = [torch.zeros_like(label) for _ in range(dist.get_world_size())] 127 | dist.all_gather(gathered_labels, label) 128 | gathered_labels = torch.cat(gathered_labels) 129 | num += gathered_samples.shape[0] 130 | 131 | num_display = min(32, sample.shape[0]) 132 | if i == 0 and dist.get_rank() == 0: 133 | vutils.save_image( 134 | sample.permute(0, 3, 1, 2)[:num_display].float() / 255, 135 | f"{sample_dir}/sample_{i}.png", 136 | nrow=int(np.sqrt(num_display)), 137 | ) 138 | if x0 is not None: 139 | vutils.save_image( 140 | x0_image[:num_display] / 2 + 0.5, 141 | f"{sample_dir}/x_{i}.png", 142 | nrow=int(np.sqrt(num_display)), 143 | ) 144 | vutils.save_image( 145 | y0_image[:num_display] / 2 + 0.5, 146 | f"{sample_dir}/y_{i}.png", 147 | nrow=int(np.sqrt(num_display)), 148 | ) 149 | 150 | all_images.append(gathered_samples.detach().cpu().numpy()) 151 | if "inpaint" in args.dataset: 152 | all_labels.append(gathered_labels.detach().cpu().numpy()) 153 | 154 | if dist.get_rank() == 0: 155 | logger.log(f"sampled {num} images") 156 | 157 | logger.log(f"created {len(all_images) * args.batch_size * dist.get_world_size()} samples") 158 | 159 | arr = np.concatenate(all_images, axis=0) 160 | arr = arr[: args.num_samples] 161 | if "inpaint" in args.dataset: 162 | labels = np.concatenate(all_labels, axis=0) 163 | labels = labels[: args.num_samples] 164 | 165 | if dist.get_rank() == 0: 166 | shape_str = "x".join([str(x) for x in arr.shape]) 167 | out_path = os.path.join(sample_dir, f"samples_{shape_str}_nfe{nfe}.npz") 168 | logger.log(f"saving to {out_path}") 169 | np.savez(out_path, arr) 170 | if "inpaint" in args.dataset: 171 | shape_str = "x".join([str(x) for x in labels.shape]) 172 | out_path = os.path.join(sample_dir, f"labels_{shape_str}_nfe{nfe}.npz") 173 | logger.log(f"saving to {out_path}") 174 | np.savez(out_path, labels) 175 | 176 | dist.barrier() 177 | logger.log("sampling complete") 178 | 179 | 180 | def create_argparser(): 181 | defaults = dict( 182 | data_dir="", ## only used in bridge 183 | dataset="edges2handbags", 184 | clip_denoised=True, 185 | num_samples=10000, 186 | batch_size=16, 187 | sampler="heun", 188 | split="train", 189 | churn_step_ratio=0.0, 190 | rho=7.0, 191 | steps=40, 192 | model_path="", 193 | exp="", 194 | seed=42, 195 | num_workers=8, 196 | eta=1.0, 197 | order=1, 198 | ) 199 | defaults.update(model_and_diffusion_defaults()) 200 | parser = argparse.ArgumentParser() 201 | add_dict_to_argparser(parser, defaults) 202 | return parser 203 | 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /scripts/args.sh: -------------------------------------------------------------------------------- 1 | DATASET_NAME=$1 2 | 3 | if [[ $DATASET_NAME == "e2h" ]]; then 4 | DATA_DIR=assets/datasets/edges2handbags 5 | DATASET=edges2handbags 6 | IMG_SIZE=64 7 | 8 | NUM_CH=192 9 | NUM_RES_BLOCKS=3 10 | ATTN_TYPE=True 11 | 12 | EXP="e2h${IMG_SIZE}_${NUM_CH}d" 13 | SAVE_ITER=100000 14 | MICRO_BS=64 15 | DROPOUT=0.1 16 | CLASS_COND=False 17 | 18 | PRED="vp" 19 | elif [[ $DATASET_NAME == "diode" ]]; then 20 | DATA_DIR=assets/datasets/DIODE-256 21 | DATASET=diode 22 | IMG_SIZE=256 23 | 24 | NUM_CH=256 25 | NUM_RES_BLOCKS=2 26 | ATTN_TYPE=True 27 | 28 | EXP="diode${IMG_SIZE}_${NUM_CH}d" 29 | SAVE_ITER=20000 30 | MICRO_BS=16 31 | DROPOUT=0.1 32 | CLASS_COND=False 33 | 34 | PRED="vp" 35 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then 36 | DATA_DIR=assets/datasets/ImageNet 37 | DATASET=imagenet_inpaint_center 38 | IMG_SIZE=256 39 | 40 | NUM_CH=256 41 | NUM_RES_BLOCKS=2 42 | ATTN_TYPE=False 43 | 44 | EXP="imagenet_inpaint_center${IMG_SIZE}_${NUM_CH}d" 45 | SAVE_ITER=20000 46 | MICRO_BS=16 47 | DROPOUT=0 48 | CLASS_COND=True 49 | 50 | PRED="i2sb_cond" 51 | fi 52 | 53 | if [[ $PRED == "ve" ]]; then 54 | EXP+="_ve" 55 | COND=concat 56 | SIGMA_MAX=80.0 57 | SIGMA_MIN=0.002 58 | elif [[ $PRED == "vp" ]]; then 59 | EXP+="_vp" 60 | COND=concat 61 | BETA_D=2 62 | BETA_MIN=0.1 63 | SIGMA_MAX=1 64 | SIGMA_MIN=0.0001 65 | elif [[ $PRED == "i2sb_cond" ]]; then 66 | EXP+="_i2sb_cond" 67 | COND=concat 68 | BETA_MAX=1.0 69 | BETA_MIN=0.1 70 | SIGMA_MAX=1 71 | SIGMA_MIN=0.0001 72 | else 73 | echo "Not supported" 74 | exit 1 75 | fi -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | 3 | DATASET_NAME=$1 4 | NFE=$2 5 | GEN_SAMPLER=$3 6 | 7 | if [[ $DATASET_NAME == "e2h" ]]; then 8 | SPLIT=train 9 | PREFIX=e2h_ema_0.9999_420000_adapted/sample_420000 10 | REF_PATH=assets/stats/edges2handbags_ref_64_data.npz 11 | SAMPLE_NAME=samples_138567x64x64x3_nfe${NFE}.npz 12 | elif [[ $DATASET_NAME == "diode" ]]; then 13 | SPLIT=train 14 | PREFIX=diode_ema_0.9999_440000_adapted/sample_440000 15 | REF_PATH=assets/stats/diode_ref_256_data.npz 16 | SAMPLE_NAME=samples_16502x256x256x3_nfe${NFE}.npz 17 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then 18 | SPLIT=test 19 | PREFIX=imagenet256_inpaint_ema_0.9999_400000/sample_400000 20 | DATA_DIR=assets/datasets/ImageNet 21 | SAMPLE_NAME=samples_10000x256x256x3_nfe${NFE}.npz 22 | LABEL_NAME=labels_10000_nfe${NFE}.npz 23 | fi 24 | 25 | if [[ $GEN_SAMPLER == "heun" ]]; then 26 | N=$(echo "$NFE" | awk '{print ($1 + 1) / 3}') 27 | N=$(printf "%.0f" "$N") 28 | SAMPLER="heun" 29 | elif [[ $GEN_SAMPLER == "dbim" ]]; then 30 | N=$((NFE-1)) 31 | ETA=$4 32 | SAMPLER="dbim_eta=${ETA}" 33 | elif [[ $GEN_SAMPLER == "dbim_high_order" ]]; then 34 | N=$((NFE-1)) 35 | ORDER=$4 36 | SAMPLER="dbim_order=${ORDER}" 37 | fi 38 | 39 | # For example: 40 | # SAMPLE_PATH="workdir/e2h_ema_0.9999_420000_adapted/sample_420000/split=train/dbim_eta=0.0/steps=4/samples_138567x64x64x3_nfe5.npz" 41 | # SAMPLE_PATH="workdir/diode_ema_0.9999_440000_adapted/sample_440000/split=train/dbim_eta=0.0/steps=4/samples_16502x256x256x3_nfe5.npz" 42 | # SAMPLE_PATH="workdir/imagenet256_inpaint_ema_0.9999_400000/sample_400000/split=test/dbim_order=3/steps=9/samples_10000x256x256x3_nfe10.npz" 43 | 44 | SAMPLE_DIR=workdir/${PREFIX}/split=${SPLIT}/${SAMPLER}/steps=${N} 45 | SAMPLE_PATH=${SAMPLE_DIR}/${SAMPLE_NAME} 46 | 47 | if [[ $DATASET_NAME == "e2h" || $DATASET_NAME == "diode" ]]; then 48 | python evaluations/evaluator.py $REF_PATH $SAMPLE_PATH --metric fid 49 | python evaluations/evaluator.py $REF_PATH $SAMPLE_PATH --metric lpips 50 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then 51 | LABEL_PATH=${SAMPLE_DIR}/${LABEL_NAME} 52 | python evaluation/compute_metrices_imagenet.py --ckpt $SAMPLE_PATH --label $LABEL_PATH --dataset-dir $DATA_DIR 53 | python evaluations/evaluator.py "" $SAMPLE_PATH --metric is 54 | fi -------------------------------------------------------------------------------- /scripts/sample.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | 3 | # For cluster 4 | # export ADDR=$1 5 | # run_args="--nproc_per_node 8 \ 6 | # --master_addr $ADDR \ 7 | # --node_rank $RANK \ 8 | # --master_port $MASTER_PORT \ 9 | # --nnodes $WORLD_SIZE" 10 | # For local 11 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 12 | run_args="--nproc_per_node 8 \ 13 | --master_port 29511" 14 | 15 | # Batch size per GPU 16 | BS=16 17 | 18 | # Dataset and checkpoint 19 | DATASET_NAME=$1 20 | 21 | if [[ $DATASET_NAME == "e2h" ]]; then 22 | SPLIT=train 23 | MODEL_PATH=assets/ckpts/e2h_ema_0.9999_420000_adapted.pt 24 | elif [[ $DATASET_NAME == "diode" ]]; then 25 | SPLIT=train 26 | MODEL_PATH=assets/ckpts/diode_ema_0.9999_440000_adapted.pt 27 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then 28 | SPLIT=test 29 | MODEL_PATH=assets/ckpts/imagenet256_inpaint_ema_0.9999_400000.pt 30 | fi 31 | 32 | source scripts/args.sh $DATASET_NAME 33 | 34 | # Number of function evaluations (NFE) 35 | NFE=$2 36 | 37 | # Sampler 38 | GEN_SAMPLER=$3 39 | 40 | if [[ $GEN_SAMPLER == "heun" ]]; then 41 | N=$(echo "$NFE" | awk '{print ($1 + 1) / 3}') 42 | N=$(printf "%.0f" "$N") 43 | # Default setting in the DDBM paper 44 | CHURN_STEP_RATIO=0.33 45 | elif [[ $GEN_SAMPLER == "dbim" ]]; then 46 | N=$((NFE-1)) 47 | ETA=$4 48 | elif [[ $GEN_SAMPLER == "dbim_high_order" ]]; then 49 | N=$((NFE-1)) 50 | ORDER=$4 51 | fi 52 | 53 | torchrun $run_args sample.py --steps $N --sampler $GEN_SAMPLER --batch_size $BS \ 54 | --model_path $MODEL_PATH --class_cond $CLASS_COND --noise_schedule $PRED \ 55 | ${BETA_D:+ --beta_d="${BETA_D}"} ${BETA_MIN:+ --beta_min="${BETA_MIN}"} ${BETA_MAX:+ --beta_max="${BETA_MAX}"} \ 56 | --condition_mode=$COND --sigma_max=$SIGMA_MAX --sigma_min=$SIGMA_MIN \ 57 | --dropout $DROPOUT --image_size $IMG_SIZE --num_channels $NUM_CH --num_res_blocks $NUM_RES_BLOCKS \ 58 | --use_new_attention_order $ATTN_TYPE --data_dir=$DATA_DIR --dataset=$DATASET --split $SPLIT\ 59 | ${CHURN_STEP_RATIO:+ --churn_step_ratio="${CHURN_STEP_RATIO}"} \ 60 | ${ETA:+ --eta="${ETA}"} \ 61 | ${ORDER:+ --order="${ORDER}"} 62 | -------------------------------------------------------------------------------- /scripts/train_bridge.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./ 2 | 3 | DATASET_NAME=imagenet_inpaint_center 4 | TRAIN_MODE=ddbm 5 | 6 | source scripts/args.sh $DATASET_NAME 7 | 8 | FREQ_SAVE_ITER=5000 9 | EXP=${DATASET_NAME}-${TRAIN_MODE} 10 | 11 | CKPT=assets/ckpts/256x256_diffusion_fixedsigma.pt 12 | 13 | # For cluster 14 | # export ADDR=$1 15 | # run_args="--nproc_per_node 8 \ 16 | # --master_addr $ADDR \ 17 | # --node_rank $RANK \ 18 | # --master_port $MASTER_PORT \ 19 | # --nnodes $WORLD_SIZE" 20 | # For local 21 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 22 | run_args="--nproc_per_node 8 \ 23 | --master_port 29511" 24 | 25 | MICRO_BS=4 26 | 27 | torchrun $run_args train.py --exp=$EXP \ 28 | --class_cond $CLASS_COND \ 29 | --dropout $DROPOUT --microbatch $MICRO_BS \ 30 | --image_size $IMG_SIZE --num_channels $NUM_CH \ 31 | --num_res_blocks $NUM_RES_BLOCKS --condition_mode=$COND \ 32 | --noise_schedule=$PRED \ 33 | --use_new_attention_order $ATTN_TYPE \ 34 | ${BETA_D:+ --beta_d="${BETA_D}"} ${BETA_MIN:+ --beta_min="${BETA_MIN}"} ${BETA_MAX:+ --beta_max="${BETA_MAX}"} \ 35 | --data_dir=$DATA_DIR --dataset=$DATASET \ 36 | --sigma_max=$SIGMA_MAX --sigma_min=$SIGMA_MIN \ 37 | --save_interval_for_preemption=$FREQ_SAVE_ITER --save_interval=$SAVE_ITER --debug=False \ 38 | ${CKPT:+ --resume_checkpoint="${CKPT}"} -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from ddbm import dist_util, logger 8 | from datasets import load_data 9 | from ddbm.resample import create_named_schedule_sampler 10 | from ddbm.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | sample_defaults, 14 | args_to_dict, 15 | add_dict_to_argparser, 16 | get_workdir, 17 | ) 18 | from ddbm.train_util import TrainLoop 19 | 20 | import torch.distributed as dist 21 | 22 | from pathlib import Path 23 | 24 | import wandb 25 | 26 | from glob import glob 27 | import os 28 | from datasets.augment import AugmentPipe 29 | 30 | 31 | def main(args): 32 | 33 | workdir = get_workdir(args.exp) 34 | Path(workdir).mkdir(parents=True, exist_ok=True) 35 | 36 | dist_util.setup_dist() 37 | logger.configure(dir=workdir) 38 | if dist.get_rank() == 0: 39 | name = args.exp if args.resume_checkpoint == "" else args.exp + "_resume" 40 | wandb.init( 41 | project="bridge", 42 | group=args.exp, 43 | name=name, 44 | config=vars(args), 45 | mode="offline" if not args.debug else "disabled", 46 | ) 47 | logger.log("creating model and diffusion...") 48 | 49 | data_image_size = args.image_size 50 | 51 | # Load target model 52 | resume_train_flag = False 53 | if args.resume_checkpoint == "": 54 | model_ckpts = list(glob(f"{workdir}/*model*[0-9].*")) 55 | if len(model_ckpts) > 0: 56 | max_ckpt = max(model_ckpts, key=lambda x: int(x.split("model_")[-1].split(".")[0])) 57 | if os.path.exists(max_ckpt): 58 | args.resume_checkpoint = max_ckpt 59 | resume_train_flag = True 60 | elif args.pretrained_ckpt is not None: 61 | max_ckpt = args.pretrained_ckpt 62 | args.resume_checkpoint = max_ckpt 63 | if dist.get_rank() == 0 and args.resume_checkpoint != "": 64 | logger.log("Resuming from checkpoint: ", max_ckpt) 65 | 66 | model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys())) 67 | model.to(dist_util.dev()) 68 | 69 | if dist.get_rank() == 0: 70 | wandb.watch(model, log="all") 71 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 72 | 73 | if args.batch_size == -1: 74 | batch_size = args.global_batch_size // dist.get_world_size() 75 | if args.global_batch_size % dist.get_world_size() != 0: 76 | logger.log(f"warning, using smaller global_batch_size of {dist.get_world_size()*batch_size} instead of {args.global_batch_size}") 77 | else: 78 | batch_size = args.batch_size 79 | 80 | if dist.get_rank() == 0: 81 | logger.log("creating data loader...") 82 | 83 | data, test_data = load_data( 84 | data_dir=args.data_dir, 85 | dataset=args.dataset, 86 | batch_size=batch_size, 87 | image_size=data_image_size, 88 | num_workers=args.num_workers, 89 | ) 90 | 91 | if args.use_augment: 92 | augment = AugmentPipe(p=0.12, xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 93 | else: 94 | augment = None 95 | 96 | logger.log("training...") 97 | TrainLoop( 98 | model=model, 99 | diffusion=diffusion, 100 | train_data=data, 101 | test_data=test_data, 102 | batch_size=batch_size, 103 | microbatch=-1 if args.microbatch >= batch_size else args.microbatch, 104 | lr=args.lr, 105 | ema_rate=args.ema_rate, 106 | log_interval=args.log_interval, 107 | test_interval=args.test_interval, 108 | save_interval=args.save_interval, 109 | save_interval_for_preemption=args.save_interval_for_preemption, 110 | resume_checkpoint=args.resume_checkpoint, 111 | workdir=workdir, 112 | use_fp16=args.use_fp16, 113 | fp16_scale_growth=args.fp16_scale_growth, 114 | schedule_sampler=schedule_sampler, 115 | weight_decay=args.weight_decay, 116 | lr_anneal_steps=args.lr_anneal_steps, 117 | augment_pipe=augment, 118 | train_mode=args.train_mode, 119 | resume_train_flag=resume_train_flag, 120 | **sample_defaults(), 121 | ).run_loop() 122 | 123 | 124 | def create_argparser(): 125 | defaults = dict( 126 | data_dir="", 127 | dataset="edges2handbags", 128 | schedule_sampler="real-uniform", 129 | lr=1e-4, 130 | weight_decay=0.0, 131 | lr_anneal_steps=0, 132 | global_batch_size=256, 133 | batch_size=-1, 134 | microbatch=-1, # -1 disables microbatches 135 | ema_rate="0.9999", # comma-separated list of EMA values 136 | log_interval=50, 137 | test_interval=500, 138 | save_interval=10000, 139 | save_interval_for_preemption=50000, 140 | resume_checkpoint="", 141 | exp="", 142 | use_fp16=True, 143 | fp16_scale_growth=1e-3, 144 | debug=False, 145 | num_workers=8, 146 | use_augment=False, 147 | pretrained_ckpt=None, 148 | train_mode="ddbm", 149 | ) 150 | defaults.update(model_and_diffusion_defaults()) 151 | parser = argparse.ArgumentParser() 152 | add_dict_to_argparser(parser, defaults) 153 | return parser 154 | 155 | 156 | if __name__ == "__main__": 157 | args = create_argparser().parse_args() 158 | main(args) 159 | --------------------------------------------------------------------------------