├── torch_utils ├── __init__.py ├── distributed.py ├── persistence.py ├── training_stats.py └── misc.py ├── training ├── __init__.py ├── loss.py ├── dataset.py ├── training_loop.py ├── augment.py └── networks.py ├── dnnlib ├── __init__.py └── util.py ├── Dockerfile ├── fid.py ├── README.md ├── generate.py ├── LICENSE ├── train.py └── dataset_tool.py /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM nvcr.io/nvidia/pytorch:22.08-py3 16 | ENV TZ=Asia/Singapore 17 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 18 | 19 | ENV PYTHONDONTWRITEBYTECODE 1 20 | ENV PYTHONUNBUFFERED 1 21 | 22 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 23 | 24 | WORKDIR /workspace 25 | 26 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 27 | ENTRYPOINT ["/entry.sh"] 28 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Loss functions used in the paper "Fast Diffusion Model".""" 16 | 17 | import torch 18 | from torch_utils import persistence 19 | from torch_utils import distributed as dist 20 | import numpy as np 21 | 22 | #---------------------------------------------------------------------------- 23 | # VP loss function with FDM loss weight warmup strategy. 24 | 25 | @persistence.persistent_class 26 | class VPLoss: 27 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5, warmup_ite=None): 28 | self.beta_d = beta_d 29 | self.beta_min = beta_min 30 | self.epsilon_t = epsilon_t 31 | self.warmup_ite = warmup_ite 32 | self.clamp_cur = 5. 33 | self.clamp_max = 500. 34 | if self.warmup_ite: 35 | self.warmup_step = np.exp(np.log(100) / self.warmup_ite) 36 | 37 | def __call__(self, net, images, labels, augment_pipe=None): 38 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 39 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 40 | weight = 1 / sigma ** 2 41 | if self.warmup_ite: 42 | if self.clamp_cur < self.clamp_max: 43 | weight.clamp_max_(self.clamp_cur) 44 | self.clamp_cur *= self.warmup_step 45 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 46 | n = torch.randn_like(y) * sigma 47 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 48 | loss = weight * ((D_yn - y) ** 2) 49 | return loss 50 | 51 | def sigma(self, t): 52 | t = torch.as_tensor(t) 53 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 54 | 55 | #---------------------------------------------------------------------------- 56 | # VE loss function with FDM loss weight warmup strategy. 57 | 58 | @persistence.persistent_class 59 | class VELoss: 60 | def __init__(self, sigma_min=0.02, sigma_max=100, warmup_ite=None): 61 | self.sigma_min = sigma_min 62 | self.sigma_max = sigma_max 63 | self.warmup_ite = warmup_ite 64 | self.clamp_cur = 5. 65 | self.clamp_max = 500. 66 | if self.warmup_ite: 67 | self.warmup_step = np.exp(np.log(100) / self.warmup_ite) 68 | 69 | def __call__(self, net, images, labels, augment_pipe=None): 70 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 71 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 72 | weight = 1 / sigma ** 2 73 | if self.warmup_ite: 74 | if self.clamp_cur < self.clamp_max: 75 | weight.clamp_max_(self.clamp_cur) 76 | self.clamp_cur *= self.warmup_step 77 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 78 | n = torch.randn_like(y) * sigma 79 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 80 | loss = weight * ((D_yn - y) ** 2) 81 | return loss 82 | 83 | #---------------------------------------------------------------------------- 84 | # EDM loss function with our FDM weight warmup strategy. 85 | 86 | @persistence.persistent_class 87 | class EDMLoss: 88 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, warmup_ite=None): 89 | self.P_mean = P_mean 90 | self.P_std = P_std 91 | self.sigma_data = sigma_data 92 | self.warmup_ite = warmup_ite 93 | self.clamp_cur = 5. 94 | self.clamp_max = 500. 95 | if self.warmup_ite: 96 | self.warmup_step = np.exp(np.log(100) / self.warmup_ite) 97 | 98 | def __call__(self, net, images, labels=None, augment_pipe=None): 99 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 100 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 101 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 102 | if self.warmup_ite: 103 | if self.clamp_cur < self.clamp_max: 104 | weight.clamp_max_(self.clamp_cur) 105 | self.clamp_cur *= self.warmup_step 106 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 107 | n = torch.randn_like(y) * sigma 108 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 109 | loss = weight * ((D_yn - y) ** 2) 110 | return loss 111 | -------------------------------------------------------------------------------- /fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import click 12 | import tqdm 13 | import pickle 14 | import numpy as np 15 | import scipy.linalg 16 | import torch 17 | import dnnlib 18 | from torch_utils import distributed as dist 19 | from training import dataset 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def calculate_inception_stats( 24 | image_path, num_expected=None, seed=0, max_batch_size=64, 25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'), 26 | ): 27 | # Rank 0 goes first. 28 | if dist.get_rank() != 0: 29 | torch.distributed.barrier() 30 | 31 | # Load Inception-v3 model. 32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 33 | dist.print0('Loading Inception-v3 model...') 34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 35 | detector_kwargs = dict(return_features=True) 36 | feature_dim = 2048 37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 38 | detector_net = pickle.load(f).to(device) 39 | 40 | # List images. 41 | dist.print0(f'Loading images from "{image_path}"...') 42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 43 | if num_expected is not None and len(dataset_obj) < num_expected: 44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 45 | if len(dataset_obj) < 2: 46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 47 | 48 | # Other ranks follow. 49 | if dist.get_rank() == 0: 50 | torch.distributed.barrier() 51 | 52 | # Divide images into batches. 53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 57 | 58 | # Accumulate statistics. 59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 63 | torch.distributed.barrier() 64 | if images.shape[0] == 0: 65 | continue 66 | if images.shape[1] == 1: 67 | images = images.repeat([1, 3, 1, 1]) 68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 69 | mu += features.sum(0) 70 | sigma += features.T @ features 71 | 72 | # Calculate grand totals. 73 | torch.distributed.all_reduce(mu) 74 | torch.distributed.all_reduce(sigma) 75 | mu /= len(dataset_obj) 76 | sigma -= mu.ger(mu) * len(dataset_obj) 77 | sigma /= len(dataset_obj) - 1 78 | return mu.cpu().numpy(), sigma.cpu().numpy() 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 83 | m = np.square(mu - mu_ref).sum() 84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 85 | fid = m + np.trace(sigma + sigma_ref - s * 2) 86 | return float(np.real(fid)) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @click.group() 91 | def main(): 92 | """Calculate Frechet Inception Distance (FID). 93 | 94 | Examples: 95 | 96 | \b 97 | # Generate 50000 images and save them as fid-tmp/*/*.png 98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 100 | 101 | \b 102 | # Calculate FID 103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ 104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 105 | 106 | \b 107 | # Compute dataset reference statistics 108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 109 | """ 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | @main.command() 114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 119 | 120 | def calc(image_path, ref_path, num_expected, seed, batch): 121 | """Calculate FID for a given set of images.""" 122 | torch.multiprocessing.set_start_method('spawn') 123 | dist.init() 124 | 125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 126 | ref = None 127 | if dist.get_rank() == 0: 128 | with dnnlib.util.open_url(ref_path) as f: 129 | ref = dict(np.load(f)) 130 | 131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 132 | dist.print0('Calculating FID...') 133 | if dist.get_rank() == 0: 134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 135 | print(f'{fid:g}') 136 | torch.distributed.barrier() 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | @main.command() 141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 144 | 145 | def ref(dataset_path, dest_path, batch): 146 | """Calculate dataset reference statistics needed by 'calc'.""" 147 | torch.multiprocessing.set_start_method('spawn') 148 | dist.init() 149 | 150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 152 | if dist.get_rank() == 0: 153 | if os.path.dirname(dest_path): 154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 155 | np.savez(dest_path, mu=mu, sigma=sigma) 156 | 157 | torch.distributed.barrier() 158 | dist.print0('Done.') 159 | 160 | #---------------------------------------------------------------------------- 161 | 162 | if __name__ == "__main__": 163 | main() 164 | 165 | #---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Diffusion Model 2 | 3 | This is an official PyTorch implementation of Fast Diffusion Model. See the paper [here](https://arxiv.org/abs/2306.06991). If you find our FDM helpful or heuristic to your projects, please cite this paper and also star this repository. Thanks! 4 | 5 | ```bibtex 6 | @misc{wu2023fast, 7 | title={Fast Diffusion Model}, 8 | author={Zike Wu and Pan Zhou and Kenji Kawaguchi and Hanwang Zhang}, 9 | year={2023}, 10 | eprint={2306.06991}, 11 | archivePrefix={arXiv}, 12 | primaryClass={cs.CV} 13 | } 14 | ``` 15 | 16 | Acknowledgement: This repo is based on the following amazing projects: [EDM](https://github.com/NVlabs/edm) and [DPM-Solver](https://github.com/LuChengTHU/dpm-solver). 17 | ## Results 18 | Image synthesis performance (FID) under different million training images (Mimg) is as follows. 19 | | Dataset | Duration
(Mimg) | EDM | EDM-FDM | VP | VP-FDM | VE | VE-FDM | 20 | |:--------:|:---------------:|:----:|:-------:|:----:|:------:|:-----:|:------:| 21 | | CIFAR10 | 50 | 5.76 | 2.17 | 2.74 | 2.74 | 49.47 | 10.01 | 22 | | CIFAR10 | 100 | 1.99 | 1.93 | 2.24 | 2.24 | 4.05 | 3.26 | 23 | | CIFAR10 | 150 | 1.92 | 1.83 | 2.19 | 2.13 | 3.27 | 3.00 | 24 | | CIFAR10 | 200 | 1.88 | 1.79 | 2.15 | 2.08 | 3.09 | 2.85 | 25 | | FFHQ | 50 | 3.21 | 3.27 | 3.07 | 12.49 | 96.49 | 93.72 | 26 | | FFHQ | 100 | 2.87 | 2.69 | 2.83 | 2.80 | 94.14 | 88.42 | 27 | | FFHQ | 150 | 2.69 | 2.63 | 2.73 | 2.53 | 79.20 | 4.73 | 28 | | FFHQ | 200 | 2.65 | 2.59 | 2.69 | 2.43 | 38.97 | 3.04 | 29 | | AFHQv2 | 50 | 2.62 | 2.73 | 3.46 | 25.70 | 57.93 | 54.41 | 30 | | AFHQv2 | 100 | 2.57 | 2.05 | 2.81 | 2.65 | 57.87 | 52.45 | 31 | | AFHQv2 | 150 | 2.44 | 1.96 | 2.72 | 2.47 | 57.69 | 50.53 | 32 | | AFHQv2 | 200 | 2.37 | 1.93 | 2.61 | 2.39 | 57.48 | 47.30 | 33 | 34 | Image synthesis performance (FID) under different inference cost on AFHQv2 with EDM sampler. 35 | | NFE | EDM | EDM-FDM | VP | VP-FDM | VE | VE-FDM | 36 | |:---:|:----:|:-------:|:----:|:------:|:-----:|:------:| 37 | | 25 | 2.78 | 2.32 | 2.88 | 2.59 | 61.04 | 48.29 | 38 | | 49 | 2.39 | 1.93 | 2.64 | 2.41 | 57.59 | 47.49 | 39 | | 79 | 2.37 | 1.93 | 2.61 | 2.39 | 57.48 | 47.30 | 40 | 41 | Image synthesis performance (FID) under different inference cost on AFHQv2 with DPM-Solver++. 42 | | NFE | EDM | EDM-FDM | VP | VP-FDM | VE | VE-FDM | 43 | |:---:|:----:|:-------:|:----:|:------:|:-----:|:------:| 44 | | 25 | 2.60 | 2.09 | 2.99 | 2.64 | 59.26 | 49.51 | 45 | | 49 | 2.42 | 1.98 | 2.79 | 2.45 | 59.16 | 48.68 | 46 | | 79 | 2.39 | 1.95 | 2.78 | 2.42 | 58.91 | 48.66 | 47 | 48 | ## Requirements 49 | All experiments were conducted using PyTorch 1.13.0, CUDA 11.7.1, and CuDNN 8.5.0. We strongly recommend to use the [provided Dockerfile](./Dockerfile) to build an image to reproduce our experiments. 50 | 51 | ## Pre-trained models 52 | We provide pre-trained models for our FDMs along with the baseline models on Hugging Face. Download the checkpoints [here](https://huggingface.co/sail/FDM/tree/main). 53 | 54 | ## Preparing datasets 55 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive: 56 | 57 | ```.bash 58 | python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \ 59 | --dest=datasets/cifar10-32x32.zip 60 | python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz 61 | ``` 62 | 63 | **FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at 64x64 resolution: 64 | 65 | ```.bash 66 | python dataset_tool.py --source=downloads/ffhq/images1024x1024 \ 67 | --dest=datasets/ffhq-64x64.zip --resolution=64x64 68 | python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz 69 | ``` 70 | 71 | **AFHQv2:** Download the updated [Animal Faces-HQ dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) (`afhq-v2-dataset`) and convert to ZIP archive at 64x64 resolution: 72 | 73 | ```.bash 74 | python dataset_tool.py --source=downloads/afhqv2 \ 75 | --dest=datasets/afhqv2-64x64.zip --resolution=64x64 76 | python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz 77 | ``` 78 | ## Training from scratch 79 | Train FDM for class-conditional CIFAR-10 using 8 GPUs: 80 | ```.bash 81 | # EDM-FDM 82 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 83 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp \ 84 | --precond=fdm_edm --warmup_ite=200 85 | 86 | # VP-FDM 87 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 88 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp --cres=1,2,2,2 \ 89 | --precond=fdm_vp --warmup_ite=400 90 | 91 | # VE-FDM 92 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 93 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ncsnpp --cres=1,2,2,2 \ 94 | --precond=fdm_ve --warmup_ite=400 95 | ``` 96 | 97 | Train FDM for unconditional FFHQ using 8 GPUs: 98 | ```.bash 99 | # EDM-FDM 100 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output 101 | --data=datasets/ffhq-64x64.zip --cond=0 --arch=ddpmpp \ 102 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 \ 103 | --precond=fdm_edm --warmup_ite=800 --fdm_multipler=1 104 | 105 | # VP-FDM 106 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 107 | --data=datasets/ffhq-64x64.zip --cond=0 --arch=ddpmpp \ 108 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 \ 109 | --precond=fdm_vp --warmup_ite=400 --fdm_multipler=1 110 | 111 | # VE-FDM 112 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 113 | --data=datasets/ffhq-64x64.zip --cond=0 --arch=ncsnpp \ 114 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 \ 115 | --precond=fdm_ve --warmup_ite=400 116 | ``` 117 | 118 | Train FDM for unconditional AFHQv2 using 8 GPUs: 119 | ```.bash 120 | # EDM-FDM 121 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output 122 | --data=datasets/afhqv2-64x64.zip --cond=0 --arch=ddpmpp \ 123 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 \ 124 | --precond=fdm_edm --warmup_ite=400 125 | 126 | # VP-FDM 127 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 128 | --data=datasets/afhqv2-64x64.zip --cond=0 --arch=ddpmpp \ 129 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 \ 130 | --precond=fdm_vp --warmup_ite=400 131 | 132 | # VE-FDM 133 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \ 134 | --data=datasets/afhqv2-64x64.zip --cond=0 --arch=ncsnpp \ 135 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 \ 136 | --precond=fdm_ve --warmup_ite=400 137 | ``` 138 | 139 | ## Calculating FID 140 | To compute Fréchet inception distance (FID) for a given model and sampler, first generate 50,000 random images and then compare them against the dataset reference statistics using `fid.py`, replace `$PATH_TO_CHECKPOINT` with the path to the checkpoint: 141 | 142 | ```.bash 143 | # Generate 50000 images 144 | torchrun --standalone --nproc_per_node=8 generate.py --outdir=fid \ 145 | --seeds=0-49999 --subdirs --network=$PATH_TO_CHECKPOINT 146 | # Calculate FID 147 | torchrun --standalone --nproc_per_node=8 fid.py calc --images=fid \ 148 | --ref=fid-refs/cifar10-32x32.npz 149 | ``` 150 | 151 | Note that the generated images should be evaluated against the same reference dataset that the model was originally trained on. Please ensure to replace the `--ref` option with the correct one (*e.g.*, `fid-refs/ffhq-64x64.npz` or `fid-refs/afhqv2-64x64.npz`) to obtain the right FID score. Addtionally, you can use `--solver=dpm` option to generate images with DPM-Solver++. -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | # Abstract base class for datasets. 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | cache = False, # Cache images in CPU memory? 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._cache = cache 40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 41 | self._raw_labels = None 42 | self._label_shape = None 43 | 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_raw_labels(self): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def __getstate__(self): 79 | return dict(self.__dict__, _raw_labels=None) 80 | 81 | def __del__(self): 82 | try: 83 | self.close() 84 | except: 85 | pass 86 | 87 | def __len__(self): 88 | return self._raw_idx.size 89 | 90 | def __getitem__(self, idx): 91 | raw_idx = self._raw_idx[idx] 92 | image = self._cached_images.get(raw_idx, None) 93 | if image is None: 94 | image = self._load_raw_image(raw_idx) 95 | if self._cache: 96 | self._cached_images[raw_idx] = image 97 | assert isinstance(image, np.ndarray) 98 | assert list(image.shape) == self.image_shape 99 | assert image.dtype == np.uint8 100 | if self._xflip[idx]: 101 | assert image.ndim == 3 # CHW 102 | image = image[:, :, ::-1] 103 | return image.copy(), self.get_label(idx) 104 | 105 | def get_label(self, idx): 106 | label = self._get_raw_labels()[self._raw_idx[idx]] 107 | if label.dtype == np.int64: 108 | onehot = np.zeros(self.label_shape, dtype=np.float32) 109 | onehot[label] = 1 110 | label = onehot 111 | return label.copy() 112 | 113 | def get_details(self, idx): 114 | d = dnnlib.EasyDict() 115 | d.raw_idx = int(self._raw_idx[idx]) 116 | d.xflip = (int(self._xflip[idx]) != 0) 117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 118 | return d 119 | 120 | @property 121 | def name(self): 122 | return self._name 123 | 124 | @property 125 | def image_shape(self): 126 | return list(self._raw_shape[1:]) 127 | 128 | @property 129 | def num_channels(self): 130 | assert len(self.image_shape) == 3 # CHW 131 | return self.image_shape[0] 132 | 133 | @property 134 | def resolution(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | assert self.image_shape[1] == self.image_shape[2] 137 | return self.image_shape[1] 138 | 139 | @property 140 | def label_shape(self): 141 | if self._label_shape is None: 142 | raw_labels = self._get_raw_labels() 143 | if raw_labels.dtype == np.int64: 144 | self._label_shape = [int(np.max(raw_labels)) + 1] 145 | else: 146 | self._label_shape = raw_labels.shape[1:] 147 | return list(self._label_shape) 148 | 149 | @property 150 | def label_dim(self): 151 | assert len(self.label_shape) == 1 152 | return self.label_shape[0] 153 | 154 | @property 155 | def has_labels(self): 156 | return any(x != 0 for x in self.label_shape) 157 | 158 | @property 159 | def has_onehot_labels(self): 160 | return self._get_raw_labels().dtype == np.int64 161 | 162 | #---------------------------------------------------------------------------- 163 | # Dataset subclass that loads images recursively from the specified directory 164 | # or ZIP file. 165 | 166 | class ImageFolderDataset(Dataset): 167 | def __init__(self, 168 | path, # Path to directory or zip. 169 | resolution = None, # Ensure specific resolution, None = highest available. 170 | use_pyspng = True, # Use pyspng if available? 171 | **super_kwargs, # Additional arguments for the Dataset base class. 172 | ): 173 | self._path = path 174 | self._use_pyspng = use_pyspng 175 | self._zipfile = None 176 | 177 | if os.path.isdir(self._path): 178 | self._type = 'dir' 179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 180 | elif self._file_ext(self._path) == '.zip': 181 | self._type = 'zip' 182 | self._all_fnames = set(self._get_zipfile().namelist()) 183 | else: 184 | raise IOError('Path must point to a directory or zip') 185 | 186 | PIL.Image.init() 187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 188 | if len(self._image_fnames) == 0: 189 | raise IOError('No image files found in the specified path') 190 | 191 | name = os.path.splitext(os.path.basename(self._path))[0] 192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 194 | raise IOError('Image files do not match the specified resolution') 195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 196 | 197 | @staticmethod 198 | def _file_ext(fname): 199 | return os.path.splitext(fname)[1].lower() 200 | 201 | def _get_zipfile(self): 202 | assert self._type == 'zip' 203 | if self._zipfile is None: 204 | self._zipfile = zipfile.ZipFile(self._path) 205 | return self._zipfile 206 | 207 | def _open_file(self, fname): 208 | if self._type == 'dir': 209 | return open(os.path.join(self._path, fname), 'rb') 210 | if self._type == 'zip': 211 | return self._get_zipfile().open(fname, 'r') 212 | return None 213 | 214 | def close(self): 215 | try: 216 | if self._zipfile is not None: 217 | self._zipfile.close() 218 | finally: 219 | self._zipfile = None 220 | 221 | def __getstate__(self): 222 | return dict(super().__getstate__(), _zipfile=None) 223 | 224 | def _load_raw_image(self, raw_idx): 225 | fname = self._image_fnames[raw_idx] 226 | with self._open_file(fname) as f: 227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 228 | image = pyspng.load(f.read()) 229 | else: 230 | image = np.array(PIL.Image.open(f)) 231 | if image.ndim == 2: 232 | image = image[:, :, np.newaxis] # HW => HWC 233 | image = image.transpose(2, 0, 1) # HWC => CHW 234 | return image 235 | 236 | def _load_raw_labels(self): 237 | fname = 'dataset.json' 238 | if fname not in self._all_fnames: 239 | return None 240 | with self._open_file(fname) as f: 241 | labels = json.load(f)['labels'] 242 | if labels is None: 243 | return None 244 | labels = dict(labels) 245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 246 | labels = np.array(labels) 247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 248 | return labels 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generate random images using EDM Sampler and DPM-Solver++.""" 16 | 17 | import os 18 | import re 19 | import click 20 | import tqdm 21 | import pickle 22 | import numpy as np 23 | import torch 24 | import PIL.Image 25 | import dnnlib 26 | from torch_utils import distributed as dist 27 | from dpm_solver import NoiseScheduleEDM, DPM_Solver 28 | 29 | #---------------------------------------------------------------------------- 30 | # EDM sampler 31 | 32 | def edm_sampler( 33 | net, latents, class_labels=None, randn_like=torch.randn_like, 34 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 35 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 36 | ): 37 | # Adjust noise levels based on what's supported by the network. 38 | sigma_min = max(sigma_min, net.sigma_min) 39 | sigma_max = min(sigma_max, net.sigma_max) 40 | 41 | # Time step discretization. 42 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 43 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 44 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 45 | 46 | # Main sampling loop. 47 | x_next = latents.to(torch.float64) * t_steps[0] 48 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 49 | x_cur = x_next 50 | 51 | # Increase noise temporarily. 52 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 53 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 54 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 55 | 56 | # Euler step. 57 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 58 | d_cur = (x_hat - denoised) / t_hat 59 | x_next = x_hat + (t_next - t_hat) * d_cur 60 | 61 | # Apply 2nd order correction. 62 | if i < num_steps - 1: 63 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 64 | d_prime = (x_next - denoised) / t_next 65 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 66 | 67 | return x_next 68 | 69 | #---------------------------------------------------------------------------- 70 | # DPM-Solver++ 71 | 72 | def ablation_sampler( 73 | net, latents, class_labels=None, randn_like=torch.randn_like, 74 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 75 | denoise=True, **kwargs 76 | ): 77 | x_next = latents.to(torch.float64) * sigma_max 78 | ns = NoiseScheduleEDM('linear') 79 | with torch.no_grad(): 80 | noise_pred_fn = lambda x, t: (x - net(x, t, class_labels).to(torch.float64)) / t 81 | dpm_solver = DPM_Solver(noise_pred_fn, ns, algorithm_type="dpmsolver++") 82 | # Initial sample 83 | x_next = dpm_solver.sample( 84 | x_next, 85 | steps=num_steps - 1 if denoise else num_steps, 86 | t_start=sigma_max, 87 | t_end=sigma_min, 88 | order=3, 89 | skip_type="logSNR", 90 | method="singlestep", 91 | denoise_to_zero=denoise, 92 | lower_order_final=True, 93 | ) 94 | return x_next 95 | 96 | #---------------------------------------------------------------------------- 97 | # Wrapper for torch.Generator that allows specifying a different random seed 98 | # for each sample in a minibatch. 99 | 100 | class StackedRandomGenerator: 101 | def __init__(self, device, seeds): 102 | super().__init__() 103 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 104 | 105 | def randn(self, size, **kwargs): 106 | assert size[0] == len(self.generators) 107 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 108 | 109 | def randn_like(self, input): 110 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 111 | 112 | def randint(self, *args, size, **kwargs): 113 | assert size[0] == len(self.generators) 114 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 115 | 116 | #---------------------------------------------------------------------------- 117 | # Parse a comma separated list of numbers or ranges and return a list of ints. 118 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 119 | 120 | def parse_int_list(s): 121 | if isinstance(s, list): return s 122 | ranges = [] 123 | range_re = re.compile(r'^(\d+)-(\d+)$') 124 | for p in s.split(','): 125 | m = range_re.match(p) 126 | if m: 127 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 128 | else: 129 | ranges.append(int(p)) 130 | return ranges 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | @click.command() 135 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) 136 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) 137 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True) 138 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 139 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 140 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 141 | 142 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 143 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 144 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 145 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 146 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 147 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 148 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 149 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 150 | @click.option('--solver', help='Ablate ODE solver', metavar='edm|dpm', type=click.Choice(['edm', 'dpm']), default='edm') 151 | 152 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs): 153 | dist.init() 154 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 155 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches) 156 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 157 | 158 | # Rank 0 goes first. 159 | if dist.get_rank() != 0: 160 | torch.distributed.barrier() 161 | 162 | # Load network. 163 | dist.print0(f'Loading network from "{network_pkl}"...') 164 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 165 | net = pickle.load(f)['ema'].to(device) 166 | 167 | # Other ranks follow. 168 | if dist.get_rank() == 0: 169 | torch.distributed.barrier() 170 | 171 | # Loop over batches. 172 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...') 173 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)): 174 | torch.distributed.barrier() 175 | batch_size = len(batch_seeds) 176 | if batch_size == 0: 177 | continue 178 | 179 | # Pick latents and labels. 180 | rnd = StackedRandomGenerator(device, batch_seeds) 181 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 182 | class_labels = None 183 | if net.label_dim: 184 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] 185 | if class_idx is not None: 186 | class_labels[:, :] = 0 187 | class_labels[:, class_idx] = 1 188 | 189 | # Generate images. 190 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} 191 | sampler_fn = edm_sampler if sampler_kwargs.pop('solver', 'edm') == 'edm' else ablation_sampler 192 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs) 193 | 194 | # Save images. 195 | images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() 196 | for seed, image_np in zip(batch_seeds, images_np): 197 | image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir 198 | os.makedirs(image_dir, exist_ok=True) 199 | image_path = os.path.join(image_dir, f'{seed:06d}.png') 200 | if image_np.shape[2] == 1: 201 | PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path) 202 | else: 203 | PIL.Image.fromarray(image_np, 'RGB').save(image_path) 204 | 205 | # Done. 206 | torch.distributed.barrier() 207 | dist.print0('Done.') 208 | 209 | #---------------------------------------------------------------------------- 210 | 211 | if __name__ == "__main__": 212 | main() 213 | 214 | #---------------------------------------------------------------------------- 215 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | print(vars(module)) 265 | print() 266 | return outputs 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /training/training_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Main training loop.""" 9 | 10 | import os 11 | import time 12 | import copy 13 | import json 14 | import pickle 15 | import psutil 16 | import numpy as np 17 | import torch 18 | import dnnlib 19 | from torch_utils import distributed as dist 20 | from torch_utils import training_stats 21 | from torch_utils import misc 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def training_loop( 26 | run_dir = '.', # Output directory. 27 | dataset_kwargs = {}, # Options for training set. 28 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. 29 | network_kwargs = {}, # Options for model and preconditioning. 30 | loss_kwargs = {}, # Options for loss function. 31 | optimizer_kwargs = {}, # Options for optimizer. 32 | augment_kwargs = None, # Options for augmentation pipeline, None = disable. 33 | seed = 0, # Global random seed. 34 | batch_size = 512, # Total batch size for one training iteration. 35 | batch_gpu = None, # Limit batch size per GPU, None = no limit. 36 | total_kimg = 200000, # Training duration, measured in thousands of training images. 37 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. 38 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. 39 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration. 40 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. 41 | kimg_per_tick = 50, # Interval of progress prints. 42 | snapshot_ticks = 50, # How often to save network snapshots, None = disable. 43 | state_dump_ticks = 500, # How often to dump training state, None = disable. 44 | resume_pkl = None, # Start from the given network snapshot, None = random initialization. 45 | resume_state_dump = None, # Start from the given training state, None = reset training state. 46 | resume_kimg = 0, # Start from the given training progress. 47 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 48 | device = torch.device('cuda'), 49 | ): 50 | # Initialize. 51 | start_time = time.time() 52 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) 53 | torch.manual_seed(np.random.randint(1 << 31)) 54 | torch.backends.cudnn.benchmark = cudnn_benchmark 55 | torch.backends.cudnn.allow_tf32 = False 56 | torch.backends.cuda.matmul.allow_tf32 = False 57 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 58 | 59 | # Select batch size per GPU. 60 | batch_gpu_total = batch_size // dist.get_world_size() 61 | if batch_gpu is None or batch_gpu > batch_gpu_total: 62 | batch_gpu = batch_gpu_total 63 | num_accumulation_rounds = batch_gpu_total // batch_gpu 64 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() 65 | 66 | # Load dataset. 67 | dist.print0('Loading dataset...') 68 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset 69 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) 70 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) 71 | 72 | # Construct network. 73 | dist.print0('Constructing network...') 74 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) 75 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 76 | net.train().requires_grad_(True).to(device) 77 | if dist.get_rank() == 0: 78 | with torch.no_grad(): 79 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device) 80 | sigma = torch.ones([batch_gpu], device=device) 81 | labels = torch.zeros([batch_gpu, net.label_dim], device=device) 82 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2) 83 | 84 | # Setup optimizer. 85 | dist.print0('Setting up optimizer...') 86 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss 87 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer 88 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe 89 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) 90 | ema = copy.deepcopy(net).eval().requires_grad_(False) 91 | 92 | # Resume training from previous snapshot. 93 | if resume_pkl is not None: 94 | dist.print0(f'Loading network weights from "{resume_pkl}"...') 95 | if dist.get_rank() != 0: 96 | torch.distributed.barrier() # rank 0 goes first 97 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: 98 | data = pickle.load(f) 99 | if dist.get_rank() == 0: 100 | torch.distributed.barrier() # other ranks follow 101 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) 102 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) 103 | del data # conserve memory 104 | if resume_state_dump: 105 | dist.print0(f'Loading training state from "{resume_state_dump}"...') 106 | data = torch.load(resume_state_dump, map_location=torch.device('cpu')) 107 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) 108 | optimizer.load_state_dict(data['optimizer_state']) 109 | del data # conserve memory 110 | 111 | # Train. 112 | dist.print0(f'Training for {total_kimg} kimg...') 113 | dist.print0() 114 | cur_nimg = resume_kimg * 1000 115 | cur_tick = 0 116 | tick_start_nimg = cur_nimg 117 | tick_start_time = time.time() 118 | maintenance_time = tick_start_time - start_time 119 | dist.update_progress(cur_nimg // 1000, total_kimg) 120 | stats_jsonl = None 121 | while True: 122 | 123 | # Accumulate gradients. 124 | optimizer.zero_grad(set_to_none=True) 125 | for round_idx in range(num_accumulation_rounds): 126 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 127 | images, labels = next(dataset_iterator) 128 | images = images.to(device).to(torch.float32) / 127.5 - 1 129 | labels = labels.to(device) 130 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) 131 | training_stats.report('Loss/loss', loss) 132 | loss.sum().mul(loss_scaling / batch_gpu_total).backward() 133 | 134 | # Update weights. 135 | for g in optimizer.param_groups: 136 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) 137 | for param in net.parameters(): 138 | if param.grad is not None: 139 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 140 | optimizer.step() 141 | 142 | # Update EMA. 143 | ema_halflife_nimg = ema_halflife_kimg * 1000 144 | if ema_rampup_ratio is not None: 145 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) 146 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) 147 | for p_ema, p_net in zip(ema.parameters(), net.parameters()): 148 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) 149 | 150 | # Perform maintenance tasks once per tick. 151 | cur_nimg += batch_size 152 | done = (cur_nimg >= total_kimg * 1000) 153 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): 154 | continue 155 | 156 | # Print status line, accumulating the same information in training_stats. 157 | tick_end_time = time.time() 158 | fields = [] 159 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] 160 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] 161 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] 162 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] 163 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] 164 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] 165 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] 166 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] 167 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] 168 | fields += [f"loss {loss.sum().item():<10.2f}"] 169 | torch.cuda.reset_peak_memory_stats() 170 | dist.print0(' '.join(fields)) 171 | 172 | # Check for abort. 173 | if (not done) and dist.should_stop(): 174 | done = True 175 | dist.print0() 176 | dist.print0('Aborting...') 177 | 178 | # Save network snapshot. 179 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): 180 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) 181 | for key, value in data.items(): 182 | if isinstance(value, torch.nn.Module): 183 | value = copy.deepcopy(value).eval().requires_grad_(False) 184 | misc.check_ddp_consistency(value) 185 | data[key] = value.cpu() 186 | del value # conserve memory 187 | if dist.get_rank() == 0: 188 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: 189 | pickle.dump(data, f) 190 | del data # conserve memory 191 | 192 | # Save full dump of the training state. 193 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: 194 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) 195 | 196 | # Update logs. 197 | training_stats.default_collector.update() 198 | if dist.get_rank() == 0: 199 | if stats_jsonl is None: 200 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at') 201 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n') 202 | stats_jsonl.flush() 203 | dist.update_progress(cur_nimg // 1000, total_kimg) 204 | 205 | # Update state. 206 | cur_tick += 1 207 | tick_start_nimg = cur_nimg 208 | tick_start_time = time.time() 209 | maintenance_time = tick_start_time - tick_end_time 210 | if done: 211 | break 212 | 213 | # Done. 214 | dist.print0() 215 | dist.print0('Exiting...') 216 | 217 | #---------------------------------------------------------------------------- 218 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Train diffusion-based generative model using the techniques described in the 16 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" 17 | 18 | import os 19 | import re 20 | import json 21 | import click 22 | import torch 23 | import dnnlib 24 | from torch_utils import distributed as dist 25 | from training import training_loop 26 | 27 | import warnings 28 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 29 | 30 | #---------------------------------------------------------------------------- 31 | # Parse a comma separated list of numbers or ranges and return a list of ints. 32 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 33 | 34 | def parse_int_list(s): 35 | if isinstance(s, list): return s 36 | ranges = [] 37 | range_re = re.compile(r'^(\d+)-(\d+)$') 38 | for p in s.split(','): 39 | m = range_re.match(p) 40 | if m: 41 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 42 | else: 43 | ranges.append(int(p)) 44 | return ranges 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | @click.command() 49 | 50 | # Main options. 51 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True) 52 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True) 53 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True) 54 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp', type=click.Choice(['ddpmpp', 'ncsnpp']), default='ddpmpp', show_default=True) 55 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm', 'fdm_edm', 'fdm_vp', 'fdm_ve']), default='fdm_edm', show_default=True) 56 | 57 | # Hyperparameters. 58 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True) 59 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) 60 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) 61 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int) 62 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list) 63 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 64 | @click.option('--lr_rampup', help='Learning rate rampup', metavar='FLOAT', type=click.FloatRange(min=0, max=1000), default=10, show_default=True) 65 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True) 66 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) 67 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True) 68 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) 69 | @click.option('--warmup_ite', help='Loss weight warmup iteration', metavar='FLOAT', type=float, default=None, show_default=True) 70 | @click.option('--fdm_multiplier', help='FDM multiplier', metavar='FLOAT', type=float, default=2.0) 71 | 72 | # Performance-related. 73 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 74 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 75 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 76 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) 77 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=8, show_default=True) 78 | 79 | # I/O-related. 80 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 81 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True) 82 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True) 83 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=200, show_default=True) 84 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True) 85 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int) 86 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 87 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str) 88 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) 89 | 90 | def main(**kwargs): 91 | opts = dnnlib.EasyDict(kwargs) 92 | torch.multiprocessing.set_start_method('spawn') 93 | dist.init() 94 | 95 | # Initialize config dict. 96 | c = dnnlib.EasyDict() 97 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) 98 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) 99 | c.network_kwargs = dnnlib.EasyDict() 100 | c.loss_kwargs = dnnlib.EasyDict() 101 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8) 102 | 103 | # Validate dataset options. 104 | try: 105 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) 106 | dataset_name = dataset_obj.name 107 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution 108 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size 109 | if opts.cond and not dataset_obj.has_labels: 110 | raise click.ClickException('--cond=True requires labels specified in dataset.json') 111 | del dataset_obj # conserve memory 112 | except IOError as err: 113 | raise click.ClickException(f'--data: {err}') 114 | 115 | # Network architecture. 116 | if opts.arch == 'ddpmpp': 117 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') 118 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2]) 119 | elif opts.arch == 'ncsnpp': 120 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') 121 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2]) 122 | else: 123 | raise(f'Unknown architecture: {opts.arch}') 124 | 125 | # Preconditioning & loss function. 126 | if opts.precond == 'vp': 127 | c.network_kwargs.class_name = 'training.networks.VPPrecond' 128 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 129 | elif opts.precond == 've': 130 | c.network_kwargs.class_name = 'training.networks.VEPrecond' 131 | c.loss_kwargs.class_name = 'training.loss.VELoss' 132 | elif opts.precond == 'edm': 133 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 134 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 135 | elif opts.precond == 'fdm_vp': 136 | # VP-FDM 137 | c.network_kwargs.class_name = 'training.networks.FDM_VPPrecond' 138 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 139 | c.network_kwargs.update(fdm_multiplier=opts.fdm_multiplier) 140 | c.loss_kwargs.update(warmup_ite=opts.warmup_ite) 141 | elif opts.precond == 'fdm_ve': 142 | # VE-FDM 143 | c.network_kwargs.class_name = 'training.networks.FDM_VEPrecond' 144 | c.loss_kwargs.class_name = 'training.loss.VELoss' 145 | c.network_kwargs.update(fdm_multiplier=opts.fdm_multiplier) 146 | c.loss_kwargs.update(warmup_ite=opts.warmup_ite) 147 | else: 148 | # EDM-FDM 149 | assert opts.precond == 'fdm_edm' 150 | c.network_kwargs.class_name = 'training.networks.FDM_EDMPrecond' 151 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 152 | c.network_kwargs.update(fdm_multiplier=opts.fdm_multiplier) 153 | c.loss_kwargs.update(warmup_ite=opts.warmup_ite) 154 | 155 | # Network options. 156 | if opts.cbase is not None: 157 | c.network_kwargs.model_channels = opts.cbase 158 | if opts.cres is not None: 159 | c.network_kwargs.channel_mult = opts.cres 160 | if opts.augment: 161 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) 162 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 163 | c.network_kwargs.augment_dim = 9 164 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) 165 | 166 | # Training options. 167 | c.total_kimg = max(int(opts.duration * 1000), 1) 168 | c.ema_halflife_kimg = int(opts.ema * 1000) 169 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 170 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) 171 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump) 172 | 173 | # Random seed. 174 | if opts.seed is not None: 175 | c.seed = opts.seed 176 | else: 177 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 178 | torch.distributed.broadcast(seed, src=0) 179 | c.seed = int(seed) 180 | 181 | # Transfer learning and resume. 182 | if opts.transfer is not None: 183 | if opts.resume is not None: 184 | raise click.ClickException('--transfer and --resume cannot be specified at the same time') 185 | c.resume_pkl = opts.transfer 186 | c.ema_rampup_ratio = None 187 | elif opts.resume is not None: 188 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume)) 189 | if not match or not os.path.isfile(opts.resume): 190 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') 191 | c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl') 192 | c.resume_kimg = int(match.group(1)) 193 | c.resume_state_dump = opts.resume 194 | 195 | # Description string. 196 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond' 197 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32' 198 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}' 199 | if opts.desc is not None: 200 | desc += f'-{opts.desc}' 201 | 202 | # Pick output directory. 203 | if dist.get_rank() == -1: # != 0: 204 | c.run_dir = None 205 | elif opts.nosubdir: 206 | c.run_dir = opts.outdir 207 | else: 208 | prev_run_dirs = [] 209 | if os.path.isdir(opts.outdir): 210 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] 211 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 212 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 213 | cur_run_id = max(prev_run_ids, default=-1) + 1 214 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') 215 | assert not os.path.exists(c.run_dir) 216 | 217 | # Print options. 218 | dist.print0() 219 | dist.print0('Training options:') 220 | dist.print0(json.dumps(c, indent=2)) 221 | dist.print0() 222 | dist.print0(f'Output directory: {c.run_dir}') 223 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}') 224 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') 225 | dist.print0(f'Network architecture: {opts.arch}') 226 | dist.print0(f'Preconditioning & loss: {opts.precond}') 227 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 228 | dist.print0(f'Batch size: {c.batch_size}') 229 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') 230 | dist.print0(f'network_kwargs: {c.network_kwargs}') 231 | dist.print0(f'loss_kwargs: {c.loss_kwargs}') 232 | dist.print0() 233 | 234 | # Dry run? 235 | if opts.dry_run: 236 | dist.print0('Dry run; exiting.') 237 | return 238 | 239 | # Create output directory. 240 | dist.print0('Creating output directory...') 241 | if dist.get_rank() == 0: 242 | os.makedirs(c.run_dir, exist_ok=True) 243 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 244 | json.dump(c, f, indent=2) 245 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) 246 | 247 | # Train. 248 | training_loop.training_loop(**c, lr_rampup_kimg=opts.lr_rampup * 1000) 249 | 250 | #---------------------------------------------------------------------------- 251 | 252 | if __name__ == "__main__": 253 | main() 254 | 255 | #---------------------------------------------------------------------------- 256 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 59 | self.file = None 60 | 61 | if file_name is not None: 62 | self.file = open(file_name, file_mode) 63 | 64 | self.should_flush = should_flush 65 | self.stdout = sys.stdout 66 | self.stderr = sys.stderr 67 | 68 | sys.stdout = self 69 | sys.stderr = self 70 | 71 | def __enter__(self) -> "Logger": 72 | return self 73 | 74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 75 | self.close() 76 | 77 | def write(self, text: Union[str, bytes]) -> None: 78 | """Write text to stdout (and a file) and optionally flush.""" 79 | if isinstance(text, bytes): 80 | text = text.decode() 81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 82 | return 83 | 84 | if self.file is not None: 85 | self.file.write(text) 86 | 87 | self.stdout.write(text) 88 | 89 | if self.should_flush: 90 | self.flush() 91 | 92 | def flush(self) -> None: 93 | """Flush written text to both stdout and a file, if open.""" 94 | if self.file is not None: 95 | self.file.flush() 96 | 97 | self.stdout.flush() 98 | 99 | def close(self) -> None: 100 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 101 | self.flush() 102 | 103 | # if using multiple loggers, prevent closing in wrong order 104 | if sys.stdout is self: 105 | sys.stdout = self.stdout 106 | if sys.stderr is self: 107 | sys.stderr = self.stderr 108 | 109 | if self.file is not None: 110 | self.file.close() 111 | self.file = None 112 | 113 | 114 | # Cache directories 115 | # ------------------------------------------------------------------------------------------ 116 | 117 | _dnnlib_cache_dir = None 118 | 119 | def set_cache_dir(path: str) -> None: 120 | global _dnnlib_cache_dir 121 | _dnnlib_cache_dir = path 122 | 123 | def make_cache_dir_path(*paths: str) -> str: 124 | if _dnnlib_cache_dir is not None: 125 | return os.path.join(_dnnlib_cache_dir, *paths) 126 | if 'DNNLIB_CACHE_DIR' in os.environ: 127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 128 | if 'HOME' in os.environ: 129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 130 | if 'USERPROFILE' in os.environ: 131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 133 | 134 | # Small util functions 135 | # ------------------------------------------------------------------------------------------ 136 | 137 | 138 | def format_time(seconds: Union[int, float]) -> str: 139 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 140 | s = int(np.rint(seconds)) 141 | 142 | if s < 60: 143 | return "{0}s".format(s) 144 | elif s < 60 * 60: 145 | return "{0}m {1:02}s".format(s // 60, s % 60) 146 | elif s < 24 * 60 * 60: 147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 148 | else: 149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 150 | 151 | 152 | def format_time_brief(seconds: Union[int, float]) -> str: 153 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 154 | s = int(np.rint(seconds)) 155 | 156 | if s < 60: 157 | return "{0}s".format(s) 158 | elif s < 60 * 60: 159 | return "{0}m {1:02}s".format(s // 60, s % 60) 160 | elif s < 24 * 60 * 60: 161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 162 | else: 163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 164 | 165 | 166 | def ask_yes_no(question: str) -> bool: 167 | """Ask the user the question until the user inputs a valid answer.""" 168 | while True: 169 | try: 170 | print("{0} [y/n]".format(question)) 171 | return strtobool(input().lower()) 172 | except ValueError: 173 | pass 174 | 175 | 176 | def tuple_product(t: Tuple) -> Any: 177 | """Calculate the product of the tuple elements.""" 178 | result = 1 179 | 180 | for v in t: 181 | result *= v 182 | 183 | return result 184 | 185 | 186 | _str_to_ctype = { 187 | "uint8": ctypes.c_ubyte, 188 | "uint16": ctypes.c_uint16, 189 | "uint32": ctypes.c_uint32, 190 | "uint64": ctypes.c_uint64, 191 | "int8": ctypes.c_byte, 192 | "int16": ctypes.c_int16, 193 | "int32": ctypes.c_int32, 194 | "int64": ctypes.c_int64, 195 | "float32": ctypes.c_float, 196 | "float64": ctypes.c_double 197 | } 198 | 199 | 200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 202 | type_str = None 203 | 204 | if isinstance(type_obj, str): 205 | type_str = type_obj 206 | elif hasattr(type_obj, "__name__"): 207 | type_str = type_obj.__name__ 208 | elif hasattr(type_obj, "name"): 209 | type_str = type_obj.name 210 | else: 211 | raise RuntimeError("Cannot infer type name from input") 212 | 213 | assert type_str in _str_to_ctype.keys() 214 | 215 | my_dtype = np.dtype(type_str) 216 | my_ctype = _str_to_ctype[type_str] 217 | 218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 219 | 220 | return my_dtype, my_ctype 221 | 222 | 223 | def is_pickleable(obj: Any) -> bool: 224 | try: 225 | with io.BytesIO() as stream: 226 | pickle.dump(obj, stream) 227 | return True 228 | except: 229 | return False 230 | 231 | 232 | # Functionality to import modules/objects by name, and call functions by name 233 | # ------------------------------------------------------------------------------------------ 234 | 235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 236 | """Searches for the underlying module behind the name to some python object. 237 | Returns the module and the object name (original name with module part removed).""" 238 | 239 | # allow convenience shorthands, substitute them by full names 240 | obj_name = re.sub("^np.", "numpy.", obj_name) 241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 242 | 243 | # list alternatives for (module_name, local_obj_name) 244 | parts = obj_name.split(".") 245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 246 | 247 | # try each alternative in turn 248 | for module_name, local_obj_name in name_pairs: 249 | try: 250 | module = importlib.import_module(module_name) # may raise ImportError 251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 252 | return module, local_obj_name 253 | except: 254 | pass 255 | 256 | # maybe some of the modules themselves contain errors? 257 | for module_name, _local_obj_name in name_pairs: 258 | try: 259 | importlib.import_module(module_name) # may raise ImportError 260 | except ImportError: 261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 262 | raise 263 | 264 | # maybe the requested attribute is missing? 265 | for module_name, local_obj_name in name_pairs: 266 | try: 267 | module = importlib.import_module(module_name) # may raise ImportError 268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 269 | except ImportError: 270 | pass 271 | 272 | # we are out of luck, but we have no idea why 273 | raise ImportError(obj_name) 274 | 275 | 276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 277 | """Traverses the object name and returns the last (rightmost) python object.""" 278 | if obj_name == '': 279 | return module 280 | obj = module 281 | for part in obj_name.split("."): 282 | obj = getattr(obj, part) 283 | return obj 284 | 285 | 286 | def get_obj_by_name(name: str) -> Any: 287 | """Finds the python object with the given name.""" 288 | module, obj_name = get_module_from_obj_name(name) 289 | return get_obj_from_module(module, obj_name) 290 | 291 | 292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 293 | """Finds the python object with the given name and calls it as a function.""" 294 | assert func_name is not None 295 | func_obj = get_obj_by_name(func_name) 296 | assert callable(func_obj) 297 | return func_obj(*args, **kwargs) 298 | 299 | 300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 301 | """Finds the python class with the given name and constructs it with the given arguments.""" 302 | return call_func_by_name(*args, func_name=class_name, **kwargs) 303 | 304 | 305 | def get_module_dir_by_obj_name(obj_name: str) -> str: 306 | """Get the directory path of the module containing the given object name.""" 307 | module, _ = get_module_from_obj_name(obj_name) 308 | return os.path.dirname(inspect.getfile(module)) 309 | 310 | 311 | def is_top_level_function(obj: Any) -> bool: 312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 314 | 315 | 316 | def get_top_level_function_name(obj: Any) -> str: 317 | """Return the fully-qualified name of a top-level function.""" 318 | assert is_top_level_function(obj) 319 | module = obj.__module__ 320 | if module == '__main__': 321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 322 | return module + "." + obj.__name__ 323 | 324 | 325 | # File system helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 329 | """List all files recursively in a given directory while ignoring given file and directory names. 330 | Returns list of tuples containing both absolute and relative paths.""" 331 | assert os.path.isdir(dir_path) 332 | base_name = os.path.basename(os.path.normpath(dir_path)) 333 | 334 | if ignores is None: 335 | ignores = [] 336 | 337 | result = [] 338 | 339 | for root, dirs, files in os.walk(dir_path, topdown=True): 340 | for ignore_ in ignores: 341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 342 | 343 | # dirs need to be edited in-place 344 | for d in dirs_to_remove: 345 | dirs.remove(d) 346 | 347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 348 | 349 | absolute_paths = [os.path.join(root, f) for f in files] 350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 351 | 352 | if add_base_to_relative: 353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 354 | 355 | assert len(absolute_paths) == len(relative_paths) 356 | result += zip(absolute_paths, relative_paths) 357 | 358 | return result 359 | 360 | 361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 362 | """Takes in a list of tuples of (src, dst) paths and copies files. 363 | Will create all necessary directories.""" 364 | for file in files: 365 | target_dir_name = os.path.dirname(file[1]) 366 | 367 | # will create all intermediate-level directories 368 | if not os.path.exists(target_dir_name): 369 | os.makedirs(target_dir_name) 370 | 371 | shutil.copyfile(file[0], file[1]) 372 | 373 | 374 | # URL helpers 375 | # ------------------------------------------------------------------------------------------ 376 | 377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 378 | """Determine whether the given object is a valid URL string.""" 379 | if not isinstance(obj, str) or not "://" in obj: 380 | return False 381 | if allow_file_urls and obj.startswith('file://'): 382 | return True 383 | try: 384 | res = requests.compat.urlparse(obj) 385 | if not res.scheme or not res.netloc or not "." in res.netloc: 386 | return False 387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 388 | if not res.scheme or not res.netloc or not "." in res.netloc: 389 | return False 390 | except: 391 | return False 392 | return True 393 | 394 | 395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 396 | """Download the given URL and return a binary-mode file object to access the data.""" 397 | assert num_attempts >= 1 398 | assert not (return_filename and (not cache)) 399 | 400 | # Doesn't look like an URL scheme so interpret it as a local filename. 401 | if not re.match('^[a-z]+://', url): 402 | return url if return_filename else open(url, "rb") 403 | 404 | # Handle file URLs. This code handles unusual file:// patterns that 405 | # arise on Windows: 406 | # 407 | # file:///c:/foo.txt 408 | # 409 | # which would translate to a local '/c:/foo.txt' filename that's 410 | # invalid. Drop the forward slash for such pathnames. 411 | # 412 | # If you touch this code path, you should test it on both Linux and 413 | # Windows. 414 | # 415 | # Some internet resources suggest using urllib.request.url2pathname() but 416 | # but that converts forward slashes to backslashes and this causes 417 | # its own set of problems. 418 | if url.startswith('file://'): 419 | filename = urllib.parse.urlparse(url).path 420 | if re.match(r'^/[a-zA-Z]:', filename): 421 | filename = filename[1:] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | assert is_url(url) 425 | 426 | # Lookup from cache. 427 | if cache_dir is None: 428 | cache_dir = make_cache_dir_path('downloads') 429 | 430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 431 | if cache: 432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 433 | if len(cache_files) == 1: 434 | filename = cache_files[0] 435 | return filename if return_filename else open(filename, "rb") 436 | 437 | # Download. 438 | url_name = None 439 | url_data = None 440 | with requests.Session() as session: 441 | if verbose: 442 | print("Downloading %s ..." % url, end="", flush=True) 443 | for attempts_left in reversed(range(num_attempts)): 444 | try: 445 | with session.get(url) as res: 446 | res.raise_for_status() 447 | if len(res.content) == 0: 448 | raise IOError("No data received") 449 | 450 | if len(res.content) < 8192: 451 | content_str = res.content.decode("utf-8") 452 | if "download_warning" in res.headers.get("Set-Cookie", ""): 453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 454 | if len(links) == 1: 455 | url = requests.compat.urljoin(url, links[0]) 456 | raise IOError("Google Drive virus checker nag") 457 | if "Google Drive - Quota exceeded" in content_str: 458 | raise IOError("Google Drive download quota exceeded -- please try again later") 459 | 460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 461 | url_name = match[1] if match else url 462 | url_data = res.content 463 | if verbose: 464 | print(" done") 465 | break 466 | except KeyboardInterrupt: 467 | raise 468 | except: 469 | if not attempts_left: 470 | if verbose: 471 | print(" failed") 472 | raise 473 | if verbose: 474 | print(".", end="", flush=True) 475 | 476 | # Save to cache. 477 | if cache: 478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 479 | safe_name = safe_name[:min(len(safe_name), 128)] 480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 482 | os.makedirs(cache_dir, exist_ok=True) 483 | with open(temp_file, "wb") as f: 484 | f.write(url_data) 485 | os.replace(temp_file, cache_file) # atomic 486 | if return_filename: 487 | return cache_file 488 | 489 | # Return data as file object. 490 | assert not return_filename 491 | return io.BytesIO(url_data) 492 | -------------------------------------------------------------------------------- /dataset_tool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Tool for creating ZIP/PNG based datasets.""" 9 | 10 | import functools 11 | import gzip 12 | import io 13 | import json 14 | import os 15 | import pickle 16 | import re 17 | import sys 18 | import tarfile 19 | import zipfile 20 | from pathlib import Path 21 | from typing import Callable, Optional, Tuple, Union 22 | import click 23 | import numpy as np 24 | import PIL.Image 25 | from tqdm import tqdm 26 | 27 | if not hasattr(PIL.Image, 'Resampling'): # Pillow<9.0 28 | PIL.Image.Resampling = PIL.Image 29 | 30 | #---------------------------------------------------------------------------- 31 | # Parse a 'M,N' or 'MxN' integer tuple. 32 | # Example: '4x2' returns (4,2) 33 | 34 | def parse_tuple(s: str) -> Tuple[int, int]: 35 | m = re.match(r'^(\d+)[x,](\d+)$', s) 36 | if m: 37 | return int(m.group(1)), int(m.group(2)) 38 | raise click.ClickException(f'cannot parse tuple {s}') 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def maybe_min(a: int, b: Optional[int]) -> int: 43 | if b is not None: 44 | return min(a, b) 45 | return a 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def file_ext(name: Union[str, Path]) -> str: 50 | return str(name).split('.')[-1] 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | def is_image_ext(fname: Union[str, Path]) -> bool: 55 | ext = file_ext(fname).lower() 56 | return f'.{ext}' in PIL.Image.EXTENSION 57 | 58 | #---------------------------------------------------------------------------- 59 | 60 | def open_image_folder(source_dir, *, max_images: Optional[int]): 61 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] 62 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images} 63 | max_idx = maybe_min(len(input_images), max_images) 64 | 65 | # Load labels. 66 | labels = dict() 67 | meta_fname = os.path.join(source_dir, 'dataset.json') 68 | if os.path.isfile(meta_fname): 69 | with open(meta_fname, 'r') as file: 70 | data = json.load(file)['labels'] 71 | if data is not None: 72 | labels = {x[0]: x[1] for x in data} 73 | 74 | # No labels available => determine from top-level directory names. 75 | if len(labels) == 0: 76 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()} 77 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))} 78 | if len(toplevel_indices) > 1: 79 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()} 80 | 81 | def iterate_images(): 82 | for idx, fname in enumerate(input_images): 83 | img = np.array(PIL.Image.open(fname)) 84 | yield dict(img=img, label=labels.get(arch_fnames.get(fname))) 85 | if idx >= max_idx - 1: 86 | break 87 | return max_idx, iterate_images() 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | def open_image_zip(source, *, max_images: Optional[int]): 92 | with zipfile.ZipFile(source, mode='r') as z: 93 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] 94 | max_idx = maybe_min(len(input_images), max_images) 95 | 96 | # Load labels. 97 | labels = dict() 98 | if 'dataset.json' in z.namelist(): 99 | with z.open('dataset.json', 'r') as file: 100 | data = json.load(file)['labels'] 101 | if data is not None: 102 | labels = {x[0]: x[1] for x in data} 103 | 104 | def iterate_images(): 105 | with zipfile.ZipFile(source, mode='r') as z: 106 | for idx, fname in enumerate(input_images): 107 | with z.open(fname, 'r') as file: 108 | img = np.array(PIL.Image.open(file)) 109 | yield dict(img=img, label=labels.get(fname)) 110 | if idx >= max_idx - 1: 111 | break 112 | return max_idx, iterate_images() 113 | 114 | #---------------------------------------------------------------------------- 115 | 116 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): 117 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python 118 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb 119 | 120 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 121 | max_idx = maybe_min(txn.stat()['entries'], max_images) 122 | 123 | def iterate_images(): 124 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 125 | for idx, (_key, value) in enumerate(txn.cursor()): 126 | try: 127 | try: 128 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) 129 | if img is None: 130 | raise IOError('cv2.imdecode failed') 131 | img = img[:, :, ::-1] # BGR => RGB 132 | except IOError: 133 | img = np.array(PIL.Image.open(io.BytesIO(value))) 134 | yield dict(img=img, label=None) 135 | if idx >= max_idx - 1: 136 | break 137 | except: 138 | print(sys.exc_info()[1]) 139 | 140 | return max_idx, iterate_images() 141 | 142 | #---------------------------------------------------------------------------- 143 | 144 | def open_cifar10(tarball: str, *, max_images: Optional[int]): 145 | images = [] 146 | labels = [] 147 | 148 | with tarfile.open(tarball, 'r:gz') as tar: 149 | for batch in range(1, 6): 150 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') 151 | with tar.extractfile(member) as file: 152 | data = pickle.load(file, encoding='latin1') 153 | images.append(data['data'].reshape(-1, 3, 32, 32)) 154 | labels.append(data['labels']) 155 | 156 | images = np.concatenate(images) 157 | labels = np.concatenate(labels) 158 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC 159 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 160 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] 161 | assert np.min(images) == 0 and np.max(images) == 255 162 | assert np.min(labels) == 0 and np.max(labels) == 9 163 | 164 | max_idx = maybe_min(len(images), max_images) 165 | 166 | def iterate_images(): 167 | for idx, img in enumerate(images): 168 | yield dict(img=img, label=int(labels[idx])) 169 | if idx >= max_idx - 1: 170 | break 171 | 172 | return max_idx, iterate_images() 173 | 174 | #---------------------------------------------------------------------------- 175 | 176 | def open_mnist(images_gz: str, *, max_images: Optional[int]): 177 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') 178 | assert labels_gz != images_gz 179 | images = [] 180 | labels = [] 181 | 182 | with gzip.open(images_gz, 'rb') as f: 183 | images = np.frombuffer(f.read(), np.uint8, offset=16) 184 | with gzip.open(labels_gz, 'rb') as f: 185 | labels = np.frombuffer(f.read(), np.uint8, offset=8) 186 | 187 | images = images.reshape(-1, 28, 28) 188 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) 189 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 190 | assert labels.shape == (60000,) and labels.dtype == np.uint8 191 | assert np.min(images) == 0 and np.max(images) == 255 192 | assert np.min(labels) == 0 and np.max(labels) == 9 193 | 194 | max_idx = maybe_min(len(images), max_images) 195 | 196 | def iterate_images(): 197 | for idx, img in enumerate(images): 198 | yield dict(img=img, label=int(labels[idx])) 199 | if idx >= max_idx - 1: 200 | break 201 | 202 | return max_idx, iterate_images() 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def make_transform( 207 | transform: Optional[str], 208 | output_width: Optional[int], 209 | output_height: Optional[int] 210 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]: 211 | def scale(width, height, img): 212 | w = img.shape[1] 213 | h = img.shape[0] 214 | if width == w and height == h: 215 | return img 216 | img = PIL.Image.fromarray(img) 217 | ww = width if width is not None else w 218 | hh = height if height is not None else h 219 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) 220 | return np.array(img) 221 | 222 | def center_crop(width, height, img): 223 | crop = np.min(img.shape[:2]) 224 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] 225 | if img.ndim == 2: 226 | img = img[:, :, np.newaxis].repeat(3, axis=2) 227 | img = PIL.Image.fromarray(img, 'RGB') 228 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 229 | return np.array(img) 230 | 231 | def center_crop_wide(width, height, img): 232 | ch = int(np.round(width * img.shape[0] / img.shape[1])) 233 | if img.shape[1] < width or ch < height: 234 | return None 235 | 236 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] 237 | if img.ndim == 2: 238 | img = img[:, :, np.newaxis].repeat(3, axis=2) 239 | img = PIL.Image.fromarray(img, 'RGB') 240 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 241 | img = np.array(img) 242 | 243 | canvas = np.zeros([width, width, 3], dtype=np.uint8) 244 | canvas[(width - height) // 2 : (width + height) // 2, :] = img 245 | return canvas 246 | 247 | if transform is None: 248 | return functools.partial(scale, output_width, output_height) 249 | if transform == 'center-crop': 250 | if output_width is None or output_height is None: 251 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform') 252 | return functools.partial(center_crop, output_width, output_height) 253 | if transform == 'center-crop-wide': 254 | if output_width is None or output_height is None: 255 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 256 | return functools.partial(center_crop_wide, output_width, output_height) 257 | assert False, 'unknown transform' 258 | 259 | #---------------------------------------------------------------------------- 260 | 261 | def open_dataset(source, *, max_images: Optional[int]): 262 | if os.path.isdir(source): 263 | if source.rstrip('/').endswith('_lmdb'): 264 | return open_lmdb(source, max_images=max_images) 265 | else: 266 | return open_image_folder(source, max_images=max_images) 267 | elif os.path.isfile(source): 268 | if os.path.basename(source) == 'cifar-10-python.tar.gz': 269 | return open_cifar10(source, max_images=max_images) 270 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': 271 | return open_mnist(source, max_images=max_images) 272 | elif file_ext(source) == 'zip': 273 | return open_image_zip(source, max_images=max_images) 274 | else: 275 | assert False, 'unknown archive type' 276 | else: 277 | raise click.ClickException(f'Missing input file or directory: {source}') 278 | 279 | #---------------------------------------------------------------------------- 280 | 281 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: 282 | dest_ext = file_ext(dest) 283 | 284 | if dest_ext == 'zip': 285 | if os.path.dirname(dest) != '': 286 | os.makedirs(os.path.dirname(dest), exist_ok=True) 287 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) 288 | def zip_write_bytes(fname: str, data: Union[bytes, str]): 289 | zf.writestr(fname, data) 290 | return '', zip_write_bytes, zf.close 291 | else: 292 | # If the output folder already exists, check that is is 293 | # empty. 294 | # 295 | # Note: creating the output directory is not strictly 296 | # necessary as folder_write_bytes() also mkdirs, but it's better 297 | # to give an error message earlier in case the dest folder 298 | # somehow cannot be created. 299 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0: 300 | raise click.ClickException('--dest folder must be empty') 301 | os.makedirs(dest, exist_ok=True) 302 | 303 | def folder_write_bytes(fname: str, data: Union[bytes, str]): 304 | os.makedirs(os.path.dirname(fname), exist_ok=True) 305 | with open(fname, 'wb') as fout: 306 | if isinstance(data, str): 307 | data = data.encode('utf8') 308 | fout.write(data) 309 | return dest, folder_write_bytes, lambda: None 310 | 311 | #---------------------------------------------------------------------------- 312 | 313 | @click.command() 314 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 315 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 316 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 317 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide'])) 318 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple) 319 | 320 | def main( 321 | source: str, 322 | dest: str, 323 | max_images: Optional[int], 324 | transform: Optional[str], 325 | resolution: Optional[Tuple[int, int]] 326 | ): 327 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. 328 | 329 | The input dataset format is guessed from the --source argument: 330 | 331 | \b 332 | --source *_lmdb/ Load LSUN dataset 333 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 334 | --source train-images-idx3-ubyte.gz Load MNIST dataset 335 | --source path/ Recursively load all images from path/ 336 | --source dataset.zip Recursively load all images from dataset.zip 337 | 338 | Specifying the output format and path: 339 | 340 | \b 341 | --dest /path/to/dir Save output files under /path/to/dir 342 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 343 | 344 | The output dataset format can be either an image folder or an uncompressed zip archive. 345 | Zip archives makes it easier to move datasets around file servers and clusters, and may 346 | offer better training performance on network file systems. 347 | 348 | Images within the dataset archive will be stored as uncompressed PNG. 349 | Uncompresed PNGs can be efficiently decoded in the training loop. 350 | 351 | Class labels are stored in a file called 'dataset.json' that is stored at the 352 | dataset root folder. This file has the following structure: 353 | 354 | \b 355 | { 356 | "labels": [ 357 | ["00000/img00000000.png",6], 358 | ["00000/img00000001.png",9], 359 | ... repeated for every image in the datase 360 | ["00049/img00049999.png",1] 361 | ] 362 | } 363 | 364 | If the 'dataset.json' file cannot be found, class labels are determined from 365 | top-level directory names. 366 | 367 | Image scale/crop and resolution requirements: 368 | 369 | Output images must be square-shaped and they must all have the same power-of-two 370 | dimensions. 371 | 372 | To scale arbitrary input image size to a specific width and height, use the 373 | --resolution option. Output resolution will be either the original 374 | input resolution (if resolution was not specified) or the one specified with 375 | --resolution option. 376 | 377 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a 378 | center crop transform on the input image. These options should be used with the 379 | --resolution option. For example: 380 | 381 | \b 382 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ 383 | --transform=center-crop-wide --resolution=512x384 384 | """ 385 | 386 | PIL.Image.init() 387 | 388 | if dest == '': 389 | raise click.ClickException('--dest output filename or directory must not be an empty string') 390 | 391 | num_files, input_iter = open_dataset(source, max_images=max_images) 392 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 393 | 394 | if resolution is None: resolution = (None, None) 395 | transform_image = make_transform(transform, *resolution) 396 | 397 | dataset_attrs = None 398 | 399 | labels = [] 400 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 401 | idx_str = f'{idx:08d}' 402 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png' 403 | 404 | # Apply crop and resize. 405 | img = transform_image(image['img']) 406 | if img is None: 407 | continue 408 | 409 | # Error check to require uniform image attributes across 410 | # the whole dataset. 411 | channels = img.shape[2] if img.ndim == 3 else 1 412 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels} 413 | if dataset_attrs is None: 414 | dataset_attrs = cur_image_attrs 415 | width = dataset_attrs['width'] 416 | height = dataset_attrs['height'] 417 | if width != height: 418 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') 419 | if dataset_attrs['channels'] not in [1, 3]: 420 | raise click.ClickException('Input images must be stored as RGB or grayscale') 421 | if width != 2 ** int(np.floor(np.log2(width))): 422 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two') 423 | elif dataset_attrs != cur_image_attrs: 424 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] 425 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) 426 | 427 | # Save the image as an uncompressed PNG. 428 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels]) 429 | image_bits = io.BytesIO() 430 | img.save(image_bits, format='png', compress_level=0, optimize=False) 431 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) 432 | labels.append([archive_fname, image['label']] if image['label'] is not None else None) 433 | 434 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 435 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 436 | close_dest() 437 | 438 | #---------------------------------------------------------------------------- 439 | 440 | if __name__ == "__main__": 441 | main() 442 | 443 | #---------------------------------------------------------------------------- 444 | -------------------------------------------------------------------------------- /training/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Augmentation pipeline used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models". 10 | Built around the same concepts that were originally proposed in the paper 11 | "Training Generative Adversarial Networks with Limited Data".""" 12 | 13 | import numpy as np 14 | import torch 15 | from torch_utils import persistence 16 | from torch_utils import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | # Coefficients of various wavelet decomposition low-pass filters. 20 | 21 | wavelets = { 22 | 'haar': [0.7071067811865476, 0.7071067811865476], 23 | 'db1': [0.7071067811865476, 0.7071067811865476], 24 | 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 25 | 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 26 | 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], 27 | 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], 28 | 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], 29 | 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], 30 | 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], 31 | 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 32 | 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 33 | 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], 34 | 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], 35 | 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], 36 | 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], 37 | 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], 38 | } 39 | 40 | #---------------------------------------------------------------------------- 41 | # Helpers for constructing transformation matrices. 42 | 43 | def matrix(*rows, device=None): 44 | assert all(len(row) == len(rows[0]) for row in rows) 45 | elems = [x for row in rows for x in row] 46 | ref = [x for x in elems if isinstance(x, torch.Tensor)] 47 | if len(ref) == 0: 48 | return misc.constant(np.asarray(rows), device=device) 49 | assert device is None or device == ref[0].device 50 | elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] 51 | return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) 52 | 53 | def translate2d(tx, ty, **kwargs): 54 | return matrix( 55 | [1, 0, tx], 56 | [0, 1, ty], 57 | [0, 0, 1], 58 | **kwargs) 59 | 60 | def translate3d(tx, ty, tz, **kwargs): 61 | return matrix( 62 | [1, 0, 0, tx], 63 | [0, 1, 0, ty], 64 | [0, 0, 1, tz], 65 | [0, 0, 0, 1], 66 | **kwargs) 67 | 68 | def scale2d(sx, sy, **kwargs): 69 | return matrix( 70 | [sx, 0, 0], 71 | [0, sy, 0], 72 | [0, 0, 1], 73 | **kwargs) 74 | 75 | def scale3d(sx, sy, sz, **kwargs): 76 | return matrix( 77 | [sx, 0, 0, 0], 78 | [0, sy, 0, 0], 79 | [0, 0, sz, 0], 80 | [0, 0, 0, 1], 81 | **kwargs) 82 | 83 | def rotate2d(theta, **kwargs): 84 | return matrix( 85 | [torch.cos(theta), torch.sin(-theta), 0], 86 | [torch.sin(theta), torch.cos(theta), 0], 87 | [0, 0, 1], 88 | **kwargs) 89 | 90 | def rotate3d(v, theta, **kwargs): 91 | vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] 92 | s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c 93 | return matrix( 94 | [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], 95 | [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], 96 | [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], 97 | [0, 0, 0, 1], 98 | **kwargs) 99 | 100 | def translate2d_inv(tx, ty, **kwargs): 101 | return translate2d(-tx, -ty, **kwargs) 102 | 103 | def scale2d_inv(sx, sy, **kwargs): 104 | return scale2d(1 / sx, 1 / sy, **kwargs) 105 | 106 | def rotate2d_inv(theta, **kwargs): 107 | return rotate2d(-theta, **kwargs) 108 | 109 | #---------------------------------------------------------------------------- 110 | # Augmentation pipeline main class. 111 | # All augmentations are disabled by default; individual augmentations can 112 | # be enabled by setting their probability multipliers to 1. 113 | 114 | @persistence.persistent_class 115 | class AugmentPipe: 116 | def __init__(self, p=1, 117 | xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125, 118 | scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125, 119 | brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, 120 | ): 121 | super().__init__() 122 | self.p = float(p) # Overall multiplier for augmentation probability. 123 | 124 | # Pixel blitting. 125 | self.xflip = float(xflip) # Probability multiplier for x-flip. 126 | self.yflip = float(yflip) # Probability multiplier for y-flip. 127 | self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation. 128 | self.translate_int = float(translate_int) # Probability multiplier for integer translation. 129 | self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions. 130 | 131 | # Geometric transformations. 132 | self.scale = float(scale) # Probability multiplier for isotropic scaling. 133 | self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation. 134 | self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. 135 | self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation. 136 | self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. 137 | self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle. 138 | self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. 139 | self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame. 140 | self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions. 141 | 142 | # Color transformations. 143 | self.brightness = float(brightness) # Probability multiplier for brightness. 144 | self.contrast = float(contrast) # Probability multiplier for contrast. 145 | self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. 146 | self.hue = float(hue) # Probability multiplier for hue rotation. 147 | self.saturation = float(saturation) # Probability multiplier for saturation. 148 | self.brightness_std = float(brightness_std) # Standard deviation of brightness. 149 | self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. 150 | self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. 151 | self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. 152 | 153 | def __call__(self, images): 154 | N, C, H, W = images.shape 155 | device = images.device 156 | labels = [torch.zeros([images.shape[0], 0], device=device)] 157 | 158 | # --------------- 159 | # Pixel blitting. 160 | # --------------- 161 | 162 | if self.xflip > 0: 163 | w = torch.randint(2, [N, 1, 1, 1], device=device) 164 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w)) 165 | images = torch.where(w == 1, images.flip(3), images) 166 | labels += [w] 167 | 168 | if self.yflip > 0: 169 | w = torch.randint(2, [N, 1, 1, 1], device=device) 170 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w)) 171 | images = torch.where(w == 1, images.flip(2), images) 172 | labels += [w] 173 | 174 | if self.rotate_int > 0: 175 | w = torch.randint(4, [N, 1, 1, 1], device=device) 176 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w)) 177 | images = torch.where((w == 1) | (w == 2), images.flip(3), images) 178 | images = torch.where((w == 2) | (w == 3), images.flip(2), images) 179 | images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images) 180 | labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)] 181 | 182 | if self.translate_int > 0: 183 | w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1 184 | w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w)) 185 | tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64) 186 | ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64) 187 | b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij') 188 | x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs() 189 | y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs() 190 | images = images.flatten()[(((b * C) + c) * H + y) * W + x] 191 | labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)] 192 | 193 | # ------------------------------------------------ 194 | # Select parameters for geometric transformations. 195 | # ------------------------------------------------ 196 | 197 | I_3 = torch.eye(3, device=device) 198 | G_inv = I_3 199 | 200 | if self.scale > 0: 201 | w = torch.randn([N], device=device) 202 | w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w)) 203 | s = w.mul(self.scale_std).exp2() 204 | G_inv = G_inv @ scale2d_inv(s, s) 205 | labels += [w] 206 | 207 | if self.rotate_frac > 0: 208 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max) 209 | w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w)) 210 | G_inv = G_inv @ rotate2d_inv(-w) 211 | labels += [w.cos() - 1, w.sin()] 212 | 213 | if self.aniso > 0: 214 | w = torch.randn([N], device=device) 215 | r = (torch.rand([N], device=device) * 2 - 1) * np.pi 216 | w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w)) 217 | r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r)) 218 | s = w.mul(self.aniso_std).exp2() 219 | G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r) 220 | labels += [w * r.cos(), w * r.sin()] 221 | 222 | if self.translate_frac > 0: 223 | w = torch.randn([2, N], device=device) 224 | w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w)) 225 | G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std)) 226 | labels += [w[0], w[1]] 227 | 228 | # ---------------------------------- 229 | # Execute geometric transformations. 230 | # ---------------------------------- 231 | 232 | if G_inv is not I_3: 233 | cx = (W - 1) / 2 234 | cy = (H - 1) / 2 235 | cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] 236 | cp = G_inv @ cp.t() # [batch, xyz, idx] 237 | Hz = np.asarray(wavelets['sym6'], dtype=np.float32) 238 | Hz_pad = len(Hz) // 4 239 | margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] 240 | margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] 241 | margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) 242 | margin = margin.max(misc.constant([0, 0] * 2, device=device)) 243 | margin = margin.min(misc.constant([W - 1, H - 1] * 2, device=device)) 244 | mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) 245 | 246 | # Pad image and adjust origin. 247 | images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') 248 | G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv 249 | 250 | # Upsample. 251 | conv_weight = misc.constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) 252 | conv_pad = (len(Hz) + 1) // 2 253 | images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1] 254 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad]) 255 | images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :] 256 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0]) 257 | G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) 258 | G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) 259 | 260 | # Execute transformation. 261 | shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2] 262 | G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) 263 | grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) 264 | images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False) 265 | 266 | # Downsample and crop. 267 | conv_weight = misc.constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) 268 | conv_pad = (len(Hz) - 1) // 2 269 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad] 270 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :] 271 | 272 | # -------------------------------------------- 273 | # Select parameters for color transformations. 274 | # -------------------------------------------- 275 | 276 | I_4 = torch.eye(4, device=device) 277 | M = I_4 278 | luma_axis = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) 279 | 280 | if self.brightness > 0: 281 | w = torch.randn([N], device=device) 282 | w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w)) 283 | b = w * self.brightness_std 284 | M = translate3d(b, b, b) @ M 285 | labels += [w] 286 | 287 | if self.contrast > 0: 288 | w = torch.randn([N], device=device) 289 | w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w)) 290 | c = w.mul(self.contrast_std).exp2() 291 | M = scale3d(c, c, c) @ M 292 | labels += [w] 293 | 294 | if self.lumaflip > 0: 295 | w = torch.randint(2, [N, 1, 1], device=device) 296 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w)) 297 | M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M 298 | labels += [w] 299 | 300 | if self.hue > 0: 301 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max) 302 | w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w)) 303 | M = rotate3d(luma_axis, w) @ M 304 | labels += [w.cos() - 1, w.sin()] 305 | 306 | if self.saturation > 0: 307 | w = torch.randn([N, 1, 1], device=device) 308 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w)) 309 | M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M 310 | labels += [w] 311 | 312 | # ------------------------------ 313 | # Execute color transformations. 314 | # ------------------------------ 315 | 316 | if M is not I_4: 317 | images = images.reshape([N, C, H * W]) 318 | if C == 3: 319 | images = M[:, :3, :3] @ images + M[:, :3, 3:] 320 | elif C == 1: 321 | M = M[:, :3, :].mean(dim=1, keepdims=True) 322 | images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:] 323 | else: 324 | raise ValueError('Image must be RGB (3 channels) or L (1 channel)') 325 | images = images.reshape([N, C, H, W]) 326 | 327 | labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1) 328 | return images, labels 329 | 330 | #---------------------------------------------------------------------------- 331 | -------------------------------------------------------------------------------- /training/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import torch 18 | from torch_utils import persistence 19 | from torch.nn.functional import silu 20 | 21 | #---------------------------------------------------------------------------- 22 | # Unified routine for initializing weights and biases. 23 | 24 | def weight_init(shape, mode, fan_in, fan_out): 25 | if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) 26 | if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) 27 | if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) 28 | if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) 29 | raise ValueError(f'Invalid init mode "{mode}"') 30 | 31 | #---------------------------------------------------------------------------- 32 | # Fully-connected layer. 33 | 34 | @persistence.persistent_class 35 | class Linear(torch.nn.Module): 36 | def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): 37 | super().__init__() 38 | self.in_features = in_features 39 | self.out_features = out_features 40 | init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) 41 | self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) 42 | self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None 43 | 44 | def forward(self, x): 45 | x = x @ self.weight.to(x.dtype).t() 46 | if self.bias is not None: 47 | x = x.add_(self.bias.to(x.dtype)) 48 | return x 49 | 50 | #---------------------------------------------------------------------------- 51 | # Convolutional layer with optional up/downsampling. 52 | 53 | @persistence.persistent_class 54 | class Conv2d(torch.nn.Module): 55 | def __init__(self, 56 | in_channels, out_channels, kernel, bias=True, up=False, down=False, 57 | resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0, 58 | ): 59 | assert not (up and down) 60 | super().__init__() 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.up = up 64 | self.down = down 65 | self.fused_resample = fused_resample 66 | init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel) 67 | self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None 68 | self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None 69 | f = torch.as_tensor(resample_filter, dtype=torch.float32) 70 | f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() 71 | self.register_buffer('resample_filter', f if up or down else None) 72 | 73 | def forward(self, x): 74 | w = self.weight.to(x.dtype) if self.weight is not None else None 75 | b = self.bias.to(x.dtype) if self.bias is not None else None 76 | f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None 77 | w_pad = w.shape[-1] // 2 if w is not None else 0 78 | f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 79 | 80 | if self.fused_resample and self.up and w is not None: 81 | x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0)) 82 | x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) 83 | elif self.fused_resample and self.down and w is not None: 84 | x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad) 85 | x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2) 86 | else: 87 | if self.up: 88 | x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) 89 | if self.down: 90 | x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) 91 | if w is not None: 92 | x = torch.nn.functional.conv2d(x, w, padding=w_pad) 93 | if b is not None: 94 | x = x.add_(b.reshape(1, -1, 1, 1)) 95 | return x 96 | 97 | #---------------------------------------------------------------------------- 98 | # Group normalization. 99 | 100 | @persistence.persistent_class 101 | class GroupNorm(torch.nn.Module): 102 | def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5): 103 | super().__init__() 104 | self.num_groups = min(num_groups, num_channels // min_channels_per_group) 105 | self.eps = eps 106 | self.weight = torch.nn.Parameter(torch.ones(num_channels)) 107 | self.bias = torch.nn.Parameter(torch.zeros(num_channels)) 108 | 109 | def forward(self, x): 110 | x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps) 111 | return x 112 | 113 | #---------------------------------------------------------------------------- 114 | # Attention weight computation, i.e., softmax(Q^T * K). 115 | # Performs all computation using FP32, but uses the original datatype for 116 | # inputs/outputs/gradients to conserve memory. 117 | 118 | class AttentionOp(torch.autograd.Function): 119 | @staticmethod 120 | def forward(ctx, q, k): 121 | w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype) 122 | ctx.save_for_backward(q, k, w) 123 | return w 124 | 125 | @staticmethod 126 | def backward(ctx, dw): 127 | q, k, w = ctx.saved_tensors 128 | db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32) 129 | dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1]) 130 | dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1]) 131 | return dq, dk 132 | 133 | #---------------------------------------------------------------------------- 134 | # Unified U-Net block with optional up/downsampling and self-attention. 135 | # Represents the union of all features employed by the DDPM++, NCSN++, and 136 | # ADM architectures. 137 | 138 | @persistence.persistent_class 139 | class UNetBlock(torch.nn.Module): 140 | def __init__(self, 141 | in_channels, out_channels, emb_channels, up=False, down=False, attention=False, 142 | num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5, 143 | resample_filter=[1,1], resample_proj=False, adaptive_scale=True, 144 | init=dict(), init_zero=dict(init_weight=0), init_attn=None, 145 | ): 146 | super().__init__() 147 | self.in_channels = in_channels 148 | self.out_channels = out_channels 149 | self.emb_channels = emb_channels 150 | self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head 151 | self.dropout = dropout 152 | self.skip_scale = skip_scale 153 | self.adaptive_scale = adaptive_scale 154 | 155 | self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) 156 | self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init) 157 | self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init) 158 | self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) 159 | self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero) 160 | 161 | self.skip = None 162 | if out_channels != in_channels or up or down: 163 | kernel = 1 if resample_proj or out_channels!= in_channels else 0 164 | self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init) 165 | 166 | if self.num_heads: 167 | self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) 168 | self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init)) 169 | self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero) 170 | 171 | def forward(self, x, emb): 172 | orig = x 173 | x = self.conv0(silu(self.norm0(x))) 174 | 175 | params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) 176 | if self.adaptive_scale: 177 | scale, shift = params.chunk(chunks=2, dim=1) 178 | x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) 179 | else: 180 | x = silu(self.norm1(x.add_(params))) 181 | 182 | x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training)) 183 | x = x.add_(self.skip(orig) if self.skip is not None else orig) 184 | x = x * self.skip_scale 185 | 186 | if self.num_heads: 187 | q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2) 188 | w = AttentionOp.apply(q, k) 189 | a = torch.einsum('nqk,nck->ncq', w, v) 190 | x = self.proj(a.reshape(*x.shape)).add_(x) 191 | x = x * self.skip_scale 192 | return x 193 | 194 | #---------------------------------------------------------------------------- 195 | # Timestep embedding used in the DDPM++ and ADM architectures. 196 | 197 | @persistence.persistent_class 198 | class PositionalEmbedding(torch.nn.Module): 199 | def __init__(self, num_channels, max_positions=10000, endpoint=False): 200 | super().__init__() 201 | self.num_channels = num_channels 202 | self.max_positions = max_positions 203 | self.endpoint = endpoint 204 | 205 | def forward(self, x): 206 | freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) 207 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) 208 | freqs = (1 / self.max_positions) ** freqs 209 | x = x.ger(freqs.to(x.dtype)) 210 | x = torch.cat([x.cos(), x.sin()], dim=1) 211 | return x 212 | 213 | #---------------------------------------------------------------------------- 214 | # Timestep embedding used in the NCSN++ architecture. 215 | 216 | @persistence.persistent_class 217 | class FourierEmbedding(torch.nn.Module): 218 | def __init__(self, num_channels, scale=16): 219 | super().__init__() 220 | self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) 221 | 222 | def forward(self, x): 223 | x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) 224 | x = torch.cat([x.cos(), x.sin()], dim=1) 225 | return x 226 | 227 | #---------------------------------------------------------------------------- 228 | # Reimplementation of the DDPM++ and NCSN++ architectures from the paper 229 | # "Score-Based Generative Modeling through Stochastic Differential 230 | # Equations". Equivalent to the original implementation by Song et al., 231 | # available at https://github.com/yang-song/score_sde_pytorch 232 | 233 | @persistence.persistent_class 234 | class SongUNet(torch.nn.Module): 235 | def __init__(self, 236 | img_resolution, # Image resolution at input/output. 237 | in_channels, # Number of color channels at input. 238 | out_channels, # Number of color channels at output. 239 | label_dim = 0, # Number of class labels, 0 = unconditional. 240 | augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation. 241 | 242 | model_channels = 128, # Base multiplier for the number of channels. 243 | channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels. 244 | channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector. 245 | num_blocks = 4, # Number of residual blocks per resolution. 246 | attn_resolutions = [16], # List of resolutions with self-attention. 247 | dropout = 0.10, # Dropout probability of intermediate activations. 248 | label_dropout = 0, # Dropout probability of class labels for classifier-free guidance. 249 | 250 | embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. 251 | channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++. 252 | encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. 253 | decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++. 254 | resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. 255 | ): 256 | assert embedding_type in ['fourier', 'positional'] 257 | assert encoder_type in ['standard', 'skip', 'residual'] 258 | assert decoder_type in ['standard', 'skip'] 259 | 260 | super().__init__() 261 | self.label_dropout = label_dropout 262 | emb_channels = model_channels * channel_mult_emb 263 | noise_channels = model_channels * channel_mult_noise 264 | init = dict(init_mode='xavier_uniform') 265 | init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5) 266 | init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2)) 267 | block_kwargs = dict( 268 | emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6, 269 | resample_filter=resample_filter, resample_proj=True, adaptive_scale=False, 270 | init=init, init_zero=init_zero, init_attn=init_attn, 271 | ) 272 | 273 | # Mapping. 274 | self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels) 275 | self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None 276 | self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None 277 | self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init) 278 | self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init) 279 | 280 | # Encoder. 281 | self.enc = torch.nn.ModuleDict() 282 | cout = in_channels 283 | caux = in_channels 284 | for level, mult in enumerate(channel_mult): 285 | res = img_resolution >> level 286 | if level == 0: 287 | cin = cout 288 | cout = model_channels 289 | self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init) 290 | else: 291 | self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs) 292 | if encoder_type == 'skip': 293 | self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter) 294 | self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init) 295 | if encoder_type == 'residual': 296 | self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init) 297 | caux = cout 298 | for idx in range(num_blocks): 299 | cin = cout 300 | cout = model_channels * mult 301 | attn = (res in attn_resolutions) 302 | self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs) 303 | skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name] 304 | 305 | # Decoder. 306 | self.dec = torch.nn.ModuleDict() 307 | for level, mult in reversed(list(enumerate(channel_mult))): 308 | res = img_resolution >> level 309 | if level == len(channel_mult) - 1: 310 | self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs) 311 | self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs) 312 | else: 313 | self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs) 314 | for idx in range(num_blocks + 1): 315 | cin = cout + skips.pop() 316 | cout = model_channels * mult 317 | attn = (idx == num_blocks and res in attn_resolutions) 318 | self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs) 319 | if decoder_type == 'skip' or level == 0: 320 | if decoder_type == 'skip' and level < len(channel_mult) - 1: 321 | self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter) 322 | self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6) 323 | self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero) 324 | 325 | def forward(self, x, noise_labels, class_labels, augment_labels=None): 326 | # Mapping. 327 | emb = self.map_noise(noise_labels) 328 | emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos 329 | if self.map_label is not None: 330 | tmp = class_labels 331 | if self.training and self.label_dropout: 332 | tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype) 333 | emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) 334 | if self.map_augment is not None and augment_labels is not None: 335 | emb = emb + self.map_augment(augment_labels) 336 | emb = silu(self.map_layer0(emb)) 337 | emb = silu(self.map_layer1(emb)) 338 | 339 | # Encoder. 340 | skips = [] 341 | aux = x 342 | for name, block in self.enc.items(): 343 | if 'aux_down' in name: 344 | aux = block(aux) 345 | elif 'aux_skip' in name: 346 | x = skips[-1] = x + block(aux) 347 | elif 'aux_residual' in name: 348 | x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) 349 | else: 350 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x) 351 | skips.append(x) 352 | 353 | # Decoder. 354 | aux = None 355 | tmp = None 356 | for name, block in self.dec.items(): 357 | if 'aux_up' in name: 358 | aux = block(aux) 359 | elif 'aux_norm' in name: 360 | tmp = block(x) 361 | elif 'aux_conv' in name: 362 | tmp = block(silu(tmp)) 363 | aux = tmp if aux is None else tmp + aux 364 | else: 365 | if x.shape[1] != block.in_channels: 366 | x = torch.cat([x, skips.pop()], dim=1) 367 | x = block(x, emb) 368 | return aux 369 | 370 | 371 | #---------------------------------------------------------------------------- 372 | # Preconditioning corresponding to the variance preserving (VP) formulation 373 | # from the paper "Score-Based Generative Modeling through Stochastic 374 | # Differential Equations". 375 | 376 | @persistence.persistent_class 377 | class VPPrecond(torch.nn.Module): 378 | def __init__(self, 379 | img_resolution, # Image resolution. 380 | img_channels, # Number of color channels. 381 | label_dim = 0, # Number of class labels, 0 = unconditional. 382 | use_fp16 = False, # Execute the underlying model at FP16 precision? 383 | beta_d = 19.9, # Extent of the noise level schedule. 384 | beta_min = 0.1, # Initial slope of the noise level schedule. 385 | M = 1000, # Original number of timesteps in the DDPM formulation. 386 | epsilon_t = 1e-5, # Minimum t-value used during training. 387 | model_type = 'SongUNet', # Class name of the underlying model. 388 | **model_kwargs, # Keyword arguments for the underlying model. 389 | ): 390 | super().__init__() 391 | self.img_resolution = img_resolution 392 | self.img_channels = img_channels 393 | self.label_dim = label_dim 394 | self.use_fp16 = use_fp16 395 | self.beta_d = beta_d 396 | self.beta_min = beta_min 397 | self.M = M 398 | self.epsilon_t = epsilon_t 399 | self.sigma_min = float(self.sigma(epsilon_t)) 400 | self.sigma_max = float(self.sigma(1)) 401 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) 402 | 403 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 404 | x = x.to(torch.float32) 405 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) 406 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 407 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 408 | 409 | c_skip = 1 410 | c_out = -sigma 411 | c_in = 1 / (sigma ** 2 + 1).sqrt() 412 | c_noise = (self.M - 1) * self.sigma_inv(sigma) 413 | 414 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) 415 | assert F_x.dtype == dtype 416 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 417 | return D_x 418 | 419 | def sigma(self, t): 420 | t = torch.as_tensor(t) 421 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 422 | 423 | def sigma_inv(self, sigma): 424 | sigma = torch.as_tensor(sigma) 425 | return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d 426 | 427 | def round_sigma(self, sigma): 428 | return torch.as_tensor(sigma) 429 | 430 | #---------------------------------------------------------------------------- 431 | # Preconditioning corresponding to the variance exploding (VE) formulation 432 | # from the paper "Score-Based Generative Modeling through Stochastic 433 | # Differential Equations". 434 | 435 | @persistence.persistent_class 436 | class VEPrecond(torch.nn.Module): 437 | def __init__(self, 438 | img_resolution, # Image resolution. 439 | img_channels, # Number of color channels. 440 | label_dim = 0, # Number of class labels, 0 = unconditional. 441 | use_fp16 = False, # Execute the underlying model at FP16 precision? 442 | sigma_min = 0.02, # Minimum supported noise level. 443 | sigma_max = 100, # Maximum supported noise level. 444 | model_type = 'SongUNet', # Class name of the underlying model. 445 | **model_kwargs, # Keyword arguments for the underlying model. 446 | ): 447 | super().__init__() 448 | self.img_resolution = img_resolution 449 | self.img_channels = img_channels 450 | self.label_dim = label_dim 451 | self.use_fp16 = use_fp16 452 | self.sigma_min = sigma_min 453 | self.sigma_max = sigma_max 454 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) 455 | 456 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 457 | x = x.to(torch.float32) 458 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) 459 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 460 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 461 | 462 | c_skip = 1 463 | c_out = sigma 464 | c_in = 1 465 | c_noise = (0.5 * sigma).log() 466 | 467 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) 468 | assert F_x.dtype == dtype 469 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 470 | return D_x 471 | 472 | def round_sigma(self, sigma): 473 | return torch.as_tensor(sigma) 474 | 475 | #---------------------------------------------------------------------------- 476 | # Preconditioning corresponding to the EDM formulation 477 | # from the paper "Elucidating the Design Space of Diffusion-Based 478 | # Generative Models". 479 | 480 | @persistence.persistent_class 481 | class EDMPrecond(torch.nn.Module): 482 | def __init__(self, 483 | img_resolution, # Image resolution. 484 | img_channels, # Number of color channels. 485 | label_dim = 0, # Number of class labels, 0 = unconditional. 486 | use_fp16 = False, # Execute the underlying model at FP16 precision? 487 | sigma_min = 0, # Minimum supported noise level. 488 | sigma_max = float('inf'), # Maximum supported noise level. 489 | sigma_data = 0.5, # Expected standard deviation of the training data. 490 | model_type = 'SongUNet', # Class name of the underlying model. 491 | **model_kwargs, # Keyword arguments for the underlying model. 492 | ): 493 | super().__init__() 494 | self.img_resolution = img_resolution 495 | self.img_channels = img_channels 496 | self.label_dim = label_dim 497 | self.use_fp16 = use_fp16 498 | self.sigma_min = sigma_min 499 | self.sigma_max = sigma_max 500 | self.sigma_data = sigma_data 501 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) 502 | 503 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 504 | x = x.to(torch.float32) 505 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) 506 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 507 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 508 | 509 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 510 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() 511 | c_in = 1 / (self.sigma_data ** 2 + (sigma) ** 2).sqrt() 512 | c_noise = sigma.log() / 4 513 | 514 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) 515 | assert F_x.dtype == dtype 516 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 517 | return D_x 518 | 519 | def round_sigma(self, sigma): 520 | return torch.as_tensor(sigma) 521 | 522 | 523 | #---------------------------------------------------------------------------- 524 | # Fast Diffusion Model with EDM Preconditioning (EDM-FDM) 525 | # from the paper "Elucidating the Design Space of Diffusion-Based 526 | # Generative Models". 527 | 528 | @persistence.persistent_class 529 | class FDM_EDMPrecond(torch.nn.Module): 530 | def __init__(self, 531 | img_resolution, # Image resolution. 532 | img_channels, # Number of color channels. 533 | label_dim = 0, # Number of class labels, 0 = unconditional. 534 | use_fp16 = False, # Execute the underlying model at FP16 precision? 535 | sigma_min = 0.002, # Minimum supported noise level. 536 | sigma_max = 80.0, # Maximum supported noise level. 537 | sigma_data = 0.5, # Expected standard deviation of the training data. 538 | model_type = 'SongUNet', # Class name of the underlying model. 539 | fdm_beta_d = 19.9, # Extent of the FDM noise level schedule. 540 | fdm_beta_min = 0.1, # Initial slope of the FDM noise level schedule. 541 | fdm_multiplier = 1.0, # Multiplier of the FDM noise level schedule. 542 | **model_kwargs, # Keyword arguments for the underlying model. 543 | ): 544 | super().__init__() 545 | self.img_resolution = img_resolution 546 | self.img_channels = img_channels 547 | self.label_dim = label_dim 548 | self.use_fp16 = use_fp16 549 | self.sigma_min = sigma_min 550 | self.sigma_max = sigma_max 551 | self.sigma_data = sigma_data 552 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) 553 | 554 | self.fdm_beta_d = fdm_beta_d 555 | self.fdm_beta_min = fdm_beta_min 556 | self.fdm_multiplier = fdm_multiplier 557 | 558 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 559 | x = x.to(torch.float32) 560 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) 561 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 562 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 563 | 564 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 565 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() 566 | c_in = self.s(sigma) 567 | c_noise = sigma.log() / 4 568 | 569 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) 570 | assert F_x.dtype == dtype 571 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 572 | return D_x 573 | 574 | def round_sigma(self, sigma): 575 | return torch.as_tensor(sigma) 576 | 577 | def fdm_sigma_inv(self, sigma): 578 | sigma = torch.as_tensor(sigma) 579 | return (sigma - self.sigma_min) / (self.sigma_max - self.sigma_min) 580 | 581 | def fdm_beta_fn(self, t): 582 | return self.fdm_beta_min * t + 0.5 * self.fdm_beta_d * t**2 583 | 584 | def s(self, sigma): 585 | t = self.fdm_sigma_inv(sigma) 586 | beta = self.fdm_beta_fn(t) 587 | return torch.exp(-self.fdm_multiplier * beta) * (1. + self.fdm_multiplier * beta) 588 | 589 | 590 | #---------------------------------------------------------------------------- 591 | # Fast Diffusion Model with VP Preconditioning (VP-FDM) 592 | # from the paper "Score-Based Generative Modeling through Stochastic 593 | # Differential Equations". 594 | 595 | @persistence.persistent_class 596 | class FDM_VPPrecond(torch.nn.Module): 597 | def __init__(self, 598 | img_resolution, # Image resolution. 599 | img_channels, # Number of color channels. 600 | label_dim = 0, # Number of class labels, 0 = unconditional. 601 | use_fp16 = False, # Execute the underlying model at FP16 precision? 602 | beta_d = 19.9, # Extent of the noise level schedule. 603 | beta_min = 0.1, # Initial slope of the noise level schedule. 604 | M = 1000, # Original number of timesteps in the DDPM formulation. 605 | epsilon_t = 1e-5, # Minimum t-value used during training. 606 | model_type = 'SongUNet', # Class name of the underlying model. 607 | fdm_beta_d = 19.9, # Extent of the FDM noise level schedule. 608 | fdm_beta_min = 0.1, # Initial slope of the FDM noise level schedule. 609 | fdm_multiplier = 1.0, # Multiplier of the FDM noise level schedule. 610 | **model_kwargs, # Keyword arguments for the underlying model. 611 | ): 612 | super().__init__() 613 | self.img_resolution = img_resolution 614 | self.img_channels = img_channels 615 | self.label_dim = label_dim 616 | self.use_fp16 = use_fp16 617 | self.beta_d = beta_d 618 | self.beta_min = beta_min 619 | self.M = M 620 | self.epsilon_t = epsilon_t 621 | self.sigma_min = float(self.sigma(epsilon_t)) 622 | self.sigma_max = float(self.sigma(1)) 623 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) 624 | 625 | self.fdm_beta_d = fdm_beta_d 626 | self.fdm_beta_min = fdm_beta_min 627 | self.fdm_multiplier = fdm_multiplier 628 | 629 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 630 | x = x.to(torch.float32) 631 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) 632 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 633 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 634 | 635 | c_skip = 1 636 | c_out = -sigma 637 | c_in = self.s(sigma) 638 | c_noise = (self.M - 1) * self.sigma_inv(sigma) 639 | 640 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) 641 | assert F_x.dtype == dtype 642 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 643 | return D_x 644 | 645 | def sigma(self, t): 646 | t = torch.as_tensor(t) 647 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 648 | 649 | def sigma_inv(self, sigma): 650 | sigma = torch.as_tensor(sigma) 651 | return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d 652 | 653 | def round_sigma(self, sigma): 654 | return torch.as_tensor(sigma) 655 | 656 | def fdm_sigma_inv(self, sigma): 657 | sigma = torch.as_tensor(sigma) 658 | return (sigma - self.sigma_min) / (self.sigma_max - self.sigma_min) 659 | 660 | def fdm_beta_fn(self, t): 661 | return self.fdm_beta_min * t + 0.5 * self.fdm_beta_d * t**2 662 | 663 | def s(self, sigma): 664 | t = self.fdm_sigma_inv(sigma) 665 | beta = self.fdm_beta_fn(t) 666 | return torch.exp(-self.fdm_multiplier * beta) * (1. + self.fdm_multiplier * beta) 667 | 668 | #---------------------------------------------------------------------------- 669 | # Fast Diffusion Model with VE Preconditioning (VE-FDM) 670 | # from the paper "Score-Based Generative Modeling through Stochastic 671 | # Differential Equations". 672 | 673 | @persistence.persistent_class 674 | class FDM_VEPrecond(torch.nn.Module): 675 | def __init__(self, 676 | img_resolution, # Image resolution. 677 | img_channels, # Number of color channels. 678 | label_dim = 0, # Number of class labels, 0 = unconditional. 679 | use_fp16 = False, # Execute the underlying model at FP16 precision? 680 | sigma_min = 0.02, # Minimum supported noise level. 681 | sigma_max = 100, # Maximum supported noise level. 682 | model_type = 'SongUNet', # Class name of the underlying model. 683 | fdm_beta_d = 19.9, # Extent of the FDM noise level schedule. 684 | fdm_beta_min = 0.1, # Initial slope of the FDM noise level schedule. 685 | fdm_multiplier = 1.0, # Multiplier of the FDM noise level schedule. 686 | **model_kwargs, # Keyword arguments for the underlying model. 687 | ): 688 | super().__init__() 689 | self.img_resolution = img_resolution 690 | self.img_channels = img_channels 691 | self.label_dim = label_dim 692 | self.use_fp16 = use_fp16 693 | self.sigma_min = sigma_min 694 | self.sigma_max = sigma_max 695 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) 696 | 697 | self.fdm_beta_d = fdm_beta_d 698 | self.fdm_beta_min = fdm_beta_min 699 | self.fdm_multiplier = fdm_multiplier 700 | 701 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 702 | x = x.to(torch.float32) 703 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) 704 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) 705 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 706 | 707 | c_skip = 1 708 | c_out = sigma 709 | c_in = self.s(sigma) 710 | c_noise = (0.5 * sigma).log() 711 | 712 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) 713 | assert F_x.dtype == dtype 714 | D_x = c_skip * x + c_out * F_x.to(torch.float32) 715 | return D_x 716 | 717 | def round_sigma(self, sigma): 718 | return torch.as_tensor(sigma) 719 | 720 | def fdm_sigma_inv(self, sigma): 721 | sigma = torch.as_tensor(sigma) 722 | return (sigma - self.sigma_min) / (self.sigma_max - self.sigma_min) 723 | 724 | def fdm_beta_fn(self, t): 725 | return self.fdm_beta_min * t + 0.5 * self.fdm_beta_d * t**2 726 | 727 | def s(self, sigma): 728 | t = self.fdm_sigma_inv(sigma) 729 | beta = self.fdm_beta_fn(t) 730 | return torch.exp(-beta) * (1. + beta) --------------------------------------------------------------------------------