├── runners └── __init__.py ├── models ├── improved_ddpm │ ├── __init__.py │ ├── fp16_util.py │ └── nn.py ├── guided_diffusion │ ├── __init__.py │ ├── nn.py │ └── fp16_util.py └── ema.py ├── edm ├── docs │ ├── afhqv2-64x64.png │ ├── ffhq-64x64.png │ ├── cifar10-32x32.png │ ├── imagenet-64x64.png │ ├── teaser-640x480.jpg │ ├── teaser-1280x640.jpg │ ├── teaser-1920x640.jpg │ ├── fid-help.txt │ ├── generate-help.txt │ ├── train-help.txt │ └── dataset-tool-help.txt ├── training │ ├── __init__.py │ ├── loss.py │ ├── dataset.py │ └── training_loop.py ├── torch_utils │ ├── __init__.py │ ├── distributed.py │ ├── persistence.py │ ├── training_stats.py │ └── misc.py ├── environment.yml ├── dnnlib │ └── __init__.py ├── Dockerfile ├── example.py └── fid.py ├── .gitignore ├── functions ├── losses.py ├── __init__.py ├── ckpt_util.py └── denoising.py ├── datasets ├── ffhq.py ├── vision.py ├── lsun.py ├── utils.py ├── imagenet64.py ├── celeba.py └── __init__.py ├── configs ├── celeba.yml ├── cifar10.yml ├── bedroom_guided.yml ├── imagenet64.yml └── imagenet128_guided.yml ├── README.md ├── main.py └── evaluate ├── fid_score.py └── inception.py /runners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/improved_ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /edm/docs/afhqv2-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/afhqv2-64x64.png -------------------------------------------------------------------------------- /edm/docs/ffhq-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/ffhq-64x64.png -------------------------------------------------------------------------------- /edm/docs/cifar10-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/cifar10-32x32.png -------------------------------------------------------------------------------- /edm/docs/imagenet-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/imagenet-64x64.png -------------------------------------------------------------------------------- /edm/docs/teaser-640x480.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/teaser-640x480.jpg -------------------------------------------------------------------------------- /edm/docs/teaser-1280x640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/teaser-1280x640.jpg -------------------------------------------------------------------------------- /edm/docs/teaser-1920x640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thudzj/Calibrated-DPMs/HEAD/edm/docs/teaser-1920x640.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | *exp* 3 | *fid_stats 4 | edm/datasets 5 | edm/downloads 6 | edm/fid-refs 7 | edm/generations 8 | edm/training-runs/ 9 | -------------------------------------------------------------------------------- /edm/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /edm/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /edm/environment.yml: -------------------------------------------------------------------------------- 1 | name: edm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python>=3.8, < 3.10 # package build failures on 3.10 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow>=8.3.1 11 | - scipy>=1.7.1 12 | - pytorch=1.12.1 13 | - psutil 14 | - requests 15 | - tqdm 16 | - imageio 17 | - pip: 18 | - imageio-ffmpeg>=0.4.3 19 | - pyspng 20 | -------------------------------------------------------------------------------- /edm/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /edm/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | FROM nvcr.io/nvidia/pytorch:22.10-py3 9 | 10 | ENV PYTHONDONTWRITEBYTECODE 1 11 | ENV PYTHONUNBUFFERED 1 12 | 13 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 14 | 15 | WORKDIR /workspace 16 | 17 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 18 | ENTRYPOINT ["/entry.sh"] 19 | -------------------------------------------------------------------------------- /functions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def noise_estimation_loss(model, 5 | x0: torch.Tensor, 6 | t: torch.LongTensor, 7 | e: torch.Tensor, 8 | b: torch.Tensor, keepdim=False): 9 | a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1) 10 | x = x0 * a.sqrt() + e * (1.0 - a).sqrt() 11 | output = model(x, t.float()) 12 | if keepdim: 13 | return (e - output).square().sum(dim=(1, 2, 3)) 14 | else: 15 | return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0) 16 | 17 | 18 | loss_registry = { 19 | 'simple': noise_estimation_loss, 20 | } 21 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError( 15 | 'Optimizer {} not understood.'.format(config.optim.optimizer)) 16 | -------------------------------------------------------------------------------- /datasets/ffhq.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FFHQ(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | target = 0 40 | 41 | return img, target -------------------------------------------------------------------------------- /configs/celeba.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CELEBA" 3 | image_size: 64 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | num_classes: 1 12 | 13 | model: 14 | model_type: "ddpm" 15 | is_upsampling: false 16 | type: "simple" 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: [1, 2, 2, 2, 4] 21 | num_res_blocks: 2 22 | attn_resolutions: [16, ] 23 | dropout: 0.1 24 | var_type: fixedlarge 25 | ema_rate: 0.9999 26 | ema: True 27 | resamp_with_conv: True 28 | ckpt_dir: "~/ddpm_ckpt/celeba/ckpt.pth" 29 | 30 | diffusion: 31 | beta_schedule: linear 32 | beta_start: 0.0001 33 | beta_end: 0.02 34 | num_diffusion_timesteps: 1000 35 | 36 | training: 37 | batch_size: 200 38 | n_epochs: 10000 39 | n_iters: 5000000 40 | snapshot_freq: 5000 41 | validation_freq: 20000 42 | 43 | sampling: 44 | total_N: 1000 45 | schedule: "linear" 46 | time_input_type: '1' 47 | batch_size: 200 48 | last_only: True 49 | fid_stats_dir: "fid_stats/fid_stats_celeba64_train_50000_ddim.npz" 50 | fid_total_samples: 50000 51 | fid_batch_size: 200 52 | cond_class: false 53 | classifier_scale: 0.0 54 | 55 | optim: 56 | weight_decay: 0.000 57 | optimizer: "Adam" 58 | lr: 0.0002 59 | beta1: 0.9 60 | amsgrad: false 61 | eps: 0.00000001 62 | grad_clip: 1.0 63 | -------------------------------------------------------------------------------- /configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "CIFAR10" 3 | image_size: 32 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | num_classes: 10 12 | # generations_dir: /home/zhijie/dpm-solver/example_pytorch/experiments/cifar10/image_samples/baseline-o3-50steps-200k 13 | 14 | model: 15 | model_type: "ddpm" 16 | is_upsampling: false 17 | type: "simple" 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [1, 2, 2, 2] 22 | num_res_blocks: 2 23 | attn_resolutions: [16, ] 24 | dropout: 0.1 25 | var_type: fixedlarge 26 | ema_rate: 0.9999 27 | ema: True 28 | resamp_with_conv: True 29 | 30 | diffusion: 31 | beta_schedule: linear 32 | beta_start: 0.0001 33 | beta_end: 0.02 34 | num_diffusion_timesteps: 1000 35 | 36 | training: 37 | batch_size: 128 38 | n_epochs: 10000 39 | n_iters: 5000000 40 | snapshot_freq: 5000 41 | validation_freq: 2000 42 | 43 | sampling: 44 | total_N: 1000 45 | schedule: "linear" 46 | time_input_type: '1' 47 | batch_size: 1000 48 | last_only: True 49 | fid_stats_dir: "fid_stats/fid_stats_cifar10_train_pytorch.npz" 50 | fid_total_samples: 50000 51 | fid_batch_size: 500 52 | likelihood_batch_size: 100 53 | cond_class: false 54 | classifier_scale: 0.0 55 | 56 | optim: 57 | weight_decay: 0.000 58 | optimizer: "Adam" 59 | lr: 0.0002 60 | beta1: 0.9 61 | amsgrad: false 62 | eps: 0.00000001 63 | grad_clip: 1.0 64 | -------------------------------------------------------------------------------- /configs/bedroom_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "LSUN" 3 | category: "bedroom" 4 | image_size: 256 5 | channels: 3 6 | logit_transform: false 7 | uniform_dequantization: false 8 | gaussian_dequantization: false 9 | random_flip: true 10 | rescaled: true 11 | num_workers: 32 12 | num_classes: 1 13 | root: "/data/LargeData/Large" 14 | 15 | model: 16 | model_type: "guided_diffusion" 17 | is_upsampling: false 18 | image_size: 256 19 | in_channels: 3 20 | model_channels: 256 21 | out_channels: 6 22 | num_res_blocks: 2 23 | attention_resolutions: [8, 16, 32] # [256 // 32, 256 // 16, 256 // 8] 24 | dropout: 0.1 25 | channel_mult: [1, 1, 2, 2, 4, 4] 26 | conv_resample: true 27 | dims: 2 28 | num_classes: null 29 | use_checkpoint: false 30 | use_fp16: true 31 | num_heads: 4 32 | num_head_channels: 64 33 | num_heads_upsample: -1 34 | use_scale_shift_norm: true 35 | resblock_updown: true 36 | use_new_attention_order: false 37 | var_type: fixedlarge 38 | ema: false 39 | ckpt_dir: "~/ddpm_ckpt/bedroom/lsun_bedroom.pt" 40 | 41 | diffusion: 42 | beta_schedule: linear 43 | beta_start: 0.0001 44 | beta_end: 0.02 45 | num_diffusion_timesteps: 1000 46 | 47 | training: 48 | batch_size: 25 49 | n_epochs: 10000 50 | n_iters: 5000000 51 | snapshot_freq: 5000 52 | validation_freq: 20000 53 | 54 | sampling: 55 | total_N: 1000 56 | schedule: "linear" 57 | time_input_type: '1' 58 | batch_size: 25 59 | last_only: True 60 | fid_stats_dir: "fid_stats/VIRTUAL_lsun_bedroom256.npz" 61 | fid_total_samples: 50000 62 | fid_batch_size: 50 63 | cond_class: false 64 | classifier_scale: 0.0 65 | -------------------------------------------------------------------------------- /configs/imagenet64.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET64" 3 | image_size: 64 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 4 11 | num_classes: 1000 12 | root: "/data/LargeData/Large/ImageNet" #"/home/dzj/imagenet64/" #""/data/LargeData/Large/imagenet64" #"/home/dzj/imagenet64/" #" 13 | # loader_type: "" #"custom" #None 14 | generations_dir: "" #"experiments/imagenet64/image_samples/original-10/" 15 | 16 | 17 | model: 18 | model_type: "improved_ddpm" 19 | is_upsampling: false 20 | in_channels: 3 21 | model_channels: 128 22 | out_channels: 6 23 | num_res_blocks: 3 24 | attention_resolutions: [4, 8] 25 | dropout: 0.0 26 | channel_mult: [1, 2, 3, 4] 27 | conv_resample: true 28 | dims: 2 29 | use_checkpoint: false 30 | num_heads: 4 31 | num_heads_upsample: -1 32 | use_scale_shift_norm: true 33 | var_type: fixedlarge 34 | use_fp16: false 35 | ema: false 36 | ckpt_dir: "~/ddpm_ckpt/imagenet64/imagenet64_uncond_100M_1500K.pt" 37 | 38 | diffusion: 39 | beta_schedule: linear 40 | beta_start: 0.0001 41 | beta_end: 0.02 42 | num_diffusion_timesteps: 1000 43 | 44 | training: 45 | batch_size: 200 46 | n_epochs: 10000 47 | n_iters: 5000000 48 | snapshot_freq: 5000 49 | validation_freq: 20000 50 | 51 | sampling: 52 | total_N: 4000 53 | schedule: "cosine" 54 | time_input_type: '2' 55 | batch_size: 200 56 | last_only: True 57 | fid_stats_dir: "fid_stats/fid_stats_imagenet64_train.npz" 58 | fid_total_samples: 50000 59 | fid_batch_size: 1000 60 | cond_class: false 61 | classifier_scale: 0.0 62 | -------------------------------------------------------------------------------- /edm/docs/fid-help.txt: -------------------------------------------------------------------------------- 1 | Usage: fid.py [OPTIONS] COMMAND [ARGS]... 2 | 3 | Calculate Frechet Inception Distance (FID). 4 | 5 | Examples: 6 | 7 | # Generate 50000 images and save them as fid-tmp/*/*.png 8 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \ 9 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 10 | 11 | # Calculate FID 12 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \ 13 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 14 | 15 | # Compute dataset reference statistics 16 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 17 | 18 | Options: 19 | --help Show this message and exit. 20 | 21 | Commands: 22 | calc Calculate FID for a given set of images. 23 | ref Calculate dataset reference statistics needed by 'calc'. 24 | 25 | 26 | Usage: fid.py calc [OPTIONS] 27 | 28 | Calculate FID for a given set of images. 29 | 30 | Options: 31 | --images PATH|ZIP Path to the images [required] 32 | --ref NPZ|URL Dataset reference statistics [required] 33 | --num INT Number of images to use [default: 50000; x>=2] 34 | --seed INT Random seed for selecting the images [default: 0] 35 | --batch INT Maximum batch size [default: 64; x>=1] 36 | --help Show this message and exit. 37 | 38 | 39 | Usage: fid.py ref [OPTIONS] 40 | 41 | Calculate dataset reference statistics needed by 'calc'. 42 | 43 | Options: 44 | --data PATH|ZIP Path to the dataset [required] 45 | --dest NPZ Destination .npz file [required] 46 | --batch INT Maximum batch size [default: 64; x>=1] 47 | --help Show this message and exit. 48 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /edm/docs/generate-help.txt: -------------------------------------------------------------------------------- 1 | Usage: generate.py [OPTIONS] 2 | 3 | Generate random images using the techniques described in the paper 4 | "Elucidating the Design Space of Diffusion-Based Generative Models". 5 | 6 | Examples: 7 | 8 | # Generate 64 images and save them as out/*.png 9 | python generate.py --outdir=out --seeds=0-63 --batch=64 \ 10 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 11 | 12 | # Generate 1024 images using 2 GPUs 13 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \ 14 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 15 | 16 | Options: 17 | --network PATH|URL Network pickle filename [required] 18 | --outdir DIR Where to save the output images [required] 19 | --seeds LIST Random seeds (e.g. 1,2,5-10) [default: 0-63] 20 | --subdirs Create subdirectory for every 1000 seeds 21 | --class INT Class label [default: random] [x>=0] 22 | --batch INT Maximum batch size [default: 64; x>=1] 23 | --steps INT Number of sampling steps [default: 18; x>=1] 24 | --sigma_min FLOAT Lowest noise level [default: varies] [x>0] 25 | --sigma_max FLOAT Highest noise level [default: varies] [x>0] 26 | --rho FLOAT Time step exponent [default: 7; x>0] 27 | --S_churn FLOAT Stochasticity strength [default: 0; x>=0] 28 | --S_min FLOAT Stoch. min noise level [default: 0; x>=0] 29 | --S_max FLOAT Stoch. max noise level [default: inf; x>=0] 30 | --S_noise FLOAT Stoch. noise inflation [default: 1] 31 | --solver euler|heun Ablate ODE solver 32 | --disc vp|ve|iddpm|edm Ablate time step discretization {t_i} 33 | --schedule vp|ve|linear Ablate noise schedule sigma(t) 34 | --scaling vp|none Ablate signal scaling s(t) 35 | --help Show this message and exit. 36 | -------------------------------------------------------------------------------- /configs/imagenet128_guided.yml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "IMAGENET128" 3 | image_size: 128 4 | channels: 3 5 | logit_transform: false 6 | uniform_dequantization: false 7 | gaussian_dequantization: false 8 | random_flip: true 9 | rescaled: true 10 | num_workers: 32 11 | num_classes: 1000 12 | 13 | model: 14 | model_type: "guided_diffusion" 15 | is_upsampling: false 16 | image_size: 128 17 | in_channels: 3 18 | model_channels: 256 19 | out_channels: 6 20 | num_res_blocks: 2 21 | attention_resolutions: [4, 8, 16] # [128 // 32, 128 // 16, 128 // 8] 22 | dropout: 0.0 23 | channel_mult: [1, 1, 2, 3, 4] 24 | conv_resample: true 25 | dims: 2 26 | num_classes: 1000 27 | use_checkpoint: false 28 | use_fp16: true 29 | num_heads: 4 30 | num_head_channels: -1 31 | num_heads_upsample: -1 32 | use_scale_shift_norm: true 33 | resblock_updown: true 34 | use_new_attention_order: false 35 | var_type: fixedlarge 36 | ema: false 37 | ckpt_dir: "~/ddpm_ckpt/imagenet128/128x128_diffusion.pt" 38 | 39 | classifier: 40 | ckpt_dir: "~/ddpm_ckpt/imagenet128/128x128_classifier.pt" 41 | image_size: 128 42 | in_channels: 3 43 | model_channels: 128 44 | out_channels: 1000 45 | num_res_blocks: 2 46 | attention_resolutions: [4, 8, 16] # [128 // 32, 128 // 16, 128 // 8] 47 | channel_mult: [1, 1, 2, 3, 4] 48 | use_fp16: true 49 | num_head_channels: 64 50 | use_scale_shift_norm: true 51 | resblock_updown: true 52 | pool: "attention" 53 | 54 | diffusion: 55 | beta_schedule: linear 56 | beta_start: 0.0001 57 | beta_end: 0.02 58 | num_diffusion_timesteps: 1000 59 | 60 | sampling: 61 | total_N: 1000 62 | schedule: "linear" 63 | time_input_type: '1' 64 | batch_size: 500 65 | last_only: True 66 | fid_stats_dir: "fid_stats/VIRTUAL_imagenet128_labeled.npz" 67 | fid_total_samples: 50000 68 | fid_batch_size: 200 69 | cond_class: true 70 | classifier_scale: 1.25 71 | -------------------------------------------------------------------------------- /edm/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /models/improved_ddpm/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() -------------------------------------------------------------------------------- /edm/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train diffusion-based generative model using the techniques described in the 4 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 5 | 6 | Examples: 7 | 8 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 9 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \ 10 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 11 | 12 | Options: 13 | --outdir DIR Where to save the results [required] 14 | --data ZIP|DIR Path to the dataset [required] 15 | --cond BOOL Train class-conditional model [default: False] 16 | --arch ddpmpp|ncsnpp|adm Network architecture [default: ddpmpp] 17 | --precond vp|ve|edm Preconditioning & loss function [default: edm] 18 | --duration MIMG Training duration [default: 200; x>0] 19 | --batch INT Total batch size [default: 512; x>=1] 20 | --batch-gpu INT Limit batch size per GPU [x>=1] 21 | --cbase INT Channel multiplier [default: varies] 22 | --cres LIST Channels per resolution [default: varies] 23 | --lr FLOAT Learning rate [default: 0.001; x>0] 24 | --ema MIMG EMA half-life [default: 0.5; x>=0] 25 | --dropout FLOAT Dropout probability [default: 0.13; 0<=x<=1] 26 | --augment FLOAT Augment probability [default: 0.12; 0<=x<=1] 27 | --xflip BOOL Enable dataset x-flips [default: False] 28 | --fp16 BOOL Enable mixed-precision training [default: False] 29 | --ls FLOAT Loss scaling [default: 1; x>0] 30 | --bench BOOL Enable cuDNN benchmarking [default: True] 31 | --cache BOOL Cache dataset in CPU memory [default: True] 32 | --workers INT DataLoader worker processes [default: 1; x>=1] 33 | --desc STR String to include in result dir name 34 | --nosubdir Do not create a subdirectory for results 35 | --tick KIMG How often to print progress [default: 50; x>=1] 36 | --snap TICKS How often to save snapshots [default: 50; x>=1] 37 | --dump TICKS How often to dump state [default: 500; x>=1] 38 | --seed INT Random seed [default: random] 39 | --transfer PKL|URL Transfer learning from network pickle 40 | --resume PT Resume from previous training state 41 | -n, --dry-run Print training options and exit 42 | --help Show this message and exit. 43 | -------------------------------------------------------------------------------- /edm/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ Load LSUN dataset 9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 10 | --source train-images-idx3-ubyte.gz Load MNIST dataset 11 | --source path/ Recursively load all images from path/ 12 | --source dataset.zip Recursively load all images from dataset.zip 13 | 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 18 | 19 | The output dataset format can be either an image folder or an uncompressed 20 | zip archive. Zip archives makes it easier to move datasets around file 21 | servers and clusters, and may offer better training performance on network 22 | file systems. 23 | 24 | Images within the dataset archive will be stored as uncompressed PNG. 25 | Uncompresed PNGs can be efficiently decoded in the training loop. 26 | 27 | Class labels are stored in a file called 'dataset.json' that is stored at 28 | the dataset root folder. This file has the following structure: 29 | 30 | { 31 | "labels": [ 32 | ["00000/img00000000.png",6], 33 | ["00000/img00000001.png",9], 34 | ... repeated for every image in the datase 35 | ["00049/img00049999.png",1] 36 | ] 37 | } 38 | 39 | If the 'dataset.json' file cannot be found, class labels are determined from 40 | top-level directory names. 41 | 42 | Image scale/crop and resolution requirements: 43 | 44 | Output images must be square-shaped and they must all have the same power- 45 | of-two dimensions. 46 | 47 | To scale arbitrary input image size to a specific width and height, use the 48 | --resolution option. Output resolution will be either the original input 49 | resolution (if resolution was not specified) or the one specified with 50 | --resolution option. 51 | 52 | Use the --transform=center-crop or --transform=center-crop-wide options to 53 | apply a center crop transform on the input image. These options should be 54 | used with the --resolution option. For example: 55 | 56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 57 | --transform=center-crop-wide --resolution=512x384 58 | 59 | Options: 60 | --source PATH Input directory or archive name [required] 61 | --dest PATH Output directory or archive name [required] 62 | --max-images INT Maximum number of images to output 63 | --transform MODE Input crop/resize mode 64 | --resolution WxH Output resolution (e.g., 512x512) 65 | --help Show this message and exit. 66 | -------------------------------------------------------------------------------- /functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | # Modify the path when necessary 59 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/ddpm_ckpt")) 60 | assert name in URL_MAP 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On Calibrating Diffusion Probabilistic Models 2 | 3 | The official code for the paper [On Calibrating Diffusion Probabilistic Models](https://arxiv.org/abs/2302.10688). 4 | 5 | -------------------- 6 | We propose a straightforward method for calibrating diffusion probabilistic models that reduces the values of SM objectives and increases model likelihood lower bounds. 7 | 8 | ## Acknowledgement 9 | The codes are modifed based on the [DPM-solver](https://github.com/LuChengTHU/dpm-solver) and [EDM](https://github.com/NVlabs/edm). 10 | 11 | 12 | ## Reproducing CIFAR-10 results on image generation and FID 13 | 14 | The command for computing the FID of **baseline** methods (without calibration): 15 | ```python 16 | python main.py --config cifar10.yml \ 17 | --exp=experiments/cifar10 \ 18 | --sample --fid \ 19 | --timesteps=20 \ 20 | --eta 0 --ni \ 21 | --skip_type=logSNR \ 22 | --sample_type=dpm_solver \ 23 | --start_time=1e-4 \ 24 | --dpm_solver_fast -i baseline 25 | ``` 26 | 27 | The command for computing the FID of **our** methods (with calibration): 28 | ```python 29 | python main.py --config cifar10.yml \ 30 | --exp=experiments/cifar10 \ 31 | --sample --fid \ 32 | --timesteps=20 \ 33 | --eta 0 --ni \ 34 | --skip_type=logSNR \ 35 | --sample_type=dpm_solver \ 36 | --start_time=1e-4 \ 37 | --dpm_solver_fast -i our --score_mean 38 | ``` 39 | 40 | ## Reproducing CelebA results on image generation and FID 41 | 42 | The command for computing the FID of **baseline** methods (without calibration): 43 | ```python 44 | python main.py --config celeba.yml \ 45 | --exp=experiments/celeba \ 46 | --sample --fid \ 47 | --timesteps=50 \ 48 | --eta 0 --ni \ 49 | --skip_type=logSNR \ 50 | --sample_type=dpm_solver \ 51 | --start_time=1e-4 \ 52 | --dpm_solver_fast -i baseline 53 | ``` 54 | 55 | The command for computing the FID of **our** methods (with calibration): 56 | ```python 57 | python main.py --config celeba.yml \ 58 | --exp=experiments/celeba \ 59 | --sample --fid \ 60 | --timesteps=50 \ 61 | --eta 0 --ni \ 62 | --skip_type=logSNR \ 63 | --sample_type=dpm_solver \ 64 | --start_time=1e-4 \ 65 | --dpm_solver_fast -i our --score_mean 66 | ``` 67 | 68 | ## Estimating SDE likelihood 69 | The command for running on **CIFAR-10**: 70 | ```python 71 | python main.py --config cifar10.yml \ 72 | --exp=experiments/cifar10 \ 73 | --sample --eta 0 \ 74 | --ni --start_time=1e-4 \ 75 | -i temp --likelihood sde 76 | ``` 77 | 78 | The command for running on **CelebA**: 79 | ```python 80 | python main.py --config celeba.yml \ 81 | --exp=experiments/celeba \ 82 | --sample --eta 0 \ 83 | --ni --start_time=1e-4 \ 84 | -i temp --likelihood sde 85 | ``` 86 | 87 | ## Estimating the average estimated score with EDM 88 | 89 | ```python 90 | cd edm/; 91 | 92 | # CIFAR-10 93 | python torch.distributed.run --master_port 12315 --nproc_per_node=1 generate.py --outdir=generations/cifar10/temp --seeds=0-49999 --subdirs --method our --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl 94 | 95 | # ImageNet 96 | python torch.distributed.run --master_port 12311 --nproc_per_node=1 generate.py --outdir=generations/imagenet/temp --seeds=0-49999 --subdirs --steps=256 --S_churn=40 --S_min=0.05 --S_max=50 --S_noise=1.003 --method our --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl 97 | ``` 98 | The commands for running on `FFHQ` and `AFHQv2` are similar. 99 | -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | has_transforms = transforms is not None 15 | has_separate_transform = transform is not None or target_transform is not None 16 | if has_transforms and has_separate_transform: 17 | raise ValueError("Only transforms or transform/target_transform can " 18 | "be passed as argument") 19 | 20 | # for backwards-compatibility 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | if has_separate_transform: 25 | transforms = StandardTransform(transform, target_transform) 26 | self.transforms = transforms 27 | 28 | def __getitem__(self, index): 29 | raise NotImplementedError 30 | 31 | def __len__(self): 32 | raise NotImplementedError 33 | 34 | def __repr__(self): 35 | head = "Dataset " + self.__class__.__name__ 36 | body = ["Number of datapoints: {}".format(self.__len__())] 37 | if self.root is not None: 38 | body.append("Root location: {}".format(self.root)) 39 | body += self.extra_repr().splitlines() 40 | if hasattr(self, 'transform') and self.transform is not None: 41 | body += self._format_transform_repr(self.transform, 42 | "Transforms: ") 43 | if hasattr(self, 'target_transform') and self.target_transform is not None: 44 | body += self._format_transform_repr(self.target_transform, 45 | "Target transforms: ") 46 | lines = [head] + [" " * self._repr_indent + line for line in body] 47 | return '\n'.join(lines) 48 | 49 | def _format_transform_repr(self, transform, head): 50 | lines = transform.__repr__().splitlines() 51 | return (["{}{}".format(head, lines[0])] + 52 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 53 | 54 | def extra_repr(self): 55 | return "" 56 | 57 | 58 | class StandardTransform(object): 59 | def __init__(self, transform=None, target_transform=None): 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def __call__(self, input, target): 64 | if self.transform is not None: 65 | input = self.transform(input) 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | return input, target 69 | 70 | def _format_transform_repr(self, transform, head): 71 | lines = transform.__repr__().splitlines() 72 | return (["{}{}".format(head, lines[0])] + 73 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 74 | 75 | def __repr__(self): 76 | body = [self.__class__.__name__] 77 | if self.transform is not None: 78 | body += self._format_transform_repr(self.transform, 79 | "Transform: ") 80 | if self.target_transform is not None: 81 | body += self._format_transform_repr(self.target_transform, 82 | "Target transform: ") 83 | 84 | return '\n'.join(body) 85 | -------------------------------------------------------------------------------- /edm/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Minimal standalone example to reproduce the main results from the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import tqdm 12 | import pickle 13 | import numpy as np 14 | import torch 15 | import PIL.Image 16 | import dnnlib 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def generate_image_grid( 21 | network_pkl, dest_path, 22 | seed=0, gridw=8, gridh=8, device=torch.device('cuda'), 23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 25 | ): 26 | batch_size = gridw * gridh 27 | torch.manual_seed(seed) 28 | 29 | # Load network. 30 | print(f'Loading network from "{network_pkl}"...') 31 | with dnnlib.util.open_url(network_pkl) as f: 32 | net = pickle.load(f)['ema'].to(device) 33 | 34 | # Pick latents and labels. 35 | print(f'Generating {batch_size} images...') 36 | latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 37 | class_labels = None 38 | if net.label_dim: 39 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)] 40 | 41 | # Adjust noise levels based on what's supported by the network. 42 | sigma_min = max(sigma_min, net.sigma_min) 43 | sigma_max = min(sigma_max, net.sigma_max) 44 | 45 | # Time step discretization. 46 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) 47 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 48 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 49 | 50 | # Main sampling loop. 51 | x_next = latents.to(torch.float64) * t_steps[0] 52 | for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1 53 | x_cur = x_next 54 | 55 | # Increase noise temporarily. 56 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 57 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 58 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) 59 | 60 | # Euler step. 61 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 62 | d_cur = (x_hat - denoised) / t_hat 63 | x_next = x_hat + (t_next - t_hat) * d_cur 64 | 65 | # Apply 2nd order correction. 66 | if i < num_steps - 1: 67 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 68 | d_prime = (x_next - denoised) / t_next 69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 70 | 71 | # Save image grid. 72 | print(f'Saving image grid to "{dest_path}"...') 73 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8) 74 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) 75 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels) 76 | image = image.cpu().numpy() 77 | PIL.Image.fromarray(image, 'RGB').save(dest_path) 78 | print('Done.') 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def main(): 83 | model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained' 84 | generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35 85 | generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 1.97, NFE = 79 86 | generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79 87 | generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | if __name__ == "__main__": 92 | main() 93 | 94 | #---------------------------------------------------------------------------- 95 | -------------------------------------------------------------------------------- /edm/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | from torch_utils import persistence 13 | 14 | #---------------------------------------------------------------------------- 15 | # Loss function corresponding to the variance preserving (VP) formulation 16 | # from the paper "Score-Based Generative Modeling through Stochastic 17 | # Differential Equations". 18 | 19 | @persistence.persistent_class 20 | class VPLoss: 21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 22 | self.beta_d = beta_d 23 | self.beta_min = beta_min 24 | self.epsilon_t = epsilon_t 25 | 26 | def __call__(self, net, images, labels, augment_pipe=None): 27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 29 | weight = 1 / sigma ** 2 30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 31 | n = torch.randn_like(y) * sigma 32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 33 | loss = weight * ((D_yn - y) ** 2) 34 | return loss 35 | 36 | def sigma(self, t): 37 | t = torch.as_tensor(t) 38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 39 | 40 | #---------------------------------------------------------------------------- 41 | # Loss function corresponding to the variance exploding (VE) formulation 42 | # from the paper "Score-Based Generative Modeling through Stochastic 43 | # Differential Equations". 44 | 45 | @persistence.persistent_class 46 | class VELoss: 47 | def __init__(self, sigma_min=0.02, sigma_max=100): 48 | self.sigma_min = sigma_min 49 | self.sigma_max = sigma_max 50 | 51 | def __call__(self, net, images, labels, augment_pipe=None): 52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 54 | weight = 1 / sigma ** 2 55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 56 | n = torch.randn_like(y) * sigma 57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 58 | loss = weight * ((D_yn - y) ** 2) 59 | return loss 60 | 61 | #---------------------------------------------------------------------------- 62 | # Improved loss function proposed in the paper "Elucidating the Design Space 63 | # of Diffusion-Based Generative Models" (EDM). 64 | 65 | @persistence.persistent_class 66 | class EDMLoss: 67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, reg_on_mean=False, cal_weight=0): 68 | self.P_mean = P_mean 69 | self.P_std = P_std 70 | self.sigma_data = sigma_data 71 | self.reg_on_mean = reg_on_mean 72 | self.cal_weight = cal_weight 73 | 74 | def __call__(self, net, images, labels=None, augment_pipe=None): 75 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 76 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 77 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 78 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 79 | n = torch.randn_like(y) * sigma 80 | if self.reg_on_mean: 81 | D_yn, score_means = net(y + n, sigma, labels, augment_labels=augment_labels, return_all=True) 82 | loss = weight * ((y - D_yn - score_means.detach()) ** 2) 83 | 84 | # to make the score net calibrated (weighted by an extra coefficient cal_weight) 85 | score = (y + n - D_yn) 86 | # loss2 = (score_means.detach() + score)**2 * self.cal_weight * weight 87 | 88 | # to make the small net approximate score mean 89 | loss3 = (score_means - score.mean(0).detach())**2 90 | 91 | return loss, loss3 #, loss2 92 | else: 93 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 94 | loss = weight * ((D_yn - y) ** 2) 95 | return loss 96 | 97 | #---------------------------------------------------------------------------- 98 | -------------------------------------------------------------------------------- /functions/denoising.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def compute_alpha(beta, t): 6 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 7 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 8 | return a 9 | 10 | 11 | def cond_fn(x, t_discrete, y, classifier, classifier_scale): 12 | assert y is not None 13 | with torch.enable_grad(): 14 | x_in = x.detach().requires_grad_(True) 15 | logits = classifier(x_in, t_discrete) 16 | log_probs = F.log_softmax(logits, dim=-1) 17 | selected = log_probs[range(len(logits)), y.view(-1)] 18 | return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale 19 | 20 | 21 | def generalized_steps(x, seq, model_fn, b, eta=0, is_cond_classifier=False, classifier=None, classifier_scale=1.0, **model_kwargs): 22 | with torch.no_grad(): 23 | def model(x, t_discrete): 24 | if is_cond_classifier: 25 | y = model_kwargs.get("y", None) 26 | if y is None: 27 | raise ValueError("For classifier guidance, the label y has to be in the input.") 28 | noise_uncond = model_fn(x, t_discrete, **model_kwargs) 29 | cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale) 30 | at = compute_alpha(b, t_discrete.long()) 31 | sigma_t = (1 - at).sqrt() 32 | return noise_uncond - sigma_t * cond_grad 33 | else: 34 | return model_fn(x, t_discrete, **model_kwargs) 35 | n = x.size(0) 36 | seq_next = [-1] + list(seq[:-1]) 37 | x0_preds = [] 38 | xs = [x] 39 | for i, j in zip(reversed(seq), reversed(seq_next)): 40 | t = (torch.ones(n) * i).to(x.device) 41 | next_t = (torch.ones(n) * j).to(x.device) 42 | at = compute_alpha(b, t.long()) 43 | at_next = compute_alpha(b, next_t.long()) 44 | xt = xs[-1].to('cuda') 45 | et = model(xt, t) 46 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 47 | x0_preds.append(x0_t.to('cpu')) 48 | c1 = ( 49 | eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 50 | ) 51 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 52 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 53 | xs.append(xt_next.to('cpu')) 54 | 55 | return xs, x0_preds 56 | 57 | 58 | def ddpm_steps(x, seq, model_fn, b, is_cond_classifier=False, classifier=None, classifier_scale=1.0, **model_kwargs): 59 | with torch.no_grad(): 60 | def model(x, t_discrete): 61 | if is_cond_classifier: 62 | y = model_kwargs.get("y", None) 63 | if y is None: 64 | raise ValueError("For classifier guidance, the label y has to be in the input.") 65 | noise_uncond = model_fn(x, t_discrete, **model_kwargs) 66 | cond_grad = cond_fn(x, t_discrete, y, classifier=classifier, classifier_scale=classifier_scale) 67 | at = compute_alpha(b, t_discrete.long()) 68 | sigma_t = (1 - at).sqrt() 69 | return noise_uncond - sigma_t * cond_grad 70 | else: 71 | return model_fn(x, t_discrete, **model_kwargs) 72 | n = x.size(0) 73 | seq_next = [-1] + list(seq[:-1]) 74 | xs = [x] 75 | x0_preds = [] 76 | betas = b 77 | for i, j in zip(reversed(seq), reversed(seq_next)): 78 | t = (torch.ones(n) * i).to(x.device) 79 | next_t = (torch.ones(n) * j).to(x.device) 80 | at = compute_alpha(betas, t.long()) 81 | atm1 = compute_alpha(betas, next_t.long()) 82 | beta_t = 1 - at / atm1 83 | x = xs[-1].to('cuda') 84 | 85 | output = model(x, t.float()) 86 | e = output 87 | 88 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 89 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 90 | x0_preds.append(x0_from_e.to('cpu')) 91 | mean_eps = ( 92 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 93 | ) / (1.0 - at) 94 | 95 | mean = mean_eps 96 | noise = torch.randn_like(x) 97 | mask = 1 - (t == 0).float() 98 | mask = mask.view(-1, 1, 1, 1) 99 | logvar = beta_t.log() 100 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 101 | xs.append(sample.to('cpu')) 102 | return xs, x0_preds 103 | -------------------------------------------------------------------------------- /models/improved_ddpm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | :param target_params: the target parameter sequence. 60 | :param source_params: the source parameter sequence. 61 | :param rate: the EMA rate (closer to 1 means slower). 62 | """ 63 | for targ, src in zip(target_params, source_params): 64 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def scale_module(module, scale): 77 | """ 78 | Scale the parameters of a module and return it. 79 | """ 80 | for p in module.parameters(): 81 | p.detach().mul_(scale) 82 | return module 83 | 84 | 85 | def mean_flat(tensor): 86 | """ 87 | Take the mean over all non-batch dimensions. 88 | """ 89 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 90 | 91 | 92 | def normalization(channels): 93 | """ 94 | Make a standard normalization layer. 95 | :param channels: number of input channels. 96 | :return: an nn.Module for normalization. 97 | """ 98 | return GroupNorm32(32, channels) 99 | 100 | 101 | def timestep_embedding(timesteps, dim, max_period=10000): 102 | """ 103 | Create sinusoidal timestep embeddings. 104 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 105 | These may be fractional. 106 | :param dim: the dimension of the output. 107 | :param max_period: controls the minimum frequency of the embeddings. 108 | :return: an [N x dim] Tensor of positional embeddings. 109 | """ 110 | half = dim // 2 111 | freqs = th.exp( 112 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 113 | ).to(device=timesteps.device) 114 | args = timesteps[:, None].float() * freqs[None] 115 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 116 | if dim % 2: 117 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 118 | return embedding 119 | 120 | 121 | def checkpoint(func, inputs, params, flag): 122 | """ 123 | Evaluate a function without caching intermediate activations, allowing for 124 | reduced memory at the expense of extra compute in the backward pass. 125 | :param func: the function to evaluate. 126 | :param inputs: the argument sequence to pass to `func`. 127 | :param params: a sequence of parameters `func` depends on but does not 128 | explicitly take as arguments. 129 | :param flag: if False, disable gradient checkpointing. 130 | """ 131 | if flag: 132 | args = tuple(inputs) + tuple(params) 133 | return CheckpointFunction.apply(func, len(inputs), *args) 134 | else: 135 | return func(*inputs) 136 | 137 | 138 | class CheckpointFunction(th.autograd.Function): 139 | @staticmethod 140 | def forward(ctx, run_function, length, *args): 141 | ctx.run_function = run_function 142 | ctx.input_tensors = list(args[:length]) 143 | ctx.input_params = list(args[length:]) 144 | with th.no_grad(): 145 | output_tensors = ctx.run_function(*ctx.input_tensors) 146 | return output_tensors 147 | 148 | @staticmethod 149 | def backward(ctx, *output_grads): 150 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 151 | with th.enable_grad(): 152 | # Fixes a bug where the first op in run_function modifies the 153 | # Tensor storage in place, which is not allowed for detach()'d 154 | # Tensors. 155 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 156 | output_tensors = ctx.run_function(*shallow_copies) 157 | input_grads = th.autograd.grad( 158 | output_tensors, 159 | ctx.input_tensors + ctx.input_params, 160 | output_grads, 161 | allow_unused=True, 162 | ) 163 | del ctx.input_tensors 164 | del ctx.input_params 165 | del output_tensors 166 | return (None, None) + input_grads -------------------------------------------------------------------------------- /models/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads -------------------------------------------------------------------------------- /datasets/lsun.py: -------------------------------------------------------------------------------- 1 | from .vision import VisionDataset 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import io 6 | from collections.abc import Iterable 7 | import pickle 8 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str 9 | 10 | 11 | class LSUNClass(VisionDataset): 12 | def __init__(self, root, transform=None, target_transform=None): 13 | import lmdb 14 | 15 | super(LSUNClass, self).__init__( 16 | root, transform=transform, target_transform=target_transform 17 | ) 18 | 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False, 26 | ) 27 | with self.env.begin(write=False) as txn: 28 | self.length = txn.stat()["entries"] 29 | root_split = root.split("/") 30 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") 31 | if os.path.isfile(cache_file): 32 | self.keys = pickle.load(open(cache_file, "rb")) 33 | else: 34 | with self.env.begin(write=False) as txn: 35 | self.keys = [key for key, _ in txn.cursor()] 36 | pickle.dump(self.keys, open(cache_file, "wb")) 37 | 38 | def __getitem__(self, index): 39 | img, target = None, None 40 | env = self.env 41 | with env.begin(write=False) as txn: 42 | imgbuf = txn.get(self.keys[index]) 43 | 44 | buf = io.BytesIO() 45 | buf.write(imgbuf) 46 | buf.seek(0) 47 | img = Image.open(buf).convert("RGB") 48 | 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return img, target 56 | 57 | def __len__(self): 58 | return self.length 59 | 60 | 61 | class LSUN(VisionDataset): 62 | """ 63 | `LSUN `_ dataset. 64 | 65 | Args: 66 | root (string): Root directory for the database files. 67 | classes (string or list): One of {'train', 'val', 'test'} or a list of 68 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 69 | transform (callable, optional): A function/transform that takes in an PIL image 70 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 71 | target_transform (callable, optional): A function/transform that takes in the 72 | target and transforms it. 73 | """ 74 | 75 | def __init__(self, root, classes="train", transform=None, target_transform=None): 76 | super(LSUN, self).__init__( 77 | root, transform=transform, target_transform=target_transform 78 | ) 79 | self.classes = self._verify_classes(classes) 80 | 81 | # for each class, create an LSUNClassDataset 82 | self.dbs = [] 83 | for c in self.classes: 84 | self.dbs.append( 85 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) 86 | ) 87 | 88 | self.indices = [] 89 | count = 0 90 | for db in self.dbs: 91 | count += len(db) 92 | self.indices.append(count) 93 | 94 | self.length = count 95 | 96 | def _verify_classes(self, classes): 97 | categories = [ 98 | "bedroom", 99 | "bridge", 100 | "church_outdoor", 101 | "classroom", 102 | "conference_room", 103 | "dining_room", 104 | "kitchen", 105 | "living_room", 106 | "restaurant", 107 | "tower", 108 | ] 109 | dset_opts = ["train", "val", "test"] 110 | 111 | try: 112 | verify_str_arg(classes, "classes", dset_opts) 113 | if classes == "test": 114 | classes = [classes] 115 | else: 116 | classes = [c + "_" + classes for c in categories] 117 | except ValueError: 118 | if not isinstance(classes, Iterable): 119 | msg = ( 120 | "Expected type str or Iterable for argument classes, " 121 | "but got type {}." 122 | ) 123 | raise ValueError(msg.format(type(classes))) 124 | 125 | classes = list(classes) 126 | msg_fmtstr = ( 127 | "Expected type str for elements in argument classes, " 128 | "but got type {}." 129 | ) 130 | for c in classes: 131 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 132 | c_short = c.split("_") 133 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1] 134 | 135 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 136 | msg = msg_fmtstr.format( 137 | category, "LSUN class", iterable_to_str(categories) 138 | ) 139 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 140 | 141 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 142 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 143 | 144 | return classes 145 | 146 | def __getitem__(self, index): 147 | """ 148 | Args: 149 | index (int): Index 150 | 151 | Returns: 152 | tuple: Tuple (image, target) where target is the index of the target category. 153 | """ 154 | target = 0 155 | sub = 0 156 | for ind in self.indices: 157 | if index < ind: 158 | break 159 | target += 1 160 | sub = ind 161 | 162 | db = self.dbs[target] 163 | index = index - sub 164 | 165 | if self.target_transform is not None: 166 | target = self.target_transform(target) 167 | 168 | img, _ = db[index] 169 | return img, target 170 | 171 | def __len__(self): 172 | return self.length 173 | 174 | def extra_repr(self): 175 | return "Classes: {classes}".format(**self.__dict__) 176 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /datasets/imagenet64.py: -------------------------------------------------------------------------------- 1 | """Create tools for Pytorch to load Downsampled ImageNet (32X32,64X64) 2 | 3 | Thanks to the cifar.py provided by Pytorch. 4 | 5 | Author: Xu Ma. 6 | Date: Apr/21/2019 7 | 8 | Data Preparation: 9 | 1. Download unsampled data from ImageNet website. 10 | 2. Unzip file to rootPath. eg: /home/xm0036/Datasets/ImageNet64(no train, val folders) 11 | 12 | Remark: 13 | This tool is able to automatic recognize downsampled size. 14 | 15 | 16 | Use this tool like cifar10 in datsets/torchvision. 17 | """ 18 | 19 | 20 | from __future__ import print_function 21 | from PIL import Image 22 | import os 23 | import os.path 24 | import numpy as np 25 | import sys 26 | if sys.version_info[0] == 2: 27 | import cPickle as pickle 28 | else: 29 | import pickle 30 | 31 | import torch.utils.data as data 32 | 33 | 34 | class ImageNetDownSample(data.Dataset): 35 | """`DownsampleImageNet`_ Dataset. 36 | 37 | Args: 38 | root (string): Root directory of dataset where directory 39 | train (bool, optional): If True, creates dataset from training set, otherwise 40 | creates from test set. 41 | transform (callable, optional): A function/transform that takes in an PIL image 42 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 43 | target_transform (callable, optional): A function/transform that takes in the 44 | target and transforms it. 45 | download (bool, optional): If true, downloads the dataset from the internet and 46 | puts it in root directory. If dataset is already downloaded, it is not 47 | downloaded again. 48 | 49 | """ 50 | 51 | train_list = [ 52 | ['train_data_batch_1'], 53 | ['train_data_batch_2'], 54 | ['train_data_batch_3'], 55 | ['train_data_batch_4'], 56 | ['train_data_batch_5'], 57 | ['train_data_batch_6'], 58 | ['train_data_batch_7'], 59 | ['train_data_batch_8'], 60 | ['train_data_batch_9'], 61 | ['train_data_batch_10'] 62 | ] 63 | test_list = [ 64 | ['val_data'], 65 | ] 66 | 67 | def __init__(self, root, train=True, 68 | transform=None, target_transform=None): 69 | self.root = os.path.expanduser(root) 70 | self.transform = transform 71 | self.target_transform = target_transform 72 | self.train = train # training set or test set 73 | 74 | # now load the picked numpy arrays 75 | if self.train: 76 | self.train_data = [] 77 | self.train_labels = [] 78 | for fentry in self.train_list: 79 | f = fentry[0] 80 | file = os.path.join(self.root, f) 81 | fo = open(file, 'rb') 82 | if sys.version_info[0] == 2: 83 | entry = pickle.load(fo) 84 | else: 85 | entry = pickle.load(fo, encoding='latin1') 86 | self.train_data.append(entry['data']) 87 | if 'labels' in entry: 88 | self.train_labels += entry['labels'] 89 | else: 90 | self.train_labels += entry['fine_labels'] 91 | fo.close() 92 | # resize label range from [1,1000] to [0,1000), 93 | # This is required by CrossEntropyLoss 94 | self.train_labels[:] = [x - 1 for x in self.train_labels] 95 | 96 | self.train_data = np.concatenate(self.train_data) 97 | [picnum, pixel] = self.train_data.shape 98 | pixel = int(np.sqrt(pixel / 3)) 99 | self.train_data = self.train_data.reshape((picnum, 3, pixel, pixel)) 100 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 101 | else: 102 | f = self.test_list[0][0] 103 | file = os.path.join(self.root, f) 104 | fo = open(file, 'rb') 105 | if sys.version_info[0] == 2: 106 | entry = pickle.load(fo) 107 | else: 108 | entry = pickle.load(fo, encoding='latin1') 109 | self.test_data = entry['data'] 110 | [picnum,pixel]= self.test_data.shape 111 | pixel = int(np.sqrt(pixel/3)) 112 | 113 | if 'labels' in entry: 114 | self.test_labels = entry['labels'] 115 | else: 116 | self.test_labels = entry['fine_labels'] 117 | fo.close() 118 | 119 | # resize label range from [1,1000] to [0,1000), 120 | # This is required by CrossEntropyLoss 121 | self.test_labels[:] = [x - 1 for x in self.test_labels] 122 | self.test_data = self.test_data.reshape((picnum, 3, pixel, pixel)) 123 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 124 | 125 | def __getitem__(self, index): 126 | """ 127 | Args: 128 | index (int): Index 129 | 130 | Returns: 131 | tuple: (image, target) where target is index of the target class. 132 | """ 133 | if self.train: 134 | img, target = self.train_data[index], self.train_labels[index] 135 | else: 136 | img, target = self.test_data[index], self.test_labels[index] 137 | 138 | # doing this so that it is consistent with all other datasets 139 | # to return a PIL Image 140 | img = Image.fromarray(img) 141 | 142 | if self.transform is not None: 143 | img = self.transform(img) 144 | 145 | if self.target_transform is not None: 146 | target = self.target_transform(target) 147 | 148 | return img, target 149 | 150 | def __len__(self): 151 | if self.train: 152 | return len(self.train_data) 153 | else: 154 | return len(self.test_data) 155 | 156 | 157 | def __repr__(self): 158 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 159 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 160 | tmp = 'train' if self.train is True else 'test' 161 | fmt_str += ' Split: {}\n'.format(tmp) 162 | fmt_str += ' Root Location: {}\n'.format(self.root) 163 | tmp = ' Transforms (if any): ' 164 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 165 | tmp = ' Target Transforms (if any): ' 166 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 167 | return fmt_str 168 | 169 | -------------------------------------------------------------------------------- /edm/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import click 12 | import tqdm 13 | import pickle 14 | import numpy as np 15 | import scipy.linalg 16 | import torch 17 | import dnnlib 18 | from torch_utils import distributed as dist 19 | from training import dataset 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def calculate_inception_stats( 24 | image_path, num_expected=None, seed=0, max_batch_size=64, 25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'), 26 | ): 27 | # Rank 0 goes first. 28 | if dist.get_rank() != 0: 29 | torch.distributed.barrier() 30 | 31 | # Load Inception-v3 model. 32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 33 | dist.print0('Loading Inception-v3 model...') 34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 35 | detector_kwargs = dict(return_features=True) 36 | feature_dim = 2048 37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 38 | detector_net = pickle.load(f).to(device) 39 | 40 | # List images. 41 | dist.print0(f'Loading images from "{image_path}"...') 42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 43 | if num_expected is not None and len(dataset_obj) < num_expected: 44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 45 | if len(dataset_obj) < 2: 46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 47 | 48 | # Other ranks follow. 49 | if dist.get_rank() == 0: 50 | torch.distributed.barrier() 51 | 52 | # Divide images into batches. 53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 57 | 58 | # Accumulate statistics. 59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 63 | torch.distributed.barrier() 64 | if images.shape[0] == 0: 65 | continue 66 | if images.shape[1] == 1: 67 | images = images.repeat([1, 3, 1, 1]) 68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 69 | mu += features.sum(0) 70 | sigma += features.T @ features 71 | 72 | # Calculate grand totals. 73 | torch.distributed.all_reduce(mu) 74 | torch.distributed.all_reduce(sigma) 75 | mu /= len(dataset_obj) 76 | sigma -= mu.ger(mu) * len(dataset_obj) 77 | sigma /= len(dataset_obj) - 1 78 | return mu.cpu().numpy(), sigma.cpu().numpy() 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 83 | m = np.square(mu - mu_ref).sum() 84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 85 | fid = m + np.trace(sigma + sigma_ref - s * 2) 86 | return float(np.real(fid)) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @click.group() 91 | def main(): 92 | """Calculate Frechet Inception Distance (FID). 93 | 94 | Examples: 95 | 96 | \b 97 | # Generate 50000 images and save them as fid-tmp/*/*.png 98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 100 | 101 | \b 102 | # Calculate FID 103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ 104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 105 | 106 | \b 107 | # Compute dataset reference statistics 108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 109 | """ 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | @main.command() 114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 119 | 120 | def calc(image_path, ref_path, num_expected, seed, batch): 121 | """Calculate FID for a given set of images.""" 122 | torch.multiprocessing.set_start_method('spawn') 123 | dist.init() 124 | 125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 126 | ref = None 127 | if dist.get_rank() == 0: 128 | with dnnlib.util.open_url(ref_path) as f: 129 | ref = dict(np.load(f)) 130 | 131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 132 | dist.print0('Calculating FID...') 133 | if dist.get_rank() == 0: 134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 135 | print(f'{fid:g}') 136 | torch.distributed.barrier() 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | @main.command() 141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 144 | 145 | def ref(dataset_path, dest_path, batch): 146 | """Calculate dataset reference statistics needed by 'calc'.""" 147 | torch.multiprocessing.set_start_method('spawn') 148 | dist.init() 149 | 150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 152 | if dist.get_rank() == 0: 153 | if os.path.dirname(dest_path): 154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 155 | np.savez(dest_path, mu=mu, sigma=sigma) 156 | 157 | torch.distributed.barrier() 158 | dist.print0('Done.') 159 | 160 | #---------------------------------------------------------------------------- 161 | 162 | if __name__ == "__main__": 163 | main() 164 | 165 | #---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it') 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | target = [] 137 | for t in self.target_type: 138 | if t == "attr": 139 | target.append(self.attr[index, :]) 140 | elif t == "identity": 141 | target.append(self.identity[index, 0]) 142 | elif t == "bbox": 143 | target.append(self.bbox[index, :]) 144 | elif t == "landmarks": 145 | target.append(self.landmarks_align[index, :]) 146 | else: 147 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return X, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /models/guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import shutil 4 | import logging 5 | import yaml 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | import torch.utils.tensorboard as tb 11 | 12 | from runners.diffusion import Diffusion 13 | 14 | torch.set_printoptions(sci_mode=False) 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 19 | 20 | parser.add_argument( 21 | "--config", type=str, required=True, help="Path to the config file" 22 | ) 23 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 24 | parser.add_argument( 25 | "--exp", type=str, default="exp", help="Path for saving running related data." 26 | ) 27 | parser.add_argument( 28 | "--doc", 29 | type=str, 30 | required=False, 31 | default="", 32 | help="A string for documentation purpose. " 33 | "Will be the name of the log folder.", 34 | ) 35 | parser.add_argument( 36 | "--comment", type=str, default="", help="A string for experiment comment" 37 | ) 38 | parser.add_argument( 39 | "--verbose", 40 | type=str, 41 | default="info", 42 | help="Verbose level: info | debug | warning | critical", 43 | ) 44 | parser.add_argument("--test", action="store_true", help="Whether to test the model") 45 | parser.add_argument( 46 | "--sample", 47 | action="store_true", 48 | help="Whether to produce samples from the model", 49 | ) 50 | parser.add_argument("--fid", action="store_true") 51 | parser.add_argument("--interpolation", action="store_true") 52 | parser.add_argument( 53 | "--resume_training", action="store_true", help="Whether to resume training" 54 | ) 55 | parser.add_argument( 56 | "-i", 57 | "--image_folder", 58 | type=str, 59 | default="images", 60 | help="The folder name of samples", 61 | ) 62 | parser.add_argument( 63 | "--ni", 64 | action="store_true", 65 | help="No interaction. Suitable for Slurm Job launcher", 66 | ) 67 | parser.add_argument( 68 | "--sample_type", 69 | type=str, 70 | default="generalized", 71 | help="sampling approach ('generalized'(DDIM) or 'ddpm_noisy'(DDPM) or 'dpm-solver')", 72 | ) 73 | parser.add_argument( 74 | "--skip_type", 75 | type=str, 76 | default="logSNR", 77 | help="skip according to ('uniform' or 'quadratic' for DDIM/DDPM; 'logSNR' or 'time_uniform' or 'time_quadratic' for DPM-Solver)", 78 | ) 79 | parser.add_argument( 80 | "--base_samples", 81 | type=str, 82 | default=None, 83 | help="base samples for upsampling, *.npz", 84 | ) 85 | parser.add_argument( 86 | "--timesteps", type=int, default=1000, help="number of steps involved" 87 | ) 88 | parser.add_argument( 89 | "--dpm_solver_order", type=int, default=3, help="order of dpm-solver" 90 | ) 91 | parser.add_argument( 92 | "--eta", 93 | type=float, 94 | default=0.0, 95 | help="eta used to control the variances of sigma", 96 | ) 97 | parser.add_argument( 98 | "--start_time", type=float, default=1e-4, help="start time for sampling" 99 | ) 100 | parser.add_argument( 101 | "--fixed_class", type=int, default=None, help="fixed class label for conditional sampling" 102 | ) 103 | parser.add_argument( 104 | "--dpm_solver_atol", type=float, default=0.0078, help="atol for adaptive step size algorithm" 105 | ) 106 | parser.add_argument( 107 | "--dpm_solver_rtol", type=float, default=0.05, help="rtol for adaptive step size algorithm" 108 | ) 109 | 110 | parser.add_argument("--sequence", action="store_true") 111 | parser.add_argument("--adaptive_step_size", action="store_true") 112 | parser.add_argument("--dpm_solver_fast", action="store_true") 113 | 114 | parser.add_argument( 115 | "--score_mean", action="store_true", default=False 116 | ) 117 | parser.add_argument( 118 | "--tradeoff", type=float, default=1 119 | ) 120 | parser.add_argument( 121 | "--subsample", type=int, default=None 122 | ) 123 | parser.add_argument( 124 | "--likelihood", type=str, default=None 125 | ) 126 | parser.add_argument( 127 | "--train_recorder", action="store_true", default=False 128 | ) 129 | 130 | args = parser.parse_args() 131 | args.log_path = os.path.join(args.exp, "logs", args.doc) 132 | 133 | # parse config file 134 | with open(os.path.join("configs", args.config), "r") as f: 135 | config = yaml.safe_load(f) 136 | new_config = dict2namespace(config) 137 | 138 | tb_path = os.path.join(args.exp, "tensorboard", args.doc) 139 | 140 | if not args.test and not args.sample: 141 | if not args.resume_training: 142 | if os.path.exists(args.log_path): 143 | overwrite = False 144 | if args.ni: 145 | overwrite = True 146 | else: 147 | response = input("Folder already exists. Overwrite? (Y/N)") 148 | if response.upper() == "Y": 149 | overwrite = True 150 | 151 | if overwrite: 152 | shutil.rmtree(args.log_path) 153 | shutil.rmtree(tb_path) 154 | os.makedirs(args.log_path) 155 | if os.path.exists(tb_path): 156 | shutil.rmtree(tb_path) 157 | else: 158 | print("Folder exists. Program halted.") 159 | sys.exit(0) 160 | else: 161 | os.makedirs(args.log_path) 162 | 163 | with open(os.path.join(args.log_path, "config.yml"), "w") as f: 164 | yaml.dump(new_config, f, default_flow_style=False) 165 | 166 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path) 167 | # setup logger 168 | level = getattr(logging, args.verbose.upper(), None) 169 | if not isinstance(level, int): 170 | raise ValueError("level {} not supported".format(args.verbose)) 171 | 172 | handler1 = logging.StreamHandler() 173 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) 174 | formatter = logging.Formatter( 175 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 176 | ) 177 | handler1.setFormatter(formatter) 178 | handler2.setFormatter(formatter) 179 | logger = logging.getLogger() 180 | logger.addHandler(handler1) 181 | logger.addHandler(handler2) 182 | logger.setLevel(level) 183 | 184 | else: 185 | level = getattr(logging, args.verbose.upper(), None) 186 | if not isinstance(level, int): 187 | raise ValueError("level {} not supported".format(args.verbose)) 188 | 189 | handler1 = logging.StreamHandler() 190 | formatter = logging.Formatter( 191 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 192 | ) 193 | handler1.setFormatter(formatter) 194 | logger = logging.getLogger() 195 | logger.addHandler(handler1) 196 | logger.setLevel(level) 197 | 198 | if args.sample: 199 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) 200 | args.image_folder = os.path.join( 201 | args.exp, "image_samples", args.image_folder 202 | ) 203 | if not os.path.exists(args.image_folder): 204 | os.makedirs(args.image_folder) 205 | else: 206 | if not (args.fid or args.interpolation): 207 | overwrite = False 208 | if args.ni: 209 | overwrite = True 210 | else: 211 | response = input( 212 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" 213 | ) 214 | if response.upper() == "Y": 215 | overwrite = True 216 | 217 | if overwrite: 218 | shutil.rmtree(args.image_folder) 219 | os.makedirs(args.image_folder) 220 | else: 221 | print("Output image folder exists. Program halted.") 222 | sys.exit(0) 223 | 224 | # add device 225 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 226 | logging.info("Using device: {}".format(device)) 227 | new_config.device = device 228 | 229 | # set random seed 230 | torch.manual_seed(args.seed) 231 | np.random.seed(args.seed) 232 | if torch.cuda.is_available(): 233 | torch.cuda.manual_seed_all(args.seed) 234 | 235 | torch.backends.cudnn.benchmark = True 236 | 237 | return args, new_config 238 | 239 | 240 | def dict2namespace(config): 241 | namespace = argparse.Namespace() 242 | for key, value in config.items(): 243 | if isinstance(value, dict): 244 | new_value = dict2namespace(value) 245 | else: 246 | new_value = value 247 | setattr(namespace, key, new_value) 248 | return namespace 249 | 250 | 251 | def main(): 252 | args, config = parse_args_and_config() 253 | logging.info("Writing log file to {}".format(args.log_path)) 254 | logging.info("Exp instance id = {}".format(os.getpid())) 255 | logging.info("Exp comment = {}".format(args.comment)) 256 | 257 | try: 258 | runner = Diffusion(args, config) 259 | if args.sample: 260 | runner.sample() 261 | elif args.test: 262 | runner.test() 263 | else: 264 | runner.train() 265 | except Exception: 266 | logging.error(traceback.format_exc()) 267 | 268 | return 0 269 | 270 | 271 | if __name__ == "__main__": 272 | sys.exit(main()) 273 | -------------------------------------------------------------------------------- /edm/training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | # Abstract base class for datasets. 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | cache = False, # Cache images in CPU memory? 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._cache = cache 40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 41 | self._raw_labels = None 42 | self._label_shape = None 43 | 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_raw_labels(self): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def __getstate__(self): 79 | return dict(self.__dict__, _raw_labels=None) 80 | 81 | def __del__(self): 82 | try: 83 | self.close() 84 | except: 85 | pass 86 | 87 | def __len__(self): 88 | return self._raw_idx.size 89 | 90 | def __getitem__(self, idx): 91 | raw_idx = self._raw_idx[idx] 92 | image = self._cached_images.get(raw_idx, None) 93 | if image is None: 94 | image = self._load_raw_image(raw_idx) 95 | if self._cache: 96 | self._cached_images[raw_idx] = image 97 | assert isinstance(image, np.ndarray) 98 | assert list(image.shape) == self.image_shape 99 | assert image.dtype == np.uint8 100 | if self._xflip[idx]: 101 | assert image.ndim == 3 # CHW 102 | image = image[:, :, ::-1] 103 | return image.copy(), self.get_label(idx) 104 | 105 | def get_label(self, idx): 106 | label = self._get_raw_labels()[self._raw_idx[idx]] 107 | if label.dtype == np.int64: 108 | onehot = np.zeros(self.label_shape, dtype=np.float32) 109 | onehot[label] = 1 110 | label = onehot 111 | return label.copy() 112 | 113 | def get_details(self, idx): 114 | d = dnnlib.EasyDict() 115 | d.raw_idx = int(self._raw_idx[idx]) 116 | d.xflip = (int(self._xflip[idx]) != 0) 117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 118 | return d 119 | 120 | @property 121 | def name(self): 122 | return self._name 123 | 124 | @property 125 | def image_shape(self): 126 | return list(self._raw_shape[1:]) 127 | 128 | @property 129 | def num_channels(self): 130 | assert len(self.image_shape) == 3 # CHW 131 | return self.image_shape[0] 132 | 133 | @property 134 | def resolution(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | assert self.image_shape[1] == self.image_shape[2] 137 | return self.image_shape[1] 138 | 139 | @property 140 | def label_shape(self): 141 | if self._label_shape is None: 142 | raw_labels = self._get_raw_labels() 143 | if raw_labels.dtype == np.int64: 144 | self._label_shape = [int(np.max(raw_labels)) + 1] 145 | else: 146 | self._label_shape = raw_labels.shape[1:] 147 | return list(self._label_shape) 148 | 149 | @property 150 | def label_dim(self): 151 | assert len(self.label_shape) == 1 152 | return self.label_shape[0] 153 | 154 | @property 155 | def has_labels(self): 156 | return any(x != 0 for x in self.label_shape) 157 | 158 | @property 159 | def has_onehot_labels(self): 160 | return self._get_raw_labels().dtype == np.int64 161 | 162 | #---------------------------------------------------------------------------- 163 | # Dataset subclass that loads images recursively from the specified directory 164 | # or ZIP file. 165 | 166 | class ImageFolderDataset(Dataset): 167 | def __init__(self, 168 | path, # Path to directory or zip. 169 | resolution = None, # Ensure specific resolution, None = highest available. 170 | use_pyspng = True, # Use pyspng if available? 171 | **super_kwargs, # Additional arguments for the Dataset base class. 172 | ): 173 | self._path = path 174 | self._use_pyspng = use_pyspng 175 | self._zipfile = None 176 | 177 | if os.path.isdir(self._path): 178 | self._type = 'dir' 179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 180 | elif self._file_ext(self._path) == '.zip': 181 | self._type = 'zip' 182 | self._all_fnames = set(self._get_zipfile().namelist()) 183 | else: 184 | raise IOError('Path must point to a directory or zip') 185 | 186 | PIL.Image.init() 187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 188 | if len(self._image_fnames) == 0: 189 | raise IOError('No image files found in the specified path') 190 | 191 | name = os.path.splitext(os.path.basename(self._path))[0] 192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 194 | raise IOError('Image files do not match the specified resolution') 195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 196 | 197 | @staticmethod 198 | def _file_ext(fname): 199 | return os.path.splitext(fname)[1].lower() 200 | 201 | def _get_zipfile(self): 202 | assert self._type == 'zip' 203 | if self._zipfile is None: 204 | self._zipfile = zipfile.ZipFile(self._path) 205 | return self._zipfile 206 | 207 | def _open_file(self, fname): 208 | if self._type == 'dir': 209 | return open(os.path.join(self._path, fname), 'rb') 210 | if self._type == 'zip': 211 | return self._get_zipfile().open(fname, 'r') 212 | return None 213 | 214 | def close(self): 215 | try: 216 | if self._zipfile is not None: 217 | self._zipfile.close() 218 | finally: 219 | self._zipfile = None 220 | 221 | def __getstate__(self): 222 | return dict(super().__getstate__(), _zipfile=None) 223 | 224 | def _load_raw_image(self, raw_idx): 225 | fname = self._image_fnames[raw_idx] 226 | with self._open_file(fname) as f: 227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 228 | image = pyspng.load(f.read()) 229 | else: 230 | image = np.array(PIL.Image.open(f)) 231 | if image.ndim == 2: 232 | image = image[:, :, np.newaxis] # HW => HWC 233 | image = image.transpose(2, 0, 1) # HWC => CHW 234 | return image 235 | 236 | def _load_raw_labels(self): 237 | fname = 'dataset.json' 238 | if fname not in self._all_fnames: 239 | return None 240 | with self._open_file(fname) as f: 241 | labels = json.load(f)['labels'] 242 | if labels is None: 243 | return None 244 | labels = dict(labels) 245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 246 | labels = np.array(labels) 247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 248 | return labels 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /edm/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /evaluate/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 37 | 38 | import numpy as np 39 | import torch 40 | import torchvision.transforms as TF 41 | from PIL import Image 42 | from scipy import linalg 43 | from torch.nn.functional import adaptive_avg_pool2d 44 | 45 | try: 46 | from tqdm import tqdm 47 | except ImportError: 48 | # If tqdm is not available, provide a mock version of it 49 | def tqdm(x): 50 | return x 51 | 52 | from pytorch_fid.inception import InceptionV3 53 | 54 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 55 | parser.add_argument('--batch-size', type=int, default=50, 56 | help='Batch size to use') 57 | parser.add_argument('--num-workers', type=int, 58 | help=('Number of processes to use for data loading. ' 59 | 'Defaults to `min(8, num_cpus)`')) 60 | parser.add_argument('--device', type=str, default=None, 61 | help='Device to use. Like cuda, cuda:0 or cpu') 62 | parser.add_argument('--dims', type=int, default=2048, 63 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 64 | help=('Dimensionality of Inception features to use. ' 65 | 'By default, uses pool3 features')) 66 | parser.add_argument('path', type=str, nargs=2, 67 | help=('Paths to the generated images or ' 68 | 'to .npz statistic files')) 69 | 70 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 71 | 'tif', 'tiff', 'webp'} 72 | 73 | 74 | class ImagePathDataset(torch.utils.data.Dataset): 75 | def __init__(self, files, transforms=None): 76 | self.files = files 77 | self.transforms = transforms 78 | 79 | def __len__(self): 80 | return len(self.files) 81 | 82 | def __getitem__(self, i): 83 | path = self.files[i] 84 | img = Image.open(path).convert('RGB') 85 | if self.transforms is not None: 86 | img = self.transforms(img) 87 | return img 88 | 89 | 90 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', 91 | num_workers=1): 92 | """Calculates the activations of the pool_3 layer for all images. 93 | 94 | Params: 95 | -- files : List of image files paths 96 | -- model : Instance of inception model 97 | -- batch_size : Batch size of images for the model to process at once. 98 | Make sure that the number of samples is a multiple of 99 | the batch size, otherwise some samples are ignored. This 100 | behavior is retained to match the original FID score 101 | implementation. 102 | -- dims : Dimensionality of features returned by Inception 103 | -- device : Device to run calculations 104 | -- num_workers : Number of parallel dataloader workers 105 | 106 | Returns: 107 | -- A numpy array of dimension (num images, dims) that contains the 108 | activations of the given tensor when feeding inception with the 109 | query tensor. 110 | """ 111 | model.eval() 112 | 113 | if batch_size > len(files): 114 | print(('Warning: batch size is bigger than the data size. ' 115 | 'Setting batch size to data size')) 116 | batch_size = len(files) 117 | 118 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 119 | dataloader = torch.utils.data.DataLoader(dataset, 120 | batch_size=batch_size, 121 | shuffle=False, 122 | drop_last=False, 123 | num_workers=num_workers) 124 | 125 | pred_arr = np.empty((len(files), dims)) 126 | 127 | start_idx = 0 128 | 129 | for batch in tqdm(dataloader): 130 | batch = batch.to(device) 131 | 132 | with torch.no_grad(): 133 | pred = model(batch)[0] 134 | 135 | # If model output is not scalar, apply global spatial average pooling. 136 | # This happens if you choose a dimensionality not equal 2048. 137 | if pred.size(2) != 1 or pred.size(3) != 1: 138 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 139 | 140 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 141 | 142 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 143 | 144 | start_idx = start_idx + pred.shape[0] 145 | 146 | return pred_arr 147 | 148 | 149 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 150 | """Numpy implementation of the Frechet Distance. 151 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 152 | and X_2 ~ N(mu_2, C_2) is 153 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 154 | 155 | Stable version by Dougal J. Sutherland. 156 | 157 | Params: 158 | -- mu1 : Numpy array containing the activations of a layer of the 159 | inception net (like returned by the function 'get_predictions') 160 | for generated samples. 161 | -- mu2 : The sample mean over activations, precalculated on an 162 | representative data set. 163 | -- sigma1: The covariance matrix over activations for generated samples. 164 | -- sigma2: The covariance matrix over activations, precalculated on an 165 | representative data set. 166 | 167 | Returns: 168 | -- : The Frechet Distance. 169 | """ 170 | 171 | mu1 = np.atleast_1d(mu1) 172 | mu2 = np.atleast_1d(mu2) 173 | 174 | sigma1 = np.atleast_2d(sigma1) 175 | sigma2 = np.atleast_2d(sigma2) 176 | 177 | assert mu1.shape == mu2.shape, \ 178 | 'Training and test mean vectors have different lengths' 179 | assert sigma1.shape == sigma2.shape, \ 180 | 'Training and test covariances have different dimensions' 181 | 182 | diff = mu1 - mu2 183 | 184 | # Product might be almost singular 185 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 186 | if not np.isfinite(covmean).all(): 187 | msg = ('fid calculation produces singular product; ' 188 | 'adding %s to diagonal of cov estimates') % eps 189 | print(msg) 190 | offset = np.eye(sigma1.shape[0]) * eps 191 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 192 | 193 | # Numerical error might give slight imaginary component 194 | if np.iscomplexobj(covmean): 195 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 196 | m = np.max(np.abs(covmean.imag)) 197 | raise ValueError('Imaginary component {}'.format(m)) 198 | covmean = covmean.real 199 | 200 | tr_covmean = np.trace(covmean) 201 | 202 | return (diff.dot(diff) + np.trace(sigma1) 203 | + np.trace(sigma2) - 2 * tr_covmean) 204 | 205 | 206 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 207 | device='cpu', num_workers=1): 208 | """Calculation of the statistics used by the FID. 209 | Params: 210 | -- files : List of image files paths 211 | -- model : Instance of inception model 212 | -- batch_size : The images numpy array is split into batches with 213 | batch size batch_size. A reasonable batch size 214 | depends on the hardware. 215 | -- dims : Dimensionality of features returned by Inception 216 | -- device : Device to run calculations 217 | -- num_workers : Number of parallel dataloader workers 218 | 219 | Returns: 220 | -- mu : The mean over samples of the activations of the pool_3 layer of 221 | the inception model. 222 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 223 | the inception model. 224 | """ 225 | act = get_activations(files, model, batch_size, dims, device, num_workers) 226 | mu = np.mean(act, axis=0) 227 | sigma = np.cov(act, rowvar=False) 228 | return mu, sigma 229 | 230 | 231 | def compute_statistics_of_path(path, model, batch_size, dims, device, 232 | num_workers=1): 233 | if path.endswith('.npz'): 234 | with np.load(path) as f: 235 | m, s = f['mu'][:], f['sigma'][:] 236 | else: 237 | path = pathlib.Path(path) 238 | files = sorted([file for ext in IMAGE_EXTENSIONS 239 | for file in path.glob('*.{}'.format(ext))]) 240 | m, s = calculate_activation_statistics(files, model, batch_size, 241 | dims, device, num_workers) 242 | 243 | return m, s 244 | 245 | 246 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): 247 | """Calculates the FID of two paths""" 248 | for p in paths: 249 | if not os.path.exists(p): 250 | raise RuntimeError('Invalid path: %s' % p) 251 | 252 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 253 | 254 | model = InceptionV3([block_idx]).to(device) 255 | 256 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 257 | dims, device, num_workers) 258 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 259 | dims, device, num_workers) 260 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 261 | 262 | return fid_value 263 | -------------------------------------------------------------------------------- /edm/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | -------------------------------------------------------------------------------- /edm/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numbers 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as F 6 | from torchvision.datasets import CIFAR10, ImageFolder 7 | from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS 8 | from datasets.celeba import CelebA 9 | from datasets.ffhq import FFHQ 10 | from datasets.lsun import LSUN 11 | from torch.utils.data import Subset 12 | import numpy as np 13 | from PIL import Image 14 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union 15 | 16 | 17 | class Crop(object): 18 | def __init__(self, x1, x2, y1, y2): 19 | self.x1 = x1 20 | self.x2 = x2 21 | self.y1 = y1 22 | self.y2 = y2 23 | 24 | def __call__(self, img): 25 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 26 | 27 | def __repr__(self): 28 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 29 | self.x1, self.x2, self.y1, self.y2 30 | ) 31 | 32 | class ImageDataset64(ImageFolder): 33 | def __init__(self, 34 | root: str, 35 | ): 36 | super().__init__( 37 | root=root 38 | ) 39 | self.resolution = 64 40 | 41 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 42 | """ 43 | Args: 44 | index (int): Index 45 | 46 | Returns: 47 | tuple: (sample, target) where target is class_index of the target class. 48 | """ 49 | path, target = self.samples[index] 50 | pil_image = self.loader(path) 51 | # if self.transform is not None: 52 | # sample = self.transform(sample) 53 | 54 | # We are not on a new enough PIL to support the `reducing_gap` 55 | # argument, which uses BOX downsampling at powers of two first. 56 | # Thus, we do it by hand to improve downsample quality. 57 | while min(*pil_image.size) >= 2 * self.resolution: 58 | pil_image = pil_image.resize( 59 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 60 | ) 61 | 62 | scale = self.resolution / min(*pil_image.size) 63 | pil_image = pil_image.resize( 64 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 65 | ) 66 | 67 | arr = np.array(pil_image.convert("RGB")) 68 | crop_y = (arr.shape[0] - self.resolution) // 2 69 | crop_x = (arr.shape[1] - self.resolution) // 2 70 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 71 | arr = arr.astype(np.float32) / 255. 72 | 73 | if self.target_transform is not None: 74 | target = self.target_transform(target) 75 | 76 | return np.transpose(arr, [2, 0, 1]), target 77 | 78 | def get_dataset(args, config): 79 | if config.data.random_flip is False: 80 | tran_transform = test_transform = transforms.Compose( 81 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 82 | ) 83 | else: 84 | tran_transform = transforms.Compose( 85 | [ 86 | transforms.Resize(config.data.image_size), 87 | transforms.RandomHorizontalFlip(p=0.5), 88 | transforms.ToTensor(), 89 | ] 90 | ) 91 | test_transform = transforms.Compose( 92 | [transforms.Resize(config.data.image_size), transforms.ToTensor()] 93 | ) 94 | 95 | if config.data.dataset == "CIFAR10": 96 | dataset = CIFAR10( 97 | os.path.join(args.exp, "datasets", "cifar10"), 98 | train=True, 99 | download=True, 100 | transform=tran_transform, 101 | ) 102 | test_dataset = CIFAR10( 103 | os.path.join(args.exp, "datasets", "cifar10_test"), 104 | train=False, 105 | download=True, 106 | transform=test_transform, 107 | ) 108 | 109 | elif config.data.dataset == "CELEBA": 110 | cx = 89 111 | cy = 121 112 | x1 = cy - 64 113 | x2 = cy + 64 114 | y1 = cx - 64 115 | y2 = cx + 64 116 | if config.data.random_flip: 117 | dataset = CelebA( 118 | root=os.path.join(args.exp, "datasets", "celeba"), 119 | split="train", 120 | transform=transforms.Compose( 121 | [ 122 | Crop(x1, x2, y1, y2), 123 | transforms.Resize(config.data.image_size), 124 | transforms.RandomHorizontalFlip(), 125 | transforms.ToTensor(), 126 | ] 127 | ), 128 | download=True, 129 | ) 130 | else: 131 | dataset = CelebA( 132 | root=os.path.join(args.exp, "datasets", "celeba"), 133 | split="train", 134 | transform=transforms.Compose( 135 | [ 136 | Crop(x1, x2, y1, y2), 137 | transforms.Resize(config.data.image_size), 138 | transforms.ToTensor(), 139 | ] 140 | ), 141 | download=True, 142 | ) 143 | 144 | test_dataset = CelebA( 145 | root=os.path.join(args.exp, "datasets", "celeba"), 146 | split="test", 147 | transform=transforms.Compose( 148 | [ 149 | Crop(x1, x2, y1, y2), 150 | transforms.Resize(config.data.image_size), 151 | transforms.ToTensor(), 152 | ] 153 | ), 154 | download=True, 155 | ) 156 | 157 | elif config.data.dataset == "LSUN": 158 | train_folder = "{}_train".format(config.data.category) 159 | val_folder = "{}_val".format(config.data.category) 160 | if config.data.random_flip: 161 | dataset = LSUN( 162 | root=config.data.root, 163 | classes=[train_folder], 164 | transform=transforms.Compose( 165 | [ 166 | transforms.Resize(config.data.image_size), 167 | transforms.CenterCrop(config.data.image_size), 168 | transforms.RandomHorizontalFlip(p=0.5), 169 | transforms.ToTensor(), 170 | ] 171 | ), 172 | ) 173 | else: 174 | dataset = LSUN( 175 | root=config.data.root, 176 | classes=[train_folder], 177 | transform=transforms.Compose( 178 | [ 179 | transforms.Resize(config.data.image_size), 180 | transforms.CenterCrop(config.data.image_size), 181 | transforms.ToTensor(), 182 | ] 183 | ), 184 | ) 185 | 186 | test_dataset = None #LSUN( 187 | # root=config.data.root, 188 | # classes=[val_folder], 189 | # transform=transforms.Compose( 190 | # [ 191 | # transforms.Resize(config.data.image_size), 192 | # transforms.CenterCrop(config.data.image_size), 193 | # transforms.ToTensor(), 194 | # ] 195 | # ), 196 | # ) 197 | 198 | elif config.data.dataset == "FFHQ": 199 | if config.data.random_flip: 200 | dataset = FFHQ( 201 | path=os.path.join(args.exp, "datasets", "FFHQ"), 202 | transform=transforms.Compose( 203 | [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()] 204 | ), 205 | resolution=config.data.image_size, 206 | ) 207 | else: 208 | dataset = FFHQ( 209 | path=os.path.join(args.exp, "datasets", "FFHQ"), 210 | transform=transforms.ToTensor(), 211 | resolution=config.data.image_size, 212 | ) 213 | 214 | num_items = len(dataset) 215 | indices = list(range(num_items)) 216 | random_state = np.random.get_state() 217 | np.random.seed(2019) 218 | np.random.shuffle(indices) 219 | np.random.set_state(random_state) 220 | train_indices, test_indices = ( 221 | indices[: int(num_items * 0.9)], 222 | indices[int(num_items * 0.9) :], 223 | ) 224 | test_dataset = Subset(dataset, test_indices) 225 | dataset = Subset(dataset, train_indices) 226 | 227 | elif config.data.dataset == "IMAGENET64": 228 | # if config.data.loader_type == 'custom': 229 | # from datasets.imagenet64 import ImageNetDownSample 230 | # if config.data.random_flip: 231 | # dataset = ImageNetDownSample( 232 | # root=config.data.root, 233 | # transform=transforms.Compose( 234 | # [ 235 | # transforms.RandomHorizontalFlip(p=0.5), 236 | # transforms.ToTensor(), 237 | # ] 238 | # ), 239 | # ) 240 | # else: 241 | # dataset = ImageNetDownSample( 242 | # root=config.data.root, 243 | # transform=transforms.Compose( 244 | # [ 245 | # transforms.ToTensor(), 246 | # ] 247 | # ), 248 | # ) 249 | # test_dataset = None 250 | # else: 251 | train_folder = "{}/train".format(config.data.root) 252 | val_folder = "{}/val".format(config.data.root) 253 | if config.data.random_flip: 254 | dataset = ImageDataset64( 255 | root=train_folder, 256 | # transform=transforms.Compose( 257 | # [ 258 | # transforms.Resize(config.data.image_size, interpolation=transforms.InterpolationMode.BICUBIC), # 259 | # transforms.CenterCrop(config.data.image_size), 260 | # transforms.RandomHorizontalFlip(p=0.5), 261 | # transforms.ToTensor(), 262 | # ] 263 | # ), 264 | ) 265 | else: 266 | dataset = ImageDataset64( 267 | root=train_folder, 268 | # transform=transforms.Compose( 269 | # [ 270 | # transforms.Resize(config.data.image_size, interpolation=transforms.InterpolationMode.BICUBIC), #interpolation=InterpolationMode.BICUBIC 271 | # transforms.CenterCrop(config.data.image_size), 272 | # transforms.ToTensor(), 273 | # ] 274 | # ), 275 | ) 276 | 277 | test_dataset = ImageDataset64( 278 | root=val_folder, 279 | # transform=transforms.Compose( 280 | # [ 281 | # transforms.Resize(config.data.image_size, interpolation=transforms.InterpolationMode.BICUBIC), #interpolation=InterpolationMode.BICUBIC 282 | # transforms.CenterCrop(config.data.image_size), 283 | # transforms.ToTensor(), 284 | # ] 285 | # ), 286 | ) 287 | 288 | else: 289 | dataset, test_dataset = None, None 290 | 291 | return dataset, test_dataset 292 | 293 | 294 | def logit_transform(image, lam=1e-6): 295 | image = lam + (1 - 2 * lam) * image 296 | return torch.log(image) - torch.log1p(-image) 297 | 298 | 299 | def data_transform(config, X): 300 | if config.data.uniform_dequantization: 301 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 302 | if config.data.gaussian_dequantization: 303 | X = X + torch.randn_like(X) * 0.01 304 | 305 | if config.data.rescaled: 306 | X = 2 * X - 1.0 307 | elif config.data.logit_transform: 308 | X = logit_transform(X) 309 | 310 | if hasattr(config, "image_mean"): 311 | return X - config.image_mean.to(X.device)[None, ...] 312 | 313 | return X 314 | 315 | 316 | def inverse_data_transform(config, X): 317 | if hasattr(config, "image_mean"): 318 | X = X + config.image_mean.to(X.device)[None, ...] 319 | 320 | if config.data.logit_transform: 321 | X = torch.sigmoid(X) 322 | elif config.data.rescaled: 323 | X = (X + 1.0) / 2.0 324 | 325 | return torch.clamp(X, 0.0, 1.0) 326 | -------------------------------------------------------------------------------- /edm/training/training_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Main training loop.""" 9 | 10 | import os 11 | import time 12 | import copy 13 | import json 14 | import pickle 15 | import psutil 16 | import numpy as np 17 | import torch 18 | import dnnlib 19 | from torch_utils import distributed as dist 20 | from torch_utils import training_stats 21 | from torch_utils import misc 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def training_loop( 26 | run_dir = '.', # Output directory. 27 | dataset_kwargs = {}, # Options for training set. 28 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. 29 | network_kwargs = {}, # Options for model and preconditioning. 30 | loss_kwargs = {}, # Options for loss function. 31 | optimizer_kwargs = {}, # Options for optimizer. 32 | augment_kwargs = None, # Options for augmentation pipeline, None = disable. 33 | seed = 0, # Global random seed. 34 | batch_size = 512, # Total batch size for one training iteration. 35 | batch_gpu = None, # Limit batch size per GPU, None = no limit. 36 | total_kimg = 200000, # Training duration, measured in thousands of training images. 37 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. 38 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. 39 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration. 40 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. 41 | kimg_per_tick = 50, # Interval of progress prints. 42 | snapshot_ticks = 50, # How often to save network snapshots, None = disable. 43 | state_dump_ticks = 500, # How often to dump training state, None = disable. 44 | resume_pkl = None, # Start from the given network snapshot, None = random initialization. 45 | resume_state_dump = None, # Start from the given training state, None = reset training state. 46 | resume_kimg = 0, # Start from the given training progress. 47 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 48 | device = torch.device('cuda'), 49 | reg_on_mean = False, 50 | ): 51 | # Initialize. 52 | start_time = time.time() 53 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) 54 | torch.manual_seed(np.random.randint(1 << 31)) 55 | torch.backends.cudnn.benchmark = cudnn_benchmark 56 | torch.backends.cudnn.allow_tf32 = False 57 | torch.backends.cuda.matmul.allow_tf32 = False 58 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 59 | 60 | # Select batch size per GPU. 61 | batch_gpu_total = batch_size // dist.get_world_size() 62 | if batch_gpu is None or batch_gpu > batch_gpu_total: 63 | batch_gpu = batch_gpu_total 64 | num_accumulation_rounds = batch_gpu_total // batch_gpu 65 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() 66 | 67 | # Load dataset. 68 | dist.print0('Loading dataset...') 69 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset 70 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) 71 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) 72 | 73 | # Construct network. 74 | dist.print0('Constructing network...') 75 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) 76 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 77 | if reg_on_mean: 78 | from training.networks import ModelWrapper 79 | net = ModelWrapper(net, num_layers=4, channel_mult_emb=8, embedding_type='positional') 80 | 81 | net.train().requires_grad_(True).to(device) 82 | if dist.get_rank() == 0: 83 | with torch.no_grad(): 84 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device) 85 | sigma = torch.ones([batch_gpu], device=device) 86 | labels = torch.zeros([batch_gpu, net.label_dim], device=device) 87 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2) 88 | 89 | # Setup optimizer. 90 | dist.print0('Setting up optimizer...') 91 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss 92 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer 93 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe 94 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) 95 | ema = copy.deepcopy(net).eval().requires_grad_(False) 96 | 97 | # Resume training from previous snapshot. 98 | if resume_pkl is not None: 99 | dist.print0(f'Loading network weights from "{resume_pkl}"...') 100 | if dist.get_rank() != 0: 101 | torch.distributed.barrier() # rank 0 goes first 102 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: 103 | data = pickle.load(f) 104 | if dist.get_rank() == 0: 105 | torch.distributed.barrier() # other ranks follow 106 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net.model, require_all=False) 107 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema.model, require_all=False) 108 | del data # conserve memory 109 | if resume_state_dump: 110 | dist.print0(f'Loading training state from "{resume_state_dump}"...') 111 | data = torch.load(resume_state_dump, map_location=torch.device('cpu')) 112 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net.model, require_all=True) 113 | optimizer.load_state_dict(data['optimizer_state']) 114 | del data # conserve memory 115 | 116 | # Train. 117 | dist.print0(f'Training for {total_kimg} kimg...') 118 | dist.print0() 119 | cur_nimg = resume_kimg * 1000 120 | cur_tick = 0 121 | tick_start_nimg = cur_nimg 122 | tick_start_time = time.time() 123 | maintenance_time = tick_start_time - start_time 124 | dist.update_progress(cur_nimg // 1000, total_kimg) 125 | stats_jsonl = None 126 | while True: 127 | 128 | # Accumulate gradients. 129 | optimizer.zero_grad(set_to_none=True) 130 | for round_idx in range(num_accumulation_rounds): 131 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 132 | images, labels = next(dataset_iterator) 133 | images = images.to(device).to(torch.float32) / 127.5 - 1 134 | labels = labels.to(device) 135 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) 136 | if reg_on_mean: 137 | loss, loss3 = loss #, loss2 138 | training_stats.report('Loss/loss', loss) 139 | # training_stats.report('Loss/loss2', loss2) 140 | training_stats.report('Loss/loss3', loss3) 141 | (loss + loss3).sum().mul(loss_scaling / batch_gpu_total).backward() # + loss2 142 | else: 143 | training_stats.report('Loss/loss', loss) 144 | loss.sum().mul(loss_scaling / batch_gpu_total).backward() 145 | 146 | # Update weights. 147 | for g in optimizer.param_groups: 148 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) 149 | for param in net.parameters(): 150 | if param.grad is not None: 151 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 152 | optimizer.step() 153 | 154 | # Update EMA. 155 | ema_halflife_nimg = ema_halflife_kimg * 1000 156 | if ema_rampup_ratio is not None: 157 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) 158 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) 159 | for p_ema, p_net in zip(ema.parameters(), net.parameters()): 160 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) 161 | 162 | # Perform maintenance tasks once per tick. 163 | cur_nimg += batch_size 164 | done = (cur_nimg >= total_kimg * 1000) 165 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): 166 | continue 167 | 168 | # Print status line, accumulating the same information in training_stats. 169 | tick_end_time = time.time() 170 | fields = [] 171 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] 172 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] 173 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] 174 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] 175 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] 176 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] 177 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] 178 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] 179 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] 180 | torch.cuda.reset_peak_memory_stats() 181 | dist.print0(' '.join(fields)) 182 | 183 | # Check for abort. 184 | if (not done) and dist.should_stop(): 185 | done = True 186 | dist.print0() 187 | dist.print0('Aborting...') 188 | 189 | # Save network snapshot. 190 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): 191 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) 192 | for key, value in data.items(): 193 | if isinstance(value, torch.nn.Module): 194 | value = copy.deepcopy(value).eval().requires_grad_(False) 195 | misc.check_ddp_consistency(value) 196 | data[key] = value.cpu() 197 | del value # conserve memory 198 | if dist.get_rank() == 0: 199 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: 200 | pickle.dump(data, f) 201 | del data # conserve memory 202 | 203 | # Save full dump of the training state. 204 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: 205 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) 206 | 207 | # Update logs. 208 | training_stats.default_collector.update() 209 | if dist.get_rank() == 0: 210 | if stats_jsonl is None: 211 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at') 212 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n') 213 | stats_jsonl.flush() 214 | dist.update_progress(cur_nimg // 1000, total_kimg) 215 | 216 | # Update state. 217 | cur_tick += 1 218 | tick_start_nimg = cur_nimg 219 | tick_start_time = time.time() 220 | maintenance_time = tick_start_time - tick_end_time 221 | if done: 222 | break 223 | 224 | # Done. 225 | dist.print0() 226 | dist.print0('Exiting...') 227 | 228 | #---------------------------------------------------------------------------- 229 | -------------------------------------------------------------------------------- /evaluate/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=(DEFAULT_BLOCK_INDEX,), 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = _inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def _inception_v3(*args, **kwargs): 167 | """Wraps `torchvision.models.inception_v3` 168 | 169 | Skips default weight inititialization if supported by torchvision version. 170 | See https://github.com/mseitzer/pytorch-fid/issues/28. 171 | """ 172 | try: 173 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 174 | except ValueError: 175 | # Just a caution against weird version strings 176 | version = (0,) 177 | 178 | if version >= (0, 6): 179 | kwargs['init_weights'] = False 180 | 181 | return torchvision.models.inception_v3(*args, **kwargs) 182 | 183 | 184 | def fid_inception_v3(): 185 | """Build pretrained Inception model for FID computation 186 | 187 | The Inception model for FID computation uses a different set of weights 188 | and has a slightly different structure than torchvision's Inception. 189 | 190 | This method first constructs torchvision's Inception and then patches the 191 | necessary parts that are different in the FID Inception model. 192 | """ 193 | inception = _inception_v3(num_classes=1008, 194 | aux_logits=False, 195 | pretrained=False) 196 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 197 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 198 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 199 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 200 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 201 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 203 | inception.Mixed_7b = FIDInceptionE_1(1280) 204 | inception.Mixed_7c = FIDInceptionE_2(2048) 205 | 206 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 207 | inception.load_state_dict(state_dict) 208 | return inception 209 | 210 | 211 | class FIDInceptionA(torchvision.models.inception.InceptionA): 212 | """InceptionA block patched for FID computation""" 213 | def __init__(self, in_channels, pool_features): 214 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 215 | 216 | def forward(self, x): 217 | branch1x1 = self.branch1x1(x) 218 | 219 | branch5x5 = self.branch5x5_1(x) 220 | branch5x5 = self.branch5x5_2(branch5x5) 221 | 222 | branch3x3dbl = self.branch3x3dbl_1(x) 223 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 224 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 225 | 226 | # Patch: Tensorflow's average pool does not use the padded zero's in 227 | # its average calculation 228 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 229 | count_include_pad=False) 230 | branch_pool = self.branch_pool(branch_pool) 231 | 232 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 233 | return torch.cat(outputs, 1) 234 | 235 | 236 | class FIDInceptionC(torchvision.models.inception.InceptionC): 237 | """InceptionC block patched for FID computation""" 238 | def __init__(self, in_channels, channels_7x7): 239 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 240 | 241 | def forward(self, x): 242 | branch1x1 = self.branch1x1(x) 243 | 244 | branch7x7 = self.branch7x7_1(x) 245 | branch7x7 = self.branch7x7_2(branch7x7) 246 | branch7x7 = self.branch7x7_3(branch7x7) 247 | 248 | branch7x7dbl = self.branch7x7dbl_1(x) 249 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 250 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 251 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 252 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 253 | 254 | # Patch: Tensorflow's average pool does not use the padded zero's in 255 | # its average calculation 256 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 257 | count_include_pad=False) 258 | branch_pool = self.branch_pool(branch_pool) 259 | 260 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 261 | return torch.cat(outputs, 1) 262 | 263 | 264 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 265 | """First InceptionE block patched for FID computation""" 266 | def __init__(self, in_channels): 267 | super(FIDInceptionE_1, self).__init__(in_channels) 268 | 269 | def forward(self, x): 270 | branch1x1 = self.branch1x1(x) 271 | 272 | branch3x3 = self.branch3x3_1(x) 273 | branch3x3 = [ 274 | self.branch3x3_2a(branch3x3), 275 | self.branch3x3_2b(branch3x3), 276 | ] 277 | branch3x3 = torch.cat(branch3x3, 1) 278 | 279 | branch3x3dbl = self.branch3x3dbl_1(x) 280 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 281 | branch3x3dbl = [ 282 | self.branch3x3dbl_3a(branch3x3dbl), 283 | self.branch3x3dbl_3b(branch3x3dbl), 284 | ] 285 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 286 | 287 | # Patch: Tensorflow's average pool does not use the padded zero's in 288 | # its average calculation 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 290 | count_include_pad=False) 291 | branch_pool = self.branch_pool(branch_pool) 292 | 293 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 294 | return torch.cat(outputs, 1) 295 | 296 | 297 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 298 | """Second InceptionE block patched for FID computation""" 299 | def __init__(self, in_channels): 300 | super(FIDInceptionE_2, self).__init__(in_channels) 301 | 302 | def forward(self, x): 303 | branch1x1 = self.branch1x1(x) 304 | 305 | branch3x3 = self.branch3x3_1(x) 306 | branch3x3 = [ 307 | self.branch3x3_2a(branch3x3), 308 | self.branch3x3_2b(branch3x3), 309 | ] 310 | branch3x3 = torch.cat(branch3x3, 1) 311 | 312 | branch3x3dbl = self.branch3x3dbl_1(x) 313 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 314 | branch3x3dbl = [ 315 | self.branch3x3dbl_3a(branch3x3dbl), 316 | self.branch3x3dbl_3b(branch3x3dbl), 317 | ] 318 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 319 | 320 | # Patch: The FID Inception model uses max pooling instead of average 321 | # pooling. This is likely an error in this specific Inception 322 | # implementation, as other Inception models use average pooling here 323 | # (which matches the description in the paper). 324 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 325 | branch_pool = self.branch_pool(branch_pool) 326 | 327 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 328 | return torch.cat(outputs, 1) 329 | --------------------------------------------------------------------------------