├── figures ├── figure1.PNG └── figure2.PNG ├── guided_diffusion ├── __init__.py ├── __pycache__ │ ├── nn.cpython-39.pyc │ ├── unet.cpython-39.pyc │ ├── logger.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── fp16_util.cpython-39.pyc │ └── script_util.cpython-39.pyc ├── script_util.py ├── dist_util.py ├── losses.py ├── respace.py ├── resample.py ├── nn.py ├── fp16_util.py ├── train_util.py └── logger.py ├── torch_utils ├── __init__.py ├── distributed.py ├── persistence.py ├── training_stats.py └── misc.py ├── training ├── __init__.py ├── loss.py ├── dataset.py └── training_loop.py ├── environment.yml ├── dnnlib ├── __init__.py └── util.py ├── Dockerfile ├── example.py ├── README.md ├── train_classifier.py ├── fid.py ├── classifier_lib.py ├── LICENSE ├── train.py └── generate.py /figures/figure1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/figures/figure1.PNG -------------------------------------------------------------------------------- /figures/figure2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/figures/figure2.PNG -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/nn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/nn.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/fp16_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc -------------------------------------------------------------------------------- /guided_diffusion/__pycache__/script_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/script_util.cpython-39.pyc -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: edm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python>=3.8, < 3.10 # package build failures on 3.10 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow>=8.3.1 11 | - scipy>=1.7.1 12 | - pytorch=1.12.1 13 | - psutil 14 | - requests 15 | - tqdm 16 | - imageio 17 | - pip: 18 | - imageio-ffmpeg>=0.4.3 19 | - pyspng 20 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | FROM nvcr.io/nvidia/pytorch:22.10-py3 9 | 10 | ENV PYTHONDONTWRITEBYTECODE 1 11 | ENV PYTHONUNBUFFERED 1 12 | 13 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 14 | 15 | WORKDIR /workspace 16 | 17 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 18 | ENTRYPOINT ["/entry.sh"] 19 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | from .unet import EncoderUNetModel 2 | 3 | NUM_CLASSES = 1000 4 | def create_classifier( 5 | image_size, 6 | classifier_use_fp16, 7 | classifier_width, 8 | classifier_depth, 9 | classifier_attention_resolutions, 10 | classifier_use_scale_shift_norm, 11 | classifier_resblock_updown, 12 | classifier_pool, 13 | out_channels, 14 | in_channels = 3, 15 | condition=False, 16 | ): 17 | if image_size == 512: 18 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 19 | elif image_size == 256: 20 | channel_mult = (1, 1, 2, 2, 4, 4) 21 | elif image_size == 128: 22 | channel_mult = (1, 1, 2, 3, 4) 23 | elif image_size == 64: 24 | channel_mult = (1, 2, 3, 4) 25 | elif image_size == 32: 26 | channel_mult = (1, 2, 4) 27 | elif image_size == 16: 28 | channel_mult = (1, 2) 29 | elif image_size == 8: 30 | channel_mult = (1,) 31 | else: 32 | raise ValueError(f"unsupported image size: {image_size}") 33 | 34 | attention_ds = [] 35 | for res in classifier_attention_resolutions.split(","): 36 | attention_ds.append(image_size // int(res)) 37 | 38 | return EncoderUNetModel( 39 | image_size=image_size, 40 | in_channels=in_channels, 41 | model_channels=classifier_width, 42 | out_channels=out_channels, 43 | num_res_blocks=classifier_depth, 44 | attention_resolutions=tuple(attention_ds), 45 | channel_mult=channel_mult, 46 | use_fp16=classifier_use_fp16, 47 | num_head_channels=64, 48 | use_scale_shift_norm=classifier_use_scale_shift_norm, 49 | resblock_updown=classifier_resblock_updown, 50 | pool=classifier_pool, 51 | condition=condition, 52 | ) 53 | 54 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(args): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE + args.initial_gpu}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Minimal standalone example to reproduce the main results from the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import tqdm 12 | import pickle 13 | import numpy as np 14 | import torch 15 | import PIL.Image 16 | import dnnlib 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def generate_image_grid( 21 | network_pkl, dest_path, 22 | seed=0, gridw=8, gridh=8, device=torch.device('cuda'), 23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 25 | ): 26 | batch_size = gridw * gridh 27 | torch.manual_seed(seed) 28 | 29 | # Load network. 30 | print(f'Loading network from "{network_pkl}"...') 31 | with dnnlib.util.open_url(network_pkl) as f: 32 | net = pickle.load(f)['ema'].to(device) 33 | 34 | # Pick latents and labels. 35 | print(f'Generating {batch_size} images...') 36 | latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 37 | class_labels = None 38 | if net.label_dim: 39 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)] 40 | 41 | # Adjust noise levels based on what's supported by the network. 42 | sigma_min = max(sigma_min, net.sigma_min) 43 | sigma_max = min(sigma_max, net.sigma_max) 44 | 45 | # Time step discretization. 46 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) 47 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 48 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 49 | 50 | # Main sampling loop. 51 | x_next = latents.to(torch.float64) * t_steps[0] 52 | for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1 53 | x_cur = x_next 54 | 55 | # Increase noise temporarily. 56 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 57 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 58 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) 59 | 60 | # Euler step. 61 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 62 | d_cur = (x_hat - denoised) / t_hat 63 | x_next = x_hat + (t_next - t_hat) * d_cur 64 | 65 | # Apply 2nd order correction. 66 | if i < num_steps - 1: 67 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 68 | d_prime = (x_next - denoised) / t_next 69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 70 | 71 | # Save image grid. 72 | print(f'Saving image grid to "{dest_path}"...') 73 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8) 74 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) 75 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels) 76 | image = image.cpu().numpy() 77 | PIL.Image.fromarray(image, 'RGB').save(dest_path) 78 | print('Done.') 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def main(): 83 | model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained' 84 | generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35 85 | generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79 86 | generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79 87 | generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | if __name__ == "__main__": 92 | main() 93 | 94 | #---------------------------------------------------------------------------- 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TRAINING UNBIASED DIFFUSION MODELS FROM BIASED DATASET (TIW-DSM) (ICLR 2024)
Official PyTorch implementation of the TIW-DSM 2 | 3 | 4 | 5 | **[Yeongmin Kim](https://sites.google.com/view/yeongmin-space/%ED%99%88), [Byeonghu Na](https://sites.google.com/view/byeonghu-na), Minsang Park, [JoonHo Jang](https://sites.google.com/view/joonhojang), [Dongjun Kim](https://sites.google.com/view/dongjun-kim), [Wanmo Kang](https://sites.google.com/site/wanmokang), and [Il-Chul Moon](https://aai.kaist.ac.kr/bbs/board.php?bo_table=sub2_1&wr_id=3)** 6 | 7 | | [openreview](https://openreview.net/forum?id=39cPKijBed) | [arxiv](https://arxiv.org/abs/2403.01189) | [datasets](https://drive.google.com/drive/u/0/folders/1RakPtfp70E2BSgDM5xMBd2Om0N8ycrRK) | [checkpoints](https://drive.google.com/drive/u/0/folders/1vYLH8UNlXWZarn0IOtiPuU8FvBFqJvTP) | 8 | 9 | -------------------- 10 | 11 | ## Overview 12 | ![Teaser image](./figures/figure1.PNG) 13 | ![Teaser image](./figures/figure2.PNG) 14 | ## Requirements 15 | The requirements for this code are the same as those outlined for [EDM](https://github.com/NVlabs/edm). 16 | 17 | ## Datasets 18 | - Download from [datasets](https://drive.google.com/drive/u/0/folders/1RakPtfp70E2BSgDM5xMBd2Om0N8ycrRK) with similar directory structure. 19 | ## Training 20 | - Download pre-trained feature extractor from [checkpoints](https://drive.google.com/drive/u/0/folders/1vYLH8UNlXWZarn0IOtiPuU8FvBFqJvTP). 21 | ### Time-dependent discriminator 22 | - CIFAR10 LT / 5% setting 23 | ``` 24 | python train_classifier.py 25 | ``` 26 | - CIFAR10 LT / 10% setting 27 | ``` 28 | python train_classifier.py --savedir=/checkpoints/discriminator/cifar10/unbias_1000/ --refdir=/datasets/cifar10/discriminator_training/unbias_1000/real_data.npz --real_mul=10 29 | ``` 30 | - CelebA64 / 5% setting 31 | ``` 32 | python train_classifier.py --feature_extractor=/checkpoints/discriminator/feature_extractor/64x64_classifier.pt --savedir=/checkpoints/discriminator/celeba64/unbias_8k/ --biasdir=/datasets/celeba64/discriminator_training/bias_162k/fake_data.npz --refdir=/datasets/celeba64/discriminator_training/unbias_8k/real_data.npz --img_resolution=64 33 | ``` 34 | 35 | ### Diffusion model with TIW-DSM objective 36 | - CIFAR10 LT / 5% setting 37 | ``` 38 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 train.py --arch=ddpmpp --outdir=out_dir --data=PATH/TIW-DSM/datasets/cifar10/score_training/500_10000/dataset.zip --batch=256 --tick=5 --duration=11 39 | ``` 40 | - CIFAR10 LT / 10% setting 41 | ``` 42 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 train.py --arch=ddpmpp --outdir=out_dir --data=PATH/TIW-DSM/datasets/cifar10/score_training/1000_10000/dataset.zip --batch=256 --tick=5 --duration=21 43 | ``` 44 | - CelebA64 / 5% setting 45 | ``` 46 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 train.py --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 --outdir=outdir --data=PATH/TIW-DSM/datasets/celeba64/score_training/dataset.zip --cla_path=PATH/TIW-DSM/checkpoints/discriminator/feature_extractor/64x64_classifier.pt --dis_path=PATH/TIW-DSM/checkpoints/discriminator/celeba64/unbias_8k/discriminator_9501.pt --tick=5 47 | ``` 48 | 49 | ## Sampling 50 | - CIFAR10 51 | ``` 52 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-49999 --batch=64 --network=TRAINED_PKL 53 | ``` 54 | - CelebA64 55 | ``` 56 | torchrun --standalone --nproc_per_node=2 generate.py --steps=40 --outdir=out --seeds=0-49999 --batch=64 --network=TRAINED_PKL 57 | ``` 58 | ## Evaluation 59 | - CIFAR10 60 | ``` 61 | python fid.py calc --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz --num=50000 --images=YOUR_SAMPLE_PATH 62 | ``` 63 | - CelebA64 64 | ``` 65 | python fid.py calc --ref=/PATH/TIW-DSM/datasets/celeba/evaluation/FID_stat.npz --num=50000 --images=YOUR_SAMPLE_PATH 66 | ``` 67 | 68 | ## Reference 69 | If you find the code useful for your research, please consider citing 70 | ```bib 71 | @inproceedings{ 72 | kim2024training, 73 | title={Training Unbiased Diffusion Models From Biased Dataset}, 74 | author={Yeongmin Kim and Byeonghu Na and Minsang Park and JoonHo Jang and Dongjun Kim and Wanmo Kang and Il-chul Moon}, 75 | booktitle={The Twelfth International Conference on Learning Representations}, 76 | year={2024}, 77 | url={https://openreview.net/forum?id=39cPKijBed} 78 | } 79 | ``` 80 | This work is heavily built upon the code from 81 | - *Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35:26565–26577,2022.* 82 | - *Dongjun Kim\*, Yeongmin Kim\*, Se Jung Kwon, Wanmo Kang, and Il-Chul Moon. Refining generative process with discriminator guidance in score-based diffusion models. In Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp. 16567–16598. PMLR, 23–29 Jul 2023* 83 | 84 | 85 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | from torch_utils import persistence 13 | import classifier_lib 14 | #---------------------------------------------------------------------------- 15 | # Loss function corresponding to the variance preserving (VP) formulation 16 | # from the paper "Score-Based Generative Modeling through Stochastic 17 | # Differential Equations". 18 | 19 | @persistence.persistent_class 20 | class VPLoss: 21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 22 | self.beta_d = beta_d 23 | self.beta_min = beta_min 24 | self.epsilon_t = epsilon_t 25 | 26 | def __call__(self, net, images, labels, augment_pipe=None): 27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 29 | weight = 1 / sigma ** 2 30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 31 | n = torch.randn_like(y) * sigma 32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 33 | loss = weight * ((D_yn - y) ** 2) 34 | return loss 35 | 36 | def sigma(self, t): 37 | t = torch.as_tensor(t) 38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 39 | 40 | #---------------------------------------------------------------------------- 41 | # Loss function corresponding to the variance exploding (VE) formulation 42 | # from the paper "Score-Based Generative Modeling through Stochastic 43 | # Differential Equations". 44 | 45 | @persistence.persistent_class 46 | class VELoss: 47 | def __init__(self, sigma_min=0.02, sigma_max=100): 48 | self.sigma_min = sigma_min 49 | self.sigma_max = sigma_max 50 | 51 | def __call__(self, net, images, labels, augment_pipe=None): 52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 54 | weight = 1 / sigma ** 2 55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 56 | n = torch.randn_like(y) * sigma 57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 58 | loss = weight * ((D_yn - y) ** 2) 59 | return loss 60 | 61 | #---------------------------------------------------------------------------- 62 | # Improved loss function proposed in the paper "Elucidating the Design Space 63 | # of Diffusion-Based Generative Models" (EDM). 64 | 65 | @persistence.persistent_class 66 | class EDMLoss: 67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 68 | self.P_mean = P_mean 69 | self.P_std = P_std 70 | self.sigma_data = sigma_data 71 | 72 | def __call__(self, net, images, labels=None, augment_pipe=None): 73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 77 | n = torch.randn_like(y) * sigma 78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 79 | loss = weight * ((D_yn - y) ** 2) 80 | return loss 81 | 82 | #---------------------------------------------------------------------------- 83 | @persistence.persistent_class 84 | class TIW_EDMLoss: 85 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 86 | self.P_mean = P_mean 87 | self.P_std = P_std 88 | self.sigma_data = sigma_data 89 | 90 | def __call__(self, vpsde, dis, net, images, labels=None, augment_pipe=None): 91 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 92 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 93 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 94 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 95 | n = torch.randn_like(y) * sigma 96 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 97 | 98 | correction, prediction = classifier_lib.get_weight_correction(dis, vpsde, y + n, sigma.flatten(), images.shape[1], 0., 1.) 99 | correction = correction.detach() 100 | prediction = prediction.detach()[:, None, None, None] 101 | 102 | ## Please refer eq 36, 37 in main paper 103 | tiw = 2 * prediction 104 | y += correction 105 | 106 | loss = weight * tiw[:, None, None, None] * ((D_yn - y) ** 2) 107 | return loss 108 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t, reversed=False): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import numpy as np 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | class GaussianFourierProjection(nn.Module): 36 | """Gaussian Fourier embeddings for noise levels.""" 37 | 38 | def __init__(self, embedding_size=256, scale=1.0): 39 | super().__init__() 40 | self.W = nn.Parameter(th.randn(embedding_size) * scale, requires_grad=False) 41 | 42 | def forward(self, x): 43 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 44 | return th.cat([th.sin(x_proj), th.cos(x_proj)], dim=-1) 45 | 46 | 47 | def GFP(embedding_size=256, scale=1.0): 48 | return GaussianFourierProjection(embedding_size, scale) 49 | 50 | 51 | def linear(*args, **kwargs): 52 | """ 53 | Create a linear module. 54 | """ 55 | return nn.Linear(*args, **kwargs) 56 | 57 | 58 | def avg_pool_nd(dims, *args, **kwargs): 59 | """ 60 | Create a 1D, 2D, or 3D average pooling module. 61 | """ 62 | if dims == 1: 63 | return nn.AvgPool1d(*args, **kwargs) 64 | elif dims == 2: 65 | return nn.AvgPool2d(*args, **kwargs) 66 | elif dims == 3: 67 | return nn.AvgPool3d(*args, **kwargs) 68 | raise ValueError(f"unsupported dimensions: {dims}") 69 | 70 | 71 | def update_ema(target_params, source_params, rate=0.99): 72 | """ 73 | Update target parameters to be closer to those of source parameters using 74 | an exponential moving average. 75 | 76 | :param target_params: the target parameter sequence. 77 | :param source_params: the source parameter sequence. 78 | :param rate: the EMA rate (closer to 1 means slower). 79 | """ 80 | for targ, src in zip(target_params, source_params): 81 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 82 | 83 | 84 | def zero_module(module): 85 | """ 86 | Zero out the parameters of a module and return it. 87 | """ 88 | for p in module.parameters(): 89 | p.detach().zero_() 90 | return module 91 | 92 | 93 | def scale_module(module, scale): 94 | """ 95 | Scale the parameters of a module and return it. 96 | """ 97 | for p in module.parameters(): 98 | p.detach().mul_(scale) 99 | return module 100 | 101 | 102 | def mean_flat(tensor): 103 | """ 104 | Take the mean over all non-batch dimensions. 105 | """ 106 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 107 | 108 | 109 | def normalization(channels): 110 | """ 111 | Make a standard normalization layer. 112 | 113 | :param channels: number of input channels. 114 | :return: an nn.Module for normalization. 115 | """ 116 | return GroupNorm32(32, channels) 117 | 118 | 119 | def timestep_embedding(timesteps, dim, max_period=10000): 120 | """ 121 | Create sinusoidal timestep embeddings. 122 | 123 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 124 | These may be fractional. 125 | :param dim: the dimension of the output. 126 | :param max_period: controls the minimum frequency of the embeddings. 127 | :return: an [N x dim] Tensor of positional embeddings. 128 | """ 129 | half = dim // 2 130 | freqs = th.exp( 131 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 132 | ).to(device=timesteps.device) 133 | args = timesteps[:, None].float() * freqs[None] 134 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 135 | if dim % 2: 136 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 137 | return embedding 138 | 139 | 140 | def checkpoint(func, inputs, params, flag): 141 | """ 142 | Evaluate a function without caching intermediate activations, allowing for 143 | reduced memory at the expense of extra compute in the backward pass. 144 | 145 | :param func: the function to evaluate. 146 | :param inputs: the argument sequence to pass to `func`. 147 | :param params: a sequence of parameters `func` depends on but does not 148 | explicitly take as arguments. 149 | :param flag: if False, disable gradient checkpointing. 150 | """ 151 | if flag: 152 | args = tuple(inputs) + tuple(params) 153 | return CheckpointFunction.apply(func, len(inputs), *args) 154 | else: 155 | return func(*inputs) 156 | 157 | 158 | class CheckpointFunction(th.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, run_function, length, *args): 161 | ctx.run_function = run_function 162 | ctx.input_tensors = list(args[:length]) 163 | ctx.input_params = list(args[length:]) 164 | with th.no_grad(): 165 | output_tensors = ctx.run_function(*ctx.input_tensors) 166 | return output_tensors 167 | 168 | @staticmethod 169 | def backward(ctx, *output_grads): 170 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 171 | with th.enable_grad(): 172 | # Fixes a bug where the first op in run_function modifies the 173 | # Tensor storage in place, which is not allowed for detach()'d 174 | # Tensors. 175 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 176 | output_tensors = ctx.run_function(*shallow_copies) 177 | input_grads = th.autograd.grad( 178 | output_tensors, 179 | ctx.input_tensors + ctx.input_params, 180 | output_grads, 181 | allow_unused=True, 182 | ) 183 | del ctx.input_tensors 184 | del ctx.input_params 185 | del output_tensors 186 | return (None, None) + input_grads 187 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import click 2 | import os 3 | import classifier_lib 4 | import torch 5 | import numpy as np 6 | import dnnlib 7 | import torchvision.transforms as transforms 8 | import torch.utils.data as data 9 | import random 10 | 11 | class BasicDataset(data.Dataset): 12 | def __init__(self, x_np, y_np, transform=transforms.ToTensor()): 13 | super(BasicDataset, self).__init__() 14 | 15 | self.x = x_np 16 | self.y = y_np 17 | self.transform = transform 18 | 19 | def __getitem__(self, index): 20 | return self.transform(self.x[index]), self.y[index] 21 | 22 | def __len__(self): 23 | return len(self.x) 24 | 25 | 26 | @click.command() 27 | 28 | ## Data configuration 29 | @click.option('--feature_extractor', help='Path of feature extractor', metavar='STR', type=str, default='/checkpoints/discriminator/feature_extractor/32x32_classifier.pt') 30 | @click.option('--savedir', help='Discriminator save directory', metavar='PATH', type=str, default="/checkpoints/discriminator/cifar10/unbias_500/") 31 | @click.option('--biasdir', help='Bias data directory', metavar='PATH', type=str, default="/datasets/cifar10/discriminator_training/bias_10000/fake_data.npz") 32 | @click.option('--refdir', help='Real sample directory', metavar='PATH', type=str, default="/datasets/cifar10/discriminator_training/unbias_500/real_data.npz") 33 | 34 | @click.option('--img_resolution', help='Image resolution', metavar='INT', type=click.IntRange(min=1), default=32) 35 | @click.option('--real_mul', help='Scaling imblance ', metavar='STR', type=click.IntRange(min=1), default=20) 36 | 37 | ## Training configuration 38 | @click.option('--batch_size', help='Batch size', metavar='INT', type=click.IntRange(min=1), default=128) 39 | @click.option('--iter', help='Training iteration', metavar='INT', type=click.IntRange(min=1), default=20000) 40 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=3e-4) 41 | @click.option('--device', help='Device', metavar='STR', type=str, default='cuda:0') 42 | 43 | def main(**kwargs): 44 | opts = dnnlib.EasyDict(kwargs) 45 | path = os.getcwd() 46 | path_feat = path + opts.feature_extractor 47 | savedir = path + opts.savedir 48 | refdir = path + opts.refdir 49 | biasdir = path + opts.biasdir 50 | os.makedirs(savedir,exist_ok=True) 51 | 52 | ## Prepare real&fake data 53 | ref_data = np.load(refdir)['samples'] 54 | for k in range(opts.real_mul): 55 | if k == 0: 56 | ref_datas = ref_data 57 | else: 58 | ref_datas = np.concatenate([ref_datas, ref_data]) 59 | ref_data = ref_datas 60 | bias_data = np.load(biasdir)['samples'] 61 | print("bias:", len(bias_data)) 62 | print("scaled unbias:", len(ref_data)) 63 | 64 | ## Loader 65 | transform = transforms.Compose([transforms.ToTensor()]) 66 | ref_label = torch.ones(ref_data.shape[0]) 67 | bias_label = torch.zeros(bias_data.shape[0]) 68 | 69 | ref_data = BasicDataset(ref_data, ref_label, transform) 70 | bias_data = BasicDataset(bias_data, bias_label, transform) 71 | ref_loader = torch.utils.data.DataLoader(dataset=ref_data, batch_size=opts.batch_size, num_workers=0, shuffle=True, drop_last=True) 72 | bias_loader = torch.utils.data.DataLoader(dataset=bias_data, batch_size=opts.batch_size, num_workers=0, shuffle=True, drop_last=True) 73 | ref_iterator = iter(ref_loader) 74 | bias_iterator = iter(bias_loader) 75 | 76 | ## Extractor & Disciminator 77 | pretrained_classifier = classifier_lib.load_classifier(path_feat, opts.img_resolution, opts.device, eval=False) 78 | discriminator = classifier_lib.load_discriminator(None, opts.device, 0, eval=False) 79 | 80 | ## Prepare training 81 | vpsde = classifier_lib.vpsde() 82 | optimizer = torch.optim.Adam(discriminator.parameters(), lr=opts.lr, weight_decay=1e-7) 83 | loss = torch.nn.BCELoss() 84 | scaler = lambda x: 2. * x - 1. 85 | 86 | ## Training 87 | for i in range(opts.iter): 88 | if i % 500 == 0: 89 | outs = [] 90 | cors = [] 91 | num_data = 0 92 | try: 93 | r_inputs, r_labels = next(ref_iterator) 94 | except: 95 | ref_iterator = iter(ref_loader) 96 | r_inputs, r_labels = next(ref_iterator) 97 | try: 98 | f_inputs, f_labels = next(bias_iterator) 99 | except: 100 | bias_iterator = iter(bias_loader) 101 | f_inputs, f_labels = next(bias_iterator) 102 | 103 | inputs = torch.cat([r_inputs, f_inputs]) 104 | labels = torch.cat([r_labels, f_labels]) 105 | c = list(range(inputs.shape[0])) 106 | random.shuffle(c) 107 | inputs, labels = inputs[c], labels[c] 108 | 109 | optimizer.zero_grad() 110 | inputs = inputs.to(opts.device) 111 | labels = labels.to(opts.device) 112 | inputs = scaler(inputs) 113 | 114 | ## Data perturbation 115 | t, _ = vpsde.get_diffusion_time(inputs.shape[0], inputs.device, importance_sampling=True) 116 | mean, std = vpsde.marginal_prob(t) 117 | z = torch.randn_like(inputs) 118 | perturbed_inputs = mean[:, None, None, None] * inputs + std[:, None, None, None] * z 119 | 120 | ## Forward 121 | with torch.no_grad(): 122 | pretrained_feature = pretrained_classifier(perturbed_inputs, timesteps=t, feature=True) 123 | label_prediction = discriminator(pretrained_feature, t, sigmoid=True).view(-1) 124 | 125 | ## Backward 126 | out = loss(label_prediction, labels) 127 | out.backward() 128 | optimizer.step() 129 | 130 | ## Report 131 | cor = ((label_prediction > 0.5).float() == labels).float().mean() 132 | outs.append(out.item()) 133 | cors.append(cor.item()) 134 | num_data += inputs.shape[0] 135 | print(f"{i}-th iter BCE loss: {np.mean(outs)}, correction rate: {np.mean(cors)}") 136 | 137 | if i % 500 == 0: 138 | torch.save(discriminator.state_dict(), savedir + f"/discriminator_{i+1}.pt") 139 | 140 | #---------------------------------------------------------------------------- 141 | if __name__ == "__main__": 142 | main() 143 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import click 12 | import tqdm 13 | import pickle 14 | import numpy as np 15 | import scipy.linalg 16 | import torch 17 | import dnnlib 18 | from torch_utils import distributed as dist 19 | from training import dataset 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def calculate_inception_stats( 24 | image_path, num_expected=None, seed=0, max_batch_size=64, 25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'), 26 | ): 27 | # Rank 0 goes first. 28 | if dist.get_rank() != 0: 29 | torch.distributed.barrier() 30 | 31 | # Load Inception-v3 model. 32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 33 | dist.print0('Loading Inception-v3 model...') 34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 35 | detector_kwargs = dict(return_features=True) 36 | feature_dim = 2048 37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 38 | detector_net = pickle.load(f).to(device) 39 | 40 | # List images. 41 | dist.print0(f'Loading images from "{image_path}"...') 42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 43 | if num_expected is not None and len(dataset_obj) < num_expected: 44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 45 | if len(dataset_obj) < 2: 46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 47 | 48 | # Other ranks follow. 49 | if dist.get_rank() == 0: 50 | torch.distributed.barrier() 51 | 52 | # Divide images into batches. 53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 57 | 58 | # Accumulate statistics. 59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 63 | torch.distributed.barrier() 64 | if images.shape[0] == 0: 65 | continue 66 | if images.shape[1] == 1: 67 | images = images.repeat([1, 3, 1, 1]) 68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 69 | mu += features.sum(0) 70 | sigma += features.T @ features 71 | 72 | # Calculate grand totals. 73 | torch.distributed.all_reduce(mu) 74 | torch.distributed.all_reduce(sigma) 75 | mu /= len(dataset_obj) 76 | sigma -= mu.ger(mu) * len(dataset_obj) 77 | sigma /= len(dataset_obj) - 1 78 | return mu.cpu().numpy(), sigma.cpu().numpy() 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 83 | m = np.square(mu - mu_ref).sum() 84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 85 | fid = m + np.trace(sigma + sigma_ref - s * 2) 86 | return float(np.real(fid)) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @click.group() 91 | def main(): 92 | """Calculate Frechet Inception Distance (FID). 93 | 94 | Examples: 95 | 96 | \b 97 | # Generate 50000 images and save them as fid-tmp/*/*.png 98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 100 | 101 | \b 102 | # Calculate FID 103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ 104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 105 | 106 | \b 107 | # Compute dataset reference statistics 108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 109 | """ 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | @main.command() 114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 119 | 120 | def calc(image_path, ref_path, num_expected, seed, batch): 121 | """Calculate FID for a given set of images.""" 122 | torch.multiprocessing.set_start_method('spawn') 123 | dist.init() 124 | 125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 126 | ref = None 127 | if dist.get_rank() == 0: 128 | with dnnlib.util.open_url(ref_path) as f: 129 | ref = dict(np.load(f)) 130 | 131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 132 | dist.print0('Calculating FID...') 133 | if dist.get_rank() == 0: 134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 135 | print(f'{fid:g}') 136 | torch.distributed.barrier() 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | @main.command() 141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 144 | 145 | def ref(dataset_path, dest_path, batch): 146 | """Calculate dataset reference statistics needed by 'calc'.""" 147 | torch.multiprocessing.set_start_method('spawn') 148 | dist.init() 149 | 150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 152 | if dist.get_rank() == 0: 153 | if os.path.dirname(dest_path): 154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 155 | np.savez(dest_path, mu=mu, sigma=sigma) 156 | 157 | torch.distributed.barrier() 158 | dist.print0('Done.') 159 | 160 | #---------------------------------------------------------------------------- 161 | 162 | if __name__ == "__main__": 163 | main() 164 | 165 | #---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | # Abstract base class for datasets. 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | cache = False, # Cache images in CPU memory? 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._cache = cache 40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 41 | self._raw_labels = None 42 | self._label_shape = None 43 | 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_raw_labels(self): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def __getstate__(self): 79 | return dict(self.__dict__, _raw_labels=None) 80 | 81 | def __del__(self): 82 | try: 83 | self.close() 84 | except: 85 | pass 86 | 87 | def __len__(self): 88 | return self._raw_idx.size 89 | 90 | def __getitem__(self, idx): 91 | raw_idx = self._raw_idx[idx] 92 | image = self._cached_images.get(raw_idx, None) 93 | if image is None: 94 | image = self._load_raw_image(raw_idx) 95 | if self._cache: 96 | self._cached_images[raw_idx] = image 97 | assert isinstance(image, np.ndarray) 98 | assert list(image.shape) == self.image_shape 99 | assert image.dtype == np.uint8 100 | if self._xflip[idx]: 101 | assert image.ndim == 3 # CHW 102 | image = image[:, :, ::-1] 103 | return image.copy(), self.get_label(idx) 104 | 105 | def get_label(self, idx): 106 | label = self._get_raw_labels()[self._raw_idx[idx]] 107 | if label.dtype == np.int64: 108 | onehot = np.zeros(self.label_shape, dtype=np.float32) 109 | onehot[label] = 1 110 | label = onehot 111 | return label.copy() 112 | 113 | def get_details(self, idx): 114 | d = dnnlib.EasyDict() 115 | d.raw_idx = int(self._raw_idx[idx]) 116 | d.xflip = (int(self._xflip[idx]) != 0) 117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 118 | return d 119 | 120 | @property 121 | def name(self): 122 | return self._name 123 | 124 | @property 125 | def image_shape(self): 126 | return list(self._raw_shape[1:]) 127 | 128 | @property 129 | def num_channels(self): 130 | assert len(self.image_shape) == 3 # CHW 131 | return self.image_shape[0] 132 | 133 | @property 134 | def resolution(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | assert self.image_shape[1] == self.image_shape[2] 137 | return self.image_shape[1] 138 | 139 | @property 140 | def label_shape(self): 141 | if self._label_shape is None: 142 | raw_labels = self._get_raw_labels() 143 | if raw_labels.dtype == np.int64: 144 | self._label_shape = [int(np.max(raw_labels)) + 1] 145 | else: 146 | self._label_shape = raw_labels.shape[1:] 147 | return list(self._label_shape) 148 | 149 | @property 150 | def label_dim(self): 151 | assert len(self.label_shape) == 1 152 | return self.label_shape[0] 153 | 154 | @property 155 | def has_labels(self): 156 | return any(x != 0 for x in self.label_shape) 157 | 158 | @property 159 | def has_onehot_labels(self): 160 | return self._get_raw_labels().dtype == np.int64 161 | 162 | #---------------------------------------------------------------------------- 163 | # Dataset subclass that loads images recursively from the specified directory 164 | # or ZIP file. 165 | 166 | class ImageFolderDataset(Dataset): 167 | def __init__(self, 168 | path, # Path to directory or zip. 169 | resolution = None, # Ensure specific resolution, None = highest available. 170 | use_pyspng = True, # Use pyspng if available? 171 | **super_kwargs, # Additional arguments for the Dataset base class. 172 | ): 173 | self._path = path 174 | self._use_pyspng = use_pyspng 175 | self._zipfile = None 176 | 177 | if os.path.isdir(self._path): 178 | self._type = 'dir' 179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 180 | elif self._file_ext(self._path) == '.zip': 181 | self._type = 'zip' 182 | self._all_fnames = set(self._get_zipfile().namelist()) 183 | else: 184 | raise IOError('Path must point to a directory or zip') 185 | 186 | PIL.Image.init() 187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 188 | if len(self._image_fnames) == 0: 189 | raise IOError('No image files found in the specified path') 190 | 191 | name = os.path.splitext(os.path.basename(self._path))[0] 192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 194 | raise IOError('Image files do not match the specified resolution') 195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 196 | 197 | @staticmethod 198 | def _file_ext(fname): 199 | return os.path.splitext(fname)[1].lower() 200 | 201 | def _get_zipfile(self): 202 | assert self._type == 'zip' 203 | if self._zipfile is None: 204 | self._zipfile = zipfile.ZipFile(self._path) 205 | return self._zipfile 206 | 207 | def _open_file(self, fname): 208 | if self._type == 'dir': 209 | return open(os.path.join(self._path, fname), 'rb') 210 | if self._type == 'zip': 211 | return self._get_zipfile().open(fname, 'r') 212 | return None 213 | 214 | def close(self): 215 | try: 216 | if self._zipfile is not None: 217 | self._zipfile.close() 218 | finally: 219 | self._zipfile = None 220 | 221 | def __getstate__(self): 222 | return dict(super().__getstate__(), _zipfile=None) 223 | 224 | def _load_raw_image(self, raw_idx): 225 | fname = self._image_fnames[raw_idx] 226 | with self._open_file(fname) as f: 227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 228 | image = pyspng.load(f.read()) 229 | else: 230 | image = np.array(PIL.Image.open(f)) 231 | if image.ndim == 2: 232 | image = image[:, :, np.newaxis] # HW => HWC 233 | image = image.transpose(2, 0, 1) # HWC => CHW 234 | return image 235 | 236 | def _load_raw_labels(self): 237 | fname = 'dataset.json' 238 | if fname not in self._all_fnames: 239 | return None 240 | with self._open_file(fname) as f: 241 | labels = json.load(f)['labels'] 242 | if labels is None: 243 | return None 244 | labels = dict(labels) 245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 246 | labels = np.array(labels) 247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 248 | return labels 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /classifier_lib.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from guided_diffusion.script_util import create_classifier 3 | import os 4 | import numpy as np 5 | import io 6 | 7 | def get_discriminator(latent_extractor_ckpt, discriminator_ckpt, enable_grad=True): 8 | classifier = latent_extractor_ckpt 9 | discriminator = discriminator_ckpt 10 | def evaluate(perturbed_inputs, timesteps=None, condition=None): 11 | with torch.enable_grad() if enable_grad else torch.no_grad(): 12 | adm_features = classifier(perturbed_inputs, timesteps=timesteps, feature=True) 13 | prediction = discriminator(adm_features, timesteps, sigmoid=True, condition=condition).view(-1) 14 | return prediction 15 | return evaluate 16 | 17 | def get_reverse_discriminator(latent_extractor_ckpt, discriminator_ckpt, enable_grad=True): 18 | classifier = latent_extractor_ckpt 19 | discriminator = discriminator_ckpt 20 | def evaluate(perturbed_inputs, timesteps=None, condition=None): 21 | with torch.enable_grad() if enable_grad else torch.no_grad(): 22 | adm_features = classifier(perturbed_inputs, timesteps=timesteps, feature=True) 23 | prediction = discriminator(adm_features, timesteps, sigmoid=True, condition=condition).view(-1) 24 | return 1.- prediction 25 | return evaluate 26 | 27 | def load_classifier(ckpt_path, img_resolution, device, eval=True): 28 | classifier_args = dict( 29 | image_size=img_resolution, 30 | classifier_use_fp16=False, 31 | classifier_width=128, 32 | classifier_depth=4 if img_resolution in [64, 32] else 2, 33 | classifier_attention_resolutions="32,16,8", 34 | classifier_use_scale_shift_norm=True, 35 | classifier_resblock_updown=True, 36 | classifier_pool="attention", 37 | out_channels=1000, 38 | ) 39 | classifier = create_classifier(**classifier_args) 40 | classifier.to(device) 41 | if ckpt_path is not None: 42 | classifier_state = torch.load(ckpt_path, map_location="cpu") 43 | classifier.load_state_dict(classifier_state) 44 | if eval: 45 | classifier.eval() 46 | return classifier 47 | 48 | def load_discriminator(ckpt_path, device, condition, eval=False, channel=512): 49 | discriminator_args = dict( 50 | image_size=8, 51 | classifier_use_fp16=False, 52 | classifier_width=128, 53 | classifier_depth=2, 54 | classifier_attention_resolutions="32,16,8", 55 | classifier_use_scale_shift_norm=True, 56 | classifier_resblock_updown=True, 57 | classifier_pool="attention", 58 | out_channels=1, 59 | in_channels=channel, 60 | condition=condition, 61 | ) 62 | discriminator = create_classifier(**discriminator_args) 63 | discriminator.to(device) 64 | if ckpt_path is not None: 65 | #ckpt_path = os.getcwd() + ckpt_path 66 | discriminator_state = torch.load(ckpt_path, map_location="cpu") 67 | discriminator.load_state_dict(discriminator_state) 68 | if eval: 69 | discriminator.eval() 70 | return discriminator 71 | 72 | def load_attribute_classifier(ckpt_path, device, condition, eval=False, channel=512): 73 | discriminator_args = dict( 74 | image_size=8, 75 | classifier_use_fp16=False, 76 | classifier_width=128, 77 | classifier_depth=2, 78 | classifier_attention_resolutions="32,16,8", 79 | classifier_use_scale_shift_norm=True, 80 | classifier_resblock_updown=True, 81 | classifier_pool="attention", 82 | out_channels=4, 83 | in_channels=channel, 84 | condition=condition, 85 | ) 86 | discriminator = create_classifier(**discriminator_args) 87 | discriminator.to(device) 88 | if ckpt_path is not None: 89 | #ckpt_path = os.getcwd() + ckpt_path 90 | discriminator_state = torch.load(ckpt_path, map_location="cpu") 91 | discriminator.load_state_dict(discriminator_state) 92 | if eval: 93 | discriminator.eval() 94 | return discriminator 95 | 96 | def get_weight_correction(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels=None): 97 | mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier 98 | if tau.min() > time_max or tau.min() < time_min or discriminator == None: 99 | return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device), torch.zeros_like(unnormalized_input) 100 | 101 | else: 102 | input = mean_vp_tau[:,None,None,None] * unnormalized_input 103 | with torch.enable_grad(): 104 | x_ = input.float().clone().detach().requires_grad_() 105 | if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE 106 | tau = vpsde.compute_t_cos_from_t_lin(tau) 107 | tau = torch.ones(input.shape[0], device=tau.device) * tau 108 | log_ratio, prediction = get_log_ratio(discriminator, x_, tau, class_labels) 109 | correction = torch.autograd.grad(outputs=prediction.sum(), inputs=x_, retain_graph=False)[0] 110 | correction *= - ((std_wve_t[:,None,None,None] ** 2) * mean_vp_tau[:,None,None,None]) ## Scaling to Expected Denoised Point 111 | return correction, prediction 112 | 113 | def get_grad_log_ratio(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels, log=False): 114 | mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier 115 | if tau.min() > time_max or tau.min() < time_min or discriminator == None: 116 | if log: 117 | return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device), torch.zeros_like(unnormalized_input) 118 | return torch.zeros_like(unnormalized_input) 119 | else: 120 | input = mean_vp_tau[:,None,None,None] * unnormalized_input 121 | with torch.enable_grad(): 122 | x_ = input.float().clone().detach().requires_grad_() 123 | if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE 124 | tau = vpsde.compute_t_cos_from_t_lin(tau) 125 | tau = torch.ones(input.shape[0], device=tau.device) * tau 126 | log_ratio, prediction = get_log_ratio(discriminator, x_, tau, class_labels) 127 | discriminator_guidance_score = torch.autograd.grad(outputs=log_ratio.sum(), inputs=x_, retain_graph=False)[0] 128 | # print(mean_vp_tau.shape) 129 | # print(std_wve_t.shape) 130 | # print(discriminator_guidance_score.shape) 131 | discriminator_guidance_score *= - ((std_wve_t[:,None,None,None] ** 2) * mean_vp_tau[:,None,None,None]) 132 | if log: 133 | return discriminator_guidance_score, log_ratio, prediction 134 | return discriminator_guidance_score 135 | 136 | def get_prediction(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels, log=False): 137 | mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier 138 | if tau.min() > time_max or tau.min() < time_min or discriminator == None: 139 | if log: 140 | return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device), torch.zeros_like(unnormalized_input) 141 | return torch.zeros_like(unnormalized_input) 142 | else: 143 | input = mean_vp_tau[:,None,None,None] * unnormalized_input 144 | with torch.enable_grad(): 145 | x_ = input.float().clone().detach().requires_grad_() 146 | if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE 147 | tau = vpsde.compute_t_cos_from_t_lin(tau) 148 | tau = torch.ones(input.shape[0], device=tau.device) * tau 149 | log_ratio, prediction = get_log_ratio(discriminator, x_, tau, class_labels) 150 | return prediction 151 | 152 | 153 | 154 | def get_log_ratio(discriminator, input, time, class_labels): 155 | if discriminator == None: 156 | return torch.zeros(input.shape[0], device=input.device) 157 | else: 158 | logits = discriminator(input, timesteps=time, condition=class_labels) 159 | prediction = torch.clip(logits, 1e-5, 1. - 1e-5) 160 | log_ratio = torch.log(prediction / (1. - prediction)) 161 | return log_ratio, prediction 162 | 163 | class vpsde(): 164 | def __init__(self): 165 | self.beta_0 = 0.1 166 | self.beta_1 = 20. 167 | self.s = 0.008 168 | self.f_0 = np.cos(self.s / (1. + self.s) * np.pi / 2.) ** 2 169 | 170 | @property 171 | def T(self): 172 | return 1 173 | def compute_reverse_tau(self,tau): 174 | tau *= self.beta_1 - self.beta_0 175 | std_wve_t = (tau+self.beta_0) **2 - (self.beta_0 ** 2.) 176 | std_wve_t /= 2. * (self.beta_1 - self.beta_0) 177 | std_wve_t = torch.exp(std_wve_t) -1 178 | std_wve_t = std_wve_t **0.5 179 | 180 | return std_wve_t 181 | 182 | def compute_tau(self, std_wve_t): 183 | tau = -self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * torch.log(1. + std_wve_t ** 2)) 184 | tau /= self.beta_1 - self.beta_0 185 | return tau 186 | 187 | def marginal_prob(self, t): 188 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 189 | mean = torch.exp(log_mean_coeff) 190 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 191 | return mean, std 192 | 193 | def transform_unnormalized_wve_to_normalized_vp(self, t, std_out=False): 194 | tau = self.compute_tau(t) 195 | mean_vp_tau, std_vp_tau = self.marginal_prob(tau) 196 | if std_out: 197 | return mean_vp_tau, std_vp_tau, tau 198 | return mean_vp_tau, tau 199 | 200 | def compute_t_cos_from_t_lin(self, t_lin): 201 | sqrt_alpha_t_bar = torch.exp(-0.25 * t_lin ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t_lin * self.beta_0) 202 | time = torch.arccos(np.sqrt(self.f_0) * sqrt_alpha_t_bar) 203 | t_cos = self.T * ((1. + self.s) * 2. / np.pi * time - self.s) 204 | return t_cos 205 | 206 | def get_diffusion_time(self, batch_size, batch_device, t_min=1e-5, importance_sampling=True): 207 | if importance_sampling: 208 | Z = self.normalizing_constant(t_min) 209 | u = torch.rand(batch_size, device=batch_device) 210 | return (-self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2 * (self.beta_1 - self.beta_0) * 211 | torch.log(1. + torch.exp(Z * u + self.antiderivative(t_min))))) / (self.beta_1 - self.beta_0), Z.detach() 212 | else: 213 | return torch.rand(batch_size, device=batch_device) * (self.T - t_min) + t_min, 1 214 | 215 | def antiderivative(self, t, stabilizing_constant=0.): 216 | if isinstance(t, float) or isinstance(t, int): 217 | t = torch.tensor(t).float() 218 | return torch.log(1. - torch.exp(- self.integral_beta(t)) + stabilizing_constant) + self.integral_beta(t) 219 | 220 | def normalizing_constant(self, t_min): 221 | return self.antiderivative(self.T) - self.antiderivative(t_min) 222 | 223 | def integral_beta(self, t): 224 | return 0.5 * t ** 2 * (self.beta_1 - self.beta_0) + t * self.beta_0 225 | 226 | a = vpsde() 227 | 228 | # for i in range(100): 229 | # mean_vp_tau, tau = a.transform_unnormalized_wve_to_normalized_vp(torch.tensor(80.)) 230 | # print(i/100, tau) 231 | 232 | # compute_reverse_tau 233 | 234 | q = a.compute_reverse_tau(torch.tensor(0.)) 235 | print(q) -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | ): 42 | self.model = model 43 | self.diffusion = diffusion 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.microbatch = microbatch if microbatch > 0 else batch_size 47 | self.lr = lr 48 | self.ema_rate = ( 49 | [ema_rate] 50 | if isinstance(ema_rate, float) 51 | else [float(x) for x in ema_rate.split(",")] 52 | ) 53 | self.log_interval = log_interval 54 | self.save_interval = save_interval 55 | self.resume_checkpoint = resume_checkpoint 56 | self.use_fp16 = use_fp16 57 | self.fp16_scale_growth = fp16_scale_growth 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 59 | self.weight_decay = weight_decay 60 | self.lr_anneal_steps = lr_anneal_steps 61 | 62 | self.step = 0 63 | self.resume_step = 0 64 | self.global_batch = self.batch_size * dist.get_world_size() 65 | 66 | self.sync_cuda = th.cuda.is_available() 67 | 68 | self._load_and_sync_parameters() 69 | self.mp_trainer = MixedPrecisionTrainer( 70 | model=self.model, 71 | use_fp16=self.use_fp16, 72 | fp16_scale_growth=fp16_scale_growth, 73 | ) 74 | 75 | self.opt = AdamW( 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 77 | ) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.mp_trainer.master_params) 88 | for _ in range(len(self.ema_rate)) 89 | ] 90 | 91 | if th.cuda.is_available(): 92 | self.use_ddp = True 93 | self.ddp_model = DDP( 94 | self.model, 95 | device_ids=[dist_util.dev()], 96 | output_device=dist_util.dev(), 97 | broadcast_buffers=False, 98 | bucket_cap_mb=128, 99 | find_unused_parameters=False, 100 | ) 101 | else: 102 | if dist.get_world_size() > 1: 103 | logger.warn( 104 | "Distributed training requires CUDA. " 105 | "Gradients will not be synchronized properly!" 106 | ) 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | def _load_and_sync_parameters(self): 111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 112 | 113 | if resume_checkpoint: 114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 115 | if dist.get_rank() == 0: 116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 117 | self.model.load_state_dict( 118 | dist_util.load_state_dict( 119 | resume_checkpoint, map_location=dist_util.dev() 120 | ) 121 | ) 122 | 123 | dist_util.sync_params(self.model.parameters()) 124 | 125 | def _load_ema_parameters(self, rate): 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 127 | 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 130 | if ema_checkpoint: 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 133 | state_dict = dist_util.load_state_dict( 134 | ema_checkpoint, map_location=dist_util.dev() 135 | ) 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 137 | 138 | dist_util.sync_params(ema_params) 139 | return ema_params 140 | 141 | def _load_optimizer_state(self): 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 143 | opt_checkpoint = bf.join( 144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 145 | ) 146 | if bf.exists(opt_checkpoint): 147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 148 | state_dict = dist_util.load_state_dict( 149 | opt_checkpoint, map_location=dist_util.dev() 150 | ) 151 | self.opt.load_state_dict(state_dict) 152 | 153 | def run_loop(self): 154 | while ( 155 | not self.lr_anneal_steps 156 | or self.step + self.resume_step < self.lr_anneal_steps 157 | ): 158 | batch, cond = next(self.data) 159 | self.run_step(batch, cond) 160 | if self.step % self.log_interval == 0: 161 | logger.dumpkvs() 162 | if self.step % self.save_interval == 0: 163 | self.save() 164 | # Run for a finite amount of time in integration tests. 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 166 | return 167 | self.step += 1 168 | # Save the last checkpoint if it wasn't already saved. 169 | if (self.step - 1) % self.save_interval != 0: 170 | self.save() 171 | 172 | def run_step(self, batch, cond): 173 | self.forward_backward(batch, cond) 174 | took_step = self.mp_trainer.optimize(self.opt) 175 | if took_step: 176 | self._update_ema() 177 | self._anneal_lr() 178 | self.log_step() 179 | 180 | def forward_backward(self, batch, cond): 181 | self.mp_trainer.zero_grad() 182 | for i in range(0, batch.shape[0], self.microbatch): 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 184 | micro_cond = { 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) 186 | for k, v in cond.items() 187 | } 188 | last_batch = (i + self.microbatch) >= batch.shape[0] 189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 190 | 191 | compute_losses = functools.partial( 192 | self.diffusion.training_losses, 193 | self.ddp_model, 194 | micro, 195 | t, 196 | model_kwargs=micro_cond, 197 | ) 198 | 199 | if last_batch or not self.use_ddp: 200 | losses = compute_losses() 201 | else: 202 | with self.ddp_model.no_sync(): 203 | losses = compute_losses() 204 | 205 | if isinstance(self.schedule_sampler, LossAwareSampler): 206 | self.schedule_sampler.update_with_local_losses( 207 | t, losses["loss"].detach() 208 | ) 209 | 210 | loss = (losses["loss"] * weights).mean() 211 | log_loss_dict( 212 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 213 | ) 214 | self.mp_trainer.backward(loss) 215 | 216 | def _update_ema(self): 217 | for rate, params in zip(self.ema_rate, self.ema_params): 218 | update_ema(params, self.mp_trainer.master_params, rate=rate) 219 | 220 | def _anneal_lr(self): 221 | if not self.lr_anneal_steps: 222 | return 223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 224 | lr = self.lr * (1 - frac_done) 225 | for param_group in self.opt.param_groups: 226 | param_group["lr"] = lr 227 | 228 | def log_step(self): 229 | logger.logkv("step", self.step + self.resume_step) 230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 231 | 232 | def save(self): 233 | def save_checkpoint(rate, params): 234 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 235 | if dist.get_rank() == 0: 236 | logger.log(f"saving model {rate}...") 237 | if not rate: 238 | filename = f"model{(self.step+self.resume_step):06d}.pt" 239 | else: 240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 242 | th.save(state_dict, f) 243 | 244 | save_checkpoint(0, self.mp_trainer.master_params) 245 | for rate, params in zip(self.ema_rate, self.ema_params): 246 | save_checkpoint(rate, params) 247 | 248 | if dist.get_rank() == 0: 249 | with bf.BlobFile( 250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 251 | "wb", 252 | ) as f: 253 | th.save(self.opt.state_dict(), f) 254 | 255 | dist.barrier() 256 | 257 | 258 | def parse_resume_step_from_filename(filename): 259 | """ 260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 261 | checkpoint's number of steps. 262 | """ 263 | split = filename.split("model") 264 | if len(split) < 2: 265 | return 0 266 | split1 = split[-1].split(".")[0] 267 | try: 268 | return int(split1) 269 | except ValueError: 270 | return 0 271 | 272 | 273 | def get_blob_logdir(): 274 | # You can change this to be a separate path to save checkpoints to 275 | # a blobstore or some external drive. 276 | return logger.get_dir() 277 | 278 | 279 | def find_resume_checkpoint(): 280 | # On your infrastructure, you may want to override this to automatically 281 | # discover the latest checkpoint on your blob storage, etc. 282 | return None 283 | 284 | 285 | def find_ema_checkpoint(main_checkpoint, step, rate): 286 | if main_checkpoint is None: 287 | return None 288 | filename = f"ema_{rate}_{(step):06d}.pt" 289 | path = bf.join(bf.dirname(main_checkpoint), filename) 290 | if bf.exists(path): 291 | return path 292 | return None 293 | 294 | 295 | def log_loss_dict(diffusion, ts, losses): 296 | for key, values in losses.items(): 297 | logger.logkv_mean(key, values.mean().item()) 298 | # Log the quantiles (four quartiles, in particular). 299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 300 | quartile = int(4 * sub_t / diffusion.num_timesteps) 301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 302 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /training/training_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Main training loop.""" 9 | 10 | import os 11 | import time 12 | import copy 13 | import json 14 | import pickle 15 | import psutil 16 | import numpy as np 17 | import torch 18 | import dnnlib 19 | from torch_utils import distributed as dist 20 | from torch_utils import training_stats 21 | from torch_utils import misc 22 | 23 | import classifier_lib 24 | #---------------------------------------------------------------------------- 25 | 26 | def training_loop( 27 | run_dir = '.', # Output directory. 28 | dataset_kwargs = {}, # Options for training set. 29 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. 30 | network_kwargs = {}, # Options for model and preconditioning. 31 | loss_kwargs = {}, # Options for loss function. 32 | optimizer_kwargs = {}, # Options for optimizer. 33 | augment_kwargs = None, # Options for augmentation pipeline, None = disable. 34 | seed = 0, # Global random seed. 35 | batch_size = 512, # Total batch size for one training iteration. 36 | batch_gpu = None, # Limit batch size per GPU, None = no limit. 37 | total_kimg = 200000, # Training duration, measured in thousands of training images. 38 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. 39 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. 40 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration. 41 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. 42 | kimg_per_tick = 50, # Interval of progress prints. 43 | snapshot_ticks = 50, # How often to save network snapshots, None = disable. 44 | state_dump_ticks = 500, # How often to dump training state, None = disable. 45 | resume_pkl = None, # Start from the given network snapshot, None = random initialization. 46 | resume_state_dump = None, # Start from the given training state, None = reset training state. 47 | resume_kimg = 0, # Start from the given training progress. 48 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 49 | device = torch.device('cuda'), 50 | cla_path = None, 51 | dis_path = None, 52 | ): 53 | # Initialize. 54 | start_time = time.time() 55 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) 56 | torch.manual_seed(np.random.randint(1 << 31)) 57 | torch.backends.cudnn.benchmark = cudnn_benchmark 58 | torch.backends.cudnn.allow_tf32 = False 59 | torch.backends.cuda.matmul.allow_tf32 = False 60 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 61 | 62 | # Select batch size per GPU. 63 | batch_gpu_total = batch_size // dist.get_world_size() 64 | if batch_gpu is None or batch_gpu > batch_gpu_total: 65 | batch_gpu = batch_gpu_total 66 | num_accumulation_rounds = batch_gpu_total // batch_gpu 67 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() 68 | 69 | # Load dataset. 70 | dist.print0('Loading dataset...') 71 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset 72 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) 73 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) 74 | 75 | # Construct network. 76 | dist.print0('Constructing network...') 77 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) 78 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 79 | net.train().requires_grad_(True).to(device) 80 | if dist.get_rank() == 0: 81 | with torch.no_grad(): 82 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device) 83 | sigma = torch.ones([batch_gpu], device=device) 84 | labels = torch.zeros([batch_gpu, net.label_dim], device=device) 85 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2) 86 | 87 | # Setup optimizer. 88 | dist.print0('Setting up optimizer...') 89 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss 90 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer 91 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe 92 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) 93 | ema = copy.deepcopy(net).eval().requires_grad_(False) 94 | 95 | # Load Discriminator 96 | cla = classifier_lib.load_classifier(cla_path, net.img_resolution, device, eval=True) 97 | dis = classifier_lib.load_discriminator(dis_path, device, False, eval=True, channel=512) 98 | discriminator = classifier_lib.get_discriminator(cla, dis,enable_grad=True) 99 | vpsde = classifier_lib.vpsde() 100 | 101 | # Resume training from previous snapshot. 102 | if resume_pkl is not None: 103 | dist.print0(f'Loading network weights from "{resume_pkl}"...') 104 | if dist.get_rank() != 0: 105 | torch.distributed.barrier() # rank 0 goes first 106 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: 107 | data = pickle.load(f) 108 | if dist.get_rank() == 0: 109 | torch.distributed.barrier() # other ranks follow 110 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) 111 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) 112 | del data # conserve memory 113 | if resume_state_dump: 114 | dist.print0(f'Loading training state from "{resume_state_dump}"...') 115 | data = torch.load(resume_state_dump, map_location=torch.device('cpu')) 116 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) 117 | optimizer.load_state_dict(data['optimizer_state']) 118 | del data # conserve memory 119 | 120 | # Train. 121 | dist.print0(f'Training for {total_kimg} kimg...') 122 | dist.print0() 123 | cur_nimg = resume_kimg * 1000 124 | cur_tick = 0 125 | tick_start_nimg = cur_nimg 126 | tick_start_time = time.time() 127 | maintenance_time = tick_start_time - start_time 128 | dist.update_progress(cur_nimg // 1000, total_kimg) 129 | stats_jsonl = None 130 | while True: 131 | 132 | # Accumulate gradients. 133 | optimizer.zero_grad(set_to_none=True) 134 | for round_idx in range(num_accumulation_rounds): 135 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 136 | images, labels = next(dataset_iterator) 137 | images = images.to(device).to(torch.float32) / 127.5 - 1 138 | labels = labels.to(device) 139 | loss = loss_fn(vpsde=vpsde, dis=discriminator, net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) 140 | training_stats.report('Loss/loss', loss) 141 | loss.sum().mul(loss_scaling / batch_gpu_total).backward() 142 | 143 | # Update weights. 144 | for g in optimizer.param_groups: 145 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) 146 | for param in net.parameters(): 147 | if param.grad is not None: 148 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 149 | optimizer.step() 150 | 151 | # Update EMA. 152 | ema_halflife_nimg = ema_halflife_kimg * 1000 153 | if ema_rampup_ratio is not None: 154 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) 155 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) 156 | for p_ema, p_net in zip(ema.parameters(), net.parameters()): 157 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) 158 | 159 | # Perform maintenance tasks once per tick. 160 | cur_nimg += batch_size 161 | done = (cur_nimg >= total_kimg * 1000) 162 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): 163 | continue 164 | 165 | # Print status line, accumulating the same information in training_stats. 166 | tick_end_time = time.time() 167 | fields = [] 168 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] 169 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] 170 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] 171 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] 172 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] 173 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] 174 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] 175 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] 176 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] 177 | torch.cuda.reset_peak_memory_stats() 178 | dist.print0(' '.join(fields)) 179 | 180 | # Check for abort. 181 | if (not done) and dist.should_stop(): 182 | done = True 183 | dist.print0() 184 | dist.print0('Aborting...') 185 | 186 | # Save network snapshot. 187 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): 188 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) 189 | for key, value in data.items(): 190 | if isinstance(value, torch.nn.Module): 191 | value = copy.deepcopy(value).eval().requires_grad_(False) 192 | misc.check_ddp_consistency(value) 193 | data[key] = value.cpu() 194 | del value # conserve memory 195 | if dist.get_rank() == 0: 196 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: 197 | pickle.dump(data, f) 198 | del data # conserve memory 199 | 200 | # Save full dump of the training state. 201 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: 202 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) 203 | 204 | # Update logs. 205 | training_stats.default_collector.update() 206 | if dist.get_rank() == 0: 207 | if stats_jsonl is None: 208 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at') 209 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n') 210 | stats_jsonl.flush() 211 | dist.update_progress(cur_nimg // 1000, total_kimg) 212 | 213 | # Update state. 214 | cur_tick += 1 215 | tick_start_nimg = cur_nimg 216 | tick_start_time = time.time() 217 | maintenance_time = tick_start_time - tick_end_time 218 | if done: 219 | break 220 | 221 | # Done. 222 | dist.print0() 223 | dist.print0('Exiting...') 224 | 225 | #---------------------------------------------------------------------------- 226 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Train diffusion-based generative model using the techniques described in the 9 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import os 12 | import re 13 | import json 14 | import click 15 | import torch 16 | import dnnlib 17 | from torch_utils import distributed as dist 18 | from training import training_loop 19 | 20 | import warnings 21 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 22 | 23 | #---------------------------------------------------------------------------- 24 | # Parse a comma separated list of numbers or ranges and return a list of ints. 25 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 26 | 27 | def parse_int_list(s): 28 | if isinstance(s, list): return s 29 | ranges = [] 30 | range_re = re.compile(r'^(\d+)-(\d+)$') 31 | for p in s.split(','): 32 | m = range_re.match(p) 33 | if m: 34 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 35 | else: 36 | ranges.append(int(p)) 37 | return ranges 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @click.command() 42 | 43 | # Main options. 44 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True) 45 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True) 46 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True) 47 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True) 48 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm', 'tiw_edm']), default='tiw_edm', show_default=True) 49 | 50 | # Hyperparameters. 51 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True) 52 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) 53 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) 54 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int) 55 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list) 56 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 57 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True) 58 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) 59 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True) 60 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) 61 | 62 | # Performance-related. 63 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 64 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 65 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 66 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) 67 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 68 | 69 | # I/O-related. 70 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 71 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True) 72 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True) 73 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) 74 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True) 75 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int) 76 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 77 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str) 78 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) 79 | 80 | ## TIW config 81 | @click.option('--cla_path', help='String to include in result dir name', metavar='STR', default="/home/aailab/alsdudrla10/FirstArticle/Github/TIW-DSM/checkpoints/discriminator/feature_extractor/32x32_classifier.pt", type=str) 82 | @click.option('--dis_path', help='String to include in result dir name', metavar='STR', default="/home/aailab/alsdudrla10/FirstArticle/Github/TIW-DSM/checkpoints/discriminator/cifar10/unbias_500/discriminator_501.pt", type=str) 83 | 84 | def main(**kwargs): 85 | """Train diffusion-based generative model using the techniques described in the 86 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 87 | 88 | Examples: 89 | 90 | \b 91 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 92 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ 93 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 94 | """ 95 | opts = dnnlib.EasyDict(kwargs) 96 | torch.multiprocessing.set_start_method('spawn') 97 | dist.init() 98 | 99 | # Initialize config dict. 100 | c = dnnlib.EasyDict() 101 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) 102 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) 103 | c.network_kwargs = dnnlib.EasyDict() 104 | c.loss_kwargs = dnnlib.EasyDict() 105 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8) 106 | 107 | # Ckpt path for weighting func 108 | c.cla_path = opts.cla_path 109 | c.dis_path = opts.dis_path 110 | c.resolution = opts.resolution 111 | 112 | # Validate dataset options. 113 | try: 114 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) 115 | dataset_name = dataset_obj.name 116 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution 117 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size 118 | if opts.cond and not dataset_obj.has_labels: 119 | raise click.ClickException('--cond=True requires labels specified in dataset.json') 120 | del dataset_obj # conserve memory 121 | except IOError as err: 122 | raise click.ClickException(f'--data: {err}') 123 | 124 | # Network architecture. 125 | if opts.arch == 'ddpmpp': 126 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') 127 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2]) 128 | elif opts.arch == 'ncsnpp': 129 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') 130 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2]) 131 | else: 132 | assert opts.arch == 'adm' 133 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) 134 | 135 | # Preconditioning & loss function. 136 | if opts.precond == 'vp': 137 | c.network_kwargs.class_name = 'training.networks.VPPrecond' 138 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 139 | elif opts.precond == 've': 140 | c.network_kwargs.class_name = 'training.networks.VEPrecond' 141 | c.loss_kwargs.class_name = 'training.loss.VELoss' 142 | elif opts.precond == 'edm': 143 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 144 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 145 | elif opts.precond =='tiw_edm': 146 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 147 | c.loss_kwargs.class_name = 'training.loss.TIW_EDMLoss' 148 | else: 149 | assert 0 150 | 151 | # Network options. 152 | if opts.cbase is not None: 153 | c.network_kwargs.model_channels = opts.cbase 154 | if opts.cres is not None: 155 | c.network_kwargs.channel_mult = opts.cres 156 | if opts.augment: 157 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) 158 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 159 | c.network_kwargs.augment_dim = 9 160 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) 161 | 162 | # Training options. 163 | c.total_kimg = max(int(opts.duration * 1000), 1) 164 | c.ema_halflife_kimg = int(opts.ema * 1000) 165 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 166 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) 167 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump) 168 | 169 | # Random seed. 170 | if opts.seed is not None: 171 | c.seed = opts.seed 172 | else: 173 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 174 | torch.distributed.broadcast(seed, src=0) 175 | c.seed = int(seed) 176 | 177 | # Transfer learning and resume. 178 | if opts.transfer is not None: 179 | if opts.resume is not None: 180 | raise click.ClickException('--transfer and --resume cannot be specified at the same time') 181 | c.resume_pkl = opts.transfer 182 | c.ema_rampup_ratio = None 183 | elif opts.resume is not None: 184 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume)) 185 | if not match or not os.path.isfile(opts.resume): 186 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') 187 | c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl') 188 | c.resume_kimg = int(match.group(1)) 189 | c.resume_state_dump = opts.resume 190 | 191 | # Description string. 192 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond' 193 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32' 194 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}' 195 | if opts.desc is not None: 196 | desc += f'-{opts.desc}' 197 | 198 | # Pick output directory. 199 | if dist.get_rank() != 0: 200 | c.run_dir = None 201 | elif opts.nosubdir: 202 | c.run_dir = opts.outdir 203 | else: 204 | prev_run_dirs = [] 205 | if os.path.isdir(opts.outdir): 206 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] 207 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 208 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 209 | cur_run_id = max(prev_run_ids, default=-1) + 1 210 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') 211 | assert not os.path.exists(c.run_dir) 212 | 213 | # Print options. 214 | dist.print0() 215 | dist.print0('Training options:') 216 | dist.print0(json.dumps(c, indent=2)) 217 | dist.print0() 218 | dist.print0(f'Output directory: {c.run_dir}') 219 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}') 220 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') 221 | dist.print0(f'Network architecture: {opts.arch}') 222 | dist.print0(f'Preconditioning & loss: {opts.precond}') 223 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 224 | dist.print0(f'Batch size: {c.batch_size}') 225 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') 226 | dist.print0() 227 | 228 | # Dry run? 229 | if opts.dry_run: 230 | dist.print0('Dry run; exiting.') 231 | return 232 | 233 | # Create output directory. 234 | dist.print0('Creating output directory...') 235 | if dist.get_rank() == 0: 236 | os.makedirs(c.run_dir, exist_ok=True) 237 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 238 | json.dump(c, f, indent=2) 239 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) 240 | 241 | # Train. 242 | training_loop.training_loop(**c) 243 | 244 | #---------------------------------------------------------------------------- 245 | 246 | if __name__ == "__main__": 247 | main() 248 | 249 | #---------------------------------------------------------------------------- 250 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Generate random images using the techniques described in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import os 12 | import re 13 | import click 14 | import tqdm 15 | import pickle 16 | import numpy as np 17 | import torch 18 | import PIL.Image 19 | import dnnlib 20 | from torch_utils import distributed as dist 21 | 22 | #---------------------------------------------------------------------------- 23 | # Proposed EDM sampler (Algorithm 2). 24 | 25 | def edm_sampler( 26 | net, latents, class_labels=None, randn_like=torch.randn_like, 27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 29 | ): 30 | # Adjust noise levels based on what's supported by the network. 31 | sigma_min = max(sigma_min, net.sigma_min) 32 | sigma_max = min(sigma_max, net.sigma_max) 33 | 34 | # Time step discretization. 35 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 36 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 37 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 38 | 39 | # Main sampling loop. 40 | x_next = latents.to(torch.float64) * t_steps[0] 41 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 42 | x_cur = x_next 43 | 44 | # Increase noise temporarily. 45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 46 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 48 | 49 | # Euler step. 50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 51 | d_cur = (x_hat - denoised) / t_hat 52 | x_next = x_hat + (t_next - t_hat) * d_cur 53 | 54 | # Apply 2nd order correction. 55 | if i < num_steps - 1: 56 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 57 | d_prime = (x_next - denoised) / t_next 58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 59 | 60 | return x_next 61 | 62 | #---------------------------------------------------------------------------- 63 | # Generalized ablation sampler, representing the superset of all sampling 64 | # methods discussed in the paper. 65 | 66 | def ablation_sampler( 67 | net, latents, class_labels=None, randn_like=torch.randn_like, 68 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 69 | solver='heun', discretization='edm', schedule='linear', scaling='none', 70 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, 71 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 72 | ): 73 | assert solver in ['euler', 'heun'] 74 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 75 | assert schedule in ['vp', 've', 'linear'] 76 | assert scaling in ['vp', 'none'] 77 | 78 | # Helper functions for VP & VE noise level schedules. 79 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 80 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 81 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 82 | ve_sigma = lambda t: t.sqrt() 83 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 84 | ve_sigma_inv = lambda sigma: sigma ** 2 85 | 86 | # Select default noise level range based on the specified time step discretization. 87 | if sigma_min is None: 88 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) 89 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 90 | if sigma_max is None: 91 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) 92 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 93 | 94 | # Adjust noise levels based on what's supported by the network. 95 | sigma_min = max(sigma_min, net.sigma_min) 96 | sigma_max = min(sigma_max, net.sigma_max) 97 | 98 | # Compute corresponding betas for VP. 99 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 100 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 101 | 102 | # Define time steps in terms of noise level. 103 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 104 | if discretization == 'vp': 105 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 106 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 107 | elif discretization == 've': 108 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 109 | sigma_steps = ve_sigma(orig_t_steps) 110 | elif discretization == 'iddpm': 111 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 112 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 113 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 114 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 115 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 116 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 117 | else: 118 | assert discretization == 'edm' 119 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 120 | 121 | # Define noise level schedule. 122 | if schedule == 'vp': 123 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 124 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 125 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 126 | elif schedule == 've': 127 | sigma = ve_sigma 128 | sigma_deriv = ve_sigma_deriv 129 | sigma_inv = ve_sigma_inv 130 | else: 131 | assert schedule == 'linear' 132 | sigma = lambda t: t 133 | sigma_deriv = lambda t: 1 134 | sigma_inv = lambda sigma: sigma 135 | 136 | # Define scaling schedule. 137 | if scaling == 'vp': 138 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 139 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 140 | else: 141 | assert scaling == 'none' 142 | s = lambda t: 1 143 | s_deriv = lambda t: 0 144 | 145 | # Compute final time steps based on the corresponding noise levels. 146 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 147 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 148 | 149 | # Main sampling loop. 150 | t_next = t_steps[0] 151 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 152 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 153 | x_cur = x_next 154 | 155 | # Increase noise temporarily. 156 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 157 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) 158 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) 159 | 160 | # Euler step. 161 | h = t_next - t_hat 162 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) 163 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised 164 | x_prime = x_hat + alpha * h * d_cur 165 | t_prime = t_hat + alpha * h 166 | 167 | # Apply 2nd order correction. 168 | if solver == 'euler' or i == num_steps - 1: 169 | x_next = x_hat + h * d_cur 170 | else: 171 | assert solver == 'heun' 172 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) 173 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised 174 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 175 | 176 | return x_next 177 | 178 | #---------------------------------------------------------------------------- 179 | # Wrapper for torch.Generator that allows specifying a different random seed 180 | # for each sample in a minibatch. 181 | 182 | class StackedRandomGenerator: 183 | def __init__(self, device, seeds): 184 | super().__init__() 185 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 186 | 187 | def randn(self, size, **kwargs): 188 | assert size[0] == len(self.generators) 189 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 190 | 191 | def randn_like(self, input): 192 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 193 | 194 | def randint(self, *args, size, **kwargs): 195 | assert size[0] == len(self.generators) 196 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 197 | 198 | #---------------------------------------------------------------------------- 199 | # Parse a comma separated list of numbers or ranges and return a list of ints. 200 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 201 | 202 | def parse_int_list(s): 203 | if isinstance(s, list): return s 204 | ranges = [] 205 | range_re = re.compile(r'^(\d+)-(\d+)$') 206 | for p in s.split(','): 207 | m = range_re.match(p) 208 | if m: 209 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 210 | else: 211 | ranges.append(int(p)) 212 | return ranges 213 | 214 | #---------------------------------------------------------------------------- 215 | 216 | @click.command() 217 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) 218 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) 219 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True) 220 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 221 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 222 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 223 | 224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 228 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 229 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 230 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 231 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 232 | 233 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun'])) 234 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm'])) 235 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear'])) 236 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none'])) 237 | 238 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs): 239 | """Generate random images using the techniques described in the paper 240 | "Elucidating the Design Space of Diffusion-Based Generative Models". 241 | 242 | Examples: 243 | 244 | \b 245 | # Generate 64 images and save them as out/*.png 246 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\ 247 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 248 | 249 | \b 250 | # Generate 1024 images using 2 GPUs 251 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\ 252 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 253 | """ 254 | dist.init() 255 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 256 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches) 257 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 258 | 259 | # Rank 0 goes first. 260 | if dist.get_rank() != 0: 261 | torch.distributed.barrier() 262 | 263 | # Load network. 264 | dist.print0(f'Loading network from "{network_pkl}"...') 265 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 266 | net = pickle.load(f)['ema'].to(device) 267 | 268 | # Other ranks follow. 269 | if dist.get_rank() == 0: 270 | torch.distributed.barrier() 271 | 272 | # Loop over batches. 273 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...') 274 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)): 275 | torch.distributed.barrier() 276 | batch_size = len(batch_seeds) 277 | if batch_size == 0: 278 | continue 279 | 280 | # Pick latents and labels. 281 | rnd = StackedRandomGenerator(device, batch_seeds) 282 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 283 | class_labels = None 284 | if net.label_dim: 285 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] 286 | if class_idx is not None: 287 | class_labels[:, :] = 0 288 | class_labels[:, class_idx] = 1 289 | 290 | # Generate images. 291 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} 292 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling']) 293 | sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler 294 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs) 295 | 296 | # Save images. 297 | images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() 298 | for seed, image_np in zip(batch_seeds, images_np): 299 | image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir 300 | os.makedirs(image_dir, exist_ok=True) 301 | image_path = os.path.join(image_dir, f'{seed:06d}.png') 302 | if image_np.shape[2] == 1: 303 | PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path) 304 | else: 305 | PIL.Image.fromarray(image_np, 'RGB').save(image_path) 306 | 307 | # Done. 308 | torch.distributed.barrier() 309 | dist.print0('Done.') 310 | 311 | #---------------------------------------------------------------------------- 312 | 313 | if __name__ == "__main__": 314 | main() 315 | 316 | #---------------------------------------------------------------------------- 317 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 59 | self.file = None 60 | 61 | if file_name is not None: 62 | self.file = open(file_name, file_mode) 63 | 64 | self.should_flush = should_flush 65 | self.stdout = sys.stdout 66 | self.stderr = sys.stderr 67 | 68 | sys.stdout = self 69 | sys.stderr = self 70 | 71 | def __enter__(self) -> "Logger": 72 | return self 73 | 74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 75 | self.close() 76 | 77 | def write(self, text: Union[str, bytes]) -> None: 78 | """Write text to stdout (and a file) and optionally flush.""" 79 | if isinstance(text, bytes): 80 | text = text.decode() 81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 82 | return 83 | 84 | if self.file is not None: 85 | self.file.write(text) 86 | 87 | self.stdout.write(text) 88 | 89 | if self.should_flush: 90 | self.flush() 91 | 92 | def flush(self) -> None: 93 | """Flush written text to both stdout and a file, if open.""" 94 | if self.file is not None: 95 | self.file.flush() 96 | 97 | self.stdout.flush() 98 | 99 | def close(self) -> None: 100 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 101 | self.flush() 102 | 103 | # if using multiple loggers, prevent closing in wrong order 104 | if sys.stdout is self: 105 | sys.stdout = self.stdout 106 | if sys.stderr is self: 107 | sys.stderr = self.stderr 108 | 109 | if self.file is not None: 110 | self.file.close() 111 | self.file = None 112 | 113 | 114 | # Cache directories 115 | # ------------------------------------------------------------------------------------------ 116 | 117 | _dnnlib_cache_dir = None 118 | 119 | def set_cache_dir(path: str) -> None: 120 | global _dnnlib_cache_dir 121 | _dnnlib_cache_dir = path 122 | 123 | def make_cache_dir_path(*paths: str) -> str: 124 | if _dnnlib_cache_dir is not None: 125 | return os.path.join(_dnnlib_cache_dir, *paths) 126 | if 'DNNLIB_CACHE_DIR' in os.environ: 127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 128 | if 'HOME' in os.environ: 129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 130 | if 'USERPROFILE' in os.environ: 131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 133 | 134 | # Small util functions 135 | # ------------------------------------------------------------------------------------------ 136 | 137 | 138 | def format_time(seconds: Union[int, float]) -> str: 139 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 140 | s = int(np.rint(seconds)) 141 | 142 | if s < 60: 143 | return "{0}s".format(s) 144 | elif s < 60 * 60: 145 | return "{0}m {1:02}s".format(s // 60, s % 60) 146 | elif s < 24 * 60 * 60: 147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 148 | else: 149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 150 | 151 | 152 | def format_time_brief(seconds: Union[int, float]) -> str: 153 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 154 | s = int(np.rint(seconds)) 155 | 156 | if s < 60: 157 | return "{0}s".format(s) 158 | elif s < 60 * 60: 159 | return "{0}m {1:02}s".format(s // 60, s % 60) 160 | elif s < 24 * 60 * 60: 161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 162 | else: 163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 164 | 165 | 166 | def ask_yes_no(question: str) -> bool: 167 | """Ask the user the question until the user inputs a valid answer.""" 168 | while True: 169 | try: 170 | print("{0} [y/n]".format(question)) 171 | return strtobool(input().lower()) 172 | except ValueError: 173 | pass 174 | 175 | 176 | def tuple_product(t: Tuple) -> Any: 177 | """Calculate the product of the tuple elements.""" 178 | result = 1 179 | 180 | for v in t: 181 | result *= v 182 | 183 | return result 184 | 185 | 186 | _str_to_ctype = { 187 | "uint8": ctypes.c_ubyte, 188 | "uint16": ctypes.c_uint16, 189 | "uint32": ctypes.c_uint32, 190 | "uint64": ctypes.c_uint64, 191 | "int8": ctypes.c_byte, 192 | "int16": ctypes.c_int16, 193 | "int32": ctypes.c_int32, 194 | "int64": ctypes.c_int64, 195 | "float32": ctypes.c_float, 196 | "float64": ctypes.c_double 197 | } 198 | 199 | 200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 202 | type_str = None 203 | 204 | if isinstance(type_obj, str): 205 | type_str = type_obj 206 | elif hasattr(type_obj, "__name__"): 207 | type_str = type_obj.__name__ 208 | elif hasattr(type_obj, "name"): 209 | type_str = type_obj.name 210 | else: 211 | raise RuntimeError("Cannot infer type name from input") 212 | 213 | assert type_str in _str_to_ctype.keys() 214 | 215 | my_dtype = np.dtype(type_str) 216 | my_ctype = _str_to_ctype[type_str] 217 | 218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 219 | 220 | return my_dtype, my_ctype 221 | 222 | 223 | def is_pickleable(obj: Any) -> bool: 224 | try: 225 | with io.BytesIO() as stream: 226 | pickle.dump(obj, stream) 227 | return True 228 | except: 229 | return False 230 | 231 | 232 | # Functionality to import modules/objects by name, and call functions by name 233 | # ------------------------------------------------------------------------------------------ 234 | 235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 236 | """Searches for the underlying module behind the name to some python object. 237 | Returns the module and the object name (original name with module part removed).""" 238 | 239 | # allow convenience shorthands, substitute them by full names 240 | obj_name = re.sub("^np.", "numpy.", obj_name) 241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 242 | 243 | # list alternatives for (module_name, local_obj_name) 244 | parts = obj_name.split(".") 245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 246 | 247 | # try each alternative in turn 248 | for module_name, local_obj_name in name_pairs: 249 | try: 250 | module = importlib.import_module(module_name) # may raise ImportError 251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 252 | return module, local_obj_name 253 | except: 254 | pass 255 | 256 | # maybe some of the modules themselves contain errors? 257 | for module_name, _local_obj_name in name_pairs: 258 | try: 259 | importlib.import_module(module_name) # may raise ImportError 260 | except ImportError: 261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 262 | raise 263 | 264 | # maybe the requested attribute is missing? 265 | for module_name, local_obj_name in name_pairs: 266 | try: 267 | module = importlib.import_module(module_name) # may raise ImportError 268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 269 | except ImportError: 270 | pass 271 | 272 | # we are out of luck, but we have no idea why 273 | raise ImportError(obj_name) 274 | 275 | 276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 277 | """Traverses the object name and returns the last (rightmost) python object.""" 278 | if obj_name == '': 279 | return module 280 | obj = module 281 | for part in obj_name.split("."): 282 | obj = getattr(obj, part) 283 | return obj 284 | 285 | 286 | def get_obj_by_name(name: str) -> Any: 287 | """Finds the python object with the given name.""" 288 | module, obj_name = get_module_from_obj_name(name) 289 | return get_obj_from_module(module, obj_name) 290 | 291 | 292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 293 | """Finds the python object with the given name and calls it as a function.""" 294 | assert func_name is not None 295 | func_obj = get_obj_by_name(func_name) 296 | assert callable(func_obj) 297 | return func_obj(*args, **kwargs) 298 | 299 | 300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 301 | """Finds the python class with the given name and constructs it with the given arguments.""" 302 | return call_func_by_name(*args, func_name=class_name, **kwargs) 303 | 304 | 305 | def get_module_dir_by_obj_name(obj_name: str) -> str: 306 | """Get the directory path of the module containing the given object name.""" 307 | module, _ = get_module_from_obj_name(obj_name) 308 | return os.path.dirname(inspect.getfile(module)) 309 | 310 | 311 | def is_top_level_function(obj: Any) -> bool: 312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 314 | 315 | 316 | def get_top_level_function_name(obj: Any) -> str: 317 | """Return the fully-qualified name of a top-level function.""" 318 | assert is_top_level_function(obj) 319 | module = obj.__module__ 320 | if module == '__main__': 321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 322 | return module + "." + obj.__name__ 323 | 324 | 325 | # File system helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 329 | """List all files recursively in a given directory while ignoring given file and directory names. 330 | Returns list of tuples containing both absolute and relative paths.""" 331 | assert os.path.isdir(dir_path) 332 | base_name = os.path.basename(os.path.normpath(dir_path)) 333 | 334 | if ignores is None: 335 | ignores = [] 336 | 337 | result = [] 338 | 339 | for root, dirs, files in os.walk(dir_path, topdown=True): 340 | for ignore_ in ignores: 341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 342 | 343 | # dirs need to be edited in-place 344 | for d in dirs_to_remove: 345 | dirs.remove(d) 346 | 347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 348 | 349 | absolute_paths = [os.path.join(root, f) for f in files] 350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 351 | 352 | if add_base_to_relative: 353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 354 | 355 | assert len(absolute_paths) == len(relative_paths) 356 | result += zip(absolute_paths, relative_paths) 357 | 358 | return result 359 | 360 | 361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 362 | """Takes in a list of tuples of (src, dst) paths and copies files. 363 | Will create all necessary directories.""" 364 | for file in files: 365 | target_dir_name = os.path.dirname(file[1]) 366 | 367 | # will create all intermediate-level directories 368 | if not os.path.exists(target_dir_name): 369 | os.makedirs(target_dir_name) 370 | 371 | shutil.copyfile(file[0], file[1]) 372 | 373 | 374 | # URL helpers 375 | # ------------------------------------------------------------------------------------------ 376 | 377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 378 | """Determine whether the given object is a valid URL string.""" 379 | if not isinstance(obj, str) or not "://" in obj: 380 | return False 381 | if allow_file_urls and obj.startswith('file://'): 382 | return True 383 | try: 384 | res = requests.compat.urlparse(obj) 385 | if not res.scheme or not res.netloc or not "." in res.netloc: 386 | return False 387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 388 | if not res.scheme or not res.netloc or not "." in res.netloc: 389 | return False 390 | except: 391 | return False 392 | return True 393 | 394 | 395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 396 | """Download the given URL and return a binary-mode file object to access the data.""" 397 | assert num_attempts >= 1 398 | assert not (return_filename and (not cache)) 399 | 400 | # Doesn't look like an URL scheme so interpret it as a local filename. 401 | if not re.match('^[a-z]+://', url): 402 | return url if return_filename else open(url, "rb") 403 | 404 | # Handle file URLs. This code handles unusual file:// patterns that 405 | # arise on Windows: 406 | # 407 | # file:///c:/foo.txt 408 | # 409 | # which would translate to a local '/c:/foo.txt' filename that's 410 | # invalid. Drop the forward slash for such pathnames. 411 | # 412 | # If you touch this code path, you should test it on both Linux and 413 | # Windows. 414 | # 415 | # Some internet resources suggest using urllib.request.url2pathname() but 416 | # but that converts forward slashes to backslashes and this causes 417 | # its own set of problems. 418 | if url.startswith('file://'): 419 | filename = urllib.parse.urlparse(url).path 420 | if re.match(r'^/[a-zA-Z]:', filename): 421 | filename = filename[1:] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | assert is_url(url) 425 | 426 | # Lookup from cache. 427 | if cache_dir is None: 428 | cache_dir = make_cache_dir_path('downloads') 429 | 430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 431 | if cache: 432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 433 | if len(cache_files) == 1: 434 | filename = cache_files[0] 435 | return filename if return_filename else open(filename, "rb") 436 | 437 | # Download. 438 | url_name = None 439 | url_data = None 440 | with requests.Session() as session: 441 | if verbose: 442 | print("Downloading %s ..." % url, end="", flush=True) 443 | for attempts_left in reversed(range(num_attempts)): 444 | try: 445 | with session.get(url) as res: 446 | res.raise_for_status() 447 | if len(res.content) == 0: 448 | raise IOError("No data received") 449 | 450 | if len(res.content) < 8192: 451 | content_str = res.content.decode("utf-8") 452 | if "download_warning" in res.headers.get("Set-Cookie", ""): 453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 454 | if len(links) == 1: 455 | url = requests.compat.urljoin(url, links[0]) 456 | raise IOError("Google Drive virus checker nag") 457 | if "Google Drive - Quota exceeded" in content_str: 458 | raise IOError("Google Drive download quota exceeded -- please try again later") 459 | 460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 461 | url_name = match[1] if match else url 462 | url_data = res.content 463 | if verbose: 464 | print(" done") 465 | break 466 | except KeyboardInterrupt: 467 | raise 468 | except: 469 | if not attempts_left: 470 | if verbose: 471 | print(" failed") 472 | raise 473 | if verbose: 474 | print(".", end="", flush=True) 475 | 476 | # Save to cache. 477 | if cache: 478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 479 | safe_name = safe_name[:min(len(safe_name), 128)] 480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 482 | os.makedirs(cache_dir, exist_ok=True) 483 | with open(temp_file, "wb") as f: 484 | f.write(url_data) 485 | os.replace(temp_file, cache_file) # atomic 486 | if return_filename: 487 | return cache_file 488 | 489 | # Return data as file object. 490 | assert not return_filename 491 | return io.BytesIO(url_data) 492 | --------------------------------------------------------------------------------