├── .gitignore
├── README.md
├── assets
├── datasets
│ ├── download_extract_DIODE.sh
│ ├── download_extract_ImageNet.sh
│ ├── download_extract_edges2handbags.sh
│ ├── val_faster_imagefolder_10k_fn.txt
│ └── val_faster_imagefolder_10k_label.txt
└── teaser.png
├── corruption
├── LICENSE_DDRM
├── LICENSE_DDRM_JPEG
├── __init__.py
├── base.py
├── blur.py
├── inpaint.py
├── jpeg.py
├── mixture.py
└── superresolution.py
├── datasets
├── __init__.py
├── aligned_dataset.py
├── augment.py
├── image_folder.py
├── imagenet_inpaint.py
└── misc.py
├── ddbm
├── __init__.py
├── dist_util.py
├── karras_diffusion.py
├── logger.py
├── nn.py
├── random_util.py
├── resample.py
├── script_util.py
├── train_util.py
└── unet.py
├── download_diffusion.py
├── evaluation
├── __init__.py
├── compute_metrices_imagenet.py
├── fid_util.py
└── resnet.py
├── evaluations
├── __init__.py
├── evaluator.py
├── feature_extractor.py
├── inception_pytorch.py
├── inception_torchscript.py
├── inception_v3.py
└── requirements.txt
├── logger.py
├── preprocess_ckpt.py
├── preprocess_depth.py
├── sample.py
├── scripts
├── args.sh
├── evaluate.sh
├── sample.sh
└── train_bridge.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.egg-info
3 | wandb
4 | workdir
5 | *.pt
6 | *.ipynb
7 | *.npz
8 | .ipynb_checkpoints
9 | assets/datasets/DIODE/
10 | assets/datasets/ImageNet/
11 | assets/datasets/edges2handbags/
12 | assets/datasets/DIODE-256/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Accelerating Diffusion Bridge Models for Image-to-Image Translation
2 |
3 | Official Implementation of [Diffusion Bridge Implicit Models](https://arxiv.org/abs/2405.15885) (ICLR 2025) and [Consistency Diffusion Bridge Models](https://arxiv.org/abs/2410.22637) (NeurIPS 2024).
4 |
5 | # Dependencies
6 |
7 | To install all packages in this codebase along with their dependencies, run
8 | ```sh
9 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
10 | pip install blobfile piq matplotlib opencv-python joblib lmdb scipy clean-fid easydict torchmetrics rich ipdb wandb
11 | ```
12 |
13 | # Datasets
14 |
15 | Please put (or link) the datasets under `assets/datasets/`.
16 |
17 | - For Edges2Handbags, please follow instructions from [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md). The resulting folder structure should be `assets/datasets/edges2handbags/train` and `assets/datasets/edges2handbags/val`.
18 | - For DIODE, please download the training dataset and the data list from [here](https://diode-dataset.org/). The resulting folder structure should be `assets/datasets/DIODE/train` and `assets/datasets/DIODE/data_list`.
19 | - For ImageNet, please download the dataset from [here](https://image-net.org/download.php). The resulting folder structure should be `assets/datasets/ImageNet/train` and `assets/datasets/ImageNet/val`.
20 |
21 | We also provide automatic downloading scripts.
22 | ```
23 | cd assets/datasets
24 | bash download_extract_edges2handbags.sh
25 | bash download_extract_DIODE.sh
26 | bash download_extract_ImageNet.sh
27 | ```
28 |
29 | After downloading, the DIODE dataset requires preprocessing by running `python preprocess_depth.py`.
30 |
31 | # Diffusion Bridge Implicit Models
32 |
33 |
34 |
35 |
36 |
37 | DBIM offers a suite of fast samplers tailored for [Denoising Diffusion Bridge Models (DDBMs)](https://arxiv.org/abs/2309.16948). We clean the codebase to support a broad range of diffusion bridges, facilitating unified training and sampling workflows. We also streamline the deployment process by replacing the cumbersome MPI-based distributed launcher with the more efficient and engineer-friendly `torchrun`.
38 |
39 | ## Pre-trained models
40 |
41 | Please put the downloaded checkpoints under `assets/ckpts/`.
42 |
43 | For image translation, we directly adopt the pretrained checkpoints from [DDBM](https://github.com/alexzhou907/DDBM):
44 |
45 | - Edges2Handbags: [e2h_ema_0.9999_420000.pt](https://huggingface.co/alexzhou907/DDBM/resolve/main/e2h_ema_0.9999_420000.pt)
46 | - DIODE: [diode_ema_0.9999_440000.pt](https://huggingface.co/alexzhou907/DDBM/resolve/main/diode_ema_0.9999_440000.pt)
47 |
48 | We remove the dependency on external packages such as `flash_attn` in this codebase, which is already supported natively by PyTorch. After downloading the two checkpoints above, please run `python preprocess_ckpt.py` to complete the conversion.
49 |
50 | For image restoration:
51 |
52 | - Center 128x128 Inpainting on ImageNet 256x256: [imagenet256_inpaint_ema_0.9999_400000.pt](https://drive.google.com/file/d/1WozJyVOAFukj0nUYLS-ZUp1-QHuGNfox)
53 |
54 | ## Sampling
55 |
56 | ```
57 | bash scripts/sample.sh $DATASET_NAME $NFE $SAMPLER ($AUX)
58 | ```
59 |
60 | - `$DATASET_NAME` can be chosen from `e2h`/`diode`/`imagenet_inpaint_center`.
61 | - `$NFE` is the *Number of Function Evaluations*, which is proportional to the sampling time.
62 | - `$SAMPLER` can be chosen from `heun`/`dbim`/`dbim_high_order`.
63 | - `heun` is the vanilla sampler of DDBM, which simulates the SDE/ODE step alternatively. In this case, `$AUX` is not required.
64 | - `dbim` and `dbim_high_order` are our proposed samplers. When using `dbim`, `$AUX` corresponds to $\eta$ which controls the stochasticity level (floating-point value in $[0,1]$). When using `dbim_high_order`, `$AUX` corresponds to the order (2 or 3).
65 |
66 | The samples will be saved to `workdir/`.
67 |
68 | ## Evaluations
69 |
70 | Before evaluating the image translation results, please download the reference statistics from DDBM and put them under `assets/stats/`:
71 | - Reference stats for Edge2Handbags: [edges2handbags_ref_64_data.npz](https://huggingface.co/alexzhou907/DDBM/resolve/main/edges2handbags_ref_64_data.npz).
72 | - Reference stats for DIODE: [diode_ref_256_data.npz](https://huggingface.co/alexzhou907/DDBM/resolve/main/diode_ref_256_data.npz).
73 |
74 | The evaluation can automatically proceed by specifying the same dataset and sampler arguments as sampling:
75 |
76 | ```
77 | bash scripts/evaluate.sh $DATASET_NAME $NFE $SAMPLER ($AUX)
78 | ```
79 |
80 | ## Training
81 |
82 | We provide the script for training diffusion bridge models on the ImageNet 256x256 inpainting (center 128x128) task. As ImageNet 256x256 is a challenging dataset, we follow [I2SB](https://github.com/NVlabs/I2SB) and initialize the network with a pretrained diffusion model [ADM](https://github.com/openai/guided-diffusion/). To start training, run
83 |
84 | ```
85 | python download_diffusion.py
86 | bash scripts/train_bridge.sh
87 | ```
88 |
89 | Training for 400k iterations takes around 2 weeks on 8xA100. We recommend using multiple nodes and modifying `scripts/train_bridge.sh` accordingly.
90 |
91 | # Consistency Diffusion Bridge Models
92 |
93 | TODO
94 |
95 | # Acknowledgement
96 |
97 | This codebase is built upon [DDBM](https://github.com/alexzhou907/DDBM) and [I2SB](https://github.com/NVlabs/I2SB).
98 |
99 |
100 | # Citation
101 |
102 | If you find this method and/or code useful, please consider citing
103 |
104 | ```
105 | @article{zheng2024diffusion,
106 | title={Diffusion Bridge Implicit Models},
107 | author={Zheng, Kaiwen and He, Guande and Chen, Jianfei and Bao, Fan and Zhu, Jun},
108 | journal={arXiv preprint arXiv:2405.15885},
109 | year={2024}
110 | }
111 | ```
112 |
113 | and
114 |
115 | ```
116 | @article{he2024consistency,
117 | title={Consistency Diffusion Bridge Models},
118 | author={He, Guande and Zheng, Kaiwen and Chen, Jianfei and Bao, Fan and Zhu, Jun},
119 | journal={arXiv preprint arXiv:2410.22637},
120 | year={2024}
121 | }
122 | ```
--------------------------------------------------------------------------------
/assets/datasets/download_extract_DIODE.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 |
4 | mkdir DIODE && cd DIODE
5 |
6 | wget http://diode-dataset.s3.amazonaws.com/train.tar.gz
7 | wget http://diode-dataset.s3.amazonaws.com/train_normals.tar.gz
8 | wget https://diode-1254389886.cos.ap-hongkong.myqcloud.com/data_list.zip
9 |
10 | tar -xvzf train.tar.gz && rm -f train.tar.gz
11 | tar -xvzf train_normals.tar.gz && rm -f train_normals.tar.gz
12 | unzip data_list.zip && rm -f data_list.zip
--------------------------------------------------------------------------------
/assets/datasets/download_extract_ImageNet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 |
4 | wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar --no-check-certificate
5 | wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar --no-check-certificate
6 |
7 | # script to extract ImageNet dataset
8 | # ILSVRC2012_img_train.tar (about 138 GB)
9 | # ILSVRC2012_img_val.tar (about 6.3 GB)
10 | # make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory
11 | #
12 | # Adapted from:
13 | # https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md
14 | # https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4
15 | #
16 | # ImageNet/train/
17 | # ├── n01440764
18 | # │ ├── n01440764_10026.JPEG
19 | # │ ├── n01440764_10027.JPEG
20 | # │ ├── ......
21 | # ├── ......
22 | # ImageNet/val/
23 | # ├── n01440764
24 | # │ ├── ILSVRC2012_val_00000293.JPEG
25 | # │ ├── ILSVRC2012_val_00002138.JPEG
26 | # │ ├── ......
27 | # ├── ......
28 | #
29 | #
30 | # Make imagnet directory
31 | #
32 | mkdir ImageNet
33 | #
34 | # Extract the training data:
35 | #
36 | # Create train directory; move .tar file; change directory
37 | mkdir ImageNet/train && mv ILSVRC2012_img_train.tar ImageNet/train/ && cd ImageNet/train
38 | # Extract training set; remove compressed file
39 | tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
40 | #
41 | # At this stage ImageNet/train will contain 1000 compressed .tar files, one for each category
42 | #
43 | # For each .tar file:
44 | # 1. create directory with same name as .tar file
45 | # 2. extract and copy contents of .tar file into directory
46 | # 3. remove .tar file
47 | find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
48 | #
49 | # This results in a training directory like so:
50 | #
51 | # ImageNet/train/
52 | # ├── n01440764
53 | # │ ├── n01440764_10026.JPEG
54 | # │ ├── n01440764_10027.JPEG
55 | # │ ├── ......
56 | # ├── ......
57 | #
58 | # Change back to original directory
59 | cd ../..
60 | #
61 | # Extract the validation data and move images to subfolders:
62 | #
63 | # Create validation directory; move .tar file; change directory; extract validation .tar; remove compressed file
64 | mkdir ImageNet/val && mv ILSVRC2012_img_val.tar ImageNet/val/ && cd ImageNet/val && tar -xvf ILSVRC2012_img_val.tar && rm -f ILSVRC2012_img_val.tar
65 | # get script from soumith and run; this script creates all class directories and moves images into corresponding directories
66 | wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
67 | #
68 | # This results in a validation directory like so:
69 | #
70 | # ImageNet/val/
71 | # ├── n01440764
72 | # │ ├── ILSVRC2012_val_00000293.JPEG
73 | # │ ├── ILSVRC2012_val_00002138.JPEG
74 | # │ ├── ......
75 | # ├── ......
76 | #
77 | #
78 | # Check total files after extract
79 | #
80 | # $ find train/ -name "*.JPEG" | wc -l
81 | # 1281167
82 | # $ find val/ -name "*.JPEG" | wc -l
83 | # 50000
84 | #
85 |
--------------------------------------------------------------------------------
/assets/datasets/download_extract_edges2handbags.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 |
4 | wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/edges2handbags.tar.gz
5 |
6 | tar -xvzf edges2handbags.tar.gz && rm -f edges2handbags.tar.gz
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/DiffusionBridge/92522733cc602686df77f07a1824bb89f89cda1a/assets/teaser.png
--------------------------------------------------------------------------------
/corruption/LICENSE_DDRM:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Bahjat Kawar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/corruption/LICENSE_DDRM_JPEG:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright [2022] [Bahjat Kawar, Jiaming Song, Stefano Ermon, Michael Elad]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/corruption/__init__.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 |
9 | def build_corruption(opt, log, corrupt_type=None):
10 |
11 | if corrupt_type is None:
12 | corrupt_type = opt.corrupt
13 |
14 | if "inpaint" in corrupt_type:
15 | from .inpaint import build_inpaint_center, build_inpaint_freeform
16 |
17 | mask = corrupt_type.split("-")[1]
18 | assert mask in ["center", "freeform1020", "freeform2030"]
19 | if mask == "center":
20 | method = build_inpaint_center(opt, log, mask)
21 | elif "freeform" in mask:
22 | method = build_inpaint_freeform(opt, log, mask)
23 |
24 | elif "jpeg" in corrupt_type:
25 | from .jpeg import build_jpeg
26 |
27 | quality_factor = int(corrupt_type.split("-")[1])
28 | method = build_jpeg(log, quality_factor)
29 |
30 | elif "sr4x" in corrupt_type:
31 | from .superresolution import build_sr4x
32 |
33 | sr_filter = corrupt_type.split("-")[1]
34 | assert sr_filter in ["pool", "bicubic"]
35 | method = build_sr4x(opt, log, sr_filter, image_size=opt.image_size)
36 |
37 | elif "blur" in corrupt_type:
38 | from .blur import build_blur
39 |
40 | kernel = corrupt_type.split("-")[1]
41 | assert kernel in ["uni", "gauss"]
42 | method = build_blur(opt, log, kernel)
43 |
44 | elif "mixture" in corrupt_type:
45 | method = None #
46 | else:
47 | raise RuntimeWarning(f"Unknown corruption: {corrupt_type}!")
48 |
49 | return method
50 |
--------------------------------------------------------------------------------
/corruption/base.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from ddrm.
5 | #
6 | # Source:
7 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L3
8 | #
9 | # The license for the original version of this file can be
10 | # found in this directory (LICENSE_DDRM).
11 | # The modifications to this file are subject to the same license.
12 | # ---------------------------------------------------------------
13 |
14 |
15 | class H_functions:
16 | """
17 | A class replacing the SVD of a matrix H, perhaps efficiently.
18 | All input vectors are of shape (Batch, ...).
19 | All output vectors are of shape (Batch, DataDimension).
20 | """
21 |
22 | def V(self, vec):
23 | """
24 | Multiplies the input vector by V
25 | """
26 | raise NotImplementedError()
27 |
28 | def Vt(self, vec):
29 | """
30 | Multiplies the input vector by V transposed
31 | """
32 | raise NotImplementedError()
33 |
34 | def U(self, vec):
35 | """
36 | Multiplies the input vector by U
37 | """
38 | raise NotImplementedError()
39 |
40 | def Ut(self, vec):
41 | """
42 | Multiplies the input vector by U transposed
43 | """
44 | raise NotImplementedError()
45 |
46 | def singulars(self):
47 | """
48 | Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U)
49 | """
50 | raise NotImplementedError()
51 |
52 | def add_zeros(self, vec):
53 | """
54 | Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V)
55 | """
56 | raise NotImplementedError()
57 |
58 | def H(self, vec):
59 | """
60 | Multiplies the input vector by H
61 | """
62 | temp = self.Vt(vec)
63 | singulars = self.singulars()
64 | return self.U(singulars * temp[:, : singulars.shape[0]])
65 |
66 | def Ht(self, vec):
67 | """
68 | Multiplies the input vector by H transposed
69 | """
70 | temp = self.Ut(vec)
71 | singulars = self.singulars()
72 | return self.V(self.add_zeros(singulars * temp[:, : singulars.shape[0]]))
73 |
74 | def H_pinv(self, vec):
75 | """
76 | Multiplies the input vector by the pseudo inverse of H
77 | """
78 | temp = self.Ut(vec)
79 | singulars = self.singulars()
80 | temp[:, : singulars.shape[0]] = temp[:, : singulars.shape[0]] / singulars
81 | return self.V(self.add_zeros(temp))
82 |
--------------------------------------------------------------------------------
/corruption/blur.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from ddrm.
5 | #
6 | # Source:
7 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L397
8 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L245
9 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L251
10 | #
11 | # The license for the original version of this file can be
12 | # found in this directory (LICENSE_DDRM).
13 | # The modifications to this file are subject to the same license.
14 | # ---------------------------------------------------------------
15 |
16 | import torch
17 | from .base import H_functions
18 |
19 | class Deblurring(H_functions):
20 | def mat_by_img(self, M, v):
21 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
22 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)
23 |
24 | def img_by_mat(self, v, M):
25 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
26 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])
27 |
28 | def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2):
29 | self.img_dim = img_dim
30 | self.channels = channels
31 | #build 1D conv matrix
32 | H_small = torch.zeros(img_dim, img_dim, device=device)
33 | for i in range(img_dim):
34 | for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
35 | if j < 0 or j >= img_dim: continue
36 | H_small[i, j] = kernel[j - i + kernel.shape[0]//2]
37 | #get the svd of the 1D conv
38 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False)
39 | #ZERO = 3e-2
40 | self.singulars_small[self.singulars_small < ZERO] = 0
41 | #calculate the singular values of the big matrix
42 | self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2)
43 | #sort the big matrix singulars and save the permutation
44 | self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)
45 |
46 | def V(self, vec):
47 | #invert the permutation
48 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
49 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
50 | temp = temp.permute(0, 2, 1)
51 | #multiply the image by V from the left and by V^T from the right
52 | out = self.mat_by_img(self.V_small, temp)
53 | out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
54 | return out
55 |
56 | def Vt(self, vec):
57 | #multiply the image by V^T from the left and by V from the right
58 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
59 | temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1)
60 | #permute the entries according to the singular values
61 | temp = temp[:, :, self._perm].permute(0, 2, 1)
62 | return temp.reshape(vec.shape[0], -1)
63 |
64 | def U(self, vec):
65 | #invert the permutation
66 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
67 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
68 | temp = temp.permute(0, 2, 1)
69 | #multiply the image by U from the left and by U^T from the right
70 | out = self.mat_by_img(self.U_small, temp)
71 | out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1)
72 | return out
73 |
74 | def Ut(self, vec):
75 | #multiply the image by U^T from the left and by U from the right
76 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone())
77 | temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1)
78 | #permute the entries according to the singular values
79 | temp = temp[:, :, self._perm].permute(0, 2, 1)
80 | return temp.reshape(vec.shape[0], -1)
81 |
82 | def singulars(self):
83 | return self._singulars.repeat(1, 3).reshape(-1)
84 |
85 | def add_zeros(self, vec):
86 | return vec.clone().reshape(vec.shape[0], -1)
87 |
88 | def build_blur(opt, log, kernel):
89 | log.info(f"[Corrupt] Bluring {kernel=}...")
90 |
91 | uni = Deblurring(torch.Tensor([1/9] * 9).to(opt.device), 3, opt.image_size, opt.device)
92 |
93 | sigma = 10
94 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2]))
95 | g_kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(opt.device)
96 | gauss = Deblurring(g_kernel / g_kernel.sum(), 3, opt.image_size, opt.device)
97 |
98 | xdim = (3, opt.image_size, opt.image_size)
99 |
100 | assert kernel in ["uni", "gauss"]
101 | def blur(img):
102 | # img: [-1,1] -> [0,1]
103 | img = (img + 1) / 2
104 | if kernel == "uni":
105 | img = uni.H(img).reshape(img.shape[0], *xdim)
106 | elif kernel == "gauss":
107 | img = gauss.H(img).reshape(img.shape[0], *xdim)
108 | # [0,1] -> [-1,1]
109 | return img * 2 - 1
110 |
111 | return blur
112 |
--------------------------------------------------------------------------------
/corruption/inpaint.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from diffusion_palette_eval.
5 | #
6 | # Source:
7 | # https://bit.ly/eval-pix2pix
8 | #
9 | # ---------------------------------------------------------------
10 |
11 | import io
12 |
13 | import os
14 |
15 | import numpy as np
16 | import torch
17 |
18 | from pathlib import Path
19 |
20 | # import gdown
21 | # from ipdb import set_trace as debug
22 |
23 | FREEFORM_URL = "https://drive.google.com/file/d/1-5YRGsekjiRKQWqo0BV5RVQu0bagc12w/view?usp=share_link"
24 |
25 |
26 | # code adoptted from
27 | # https://bit.ly/eval- pix2pix
28 | def bbox2mask(img_shape, bbox, dtype="uint8"):
29 | """Generate mask in ndarray from bbox.
30 |
31 | The returned mask has the shape of (h, w, 1). '1' indicates the
32 | hole and '0' indicates the valid regions.
33 |
34 | We prefer to use `uint8` as the data type of masks, which may be different
35 | from other codes in the community.
36 |
37 | Args:
38 | img_shape (tuple[int]): The size of the image.
39 | bbox (tuple[int]): Configuration tuple, (top, left, height, width)
40 | dtype (str): Indicate the data type of returned masks. Default: 'uint8'
41 |
42 | Return:
43 | numpy.ndarray: Mask in the shape of (h, w, 1).
44 | """
45 |
46 | height, width = img_shape[:2]
47 |
48 | mask = np.zeros((height, width, 1), dtype=dtype)
49 | mask[bbox[0] : bbox[0] + bbox[2], bbox[1] : bbox[1] + bbox[3], :] = 1
50 |
51 | return mask
52 |
53 |
54 | # code adoptted from
55 | # https://bit.ly/eval-pix2pix
56 | def load_masks(filename):
57 | # filename = "imagenet_freeform_masks.npz"
58 | shape = [10000, 256, 256]
59 |
60 | # shape = [10950, 256, 256] # Uncomment this for places2.
61 |
62 | # Load the npz file.
63 | with open(filename, "rb") as f:
64 | data = f.read()
65 |
66 | data = dict(np.load(io.BytesIO(data)))
67 | # print("Categories of masks:")
68 | # for key in data:
69 | # print(key)
70 |
71 | # Unpack and reshape the masks.
72 | for key in data:
73 | data[key] = np.unpackbits(data[key], axis=None)[: np.prod(shape)].reshape(shape).astype(np.uint8)
74 |
75 | # data[key] contains [10000, 256, 256] array i.e. 10000 256x256 masks.
76 | return data
77 |
78 |
79 | def load_freeform_masks(op_type):
80 | data_dir = Path("assets/datasets/ImageNet")
81 |
82 | mask_fn = data_dir / f"imagenet_{op_type}_masks.npz"
83 | if not mask_fn.exists():
84 | # download orignal npz from palette google drive
85 | orig_mask_fn = str(data_dir / "imagenet_freeform_masks.npz")
86 | if not os.path.exists(orig_mask_fn):
87 | assert False
88 | # gdown.download(url=FREEFORM_URL, output=orig_mask_fn, quiet=False, fuzzy=True)
89 | masks = load_masks(orig_mask_fn)
90 |
91 | # store freeform of current ratio for faster loading in future
92 | key = {
93 | "freeform1020": "10-20% freeform",
94 | "freeform2030": "20-30% freeform",
95 | "freeform3040": "30-40% freeform",
96 | }.get(op_type)
97 | np.savez(mask_fn, mask=masks[key])
98 |
99 | # [10000, 256, 256] --> [10000, 1, 256, 256]
100 | return np.load(mask_fn)["mask"][:, None]
101 |
102 |
103 | def get_center_mask(image_size):
104 | h, w = image_size
105 | mask = bbox2mask(image_size, (h // 4, w // 4, h // 2, w // 2))
106 | return torch.from_numpy(mask).permute(2, 0, 1)
107 |
108 |
109 | def build_inpaint_center(mask_type, image_size):
110 | assert mask_type == "center"
111 |
112 | # log.info(f"[Corrupt] Inpaint: {mask_type=} ...")
113 |
114 | center_mask = get_center_mask([image_size, image_size]) # [1,1,256,256]
115 | # center_mask = center_mask.to(opt.device)
116 |
117 | def inpaint_center(img):
118 | # img: [-1,1]
119 | mask = center_mask
120 | # img[mask==0] = img[mask==0], img[mask==1] = 1 (white)
121 | return img * (1.0 - mask) + mask * 0, mask
122 |
123 | return inpaint_center
124 |
125 |
126 | def build_inpaint_freeform(mask_type):
127 | assert "freeform" in mask_type
128 |
129 | # log.info(f"[Corrupt] Inpaint: {mask_type=} ...")
130 |
131 | freeform_masks = load_freeform_masks(mask_type) # [10000, 1, 256, 256]
132 | n_freeform_masks = freeform_masks.shape[0]
133 | # freeform_masks = torch.from_numpy(freeform_masks).to(opt.device)
134 | freeform_masks = torch.from_numpy(freeform_masks)
135 |
136 | def inpaint_freeform(img):
137 | # img: [-1,1]
138 | # index = np.random.randint(n_freeform_masks, size=img.shape[0])
139 | index = np.random.randint(n_freeform_masks)
140 | mask = freeform_masks[index]
141 | # img[mask==0] = img[mask==0], img[mask==1] = 1 (white)
142 | return img * (1.0 - mask) + mask * 0, mask
143 |
144 | return inpaint_freeform
145 |
--------------------------------------------------------------------------------
/corruption/jpeg.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from ddrm-jpeg.
5 | #
6 | # Source:
7 | # https://github.com/bahjat-kawar/ddrm-jpeg/blob/master/functions/jpeg_torch.py
8 | #
9 | # The license for the original version of this file can be
10 | # found in this directory (LICENSE_DDRM_JPEG).
11 | # The modifications to this file are subject to the same license.
12 | # ---------------------------------------------------------------
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 |
18 | def dct1(x):
19 | """
20 | Discrete Cosine Transform, Type I
21 | :param x: the input signal
22 | :return: the DCT-I of the signal over the last dimension
23 | """
24 | x_shape = x.shape
25 | x = x.view(-1, x_shape[-1])
26 |
27 | return torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1))[:, :, 0].view(*x_shape)
28 |
29 |
30 | def idct1(X):
31 | """
32 | The inverse of DCT-I, which is just a scaled DCT-I
33 | Our definition if idct1 is such that idct1(dct1(x)) == x
34 | :param X: the input signal
35 | :return: the inverse DCT-I of the signal over the last dimension
36 | """
37 | n = X.shape[-1]
38 | return dct1(X) / (2 * (n - 1))
39 |
40 |
41 | def dct(x, norm=None):
42 | """
43 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
44 | For the meaning of the parameter `norm`, see:
45 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
46 | :param x: the input signal
47 | :param norm: the normalization, None or 'ortho'
48 | :return: the DCT-II of the signal over the last dimension
49 | """
50 | x_shape = x.shape
51 | N = x_shape[-1]
52 | x = x.contiguous().view(-1, N)
53 |
54 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
55 |
56 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
57 |
58 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
59 | W_r = torch.cos(k)
60 | W_i = torch.sin(k)
61 |
62 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
63 |
64 | if norm == 'ortho':
65 | V[:, 0] /= np.sqrt(N) * 2
66 | V[:, 1:] /= np.sqrt(N / 2) * 2
67 |
68 | V = 2 * V.view(*x_shape)
69 |
70 | return V
71 |
72 |
73 | def idct(X, norm=None):
74 | """
75 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
76 | Our definition of idct is that idct(dct(x)) == x
77 | For the meaning of the parameter `norm`, see:
78 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
79 | :param X: the input signal
80 | :param norm: the normalization, None or 'ortho'
81 | :return: the inverse DCT-II of the signal over the last dimension
82 | """
83 |
84 | x_shape = X.shape
85 | N = x_shape[-1]
86 |
87 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
88 |
89 | if norm == 'ortho':
90 | X_v[:, 0] *= np.sqrt(N) * 2
91 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
92 |
93 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
94 | W_r = torch.cos(k)
95 | W_i = torch.sin(k)
96 |
97 | V_t_r = X_v
98 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
99 |
100 | V_r = V_t_r * W_r - V_t_i * W_i
101 | V_i = V_t_r * W_i + V_t_i * W_r
102 |
103 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
104 |
105 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
106 | x = v.new_zeros(v.shape)
107 | x[:, ::2] += v[:, :N - (N // 2)]
108 | x[:, 1::2] += v.flip([1])[:, :N // 2]
109 |
110 | return x.view(*x_shape)
111 |
112 |
113 | def dct_2d(x, norm=None):
114 | """
115 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
116 | For the meaning of the parameter `norm`, see:
117 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
118 | :param x: the input signal
119 | :param norm: the normalization, None or 'ortho'
120 | :return: the DCT-II of the signal over the last 2 dimensions
121 | """
122 | X1 = dct(x, norm=norm)
123 | X2 = dct(X1.transpose(-1, -2), norm=norm)
124 | return X2.transpose(-1, -2)
125 |
126 |
127 | def idct_2d(X, norm=None):
128 | """
129 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
130 | Our definition of idct is that idct_2d(dct_2d(x)) == x
131 | For the meaning of the parameter `norm`, see:
132 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
133 | :param X: the input signal
134 | :param norm: the normalization, None or 'ortho'
135 | :return: the DCT-II of the signal over the last 2 dimensions
136 | """
137 | x1 = idct(X, norm=norm)
138 | x2 = idct(x1.transpose(-1, -2), norm=norm)
139 | return x2.transpose(-1, -2)
140 |
141 |
142 | def dct_3d(x, norm=None):
143 | """
144 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
145 | For the meaning of the parameter `norm`, see:
146 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
147 | :param x: the input signal
148 | :param norm: the normalization, None or 'ortho'
149 | :return: the DCT-II of the signal over the last 3 dimensions
150 | """
151 | X1 = dct(x, norm=norm)
152 | X2 = dct(X1.transpose(-1, -2), norm=norm)
153 | X3 = dct(X2.transpose(-1, -3), norm=norm)
154 | return X3.transpose(-1, -3).transpose(-1, -2)
155 |
156 |
157 | def idct_3d(X, norm=None):
158 | """
159 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
160 | Our definition of idct is that idct_3d(dct_3d(x)) == x
161 | For the meaning of the parameter `norm`, see:
162 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
163 | :param X: the input signal
164 | :param norm: the normalization, None or 'ortho'
165 | :return: the DCT-II of the signal over the last 3 dimensions
166 | """
167 | x1 = idct(X, norm=norm)
168 | x2 = idct(x1.transpose(-1, -2), norm=norm)
169 | x3 = idct(x2.transpose(-1, -3), norm=norm)
170 | return x3.transpose(-1, -3).transpose(-1, -2)
171 |
172 |
173 | class LinearDCT(nn.Linear):
174 | """Implement any DCT as a linear layer; in practice this executes around
175 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
176 | increase memory usage.
177 | :param in_features: size of expected input
178 | :param type: which dct function in this file to use"""
179 | def __init__(self, in_features, type, norm=None, bias=False):
180 | self.type = type
181 | self.N = in_features
182 | self.norm = norm
183 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
184 |
185 | def reset_parameters(self):
186 | # initialise using dct function
187 | I = torch.eye(self.N)
188 | if self.type == 'dct1':
189 | self.weight.data = dct1(I).data.t()
190 | elif self.type == 'idct1':
191 | self.weight.data = idct1(I).data.t()
192 | elif self.type == 'dct':
193 | self.weight.data = dct(I, norm=self.norm).data.t()
194 | elif self.type == 'idct':
195 | self.weight.data = idct(I, norm=self.norm).data.t()
196 | self.weight.requires_grad = False # don't learn this!
197 |
198 |
199 | def apply_linear_2d(x, linear_layer):
200 | """Can be used with a LinearDCT layer to do a 2D DCT.
201 | :param x: the input signal
202 | :param linear_layer: any PyTorch Linear layer
203 | :return: result of linear layer applied to last 2 dimensions
204 | """
205 | X1 = linear_layer(x)
206 | X2 = linear_layer(X1.transpose(-1, -2))
207 | return X2.transpose(-1, -2)
208 |
209 |
210 | def apply_linear_3d(x, linear_layer):
211 | """Can be used with a LinearDCT layer to do a 3D DCT.
212 | :param x: the input signal
213 | :param linear_layer: any PyTorch Linear layer
214 | :return: result of linear layer applied to last 3 dimensions
215 | """
216 | X1 = linear_layer(x)
217 | X2 = linear_layer(X1.transpose(-1, -2))
218 | X3 = linear_layer(X2.transpose(-1, -3))
219 | return X3.transpose(-1, -3).transpose(-1, -2)
220 |
221 |
222 | def torch_rgb2ycbcr(x):
223 | # Assume x is a batch of size (N x C x H x W)
224 | v = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).to(x.device)
225 | ycbcr = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
226 | ycbcr[:,1:] += 128
227 | return ycbcr
228 |
229 |
230 | def torch_ycbcr2rgb(x):
231 | # Assume x is a batch of size (N x C x H x W)
232 | v = torch.tensor([[ 1.00000000e+00, -3.68199903e-05, 1.40198758e+00],
233 | [ 1.00000000e+00, -3.44113281e-01, -7.14103821e-01],
234 | [ 1.00000000e+00, 1.77197812e+00, -1.34583413e-04]]).to(x.device)
235 | x[:, 1:] -= 128
236 | rgb = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1)
237 | return rgb
238 |
239 | def chroma_subsample(x):
240 | return x[:, 0:1, :, :], x[:, 1:, ::2, ::2]
241 |
242 |
243 | def general_quant_matrix(qf = 10):
244 | q1 = torch.tensor([
245 | 16, 11, 10, 16, 24, 40, 51, 61,
246 | 12, 12, 14, 19, 26, 58, 60, 55,
247 | 14, 13, 16, 24, 40, 57, 69, 56,
248 | 14, 17, 22, 29, 51, 87, 80, 62,
249 | 18, 22, 37, 56, 68, 109, 103, 77,
250 | 24, 35, 55, 64, 81, 104, 113, 92,
251 | 49, 64, 78, 87, 103, 121, 120, 101,
252 | 72, 92, 95, 98, 112, 100, 103, 99
253 | ])
254 | q2 = torch.tensor([
255 | 17, 18, 24, 47, 99, 99, 99, 99,
256 | 18, 21, 26, 66, 99, 99, 99, 99,
257 | 24, 26, 56, 99, 99, 99, 99, 99,
258 | 47, 66, 99, 99, 99, 99, 99, 99,
259 | 99, 99, 99, 99, 99, 99, 99, 99,
260 | 99, 99, 99, 99, 99, 99, 99, 99,
261 | 99, 99, 99, 99, 99, 99, 99, 99,
262 | 99, 99, 99, 99, 99, 99, 99, 99
263 | ])
264 | s = (5000 / qf) if qf < 50 else (200 - 2 * qf)
265 | q1 = torch.floor((s * q1 + 50) / 100)
266 | q1[q1 <= 0] = 1
267 | q1[q1 > 255] = 255
268 | q2 = torch.floor((s * q2 + 50) / 100)
269 | q2[q2 <= 0] = 1
270 | q2[q2 > 255] = 255
271 | return q1, q2
272 |
273 |
274 | def quantization_matrix(qf):
275 | return general_quant_matrix(qf)
276 | # q1 = torch.tensor([[ 80, 55, 50, 80, 120, 200, 255, 255],
277 | # [ 60, 60, 70, 95, 130, 255, 255, 255],
278 | # [ 70, 65, 80, 120, 200, 255, 255, 255],
279 | # [ 70, 85, 110, 145, 255, 255, 255, 255],
280 | # [ 90, 110, 185, 255, 255, 255, 255, 255],
281 | # [120, 175, 255, 255, 255, 255, 255, 255],
282 | # [245, 255, 255, 255, 255, 255, 255, 255],
283 | # [255, 255, 255, 255, 255, 255, 255, 255]])
284 | # q2 = torch.tensor([[ 85, 90, 120, 235, 255, 255, 255, 255],
285 | # [ 90, 105, 130, 255, 255, 255, 255, 255],
286 | # [120, 130, 255, 255, 255, 255, 255, 255],
287 | # [235, 255, 255, 255, 255, 255, 255, 255],
288 | # [255, 255, 255, 255, 255, 255, 255, 255],
289 | # [255, 255, 255, 255, 255, 255, 255, 255],
290 | # [255, 255, 255, 255, 255, 255, 255, 255],
291 | # [255, 255, 255, 255, 255, 255, 255, 255]])
292 | # return q1, q2
293 |
294 | def jpeg_encode(x, qf):
295 | # Assume x is a batch of size (N x C x H x W)
296 | # [-1, 1] to [0, 255]
297 | x = (x + 1) / 2 * 255
298 | n_batch, _, n_size, _ = x.shape
299 |
300 | x = torch_rgb2ycbcr(x)
301 | x_luma, x_chroma = chroma_subsample(x)
302 | unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
303 | x_luma = unfold(x_luma).transpose(2, 1)
304 | x_chroma = unfold(x_chroma).transpose(2, 1)
305 |
306 | x_luma = x_luma.reshape(-1, 8, 8) - 128
307 | x_chroma = x_chroma.reshape(-1, 8, 8) - 128
308 |
309 | dct_layer = LinearDCT(8, 'dct', norm='ortho')
310 | dct_layer.to(x_luma.device)
311 | x_luma = apply_linear_2d(x_luma, dct_layer)
312 | x_chroma = apply_linear_2d(x_chroma, dct_layer)
313 |
314 | x_luma = x_luma.view(-1, 1, 8, 8)
315 | x_chroma = x_chroma.view(-1, 2, 8, 8)
316 |
317 | q1, q2 = quantization_matrix(qf)
318 | q1 = q1.to(x_luma.device)
319 | q2 = q2.to(x_luma.device)
320 | x_luma /= q1.view(1, 8, 8)
321 | x_chroma /= q2.view(1, 8, 8)
322 |
323 | x_luma = x_luma.round()
324 | x_chroma = x_chroma.round()
325 |
326 | x_luma = x_luma.reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
327 | x_chroma = x_chroma.reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
328 |
329 | fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
330 | x_luma = fold(x_luma)
331 | fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
332 | x_chroma = fold(x_chroma)
333 |
334 | return [x_luma, x_chroma]
335 |
336 |
337 |
338 | def jpeg_decode(x, qf):
339 | # Assume x[0] is a batch of size (N x 1 x H x W) (luma)
340 | # Assume x[1:] is a batch of size (N x 2 x H/2 x W/2) (chroma)
341 | x_luma, x_chroma = x
342 | n_batch, _, n_size, _ = x_luma.shape
343 | unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8))
344 | x_luma = unfold(x_luma).transpose(2, 1)
345 | x_luma = x_luma.reshape(-1, 1, 8, 8)
346 | x_chroma = unfold(x_chroma).transpose(2, 1)
347 | x_chroma = x_chroma.reshape(-1, 2, 8, 8)
348 |
349 | q1, q2 = quantization_matrix(qf)
350 | q1 = q1.to(x_luma.device)
351 | q2 = q2.to(x_luma.device)
352 | x_luma *= q1.view(1, 8, 8)
353 | x_chroma *= q2.view(1, 8, 8)
354 |
355 | x_luma = x_luma.reshape(-1, 8, 8)
356 | x_chroma = x_chroma.reshape(-1, 8, 8)
357 |
358 | dct_layer = LinearDCT(8, 'idct', norm='ortho')
359 | dct_layer.to(x_luma.device)
360 | x_luma = apply_linear_2d(x_luma, dct_layer)
361 | x_chroma = apply_linear_2d(x_chroma, dct_layer)
362 |
363 | x_luma = (x_luma + 128).reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1)
364 | x_chroma = (x_chroma + 128).reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1)
365 |
366 | fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8))
367 | x_luma = fold(x_luma)
368 | fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8))
369 | x_chroma = fold(x_chroma)
370 |
371 | x_chroma_repeated = torch.zeros(n_batch, 2, n_size, n_size, device = x_luma.device)
372 | x_chroma_repeated[:, :, 0::2, 0::2] = x_chroma
373 | x_chroma_repeated[:, :, 0::2, 1::2] = x_chroma
374 | x_chroma_repeated[:, :, 1::2, 0::2] = x_chroma
375 | x_chroma_repeated[:, :, 1::2, 1::2] = x_chroma
376 |
377 | x = torch.cat([x_luma, x_chroma_repeated], dim=1)
378 |
379 | x = torch_ycbcr2rgb(x)
380 |
381 | # [0, 255] to [-1, 1]
382 | x = x / 255 * 2 - 1
383 |
384 | return x
385 |
386 |
387 | def build_jpeg(log, qf):
388 | log.info(f"[Corrupt] JPEG restoration: {qf=} ...")
389 | def jpeg(img):
390 | return jpeg_decode(jpeg_encode(img, qf), qf)
391 | return jpeg
392 |
--------------------------------------------------------------------------------
/corruption/mixture.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import numpy as np
9 | import enum
10 |
11 | import torch
12 | from torch.utils.data import Dataset
13 |
14 | from .jpeg import jpeg_decode, jpeg_encode
15 | from .blur import Deblurring
16 | from .superresolution import build_sr_bicubic, build_sr_pool
17 | from .inpaint import get_center_mask, load_freeform_masks
18 |
19 | from ipdb import set_trace as debug
20 |
21 |
22 | class AllCorrupt(enum.IntEnum):
23 | JPEG_5 = 0
24 | JPEG_10 = 1
25 | BLUR_UNI = 2
26 | BLUR_GAUSS = 3
27 | SR4X_POOL = 4
28 | SR4X_BICUBIC = 5
29 | INPAINT_CENTER = 6
30 | INPAINT_FREE1020 = 7
31 | INPAINT_FREE2030 = 8
32 |
33 |
34 | class MixtureCorruptMethod:
35 | def __init__(self, opt, device="cpu"):
36 |
37 | # ===== blur ====
38 | self.blur_uni = Deblurring(torch.Tensor([1 / 9] * 9).to(device), 3, opt.image_size, device)
39 |
40 | sigma = 10
41 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
42 | g_kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(device)
43 | self.blur_gauss = Deblurring(g_kernel / g_kernel.sum(), 3, opt.image_size, device)
44 |
45 | # ===== sr4x ====
46 | factor = 4
47 | self.sr4x_pool = build_sr_pool(factor, device, opt.image_size)
48 | self.sr4x_bicubic = build_sr_bicubic(factor, device, opt.image_size)
49 | self.upsample = torch.nn.Upsample(scale_factor=factor, mode="nearest")
50 |
51 | # ===== inpaint ====
52 | self.center_mask = get_center_mask([opt.image_size, opt.image_size])[None, ...] # [1, 1, 256, 256]
53 | self.free1020_masks = torch.from_numpy((load_freeform_masks("freeform1020"))) # [10000, 1, 256, 256]
54 | self.free2030_masks = torch.from_numpy((load_freeform_masks("freeform2030"))) # [10000, 1, 256, 256]
55 |
56 | def jpeg(self, img, qf):
57 | return jpeg_decode(jpeg_encode(img, qf), qf)
58 |
59 | def blur(self, img, kernel):
60 | img = (img + 1) / 2
61 | if kernel == "uni":
62 | _img = self.blur_uni.H(img).reshape(*img.shape)
63 | elif kernel == "gauss":
64 | _img = self.blur_gauss.H(img).reshape(*img.shape)
65 | # [0,1] -> [-1,1]
66 | return _img * 2 - 1
67 |
68 | def sr4x(self, img, filter):
69 | b, c, w, h = img.shape
70 | if filter == "pool":
71 | _img = self.sr4x_pool.H(img).reshape(b, c, w // 4, h // 4)
72 | elif filter == "bicubic":
73 | _img = self.sr4x_bicubic.H(img).reshape(b, c, w // 4, h // 4)
74 |
75 | # scale to original image size for I2SB
76 | return self.upsample(_img)
77 |
78 | def inpaint(self, img, mask_type, mask_index=None):
79 | if mask_type == "center":
80 | mask = self.center_mask
81 | elif mask_type == "free1020":
82 | if mask_index is None:
83 | mask_index = np.random.randint(len(self.free1020_masks))
84 | mask = self.free1020_masks[[mask_index]]
85 | elif mask_type == "free2030":
86 | if mask_index is None:
87 | mask_index = np.random.randint(len(self.free2030_masks))
88 | mask = self.free2030_masks[[mask_index]]
89 | return img * (1.0 - mask) + mask * torch.randn_like(img)
90 |
91 | def mixture(self, img, corrupt_idx, mask_index=None):
92 | if corrupt_idx == AllCorrupt.JPEG_5:
93 | corrupt_img = self.jpeg(img, 5)
94 | elif corrupt_idx == AllCorrupt.JPEG_10:
95 | corrupt_img = self.jpeg(img, 10)
96 | elif corrupt_idx == AllCorrupt.BLUR_UNI:
97 | corrupt_img = self.blur(img, "uni")
98 | elif corrupt_idx == AllCorrupt.BLUR_GAUSS:
99 | corrupt_img = self.blur(img, "gauss")
100 | elif corrupt_idx == AllCorrupt.SR4X_POOL:
101 | corrupt_img = self.sr4x(img, "pool")
102 | elif corrupt_idx == AllCorrupt.SR4X_BICUBIC:
103 | corrupt_img = self.sr4x(img, "bicubic")
104 | elif corrupt_idx == AllCorrupt.INPAINT_CENTER:
105 | corrupt_img = self.inpaint(img, "center")
106 | elif corrupt_idx == AllCorrupt.INPAINT_FREE1020:
107 | corrupt_img = self.inpaint(img, "free1020", mask_index=mask_index)
108 | elif corrupt_idx == AllCorrupt.INPAINT_FREE2030:
109 | corrupt_img = self.inpaint(img, "free2030", mask_index=mask_index)
110 | return corrupt_img
111 |
112 |
113 | class MixtureCorruptDatasetTrain(Dataset):
114 | def __init__(self, opt, dataset):
115 | super(MixtureCorruptDatasetTrain, self).__init__()
116 | self.dataset = dataset
117 | self.method = MixtureCorruptMethod(opt)
118 |
119 | def __len__(self):
120 | return self.dataset.__len__()
121 |
122 | def __getitem__(self, index):
123 | clean_img, y = self.dataset[index] # clean_img: tensor [-1,1]
124 |
125 | rand_idx = np.random.choice(AllCorrupt)
126 | corrupt_img = self.method.mixture(clean_img.unsqueeze(0), rand_idx).squeeze(0)
127 |
128 | assert corrupt_img.shape == clean_img.shape, (clean_img.shape, corrupt_img.shape)
129 | return clean_img, corrupt_img, y
130 |
131 |
132 | class MixtureCorruptDatasetVal(Dataset):
133 | def __init__(self, opt, dataset):
134 | super(MixtureCorruptDatasetVal, self).__init__()
135 | self.dataset = dataset
136 | self.method = MixtureCorruptMethod(opt)
137 |
138 | def __len__(self):
139 | return self.dataset.__len__()
140 |
141 | def __getitem__(self, index):
142 | clean_img, y = self.dataset[index] # clean_img: tensor [-1,1]
143 |
144 | idx = index % len(AllCorrupt)
145 | corrupt_img = self.method.mixture(clean_img.unsqueeze(0), idx, mask_index=idx).squeeze(0)
146 |
147 | assert corrupt_img.shape == clean_img.shape, (clean_img.shape, corrupt_img.shape)
148 | return clean_img, corrupt_img, y
149 |
--------------------------------------------------------------------------------
/corruption/superresolution.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This file has been modified from ddrm.
5 | #
6 | # Source:
7 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L171
8 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L264
9 | # https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L314
10 | #
11 | # The license for the original version of this file can be
12 | # found in this directory (LICENSE_DDRM).
13 | # The modifications to this file are subject to the same license.
14 | # ---------------------------------------------------------------
15 |
16 | import numpy as np
17 | import torch
18 | from .base import H_functions
19 |
20 | from ipdb import set_trace as debug
21 |
22 |
23 | class SuperResolution(H_functions):
24 | def __init__(self, channels, img_dim, ratio, device): # ratio = 2 or 4
25 | assert img_dim % ratio == 0
26 | self.img_dim = img_dim
27 | self.channels = channels
28 | self.y_dim = img_dim // ratio
29 | self.ratio = ratio
30 | H = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device)
31 | self.U_small, self.singulars_small, self.V_small = torch.svd(H, some=False)
32 | self.Vt_small = self.V_small.transpose(0, 1)
33 |
34 | def V(self, vec):
35 | # reorder the vector back into patches (because singulars are ordered descendingly)
36 | temp = vec.clone().reshape(vec.shape[0], -1)
37 | patches = torch.zeros(vec.shape[0], self.channels, self.y_dim**2, self.ratio**2, device=vec.device)
38 | patches[:, :, :, 0] = temp[:, : self.channels * self.y_dim**2].view(vec.shape[0], self.channels, -1)
39 | for idx in range(self.ratio**2 - 1):
40 | patches[:, :, :, idx + 1] = temp[:, (self.channels * self.y_dim**2 + idx) :: self.ratio**2 - 1].view(
41 | vec.shape[0], self.channels, -1
42 | )
43 | # multiply each patch by the small V
44 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(
45 | vec.shape[0], self.channels, -1, self.ratio**2
46 | )
47 | # repatch the patches into an image
48 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio)
49 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous()
50 | recon = recon.reshape(vec.shape[0], self.channels * self.img_dim**2)
51 | return recon
52 |
53 | def Vt(self, vec):
54 | # extract flattened patches
55 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim)
56 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio)
57 | unfold_shape = patches.shape
58 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2)
59 | # multiply each by the small V transposed
60 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(
61 | vec.shape[0], self.channels, -1, self.ratio**2
62 | )
63 | # reorder the vector to have the first entry first (because singulars are ordered descendingly)
64 | recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
65 | recon[:, : self.channels * self.y_dim**2] = patches[:, :, :, 0].view(
66 | vec.shape[0], self.channels * self.y_dim**2
67 | )
68 | for idx in range(self.ratio**2 - 1):
69 | recon[:, (self.channels * self.y_dim**2 + idx) :: self.ratio**2 - 1] = patches[:, :, :, idx + 1].view(
70 | vec.shape[0], self.channels * self.y_dim**2
71 | )
72 | return recon
73 |
74 | def U(self, vec):
75 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
76 |
77 | def Ut(self, vec): # U is 1x1, so U^T = U
78 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1)
79 |
80 | def singulars(self):
81 | return self.singulars_small.repeat(self.channels * self.y_dim**2)
82 |
83 | def add_zeros(self, vec):
84 | reshaped = vec.clone().reshape(vec.shape[0], -1)
85 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
86 | temp[:, : reshaped.shape[1]] = reshaped
87 | return temp
88 |
89 |
90 | class SRConv(H_functions):
91 | def mat_by_img(self, M, v, dim):
92 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim, dim)).reshape(
93 | v.shape[0], self.channels, M.shape[0], dim
94 | )
95 |
96 | def img_by_mat(self, v, M, dim):
97 | return torch.matmul(v.reshape(v.shape[0] * self.channels, dim, dim), M).reshape(
98 | v.shape[0], self.channels, dim, M.shape[1]
99 | )
100 |
101 | def __init__(self, kernel, channels, img_dim, device, stride=1):
102 | self.img_dim = img_dim
103 | self.channels = channels
104 | self.ratio = stride
105 | small_dim = img_dim // stride
106 | self.small_dim = small_dim
107 | # build 1D conv matrix
108 | H_small = torch.zeros(small_dim, img_dim, device=device)
109 | for i in range(stride // 2, img_dim + stride // 2, stride):
110 | for j in range(i - kernel.shape[0] // 2, i + kernel.shape[0] // 2):
111 | j_effective = j
112 | # reflective padding
113 | if j_effective < 0:
114 | j_effective = -j_effective - 1
115 | if j_effective >= img_dim:
116 | j_effective = (img_dim - 1) - (j_effective - img_dim)
117 | # matrix building
118 | H_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0] // 2]
119 | # get the svd of the 1D conv
120 | self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False)
121 | ZERO = 3e-2
122 | self.singulars_small[self.singulars_small < ZERO] = 0
123 | # calculate the singular values of the big matrix
124 | self._singulars = torch.matmul(
125 | self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim)
126 | ).reshape(small_dim**2)
127 | # permutation for matching the singular values. See P_1 in Appendix D.5.
128 | self._perm = (
129 | torch.Tensor(
130 | [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)]
131 | + [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)]
132 | )
133 | .to(device)
134 | .long()
135 | )
136 |
137 | def V(self, vec):
138 | # invert the permutation
139 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
140 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[
141 | :, : self._perm.shape[0], :
142 | ]
143 | temp[:, self._perm.shape[0] :, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[
144 | :, self._perm.shape[0] :, :
145 | ]
146 | temp = temp.permute(0, 2, 1)
147 | # multiply the image by V from the left and by V^T from the right
148 | out = self.mat_by_img(self.V_small, temp, self.img_dim)
149 | out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1)
150 | return out
151 |
152 | def Vt(self, vec):
153 | # multiply the image by V^T from the left and by V from the right
154 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim)
155 | temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1)
156 | # permute the entries
157 | temp[:, :, : self._perm.shape[0]] = temp[:, :, self._perm]
158 | temp = temp.permute(0, 2, 1)
159 | return temp.reshape(vec.shape[0], -1)
160 |
161 | def U(self, vec):
162 | # invert the permutation
163 | temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device)
164 | temp[:, : self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels)
165 | temp = temp.permute(0, 2, 1)
166 | # multiply the image by U from the left and by U^T from the right
167 | out = self.mat_by_img(self.U_small, temp, self.small_dim)
168 | out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1)
169 | return out
170 |
171 | def Ut(self, vec):
172 | # multiply the image by U^T from the left and by U from the right
173 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim)
174 | temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1)
175 | # permute the entries
176 | temp = temp.permute(0, 2, 1)
177 | return temp.reshape(vec.shape[0], -1)
178 |
179 | def singulars(self):
180 | return self._singulars.repeat_interleave(3).reshape(-1)
181 |
182 | def add_zeros(self, vec):
183 | reshaped = vec.clone().reshape(vec.shape[0], -1)
184 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
185 | temp[:, : reshaped.shape[1]] = reshaped
186 | return temp
187 |
188 |
189 | # note: codes adoptted from
190 | # https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L228
191 | def build_sr_bicubic(factor, device, image_size, data_channels=3):
192 | def bicubic_kernel(x, a=-0.5):
193 | if abs(x) <= 1:
194 | return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1
195 | elif 1 < abs(x) and abs(x) < 2:
196 | return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a
197 | else:
198 | return 0
199 |
200 | k = np.zeros((factor * 4))
201 | for i in range(factor * 4):
202 | x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5)
203 | k[i] = bicubic_kernel(x)
204 | k = k / np.sum(k)
205 | kernel = torch.from_numpy(k).float().to(device)
206 | H_funcs = SRConv(kernel / kernel.sum(), data_channels, image_size, device, stride=factor)
207 |
208 | return H_funcs
209 |
210 |
211 | def build_sr_pool(factor, device, image_size, data_channels=3):
212 | H_funcs = SuperResolution(data_channels, image_size, factor, device)
213 | return H_funcs
214 |
215 |
216 | def build_sr4x(opt, log, sr_filter, image_size):
217 | assert sr_filter in ["pool", "bicubic"]
218 | log.info(f"[Corrupt] Super-resolution (4x): {sr_filter=} ...")
219 |
220 | factor = 4
221 |
222 | sr_bicubic = build_sr_bicubic(factor, opt.device, image_size)
223 | sr_pool = build_sr_pool(factor, opt.device, image_size)
224 |
225 | upsample = torch.nn.Upsample(scale_factor=factor, mode="nearest")
226 |
227 | assert sr_filter in ["pool", "bicubic"]
228 |
229 | def sr4x(img):
230 | b, c, w, h = img.shape
231 | img = img.to(opt.device)
232 | if sr_filter == "pool":
233 | _img = sr_pool.H(img).reshape(b, c, w // factor, h // factor)
234 | elif sr_filter == "bicubic":
235 | _img = sr_bicubic.H(img).reshape(b, c, w // factor, h // factor)
236 |
237 | # scale to original image size for I2SB
238 | return upsample(_img)
239 |
240 | return sr4x
241 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 | import torch.distributed as dist
5 | from torch.utils.data.sampler import Sampler
6 | from torch.utils.data import DataLoader
7 |
8 |
9 | def get_data_scaler(config):
10 | """Data normalizer. Assume data are always in [0, 1]."""
11 | if config.data.centered:
12 | # Rescale to [-1, 1]
13 | return lambda x: x * 2.0 - 1.0
14 | else:
15 | return lambda x: x
16 |
17 |
18 | def get_data_inverse_scaler(config):
19 | """Inverse data normalizer."""
20 | if config.data.centered:
21 | # Rescale [-1, 1] to [0, 1]
22 | return lambda x: (x + 1.0) / 2.0
23 | else:
24 | return lambda x: x
25 |
26 |
27 | class UniformDequant(object):
28 | def __call__(self, x):
29 | return x + torch.rand_like(x) / 256
30 |
31 |
32 | class RASampler(Sampler):
33 | """Sampler that restricts data loading to a subset of the dataset for distributed,
34 | with repeated augmentation.
35 | It ensures that different each augmented version of a sample will be visible to a
36 | different process (GPU).
37 | Heavily based on 'torch.utils.data.DistributedSampler'.
38 | This is borrowed from the DeiT Repo:
39 | https://github.com/facebookresearch/deit/blob/main/samplers.py
40 | """
41 |
42 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
43 | if num_replicas is None:
44 | num_replicas = dist.get_world_size()
45 | if rank is None:
46 | rank = dist.get_rank()
47 | self.dataset = dataset
48 | self.num_replicas = num_replicas
49 | self.rank = rank
50 | self.epoch = 0
51 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
52 | self.total_size = self.num_samples * self.num_replicas
53 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
54 | self.shuffle = shuffle
55 | self.seed = seed
56 | self.repetitions = repetitions
57 |
58 | def __iter__(self):
59 | if self.shuffle:
60 | # Deterministically shuffle based on epoch
61 | g = torch.Generator()
62 | g.manual_seed(self.seed + self.epoch)
63 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
64 | else:
65 | indices = list(range(len(self.dataset)))
66 |
67 | # Add extra samples to make it evenly divisible
68 | indices = [ele for ele in indices for i in range(self.repetitions)]
69 | indices += indices[: (self.total_size - len(indices))]
70 | assert len(indices) == self.total_size
71 |
72 | # Subsample
73 | indices = indices[self.rank : self.total_size : self.num_replicas]
74 | assert len(indices) == self.num_samples
75 |
76 | return iter(indices[: self.num_selected_samples])
77 |
78 | def __len__(self):
79 | return self.num_selected_samples
80 |
81 | def set_epoch(self, epoch):
82 | self.epoch = epoch
83 |
84 |
85 | class InfiniteBatchSampler(Sampler):
86 | def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False):
87 | self.dataset_len = dataset_len
88 | self.batch_size = batch_size
89 | self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
90 | self.max_p = self.iters_per_ep * batch_size
91 | self.filling = filling
92 | self.shuffle = shuffle
93 | self.epoch = 0
94 | self.seed = seed
95 | self.indices = self.gener_indices()
96 |
97 | def gener_indices(self):
98 | if self.shuffle:
99 | g = torch.Generator()
100 | g.manual_seed(self.epoch + self.seed)
101 | indices = torch.randperm(self.dataset_len, generator=g).numpy()
102 | else:
103 | indices = torch.arange(self.dataset_len).numpy()
104 |
105 | tails = self.batch_size - (self.dataset_len % self.batch_size)
106 | if tails != self.batch_size and self.filling:
107 | tails = indices[:tails]
108 | np.random.shuffle(indices)
109 | indices = np.concatenate((indices, tails))
110 |
111 | # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
112 | # noinspection PyTypeChecker
113 | return tuple(indices.tolist())
114 |
115 | def __iter__(self):
116 | self.epoch = 0
117 | while True:
118 | self.epoch += 1
119 | p, q = 0, 0
120 | while p < self.max_p:
121 | q = p + self.batch_size
122 | yield self.indices[p:q]
123 | p = q
124 | if self.shuffle:
125 | self.indices = self.gener_indices()
126 |
127 | def __len__(self):
128 | return self.iters_per_ep
129 |
130 |
131 | class DistInfiniteBatchSampler(InfiniteBatchSampler):
132 | def __init__(
133 | self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True
134 | ):
135 | # from torchvision.models import ResNet50_Weights
136 | # RA sampler: https://github.com/pytorch/vision/blob/5521e9d01ee7033b9ee9d421c1ef6fb211ed3782/references/classification/sampler.py
137 |
138 | assert glb_batch_size % world_size == 0
139 | self.world_size, self.rank = world_size, rank
140 | self.dataset_len = dataset_len
141 | self.glb_batch_size = glb_batch_size
142 | self.batch_size = glb_batch_size // world_size
143 |
144 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
145 | self.filling = filling
146 | self.shuffle = shuffle
147 | self.repeated_aug = repeated_aug
148 | self.epoch = 0
149 | self.seed = seed
150 | self.indices = self.gener_indices()
151 |
152 | def gener_indices(self):
153 | global_max_p = (
154 | self.iters_per_ep * self.glb_batch_size
155 | ) # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
156 | if self.shuffle:
157 | g = torch.Generator()
158 | g.manual_seed(self.epoch + self.seed)
159 | global_indices = torch.randperm(self.dataset_len, generator=g)
160 | if self.repeated_aug > 1:
161 | global_indices = global_indices[
162 | : (self.dataset_len + self.repeated_aug - 1) // self.repeated_aug
163 | ].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
164 | else:
165 | global_indices = torch.arange(self.dataset_len)
166 | filling = global_max_p - global_indices.shape[0]
167 | if filling > 0 and self.filling:
168 | global_indices = torch.cat((global_indices, global_indices[:filling]))
169 | global_indices = tuple(global_indices.numpy().tolist())
170 |
171 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
172 | local_indices = global_indices[seps[self.rank] : seps[self.rank + 1]]
173 | self.max_p = len(local_indices)
174 | return local_indices
175 |
176 |
177 | def load_data(
178 | data_dir,
179 | dataset,
180 | batch_size,
181 | image_size,
182 | deterministic=False,
183 | include_test=False,
184 | seed=42,
185 | num_workers=2,
186 | ):
187 | # Compute batch size for this worker.
188 | root = data_dir
189 |
190 | if dataset == "edges2handbags":
191 |
192 | from .aligned_dataset import EdgesDataset
193 |
194 | trainset = EdgesDataset(dataroot=root, train=True, img_size=image_size, random_crop=True, random_flip=True)
195 |
196 | valset = EdgesDataset(dataroot=root, train=True, img_size=image_size, random_crop=False, random_flip=False)
197 | if include_test:
198 | testset = EdgesDataset(
199 | dataroot=root, train=False, img_size=image_size, random_crop=False, random_flip=False
200 | )
201 |
202 | elif dataset == "diode":
203 |
204 | from .aligned_dataset import DIODE
205 |
206 | trainset = DIODE(
207 | dataroot=root, train=True, img_size=image_size, random_crop=True, random_flip=True, disable_cache=True
208 | )
209 |
210 | valset = DIODE(
211 | dataroot=root, train=True, img_size=image_size, random_crop=False, random_flip=False, disable_cache=True
212 | )
213 |
214 | if include_test:
215 | testset = DIODE(dataroot=root, train=False, img_size=image_size, random_crop=False, random_flip=False)
216 |
217 | elif "imagenet_inpaint" in dataset:
218 | corrupt_type = dataset.split("_")[-1]
219 | assert corrupt_type in ["center", "freeform2030"]
220 | from .imagenet_inpaint import ImageNetInpaintingDataset, InpaintingVal10kSubset
221 |
222 | trainset = ImageNetInpaintingDataset(root, image_size, corrupt_type, train=True)
223 | valset = ImageNetInpaintingDataset(root, image_size, corrupt_type, train=False)
224 |
225 | if include_test:
226 | testset = InpaintingVal10kSubset(root, image_size, corrupt_type)
227 |
228 | loader = DataLoader(
229 | dataset=trainset,
230 | num_workers=num_workers,
231 | pin_memory=True,
232 | batch_sampler=DistInfiniteBatchSampler(
233 | dataset_len=len(trainset),
234 | glb_batch_size=batch_size * dist.get_world_size(),
235 | seed=seed,
236 | shuffle=not deterministic,
237 | filling=True,
238 | rank=dist.get_rank(),
239 | world_size=dist.get_world_size(),
240 | ),
241 | )
242 |
243 | num_tasks = dist.get_world_size()
244 | global_rank = dist.get_rank()
245 | sampler = torch.utils.data.DistributedSampler(
246 | valset, num_replicas=num_tasks, rank=global_rank, shuffle=False, drop_last=False
247 | )
248 | val_loader = torch.utils.data.DataLoader(
249 | valset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, drop_last=False
250 | )
251 |
252 | if include_test:
253 |
254 | num_tasks = dist.get_world_size()
255 | global_rank = dist.get_rank()
256 | sampler = torch.utils.data.DistributedSampler(
257 | testset, num_replicas=num_tasks, rank=global_rank, shuffle=False, drop_last=False
258 | )
259 | test_loader = torch.utils.data.DataLoader(
260 | testset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, shuffle=False, drop_last=False
261 | )
262 |
263 | return loader, val_loader, test_loader
264 | else:
265 | return loader, val_loader
266 |
--------------------------------------------------------------------------------
/datasets/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torch
3 | import random
4 | import numpy as np
5 | import torchvision.transforms as transforms
6 | from .image_folder import make_dataset
7 | from PIL import Image
8 | import blobfile as bf
9 |
10 |
11 | def get_params(size, resize_size, crop_size):
12 | w, h = size
13 | new_h = h
14 | new_w = w
15 |
16 | ss, ls = min(w, h), max(w, h) # shortside and longside
17 | width_is_shorter = w == ss
18 | ls = int(resize_size * ls / ss)
19 | ss = resize_size
20 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
21 |
22 | x = random.randint(0, np.maximum(0, new_w - crop_size))
23 | y = random.randint(0, np.maximum(0, new_h - crop_size))
24 |
25 | flip = random.random() > 0.5
26 | return {"crop_pos": (x, y), "flip": flip}
27 |
28 |
29 | def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop=True, totensor=True):
30 | transform_list = []
31 | transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method)))
32 |
33 | if flip:
34 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
35 | if totensor:
36 | transform_list.append(transforms.ToTensor())
37 | transform_list.append(transforms.Lambda(lambda t: (t * 2) - 1)) # [0,1] --> [-1, 1])
38 | return transforms.Compose(transform_list)
39 |
40 |
41 | def get_tensor(normalize=True, toTensor=True):
42 | transform_list = []
43 | if toTensor:
44 | transform_list += [transforms.ToTensor()]
45 |
46 | if normalize:
47 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
48 | return transforms.Compose(transform_list)
49 |
50 |
51 | def normalize():
52 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
53 |
54 |
55 | def __scale(img, target_width, method=Image.BICUBIC):
56 | if isinstance(img, torch.Tensor):
57 | return torch.nn.functional.interpolate(
58 | img.unsqueeze(0), size=(target_width, target_width), mode="bicubic", align_corners=False
59 | ).squeeze(0)
60 | else:
61 | return img.resize((target_width, target_width), method)
62 |
63 |
64 | def __flip(img, flip):
65 | if flip:
66 | if isinstance(img, torch.Tensor):
67 | return img.flip(-1)
68 | else:
69 | return img.transpose(Image.FLIP_LEFT_RIGHT)
70 | return img
71 |
72 |
73 | def get_flip(img, flip):
74 | return __flip(img, flip)
75 |
76 |
77 | class EdgesDataset(torch.utils.data.Dataset):
78 | """A dataset class for paired image dataset.
79 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
80 | During test time, you need to prepare a directory '/path/to/data/test'.
81 | """
82 |
83 | def __init__(self, dataroot, train=True, img_size=256, random_crop=False, random_flip=True):
84 | """Initialize this dataset class.
85 | Parameters:
86 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
87 | """
88 | super().__init__()
89 | if train:
90 | self.train_dir = os.path.join(dataroot, "train") # get the image directory
91 | self.train_paths = make_dataset(self.train_dir) # get image paths
92 | self.AB_paths = sorted(self.train_paths)
93 | else:
94 |
95 | self.test_dir = os.path.join(dataroot, "val") # get the image directory
96 |
97 | self.AB_paths = make_dataset(self.test_dir) # get image paths
98 |
99 | self.crop_size = img_size
100 | self.resize_size = img_size
101 |
102 | self.random_crop = random_crop
103 | self.random_flip = random_flip
104 | self.train = train
105 |
106 | def __getitem__(self, index):
107 | """Return a data point and its metadata information.
108 | Parameters:
109 | index - - a random integer for data indexing
110 | Returns a dictionary that contains A, B, A_paths and B_paths
111 | A (tensor) - - an image in the input domain
112 | B (tensor) - - its corresponding image in the target domain
113 | A_paths (str) - - image paths
114 | B_paths (str) - - image paths (same as A_paths)
115 | """
116 | # read a image given a random integer index
117 |
118 | AB_path = self.AB_paths[index]
119 | AB = Image.open(AB_path).convert("RGB")
120 | # split AB image into A and B
121 | w, h = AB.size
122 | w2 = int(w / 2)
123 | A = AB.crop((0, 0, w2, h))
124 | B = AB.crop((w2, 0, w, h))
125 |
126 | # apply the same transform to both A and B
127 | params = get_params(A.size, self.resize_size, self.crop_size)
128 |
129 | transform_image = get_transform(
130 | params, self.resize_size, self.crop_size, crop=self.random_crop, flip=self.random_flip
131 | )
132 |
133 | A = transform_image(A)
134 | B = transform_image(B)
135 |
136 | return B, A, (index, AB_path)
137 |
138 | def __len__(self):
139 | """Return the total number of images in the dataset."""
140 | return len(self.AB_paths)
141 |
142 |
143 | class DIODE(torch.utils.data.Dataset):
144 | """A dataset class for paired image dataset.
145 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
146 | During test time, you need to prepare a directory '/path/to/data/test'.
147 | """
148 |
149 | def __init__(
150 | self,
151 | dataroot,
152 | train=True,
153 | img_size=256,
154 | random_crop=False,
155 | random_flip=True,
156 | down_sample_img_size=0,
157 | cache_name="cache",
158 | disable_cache=False,
159 | ):
160 | """Initialize this dataset class.
161 | Parameters:
162 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
163 | """
164 | super().__init__()
165 | self.image_root = os.path.join(dataroot, "train" if train else "val")
166 | self.crop_size = img_size
167 | self.resize_size = img_size
168 |
169 | self.random_crop = random_crop
170 | self.random_flip = random_flip
171 | self.train = train
172 |
173 | self.filenames = [
174 | l
175 | for l in os.listdir(self.image_root)
176 | if not l.endswith(".pth") and not l.endswith("_depth.png") and not l.endswith("_normal.png")
177 | ]
178 |
179 | self.cache_path = os.path.join(self.image_root, cache_name + f"_{img_size}.pth")
180 | if os.path.exists(self.cache_path) and not disable_cache:
181 | self.cache = torch.load(self.cache_path)
182 | # self.cache['img'] = self.cache['img'][:256]
183 | self.scale_factor = self.cache["scale_factor"]
184 | print("Loaded cache from {}".format(self.cache_path))
185 | else:
186 | self.cache = None
187 |
188 | def __getitem__(self, index):
189 | """Return a data point and its metadata information.
190 | Parameters:
191 | index - - a random integer for data indexing
192 | Returns a dictionary that contains A, B, A_paths and B_paths
193 | A (tensor) - - an image in the input domain
194 | B (tensor) - - its corresponding image in the target domain
195 | A_paths (str) - - image paths
196 | B_paths (str) - - image paths (same as A_paths)
197 | """
198 | # read a image given a random integer index
199 |
200 | fn = self.filenames[index]
201 | img_path = os.path.join(self.image_root, fn)
202 | label_path = os.path.join(self.image_root, fn[:-4] + "_normal.png")
203 |
204 | with bf.BlobFile(img_path, "rb") as f:
205 | pil_image = Image.open(f)
206 | pil_image.load()
207 | pil_image = pil_image.convert("RGB")
208 |
209 | with bf.BlobFile(label_path, "rb") as f:
210 | pil_label = Image.open(f)
211 | pil_label.load()
212 | pil_label = pil_label.convert("RGB")
213 |
214 | # apply the same transform to both A and B
215 | params = get_params(pil_image.size, self.resize_size, self.crop_size)
216 |
217 | transform_label = get_transform(
218 | params, self.resize_size, self.crop_size, method=Image.NEAREST, crop=False, flip=self.random_flip
219 | )
220 | transform_image = get_transform(params, self.resize_size, self.crop_size, crop=False, flip=self.random_flip)
221 |
222 | cond = transform_label(pil_label)
223 | img = transform_image(pil_image)
224 |
225 | return img, cond, (index, fn)
226 |
227 | def __len__(self):
228 | """Return the total number of images in the dataset."""
229 | if self.cache is not None:
230 | return len(self.cache["img"])
231 | else:
232 | return len(self.filenames)
233 |
--------------------------------------------------------------------------------
/datasets/image_folder.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | '.tif', '.TIF', '.tiff', '.TIFF',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 |
18 | def make_dataset(dir, max_dataset_size=float("inf")):
19 | images = []
20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
21 |
22 | for root, _, fnames in sorted(os.walk(dir)):
23 | for fname in fnames:
24 | if is_image_file(fname):
25 | path = os.path.join(root, fname)
26 | images.append(path)
27 | return images[:min(max_dataset_size, len(images))]
28 |
29 |
30 | def default_loader(path):
31 | return Image.open(path).convert('RGB')
32 |
33 |
34 | class ImageFolder(data.Dataset):
35 |
36 | def __init__(self, root, transform=None, return_paths=False,
37 | loader=default_loader):
38 | imgs = make_dataset(root)
39 | if len(imgs) == 0:
40 | raise(RuntimeError("Found 0 images in: " + root + "\n"
41 | "Supported image extensions are: " +
42 | ",".join(IMG_EXTENSIONS)))
43 |
44 | self.root = root
45 | self.imgs = imgs
46 | self.transform = transform
47 | self.return_paths = return_paths
48 | self.loader = loader
49 |
50 | def __getitem__(self, index):
51 | path = self.imgs[index]
52 | img = self.loader(path)
53 | if self.transform is not None:
54 | img = self.transform(img)
55 | if self.return_paths:
56 | return img, path
57 | else:
58 | return img
59 |
60 | def __len__(self):
61 | return len(self.imgs)
--------------------------------------------------------------------------------
/datasets/imagenet_inpaint.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | import io
10 |
11 | from PIL import Image
12 | import lmdb
13 |
14 | import torch
15 | import torchvision.datasets as datasets
16 | from torchvision import transforms
17 | from torch.utils.data import Dataset
18 |
19 | from corruption.inpaint import build_inpaint_freeform, build_inpaint_center
20 |
21 | # from ipdb import set_trace as debug
22 |
23 |
24 | def lmdb_loader(path, lmdb_data):
25 | # In-memory binary streams
26 | # print(path)
27 | with lmdb_data.begin(write=False, buffers=True) as txn:
28 | bytedata = txn.get(path.encode())
29 | img = Image.open(io.BytesIO(bytedata))
30 | return img.convert("RGB")
31 |
32 |
33 | def _build_lmdb_dataset(root, transform=None, target_transform=None, loader=lmdb_loader):
34 | """
35 | You can create this dataloader using:
36 | train_data = _build_lmdb_dataset(traindir, transform=train_transform)
37 | valid_data = _build_lmdb_dataset(validdir, transform=val_transform)
38 | """
39 |
40 | root = str(root)
41 | if root.endswith("/"):
42 | root = root[:-1]
43 | pt_path = os.path.join(root + "_faster_imagefolder.lmdb.pt")
44 | lmdb_path = os.path.join(root + "_faster_imagefolder.lmdb")
45 |
46 | if os.path.isfile(pt_path) and os.path.isdir(lmdb_path):
47 | # log.info('[Dataset] Loading pt {} and lmdb {}'.format(pt_path, lmdb_path))
48 | data_set = torch.load(pt_path)
49 | else:
50 | data_set = datasets.ImageFolder(root, None, None, None)
51 | torch.save(data_set, pt_path, pickle_protocol=4)
52 | # log.info('[Dataset] Saving pt to {}'.format(pt_path))
53 | # log.info('[Dataset] Building lmdb to {}'.format(lmdb_path))
54 | env = lmdb.open(lmdb_path, map_size=1e12)
55 | with env.begin(write=True) as txn:
56 | for _path, class_index in data_set.imgs:
57 | with open(_path, "rb") as f:
58 | data = f.read()
59 | txn.put(_path.encode("ascii"), data)
60 | data_set.lmdb_data = lmdb.open(lmdb_path, readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
61 | # reset transform and target_transform
62 | data_set.samples = data_set.imgs
63 | data_set.transform = transform
64 | data_set.target_transform = target_transform
65 | data_set.loader = lambda path: loader(path, data_set.lmdb_data)
66 |
67 | return data_set
68 |
69 |
70 | def build_train_transform(image_size):
71 | return transforms.Compose(
72 | [
73 | transforms.Resize(image_size),
74 | transforms.CenterCrop(image_size),
75 | transforms.RandomHorizontalFlip(p=0.5),
76 | transforms.ToTensor(),
77 | transforms.Lambda(lambda t: (t * 2) - 1), # [0,1] --> [-1, 1]
78 | ]
79 | )
80 |
81 |
82 | def build_test_transform(image_size):
83 | return transforms.Compose(
84 | [
85 | transforms.Resize(image_size),
86 | transforms.CenterCrop(image_size),
87 | # transforms.RandomHorizontalFlip(p=0.5),
88 | transforms.ToTensor(),
89 | transforms.Lambda(lambda t: (t * 2) - 1), # [0,1] --> [-1, 1]
90 | ]
91 | )
92 |
93 |
94 | def build_lmdb_dataset(dataset_dir, image_size, train, transform=None):
95 | """resize -> crop -> to_tensor -> norm(-1,1)"""
96 | fn = os.path.join(dataset_dir, "train" if train else "val")
97 |
98 | if transform is None:
99 | build_transform = build_train_transform if train else build_test_transform
100 | transform = build_transform(image_size)
101 |
102 | dataset = _build_lmdb_dataset(fn, transform=transform)
103 | # log.info(f"[Dataset] Built Imagenet dataset {fn=}, size={len(dataset)}!")
104 | return dataset
105 |
106 |
107 | def readlines(fn):
108 | file = open(fn, "r").readlines()
109 | return [line.strip("\n\r") for line in file]
110 |
111 |
112 | def build_lmdb_dataset_val10k(dataset_dir, image_size, transform=None):
113 | fn = os.path.join(dataset_dir, "val")
114 | fn_10k = readlines("assets/datasets/val_faster_imagefolder_10k_fn.txt")
115 | label_10k = readlines("assets/datasets/val_faster_imagefolder_10k_label.txt")
116 |
117 | if transform is None:
118 | transform = build_test_transform(image_size)
119 | dataset = _build_lmdb_dataset(fn, transform=transform)
120 | dataset.samples = [(os.path.join(fn, name), int(label)) for name, label in zip(fn_10k, label_10k)]
121 |
122 | assert len(dataset) == 10_000
123 | # log.info(f"[Dataset] Built Imagenet val10k, size={len(dataset)}!")
124 | return dataset
125 |
126 |
127 | class InpaintingVal10kSubset(Dataset):
128 | def __init__(self, dataset_dir, image_size, corrupt_type, transform=None):
129 | super(InpaintingVal10kSubset, self).__init__()
130 | self.dataset = build_lmdb_dataset_val10k(dataset_dir, image_size, transform)
131 |
132 | if corrupt_type == "center":
133 | self.corrupt = build_inpaint_center(corrupt_type, image_size)
134 | elif "freeform" in corrupt_type:
135 | self.corrupt = build_inpaint_freeform(corrupt_type)
136 | else:
137 | raise NotImplementedError()
138 |
139 | def __len__(self):
140 | return self.dataset.__len__()
141 |
142 | def __getitem__(self, index):
143 | img_clean, label = self.dataset[index]
144 | img_corrupt, mask = self.corrupt(img_clean)
145 | return img_clean, img_corrupt, (index, mask, label)
146 |
147 |
148 | class ImageNetInpaintingDataset(Dataset):
149 | def __init__(self, dataset_dir, image_size, corrupt_type, train, transform=None):
150 | super(ImageNetInpaintingDataset, self).__init__()
151 | self.dataset = build_lmdb_dataset(dataset_dir, image_size, train, transform)
152 |
153 | if corrupt_type == "center":
154 | self.corrupt = build_inpaint_center(corrupt_type, image_size)
155 | elif "freeform" in corrupt_type:
156 | self.corrupt = build_inpaint_freeform(corrupt_type)
157 | else:
158 | raise NotImplementedError()
159 |
160 | def __len__(self):
161 | return self.dataset.__len__()
162 |
163 | def __getitem__(self, index):
164 | img_clean, label = self.dataset[index]
165 | img_corrupt, mask = self.corrupt(img_clean)
166 | return img_clean, img_corrupt, (index, mask, label)
167 |
--------------------------------------------------------------------------------
/datasets/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import numpy as np
9 | import torch
10 |
11 | # ----------------------------------------------------------------------------
12 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
13 | # same constant is used multiple times.
14 |
15 | _constant_cache = dict()
16 |
17 |
18 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
19 | value = np.asarray(value)
20 | if shape is not None:
21 | shape = tuple(shape)
22 | if dtype is None:
23 | dtype = torch.get_default_dtype()
24 | if device is None:
25 | device = torch.device("cpu")
26 | if memory_format is None:
27 | memory_format = torch.contiguous_format
28 |
29 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
30 | tensor = _constant_cache.get(key, None)
31 | if tensor is None:
32 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
33 | if shape is not None:
34 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
35 | tensor = tensor.contiguous(memory_format=memory_format)
36 | _constant_cache[key] = tensor
37 | return tensor
38 |
--------------------------------------------------------------------------------
/ddbm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/DiffusionBridge/92522733cc602686df77f07a1824bb89f89cda1a/ddbm/__init__.py
--------------------------------------------------------------------------------
/ddbm/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import os
6 |
7 | import torch
8 | import torch.distributed as dist
9 |
10 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
11 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
12 | WORLD_RANK = int(os.environ["RANK"])
13 |
14 |
15 | def setup_dist():
16 | """
17 | Setup a distributed process group.
18 | """
19 | if dist.is_initialized():
20 | return
21 |
22 | torch.cuda.set_device(LOCAL_RANK)
23 | backend = "gloo" if not torch.cuda.is_available() else "nccl"
24 | dist.init_process_group(backend)
25 |
26 |
27 | def dev():
28 | """
29 | Get the device to use for torch.distributed.
30 | """
31 | if torch.cuda.is_available():
32 | return torch.device("cuda")
33 | return torch.device("cpu")
34 |
--------------------------------------------------------------------------------
/ddbm/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import os.path as osp
9 | import json
10 | import time
11 | import datetime
12 | import tempfile
13 | import warnings
14 | from collections import defaultdict
15 | from contextlib import contextmanager
16 | import torch.distributed as dist
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), "expected file or str, got %s" % filename_or_file
43 | self.file = filename_or_file
44 | self.own_file = False
45 |
46 | def writekvs(self, kvs):
47 | if dist.get_rank() != 0 and self.file == sys.stdout:
48 | return
49 | # Create strings for printing
50 | key2str = {}
51 | for key, val in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append("| %s%s | %s%s |" % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))))
71 | lines.append(dashes)
72 | self.file.write("\n".join(lines) + "\n")
73 |
74 | # Flush the output to the file
75 | self.file.flush()
76 |
77 | def _truncate(self, s):
78 | maxlen = 30
79 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
80 |
81 | def writeseq(self, seq):
82 | seq = list(seq)
83 | for i, elem in enumerate(seq):
84 | self.file.write(elem)
85 | if i < len(seq) - 1: # add space unless this is the last one
86 | self.file.write(" ")
87 | self.file.write("\n")
88 | self.file.flush()
89 |
90 | def close(self):
91 | if self.own_file:
92 | self.file.close()
93 |
94 |
95 | class JSONOutputFormat(KVWriter):
96 | def __init__(self, filename):
97 | self.file = open(filename, "wt")
98 |
99 | def writekvs(self, kvs):
100 | for k, v in sorted(kvs.items()):
101 | if hasattr(v, "dtype"):
102 | kvs[k] = float(v)
103 | self.file.write(json.dumps(kvs) + "\n")
104 | self.file.flush()
105 |
106 | def close(self):
107 | self.file.close()
108 |
109 |
110 | class CSVOutputFormat(KVWriter):
111 | def __init__(self, filename):
112 | self.file = open(filename, "w+t")
113 | self.keys = []
114 | self.sep = ","
115 |
116 | def writekvs(self, kvs):
117 | # Add our current row to the history
118 | extra_keys = list(kvs.keys() - self.keys)
119 | extra_keys.sort()
120 | if extra_keys:
121 | self.keys.extend(extra_keys)
122 | self.file.seek(0)
123 | lines = self.file.readlines()
124 | self.file.seek(0)
125 | for i, k in enumerate(self.keys):
126 | if i > 0:
127 | self.file.write(",")
128 | self.file.write(k)
129 | self.file.write("\n")
130 | for line in lines[1:]:
131 | self.file.write(line[:-1])
132 | self.file.write(self.sep * len(extra_keys))
133 | self.file.write("\n")
134 | for i, k in enumerate(self.keys):
135 | if i > 0:
136 | self.file.write(",")
137 | v = kvs.get(k)
138 | if v is not None:
139 | self.file.write(str(v))
140 | self.file.write("\n")
141 | self.file.flush()
142 |
143 | def close(self):
144 | self.file.close()
145 |
146 |
147 | class TensorBoardOutputFormat(KVWriter):
148 | """
149 | Dumps key/value pairs into TensorBoard's numeric format.
150 | """
151 |
152 | def __init__(self, dir):
153 | os.makedirs(dir, exist_ok=True)
154 | self.dir = dir
155 | self.step = 1
156 | prefix = "events"
157 | path = osp.join(osp.abspath(dir), prefix)
158 | import tensorflow as tf
159 | from tensorflow.python import pywrap_tensorflow
160 | from tensorflow.core.util import event_pb2
161 | from tensorflow.python.util import compat
162 |
163 | self.tf = tf
164 | self.event_pb2 = event_pb2
165 | self.pywrap_tensorflow = pywrap_tensorflow
166 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
167 |
168 | def writekvs(self, kvs):
169 | def summary_val(k, v):
170 | kwargs = {"tag": k, "simple_value": float(v)}
171 | return self.tf.Summary.Value(**kwargs)
172 |
173 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
174 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
175 | event.step = self.step # is there any reason why you'd want to specify the step?
176 | self.writer.WriteEvent(event)
177 | self.writer.Flush()
178 | self.step += 1
179 |
180 | def close(self):
181 | if self.writer:
182 | self.writer.Close()
183 | self.writer = None
184 |
185 |
186 | def make_output_format(format, ev_dir, log_suffix=""):
187 | os.makedirs(ev_dir, exist_ok=True)
188 | if format == "stdout":
189 | return HumanOutputFormat(sys.stdout)
190 | elif format == "log":
191 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
192 | elif format == "json":
193 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
194 | elif format == "csv":
195 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
196 | elif format == "tensorboard":
197 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
198 | else:
199 | raise ValueError("Unknown format specified: %s" % (format,))
200 |
201 |
202 | # ================================================================
203 | # API
204 | # ================================================================
205 |
206 |
207 | def logkv(key, val):
208 | """
209 | Log a value of some diagnostic
210 | Call this once for each diagnostic quantity, each iteration
211 | If called many times, last value will be used.
212 | """
213 | get_current().logkv(key, val)
214 |
215 |
216 | def logkv_mean(key, val):
217 | """
218 | The same as logkv(), but if called many times, values averaged.
219 | """
220 | get_current().logkv_mean(key, val)
221 |
222 |
223 | def logkvs(d):
224 | """
225 | Log a dictionary of key-value pairs
226 | """
227 | for k, v in d.items():
228 | logkv(k, v)
229 |
230 |
231 | def dumpkvs():
232 | """
233 | Write all of the diagnostics from the current iteration
234 | """
235 | return get_current().dumpkvs()
236 |
237 |
238 | def getkvs():
239 | return get_current().name2val
240 |
241 |
242 | def log(*args, level=INFO):
243 | """
244 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
245 | """
246 | get_current().log(*args, level=level)
247 |
248 |
249 | def debug(*args):
250 | log(*args, level=DEBUG)
251 |
252 |
253 | def info(*args):
254 | log(*args, level=INFO)
255 |
256 |
257 | def warn(*args):
258 | log(*args, level=WARN)
259 |
260 |
261 | def error(*args):
262 | log(*args, level=ERROR)
263 |
264 |
265 | def set_level(level):
266 | """
267 | Set logging threshold on current logger.
268 | """
269 | get_current().set_level(level)
270 |
271 |
272 | def set_comm(comm):
273 | get_current().set_comm(comm)
274 |
275 |
276 | def get_dir():
277 | """
278 | Get directory that log files are being written to.
279 | will be None if there is no output directory (i.e., if you didn't call start)
280 | """
281 | return get_current().get_dir()
282 |
283 |
284 | record_tabular = logkv
285 | dump_tabular = dumpkvs
286 |
287 |
288 | @contextmanager
289 | def profile_kv(scopename):
290 | logkey = "wait_" + scopename
291 | tstart = time.time()
292 | try:
293 | yield
294 | finally:
295 | get_current().name2val[logkey] += time.time() - tstart
296 |
297 |
298 | def profile(n):
299 | """
300 | Usage:
301 | @profile("my_func")
302 | def my_func(): code
303 | """
304 |
305 | def decorator_with_name(func):
306 | def func_wrapper(*args, **kwargs):
307 | with profile_kv(n):
308 | return func(*args, **kwargs)
309 |
310 | return func_wrapper
311 |
312 | return decorator_with_name
313 |
314 |
315 | # ================================================================
316 | # Backend
317 | # ================================================================
318 |
319 |
320 | def get_current(dir=None):
321 | if Logger.CURRENT is None:
322 | _configure_default_logger(dir=dir)
323 |
324 | return Logger.CURRENT
325 |
326 |
327 | class Logger(object):
328 | DEFAULT = None # A logger with no output files. (See right below class definition)
329 | # So that you can still log to the terminal without setting up any output files
330 | CURRENT = None # Current logger being used by the free functions above
331 |
332 | def __init__(self, dir, output_formats, comm=None):
333 | self.name2val = defaultdict(float) # values this iteration
334 | self.name2cnt = defaultdict(int)
335 | self.level = INFO
336 | self.dir = dir
337 | self.output_formats = output_formats
338 | self.comm = comm
339 |
340 | # Logging API, forwarded
341 | # ----------------------------------------
342 | def logkv(self, key, val):
343 | self.name2val[key] = val
344 |
345 | def logkv_mean(self, key, val):
346 | oldval, cnt = self.name2val[key], self.name2cnt[key]
347 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
348 | self.name2cnt[key] = cnt + 1
349 |
350 | def dumpkvs(self):
351 | if self.comm is None:
352 | d = self.name2val
353 | else:
354 | d = mpi_weighted_mean(
355 | self.comm,
356 | {name: (val, self.name2cnt.get(name, 1)) for (name, val) in self.name2val.items()},
357 | )
358 | if self.comm.rank != 0:
359 | d["dummy"] = 1 # so we don't get a warning about empty dict
360 | out = d.copy() # Return the dict for unit testing purposes
361 | for fmt in self.output_formats:
362 | if isinstance(fmt, KVWriter):
363 | fmt.writekvs(d)
364 | self.name2val.clear()
365 | self.name2cnt.clear()
366 | return out
367 |
368 | def log(self, *args, level=INFO):
369 | if self.level <= level:
370 | self._do_log(args)
371 |
372 | # Configuration
373 | # ----------------------------------------
374 | def set_level(self, level):
375 | self.level = level
376 |
377 | def set_comm(self, comm):
378 | self.comm = comm
379 |
380 | def get_dir(self):
381 | return self.dir
382 |
383 | def close(self):
384 | for fmt in self.output_formats:
385 | fmt.close()
386 |
387 | # Misc
388 | # ----------------------------------------
389 | def _do_log(self, args):
390 | for fmt in self.output_formats:
391 | if isinstance(fmt, SeqWriter):
392 | fmt.writeseq(map(str, args))
393 |
394 |
395 | def get_rank_without_mpi_import():
396 | # check environment variables here instead of importing mpi4py
397 | # to avoid calling MPI_Init() when this module is imported
398 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
399 | if varname in os.environ:
400 | return int(os.environ[varname])
401 | return 0
402 |
403 |
404 | def mpi_weighted_mean(comm, local_name2valcount):
405 | """
406 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
407 | Perform a weighted average over dicts that are each on a different node
408 | Input: local_name2valcount: dict mapping key -> (value, count)
409 | Returns: key -> mean
410 | """
411 | all_name2valcount = comm.gather(local_name2valcount)
412 | if comm.rank == 0:
413 | name2sum = defaultdict(float)
414 | name2count = defaultdict(float)
415 | for n2vc in all_name2valcount:
416 | for name, (val, count) in n2vc.items():
417 | try:
418 | val = float(val)
419 | except ValueError:
420 | if comm.rank == 0:
421 | warnings.warn("WARNING: tried to compute mean on non-float {}={}".format(name, val))
422 | else:
423 | name2sum[name] += val * count
424 | name2count[name] += count
425 | return {name: name2sum[name] / name2count[name] for name in name2sum}
426 | else:
427 | return {}
428 |
429 |
430 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
431 | """
432 | If comm is provided, average all numerical stats across that comm
433 | """
434 | if dir is None:
435 | dir = os.getenv("OPENAI_LOGDIR")
436 | if dir is None:
437 | dir = osp.join(
438 | tempfile.gettempdir(),
439 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
440 | )
441 | assert isinstance(dir, str)
442 | dir = os.path.expanduser(dir)
443 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
444 |
445 | rank = get_rank_without_mpi_import()
446 | if rank > 0:
447 | log_suffix = log_suffix + "-rank%03i" % rank
448 |
449 | if format_strs is None:
450 | if rank == 0:
451 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
452 | else:
453 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
454 | format_strs = filter(None, format_strs)
455 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
456 |
457 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
458 | if output_formats:
459 | log("Logging to %s" % dir)
460 |
461 |
462 | def _configure_default_logger(dir):
463 | configure(dir=dir)
464 | Logger.DEFAULT = Logger.CURRENT
465 |
466 |
467 | def reset():
468 | if Logger.CURRENT is not Logger.DEFAULT:
469 | Logger.CURRENT.close()
470 | Logger.CURRENT = Logger.DEFAULT
471 | log("Reset logger")
472 |
473 |
474 | @contextmanager
475 | def scoped_configure(dir=None, format_strs=None, comm=None):
476 | prevlogger = Logger.CURRENT
477 | configure(dir=dir, format_strs=format_strs, comm=comm)
478 | try:
479 | yield
480 | finally:
481 | Logger.CURRENT.close()
482 | Logger.CURRENT = prevlogger
483 |
--------------------------------------------------------------------------------
/ddbm/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * torch.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def append_dims(x, target_dims):
94 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
95 | dims_to_append = target_dims - x.ndim
96 | if dims_to_append < 0:
97 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
98 | return x[(...,) + (None,) * dims_to_append]
99 |
100 |
101 | def append_zero(x):
102 | return torch.cat([x, x.new_zeros([1])])
103 |
104 |
105 | def normalization(channels):
106 | """
107 | Make a standard normalization layer.
108 |
109 | :param channels: number of input channels.
110 | :return: an nn.Module for normalization.
111 | """
112 | return GroupNorm32(32, channels)
113 |
114 |
115 | def timestep_embedding(timesteps, dim, max_period=10000):
116 | """
117 | Create sinusoidal timestep embeddings.
118 |
119 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
120 | These may be fractional.
121 | :param dim: the dimension of the output.
122 | :param max_period: controls the minimum frequency of the embeddings.
123 | :return: an [N x dim] Tensor of positional embeddings.
124 | """
125 | half = dim // 2
126 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
127 | device=timesteps.device
128 | )
129 | args = timesteps[:, None].float() * freqs[None]
130 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
131 | if dim % 2:
132 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
133 | return embedding
134 |
135 |
136 | def checkpoint(func, inputs, params, flag):
137 | """
138 | Evaluate a function without caching intermediate activations, allowing for
139 | reduced memory at the expense of extra compute in the backward pass.
140 |
141 | :param func: the function to evaluate.
142 | :param inputs: the argument sequence to pass to `func`.
143 | :param params: a sequence of parameters `func` depends on but does not
144 | explicitly take as arguments.
145 | :param flag: if False, disable gradient checkpointing.
146 | """
147 | if flag:
148 | args = tuple(inputs) + tuple(params)
149 | return CheckpointFunction.apply(func, len(inputs), *args)
150 | else:
151 | return func(*inputs)
152 |
153 |
154 | class CheckpointFunction(torch.autograd.Function):
155 | @staticmethod
156 | def forward(ctx, run_function, length, *args):
157 | ctx.run_function = run_function
158 | ctx.input_tensors = list(args[:length])
159 | ctx.input_params = list(args[length:])
160 | with torch.no_grad():
161 | output_tensors = ctx.run_function(*ctx.input_tensors)
162 | return output_tensors
163 |
164 | @staticmethod
165 | def backward(ctx, *output_grads):
166 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
167 | with torch.enable_grad():
168 | # Fixes a bug where the first op in run_function modifies the
169 | # Tensor storage in place, which is not allowed for detach()'d
170 | # Tensors.
171 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
172 | output_tensors = ctx.run_function(*shallow_copies)
173 | input_grads = torch.autograd.grad(
174 | output_tensors,
175 | ctx.input_tensors + ctx.input_params,
176 | output_grads,
177 | allow_unused=True,
178 | )
179 | del ctx.input_tensors
180 | del ctx.input_params
181 | del output_tensors
182 | return (None, None) + input_grads
183 |
--------------------------------------------------------------------------------
/ddbm/random_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from . import dist_util
4 |
5 |
6 | def get_generator(generator, num_samples=0, seed=0):
7 | if generator == "dummy":
8 | return DummyGenerator()
9 | elif generator == "determ":
10 | return DeterministicGenerator(num_samples, seed)
11 | elif generator == "determ-indiv":
12 | return DeterministicIndividualGenerator(num_samples, seed)
13 | else:
14 | raise NotImplementedError
15 |
16 |
17 | class DummyGenerator:
18 | def randn(self, *args, **kwargs):
19 | return torch.randn(*args, **kwargs)
20 |
21 | def randint(self, *args, **kwargs):
22 | return torch.randint(*args, **kwargs)
23 |
24 | def randn_like(self, *args, **kwargs):
25 | return torch.randn_like(*args, **kwargs)
26 |
27 |
28 | class DeterministicGenerator:
29 | """
30 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
31 | Uses a single rng and samples num_samples sized randomness and subsamples the current indices
32 | """
33 |
34 | def __init__(self, num_samples, seed=0):
35 | if dist.is_initialized():
36 | self.rank = dist.get_rank()
37 | self.world_size = dist.get_world_size()
38 | else:
39 | print("Warning: Distributed not initialised, using single rank")
40 | self.rank = 0
41 | self.world_size = 1
42 | self.num_samples = num_samples
43 | self.done_samples = 0
44 | self.seed = seed
45 | self.rng_cpu = torch.Generator()
46 | if torch.cuda.is_available():
47 | self.rng_cuda = torch.Generator(dist_util.dev())
48 | self.set_seed(seed)
49 |
50 | def get_global_size_and_indices(self, size):
51 | global_size = (self.num_samples, *size[1:])
52 | indices = torch.arange(
53 | self.done_samples + self.rank,
54 | self.done_samples + self.world_size * int(size[0]),
55 | self.world_size,
56 | )
57 | indices = torch.clamp(indices, 0, self.num_samples - 1)
58 | assert len(indices) == size[0], f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
59 | return global_size, indices
60 |
61 | def get_generator(self, device):
62 | return self.rng_cpu if torch.device(device).type == "cpu" else self.rng_cuda
63 |
64 | def randn(self, *size, dtype=torch.float, device="cpu"):
65 | global_size, indices = self.get_global_size_and_indices(size)
66 | generator = self.get_generator(device)
67 | return torch.randn(*global_size, generator=generator, dtype=dtype, device=device)[indices]
68 |
69 | def randint(self, low, high, size, dtype=torch.long, device="cpu"):
70 | global_size, indices = self.get_global_size_and_indices(size)
71 | generator = self.get_generator(device)
72 | return torch.randint(low, high, generator=generator, size=global_size, dtype=dtype, device=device)[indices]
73 |
74 | def randn_like(self, tensor):
75 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device
76 | return self.randn(*size, dtype=dtype, device=device)
77 |
78 | def set_done_samples(self, done_samples):
79 | self.done_samples = done_samples
80 | self.set_seed(self.seed)
81 |
82 | def get_seed(self):
83 | return self.seed
84 |
85 | def set_seed(self, seed):
86 | self.rng_cpu.manual_seed(seed)
87 | if torch.cuda.is_available():
88 | self.rng_cuda.manual_seed(seed)
89 |
90 |
91 | class DeterministicIndividualGenerator:
92 | """
93 | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
94 | Uses a separate rng for each sample to reduce memoery usage
95 | """
96 |
97 | def __init__(self, num_samples, seed=0):
98 | if dist.is_initialized():
99 | self.rank = dist.get_rank()
100 | self.world_size = dist.get_world_size()
101 | else:
102 | print("Warning: Distributed not initialised, using single rank")
103 | self.rank = 0
104 | self.world_size = 1
105 | self.num_samples = num_samples
106 | self.done_samples = 0
107 | self.seed = seed
108 | self.rng_cpu = [torch.Generator() for _ in range(num_samples)]
109 | if torch.cuda.is_available():
110 | self.rng_cuda = [torch.Generator(dist_util.dev()) for _ in range(num_samples)]
111 | self.set_seed(seed)
112 |
113 | def get_size_and_indices(self, size):
114 | indices = torch.arange(
115 | self.done_samples + self.rank,
116 | self.done_samples + self.world_size * int(size[0]),
117 | self.world_size,
118 | )
119 | indices = torch.clamp(indices, 0, self.num_samples - 1)
120 | assert len(indices) == size[0], f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
121 | return (1, *size[1:]), indices
122 |
123 | def get_generator(self, device):
124 | return self.rng_cpu if torch.device(device).type == "cpu" else self.rng_cuda
125 |
126 | def randn(self, *size, dtype=torch.float, device="cpu"):
127 | size, indices = self.get_size_and_indices(size)
128 | generator = self.get_generator(device)
129 | return torch.cat(
130 | [torch.randn(*size, generator=generator[i], dtype=dtype, device=device) for i in indices],
131 | dim=0,
132 | )
133 |
134 | def randint(self, low, high, size, dtype=torch.long, device="cpu"):
135 | size, indices = self.get_size_and_indices(size)
136 | generator = self.get_generator(device)
137 | return torch.cat(
138 | [
139 | torch.randint(
140 | low,
141 | high,
142 | generator=generator[i],
143 | size=size,
144 | dtype=dtype,
145 | device=device,
146 | )
147 | for i in indices
148 | ],
149 | dim=0,
150 | )
151 |
152 | def randn_like(self, tensor):
153 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device
154 | return self.randn(*size, dtype=dtype, device=device)
155 |
156 | def set_done_samples(self, done_samples):
157 | self.done_samples = done_samples
158 |
159 | def get_seed(self):
160 | return self.seed
161 |
162 | def set_seed(self, seed):
163 | [rng_cpu.manual_seed(i + self.num_samples * seed) for i, rng_cpu in enumerate(self.rng_cpu)]
164 | if torch.cuda.is_available():
165 | [rng_cuda.manual_seed(i + self.num_samples * seed) for i, rng_cuda in enumerate(self.rng_cuda)]
166 |
167 |
168 | class BatchedSeedGenerator:
169 |
170 | def __init__(self, seeds=None):
171 | self.num_samples = len(seeds)
172 | if torch.cuda.is_available():
173 | self.rng = [torch.Generator(dist_util.dev()) for _ in range(self.num_samples)]
174 | else:
175 | self.rng = [torch.Generator() for _ in range(self.num_samples)]
176 | [rng.manual_seed(int(seeds[i])) for i, rng in enumerate(self.rng)]
177 |
178 | def randn(self, size, dtype=torch.float, device="cpu"):
179 | assert size[0] == self.num_samples
180 | return torch.cat(
181 | [
182 | torch.randn(1, *size[1:], generator=self.rng[i], dtype=dtype, device=device)
183 | for i in range(self.num_samples)
184 | ],
185 | dim=0,
186 | )
187 |
188 | def randint(self, low, high, size, dtype=torch.long, device="cpu"):
189 | assert size[0] == self.num_samples
190 | return torch.cat(
191 | [
192 | torch.randint(
193 | low,
194 | high,
195 | generator=self.rng[i],
196 | size=(1, *size[1:]),
197 | dtype=dtype,
198 | device=device,
199 | )
200 | for i in range(self.num_samples)
201 | ],
202 | dim=0,
203 | )
204 |
205 | def randn_like(self, tensor):
206 | size, dtype, device = tensor.size(), tensor.dtype, tensor.device
207 | return self.randn(size, dtype=dtype, device=device)
208 |
--------------------------------------------------------------------------------
/ddbm/resample.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def create_named_schedule_sampler(name, diffusion):
5 | """
6 | Create a ScheduleSampler from a library of pre-defined samplers.
7 |
8 | :param name: the name of the sampler.
9 | :param diffusion: the diffusion object to sample for.
10 | """
11 | if name == "real-uniform":
12 | return RealUniformSampler(diffusion)
13 | else:
14 | raise NotImplementedError(f"unknown schedule sampler: {name}")
15 |
16 |
17 | class RealUniformSampler:
18 | def __init__(self, diffusion):
19 | self.t_max = diffusion.t_max
20 | self.t_min = diffusion.t_min
21 |
22 | def sample(self, batch_size, device):
23 | ts = torch.rand(batch_size).to(device) * (self.t_max - self.t_min) + self.t_min
24 | return ts, torch.ones_like(ts)
25 |
--------------------------------------------------------------------------------
/ddbm/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from .karras_diffusion import (
4 | KarrasDenoiser,
5 | VPNoiseSchedule,
6 | VENoiseSchedule,
7 | I2SBNoiseSchedule,
8 | DDBMPreCond,
9 | I2SBPreCond,
10 | )
11 | from .unet import UNetModel
12 |
13 | NUM_CLASSES = 1000
14 |
15 |
16 | def get_workdir(exp):
17 | workdir = f"./workdir/{exp}"
18 | return workdir
19 |
20 |
21 | def sample_defaults():
22 | return dict(
23 | generator="determ",
24 | clip_denoised=True,
25 | sampler="euler",
26 | s_churn=0.0,
27 | s_tmin=0.002,
28 | s_tmax=80,
29 | s_noise=1.0,
30 | steps=40,
31 | model_path="",
32 | seed=42,
33 | ts="",
34 | )
35 |
36 |
37 | def model_and_diffusion_defaults():
38 | """
39 | Defaults for image training.
40 | """
41 | res = dict(
42 | sigma_data=0.5,
43 | sigma_min=0.002,
44 | sigma_max=80.0,
45 | beta_d=2,
46 | beta_min=0.1,
47 | beta_max=1.0,
48 | cov_xy=0.0,
49 | image_size=64,
50 | in_channels=3,
51 | num_channels=128,
52 | num_res_blocks=2,
53 | num_heads=4,
54 | num_heads_upsample=-1,
55 | num_head_channels=64,
56 | unet_type="adm",
57 | attention_resolutions="32,16,8",
58 | channel_mult="",
59 | dropout=0.0,
60 | class_cond=False,
61 | use_checkpoint=False,
62 | use_scale_shift_norm=True,
63 | resblock_updown=True,
64 | use_fp16=True,
65 | use_new_attention_order=False,
66 | condition_mode=None,
67 | noise_schedule="ve",
68 | )
69 | return res
70 |
71 |
72 | def create_model_and_diffusion(
73 | image_size,
74 | in_channels,
75 | class_cond,
76 | num_channels,
77 | num_res_blocks,
78 | channel_mult,
79 | num_heads,
80 | num_head_channels,
81 | num_heads_upsample,
82 | attention_resolutions,
83 | dropout,
84 | use_checkpoint,
85 | use_scale_shift_norm,
86 | resblock_updown,
87 | use_fp16,
88 | use_new_attention_order,
89 | condition_mode,
90 | noise_schedule,
91 | sigma_data=0.5,
92 | sigma_min=0.002,
93 | sigma_max=80.0,
94 | beta_d=2,
95 | beta_min=0.1,
96 | beta_max=1.0,
97 | cov_xy=0.0,
98 | unet_type="adm",
99 | ):
100 | model = create_model(
101 | image_size,
102 | in_channels,
103 | num_channels,
104 | num_res_blocks,
105 | unet_type=unet_type,
106 | channel_mult=channel_mult,
107 | class_cond=class_cond,
108 | use_checkpoint=use_checkpoint,
109 | attention_resolutions=attention_resolutions,
110 | num_heads=num_heads,
111 | num_head_channels=num_head_channels,
112 | num_heads_upsample=num_heads_upsample,
113 | use_scale_shift_norm=use_scale_shift_norm,
114 | dropout=dropout,
115 | resblock_updown=resblock_updown,
116 | use_fp16=use_fp16,
117 | use_new_attention_order=use_new_attention_order,
118 | condition_mode=condition_mode,
119 | )
120 | if noise_schedule.startswith("vp"):
121 | ns = VPNoiseSchedule(beta_d=beta_d, beta_min=beta_min)
122 | precond = DDBMPreCond(ns, sigma_data=sigma_data, cov_xy=cov_xy)
123 | elif noise_schedule == "ve":
124 | ns = VENoiseSchedule(sigma_max=sigma_max)
125 | precond = DDBMPreCond(ns, sigma_data=sigma_data, cov_xy=cov_xy)
126 | elif noise_schedule.startswith("i2sb"):
127 | ns = I2SBNoiseSchedule(beta_max=beta_max, beta_min=beta_min)
128 | precond = I2SBPreCond(ns)
129 |
130 | diffusion = KarrasDenoiser(
131 | noise_schedule=ns,
132 | precond=precond,
133 | t_max=sigma_max,
134 | t_min=sigma_min,
135 | )
136 | return model, diffusion
137 |
138 |
139 | def create_model(
140 | image_size,
141 | in_channels,
142 | num_channels,
143 | num_res_blocks,
144 | unet_type="adm",
145 | channel_mult="",
146 | class_cond=False,
147 | use_checkpoint=False,
148 | attention_resolutions="16",
149 | num_heads=1,
150 | num_head_channels=-1,
151 | num_heads_upsample=-1,
152 | use_scale_shift_norm=False,
153 | dropout=0,
154 | resblock_updown=False,
155 | use_fp16=False,
156 | use_new_attention_order=False,
157 | condition_mode=None,
158 | ):
159 | if channel_mult == "":
160 | if image_size == 512:
161 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
162 | elif image_size == 256:
163 | channel_mult = (1, 1, 2, 2, 4, 4)
164 | elif image_size == 128:
165 | channel_mult = (1, 1, 2, 3, 4)
166 | elif image_size == 64:
167 | channel_mult = (1, 2, 3, 4)
168 | elif image_size == 32:
169 | channel_mult = (1, 2, 3, 4)
170 | else:
171 | raise ValueError(f"unsupported image size: {image_size}")
172 | else:
173 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
174 |
175 | attention_ds = []
176 | for res in attention_resolutions.split(","):
177 | attention_ds.append(image_size // int(res))
178 |
179 | if unet_type == "adm":
180 | return UNetModel(
181 | image_size=image_size,
182 | in_channels=in_channels,
183 | model_channels=num_channels,
184 | out_channels=in_channels,
185 | num_res_blocks=num_res_blocks,
186 | attention_resolutions=tuple(attention_ds),
187 | dropout=dropout,
188 | channel_mult=channel_mult,
189 | num_classes=(NUM_CLASSES if class_cond else None),
190 | use_checkpoint=use_checkpoint,
191 | use_fp16=use_fp16,
192 | num_heads=num_heads,
193 | num_head_channels=num_head_channels,
194 | num_heads_upsample=num_heads_upsample,
195 | use_scale_shift_norm=use_scale_shift_norm,
196 | resblock_updown=resblock_updown,
197 | use_new_attention_order=use_new_attention_order,
198 | condition_mode=condition_mode,
199 | )
200 | else:
201 | raise ValueError(f"Unsupported unet type: {unet_type}")
202 |
203 |
204 | def add_dict_to_argparser(parser, default_dict):
205 | for k, v in default_dict.items():
206 | v_type = type(v)
207 | if v is None:
208 | v_type = str
209 | elif isinstance(v, bool):
210 | v_type = str2bool
211 | parser.add_argument(f"--{k}", default=v, type=v_type)
212 |
213 |
214 | def args_to_dict(args, keys):
215 | return {k: getattr(args, k) for k in keys}
216 |
217 |
218 | def str2bool(v):
219 | """
220 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
221 | """
222 | if isinstance(v, bool):
223 | return v
224 | if v.lower() in ("yes", "true", "t", "y", "1"):
225 | return True
226 | elif v.lower() in ("no", "false", "f", "n", "0"):
227 | return False
228 | else:
229 | raise argparse.ArgumentTypeError("boolean value expected")
230 |
--------------------------------------------------------------------------------
/ddbm/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import numpy as np
6 |
7 | import blobfile as bf
8 | import torch
9 | import torch.distributed as dist
10 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
11 | from torch.optim import RAdam
12 |
13 | from . import dist_util, logger
14 | from .nn import update_ema
15 |
16 | from ddbm.random_util import get_generator
17 |
18 | import glob
19 |
20 | import wandb
21 |
22 |
23 | class TrainLoop:
24 | def __init__(
25 | self,
26 | *,
27 | model,
28 | diffusion,
29 | train_data,
30 | test_data,
31 | batch_size,
32 | microbatch,
33 | lr,
34 | ema_rate,
35 | log_interval,
36 | test_interval,
37 | save_interval,
38 | save_interval_for_preemption,
39 | resume_checkpoint,
40 | workdir,
41 | use_fp16=False,
42 | fp16_scale_growth=1e-3,
43 | schedule_sampler=None,
44 | weight_decay=0.0,
45 | lr_anneal_steps=0,
46 | total_training_steps=10000000,
47 | augment_pipe=None,
48 | train_mode="ddbm",
49 | resume_train_flag=False,
50 | **sample_kwargs,
51 | ):
52 | self.model = model
53 | self.diffusion = diffusion
54 | self.data = train_data
55 | self.test_data = test_data
56 | self.image_size = model.image_size
57 | self.batch_size = batch_size
58 | self.microbatch = microbatch if microbatch > 0 else batch_size
59 | self.lr = lr
60 | self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")]
61 | self.log_interval = log_interval
62 | self.workdir = workdir
63 | self.test_interval = test_interval
64 | self.save_interval = save_interval
65 | self.save_interval_for_preemption = save_interval_for_preemption
66 | self.resume_checkpoint = resume_checkpoint
67 | self.use_fp16 = use_fp16
68 | self.fp16_scale_growth = fp16_scale_growth
69 | self.schedule_sampler = schedule_sampler
70 | self.weight_decay = weight_decay
71 | self.lr_anneal_steps = lr_anneal_steps
72 | self.total_training_steps = total_training_steps
73 |
74 | self.train_mode = train_mode
75 |
76 | self.step = 0
77 | self.resume_train_flag = resume_train_flag
78 | self.resume_step = 0
79 | self.global_batch = self.batch_size * dist.get_world_size()
80 |
81 | self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_fp16)
82 |
83 | self._load_and_sync_parameters()
84 | if not self.resume_train_flag:
85 | self.resume_step = 0
86 |
87 | self.opt = RAdam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
88 | if self.resume_step:
89 | self._load_optimizer_state()
90 | # Model was resumed, either due to a restart or a checkpoint
91 | # being specified at the command line.
92 | self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate]
93 | else:
94 | self.ema_params = [copy.deepcopy(list(self.model.parameters())) for _ in range(len(self.ema_rate))]
95 |
96 | if torch.cuda.is_available():
97 | self.use_ddp = True
98 | local_rank = int(os.environ["LOCAL_RANK"])
99 | self.ddp_model = DDP(
100 | self.model,
101 | device_ids=[local_rank],
102 | output_device=local_rank,
103 | )
104 | else:
105 | if dist.get_world_size() > 1:
106 | logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!")
107 | self.use_ddp = False
108 | self.ddp_model = self.model
109 |
110 | self.step = self.resume_step
111 |
112 | self.generator = get_generator(sample_kwargs["generator"], self.batch_size, 42)
113 | self.sample_kwargs = sample_kwargs
114 |
115 | self.augment = augment_pipe
116 |
117 | def _load_and_sync_parameters(self):
118 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
119 |
120 | if resume_checkpoint:
121 | if self.resume_train_flag:
122 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
123 | if dist.get_rank() == 0:
124 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
125 | logger.log("Resume step: ", self.resume_step)
126 |
127 | self.model.load_state_dict(torch.load(resume_checkpoint, map_location="cpu"))
128 | self.model.to(dist_util.dev())
129 |
130 | dist.barrier()
131 |
132 | def _load_ema_parameters(self, rate):
133 | ema_params = copy.deepcopy(list(self.model.parameters()))
134 |
135 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
136 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
137 | if ema_checkpoint:
138 | if dist.get_rank() == 0:
139 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
140 | state_dict = torch.load(ema_checkpoint, map_location=dist_util.dev())
141 | ema_params = [state_dict[name] for name, _ in self.model.named_parameters()]
142 |
143 | dist.barrier()
144 | return ema_params
145 |
146 | def _load_optimizer_state(self):
147 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
148 | if main_checkpoint.split("/")[-1].startswith("freq"):
149 | prefix = "freq_"
150 | else:
151 | prefix = ""
152 | opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"{prefix}opt_{self.resume_step:06}.pt")
153 | if bf.exists(opt_checkpoint):
154 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
155 | state_dict = torch.load(opt_checkpoint, map_location=dist_util.dev())
156 | self.opt.load_state_dict(state_dict)
157 | dist.barrier()
158 |
159 | def run_loop(self):
160 | while True:
161 | for batch, cond, _ in self.data:
162 |
163 | if "inpaint" in self.workdir:
164 | _, mask, label = _
165 | else:
166 | mask = None
167 |
168 | if not (not self.lr_anneal_steps or self.step < self.total_training_steps):
169 | # Save the last checkpoint if it wasn't already saved.
170 | if (self.step - 1) % self.save_interval != 0:
171 | self.save()
172 | return
173 |
174 | if self.augment is not None:
175 | batch, _ = self.augment(batch)
176 | if isinstance(cond, torch.Tensor) and batch.ndim == cond.ndim:
177 | cond = {"xT": cond}
178 | else:
179 | cond["xT"] = cond["xT"]
180 | if mask is not None:
181 | # cond["mask"] = mask
182 | cond["y"] = label
183 |
184 | took_step = self.run_step(batch, cond)
185 | if took_step and self.step % self.log_interval == 0:
186 | logs = logger.dumpkvs()
187 |
188 | if dist.get_rank() == 0:
189 | wandb.log(logs, step=self.step)
190 |
191 | if took_step and self.step % self.save_interval == 0:
192 | self.save()
193 | # Run for a finite amount of time in integration tests.
194 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
195 | return
196 |
197 | test_batch, test_cond, _ = next(iter(self.test_data))
198 | if "inpaint" in self.workdir:
199 | _, mask, label = _
200 | else:
201 | mask = None
202 | if isinstance(test_cond, torch.Tensor) and test_batch.ndim == test_cond.ndim:
203 | test_cond = {"xT": test_cond}
204 | else:
205 | test_cond["xT"] = test_cond["xT"]
206 | if mask is not None:
207 | # test_cond["mask"] = mask
208 | test_cond["y"] = label
209 | self.run_test_step(test_batch, test_cond)
210 | logs = logger.dumpkvs()
211 |
212 | if dist.get_rank() == 0:
213 | wandb.log(logs, step=self.step)
214 |
215 | if took_step and self.step % self.save_interval_for_preemption == 0:
216 | self.save(for_preemption=True)
217 |
218 | def run_step(self, batch, cond):
219 | self.forward_backward(batch, cond)
220 | logger.logkv_mean("lg_loss_scale", np.log2(self.scaler.get_scale()))
221 | self.scaler.unscale_(self.opt)
222 |
223 | def _compute_norms():
224 | grad_norm = 0.0
225 | param_norm = 0.0
226 | for p in self.model.parameters():
227 | with torch.no_grad():
228 | param_norm += torch.norm(p, p=2, dtype=torch.float32).item() ** 2
229 | if p.grad is not None:
230 | grad_norm += torch.norm(p.grad, p=2, dtype=torch.float32).item() ** 2
231 | return np.sqrt(grad_norm), np.sqrt(param_norm)
232 |
233 | grad_norm, param_norm = _compute_norms()
234 |
235 | logger.logkv_mean("grad_norm", grad_norm)
236 | logger.logkv_mean("param_norm", param_norm)
237 |
238 | self.scaler.step(self.opt)
239 | self.scaler.update()
240 | self.step += 1
241 | self._update_ema()
242 |
243 | self._anneal_lr()
244 | self.log_step()
245 | return True
246 |
247 | def run_test_step(self, batch, cond):
248 | with torch.no_grad():
249 | self.forward_backward(batch, cond, train=False)
250 |
251 | def forward_backward(self, batch, cond, train=True):
252 | if train:
253 | self.opt.zero_grad()
254 | assert batch.shape[0] % self.microbatch == 0
255 | num_microbatches = batch.shape[0] // self.microbatch
256 | for i in range(0, batch.shape[0], self.microbatch):
257 | with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.use_fp16):
258 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
259 | micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()}
260 | last_batch = (i + self.microbatch) >= batch.shape[0]
261 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
262 |
263 | if self.train_mode == "ddbm":
264 | compute_losses = functools.partial(
265 | self.diffusion.training_bridge_losses,
266 | self.ddp_model,
267 | micro,
268 | t,
269 | model_kwargs=micro_cond,
270 | )
271 | else:
272 | raise NotImplementedError()
273 |
274 | if last_batch or not self.use_ddp:
275 | losses = compute_losses()
276 | else:
277 | with self.ddp_model.no_sync():
278 | losses = compute_losses()
279 |
280 | loss = (losses["loss"] * weights).mean() / num_microbatches
281 | log_loss_dict(self.diffusion, t, {k if train else "test_" + k: v * weights for k, v in losses.items()})
282 | if train:
283 | self.scaler.scale(loss).backward()
284 |
285 | def _update_ema(self):
286 | for rate, params in zip(self.ema_rate, self.ema_params):
287 | update_ema(params, self.model.parameters(), rate=rate)
288 |
289 | def _anneal_lr(self):
290 | if not self.lr_anneal_steps:
291 | return
292 | frac_done = (self.step) / self.lr_anneal_steps
293 | lr = self.lr * (1 - frac_done)
294 | for param_group in self.opt.param_groups:
295 | param_group["lr"] = lr
296 |
297 | def log_step(self):
298 | logger.logkv("step", self.step)
299 | logger.logkv("samples", (self.step + 1) * self.global_batch)
300 |
301 | def save(self, for_preemption=False):
302 | def maybe_delete_earliest(filename):
303 | wc = filename.split(f"{(self.step):06d}")[0] + "*"
304 | freq_states = list(glob.glob(os.path.join(get_blob_logdir(), wc)))
305 | if len(freq_states) > 3000:
306 | earliest = min(freq_states, key=lambda x: x.split("_")[-1].split(".")[0])
307 | os.remove(earliest)
308 |
309 | # if dist.get_rank() == 0 and for_preemption:
310 | # maybe_delete_earliest(get_blob_logdir())
311 | def save_checkpoint(rate, params):
312 | state_dict = self.model.state_dict()
313 | for i, (name, _) in enumerate(self.model.named_parameters()):
314 | assert name in state_dict
315 | state_dict[name] = params[i]
316 | if dist.get_rank() == 0:
317 | logger.log(f"saving model {rate}...")
318 | if not rate:
319 | filename = f"model_{(self.step):06d}.pt"
320 | else:
321 | filename = f"ema_{rate}_{(self.step):06d}.pt"
322 | if for_preemption:
323 | filename = f"freq_{filename}"
324 | maybe_delete_earliest(filename)
325 |
326 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
327 | torch.save(state_dict, f)
328 |
329 | for rate, params in zip(self.ema_rate, self.ema_params):
330 | save_checkpoint(rate, params)
331 |
332 | if dist.get_rank() == 0:
333 | filename = f"opt_{(self.step):06d}.pt"
334 | if for_preemption:
335 | filename = f"freq_{filename}"
336 | maybe_delete_earliest(filename)
337 |
338 | with bf.BlobFile(
339 | bf.join(get_blob_logdir(), filename),
340 | "wb",
341 | ) as f:
342 | torch.save(self.opt.state_dict(), f)
343 |
344 | # Save model parameters last to prevent race conditions where a restart
345 | # loads model at step N, but opt/ema state isn't saved for step N.
346 | save_checkpoint(0, list(self.model.parameters()))
347 | dist.barrier()
348 |
349 |
350 | def parse_resume_step_from_filename(filename):
351 | """
352 | Parse filenames of the form path/to/model_NNNNNN.pt, where NNNNNN is the
353 | checkpoint's number of steps.
354 | """
355 | split = filename.split("model_")
356 | if len(split) < 2:
357 | return 0
358 | split1 = split[-1].split(".")[0]
359 | try:
360 | return int(split1)
361 | except ValueError:
362 | return 0
363 |
364 |
365 | def get_blob_logdir():
366 | # You can change this to be a separate path to save checkpoints to
367 | # a blobstore or some external drive.
368 | return logger.get_dir()
369 |
370 |
371 | def find_resume_checkpoint():
372 | # On your infrastructure, you may want to override this to automatically
373 | # discover the latest checkpoint on your blob storage, etc.
374 | return None
375 |
376 |
377 | def find_ema_checkpoint(main_checkpoint, step, rate):
378 | if main_checkpoint is None:
379 | return None
380 | if main_checkpoint.split("/")[-1].startswith("freq"):
381 | prefix = "freq_"
382 | else:
383 | prefix = ""
384 | filename = f"{prefix}ema_{rate}_{(step):06d}.pt"
385 | path = bf.join(bf.dirname(main_checkpoint), filename)
386 | if bf.exists(path):
387 | return path
388 | return None
389 |
390 |
391 | def log_loss_dict(diffusion, ts, losses):
392 | for key, values in losses.items():
393 | logger.logkv_mean(key, values.mean().item())
394 | # Log the quantiles (four quartiles, in particular).
395 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
396 | quartile = int(4 * sub_t)
397 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
398 |
--------------------------------------------------------------------------------
/download_diffusion.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | import requests
10 | from tqdm import tqdm
11 |
12 | import torch
13 |
14 | ADM_IMG256_COND_CKPT = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion.pt"
15 | I2SB_IMG256_COND_CKPT = "256x256_diffusion_fixedsigma.pt"
16 |
17 |
18 | def download(url, local_path, chunk_size=1024):
19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20 | with requests.get(url, stream=True) as r:
21 | total_size = int(r.headers.get("content-length", 0))
22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23 | with open(local_path, "wb") as f:
24 | for data in r.iter_content(chunk_size=chunk_size):
25 | if data:
26 | f.write(data)
27 | pbar.update(chunk_size)
28 |
29 |
30 | def download_adm_image256_cond_ckpt(ckpt_dir):
31 | ckpt_pt = os.path.join(ckpt_dir, I2SB_IMG256_COND_CKPT)
32 | if os.path.exists(ckpt_pt):
33 | return
34 |
35 | adm_ckpt = os.path.join(ckpt_dir, os.path.basename(ADM_IMG256_COND_CKPT))
36 |
37 | print("Downloading ADM checkpoint to {} ...".format(adm_ckpt))
38 | download(ADM_IMG256_COND_CKPT, adm_ckpt)
39 | ckpt_state_dict = torch.load(adm_ckpt, map_location="cpu")
40 |
41 | # pt: remove the sigma prediction and add concat module
42 | ckpt_state_dict["out.2.weight"] = ckpt_state_dict["out.2.weight"][:3]
43 | ckpt_state_dict["out.2.bias"] = ckpt_state_dict["out.2.bias"][:3]
44 | ckpt_state_dict["input_blocks.0.0.weight"] = torch.cat(
45 | [ckpt_state_dict["input_blocks.0.0.weight"], ckpt_state_dict["input_blocks.0.0.weight"]], dim=1
46 | )
47 | torch.save(ckpt_state_dict, ckpt_pt)
48 |
49 | print(f"Saved adm cond pretrain models at {ckpt_pt}!")
50 |
51 |
52 | def download_ckpt(ckpt_dir="assets/ckpts"):
53 | os.makedirs(ckpt_dir, exist_ok=True)
54 | download_adm_image256_cond_ckpt(ckpt_dir=ckpt_dir)
55 |
56 |
57 | download_ckpt()
58 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | from .resnet import build_resnet50
--------------------------------------------------------------------------------
/evaluation/compute_metrices_imagenet.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | import argparse
10 | import random
11 | from pathlib import Path
12 | import json
13 |
14 | import numpy as np
15 | import torch
16 | import torchvision.transforms as transforms
17 | from torch.utils.data import DataLoader
18 | from torch.utils.data import Dataset
19 | from easydict import EasyDict as edict
20 | from logger import Logger
21 |
22 | from evaluation.resnet import build_resnet50
23 | from evaluation import fid_util
24 |
25 | RESULT_DIR = Path("results")
26 | ADM_IMG256_FID_TRAIN_REF_CKPT = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz"
27 |
28 | def set_seed(seed):
29 | # https://github.com/pytorch/pytorch/issues/7068
30 | random.seed(seed)
31 | os.environ['PYTHONHASHSEED'] = str(seed)
32 | np.random.seed(seed)
33 | torch.manual_seed(seed)
34 | torch.cuda.manual_seed(seed)
35 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
36 |
37 | class NumpyDataset(Dataset):
38 | def __init__(self, data, targets):
39 | self.data = data
40 | self.targets = torch.LongTensor(targets)
41 | self.transform = transforms.ToTensor()
42 |
43 | def __getitem__(self, index):
44 | img_np = self.data[index]
45 | y = self.targets[index]
46 |
47 | if img_np.dtype == "uint8":
48 | # transform gives [0,1]
49 | img_t = self.transform(img_np) * 2 - 1
50 | elif img_np.dtype == "float32":
51 | # transform gives [0,255]
52 | img_t = self.transform(img_np) / 127.5 - 1
53 |
54 | # img_t: [-1,1]
55 | return img_t, y
56 |
57 | def __len__(self):
58 | return len(self.data)
59 |
60 | @torch.no_grad()
61 | def compute_accu(opt, numpy_arr, numpy_label_arr, batch_size=256):
62 | dataset = NumpyDataset(numpy_arr, numpy_label_arr)
63 | loader = DataLoader(dataset,
64 | batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1, drop_last=False,
65 | )
66 |
67 | resnet = build_resnet50().to(opt.device)
68 | correct = total = 0
69 | for (x,y) in loader:
70 | pred_y = resnet(x.to(opt.device))
71 |
72 | _, predicted = torch.max(pred_y.cpu(), 1)
73 | correct += (predicted==y).sum().item()
74 | total += y.size(0)
75 |
76 | accu = correct / total
77 | return accu
78 |
79 | def convert_to_numpy(t):
80 | # t: [-1,1]
81 | out = (t + 1) * 127.5
82 | out = out.clamp(0, 255)
83 | out = out.to(torch.uint8)
84 | out = out.permute(0, 2, 3, 1) # batch, 256, 256, 3
85 | out = out.contiguous()
86 | return out.cpu().numpy() # [0, 255]
87 |
88 |
89 | def build_numpy_data(opt, log):
90 | arr = []
91 | ckpt_path = opt.ckpt
92 | numpy_arr = np.load(ckpt_path)['arr_0']
93 | label_path = opt.label
94 | label_arr = np.load(label_path)['arr_0']
95 |
96 | # converet to numpy
97 | return numpy_arr, label_arr
98 |
99 | def build_ref_opt(opt, ref_fid_fn):
100 | split = ref_fid_fn.name[:-4].split("_")[-1]
101 | image_size = int(ref_fid_fn.name[:-4].split("_")[-2])
102 | assert opt.image_size == image_size
103 | return edict(
104 | mode=opt.mode,
105 | split=split,
106 | image_size=image_size,
107 | dataset_dir=opt.dataset_dir,
108 | )
109 |
110 | def get_ref_fid(opt, log):
111 | # get ref fid npz file
112 |
113 | ref_fid_fn = Path("assets/stats/fid_imagenet_256_val.npz")
114 | if not ref_fid_fn.exists():
115 | log.info(f"Generating {ref_fid_fn=} (this can take a while ...)")
116 | ref_opt = build_ref_opt(opt, ref_fid_fn)
117 | fid_util.compute_fid_ref_stat(ref_opt, log)
118 |
119 | # load npz file
120 | ref_fid = np.load(ref_fid_fn)
121 | ref_mu, ref_sigma = ref_fid['mu'], ref_fid['sigma']
122 | return ref_fid_fn, ref_mu, ref_sigma
123 |
124 | def log_metrices(opt):
125 | # setup
126 | set_seed(opt.seed)
127 | if opt.gpu is not None:
128 | torch.cuda.set_device(opt.gpu)
129 | log = Logger(0, ".log")
130 |
131 | log.info(f"======== Compute metrices: {opt.ckpt=}, {opt.mode=} ========")
132 |
133 | # find all recon pt files
134 | # recon_imgs_pts = find_recon_imgs_pts(opt, log)
135 | # log.info(f"Found {len(recon_imgs_pts)} pt files={[pt.name for pt in recon_imgs_pts]}")
136 |
137 | # build torch array
138 | numpy_arr, numpy_label_arr = build_numpy_data(opt, log)
139 | log.info(f"Collected {numpy_arr.shape=}!")
140 |
141 | # compute accu
142 | accu = compute_accu(opt, numpy_arr, numpy_label_arr)
143 | log.info(f"Accuracy={accu:.3f}!")
144 |
145 | # load ref fid stat
146 | ref_fid_fn, ref_mu, ref_sigma = get_ref_fid(opt, log)
147 | log.info(f"Loaded FID reference statistics from {ref_fid_fn}!")
148 |
149 | # compute fid
150 | fid = fid_util.compute_fid_from_numpy(numpy_arr, ref_mu, ref_sigma, mode=opt.mode)
151 | log.info(f"FID(w.r.t. {ref_fid_fn=})={fid:.2f}!")
152 |
153 | res_dict = {
154 | "accu": accu,
155 | "fid": fid,
156 | }
157 |
158 | # save to file
159 | res_dir = opt.ckpt.split("/")[:-1]
160 | res_dir = Path("/".join(res_dir))
161 | with open(res_dir / "res.json", "w") as f:
162 | json.dump(res_dict, f)
163 |
164 | if __name__ == '__main__':
165 | parser = argparse.ArgumentParser()
166 | parser.add_argument("--seed", type=int, default=0)
167 | parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device")
168 | parser.add_argument("--ckpt", type=str, default=None, help="the checkpoint name for which we wish to compute metrices")
169 | parser.add_argument("--label", type=str, default=None)
170 | parser.add_argument("--mode", type=str, default="legacy_pytorch", help="the FID computation mode used in clean-fid")
171 | parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset")
172 | parser.add_argument("--sample-dir", type=Path, default=None, help="directory where samples are stored")
173 | parser.add_argument("--image-size", type=int, default=256)
174 |
175 | arg = parser.parse_args()
176 |
177 | opt = edict(
178 | device="cuda",
179 | )
180 | opt.update(vars(arg))
181 |
182 | log_metrices(opt)
183 |
--------------------------------------------------------------------------------
/evaluation/fid_util.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | from tqdm import tqdm
10 | import numpy as np
11 | from pathlib import Path
12 |
13 | import torch
14 | import torchvision
15 | from cleanfid.resize import build_resizer
16 | from cleanfid.features import build_feature_extractor
17 | from cleanfid.fid import get_batch_features, frechet_distance
18 |
19 | FID_REF_DIR = Path("assets/stats")
20 |
21 |
22 | def collect_features(
23 | dataset, mode, batch_size, num_workers, device=torch.device("cuda"), use_dataparallel=True, verbose=True
24 | ):
25 |
26 | dataloader = torch.utils.data.DataLoader(
27 | dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers
28 | )
29 | feat_model = build_feature_extractor(mode, device, use_dataparallel=use_dataparallel)
30 | l_feats = []
31 | pbar = tqdm(dataloader, desc="FID") if verbose else dataloader
32 | for batch in pbar:
33 | l_feats.append(get_batch_features(batch, feat_model, device))
34 |
35 | np_feats = np.concatenate(l_feats)
36 | mu = np.mean(np_feats, axis=0)
37 | sigma = np.cov(np_feats, rowvar=False)
38 | return mu, sigma
39 |
40 |
41 | class NumpyResizeDataset(torch.utils.data.Dataset):
42 | def __init__(self, dataset, mode, size=(299, 299)):
43 | self.dataset = dataset
44 | self.transforms = torchvision.transforms.ToTensor()
45 | self.size = size
46 | self.fn_resize = build_resizer(mode)
47 | self.custom_image_tranform = lambda x: x
48 |
49 | def get_img_np(self, i):
50 | return self.dataset[i]
51 |
52 | def __len__(self):
53 | return len(self.dataset)
54 |
55 | def __getitem__(self, i):
56 | img_np = self.get_img_np(i)
57 |
58 | # apply a custom image transform before resizing the image to 299x299
59 | img_np = self.custom_image_tranform(img_np)
60 | # fn_resize expects a np array and returns a np array
61 | img_resized = self.fn_resize(img_np)
62 |
63 | # ToTensor() converts to [0,1] only if input in uint8
64 | if img_resized.dtype == "uint8":
65 | img_t = self.transforms(np.array(img_resized)) * 255
66 | elif img_resized.dtype == "float32":
67 | img_t = self.transforms(img_resized)
68 |
69 | return img_t
70 |
71 |
72 | @torch.no_grad()
73 | def compute_fid_from_numpy(numpy_arr, ref_mu, ref_sigma, batch_size=256, mode="legacy_pytorch"):
74 |
75 | dataset = NumpyResizeDataset(numpy_arr, mode=mode)
76 | mu, sigma = collect_features(
77 | dataset,
78 | mode,
79 | num_workers=1,
80 | batch_size=batch_size,
81 | use_dataparallel=False,
82 | verbose=False,
83 | )
84 | return frechet_distance(mu, sigma, ref_mu, ref_sigma)
85 |
86 |
87 | class LMDBResizeDataset(NumpyResizeDataset):
88 | def __init__(self, dataset, mode):
89 | super(LMDBResizeDataset, self).__init__(dataset, mode)
90 |
91 | def get_img_np(self, i):
92 | img_pil, _ = self.dataset[i]
93 | return np.array(img_pil)
94 |
95 |
96 | def compute_fid_ref_stat(opt, log):
97 | from datasets.imagenet_inpaint import build_lmdb_dataset
98 | from torchvision import transforms
99 |
100 | mode = opt.mode
101 |
102 | # build dataset
103 | transform = transforms.Compose(
104 | [
105 | transforms.Resize(opt.image_size),
106 | transforms.CenterCrop(opt.image_size),
107 | ]
108 | )
109 | lmdb_dataset = build_lmdb_dataset(opt.dataset_dir, opt.image_size, train=opt.split == "train", transform=transform)
110 | dataset = LMDBResizeDataset(lmdb_dataset, mode=mode)
111 | log.info(f"[FID] Built Imagenet {opt.split} dataset, size={len(dataset)}!")
112 |
113 | # compute fid statistics
114 | num_avail_cpus = len(os.sched_getaffinity(0))
115 | num_workers = min(num_avail_cpus, 8)
116 | mu, sigma = collect_features(dataset, mode, batch_size=512, num_workers=num_workers)
117 | log.info(f"Collected inception features, {mu.shape=}, {sigma.shape=}!")
118 |
119 | # save and return statistics
120 | os.makedirs(FID_REF_DIR, exist_ok=True)
121 | fn = FID_REF_DIR / f"fid_imagenet_{opt.image_size}_{opt.split}.npz"
122 | np.savez(fn, mu=mu, sigma=sigma)
123 | log.info(f"Saved FID reference statistics to {fn}!")
124 | return mu, sigma
125 |
126 |
127 | if __name__ == "__main__":
128 | import argparse
129 | from logger import Logger
130 |
131 | parser = argparse.ArgumentParser()
132 | parser.add_argument(
133 | "--split", type=str, choices=["train", "val"], help="which dataset to compute FID ref statistics"
134 | )
135 | parser.add_argument("--mode", type=str, default="legacy_pytorch", help="the FID computation mode used in clean-fid")
136 | parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset")
137 | parser.add_argument("--image-size", type=int, default=256)
138 | opt = parser.parse_args()
139 |
140 | log = Logger(0, ".log")
141 | log.info(f"======== Compute FID ref statistics: mode={opt.mode} ========")
142 | compute_fid_ref_stat(opt, log)
143 |
--------------------------------------------------------------------------------
/evaluation/resnet.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import torch
9 |
10 | import torch.nn.functional as F
11 | from collections import OrderedDict
12 | from torchvision.models import resnet50
13 |
14 | from ipdb import set_trace as debug
15 |
16 |
17 | class ImageNormalizer(torch.nn.Module):
18 |
19 | def __init__(self, mean, std) -> None:
20 | super(ImageNormalizer, self).__init__()
21 |
22 | self.register_buffer("mean", torch.as_tensor(mean).view(1, 3, 1, 1))
23 | self.register_buffer("std", torch.as_tensor(std).view(1, 3, 1, 1))
24 |
25 | def forward(self, image):
26 | # note: image should be in [-1,1]
27 | image = (image + 1) / 2 # [-1,1] -> [0,1]
28 | image = F.interpolate(image, size=(224, 224), mode="bicubic")
29 | return (image - self.mean) / self.std
30 |
31 | def __repr__(self):
32 | return f"ImageNormalizer(mean={self.mean.squeeze()}, std={self.std.squeeze()})" # type: ignore
33 |
34 |
35 | def normalize_model(model, mean, std):
36 | layers = OrderedDict([("normalize", ImageNormalizer(mean, std)), ("model", model)])
37 | return torch.nn.Sequential(layers)
38 |
39 |
40 | def build_resnet50():
41 | model = resnet50(pretrained=True)
42 | mu = (0.485, 0.456, 0.406)
43 | sigma = (0.229, 0.224, 0.225)
44 | model = normalize_model(model, mu, sigma)
45 | model.eval()
46 | return model
47 |
--------------------------------------------------------------------------------
/evaluations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-ml/DiffusionBridge/92522733cc602686df77f07a1824bb89f89cda1a/evaluations/__init__.py
--------------------------------------------------------------------------------
/evaluations/feature_extractor.py:
--------------------------------------------------------------------------------
1 | """ from clean fid """
2 |
3 | import os
4 | import platform
5 | import numpy as np
6 | import torch
7 | import cleanfid
8 | from cleanfid.downloads_helper import check_download_url
9 | from .inception_pytorch import InceptionV3
10 | from .inception_torchscript import InceptionV3W
11 |
12 |
13 | """
14 | returns a functions that takes an image in range [0,255]
15 | and outputs a feature embedding vector
16 | """
17 | def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False, use_dataparallel=True):
18 | if name == "torchscript_inception":
19 | path = "./" if platform.system() == "Windows" else "/tmp"
20 | model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(device)
21 | model.eval()
22 | if use_dataparallel:
23 | model = torch.nn.DataParallel(model)
24 | def model_fn(x): return model(x)
25 | elif name == "pytorch_inception":
26 | model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
27 | model.eval()
28 | if use_dataparallel:
29 | model = torch.nn.DataParallel(model)
30 | def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1)
31 | else:
32 | raise ValueError(f"{name} feature extractor not implemented")
33 | return model_fn
34 |
35 |
36 | """
37 | Build a feature extractor for each of the modes
38 | """
39 | def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
40 | assert not (mode == 'legacy_pytorch')
41 | if mode == "legacy_pytorch":
42 | feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
43 | elif mode == "legacy_tensorflow":
44 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device, use_dataparallel=use_dataparallel)
45 | elif mode == "clean":
46 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel)
47 | return feat_model
48 |
--------------------------------------------------------------------------------
/evaluations/inception_pytorch.py:
--------------------------------------------------------------------------------
1 | """
2 | File from: https://github.com/mseitzer/pytorch-fid
3 | """
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision
10 |
11 | try:
12 | from torchvision.models.utils import load_state_dict_from_url
13 | except ImportError:
14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
15 |
16 | # Inception weights ported to Pytorch from
17 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
18 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
19 |
20 |
21 | class InceptionV3(nn.Module):
22 | """Pretrained InceptionV3 network returning feature maps"""
23 |
24 | # Index of default block of inception to return,
25 | # corresponds to output of final average pooling
26 | DEFAULT_BLOCK_INDEX = 3
27 |
28 | # Maps feature dimensionality to their output blocks indices
29 | BLOCK_INDEX_BY_DIM = {
30 | 64: 0, # First max pooling features
31 | 192: 1, # Second max pooling featurs
32 | 768: 2, # Pre-aux classifier features
33 | 2048: 3 # Final average pooling features
34 | }
35 |
36 | def __init__(self,
37 | output_blocks=(DEFAULT_BLOCK_INDEX,),
38 | resize_input=True,
39 | normalize_input=True,
40 | requires_grad=False,
41 | use_fid_inception=True):
42 | """Build pretrained InceptionV3
43 | Parameters
44 | ----------
45 | output_blocks : list of int
46 | Indices of blocks to return features of. Possible values are:
47 | - 0: corresponds to output of first max pooling
48 | - 1: corresponds to output of second max pooling
49 | - 2: corresponds to output which is fed to aux classifier
50 | - 3: corresponds to output of final average pooling
51 | resize_input : bool
52 | If true, bilinearly resizes input to width and height 299 before
53 | feeding input to model. As the network without fully connected
54 | layers is fully convolutional, it should be able to handle inputs
55 | of arbitrary size, so resizing might not be strictly needed
56 | normalize_input : bool
57 | If true, scales the input from range (0, 1) to the range the
58 | pretrained Inception network expects, namely (-1, 1)
59 | requires_grad : bool
60 | If true, parameters of the model require gradients. Possibly useful
61 | for finetuning the network
62 | use_fid_inception : bool
63 | If true, uses the pretrained Inception model used in Tensorflow's
64 | FID implementation. If false, uses the pretrained Inception model
65 | available in torchvision. The FID Inception model has different
66 | weights and a slightly different structure from torchvision's
67 | Inception model. If you want to compute FID scores, you are
68 | strongly advised to set this parameter to true to get comparable
69 | results.
70 | """
71 | super(InceptionV3, self).__init__()
72 |
73 | self.resize_input = resize_input
74 | self.normalize_input = normalize_input
75 | self.output_blocks = sorted(output_blocks)
76 | self.last_needed_block = max(output_blocks)
77 |
78 | assert self.last_needed_block <= 3, \
79 | 'Last possible output block index is 3'
80 |
81 | self.blocks = nn.ModuleList()
82 |
83 | if use_fid_inception:
84 | inception = fid_inception_v3()
85 | else:
86 | inception = _inception_v3(pretrained=True)
87 |
88 | # Block 0: input to maxpool1
89 | block0 = [
90 | inception.Conv2d_1a_3x3,
91 | inception.Conv2d_2a_3x3,
92 | inception.Conv2d_2b_3x3,
93 | nn.MaxPool2d(kernel_size=3, stride=2)
94 | ]
95 | self.blocks.append(nn.Sequential(*block0))
96 |
97 | # Block 1: maxpool1 to maxpool2
98 | if self.last_needed_block >= 1:
99 | block1 = [
100 | inception.Conv2d_3b_1x1,
101 | inception.Conv2d_4a_3x3,
102 | nn.MaxPool2d(kernel_size=3, stride=2)
103 | ]
104 | self.blocks.append(nn.Sequential(*block1))
105 |
106 | # Block 2: maxpool2 to aux classifier
107 | if self.last_needed_block >= 2:
108 | block2 = [
109 | inception.Mixed_5b,
110 | inception.Mixed_5c,
111 | inception.Mixed_5d,
112 | inception.Mixed_6a,
113 | inception.Mixed_6b,
114 | inception.Mixed_6c,
115 | inception.Mixed_6d,
116 | inception.Mixed_6e,
117 | ]
118 | self.blocks.append(nn.Sequential(*block2))
119 |
120 | # Block 3: aux classifier to final avgpool
121 | if self.last_needed_block >= 3:
122 | block3 = [
123 | inception.Mixed_7a,
124 | inception.Mixed_7b,
125 | inception.Mixed_7c,
126 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
127 | ]
128 | self.blocks.append(nn.Sequential(*block3))
129 |
130 | for param in self.parameters():
131 | param.requires_grad = requires_grad
132 |
133 | def forward(self, inp):
134 | """Get Inception feature maps
135 | Parameters
136 | ----------
137 | inp : torch.autograd.Variable
138 | Input tensor of shape Bx3xHxW. Values are expected to be in
139 | range (0, 1)
140 | Returns
141 | -------
142 | List of torch.autograd.Variable, corresponding to the selected output
143 | block, sorted ascending by index
144 | """
145 | outp = []
146 | x = inp
147 |
148 | if self.resize_input:
149 | raise ValueError("should not resize here")
150 | x = F.interpolate(x,
151 | size=(299, 299),
152 | mode='bilinear',
153 | align_corners=False)
154 |
155 | if self.normalize_input:
156 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
157 |
158 | for idx, block in enumerate(self.blocks):
159 | x = block(x)
160 | if idx in self.output_blocks:
161 | outp.append(x)
162 |
163 | if idx == self.last_needed_block:
164 | break
165 |
166 | return outp
167 |
168 |
169 | def _inception_v3(*args, **kwargs):
170 | """Wraps `torchvision.models.inception_v3`
171 | Skips default weight inititialization if supported by torchvision version.
172 | See https://github.com/mseitzer/pytorch-fid/issues/28.
173 | """
174 | try:
175 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
176 | except ValueError:
177 | # Just a caution against weird version strings
178 | version = (0,)
179 |
180 | if version >= (0, 6):
181 | kwargs['init_weights'] = False
182 |
183 | return torchvision.models.inception_v3(*args, **kwargs)
184 |
185 |
186 | def fid_inception_v3():
187 | """Build pretrained Inception model for FID computation
188 | The Inception model for FID computation uses a different set of weights
189 | and has a slightly different structure than torchvision's Inception.
190 | This method first constructs torchvision's Inception and then patches the
191 | necessary parts that are different in the FID Inception model.
192 | """
193 | inception = _inception_v3(num_classes=1008,
194 | aux_logits=False,
195 | pretrained=False)
196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203 | inception.Mixed_7b = FIDInceptionE_1(1280)
204 | inception.Mixed_7c = FIDInceptionE_2(2048)
205 |
206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
207 | inception.load_state_dict(state_dict)
208 | return inception
209 |
210 |
211 | class FIDInceptionA(torchvision.models.inception.InceptionA):
212 | """InceptionA block patched for FID computation"""
213 | def __init__(self, in_channels, pool_features):
214 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
215 |
216 | def forward(self, x):
217 | branch1x1 = self.branch1x1(x)
218 |
219 | branch5x5 = self.branch5x5_1(x)
220 | branch5x5 = self.branch5x5_2(branch5x5)
221 |
222 | branch3x3dbl = self.branch3x3dbl_1(x)
223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225 |
226 | # Patch: Tensorflow's average pool does not use the padded zero's in
227 | # its average calculation
228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229 | count_include_pad=False)
230 | branch_pool = self.branch_pool(branch_pool)
231 |
232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233 | return torch.cat(outputs, 1)
234 |
235 |
236 | class FIDInceptionC(torchvision.models.inception.InceptionC):
237 | """InceptionC block patched for FID computation"""
238 | def __init__(self, in_channels, channels_7x7):
239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240 |
241 | def forward(self, x):
242 | branch1x1 = self.branch1x1(x)
243 |
244 | branch7x7 = self.branch7x7_1(x)
245 | branch7x7 = self.branch7x7_2(branch7x7)
246 | branch7x7 = self.branch7x7_3(branch7x7)
247 |
248 | branch7x7dbl = self.branch7x7dbl_1(x)
249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253 |
254 | # Patch: Tensorflow's average pool does not use the padded zero's in
255 | # its average calculation
256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257 | count_include_pad=False)
258 | branch_pool = self.branch_pool(branch_pool)
259 |
260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261 | return torch.cat(outputs, 1)
262 |
263 |
264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265 | """First InceptionE block patched for FID computation"""
266 | def __init__(self, in_channels):
267 | super(FIDInceptionE_1, self).__init__(in_channels)
268 |
269 | def forward(self, x):
270 | branch1x1 = self.branch1x1(x)
271 |
272 | branch3x3 = self.branch3x3_1(x)
273 | branch3x3 = [
274 | self.branch3x3_2a(branch3x3),
275 | self.branch3x3_2b(branch3x3),
276 | ]
277 | branch3x3 = torch.cat(branch3x3, 1)
278 |
279 | branch3x3dbl = self.branch3x3dbl_1(x)
280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281 | branch3x3dbl = [
282 | self.branch3x3dbl_3a(branch3x3dbl),
283 | self.branch3x3dbl_3b(branch3x3dbl),
284 | ]
285 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
286 |
287 | # Patch: Tensorflow's average pool does not use the padded zero's in
288 | # its average calculation
289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290 | count_include_pad=False)
291 | branch_pool = self.branch_pool(branch_pool)
292 |
293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294 | return torch.cat(outputs, 1)
295 |
296 |
297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298 | """Second InceptionE block patched for FID computation"""
299 | def __init__(self, in_channels):
300 | super(FIDInceptionE_2, self).__init__(in_channels)
301 |
302 | def forward(self, x):
303 | branch1x1 = self.branch1x1(x)
304 |
305 | branch3x3 = self.branch3x3_1(x)
306 | branch3x3 = [
307 | self.branch3x3_2a(branch3x3),
308 | self.branch3x3_2b(branch3x3),
309 | ]
310 | branch3x3 = torch.cat(branch3x3, 1)
311 |
312 | branch3x3dbl = self.branch3x3dbl_1(x)
313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314 | branch3x3dbl = [
315 | self.branch3x3dbl_3a(branch3x3dbl),
316 | self.branch3x3dbl_3b(branch3x3dbl),
317 | ]
318 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
319 |
320 | # Patch: The FID Inception model uses max pooling instead of average
321 | # pooling. This is likely an error in this specific Inception
322 | # implementation, as other Inception models use average pooling here
323 | # (which matches the description in the paper).
324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325 | branch_pool = self.branch_pool(branch_pool)
326 |
327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328 | return torch.cat(outputs, 1)
--------------------------------------------------------------------------------
/evaluations/inception_torchscript.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from cleanfid.downloads_helper import *
5 | import contextlib
6 |
7 |
8 | @contextlib.contextmanager
9 | def disable_gpu_fuser_on_pt19():
10 | # On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See
11 | # https://github.com/GaParmar/clean-fid/issues/5
12 | # https://github.com/pytorch/pytorch/issues/64062
13 | if torch.__version__.startswith('1.9.'):
14 | old_val = torch._C._jit_can_fuse_on_gpu()
15 | torch._C._jit_override_can_fuse_on_gpu(False)
16 | yield
17 | if torch.__version__.startswith('1.9.'):
18 | torch._C._jit_override_can_fuse_on_gpu(old_val)
19 |
20 |
21 | class InceptionV3W(nn.Module):
22 | """
23 | Wrapper around Inception V3 torchscript model provided here
24 | https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt
25 |
26 | path: locally saved inception weights
27 | """
28 | def __init__(self, path, download=True, resize_inside=False):
29 | super(InceptionV3W, self).__init__()
30 | # download the network if it is not present at the given directory
31 | # use the current directory by default
32 | if download:
33 | check_download_inception(fpath=path)
34 | path = os.path.join(path, "inception-2015-12-05.pt")
35 | self.base = torch.jit.load(path).eval()
36 | self.layers = self.base.layers
37 | self.output = self.base.output
38 | self.resize_inside = resize_inside
39 |
40 | """
41 | Get the inception features without resizing
42 | x: Image with values in range [0,255]
43 | """
44 | def forward(self, x):
45 | with disable_gpu_fuser_on_pt19():
46 | bs = x.shape[0]
47 | if self.resize_inside:
48 | features = self.base(x, return_features=True).view((bs, 2048))
49 | else:
50 | # make sure it is resized already
51 | assert (x.shape[2] == 299) and (x.shape[3] == 299)
52 | # apply normalization
53 | x1 = x - 128
54 | x2 = x1 / 128
55 | features = self.layers.forward(x2, ).view((bs, 2048))
56 | return features, self.output(features)
--------------------------------------------------------------------------------
/evaluations/inception_v3.py:
--------------------------------------------------------------------------------
1 | # Ported from the model here:
2 | # https://github.com/NVlabs/stylegan3/blob/407db86e6fe432540a22515310188288687858fa/metrics/frechet_inception_distance.py#L22
3 | #
4 | # I have verified that the spatial features and output features are correct
5 | # within a mean absolute error of ~3e-5.
6 |
7 | import collections
8 |
9 | import torch
10 |
11 |
12 | class Conv2dLayer(torch.nn.Module):
13 | def __init__(self, in_channels, out_channels, kh, kw, stride=1, padding=0):
14 | super().__init__()
15 | self.stride = stride
16 | self.padding = padding
17 | self.weight = torch.nn.Parameter(torch.zeros(out_channels, in_channels, kh, kw))
18 | self.beta = torch.nn.Parameter(torch.zeros(out_channels))
19 | self.mean = torch.nn.Parameter(torch.zeros(out_channels))
20 | self.var = torch.nn.Parameter(torch.zeros(out_channels))
21 |
22 | def forward(self, x):
23 | x = torch.nn.functional.conv2d(
24 | x, self.weight.to(x.dtype), stride=self.stride, padding=self.padding
25 | )
26 | x = torch.nn.functional.batch_norm(
27 | x, running_mean=self.mean, running_var=self.var, bias=self.beta, eps=1e-3
28 | )
29 | x = torch.nn.functional.relu(x)
30 | return x
31 |
32 |
33 | # ----------------------------------------------------------------------------
34 |
35 |
36 | class InceptionA(torch.nn.Module):
37 | def __init__(self, in_channels, tmp_channels):
38 | super().__init__()
39 | self.conv = Conv2dLayer(in_channels, 64, kh=1, kw=1)
40 | self.tower = torch.nn.Sequential(
41 | collections.OrderedDict(
42 | [
43 | ("conv", Conv2dLayer(in_channels, 48, kh=1, kw=1)),
44 | ("conv_1", Conv2dLayer(48, 64, kh=5, kw=5, padding=2)),
45 | ]
46 | )
47 | )
48 | self.tower_1 = torch.nn.Sequential(
49 | collections.OrderedDict(
50 | [
51 | ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)),
52 | ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)),
53 | ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, padding=1)),
54 | ]
55 | )
56 | )
57 | self.tower_2 = torch.nn.Sequential(
58 | collections.OrderedDict(
59 | [
60 | (
61 | "pool",
62 | torch.nn.AvgPool2d(
63 | kernel_size=3, stride=1, padding=1, count_include_pad=False
64 | ),
65 | ),
66 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)),
67 | ]
68 | )
69 | )
70 |
71 | def forward(self, x):
72 | return torch.cat(
73 | [
74 | self.conv(x).contiguous(),
75 | self.tower(x).contiguous(),
76 | self.tower_1(x).contiguous(),
77 | self.tower_2(x).contiguous(),
78 | ],
79 | dim=1,
80 | )
81 |
82 |
83 | # ----------------------------------------------------------------------------
84 |
85 |
86 | class InceptionB(torch.nn.Module):
87 | def __init__(self, in_channels):
88 | super().__init__()
89 | self.conv = Conv2dLayer(in_channels, 384, kh=3, kw=3, stride=2)
90 | self.tower = torch.nn.Sequential(
91 | collections.OrderedDict(
92 | [
93 | ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)),
94 | ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)),
95 | ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, stride=2)),
96 | ]
97 | )
98 | )
99 | self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2)
100 |
101 | def forward(self, x):
102 | return torch.cat(
103 | [
104 | self.conv(x).contiguous(),
105 | self.tower(x).contiguous(),
106 | self.pool(x).contiguous(),
107 | ],
108 | dim=1,
109 | )
110 |
111 |
112 | # ----------------------------------------------------------------------------
113 |
114 |
115 | class InceptionC(torch.nn.Module):
116 | def __init__(self, in_channels, tmp_channels):
117 | super().__init__()
118 | self.conv = Conv2dLayer(in_channels, 192, kh=1, kw=1)
119 | self.tower = torch.nn.Sequential(
120 | collections.OrderedDict(
121 | [
122 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)),
123 | (
124 | "conv_1",
125 | Conv2dLayer(
126 | tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3]
127 | ),
128 | ),
129 | (
130 | "conv_2",
131 | Conv2dLayer(tmp_channels, 192, kh=7, kw=1, padding=[3, 0]),
132 | ),
133 | ]
134 | )
135 | )
136 | self.tower_1 = torch.nn.Sequential(
137 | collections.OrderedDict(
138 | [
139 | ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)),
140 | (
141 | "conv_1",
142 | Conv2dLayer(
143 | tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0]
144 | ),
145 | ),
146 | (
147 | "conv_2",
148 | Conv2dLayer(
149 | tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3]
150 | ),
151 | ),
152 | (
153 | "conv_3",
154 | Conv2dLayer(
155 | tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0]
156 | ),
157 | ),
158 | (
159 | "conv_4",
160 | Conv2dLayer(tmp_channels, 192, kh=1, kw=7, padding=[0, 3]),
161 | ),
162 | ]
163 | )
164 | )
165 | self.tower_2 = torch.nn.Sequential(
166 | collections.OrderedDict(
167 | [
168 | (
169 | "pool",
170 | torch.nn.AvgPool2d(
171 | kernel_size=3, stride=1, padding=1, count_include_pad=False
172 | ),
173 | ),
174 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)),
175 | ]
176 | )
177 | )
178 |
179 | def forward(self, x):
180 | return torch.cat(
181 | [
182 | self.conv(x).contiguous(),
183 | self.tower(x).contiguous(),
184 | self.tower_1(x).contiguous(),
185 | self.tower_2(x).contiguous(),
186 | ],
187 | dim=1,
188 | )
189 |
190 |
191 | # ----------------------------------------------------------------------------
192 |
193 |
194 | class InceptionD(torch.nn.Module):
195 | def __init__(self, in_channels):
196 | super().__init__()
197 | self.tower = torch.nn.Sequential(
198 | collections.OrderedDict(
199 | [
200 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)),
201 | ("conv_1", Conv2dLayer(192, 320, kh=3, kw=3, stride=2)),
202 | ]
203 | )
204 | )
205 | self.tower_1 = torch.nn.Sequential(
206 | collections.OrderedDict(
207 | [
208 | ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)),
209 | ("conv_1", Conv2dLayer(192, 192, kh=1, kw=7, padding=[0, 3])),
210 | ("conv_2", Conv2dLayer(192, 192, kh=7, kw=1, padding=[3, 0])),
211 | ("conv_3", Conv2dLayer(192, 192, kh=3, kw=3, stride=2)),
212 | ]
213 | )
214 | )
215 | self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2)
216 |
217 | def forward(self, x):
218 | return torch.cat(
219 | [
220 | self.tower(x).contiguous(),
221 | self.tower_1(x).contiguous(),
222 | self.pool(x).contiguous(),
223 | ],
224 | dim=1,
225 | )
226 |
227 |
228 | # ----------------------------------------------------------------------------
229 |
230 |
231 | class InceptionE(torch.nn.Module):
232 | def __init__(self, in_channels, use_avg_pool):
233 | super().__init__()
234 | self.conv = Conv2dLayer(in_channels, 320, kh=1, kw=1)
235 | self.tower_conv = Conv2dLayer(in_channels, 384, kh=1, kw=1)
236 | self.tower_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1])
237 | self.tower_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0])
238 | self.tower_1_conv = Conv2dLayer(in_channels, 448, kh=1, kw=1)
239 | self.tower_1_conv_1 = Conv2dLayer(448, 384, kh=3, kw=3, padding=1)
240 | self.tower_1_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1])
241 | self.tower_1_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0])
242 | if use_avg_pool:
243 | self.tower_2_pool = torch.nn.AvgPool2d(
244 | kernel_size=3, stride=1, padding=1, count_include_pad=False
245 | )
246 | else:
247 | self.tower_2_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
248 | self.tower_2_conv = Conv2dLayer(in_channels, 192, kh=1, kw=1)
249 |
250 | def forward(self, x):
251 | a = self.tower_conv(x)
252 | b = self.tower_1_conv_1(self.tower_1_conv(x))
253 | return torch.cat(
254 | [
255 | self.conv(x).contiguous(),
256 | self.tower_mixed_conv(a).contiguous(),
257 | self.tower_mixed_conv_1(a).contiguous(),
258 | self.tower_1_mixed_conv(b).contiguous(),
259 | self.tower_1_mixed_conv_1(b).contiguous(),
260 | self.tower_2_conv(self.tower_2_pool(x)).contiguous(),
261 | ],
262 | dim=1,
263 | )
264 |
265 |
266 | # ----------------------------------------------------------------------------
267 |
268 |
269 | class InceptionV3(torch.nn.Module):
270 | def __init__(self):
271 | super().__init__()
272 | self.layers = torch.nn.Sequential(
273 | collections.OrderedDict(
274 | [
275 | ("conv", Conv2dLayer(3, 32, kh=3, kw=3, stride=2)),
276 | ("conv_1", Conv2dLayer(32, 32, kh=3, kw=3)),
277 | ("conv_2", Conv2dLayer(32, 64, kh=3, kw=3, padding=1)),
278 | ("pool0", torch.nn.MaxPool2d(kernel_size=3, stride=2)),
279 | ("conv_3", Conv2dLayer(64, 80, kh=1, kw=1)),
280 | ("conv_4", Conv2dLayer(80, 192, kh=3, kw=3)),
281 | ("pool1", torch.nn.MaxPool2d(kernel_size=3, stride=2)),
282 | ("mixed", InceptionA(192, tmp_channels=32)),
283 | ("mixed_1", InceptionA(256, tmp_channels=64)),
284 | ("mixed_2", InceptionA(288, tmp_channels=64)),
285 | ("mixed_3", InceptionB(288)),
286 | ("mixed_4", InceptionC(768, tmp_channels=128)),
287 | ("mixed_5", InceptionC(768, tmp_channels=160)),
288 | ("mixed_6", InceptionC(768, tmp_channels=160)),
289 | ("mixed_7", InceptionC(768, tmp_channels=192)),
290 | ("mixed_8", InceptionD(768)),
291 | ("mixed_9", InceptionE(1280, use_avg_pool=True)),
292 | ("mixed_10", InceptionE(2048, use_avg_pool=False)),
293 | ("pool2", torch.nn.AvgPool2d(kernel_size=8)),
294 | ]
295 | )
296 | )
297 | self.output = torch.nn.Linear(2048, 1008)
298 |
299 | def forward(
300 | self,
301 | img,
302 | return_features: bool = True,
303 | use_fp16: bool = False,
304 | no_output_bias: bool = False,
305 | ):
306 | batch_size, channels, height, width = img.shape # [NCHW]
307 | assert channels == 3
308 |
309 | # Cast to float.
310 | x = img.to(torch.float16 if use_fp16 else torch.float32)
311 |
312 | # Emulate tf.image.resize_bilinear(x, [299, 299]), including the funky alignment.
313 | new_width, new_height = 299, 299
314 | theta = torch.eye(2, 3, device=x.device)
315 | theta[0, 2] += theta[0, 0] / width - theta[0, 0] / new_width
316 | theta[1, 2] += theta[1, 1] / height - theta[1, 1] / new_height
317 | theta = theta.to(x.dtype).unsqueeze(0).repeat([batch_size, 1, 1])
318 | grid = torch.nn.functional.affine_grid(
319 | theta, [batch_size, channels, new_height, new_width], align_corners=False
320 | )
321 | x = torch.nn.functional.grid_sample(
322 | x, grid, mode="bilinear", padding_mode="border", align_corners=False
323 | )
324 |
325 | # Scale dynamic range from [0,255] to [-1,1[.
326 | x -= 128
327 | x /= 128
328 |
329 | # Main layers.
330 | intermediate = self.layers[:-6](x)
331 | spatial_features = (
332 | self.layers[-6]
333 | .conv(intermediate)[:, :7]
334 | .permute(0, 2, 3, 1)
335 | .reshape(-1, 2023)
336 | )
337 | features = self.layers[-6:](intermediate).reshape(-1, 2048).to(torch.float32)
338 | if return_features:
339 | return features, spatial_features
340 |
341 | # Output layer.
342 | return self.acts_to_probs(features, no_output_bias=no_output_bias)
343 |
344 | def acts_to_probs(self, features, no_output_bias: bool = False):
345 | if no_output_bias:
346 | logits = torch.nn.functional.linear(features, self.output.weight)
347 | else:
348 | logits = self.output(features)
349 | probs = torch.nn.functional.softmax(logits, dim=1)
350 | return probs
351 |
352 | def create_softmax_model(self):
353 | return SoftmaxModel(self.output.weight)
354 |
355 |
356 | class SoftmaxModel(torch.nn.Module):
357 | def __init__(self, weight: torch.Tensor):
358 | super().__init__()
359 | self.weight = torch.nn.Parameter(weight.detach().clone())
360 |
361 | def forward(self, x):
362 | logits = torch.nn.functional.linear(x, self.weight)
363 | probs = torch.nn.functional.softmax(logits, dim=1)
364 | return probs
365 |
--------------------------------------------------------------------------------
/evaluations/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu>=2.0
2 | scipy
3 | requests
4 | tqdm
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import logging
11 | from rich.console import Console
12 | from rich.logging import RichHandler
13 |
14 |
15 | def get_time(sec):
16 | h = int(sec // 3600)
17 | m = int((sec // 60) % 60)
18 | s = int(sec % 60)
19 | return h, m, s
20 |
21 |
22 | class TimeFilter(logging.Filter):
23 |
24 | def filter(self, record):
25 | try:
26 | start = self.start
27 | except AttributeError:
28 | start = self.start = time.time()
29 |
30 | time_elapsed = get_time(time.time() - start)
31 |
32 | record.relative = "{0}:{1:02d}:{2:02d}".format(*time_elapsed)
33 |
34 | # self.last = record.relativeCreated/1000.0
35 | return True
36 |
37 |
38 | class Logger(object):
39 | def __init__(self, rank=0, log_dir=".log"):
40 | # other libraries may set logging before arriving at this line.
41 | # by reloading logging, we can get rid of previous configs set by other libraries.
42 | from importlib import reload
43 |
44 | reload(logging)
45 | self.rank = rank
46 | if self.rank == 0:
47 | os.makedirs(log_dir, exist_ok=True)
48 |
49 | log_file = open(os.path.join(log_dir, "log.txt"), "w")
50 | file_console = Console(file=log_file, width=150)
51 | logging.basicConfig(
52 | level=logging.INFO,
53 | format="(%(relative)s) %(message)s",
54 | datefmt="[%X]",
55 | force=True,
56 | handlers=[RichHandler(show_path=False), RichHandler(console=file_console, show_path=False)],
57 | )
58 | # https://stackoverflow.com/questions/31521859/python-logging-module-time-since-last-log
59 | log = logging.getLogger()
60 | [hndl.addFilter(TimeFilter()) for hndl in log.handlers]
61 |
62 | def info(self, string, *args):
63 | if self.rank == 0:
64 | logging.info(string, *args)
65 |
66 | def warning(self, string, *args):
67 | if self.rank == 0:
68 | logging.warning(string, *args)
69 |
70 | def error(self, string, *args):
71 | if self.rank == 0:
72 | logging.error(string, *args)
73 |
--------------------------------------------------------------------------------
/preprocess_ckpt.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # process DDBM checkpoints
4 |
5 | state_dict = torch.load("assets/ckpts/e2h_ema_0.9999_420000.pt", map_location="cpu")
6 |
7 | module_list = []
8 |
9 | for i in range(5, 16):
10 | if i == 8 or i == 12:
11 | continue
12 | module_list.append(f"input_blocks.{i}.1.qkv.weight")
13 | module_list.append(f"input_blocks.{i}.1.proj_out.weight")
14 |
15 | module_list.append("middle_block.1.qkv.weight")
16 | module_list.append("middle_block.1.proj_out.weight")
17 |
18 | for i in range(0, 12):
19 | module_list.append(f"output_blocks.{i}.1.qkv.weight")
20 | module_list.append(f"output_blocks.{i}.1.proj_out.weight")
21 |
22 | for name in module_list:
23 | state_dict[name] = state_dict[name].squeeze(-1)
24 |
25 | torch.save(state_dict, "assets/ckpts/e2h_ema_0.9999_420000_adapted.pt")
26 |
27 | state_dict = torch.load("assets/ckpts/diode_ema_0.9999_440000.pt", map_location="cpu")
28 |
29 | module_list = []
30 |
31 | for i in range(10, 18):
32 | if i == 12 or i == 15:
33 | continue
34 | module_list.append(f"input_blocks.{i}.1.qkv.weight")
35 | module_list.append(f"input_blocks.{i}.1.proj_out.weight")
36 |
37 | module_list.append("middle_block.1.qkv.weight")
38 | module_list.append("middle_block.1.proj_out.weight")
39 |
40 | for i in range(0, 9):
41 | module_list.append(f"output_blocks.{i}.1.qkv.weight")
42 | module_list.append(f"output_blocks.{i}.1.proj_out.weight")
43 |
44 | for name in module_list:
45 | state_dict[name] = state_dict[name].squeeze(-1)
46 |
47 | torch.save(state_dict, "assets/ckpts/diode_ema_0.9999_440000_adapted.pt")
48 |
--------------------------------------------------------------------------------
/preprocess_depth.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | from matplotlib import pyplot as plt
4 | import numpy as np
5 | import cv2
6 |
7 |
8 | from joblib import Parallel, delayed
9 | from tqdm import tqdm
10 | import csv
11 |
12 | SPLIT = "train"
13 | img_size = 256
14 | data_dir = "assets/datasets/DIODE/data_list/train_outdoor.csv"
15 |
16 |
17 | def plot_depth_map(dm, validity_mask, name):
18 | validity_mask = validity_mask > 0
19 | MIN_DEPTH = 0.5
20 | MAX_DEPTH = min(300, np.percentile(dm, 99))
21 | dm = np.clip(dm, MIN_DEPTH, MAX_DEPTH)
22 | dm = np.log(dm, where=validity_mask)
23 |
24 | dm = np.ma.masked_where(~validity_mask, dm)
25 |
26 | cmap = plt.cm.get_cmap("jet")
27 | cmap.set_bad(color="black")
28 | norm = plt.Normalize(vmin=0, vmax=np.log(MAX_DEPTH + 1.01))
29 | image = cmap(norm(dm))
30 | plt.imsave(name, np.clip(image, 0.0, 1.0))
31 |
32 |
33 | def plot_normal_map(normal_map, name):
34 | normal_viz = normal_map[:, ::, :]
35 |
36 | normal_viz = normal_viz + np.equal(np.sum(normal_viz, 2, keepdims=True), 0.0).astype(np.float32) * np.min(
37 | normal_viz
38 | )
39 |
40 | normal_viz = (normal_viz - np.min(normal_viz)) / 2.0
41 | plt.axis("off")
42 | plt.imsave(name, np.clip(normal_viz, 0.0, 1.0))
43 |
44 |
45 | all_files = []
46 | with open(data_dir, newline="") as csvfile:
47 | reader = csv.reader(csvfile, delimiter=",")
48 |
49 | for row in reader:
50 | if row[-1] == "Unavailable":
51 | continue
52 | all_files.append(row[0].split("/")[-1])
53 |
54 |
55 | def process(file):
56 | scene_id, scan_id = file.split("_")[0], file.split("_")[1]
57 | base_path = f"assets/datasets/DIODE/{SPLIT}/outdoor/scene_{scene_id}/scan_{scan_id}"
58 | path = os.path.join(base_path, file)
59 | pil_image = Image.open(path).convert("RGB").resize((img_size, img_size), Image.BICUBIC)
60 |
61 | path2 = os.path.join(base_path, file[:-4] + "_depth.npy")
62 | depth = np.load(path2).squeeze()
63 | depth = depth.astype(np.float32)
64 |
65 | path3 = os.path.join(base_path, file[:-4] + "_depth_mask.npy")
66 | depth_mask = np.load(path3)
67 | depth_mask = depth_mask.astype(np.float32)
68 |
69 | path4 = os.path.join(base_path, file[:-4] + "_normal.npy")
70 | normal = np.load(path4)
71 | normal = normal.astype(np.float32)
72 |
73 | image_depth = cv2.resize(depth, dsize=(img_size, img_size), interpolation=cv2.INTER_NEAREST)
74 | image_depth_mask = cv2.resize(depth_mask, dsize=(img_size, img_size), interpolation=cv2.INTER_NEAREST)
75 |
76 | normal = cv2.resize(normal, dsize=(img_size, img_size), interpolation=cv2.INTER_NEAREST)
77 |
78 | name = os.path.join(target_dir, file)
79 | pil_image.save(name)
80 | plot_depth_map(image_depth, image_depth_mask, name[:-4] + "_depth.png")
81 | plot_normal_map(normal, name[:-4] + "_normal.png")
82 |
83 |
84 | target_dir = f"assets/datasets/DIODE-{img_size}/{SPLIT}"
85 | os.makedirs(target_dir, exist_ok=True)
86 | Parallel(n_jobs=1)(delayed(process)(name) for name in tqdm(all_files))
87 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch
11 | import torchvision.utils as vutils
12 | import torch.distributed as dist
13 |
14 | from ddbm import dist_util, logger
15 | from ddbm.script_util import (
16 | model_and_diffusion_defaults,
17 | create_model_and_diffusion,
18 | add_dict_to_argparser,
19 | args_to_dict,
20 | )
21 | from ddbm.karras_diffusion import karras_sample
22 |
23 | from datasets import load_data
24 |
25 | from pathlib import Path
26 |
27 |
28 | def main():
29 | args = create_argparser().parse_args()
30 | args.use_fp16 = False
31 |
32 | workdir = os.path.join("workdir", os.path.basename(args.model_path)[:-3])
33 |
34 | ## assume ema ckpt format: ema_{rate}_{steps}.pt
35 | split = args.model_path.replace("_adapted", "").split("_")
36 | step = int(split[-1].split(".")[0])
37 | if args.sampler == "dbim":
38 | sample_dir = Path(workdir) / f"sample_{step}/split={args.split}/dbim_eta={args.eta}/steps={args.steps}"
39 | elif args.sampler == "dbim_high_order":
40 | sample_dir = Path(workdir) / f"sample_{step}/split={args.split}/dbim_order={args.order}/steps={args.steps}"
41 | else:
42 | sample_dir = Path(workdir) / f"sample_{step}/split={args.split}/{args.sampler}/steps={args.steps}"
43 | dist_util.setup_dist()
44 | if dist.get_rank() == 0:
45 |
46 | sample_dir.mkdir(parents=True, exist_ok=True)
47 | logger.configure(dir=str(sample_dir))
48 |
49 | logger.log("creating model and diffusion...")
50 | model, diffusion = create_model_and_diffusion(
51 | **args_to_dict(args, model_and_diffusion_defaults().keys()),
52 | )
53 | model.load_state_dict(torch.load(args.model_path, map_location="cpu"))
54 | model = model.to(dist_util.dev())
55 |
56 | if args.use_fp16:
57 | model = model.half()
58 | model.eval()
59 |
60 | logger.log("sampling...")
61 |
62 | all_images = []
63 | all_labels = []
64 |
65 | all_dataloaders = load_data(
66 | data_dir=args.data_dir,
67 | dataset=args.dataset,
68 | batch_size=args.batch_size,
69 | image_size=args.image_size,
70 | include_test=(args.split == "test"),
71 | seed=args.seed,
72 | num_workers=args.num_workers,
73 | )
74 | if args.split == "train":
75 | dataloader = all_dataloaders[1]
76 | elif args.split == "test":
77 | dataloader = all_dataloaders[2]
78 | else:
79 | raise NotImplementedError
80 | args.num_samples = len(dataloader.dataset)
81 | num = 0
82 | for i, data in enumerate(dataloader):
83 |
84 | x0_image = data[0]
85 | x0 = x0_image.to(dist_util.dev())
86 |
87 | y0_image = data[1].to(dist_util.dev())
88 | y0 = y0_image
89 |
90 | model_kwargs = {"xT": y0}
91 |
92 | if "inpaint" in args.dataset:
93 | _, mask, label = data[2]
94 | mask = mask.to(dist_util.dev())
95 | label = label.to(dist_util.dev())
96 | model_kwargs["y"] = label
97 | else:
98 | mask = None
99 |
100 | indexes = data[2][0].numpy()
101 | sample, path, nfe, pred_x0, sigmas, _ = karras_sample(
102 | diffusion,
103 | model,
104 | y0,
105 | x0,
106 | steps=args.steps,
107 | mask=mask,
108 | model_kwargs=model_kwargs,
109 | device=dist_util.dev(),
110 | clip_denoised=args.clip_denoised,
111 | sampler=args.sampler,
112 | churn_step_ratio=args.churn_step_ratio,
113 | eta=args.eta,
114 | order=args.order,
115 | seed=indexes + args.seed,
116 | )
117 |
118 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
119 | sample = sample.permute(0, 2, 3, 1)
120 | sample = sample.contiguous()
121 |
122 | gathered_samples = [torch.zeros_like(sample) for _ in range(dist.get_world_size())]
123 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
124 | gathered_samples = torch.cat(gathered_samples)
125 | if "inpaint" in args.dataset:
126 | gathered_labels = [torch.zeros_like(label) for _ in range(dist.get_world_size())]
127 | dist.all_gather(gathered_labels, label)
128 | gathered_labels = torch.cat(gathered_labels)
129 | num += gathered_samples.shape[0]
130 |
131 | num_display = min(32, sample.shape[0])
132 | if i == 0 and dist.get_rank() == 0:
133 | vutils.save_image(
134 | sample.permute(0, 3, 1, 2)[:num_display].float() / 255,
135 | f"{sample_dir}/sample_{i}.png",
136 | nrow=int(np.sqrt(num_display)),
137 | )
138 | if x0 is not None:
139 | vutils.save_image(
140 | x0_image[:num_display] / 2 + 0.5,
141 | f"{sample_dir}/x_{i}.png",
142 | nrow=int(np.sqrt(num_display)),
143 | )
144 | vutils.save_image(
145 | y0_image[:num_display] / 2 + 0.5,
146 | f"{sample_dir}/y_{i}.png",
147 | nrow=int(np.sqrt(num_display)),
148 | )
149 |
150 | all_images.append(gathered_samples.detach().cpu().numpy())
151 | if "inpaint" in args.dataset:
152 | all_labels.append(gathered_labels.detach().cpu().numpy())
153 |
154 | if dist.get_rank() == 0:
155 | logger.log(f"sampled {num} images")
156 |
157 | logger.log(f"created {len(all_images) * args.batch_size * dist.get_world_size()} samples")
158 |
159 | arr = np.concatenate(all_images, axis=0)
160 | arr = arr[: args.num_samples]
161 | if "inpaint" in args.dataset:
162 | labels = np.concatenate(all_labels, axis=0)
163 | labels = labels[: args.num_samples]
164 |
165 | if dist.get_rank() == 0:
166 | shape_str = "x".join([str(x) for x in arr.shape])
167 | out_path = os.path.join(sample_dir, f"samples_{shape_str}_nfe{nfe}.npz")
168 | logger.log(f"saving to {out_path}")
169 | np.savez(out_path, arr)
170 | if "inpaint" in args.dataset:
171 | shape_str = "x".join([str(x) for x in labels.shape])
172 | out_path = os.path.join(sample_dir, f"labels_{shape_str}_nfe{nfe}.npz")
173 | logger.log(f"saving to {out_path}")
174 | np.savez(out_path, labels)
175 |
176 | dist.barrier()
177 | logger.log("sampling complete")
178 |
179 |
180 | def create_argparser():
181 | defaults = dict(
182 | data_dir="", ## only used in bridge
183 | dataset="edges2handbags",
184 | clip_denoised=True,
185 | num_samples=10000,
186 | batch_size=16,
187 | sampler="heun",
188 | split="train",
189 | churn_step_ratio=0.0,
190 | rho=7.0,
191 | steps=40,
192 | model_path="",
193 | exp="",
194 | seed=42,
195 | num_workers=8,
196 | eta=1.0,
197 | order=1,
198 | )
199 | defaults.update(model_and_diffusion_defaults())
200 | parser = argparse.ArgumentParser()
201 | add_dict_to_argparser(parser, defaults)
202 | return parser
203 |
204 |
205 | if __name__ == "__main__":
206 | main()
207 |
--------------------------------------------------------------------------------
/scripts/args.sh:
--------------------------------------------------------------------------------
1 | DATASET_NAME=$1
2 |
3 | if [[ $DATASET_NAME == "e2h" ]]; then
4 | DATA_DIR=assets/datasets/edges2handbags
5 | DATASET=edges2handbags
6 | IMG_SIZE=64
7 |
8 | NUM_CH=192
9 | NUM_RES_BLOCKS=3
10 | ATTN_TYPE=True
11 |
12 | EXP="e2h${IMG_SIZE}_${NUM_CH}d"
13 | SAVE_ITER=100000
14 | MICRO_BS=64
15 | DROPOUT=0.1
16 | CLASS_COND=False
17 |
18 | PRED="vp"
19 | elif [[ $DATASET_NAME == "diode" ]]; then
20 | DATA_DIR=assets/datasets/DIODE-256
21 | DATASET=diode
22 | IMG_SIZE=256
23 |
24 | NUM_CH=256
25 | NUM_RES_BLOCKS=2
26 | ATTN_TYPE=True
27 |
28 | EXP="diode${IMG_SIZE}_${NUM_CH}d"
29 | SAVE_ITER=20000
30 | MICRO_BS=16
31 | DROPOUT=0.1
32 | CLASS_COND=False
33 |
34 | PRED="vp"
35 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then
36 | DATA_DIR=assets/datasets/ImageNet
37 | DATASET=imagenet_inpaint_center
38 | IMG_SIZE=256
39 |
40 | NUM_CH=256
41 | NUM_RES_BLOCKS=2
42 | ATTN_TYPE=False
43 |
44 | EXP="imagenet_inpaint_center${IMG_SIZE}_${NUM_CH}d"
45 | SAVE_ITER=20000
46 | MICRO_BS=16
47 | DROPOUT=0
48 | CLASS_COND=True
49 |
50 | PRED="i2sb_cond"
51 | fi
52 |
53 | if [[ $PRED == "ve" ]]; then
54 | EXP+="_ve"
55 | COND=concat
56 | SIGMA_MAX=80.0
57 | SIGMA_MIN=0.002
58 | elif [[ $PRED == "vp" ]]; then
59 | EXP+="_vp"
60 | COND=concat
61 | BETA_D=2
62 | BETA_MIN=0.1
63 | SIGMA_MAX=1
64 | SIGMA_MIN=0.0001
65 | elif [[ $PRED == "i2sb_cond" ]]; then
66 | EXP+="_i2sb_cond"
67 | COND=concat
68 | BETA_MAX=1.0
69 | BETA_MIN=0.1
70 | SIGMA_MAX=1
71 | SIGMA_MIN=0.0001
72 | else
73 | echo "Not supported"
74 | exit 1
75 | fi
--------------------------------------------------------------------------------
/scripts/evaluate.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=$PYTHONPATH:./
2 |
3 | DATASET_NAME=$1
4 | NFE=$2
5 | GEN_SAMPLER=$3
6 |
7 | if [[ $DATASET_NAME == "e2h" ]]; then
8 | SPLIT=train
9 | PREFIX=e2h_ema_0.9999_420000_adapted/sample_420000
10 | REF_PATH=assets/stats/edges2handbags_ref_64_data.npz
11 | SAMPLE_NAME=samples_138567x64x64x3_nfe${NFE}.npz
12 | elif [[ $DATASET_NAME == "diode" ]]; then
13 | SPLIT=train
14 | PREFIX=diode_ema_0.9999_440000_adapted/sample_440000
15 | REF_PATH=assets/stats/diode_ref_256_data.npz
16 | SAMPLE_NAME=samples_16502x256x256x3_nfe${NFE}.npz
17 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then
18 | SPLIT=test
19 | PREFIX=imagenet256_inpaint_ema_0.9999_400000/sample_400000
20 | DATA_DIR=assets/datasets/ImageNet
21 | SAMPLE_NAME=samples_10000x256x256x3_nfe${NFE}.npz
22 | LABEL_NAME=labels_10000_nfe${NFE}.npz
23 | fi
24 |
25 | if [[ $GEN_SAMPLER == "heun" ]]; then
26 | N=$(echo "$NFE" | awk '{print ($1 + 1) / 3}')
27 | N=$(printf "%.0f" "$N")
28 | SAMPLER="heun"
29 | elif [[ $GEN_SAMPLER == "dbim" ]]; then
30 | N=$((NFE-1))
31 | ETA=$4
32 | SAMPLER="dbim_eta=${ETA}"
33 | elif [[ $GEN_SAMPLER == "dbim_high_order" ]]; then
34 | N=$((NFE-1))
35 | ORDER=$4
36 | SAMPLER="dbim_order=${ORDER}"
37 | fi
38 |
39 | # For example:
40 | # SAMPLE_PATH="workdir/e2h_ema_0.9999_420000_adapted/sample_420000/split=train/dbim_eta=0.0/steps=4/samples_138567x64x64x3_nfe5.npz"
41 | # SAMPLE_PATH="workdir/diode_ema_0.9999_440000_adapted/sample_440000/split=train/dbim_eta=0.0/steps=4/samples_16502x256x256x3_nfe5.npz"
42 | # SAMPLE_PATH="workdir/imagenet256_inpaint_ema_0.9999_400000/sample_400000/split=test/dbim_order=3/steps=9/samples_10000x256x256x3_nfe10.npz"
43 |
44 | SAMPLE_DIR=workdir/${PREFIX}/split=${SPLIT}/${SAMPLER}/steps=${N}
45 | SAMPLE_PATH=${SAMPLE_DIR}/${SAMPLE_NAME}
46 |
47 | if [[ $DATASET_NAME == "e2h" || $DATASET_NAME == "diode" ]]; then
48 | python evaluations/evaluator.py $REF_PATH $SAMPLE_PATH --metric fid
49 | python evaluations/evaluator.py $REF_PATH $SAMPLE_PATH --metric lpips
50 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then
51 | LABEL_PATH=${SAMPLE_DIR}/${LABEL_NAME}
52 | python evaluation/compute_metrices_imagenet.py --ckpt $SAMPLE_PATH --label $LABEL_PATH --dataset-dir $DATA_DIR
53 | python evaluations/evaluator.py "" $SAMPLE_PATH --metric is
54 | fi
--------------------------------------------------------------------------------
/scripts/sample.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=$PYTHONPATH:./
2 |
3 | # For cluster
4 | # export ADDR=$1
5 | # run_args="--nproc_per_node 8 \
6 | # --master_addr $ADDR \
7 | # --node_rank $RANK \
8 | # --master_port $MASTER_PORT \
9 | # --nnodes $WORLD_SIZE"
10 | # For local
11 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
12 | run_args="--nproc_per_node 8 \
13 | --master_port 29511"
14 |
15 | # Batch size per GPU
16 | BS=16
17 |
18 | # Dataset and checkpoint
19 | DATASET_NAME=$1
20 |
21 | if [[ $DATASET_NAME == "e2h" ]]; then
22 | SPLIT=train
23 | MODEL_PATH=assets/ckpts/e2h_ema_0.9999_420000_adapted.pt
24 | elif [[ $DATASET_NAME == "diode" ]]; then
25 | SPLIT=train
26 | MODEL_PATH=assets/ckpts/diode_ema_0.9999_440000_adapted.pt
27 | elif [[ $DATASET_NAME == "imagenet_inpaint_center" ]]; then
28 | SPLIT=test
29 | MODEL_PATH=assets/ckpts/imagenet256_inpaint_ema_0.9999_400000.pt
30 | fi
31 |
32 | source scripts/args.sh $DATASET_NAME
33 |
34 | # Number of function evaluations (NFE)
35 | NFE=$2
36 |
37 | # Sampler
38 | GEN_SAMPLER=$3
39 |
40 | if [[ $GEN_SAMPLER == "heun" ]]; then
41 | N=$(echo "$NFE" | awk '{print ($1 + 1) / 3}')
42 | N=$(printf "%.0f" "$N")
43 | # Default setting in the DDBM paper
44 | CHURN_STEP_RATIO=0.33
45 | elif [[ $GEN_SAMPLER == "dbim" ]]; then
46 | N=$((NFE-1))
47 | ETA=$4
48 | elif [[ $GEN_SAMPLER == "dbim_high_order" ]]; then
49 | N=$((NFE-1))
50 | ORDER=$4
51 | fi
52 |
53 | torchrun $run_args sample.py --steps $N --sampler $GEN_SAMPLER --batch_size $BS \
54 | --model_path $MODEL_PATH --class_cond $CLASS_COND --noise_schedule $PRED \
55 | ${BETA_D:+ --beta_d="${BETA_D}"} ${BETA_MIN:+ --beta_min="${BETA_MIN}"} ${BETA_MAX:+ --beta_max="${BETA_MAX}"} \
56 | --condition_mode=$COND --sigma_max=$SIGMA_MAX --sigma_min=$SIGMA_MIN \
57 | --dropout $DROPOUT --image_size $IMG_SIZE --num_channels $NUM_CH --num_res_blocks $NUM_RES_BLOCKS \
58 | --use_new_attention_order $ATTN_TYPE --data_dir=$DATA_DIR --dataset=$DATASET --split $SPLIT\
59 | ${CHURN_STEP_RATIO:+ --churn_step_ratio="${CHURN_STEP_RATIO}"} \
60 | ${ETA:+ --eta="${ETA}"} \
61 | ${ORDER:+ --order="${ORDER}"}
62 |
--------------------------------------------------------------------------------
/scripts/train_bridge.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=$PYTHONPATH:./
2 |
3 | DATASET_NAME=imagenet_inpaint_center
4 | TRAIN_MODE=ddbm
5 |
6 | source scripts/args.sh $DATASET_NAME
7 |
8 | FREQ_SAVE_ITER=5000
9 | EXP=${DATASET_NAME}-${TRAIN_MODE}
10 |
11 | CKPT=assets/ckpts/256x256_diffusion_fixedsigma.pt
12 |
13 | # For cluster
14 | # export ADDR=$1
15 | # run_args="--nproc_per_node 8 \
16 | # --master_addr $ADDR \
17 | # --node_rank $RANK \
18 | # --master_port $MASTER_PORT \
19 | # --nnodes $WORLD_SIZE"
20 | # For local
21 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
22 | run_args="--nproc_per_node 8 \
23 | --master_port 29511"
24 |
25 | MICRO_BS=4
26 |
27 | torchrun $run_args train.py --exp=$EXP \
28 | --class_cond $CLASS_COND \
29 | --dropout $DROPOUT --microbatch $MICRO_BS \
30 | --image_size $IMG_SIZE --num_channels $NUM_CH \
31 | --num_res_blocks $NUM_RES_BLOCKS --condition_mode=$COND \
32 | --noise_schedule=$PRED \
33 | --use_new_attention_order $ATTN_TYPE \
34 | ${BETA_D:+ --beta_d="${BETA_D}"} ${BETA_MIN:+ --beta_min="${BETA_MIN}"} ${BETA_MAX:+ --beta_max="${BETA_MAX}"} \
35 | --data_dir=$DATA_DIR --dataset=$DATASET \
36 | --sigma_max=$SIGMA_MAX --sigma_min=$SIGMA_MIN \
37 | --save_interval_for_preemption=$FREQ_SAVE_ITER --save_interval=$SAVE_ITER --debug=False \
38 | ${CKPT:+ --resume_checkpoint="${CKPT}"}
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 |
7 | from ddbm import dist_util, logger
8 | from datasets import load_data
9 | from ddbm.resample import create_named_schedule_sampler
10 | from ddbm.script_util import (
11 | model_and_diffusion_defaults,
12 | create_model_and_diffusion,
13 | sample_defaults,
14 | args_to_dict,
15 | add_dict_to_argparser,
16 | get_workdir,
17 | )
18 | from ddbm.train_util import TrainLoop
19 |
20 | import torch.distributed as dist
21 |
22 | from pathlib import Path
23 |
24 | import wandb
25 |
26 | from glob import glob
27 | import os
28 | from datasets.augment import AugmentPipe
29 |
30 |
31 | def main(args):
32 |
33 | workdir = get_workdir(args.exp)
34 | Path(workdir).mkdir(parents=True, exist_ok=True)
35 |
36 | dist_util.setup_dist()
37 | logger.configure(dir=workdir)
38 | if dist.get_rank() == 0:
39 | name = args.exp if args.resume_checkpoint == "" else args.exp + "_resume"
40 | wandb.init(
41 | project="bridge",
42 | group=args.exp,
43 | name=name,
44 | config=vars(args),
45 | mode="offline" if not args.debug else "disabled",
46 | )
47 | logger.log("creating model and diffusion...")
48 |
49 | data_image_size = args.image_size
50 |
51 | # Load target model
52 | resume_train_flag = False
53 | if args.resume_checkpoint == "":
54 | model_ckpts = list(glob(f"{workdir}/*model*[0-9].*"))
55 | if len(model_ckpts) > 0:
56 | max_ckpt = max(model_ckpts, key=lambda x: int(x.split("model_")[-1].split(".")[0]))
57 | if os.path.exists(max_ckpt):
58 | args.resume_checkpoint = max_ckpt
59 | resume_train_flag = True
60 | elif args.pretrained_ckpt is not None:
61 | max_ckpt = args.pretrained_ckpt
62 | args.resume_checkpoint = max_ckpt
63 | if dist.get_rank() == 0 and args.resume_checkpoint != "":
64 | logger.log("Resuming from checkpoint: ", max_ckpt)
65 |
66 | model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))
67 | model.to(dist_util.dev())
68 |
69 | if dist.get_rank() == 0:
70 | wandb.watch(model, log="all")
71 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
72 |
73 | if args.batch_size == -1:
74 | batch_size = args.global_batch_size // dist.get_world_size()
75 | if args.global_batch_size % dist.get_world_size() != 0:
76 | logger.log(f"warning, using smaller global_batch_size of {dist.get_world_size()*batch_size} instead of {args.global_batch_size}")
77 | else:
78 | batch_size = args.batch_size
79 |
80 | if dist.get_rank() == 0:
81 | logger.log("creating data loader...")
82 |
83 | data, test_data = load_data(
84 | data_dir=args.data_dir,
85 | dataset=args.dataset,
86 | batch_size=batch_size,
87 | image_size=data_image_size,
88 | num_workers=args.num_workers,
89 | )
90 |
91 | if args.use_augment:
92 | augment = AugmentPipe(p=0.12, xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
93 | else:
94 | augment = None
95 |
96 | logger.log("training...")
97 | TrainLoop(
98 | model=model,
99 | diffusion=diffusion,
100 | train_data=data,
101 | test_data=test_data,
102 | batch_size=batch_size,
103 | microbatch=-1 if args.microbatch >= batch_size else args.microbatch,
104 | lr=args.lr,
105 | ema_rate=args.ema_rate,
106 | log_interval=args.log_interval,
107 | test_interval=args.test_interval,
108 | save_interval=args.save_interval,
109 | save_interval_for_preemption=args.save_interval_for_preemption,
110 | resume_checkpoint=args.resume_checkpoint,
111 | workdir=workdir,
112 | use_fp16=args.use_fp16,
113 | fp16_scale_growth=args.fp16_scale_growth,
114 | schedule_sampler=schedule_sampler,
115 | weight_decay=args.weight_decay,
116 | lr_anneal_steps=args.lr_anneal_steps,
117 | augment_pipe=augment,
118 | train_mode=args.train_mode,
119 | resume_train_flag=resume_train_flag,
120 | **sample_defaults(),
121 | ).run_loop()
122 |
123 |
124 | def create_argparser():
125 | defaults = dict(
126 | data_dir="",
127 | dataset="edges2handbags",
128 | schedule_sampler="real-uniform",
129 | lr=1e-4,
130 | weight_decay=0.0,
131 | lr_anneal_steps=0,
132 | global_batch_size=256,
133 | batch_size=-1,
134 | microbatch=-1, # -1 disables microbatches
135 | ema_rate="0.9999", # comma-separated list of EMA values
136 | log_interval=50,
137 | test_interval=500,
138 | save_interval=10000,
139 | save_interval_for_preemption=50000,
140 | resume_checkpoint="",
141 | exp="",
142 | use_fp16=True,
143 | fp16_scale_growth=1e-3,
144 | debug=False,
145 | num_workers=8,
146 | use_augment=False,
147 | pretrained_ckpt=None,
148 | train_mode="ddbm",
149 | )
150 | defaults.update(model_and_diffusion_defaults())
151 | parser = argparse.ArgumentParser()
152 | add_dict_to_argparser(parser, defaults)
153 | return parser
154 |
155 |
156 | if __name__ == "__main__":
157 | args = create_argparser().parse_args()
158 | main(args)
159 |
--------------------------------------------------------------------------------