├── figures
├── figure1.PNG
└── figure2.PNG
├── guided_diffusion
├── __init__.py
├── __pycache__
│ ├── nn.cpython-39.pyc
│ ├── unet.cpython-39.pyc
│ ├── logger.cpython-39.pyc
│ ├── __init__.cpython-39.pyc
│ ├── fp16_util.cpython-39.pyc
│ └── script_util.cpython-39.pyc
├── script_util.py
├── dist_util.py
├── losses.py
├── respace.py
├── resample.py
├── nn.py
├── fp16_util.py
├── train_util.py
└── logger.py
├── torch_utils
├── __init__.py
├── distributed.py
├── persistence.py
├── training_stats.py
└── misc.py
├── training
├── __init__.py
├── loss.py
├── dataset.py
└── training_loop.py
├── environment.yml
├── dnnlib
├── __init__.py
└── util.py
├── Dockerfile
├── example.py
├── README.md
├── train_classifier.py
├── fid.py
├── classifier_lib.py
├── LICENSE
├── train.py
└── generate.py
/figures/figure1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/figures/figure1.PNG
--------------------------------------------------------------------------------
/figures/figure2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/figures/figure2.PNG
--------------------------------------------------------------------------------
/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/nn.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/nn.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/unet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/unet.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/fp16_util.cpython-39.pyc
--------------------------------------------------------------------------------
/guided_diffusion/__pycache__/script_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aailabkaist/TIW-DSM/HEAD/guided_diffusion/__pycache__/script_util.cpython-39.pyc
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: edm
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python>=3.8, < 3.10 # package build failures on 3.10
7 | - pip
8 | - numpy>=1.20
9 | - click>=8.0
10 | - pillow>=8.3.1
11 | - scipy>=1.7.1
12 | - pytorch=1.12.1
13 | - psutil
14 | - requests
15 | - tqdm
16 | - imageio
17 | - pip:
18 | - imageio-ffmpeg>=0.4.3
19 | - pyspng
20 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | FROM nvcr.io/nvidia/pytorch:22.10-py3
9 |
10 | ENV PYTHONDONTWRITEBYTECODE 1
11 | ENV PYTHONUNBUFFERED 1
12 |
13 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
14 |
15 | WORKDIR /workspace
16 |
17 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
18 | ENTRYPOINT ["/entry.sh"]
19 |
--------------------------------------------------------------------------------
/guided_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | from .unet import EncoderUNetModel
2 |
3 | NUM_CLASSES = 1000
4 | def create_classifier(
5 | image_size,
6 | classifier_use_fp16,
7 | classifier_width,
8 | classifier_depth,
9 | classifier_attention_resolutions,
10 | classifier_use_scale_shift_norm,
11 | classifier_resblock_updown,
12 | classifier_pool,
13 | out_channels,
14 | in_channels = 3,
15 | condition=False,
16 | ):
17 | if image_size == 512:
18 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
19 | elif image_size == 256:
20 | channel_mult = (1, 1, 2, 2, 4, 4)
21 | elif image_size == 128:
22 | channel_mult = (1, 1, 2, 3, 4)
23 | elif image_size == 64:
24 | channel_mult = (1, 2, 3, 4)
25 | elif image_size == 32:
26 | channel_mult = (1, 2, 4)
27 | elif image_size == 16:
28 | channel_mult = (1, 2)
29 | elif image_size == 8:
30 | channel_mult = (1,)
31 | else:
32 | raise ValueError(f"unsupported image size: {image_size}")
33 |
34 | attention_ds = []
35 | for res in classifier_attention_resolutions.split(","):
36 | attention_ds.append(image_size // int(res))
37 |
38 | return EncoderUNetModel(
39 | image_size=image_size,
40 | in_channels=in_channels,
41 | model_channels=classifier_width,
42 | out_channels=out_channels,
43 | num_res_blocks=classifier_depth,
44 | attention_resolutions=tuple(attention_ds),
45 | channel_mult=channel_mult,
46 | use_fp16=classifier_use_fp16,
47 | num_head_channels=64,
48 | use_scale_shift_norm=classifier_use_scale_shift_norm,
49 | resblock_updown=classifier_resblock_updown,
50 | pool=classifier_pool,
51 | condition=condition,
52 | )
53 |
54 |
--------------------------------------------------------------------------------
/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist(args):
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE + args.initial_gpu}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Minimal standalone example to reproduce the main results from the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import tqdm
12 | import pickle
13 | import numpy as np
14 | import torch
15 | import PIL.Image
16 | import dnnlib
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def generate_image_grid(
21 | network_pkl, dest_path,
22 | seed=0, gridw=8, gridh=8, device=torch.device('cuda'),
23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
25 | ):
26 | batch_size = gridw * gridh
27 | torch.manual_seed(seed)
28 |
29 | # Load network.
30 | print(f'Loading network from "{network_pkl}"...')
31 | with dnnlib.util.open_url(network_pkl) as f:
32 | net = pickle.load(f)['ema'].to(device)
33 |
34 | # Pick latents and labels.
35 | print(f'Generating {batch_size} images...')
36 | latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
37 | class_labels = None
38 | if net.label_dim:
39 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
40 |
41 | # Adjust noise levels based on what's supported by the network.
42 | sigma_min = max(sigma_min, net.sigma_min)
43 | sigma_max = min(sigma_max, net.sigma_max)
44 |
45 | # Time step discretization.
46 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
47 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
48 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
49 |
50 | # Main sampling loop.
51 | x_next = latents.to(torch.float64) * t_steps[0]
52 | for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
53 | x_cur = x_next
54 |
55 | # Increase noise temporarily.
56 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
57 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
58 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
59 |
60 | # Euler step.
61 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
62 | d_cur = (x_hat - denoised) / t_hat
63 | x_next = x_hat + (t_next - t_hat) * d_cur
64 |
65 | # Apply 2nd order correction.
66 | if i < num_steps - 1:
67 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
68 | d_prime = (x_next - denoised) / t_next
69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
70 |
71 | # Save image grid.
72 | print(f'Saving image grid to "{dest_path}"...')
73 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)
74 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
75 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
76 | image = image.cpu().numpy()
77 | PIL.Image.fromarray(image, 'RGB').save(dest_path)
78 | print('Done.')
79 |
80 | #----------------------------------------------------------------------------
81 |
82 | def main():
83 | model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained'
84 | generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35
85 | generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79
86 | generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79
87 | generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511
88 |
89 | #----------------------------------------------------------------------------
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
94 | #----------------------------------------------------------------------------
95 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## TRAINING UNBIASED DIFFUSION MODELS FROM BIASED DATASET (TIW-DSM) (ICLR 2024)
Official PyTorch implementation of the TIW-DSM
2 |
3 |
4 |
5 | **[Yeongmin Kim](https://sites.google.com/view/yeongmin-space/%ED%99%88), [Byeonghu Na](https://sites.google.com/view/byeonghu-na), Minsang Park, [JoonHo Jang](https://sites.google.com/view/joonhojang), [Dongjun Kim](https://sites.google.com/view/dongjun-kim), [Wanmo Kang](https://sites.google.com/site/wanmokang), and [Il-Chul Moon](https://aai.kaist.ac.kr/bbs/board.php?bo_table=sub2_1&wr_id=3)**
6 |
7 | | [openreview](https://openreview.net/forum?id=39cPKijBed) | [arxiv](https://arxiv.org/abs/2403.01189) | [datasets](https://drive.google.com/drive/u/0/folders/1RakPtfp70E2BSgDM5xMBd2Om0N8ycrRK) | [checkpoints](https://drive.google.com/drive/u/0/folders/1vYLH8UNlXWZarn0IOtiPuU8FvBFqJvTP) |
8 |
9 | --------------------
10 |
11 | ## Overview
12 | 
13 | 
14 | ## Requirements
15 | The requirements for this code are the same as those outlined for [EDM](https://github.com/NVlabs/edm).
16 |
17 | ## Datasets
18 | - Download from [datasets](https://drive.google.com/drive/u/0/folders/1RakPtfp70E2BSgDM5xMBd2Om0N8ycrRK) with similar directory structure.
19 | ## Training
20 | - Download pre-trained feature extractor from [checkpoints](https://drive.google.com/drive/u/0/folders/1vYLH8UNlXWZarn0IOtiPuU8FvBFqJvTP).
21 | ### Time-dependent discriminator
22 | - CIFAR10 LT / 5% setting
23 | ```
24 | python train_classifier.py
25 | ```
26 | - CIFAR10 LT / 10% setting
27 | ```
28 | python train_classifier.py --savedir=/checkpoints/discriminator/cifar10/unbias_1000/ --refdir=/datasets/cifar10/discriminator_training/unbias_1000/real_data.npz --real_mul=10
29 | ```
30 | - CelebA64 / 5% setting
31 | ```
32 | python train_classifier.py --feature_extractor=/checkpoints/discriminator/feature_extractor/64x64_classifier.pt --savedir=/checkpoints/discriminator/celeba64/unbias_8k/ --biasdir=/datasets/celeba64/discriminator_training/bias_162k/fake_data.npz --refdir=/datasets/celeba64/discriminator_training/unbias_8k/real_data.npz --img_resolution=64
33 | ```
34 |
35 | ### Diffusion model with TIW-DSM objective
36 | - CIFAR10 LT / 5% setting
37 | ```
38 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 train.py --arch=ddpmpp --outdir=out_dir --data=PATH/TIW-DSM/datasets/cifar10/score_training/500_10000/dataset.zip --batch=256 --tick=5 --duration=11
39 | ```
40 | - CIFAR10 LT / 10% setting
41 | ```
42 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 train.py --arch=ddpmpp --outdir=out_dir --data=PATH/TIW-DSM/datasets/cifar10/score_training/1000_10000/dataset.zip --batch=256 --tick=5 --duration=21
43 | ```
44 | - CelebA64 / 5% setting
45 | ```
46 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --standalone --nproc_per_node=4 train.py --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 --outdir=outdir --data=PATH/TIW-DSM/datasets/celeba64/score_training/dataset.zip --cla_path=PATH/TIW-DSM/checkpoints/discriminator/feature_extractor/64x64_classifier.pt --dis_path=PATH/TIW-DSM/checkpoints/discriminator/celeba64/unbias_8k/discriminator_9501.pt --tick=5
47 | ```
48 |
49 | ## Sampling
50 | - CIFAR10
51 | ```
52 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-49999 --batch=64 --network=TRAINED_PKL
53 | ```
54 | - CelebA64
55 | ```
56 | torchrun --standalone --nproc_per_node=2 generate.py --steps=40 --outdir=out --seeds=0-49999 --batch=64 --network=TRAINED_PKL
57 | ```
58 | ## Evaluation
59 | - CIFAR10
60 | ```
61 | python fid.py calc --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz --num=50000 --images=YOUR_SAMPLE_PATH
62 | ```
63 | - CelebA64
64 | ```
65 | python fid.py calc --ref=/PATH/TIW-DSM/datasets/celeba/evaluation/FID_stat.npz --num=50000 --images=YOUR_SAMPLE_PATH
66 | ```
67 |
68 | ## Reference
69 | If you find the code useful for your research, please consider citing
70 | ```bib
71 | @inproceedings{
72 | kim2024training,
73 | title={Training Unbiased Diffusion Models From Biased Dataset},
74 | author={Yeongmin Kim and Byeonghu Na and Minsang Park and JoonHo Jang and Dongjun Kim and Wanmo Kang and Il-chul Moon},
75 | booktitle={The Twelfth International Conference on Learning Representations},
76 | year={2024},
77 | url={https://openreview.net/forum?id=39cPKijBed}
78 | }
79 | ```
80 | This work is heavily built upon the code from
81 | - *Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusion-based generative models. Advances in Neural Information Processing Systems, 35:26565–26577,2022.*
82 | - *Dongjun Kim\*, Yeongmin Kim\*, Se Jung Kwon, Wanmo Kang, and Il-Chul Moon. Refining generative process with discriminator guidance in score-based diffusion models. In Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp. 16567–16598. PMLR, 23–29 Jul 2023*
83 |
84 |
85 |
--------------------------------------------------------------------------------
/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Loss functions used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import torch
12 | from torch_utils import persistence
13 | import classifier_lib
14 | #----------------------------------------------------------------------------
15 | # Loss function corresponding to the variance preserving (VP) formulation
16 | # from the paper "Score-Based Generative Modeling through Stochastic
17 | # Differential Equations".
18 |
19 | @persistence.persistent_class
20 | class VPLoss:
21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
22 | self.beta_d = beta_d
23 | self.beta_min = beta_min
24 | self.epsilon_t = epsilon_t
25 |
26 | def __call__(self, net, images, labels, augment_pipe=None):
27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
29 | weight = 1 / sigma ** 2
30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
31 | n = torch.randn_like(y) * sigma
32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
33 | loss = weight * ((D_yn - y) ** 2)
34 | return loss
35 |
36 | def sigma(self, t):
37 | t = torch.as_tensor(t)
38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
39 |
40 | #----------------------------------------------------------------------------
41 | # Loss function corresponding to the variance exploding (VE) formulation
42 | # from the paper "Score-Based Generative Modeling through Stochastic
43 | # Differential Equations".
44 |
45 | @persistence.persistent_class
46 | class VELoss:
47 | def __init__(self, sigma_min=0.02, sigma_max=100):
48 | self.sigma_min = sigma_min
49 | self.sigma_max = sigma_max
50 |
51 | def __call__(self, net, images, labels, augment_pipe=None):
52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
54 | weight = 1 / sigma ** 2
55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
56 | n = torch.randn_like(y) * sigma
57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
58 | loss = weight * ((D_yn - y) ** 2)
59 | return loss
60 |
61 | #----------------------------------------------------------------------------
62 | # Improved loss function proposed in the paper "Elucidating the Design Space
63 | # of Diffusion-Based Generative Models" (EDM).
64 |
65 | @persistence.persistent_class
66 | class EDMLoss:
67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
68 | self.P_mean = P_mean
69 | self.P_std = P_std
70 | self.sigma_data = sigma_data
71 |
72 | def __call__(self, net, images, labels=None, augment_pipe=None):
73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
77 | n = torch.randn_like(y) * sigma
78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
79 | loss = weight * ((D_yn - y) ** 2)
80 | return loss
81 |
82 | #----------------------------------------------------------------------------
83 | @persistence.persistent_class
84 | class TIW_EDMLoss:
85 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
86 | self.P_mean = P_mean
87 | self.P_std = P_std
88 | self.sigma_data = sigma_data
89 |
90 | def __call__(self, vpsde, dis, net, images, labels=None, augment_pipe=None):
91 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
92 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
93 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
94 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
95 | n = torch.randn_like(y) * sigma
96 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
97 |
98 | correction, prediction = classifier_lib.get_weight_correction(dis, vpsde, y + n, sigma.flatten(), images.shape[1], 0., 1.)
99 | correction = correction.detach()
100 | prediction = prediction.detach()[:, None, None, None]
101 |
102 | ## Please refer eq 36, 37 in main paper
103 | tiw = 2 * prediction
104 | y += correction
105 |
106 | loss = weight * tiw[:, None, None, None] * ((D_yn - y) ** 2)
107 | return loss
108 |
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t, reversed=False):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import numpy as np
8 | import torch as th
9 | import torch.nn as nn
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | class GaussianFourierProjection(nn.Module):
36 | """Gaussian Fourier embeddings for noise levels."""
37 |
38 | def __init__(self, embedding_size=256, scale=1.0):
39 | super().__init__()
40 | self.W = nn.Parameter(th.randn(embedding_size) * scale, requires_grad=False)
41 |
42 | def forward(self, x):
43 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
44 | return th.cat([th.sin(x_proj), th.cos(x_proj)], dim=-1)
45 |
46 |
47 | def GFP(embedding_size=256, scale=1.0):
48 | return GaussianFourierProjection(embedding_size, scale)
49 |
50 |
51 | def linear(*args, **kwargs):
52 | """
53 | Create a linear module.
54 | """
55 | return nn.Linear(*args, **kwargs)
56 |
57 |
58 | def avg_pool_nd(dims, *args, **kwargs):
59 | """
60 | Create a 1D, 2D, or 3D average pooling module.
61 | """
62 | if dims == 1:
63 | return nn.AvgPool1d(*args, **kwargs)
64 | elif dims == 2:
65 | return nn.AvgPool2d(*args, **kwargs)
66 | elif dims == 3:
67 | return nn.AvgPool3d(*args, **kwargs)
68 | raise ValueError(f"unsupported dimensions: {dims}")
69 |
70 |
71 | def update_ema(target_params, source_params, rate=0.99):
72 | """
73 | Update target parameters to be closer to those of source parameters using
74 | an exponential moving average.
75 |
76 | :param target_params: the target parameter sequence.
77 | :param source_params: the source parameter sequence.
78 | :param rate: the EMA rate (closer to 1 means slower).
79 | """
80 | for targ, src in zip(target_params, source_params):
81 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
82 |
83 |
84 | def zero_module(module):
85 | """
86 | Zero out the parameters of a module and return it.
87 | """
88 | for p in module.parameters():
89 | p.detach().zero_()
90 | return module
91 |
92 |
93 | def scale_module(module, scale):
94 | """
95 | Scale the parameters of a module and return it.
96 | """
97 | for p in module.parameters():
98 | p.detach().mul_(scale)
99 | return module
100 |
101 |
102 | def mean_flat(tensor):
103 | """
104 | Take the mean over all non-batch dimensions.
105 | """
106 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
107 |
108 |
109 | def normalization(channels):
110 | """
111 | Make a standard normalization layer.
112 |
113 | :param channels: number of input channels.
114 | :return: an nn.Module for normalization.
115 | """
116 | return GroupNorm32(32, channels)
117 |
118 |
119 | def timestep_embedding(timesteps, dim, max_period=10000):
120 | """
121 | Create sinusoidal timestep embeddings.
122 |
123 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
124 | These may be fractional.
125 | :param dim: the dimension of the output.
126 | :param max_period: controls the minimum frequency of the embeddings.
127 | :return: an [N x dim] Tensor of positional embeddings.
128 | """
129 | half = dim // 2
130 | freqs = th.exp(
131 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
132 | ).to(device=timesteps.device)
133 | args = timesteps[:, None].float() * freqs[None]
134 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
135 | if dim % 2:
136 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
137 | return embedding
138 |
139 |
140 | def checkpoint(func, inputs, params, flag):
141 | """
142 | Evaluate a function without caching intermediate activations, allowing for
143 | reduced memory at the expense of extra compute in the backward pass.
144 |
145 | :param func: the function to evaluate.
146 | :param inputs: the argument sequence to pass to `func`.
147 | :param params: a sequence of parameters `func` depends on but does not
148 | explicitly take as arguments.
149 | :param flag: if False, disable gradient checkpointing.
150 | """
151 | if flag:
152 | args = tuple(inputs) + tuple(params)
153 | return CheckpointFunction.apply(func, len(inputs), *args)
154 | else:
155 | return func(*inputs)
156 |
157 |
158 | class CheckpointFunction(th.autograd.Function):
159 | @staticmethod
160 | def forward(ctx, run_function, length, *args):
161 | ctx.run_function = run_function
162 | ctx.input_tensors = list(args[:length])
163 | ctx.input_params = list(args[length:])
164 | with th.no_grad():
165 | output_tensors = ctx.run_function(*ctx.input_tensors)
166 | return output_tensors
167 |
168 | @staticmethod
169 | def backward(ctx, *output_grads):
170 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
171 | with th.enable_grad():
172 | # Fixes a bug where the first op in run_function modifies the
173 | # Tensor storage in place, which is not allowed for detach()'d
174 | # Tensors.
175 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
176 | output_tensors = ctx.run_function(*shallow_copies)
177 | input_grads = th.autograd.grad(
178 | output_tensors,
179 | ctx.input_tensors + ctx.input_params,
180 | output_grads,
181 | allow_unused=True,
182 | )
183 | del ctx.input_tensors
184 | del ctx.input_params
185 | del output_tensors
186 | return (None, None) + input_grads
187 |
--------------------------------------------------------------------------------
/train_classifier.py:
--------------------------------------------------------------------------------
1 | import click
2 | import os
3 | import classifier_lib
4 | import torch
5 | import numpy as np
6 | import dnnlib
7 | import torchvision.transforms as transforms
8 | import torch.utils.data as data
9 | import random
10 |
11 | class BasicDataset(data.Dataset):
12 | def __init__(self, x_np, y_np, transform=transforms.ToTensor()):
13 | super(BasicDataset, self).__init__()
14 |
15 | self.x = x_np
16 | self.y = y_np
17 | self.transform = transform
18 |
19 | def __getitem__(self, index):
20 | return self.transform(self.x[index]), self.y[index]
21 |
22 | def __len__(self):
23 | return len(self.x)
24 |
25 |
26 | @click.command()
27 |
28 | ## Data configuration
29 | @click.option('--feature_extractor', help='Path of feature extractor', metavar='STR', type=str, default='/checkpoints/discriminator/feature_extractor/32x32_classifier.pt')
30 | @click.option('--savedir', help='Discriminator save directory', metavar='PATH', type=str, default="/checkpoints/discriminator/cifar10/unbias_500/")
31 | @click.option('--biasdir', help='Bias data directory', metavar='PATH', type=str, default="/datasets/cifar10/discriminator_training/bias_10000/fake_data.npz")
32 | @click.option('--refdir', help='Real sample directory', metavar='PATH', type=str, default="/datasets/cifar10/discriminator_training/unbias_500/real_data.npz")
33 |
34 | @click.option('--img_resolution', help='Image resolution', metavar='INT', type=click.IntRange(min=1), default=32)
35 | @click.option('--real_mul', help='Scaling imblance ', metavar='STR', type=click.IntRange(min=1), default=20)
36 |
37 | ## Training configuration
38 | @click.option('--batch_size', help='Batch size', metavar='INT', type=click.IntRange(min=1), default=128)
39 | @click.option('--iter', help='Training iteration', metavar='INT', type=click.IntRange(min=1), default=20000)
40 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=3e-4)
41 | @click.option('--device', help='Device', metavar='STR', type=str, default='cuda:0')
42 |
43 | def main(**kwargs):
44 | opts = dnnlib.EasyDict(kwargs)
45 | path = os.getcwd()
46 | path_feat = path + opts.feature_extractor
47 | savedir = path + opts.savedir
48 | refdir = path + opts.refdir
49 | biasdir = path + opts.biasdir
50 | os.makedirs(savedir,exist_ok=True)
51 |
52 | ## Prepare real&fake data
53 | ref_data = np.load(refdir)['samples']
54 | for k in range(opts.real_mul):
55 | if k == 0:
56 | ref_datas = ref_data
57 | else:
58 | ref_datas = np.concatenate([ref_datas, ref_data])
59 | ref_data = ref_datas
60 | bias_data = np.load(biasdir)['samples']
61 | print("bias:", len(bias_data))
62 | print("scaled unbias:", len(ref_data))
63 |
64 | ## Loader
65 | transform = transforms.Compose([transforms.ToTensor()])
66 | ref_label = torch.ones(ref_data.shape[0])
67 | bias_label = torch.zeros(bias_data.shape[0])
68 |
69 | ref_data = BasicDataset(ref_data, ref_label, transform)
70 | bias_data = BasicDataset(bias_data, bias_label, transform)
71 | ref_loader = torch.utils.data.DataLoader(dataset=ref_data, batch_size=opts.batch_size, num_workers=0, shuffle=True, drop_last=True)
72 | bias_loader = torch.utils.data.DataLoader(dataset=bias_data, batch_size=opts.batch_size, num_workers=0, shuffle=True, drop_last=True)
73 | ref_iterator = iter(ref_loader)
74 | bias_iterator = iter(bias_loader)
75 |
76 | ## Extractor & Disciminator
77 | pretrained_classifier = classifier_lib.load_classifier(path_feat, opts.img_resolution, opts.device, eval=False)
78 | discriminator = classifier_lib.load_discriminator(None, opts.device, 0, eval=False)
79 |
80 | ## Prepare training
81 | vpsde = classifier_lib.vpsde()
82 | optimizer = torch.optim.Adam(discriminator.parameters(), lr=opts.lr, weight_decay=1e-7)
83 | loss = torch.nn.BCELoss()
84 | scaler = lambda x: 2. * x - 1.
85 |
86 | ## Training
87 | for i in range(opts.iter):
88 | if i % 500 == 0:
89 | outs = []
90 | cors = []
91 | num_data = 0
92 | try:
93 | r_inputs, r_labels = next(ref_iterator)
94 | except:
95 | ref_iterator = iter(ref_loader)
96 | r_inputs, r_labels = next(ref_iterator)
97 | try:
98 | f_inputs, f_labels = next(bias_iterator)
99 | except:
100 | bias_iterator = iter(bias_loader)
101 | f_inputs, f_labels = next(bias_iterator)
102 |
103 | inputs = torch.cat([r_inputs, f_inputs])
104 | labels = torch.cat([r_labels, f_labels])
105 | c = list(range(inputs.shape[0]))
106 | random.shuffle(c)
107 | inputs, labels = inputs[c], labels[c]
108 |
109 | optimizer.zero_grad()
110 | inputs = inputs.to(opts.device)
111 | labels = labels.to(opts.device)
112 | inputs = scaler(inputs)
113 |
114 | ## Data perturbation
115 | t, _ = vpsde.get_diffusion_time(inputs.shape[0], inputs.device, importance_sampling=True)
116 | mean, std = vpsde.marginal_prob(t)
117 | z = torch.randn_like(inputs)
118 | perturbed_inputs = mean[:, None, None, None] * inputs + std[:, None, None, None] * z
119 |
120 | ## Forward
121 | with torch.no_grad():
122 | pretrained_feature = pretrained_classifier(perturbed_inputs, timesteps=t, feature=True)
123 | label_prediction = discriminator(pretrained_feature, t, sigmoid=True).view(-1)
124 |
125 | ## Backward
126 | out = loss(label_prediction, labels)
127 | out.backward()
128 | optimizer.step()
129 |
130 | ## Report
131 | cor = ((label_prediction > 0.5).float() == labels).float().mean()
132 | outs.append(out.item())
133 | cors.append(cor.item())
134 | num_data += inputs.shape[0]
135 | print(f"{i}-th iter BCE loss: {np.mean(outs)}, correction rate: {np.mean(cors)}")
136 |
137 | if i % 500 == 0:
138 | torch.save(discriminator.state_dict(), savedir + f"/discriminator_{i+1}.pt")
139 |
140 | #----------------------------------------------------------------------------
141 | if __name__ == "__main__":
142 | main()
143 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/fid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Script for calculating Frechet Inception Distance (FID)."""
9 |
10 | import os
11 | import click
12 | import tqdm
13 | import pickle
14 | import numpy as np
15 | import scipy.linalg
16 | import torch
17 | import dnnlib
18 | from torch_utils import distributed as dist
19 | from training import dataset
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def calculate_inception_stats(
24 | image_path, num_expected=None, seed=0, max_batch_size=64,
25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
26 | ):
27 | # Rank 0 goes first.
28 | if dist.get_rank() != 0:
29 | torch.distributed.barrier()
30 |
31 | # Load Inception-v3 model.
32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
33 | dist.print0('Loading Inception-v3 model...')
34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
35 | detector_kwargs = dict(return_features=True)
36 | feature_dim = 2048
37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
38 | detector_net = pickle.load(f).to(device)
39 |
40 | # List images.
41 | dist.print0(f'Loading images from "{image_path}"...')
42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
43 | if num_expected is not None and len(dataset_obj) < num_expected:
44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
45 | if len(dataset_obj) < 2:
46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
47 |
48 | # Other ranks follow.
49 | if dist.get_rank() == 0:
50 | torch.distributed.barrier()
51 |
52 | # Divide images into batches.
53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
57 |
58 | # Accumulate statistics.
59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
63 | torch.distributed.barrier()
64 | if images.shape[0] == 0:
65 | continue
66 | if images.shape[1] == 1:
67 | images = images.repeat([1, 3, 1, 1])
68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
69 | mu += features.sum(0)
70 | sigma += features.T @ features
71 |
72 | # Calculate grand totals.
73 | torch.distributed.all_reduce(mu)
74 | torch.distributed.all_reduce(sigma)
75 | mu /= len(dataset_obj)
76 | sigma -= mu.ger(mu) * len(dataset_obj)
77 | sigma /= len(dataset_obj) - 1
78 | return mu.cpu().numpy(), sigma.cpu().numpy()
79 |
80 | #----------------------------------------------------------------------------
81 |
82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
83 | m = np.square(mu - mu_ref).sum()
84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
85 | fid = m + np.trace(sigma + sigma_ref - s * 2)
86 | return float(np.real(fid))
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | @click.group()
91 | def main():
92 | """Calculate Frechet Inception Distance (FID).
93 |
94 | Examples:
95 |
96 | \b
97 | # Generate 50000 images and save them as fid-tmp/*/*.png
98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
100 |
101 | \b
102 | # Calculate FID
103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
105 |
106 | \b
107 | # Compute dataset reference statistics
108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
109 | """
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | @main.command()
114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
119 |
120 | def calc(image_path, ref_path, num_expected, seed, batch):
121 | """Calculate FID for a given set of images."""
122 | torch.multiprocessing.set_start_method('spawn')
123 | dist.init()
124 |
125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
126 | ref = None
127 | if dist.get_rank() == 0:
128 | with dnnlib.util.open_url(ref_path) as f:
129 | ref = dict(np.load(f))
130 |
131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
132 | dist.print0('Calculating FID...')
133 | if dist.get_rank() == 0:
134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
135 | print(f'{fid:g}')
136 | torch.distributed.barrier()
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | @main.command()
141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
144 |
145 | def ref(dataset_path, dest_path, batch):
146 | """Calculate dataset reference statistics needed by 'calc'."""
147 | torch.multiprocessing.set_start_method('spawn')
148 | dist.init()
149 |
150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
152 | if dist.get_rank() == 0:
153 | if os.path.dirname(dest_path):
154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True)
155 | np.savez(dest_path, mu=mu, sigma=sigma)
156 |
157 | torch.distributed.barrier()
158 | dist.print0('Done.')
159 |
160 | #----------------------------------------------------------------------------
161 |
162 | if __name__ == "__main__":
163 | main()
164 |
165 | #----------------------------------------------------------------------------
166 |
--------------------------------------------------------------------------------
/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203 | opt.step()
204 | zero_master_grads(self.master_params)
205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206 | self.lg_loss_scale += self.fp16_scale_growth
207 | return True
208 |
209 | def _optimize_normal(self, opt: th.optim.Optimizer):
210 | grad_norm, param_norm = self._compute_norms()
211 | logger.logkv_mean("grad_norm", grad_norm)
212 | logger.logkv_mean("param_norm", param_norm)
213 | opt.step()
214 | return True
215 |
216 | def _compute_norms(self, grad_scale=1.0):
217 | grad_norm = 0.0
218 | param_norm = 0.0
219 | for p in self.master_params:
220 | with th.no_grad():
221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222 | if p.grad is not None:
223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225 |
226 | def master_params_to_state_dict(self, master_params):
227 | return master_params_to_state_dict(
228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229 | )
230 |
231 | def state_dict_to_master_params(self, state_dict):
232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233 |
234 |
235 | def check_overflow(value):
236 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
237 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Streaming images and labels from datasets created with dataset_tool.py."""
9 |
10 | import os
11 | import numpy as np
12 | import zipfile
13 | import PIL.Image
14 | import json
15 | import torch
16 | import dnnlib
17 |
18 | try:
19 | import pyspng
20 | except ImportError:
21 | pyspng = None
22 |
23 | #----------------------------------------------------------------------------
24 | # Abstract base class for datasets.
25 |
26 | class Dataset(torch.utils.data.Dataset):
27 | def __init__(self,
28 | name, # Name of the dataset.
29 | raw_shape, # Shape of the raw image data (NCHW).
30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
33 | random_seed = 0, # Random seed to use when applying max_size.
34 | cache = False, # Cache images in CPU memory?
35 | ):
36 | self._name = name
37 | self._raw_shape = list(raw_shape)
38 | self._use_labels = use_labels
39 | self._cache = cache
40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
41 | self._raw_labels = None
42 | self._label_shape = None
43 |
44 | # Apply max_size.
45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
46 | if (max_size is not None) and (self._raw_idx.size > max_size):
47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
48 | self._raw_idx = np.sort(self._raw_idx[:max_size])
49 |
50 | # Apply xflip.
51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
52 | if xflip:
53 | self._raw_idx = np.tile(self._raw_idx, 2)
54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
55 |
56 | def _get_raw_labels(self):
57 | if self._raw_labels is None:
58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
59 | if self._raw_labels is None:
60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
61 | assert isinstance(self._raw_labels, np.ndarray)
62 | assert self._raw_labels.shape[0] == self._raw_shape[0]
63 | assert self._raw_labels.dtype in [np.float32, np.int64]
64 | if self._raw_labels.dtype == np.int64:
65 | assert self._raw_labels.ndim == 1
66 | assert np.all(self._raw_labels >= 0)
67 | return self._raw_labels
68 |
69 | def close(self): # to be overridden by subclass
70 | pass
71 |
72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
73 | raise NotImplementedError
74 |
75 | def _load_raw_labels(self): # to be overridden by subclass
76 | raise NotImplementedError
77 |
78 | def __getstate__(self):
79 | return dict(self.__dict__, _raw_labels=None)
80 |
81 | def __del__(self):
82 | try:
83 | self.close()
84 | except:
85 | pass
86 |
87 | def __len__(self):
88 | return self._raw_idx.size
89 |
90 | def __getitem__(self, idx):
91 | raw_idx = self._raw_idx[idx]
92 | image = self._cached_images.get(raw_idx, None)
93 | if image is None:
94 | image = self._load_raw_image(raw_idx)
95 | if self._cache:
96 | self._cached_images[raw_idx] = image
97 | assert isinstance(image, np.ndarray)
98 | assert list(image.shape) == self.image_shape
99 | assert image.dtype == np.uint8
100 | if self._xflip[idx]:
101 | assert image.ndim == 3 # CHW
102 | image = image[:, :, ::-1]
103 | return image.copy(), self.get_label(idx)
104 |
105 | def get_label(self, idx):
106 | label = self._get_raw_labels()[self._raw_idx[idx]]
107 | if label.dtype == np.int64:
108 | onehot = np.zeros(self.label_shape, dtype=np.float32)
109 | onehot[label] = 1
110 | label = onehot
111 | return label.copy()
112 |
113 | def get_details(self, idx):
114 | d = dnnlib.EasyDict()
115 | d.raw_idx = int(self._raw_idx[idx])
116 | d.xflip = (int(self._xflip[idx]) != 0)
117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
118 | return d
119 |
120 | @property
121 | def name(self):
122 | return self._name
123 |
124 | @property
125 | def image_shape(self):
126 | return list(self._raw_shape[1:])
127 |
128 | @property
129 | def num_channels(self):
130 | assert len(self.image_shape) == 3 # CHW
131 | return self.image_shape[0]
132 |
133 | @property
134 | def resolution(self):
135 | assert len(self.image_shape) == 3 # CHW
136 | assert self.image_shape[1] == self.image_shape[2]
137 | return self.image_shape[1]
138 |
139 | @property
140 | def label_shape(self):
141 | if self._label_shape is None:
142 | raw_labels = self._get_raw_labels()
143 | if raw_labels.dtype == np.int64:
144 | self._label_shape = [int(np.max(raw_labels)) + 1]
145 | else:
146 | self._label_shape = raw_labels.shape[1:]
147 | return list(self._label_shape)
148 |
149 | @property
150 | def label_dim(self):
151 | assert len(self.label_shape) == 1
152 | return self.label_shape[0]
153 |
154 | @property
155 | def has_labels(self):
156 | return any(x != 0 for x in self.label_shape)
157 |
158 | @property
159 | def has_onehot_labels(self):
160 | return self._get_raw_labels().dtype == np.int64
161 |
162 | #----------------------------------------------------------------------------
163 | # Dataset subclass that loads images recursively from the specified directory
164 | # or ZIP file.
165 |
166 | class ImageFolderDataset(Dataset):
167 | def __init__(self,
168 | path, # Path to directory or zip.
169 | resolution = None, # Ensure specific resolution, None = highest available.
170 | use_pyspng = True, # Use pyspng if available?
171 | **super_kwargs, # Additional arguments for the Dataset base class.
172 | ):
173 | self._path = path
174 | self._use_pyspng = use_pyspng
175 | self._zipfile = None
176 |
177 | if os.path.isdir(self._path):
178 | self._type = 'dir'
179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
180 | elif self._file_ext(self._path) == '.zip':
181 | self._type = 'zip'
182 | self._all_fnames = set(self._get_zipfile().namelist())
183 | else:
184 | raise IOError('Path must point to a directory or zip')
185 |
186 | PIL.Image.init()
187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
188 | if len(self._image_fnames) == 0:
189 | raise IOError('No image files found in the specified path')
190 |
191 | name = os.path.splitext(os.path.basename(self._path))[0]
192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
194 | raise IOError('Image files do not match the specified resolution')
195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
196 |
197 | @staticmethod
198 | def _file_ext(fname):
199 | return os.path.splitext(fname)[1].lower()
200 |
201 | def _get_zipfile(self):
202 | assert self._type == 'zip'
203 | if self._zipfile is None:
204 | self._zipfile = zipfile.ZipFile(self._path)
205 | return self._zipfile
206 |
207 | def _open_file(self, fname):
208 | if self._type == 'dir':
209 | return open(os.path.join(self._path, fname), 'rb')
210 | if self._type == 'zip':
211 | return self._get_zipfile().open(fname, 'r')
212 | return None
213 |
214 | def close(self):
215 | try:
216 | if self._zipfile is not None:
217 | self._zipfile.close()
218 | finally:
219 | self._zipfile = None
220 |
221 | def __getstate__(self):
222 | return dict(super().__getstate__(), _zipfile=None)
223 |
224 | def _load_raw_image(self, raw_idx):
225 | fname = self._image_fnames[raw_idx]
226 | with self._open_file(fname) as f:
227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
228 | image = pyspng.load(f.read())
229 | else:
230 | image = np.array(PIL.Image.open(f))
231 | if image.ndim == 2:
232 | image = image[:, :, np.newaxis] # HW => HWC
233 | image = image.transpose(2, 0, 1) # HWC => CHW
234 | return image
235 |
236 | def _load_raw_labels(self):
237 | fname = 'dataset.json'
238 | if fname not in self._all_fnames:
239 | return None
240 | with self._open_file(fname) as f:
241 | labels = json.load(f)['labels']
242 | if labels is None:
243 | return None
244 | labels = dict(labels)
245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
246 | labels = np.array(labels)
247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
248 | return labels
249 |
250 | #----------------------------------------------------------------------------
251 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for pickling Python code alongside other data.
9 |
10 | The pickled code is automatically imported into a separate Python module
11 | during unpickling. This way, any previously exported pickles will remain
12 | usable even if the original code is no longer available, or if the current
13 | version of the code is not consistent with what was originally pickled."""
14 |
15 | import sys
16 | import pickle
17 | import io
18 | import inspect
19 | import copy
20 | import uuid
21 | import types
22 | import dnnlib
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | _version = 6 # internal version number
27 | _decorators = set() # {decorator_class, ...}
28 | _import_hooks = [] # [hook_function, ...]
29 | _module_to_src_dict = dict() # {module: src, ...}
30 | _src_to_module_dict = dict() # {src: module, ...}
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def persistent_class(orig_class):
35 | r"""Class decorator that extends a given class to save its source code
36 | when pickled.
37 |
38 | Example:
39 |
40 | from torch_utils import persistence
41 |
42 | @persistence.persistent_class
43 | class MyNetwork(torch.nn.Module):
44 | def __init__(self, num_inputs, num_outputs):
45 | super().__init__()
46 | self.fc = MyLayer(num_inputs, num_outputs)
47 | ...
48 |
49 | @persistence.persistent_class
50 | class MyLayer(torch.nn.Module):
51 | ...
52 |
53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
54 | source code alongside other internal state (e.g., parameters, buffers,
55 | and submodules). This way, any previously exported pickle will remain
56 | usable even if the class definitions have been modified or are no
57 | longer available.
58 |
59 | The decorator saves the source code of the entire Python module
60 | containing the decorated class. It does *not* save the source code of
61 | any imported modules. Thus, the imported modules must be available
62 | during unpickling, also including `torch_utils.persistence` itself.
63 |
64 | It is ok to call functions defined in the same module from the
65 | decorated class. However, if the decorated class depends on other
66 | classes defined in the same module, they must be decorated as well.
67 | This is illustrated in the above example in the case of `MyLayer`.
68 |
69 | It is also possible to employ the decorator just-in-time before
70 | calling the constructor. For example:
71 |
72 | cls = MyLayer
73 | if want_to_make_it_persistent:
74 | cls = persistence.persistent_class(cls)
75 | layer = cls(num_inputs, num_outputs)
76 |
77 | As an additional feature, the decorator also keeps track of the
78 | arguments that were used to construct each instance of the decorated
79 | class. The arguments can be queried via `obj.init_args` and
80 | `obj.init_kwargs`, and they are automatically pickled alongside other
81 | object state. This feature can be disabled on a per-instance basis
82 | by setting `self._record_init_args = False` in the constructor.
83 |
84 | A typical use case is to first unpickle a previous instance of a
85 | persistent class, and then upgrade it to use the latest version of
86 | the source code:
87 |
88 | with open('old_pickle.pkl', 'rb') as f:
89 | old_net = pickle.load(f)
90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92 | """
93 | assert isinstance(orig_class, type)
94 | if is_persistent(orig_class):
95 | return orig_class
96 |
97 | assert orig_class.__module__ in sys.modules
98 | orig_module = sys.modules[orig_class.__module__]
99 | orig_module_src = _module_to_src(orig_module)
100 |
101 | class Decorator(orig_class):
102 | _orig_module_src = orig_module_src
103 | _orig_class_name = orig_class.__name__
104 |
105 | def __init__(self, *args, **kwargs):
106 | super().__init__(*args, **kwargs)
107 | record_init_args = getattr(self, '_record_init_args', True)
108 | self._init_args = copy.deepcopy(args) if record_init_args else None
109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
110 | assert orig_class.__name__ in orig_module.__dict__
111 | _check_pickleable(self.__reduce__())
112 |
113 | @property
114 | def init_args(self):
115 | assert self._init_args is not None
116 | return copy.deepcopy(self._init_args)
117 |
118 | @property
119 | def init_kwargs(self):
120 | assert self._init_kwargs is not None
121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
122 |
123 | def __reduce__(self):
124 | fields = list(super().__reduce__())
125 | fields += [None] * max(3 - len(fields), 0)
126 | if fields[0] is not _reconstruct_persistent_obj:
127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
128 | fields[0] = _reconstruct_persistent_obj # reconstruct func
129 | fields[1] = (meta,) # reconstruct args
130 | fields[2] = None # state dict
131 | return tuple(fields)
132 |
133 | Decorator.__name__ = orig_class.__name__
134 | Decorator.__module__ = orig_class.__module__
135 | _decorators.add(Decorator)
136 | return Decorator
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | def is_persistent(obj):
141 | r"""Test whether the given object or class is persistent, i.e.,
142 | whether it will save its source code when pickled.
143 | """
144 | try:
145 | if obj in _decorators:
146 | return True
147 | except TypeError:
148 | pass
149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150 |
151 | #----------------------------------------------------------------------------
152 |
153 | def import_hook(hook):
154 | r"""Register an import hook that is called whenever a persistent object
155 | is being unpickled. A typical use case is to patch the pickled source
156 | code to avoid errors and inconsistencies when the API of some imported
157 | module has changed.
158 |
159 | The hook should have the following signature:
160 |
161 | hook(meta) -> modified meta
162 |
163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164 |
165 | type: Type of the persistent object, e.g. `'class'`.
166 | version: Internal version number of `torch_utils.persistence`.
167 | module_src Original source code of the Python module.
168 | class_name: Class name in the original Python module.
169 | state: Internal state of the object.
170 |
171 | Example:
172 |
173 | @persistence.import_hook
174 | def wreck_my_network(meta):
175 | if meta.class_name == 'MyNetwork':
176 | print('MyNetwork is being imported. I will wreck it!')
177 | meta.module_src = meta.module_src.replace("True", "False")
178 | return meta
179 | """
180 | assert callable(hook)
181 | _import_hooks.append(hook)
182 |
183 | #----------------------------------------------------------------------------
184 |
185 | def _reconstruct_persistent_obj(meta):
186 | r"""Hook that is called internally by the `pickle` module to unpickle
187 | a persistent object.
188 | """
189 | meta = dnnlib.EasyDict(meta)
190 | meta.state = dnnlib.EasyDict(meta.state)
191 | for hook in _import_hooks:
192 | meta = hook(meta)
193 | assert meta is not None
194 |
195 | assert meta.version == _version
196 | module = _src_to_module(meta.module_src)
197 |
198 | assert meta.type == 'class'
199 | orig_class = module.__dict__[meta.class_name]
200 | decorator_class = persistent_class(orig_class)
201 | obj = decorator_class.__new__(decorator_class)
202 |
203 | setstate = getattr(obj, '__setstate__', None)
204 | if callable(setstate):
205 | setstate(meta.state) # pylint: disable=not-callable
206 | else:
207 | obj.__dict__.update(meta.state)
208 | return obj
209 |
210 | #----------------------------------------------------------------------------
211 |
212 | def _module_to_src(module):
213 | r"""Query the source code of a given Python module.
214 | """
215 | src = _module_to_src_dict.get(module, None)
216 | if src is None:
217 | src = inspect.getsource(module)
218 | _module_to_src_dict[module] = src
219 | _src_to_module_dict[src] = module
220 | return src
221 |
222 | def _src_to_module(src):
223 | r"""Get or create a Python module for the given source code.
224 | """
225 | module = _src_to_module_dict.get(src, None)
226 | if module is None:
227 | module_name = "_imported_module_" + uuid.uuid4().hex
228 | module = types.ModuleType(module_name)
229 | sys.modules[module_name] = module
230 | _module_to_src_dict[module] = src
231 | _src_to_module_dict[src] = module
232 | exec(src, module.__dict__) # pylint: disable=exec-used
233 | return module
234 |
235 | #----------------------------------------------------------------------------
236 |
237 | def _check_pickleable(obj):
238 | r"""Check that the given object is pickleable, raising an exception if
239 | it is not. This function is expected to be considerably more efficient
240 | than actually pickling the object.
241 | """
242 | def recurse(obj):
243 | if isinstance(obj, (list, tuple, set)):
244 | return [recurse(x) for x in obj]
245 | if isinstance(obj, dict):
246 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248 | return None # Python primitive types are pickleable.
249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250 | return None # NumPy arrays and PyTorch tensors are pickleable.
251 | if is_persistent(obj):
252 | return None # Persistent objects are pickleable, by virtue of the constructor check.
253 | return obj
254 | with io.BytesIO() as f:
255 | pickle.dump(recurse(obj), f)
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/classifier_lib.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from guided_diffusion.script_util import create_classifier
3 | import os
4 | import numpy as np
5 | import io
6 |
7 | def get_discriminator(latent_extractor_ckpt, discriminator_ckpt, enable_grad=True):
8 | classifier = latent_extractor_ckpt
9 | discriminator = discriminator_ckpt
10 | def evaluate(perturbed_inputs, timesteps=None, condition=None):
11 | with torch.enable_grad() if enable_grad else torch.no_grad():
12 | adm_features = classifier(perturbed_inputs, timesteps=timesteps, feature=True)
13 | prediction = discriminator(adm_features, timesteps, sigmoid=True, condition=condition).view(-1)
14 | return prediction
15 | return evaluate
16 |
17 | def get_reverse_discriminator(latent_extractor_ckpt, discriminator_ckpt, enable_grad=True):
18 | classifier = latent_extractor_ckpt
19 | discriminator = discriminator_ckpt
20 | def evaluate(perturbed_inputs, timesteps=None, condition=None):
21 | with torch.enable_grad() if enable_grad else torch.no_grad():
22 | adm_features = classifier(perturbed_inputs, timesteps=timesteps, feature=True)
23 | prediction = discriminator(adm_features, timesteps, sigmoid=True, condition=condition).view(-1)
24 | return 1.- prediction
25 | return evaluate
26 |
27 | def load_classifier(ckpt_path, img_resolution, device, eval=True):
28 | classifier_args = dict(
29 | image_size=img_resolution,
30 | classifier_use_fp16=False,
31 | classifier_width=128,
32 | classifier_depth=4 if img_resolution in [64, 32] else 2,
33 | classifier_attention_resolutions="32,16,8",
34 | classifier_use_scale_shift_norm=True,
35 | classifier_resblock_updown=True,
36 | classifier_pool="attention",
37 | out_channels=1000,
38 | )
39 | classifier = create_classifier(**classifier_args)
40 | classifier.to(device)
41 | if ckpt_path is not None:
42 | classifier_state = torch.load(ckpt_path, map_location="cpu")
43 | classifier.load_state_dict(classifier_state)
44 | if eval:
45 | classifier.eval()
46 | return classifier
47 |
48 | def load_discriminator(ckpt_path, device, condition, eval=False, channel=512):
49 | discriminator_args = dict(
50 | image_size=8,
51 | classifier_use_fp16=False,
52 | classifier_width=128,
53 | classifier_depth=2,
54 | classifier_attention_resolutions="32,16,8",
55 | classifier_use_scale_shift_norm=True,
56 | classifier_resblock_updown=True,
57 | classifier_pool="attention",
58 | out_channels=1,
59 | in_channels=channel,
60 | condition=condition,
61 | )
62 | discriminator = create_classifier(**discriminator_args)
63 | discriminator.to(device)
64 | if ckpt_path is not None:
65 | #ckpt_path = os.getcwd() + ckpt_path
66 | discriminator_state = torch.load(ckpt_path, map_location="cpu")
67 | discriminator.load_state_dict(discriminator_state)
68 | if eval:
69 | discriminator.eval()
70 | return discriminator
71 |
72 | def load_attribute_classifier(ckpt_path, device, condition, eval=False, channel=512):
73 | discriminator_args = dict(
74 | image_size=8,
75 | classifier_use_fp16=False,
76 | classifier_width=128,
77 | classifier_depth=2,
78 | classifier_attention_resolutions="32,16,8",
79 | classifier_use_scale_shift_norm=True,
80 | classifier_resblock_updown=True,
81 | classifier_pool="attention",
82 | out_channels=4,
83 | in_channels=channel,
84 | condition=condition,
85 | )
86 | discriminator = create_classifier(**discriminator_args)
87 | discriminator.to(device)
88 | if ckpt_path is not None:
89 | #ckpt_path = os.getcwd() + ckpt_path
90 | discriminator_state = torch.load(ckpt_path, map_location="cpu")
91 | discriminator.load_state_dict(discriminator_state)
92 | if eval:
93 | discriminator.eval()
94 | return discriminator
95 |
96 | def get_weight_correction(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels=None):
97 | mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier
98 | if tau.min() > time_max or tau.min() < time_min or discriminator == None:
99 | return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device), torch.zeros_like(unnormalized_input)
100 |
101 | else:
102 | input = mean_vp_tau[:,None,None,None] * unnormalized_input
103 | with torch.enable_grad():
104 | x_ = input.float().clone().detach().requires_grad_()
105 | if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE
106 | tau = vpsde.compute_t_cos_from_t_lin(tau)
107 | tau = torch.ones(input.shape[0], device=tau.device) * tau
108 | log_ratio, prediction = get_log_ratio(discriminator, x_, tau, class_labels)
109 | correction = torch.autograd.grad(outputs=prediction.sum(), inputs=x_, retain_graph=False)[0]
110 | correction *= - ((std_wve_t[:,None,None,None] ** 2) * mean_vp_tau[:,None,None,None]) ## Scaling to Expected Denoised Point
111 | return correction, prediction
112 |
113 | def get_grad_log_ratio(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels, log=False):
114 | mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier
115 | if tau.min() > time_max or tau.min() < time_min or discriminator == None:
116 | if log:
117 | return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device), torch.zeros_like(unnormalized_input)
118 | return torch.zeros_like(unnormalized_input)
119 | else:
120 | input = mean_vp_tau[:,None,None,None] * unnormalized_input
121 | with torch.enable_grad():
122 | x_ = input.float().clone().detach().requires_grad_()
123 | if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE
124 | tau = vpsde.compute_t_cos_from_t_lin(tau)
125 | tau = torch.ones(input.shape[0], device=tau.device) * tau
126 | log_ratio, prediction = get_log_ratio(discriminator, x_, tau, class_labels)
127 | discriminator_guidance_score = torch.autograd.grad(outputs=log_ratio.sum(), inputs=x_, retain_graph=False)[0]
128 | # print(mean_vp_tau.shape)
129 | # print(std_wve_t.shape)
130 | # print(discriminator_guidance_score.shape)
131 | discriminator_guidance_score *= - ((std_wve_t[:,None,None,None] ** 2) * mean_vp_tau[:,None,None,None])
132 | if log:
133 | return discriminator_guidance_score, log_ratio, prediction
134 | return discriminator_guidance_score
135 |
136 | def get_prediction(discriminator, vpsde, unnormalized_input, std_wve_t, img_resolution, time_min, time_max, class_labels, log=False):
137 | mean_vp_tau, tau = vpsde.transform_unnormalized_wve_to_normalized_vp(std_wve_t) ## VP pretrained classifier
138 | if tau.min() > time_max or tau.min() < time_min or discriminator == None:
139 | if log:
140 | return torch.zeros_like(unnormalized_input), 10000000. * torch.ones(unnormalized_input.shape[0], device=unnormalized_input.device), torch.zeros_like(unnormalized_input)
141 | return torch.zeros_like(unnormalized_input)
142 | else:
143 | input = mean_vp_tau[:,None,None,None] * unnormalized_input
144 | with torch.enable_grad():
145 | x_ = input.float().clone().detach().requires_grad_()
146 | if img_resolution == 64: # ADM trained UNet classifier for 64x64 with Cosine VPSDE
147 | tau = vpsde.compute_t_cos_from_t_lin(tau)
148 | tau = torch.ones(input.shape[0], device=tau.device) * tau
149 | log_ratio, prediction = get_log_ratio(discriminator, x_, tau, class_labels)
150 | return prediction
151 |
152 |
153 |
154 | def get_log_ratio(discriminator, input, time, class_labels):
155 | if discriminator == None:
156 | return torch.zeros(input.shape[0], device=input.device)
157 | else:
158 | logits = discriminator(input, timesteps=time, condition=class_labels)
159 | prediction = torch.clip(logits, 1e-5, 1. - 1e-5)
160 | log_ratio = torch.log(prediction / (1. - prediction))
161 | return log_ratio, prediction
162 |
163 | class vpsde():
164 | def __init__(self):
165 | self.beta_0 = 0.1
166 | self.beta_1 = 20.
167 | self.s = 0.008
168 | self.f_0 = np.cos(self.s / (1. + self.s) * np.pi / 2.) ** 2
169 |
170 | @property
171 | def T(self):
172 | return 1
173 | def compute_reverse_tau(self,tau):
174 | tau *= self.beta_1 - self.beta_0
175 | std_wve_t = (tau+self.beta_0) **2 - (self.beta_0 ** 2.)
176 | std_wve_t /= 2. * (self.beta_1 - self.beta_0)
177 | std_wve_t = torch.exp(std_wve_t) -1
178 | std_wve_t = std_wve_t **0.5
179 |
180 | return std_wve_t
181 |
182 | def compute_tau(self, std_wve_t):
183 | tau = -self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2. * (self.beta_1 - self.beta_0) * torch.log(1. + std_wve_t ** 2))
184 | tau /= self.beta_1 - self.beta_0
185 | return tau
186 |
187 | def marginal_prob(self, t):
188 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
189 | mean = torch.exp(log_mean_coeff)
190 | std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
191 | return mean, std
192 |
193 | def transform_unnormalized_wve_to_normalized_vp(self, t, std_out=False):
194 | tau = self.compute_tau(t)
195 | mean_vp_tau, std_vp_tau = self.marginal_prob(tau)
196 | if std_out:
197 | return mean_vp_tau, std_vp_tau, tau
198 | return mean_vp_tau, tau
199 |
200 | def compute_t_cos_from_t_lin(self, t_lin):
201 | sqrt_alpha_t_bar = torch.exp(-0.25 * t_lin ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t_lin * self.beta_0)
202 | time = torch.arccos(np.sqrt(self.f_0) * sqrt_alpha_t_bar)
203 | t_cos = self.T * ((1. + self.s) * 2. / np.pi * time - self.s)
204 | return t_cos
205 |
206 | def get_diffusion_time(self, batch_size, batch_device, t_min=1e-5, importance_sampling=True):
207 | if importance_sampling:
208 | Z = self.normalizing_constant(t_min)
209 | u = torch.rand(batch_size, device=batch_device)
210 | return (-self.beta_0 + torch.sqrt(self.beta_0 ** 2 + 2 * (self.beta_1 - self.beta_0) *
211 | torch.log(1. + torch.exp(Z * u + self.antiderivative(t_min))))) / (self.beta_1 - self.beta_0), Z.detach()
212 | else:
213 | return torch.rand(batch_size, device=batch_device) * (self.T - t_min) + t_min, 1
214 |
215 | def antiderivative(self, t, stabilizing_constant=0.):
216 | if isinstance(t, float) or isinstance(t, int):
217 | t = torch.tensor(t).float()
218 | return torch.log(1. - torch.exp(- self.integral_beta(t)) + stabilizing_constant) + self.integral_beta(t)
219 |
220 | def normalizing_constant(self, t_min):
221 | return self.antiderivative(self.T) - self.antiderivative(t_min)
222 |
223 | def integral_beta(self, t):
224 | return 0.5 * t ** 2 * (self.beta_1 - self.beta_0) + t * self.beta_0
225 |
226 | a = vpsde()
227 |
228 | # for i in range(100):
229 | # mean_vp_tau, tau = a.transform_unnormalized_wve_to_normalized_vp(torch.tensor(80.))
230 | # print(i/100, tau)
231 |
232 | # compute_reverse_tau
233 |
234 | q = a.compute_reverse_tau(torch.tensor(0.))
235 | print(q)
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for reporting and collecting training statistics across
9 | multiple processes and devices. The interface is designed to minimize
10 | synchronization overhead as well as the amount of boilerplate in user
11 | code."""
12 |
13 | import re
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 |
18 | from . import misc
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
25 | _rank = 0 # Rank of the current process.
26 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
27 | _sync_called = False # Has _sync() been called yet?
28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | def init_multiprocessing(rank, sync_device):
34 | r"""Initializes `torch_utils.training_stats` for collecting statistics
35 | across multiple processes.
36 |
37 | This function must be called after
38 | `torch.distributed.init_process_group()` and before `Collector.update()`.
39 | The call is not necessary if multi-process collection is not needed.
40 |
41 | Args:
42 | rank: Rank of the current process.
43 | sync_device: PyTorch device to use for inter-process
44 | communication, or None to disable multi-process
45 | collection. Typically `torch.device('cuda', rank)`.
46 | """
47 | global _rank, _sync_device
48 | assert not _sync_called
49 | _rank = rank
50 | _sync_device = sync_device
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | @misc.profiled_function
55 | def report(name, value):
56 | r"""Broadcasts the given set of scalars to all interested instances of
57 | `Collector`, across device and process boundaries.
58 |
59 | This function is expected to be extremely cheap and can be safely
60 | called from anywhere in the training loop, loss function, or inside a
61 | `torch.nn.Module`.
62 |
63 | Warning: The current implementation expects the set of unique names to
64 | be consistent across processes. Please make sure that `report()` is
65 | called at least once for each unique name by each process, and in the
66 | same order. If a given process has no scalars to broadcast, it can do
67 | `report(name, [])` (empty list).
68 |
69 | Args:
70 | name: Arbitrary string specifying the name of the statistic.
71 | Averages are accumulated separately for each unique name.
72 | value: Arbitrary set of scalars. Can be a list, tuple,
73 | NumPy array, PyTorch tensor, or Python scalar.
74 |
75 | Returns:
76 | The same `value` that was passed in.
77 | """
78 | if name not in _counters:
79 | _counters[name] = dict()
80 |
81 | elems = torch.as_tensor(value)
82 | if elems.numel() == 0:
83 | return value
84 |
85 | elems = elems.detach().flatten().to(_reduce_dtype)
86 | moments = torch.stack([
87 | torch.ones_like(elems).sum(),
88 | elems.sum(),
89 | elems.square().sum(),
90 | ])
91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
92 | moments = moments.to(_counter_dtype)
93 |
94 | device = moments.device
95 | if device not in _counters[name]:
96 | _counters[name][device] = torch.zeros_like(moments)
97 | _counters[name][device].add_(moments)
98 | return value
99 |
100 | #----------------------------------------------------------------------------
101 |
102 | def report0(name, value):
103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
104 | but ignores any scalars provided by the other processes.
105 | See `report()` for further details.
106 | """
107 | report(name, value if _rank == 0 else [])
108 | return value
109 |
110 | #----------------------------------------------------------------------------
111 |
112 | class Collector:
113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
114 | computes their long-term averages (mean and standard deviation) over
115 | user-defined periods of time.
116 |
117 | The averages are first collected into internal counters that are not
118 | directly visible to the user. They are then copied to the user-visible
119 | state as a result of calling `update()` and can then be queried using
120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
121 | internal counters for the next round, so that the user-visible state
122 | effectively reflects averages collected between the last two calls to
123 | `update()`.
124 |
125 | Args:
126 | regex: Regular expression defining which statistics to
127 | collect. The default is to collect everything.
128 | keep_previous: Whether to retain the previous averages if no
129 | scalars were collected on a given round
130 | (default: True).
131 | """
132 | def __init__(self, regex='.*', keep_previous=True):
133 | self._regex = re.compile(regex)
134 | self._keep_previous = keep_previous
135 | self._cumulative = dict()
136 | self._moments = dict()
137 | self.update()
138 | self._moments.clear()
139 |
140 | def names(self):
141 | r"""Returns the names of all statistics broadcasted so far that
142 | match the regular expression specified at construction time.
143 | """
144 | return [name for name in _counters if self._regex.fullmatch(name)]
145 |
146 | def update(self):
147 | r"""Copies current values of the internal counters to the
148 | user-visible state and resets them for the next round.
149 |
150 | If `keep_previous=True` was specified at construction time, the
151 | operation is skipped for statistics that have received no scalars
152 | since the last update, retaining their previous averages.
153 |
154 | This method performs a number of GPU-to-CPU transfers and one
155 | `torch.distributed.all_reduce()`. It is intended to be called
156 | periodically in the main training loop, typically once every
157 | N training steps.
158 | """
159 | if not self._keep_previous:
160 | self._moments.clear()
161 | for name, cumulative in _sync(self.names()):
162 | if name not in self._cumulative:
163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
164 | delta = cumulative - self._cumulative[name]
165 | self._cumulative[name].copy_(cumulative)
166 | if float(delta[0]) != 0:
167 | self._moments[name] = delta
168 |
169 | def _get_delta(self, name):
170 | r"""Returns the raw moments that were accumulated for the given
171 | statistic between the last two calls to `update()`, or zero if
172 | no scalars were collected.
173 | """
174 | assert self._regex.fullmatch(name)
175 | if name not in self._moments:
176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
177 | return self._moments[name]
178 |
179 | def num(self, name):
180 | r"""Returns the number of scalars that were accumulated for the given
181 | statistic between the last two calls to `update()`, or zero if
182 | no scalars were collected.
183 | """
184 | delta = self._get_delta(name)
185 | return int(delta[0])
186 |
187 | def mean(self, name):
188 | r"""Returns the mean of the scalars that were accumulated for the
189 | given statistic between the last two calls to `update()`, or NaN if
190 | no scalars were collected.
191 | """
192 | delta = self._get_delta(name)
193 | if int(delta[0]) == 0:
194 | return float('nan')
195 | return float(delta[1] / delta[0])
196 |
197 | def std(self, name):
198 | r"""Returns the standard deviation of the scalars that were
199 | accumulated for the given statistic between the last two calls to
200 | `update()`, or NaN if no scalars were collected.
201 | """
202 | delta = self._get_delta(name)
203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
204 | return float('nan')
205 | if int(delta[0]) == 1:
206 | return float(0)
207 | mean = float(delta[1] / delta[0])
208 | raw_var = float(delta[2] / delta[0])
209 | return np.sqrt(max(raw_var - np.square(mean), 0))
210 |
211 | def as_dict(self):
212 | r"""Returns the averages accumulated between the last two calls to
213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
214 |
215 | dnnlib.EasyDict(
216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
217 | ...
218 | )
219 | """
220 | stats = dnnlib.EasyDict()
221 | for name in self.names():
222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
223 | return stats
224 |
225 | def __getitem__(self, name):
226 | r"""Convenience getter.
227 | `collector[name]` is a synonym for `collector.mean(name)`.
228 | """
229 | return self.mean(name)
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def _sync(names):
234 | r"""Synchronize the global cumulative counters across devices and
235 | processes. Called internally by `Collector.update()`.
236 | """
237 | if len(names) == 0:
238 | return []
239 | global _sync_called
240 | _sync_called = True
241 |
242 | # Collect deltas within current rank.
243 | deltas = []
244 | device = _sync_device if _sync_device is not None else torch.device('cpu')
245 | for name in names:
246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
247 | for counter in _counters[name].values():
248 | delta.add_(counter.to(device))
249 | counter.copy_(torch.zeros_like(counter))
250 | deltas.append(delta)
251 | deltas = torch.stack(deltas)
252 |
253 | # Sum deltas across ranks.
254 | if _sync_device is not None:
255 | torch.distributed.all_reduce(deltas)
256 |
257 | # Update cumulative values.
258 | deltas = deltas.cpu()
259 | for idx, name in enumerate(names):
260 | if name not in _cumulative:
261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
262 | _cumulative[name].add_(deltas[idx])
263 |
264 | # Return name-value pairs.
265 | return [(name, _cumulative[name]) for name in names]
266 |
267 | #----------------------------------------------------------------------------
268 | # Convenience.
269 |
270 | default_collector = Collector()
271 |
272 | #----------------------------------------------------------------------------
273 |
--------------------------------------------------------------------------------
/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import torch as th
7 | import torch.distributed as dist
8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9 | from torch.optim import AdamW
10 |
11 | from . import dist_util, logger
12 | from .fp16_util import MixedPrecisionTrainer
13 | from .nn import update_ema
14 | from .resample import LossAwareSampler, UniformSampler
15 |
16 | # For ImageNet experiments, this was a good default value.
17 | # We found that the lg_loss_scale quickly climbed to
18 | # 20-21 within the first ~1K steps of training.
19 | INITIAL_LOG_LOSS_SCALE = 20.0
20 |
21 |
22 | class TrainLoop:
23 | def __init__(
24 | self,
25 | *,
26 | model,
27 | diffusion,
28 | data,
29 | batch_size,
30 | microbatch,
31 | lr,
32 | ema_rate,
33 | log_interval,
34 | save_interval,
35 | resume_checkpoint,
36 | use_fp16=False,
37 | fp16_scale_growth=1e-3,
38 | schedule_sampler=None,
39 | weight_decay=0.0,
40 | lr_anneal_steps=0,
41 | ):
42 | self.model = model
43 | self.diffusion = diffusion
44 | self.data = data
45 | self.batch_size = batch_size
46 | self.microbatch = microbatch if microbatch > 0 else batch_size
47 | self.lr = lr
48 | self.ema_rate = (
49 | [ema_rate]
50 | if isinstance(ema_rate, float)
51 | else [float(x) for x in ema_rate.split(",")]
52 | )
53 | self.log_interval = log_interval
54 | self.save_interval = save_interval
55 | self.resume_checkpoint = resume_checkpoint
56 | self.use_fp16 = use_fp16
57 | self.fp16_scale_growth = fp16_scale_growth
58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
59 | self.weight_decay = weight_decay
60 | self.lr_anneal_steps = lr_anneal_steps
61 |
62 | self.step = 0
63 | self.resume_step = 0
64 | self.global_batch = self.batch_size * dist.get_world_size()
65 |
66 | self.sync_cuda = th.cuda.is_available()
67 |
68 | self._load_and_sync_parameters()
69 | self.mp_trainer = MixedPrecisionTrainer(
70 | model=self.model,
71 | use_fp16=self.use_fp16,
72 | fp16_scale_growth=fp16_scale_growth,
73 | )
74 |
75 | self.opt = AdamW(
76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
77 | )
78 | if self.resume_step:
79 | self._load_optimizer_state()
80 | # Model was resumed, either due to a restart or a checkpoint
81 | # being specified at the command line.
82 | self.ema_params = [
83 | self._load_ema_parameters(rate) for rate in self.ema_rate
84 | ]
85 | else:
86 | self.ema_params = [
87 | copy.deepcopy(self.mp_trainer.master_params)
88 | for _ in range(len(self.ema_rate))
89 | ]
90 |
91 | if th.cuda.is_available():
92 | self.use_ddp = True
93 | self.ddp_model = DDP(
94 | self.model,
95 | device_ids=[dist_util.dev()],
96 | output_device=dist_util.dev(),
97 | broadcast_buffers=False,
98 | bucket_cap_mb=128,
99 | find_unused_parameters=False,
100 | )
101 | else:
102 | if dist.get_world_size() > 1:
103 | logger.warn(
104 | "Distributed training requires CUDA. "
105 | "Gradients will not be synchronized properly!"
106 | )
107 | self.use_ddp = False
108 | self.ddp_model = self.model
109 |
110 | def _load_and_sync_parameters(self):
111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
112 |
113 | if resume_checkpoint:
114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
115 | if dist.get_rank() == 0:
116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
117 | self.model.load_state_dict(
118 | dist_util.load_state_dict(
119 | resume_checkpoint, map_location=dist_util.dev()
120 | )
121 | )
122 |
123 | dist_util.sync_params(self.model.parameters())
124 |
125 | def _load_ema_parameters(self, rate):
126 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
127 |
128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
130 | if ema_checkpoint:
131 | if dist.get_rank() == 0:
132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
133 | state_dict = dist_util.load_state_dict(
134 | ema_checkpoint, map_location=dist_util.dev()
135 | )
136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
137 |
138 | dist_util.sync_params(ema_params)
139 | return ema_params
140 |
141 | def _load_optimizer_state(self):
142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
143 | opt_checkpoint = bf.join(
144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
145 | )
146 | if bf.exists(opt_checkpoint):
147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
148 | state_dict = dist_util.load_state_dict(
149 | opt_checkpoint, map_location=dist_util.dev()
150 | )
151 | self.opt.load_state_dict(state_dict)
152 |
153 | def run_loop(self):
154 | while (
155 | not self.lr_anneal_steps
156 | or self.step + self.resume_step < self.lr_anneal_steps
157 | ):
158 | batch, cond = next(self.data)
159 | self.run_step(batch, cond)
160 | if self.step % self.log_interval == 0:
161 | logger.dumpkvs()
162 | if self.step % self.save_interval == 0:
163 | self.save()
164 | # Run for a finite amount of time in integration tests.
165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
166 | return
167 | self.step += 1
168 | # Save the last checkpoint if it wasn't already saved.
169 | if (self.step - 1) % self.save_interval != 0:
170 | self.save()
171 |
172 | def run_step(self, batch, cond):
173 | self.forward_backward(batch, cond)
174 | took_step = self.mp_trainer.optimize(self.opt)
175 | if took_step:
176 | self._update_ema()
177 | self._anneal_lr()
178 | self.log_step()
179 |
180 | def forward_backward(self, batch, cond):
181 | self.mp_trainer.zero_grad()
182 | for i in range(0, batch.shape[0], self.microbatch):
183 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
184 | micro_cond = {
185 | k: v[i : i + self.microbatch].to(dist_util.dev())
186 | for k, v in cond.items()
187 | }
188 | last_batch = (i + self.microbatch) >= batch.shape[0]
189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
190 |
191 | compute_losses = functools.partial(
192 | self.diffusion.training_losses,
193 | self.ddp_model,
194 | micro,
195 | t,
196 | model_kwargs=micro_cond,
197 | )
198 |
199 | if last_batch or not self.use_ddp:
200 | losses = compute_losses()
201 | else:
202 | with self.ddp_model.no_sync():
203 | losses = compute_losses()
204 |
205 | if isinstance(self.schedule_sampler, LossAwareSampler):
206 | self.schedule_sampler.update_with_local_losses(
207 | t, losses["loss"].detach()
208 | )
209 |
210 | loss = (losses["loss"] * weights).mean()
211 | log_loss_dict(
212 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
213 | )
214 | self.mp_trainer.backward(loss)
215 |
216 | def _update_ema(self):
217 | for rate, params in zip(self.ema_rate, self.ema_params):
218 | update_ema(params, self.mp_trainer.master_params, rate=rate)
219 |
220 | def _anneal_lr(self):
221 | if not self.lr_anneal_steps:
222 | return
223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
224 | lr = self.lr * (1 - frac_done)
225 | for param_group in self.opt.param_groups:
226 | param_group["lr"] = lr
227 |
228 | def log_step(self):
229 | logger.logkv("step", self.step + self.resume_step)
230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
231 |
232 | def save(self):
233 | def save_checkpoint(rate, params):
234 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
235 | if dist.get_rank() == 0:
236 | logger.log(f"saving model {rate}...")
237 | if not rate:
238 | filename = f"model{(self.step+self.resume_step):06d}.pt"
239 | else:
240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
242 | th.save(state_dict, f)
243 |
244 | save_checkpoint(0, self.mp_trainer.master_params)
245 | for rate, params in zip(self.ema_rate, self.ema_params):
246 | save_checkpoint(rate, params)
247 |
248 | if dist.get_rank() == 0:
249 | with bf.BlobFile(
250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
251 | "wb",
252 | ) as f:
253 | th.save(self.opt.state_dict(), f)
254 |
255 | dist.barrier()
256 |
257 |
258 | def parse_resume_step_from_filename(filename):
259 | """
260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
261 | checkpoint's number of steps.
262 | """
263 | split = filename.split("model")
264 | if len(split) < 2:
265 | return 0
266 | split1 = split[-1].split(".")[0]
267 | try:
268 | return int(split1)
269 | except ValueError:
270 | return 0
271 |
272 |
273 | def get_blob_logdir():
274 | # You can change this to be a separate path to save checkpoints to
275 | # a blobstore or some external drive.
276 | return logger.get_dir()
277 |
278 |
279 | def find_resume_checkpoint():
280 | # On your infrastructure, you may want to override this to automatically
281 | # discover the latest checkpoint on your blob storage, etc.
282 | return None
283 |
284 |
285 | def find_ema_checkpoint(main_checkpoint, step, rate):
286 | if main_checkpoint is None:
287 | return None
288 | filename = f"ema_{rate}_{(step):06d}.pt"
289 | path = bf.join(bf.dirname(main_checkpoint), filename)
290 | if bf.exists(path):
291 | return path
292 | return None
293 |
294 |
295 | def log_loss_dict(diffusion, ts, losses):
296 | for key, values in losses.items():
297 | logger.logkv_mean(key, values.mean().item())
298 | # Log the quantiles (four quartiles, in particular).
299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
300 | quartile = int(4 * sub_t / diffusion.num_timesteps)
301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
302 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import re
9 | import contextlib
10 | import numpy as np
11 | import torch
12 | import warnings
13 | import dnnlib
14 |
15 | #----------------------------------------------------------------------------
16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17 | # same constant is used multiple times.
18 |
19 | _constant_cache = dict()
20 |
21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22 | value = np.asarray(value)
23 | if shape is not None:
24 | shape = tuple(shape)
25 | if dtype is None:
26 | dtype = torch.get_default_dtype()
27 | if device is None:
28 | device = torch.device('cpu')
29 | if memory_format is None:
30 | memory_format = torch.contiguous_format
31 |
32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33 | tensor = _constant_cache.get(key, None)
34 | if tensor is None:
35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36 | if shape is not None:
37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38 | tensor = tensor.contiguous(memory_format=memory_format)
39 | _constant_cache[key] = tensor
40 | return tensor
41 |
42 | #----------------------------------------------------------------------------
43 | # Replace NaN/Inf with specified numerical values.
44 |
45 | try:
46 | nan_to_num = torch.nan_to_num # 1.8.0a0
47 | except AttributeError:
48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49 | assert isinstance(input, torch.Tensor)
50 | if posinf is None:
51 | posinf = torch.finfo(input.dtype).max
52 | if neginf is None:
53 | neginf = torch.finfo(input.dtype).min
54 | assert nan == 0
55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56 |
57 | #----------------------------------------------------------------------------
58 | # Symbolic assert.
59 |
60 | try:
61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62 | except AttributeError:
63 | symbolic_assert = torch.Assert # 1.7.0
64 |
65 | #----------------------------------------------------------------------------
66 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68 |
69 | @contextlib.contextmanager
70 | def suppress_tracer_warnings():
71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72 | warnings.filters.insert(0, flt)
73 | yield
74 | warnings.filters.remove(flt)
75 |
76 | #----------------------------------------------------------------------------
77 | # Assert that the shape of a tensor matches the given list of integers.
78 | # None indicates that the size of a dimension is allowed to vary.
79 | # Performs symbolic assertion when used in torch.jit.trace().
80 |
81 | def assert_shape(tensor, ref_shape):
82 | if tensor.ndim != len(ref_shape):
83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85 | if ref_size is None:
86 | pass
87 | elif isinstance(ref_size, torch.Tensor):
88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90 | elif isinstance(size, torch.Tensor):
91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93 | elif size != ref_size:
94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95 |
96 | #----------------------------------------------------------------------------
97 | # Function decorator that calls torch.autograd.profiler.record_function().
98 |
99 | def profiled_function(fn):
100 | def decorator(*args, **kwargs):
101 | with torch.autograd.profiler.record_function(fn.__name__):
102 | return fn(*args, **kwargs)
103 | decorator.__name__ = fn.__name__
104 | return decorator
105 |
106 | #----------------------------------------------------------------------------
107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
108 | # indefinitely, shuffling items as it goes.
109 |
110 | class InfiniteSampler(torch.utils.data.Sampler):
111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112 | assert len(dataset) > 0
113 | assert num_replicas > 0
114 | assert 0 <= rank < num_replicas
115 | assert 0 <= window_size <= 1
116 | super().__init__(dataset)
117 | self.dataset = dataset
118 | self.rank = rank
119 | self.num_replicas = num_replicas
120 | self.shuffle = shuffle
121 | self.seed = seed
122 | self.window_size = window_size
123 |
124 | def __iter__(self):
125 | order = np.arange(len(self.dataset))
126 | rnd = None
127 | window = 0
128 | if self.shuffle:
129 | rnd = np.random.RandomState(self.seed)
130 | rnd.shuffle(order)
131 | window = int(np.rint(order.size * self.window_size))
132 |
133 | idx = 0
134 | while True:
135 | i = idx % order.size
136 | if idx % self.num_replicas == self.rank:
137 | yield order[i]
138 | if window >= 2:
139 | j = (i - rnd.randint(window)) % order.size
140 | order[i], order[j] = order[j], order[i]
141 | idx += 1
142 |
143 | #----------------------------------------------------------------------------
144 | # Utilities for operating with torch.nn.Module parameters and buffers.
145 |
146 | def params_and_buffers(module):
147 | assert isinstance(module, torch.nn.Module)
148 | return list(module.parameters()) + list(module.buffers())
149 |
150 | def named_params_and_buffers(module):
151 | assert isinstance(module, torch.nn.Module)
152 | return list(module.named_parameters()) + list(module.named_buffers())
153 |
154 | @torch.no_grad()
155 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
156 | assert isinstance(src_module, torch.nn.Module)
157 | assert isinstance(dst_module, torch.nn.Module)
158 | src_tensors = dict(named_params_and_buffers(src_module))
159 | for name, tensor in named_params_and_buffers(dst_module):
160 | assert (name in src_tensors) or (not require_all)
161 | if name in src_tensors:
162 | tensor.copy_(src_tensors[name])
163 |
164 | #----------------------------------------------------------------------------
165 | # Context manager for easily enabling/disabling DistributedDataParallel
166 | # synchronization.
167 |
168 | @contextlib.contextmanager
169 | def ddp_sync(module, sync):
170 | assert isinstance(module, torch.nn.Module)
171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172 | yield
173 | else:
174 | with module.no_sync():
175 | yield
176 |
177 | #----------------------------------------------------------------------------
178 | # Check DistributedDataParallel consistency across processes.
179 |
180 | def check_ddp_consistency(module, ignore_regex=None):
181 | assert isinstance(module, torch.nn.Module)
182 | for name, tensor in named_params_and_buffers(module):
183 | fullname = type(module).__name__ + '.' + name
184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185 | continue
186 | tensor = tensor.detach()
187 | if tensor.is_floating_point():
188 | tensor = nan_to_num(tensor)
189 | other = tensor.clone()
190 | torch.distributed.broadcast(tensor=other, src=0)
191 | assert (tensor == other).all(), fullname
192 |
193 | #----------------------------------------------------------------------------
194 | # Print summary table of module hierarchy.
195 |
196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197 | assert isinstance(module, torch.nn.Module)
198 | assert not isinstance(module, torch.jit.ScriptModule)
199 | assert isinstance(inputs, (tuple, list))
200 |
201 | # Register hooks.
202 | entries = []
203 | nesting = [0]
204 | def pre_hook(_mod, _inputs):
205 | nesting[0] += 1
206 | def post_hook(mod, _inputs, outputs):
207 | nesting[0] -= 1
208 | if nesting[0] <= max_nesting:
209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214 |
215 | # Run module.
216 | outputs = module(*inputs)
217 | for hook in hooks:
218 | hook.remove()
219 |
220 | # Identify unique outputs, parameters, and buffers.
221 | tensors_seen = set()
222 | for e in entries:
223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227 |
228 | # Filter out redundant entries.
229 | if skip_redundant:
230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231 |
232 | # Construct table.
233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234 | rows += [['---'] * len(rows[0])]
235 | param_total = 0
236 | buffer_total = 0
237 | submodule_names = {mod: name for name, mod in module.named_modules()}
238 | for e in entries:
239 | name = '' if e.mod is module else submodule_names[e.mod]
240 | param_size = sum(t.numel() for t in e.unique_params)
241 | buffer_size = sum(t.numel() for t in e.unique_buffers)
242 | output_shapes = [str(list(t.shape)) for t in e.outputs]
243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244 | rows += [[
245 | name + (':0' if len(e.outputs) >= 2 else ''),
246 | str(param_size) if param_size else '-',
247 | str(buffer_size) if buffer_size else '-',
248 | (output_shapes + ['-'])[0],
249 | (output_dtypes + ['-'])[0],
250 | ]]
251 | for idx in range(1, len(e.outputs)):
252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253 | param_total += param_size
254 | buffer_total += buffer_size
255 | rows += [['---'] * len(rows[0])]
256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257 |
258 | # Print table.
259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260 | print()
261 | for row in rows:
262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263 | print()
264 | return outputs
265 |
266 | #----------------------------------------------------------------------------
267 |
--------------------------------------------------------------------------------
/training/training_loop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Main training loop."""
9 |
10 | import os
11 | import time
12 | import copy
13 | import json
14 | import pickle
15 | import psutil
16 | import numpy as np
17 | import torch
18 | import dnnlib
19 | from torch_utils import distributed as dist
20 | from torch_utils import training_stats
21 | from torch_utils import misc
22 |
23 | import classifier_lib
24 | #----------------------------------------------------------------------------
25 |
26 | def training_loop(
27 | run_dir = '.', # Output directory.
28 | dataset_kwargs = {}, # Options for training set.
29 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
30 | network_kwargs = {}, # Options for model and preconditioning.
31 | loss_kwargs = {}, # Options for loss function.
32 | optimizer_kwargs = {}, # Options for optimizer.
33 | augment_kwargs = None, # Options for augmentation pipeline, None = disable.
34 | seed = 0, # Global random seed.
35 | batch_size = 512, # Total batch size for one training iteration.
36 | batch_gpu = None, # Limit batch size per GPU, None = no limit.
37 | total_kimg = 200000, # Training duration, measured in thousands of training images.
38 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
39 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
40 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
41 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
42 | kimg_per_tick = 50, # Interval of progress prints.
43 | snapshot_ticks = 50, # How often to save network snapshots, None = disable.
44 | state_dump_ticks = 500, # How often to dump training state, None = disable.
45 | resume_pkl = None, # Start from the given network snapshot, None = random initialization.
46 | resume_state_dump = None, # Start from the given training state, None = reset training state.
47 | resume_kimg = 0, # Start from the given training progress.
48 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
49 | device = torch.device('cuda'),
50 | cla_path = None,
51 | dis_path = None,
52 | ):
53 | # Initialize.
54 | start_time = time.time()
55 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
56 | torch.manual_seed(np.random.randint(1 << 31))
57 | torch.backends.cudnn.benchmark = cudnn_benchmark
58 | torch.backends.cudnn.allow_tf32 = False
59 | torch.backends.cuda.matmul.allow_tf32 = False
60 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
61 |
62 | # Select batch size per GPU.
63 | batch_gpu_total = batch_size // dist.get_world_size()
64 | if batch_gpu is None or batch_gpu > batch_gpu_total:
65 | batch_gpu = batch_gpu_total
66 | num_accumulation_rounds = batch_gpu_total // batch_gpu
67 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
68 |
69 | # Load dataset.
70 | dist.print0('Loading dataset...')
71 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
72 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
73 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
74 |
75 | # Construct network.
76 | dist.print0('Constructing network...')
77 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
78 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
79 | net.train().requires_grad_(True).to(device)
80 | if dist.get_rank() == 0:
81 | with torch.no_grad():
82 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
83 | sigma = torch.ones([batch_gpu], device=device)
84 | labels = torch.zeros([batch_gpu, net.label_dim], device=device)
85 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
86 |
87 | # Setup optimizer.
88 | dist.print0('Setting up optimizer...')
89 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
90 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
91 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
92 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
93 | ema = copy.deepcopy(net).eval().requires_grad_(False)
94 |
95 | # Load Discriminator
96 | cla = classifier_lib.load_classifier(cla_path, net.img_resolution, device, eval=True)
97 | dis = classifier_lib.load_discriminator(dis_path, device, False, eval=True, channel=512)
98 | discriminator = classifier_lib.get_discriminator(cla, dis,enable_grad=True)
99 | vpsde = classifier_lib.vpsde()
100 |
101 | # Resume training from previous snapshot.
102 | if resume_pkl is not None:
103 | dist.print0(f'Loading network weights from "{resume_pkl}"...')
104 | if dist.get_rank() != 0:
105 | torch.distributed.barrier() # rank 0 goes first
106 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
107 | data = pickle.load(f)
108 | if dist.get_rank() == 0:
109 | torch.distributed.barrier() # other ranks follow
110 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
111 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
112 | del data # conserve memory
113 | if resume_state_dump:
114 | dist.print0(f'Loading training state from "{resume_state_dump}"...')
115 | data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
116 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
117 | optimizer.load_state_dict(data['optimizer_state'])
118 | del data # conserve memory
119 |
120 | # Train.
121 | dist.print0(f'Training for {total_kimg} kimg...')
122 | dist.print0()
123 | cur_nimg = resume_kimg * 1000
124 | cur_tick = 0
125 | tick_start_nimg = cur_nimg
126 | tick_start_time = time.time()
127 | maintenance_time = tick_start_time - start_time
128 | dist.update_progress(cur_nimg // 1000, total_kimg)
129 | stats_jsonl = None
130 | while True:
131 |
132 | # Accumulate gradients.
133 | optimizer.zero_grad(set_to_none=True)
134 | for round_idx in range(num_accumulation_rounds):
135 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
136 | images, labels = next(dataset_iterator)
137 | images = images.to(device).to(torch.float32) / 127.5 - 1
138 | labels = labels.to(device)
139 | loss = loss_fn(vpsde=vpsde, dis=discriminator, net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
140 | training_stats.report('Loss/loss', loss)
141 | loss.sum().mul(loss_scaling / batch_gpu_total).backward()
142 |
143 | # Update weights.
144 | for g in optimizer.param_groups:
145 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
146 | for param in net.parameters():
147 | if param.grad is not None:
148 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
149 | optimizer.step()
150 |
151 | # Update EMA.
152 | ema_halflife_nimg = ema_halflife_kimg * 1000
153 | if ema_rampup_ratio is not None:
154 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
155 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
156 | for p_ema, p_net in zip(ema.parameters(), net.parameters()):
157 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
158 |
159 | # Perform maintenance tasks once per tick.
160 | cur_nimg += batch_size
161 | done = (cur_nimg >= total_kimg * 1000)
162 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
163 | continue
164 |
165 | # Print status line, accumulating the same information in training_stats.
166 | tick_end_time = time.time()
167 | fields = []
168 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
169 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
170 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
171 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
172 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
173 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
174 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
175 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
176 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
177 | torch.cuda.reset_peak_memory_stats()
178 | dist.print0(' '.join(fields))
179 |
180 | # Check for abort.
181 | if (not done) and dist.should_stop():
182 | done = True
183 | dist.print0()
184 | dist.print0('Aborting...')
185 |
186 | # Save network snapshot.
187 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
188 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
189 | for key, value in data.items():
190 | if isinstance(value, torch.nn.Module):
191 | value = copy.deepcopy(value).eval().requires_grad_(False)
192 | misc.check_ddp_consistency(value)
193 | data[key] = value.cpu()
194 | del value # conserve memory
195 | if dist.get_rank() == 0:
196 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
197 | pickle.dump(data, f)
198 | del data # conserve memory
199 |
200 | # Save full dump of the training state.
201 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
202 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
203 |
204 | # Update logs.
205 | training_stats.default_collector.update()
206 | if dist.get_rank() == 0:
207 | if stats_jsonl is None:
208 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
209 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
210 | stats_jsonl.flush()
211 | dist.update_progress(cur_nimg // 1000, total_kimg)
212 |
213 | # Update state.
214 | cur_tick += 1
215 | tick_start_nimg = cur_nimg
216 | tick_start_time = time.time()
217 | maintenance_time = tick_start_time - tick_end_time
218 | if done:
219 | break
220 |
221 | # Done.
222 | dist.print0()
223 | dist.print0('Exiting...')
224 |
225 | #----------------------------------------------------------------------------
226 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Train diffusion-based generative model using the techniques described in the
9 | paper "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import os
12 | import re
13 | import json
14 | import click
15 | import torch
16 | import dnnlib
17 | from torch_utils import distributed as dist
18 | from training import training_loop
19 |
20 | import warnings
21 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
22 |
23 | #----------------------------------------------------------------------------
24 | # Parse a comma separated list of numbers or ranges and return a list of ints.
25 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
26 |
27 | def parse_int_list(s):
28 | if isinstance(s, list): return s
29 | ranges = []
30 | range_re = re.compile(r'^(\d+)-(\d+)$')
31 | for p in s.split(','):
32 | m = range_re.match(p)
33 | if m:
34 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
35 | else:
36 | ranges.append(int(p))
37 | return ranges
38 |
39 | #----------------------------------------------------------------------------
40 |
41 | @click.command()
42 |
43 | # Main options.
44 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True)
45 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
46 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
47 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True)
48 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm', 'tiw_edm']), default='tiw_edm', show_default=True)
49 |
50 | # Hyperparameters.
51 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
52 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
53 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
54 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
55 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
56 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
57 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
58 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
59 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
60 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
61 |
62 | # Performance-related.
63 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
64 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
65 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
66 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
67 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
68 |
69 | # I/O-related.
70 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
71 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
72 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
73 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)
74 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True)
75 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
76 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
77 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
78 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
79 |
80 | ## TIW config
81 | @click.option('--cla_path', help='String to include in result dir name', metavar='STR', default="/home/aailab/alsdudrla10/FirstArticle/Github/TIW-DSM/checkpoints/discriminator/feature_extractor/32x32_classifier.pt", type=str)
82 | @click.option('--dis_path', help='String to include in result dir name', metavar='STR', default="/home/aailab/alsdudrla10/FirstArticle/Github/TIW-DSM/checkpoints/discriminator/cifar10/unbias_500/discriminator_501.pt", type=str)
83 |
84 | def main(**kwargs):
85 | """Train diffusion-based generative model using the techniques described in the
86 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".
87 |
88 | Examples:
89 |
90 | \b
91 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
92 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\
93 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
94 | """
95 | opts = dnnlib.EasyDict(kwargs)
96 | torch.multiprocessing.set_start_method('spawn')
97 | dist.init()
98 |
99 | # Initialize config dict.
100 | c = dnnlib.EasyDict()
101 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
102 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
103 | c.network_kwargs = dnnlib.EasyDict()
104 | c.loss_kwargs = dnnlib.EasyDict()
105 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
106 |
107 | # Ckpt path for weighting func
108 | c.cla_path = opts.cla_path
109 | c.dis_path = opts.dis_path
110 | c.resolution = opts.resolution
111 |
112 | # Validate dataset options.
113 | try:
114 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
115 | dataset_name = dataset_obj.name
116 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
117 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
118 | if opts.cond and not dataset_obj.has_labels:
119 | raise click.ClickException('--cond=True requires labels specified in dataset.json')
120 | del dataset_obj # conserve memory
121 | except IOError as err:
122 | raise click.ClickException(f'--data: {err}')
123 |
124 | # Network architecture.
125 | if opts.arch == 'ddpmpp':
126 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
127 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
128 | elif opts.arch == 'ncsnpp':
129 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
130 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2])
131 | else:
132 | assert opts.arch == 'adm'
133 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
134 |
135 | # Preconditioning & loss function.
136 | if opts.precond == 'vp':
137 | c.network_kwargs.class_name = 'training.networks.VPPrecond'
138 | c.loss_kwargs.class_name = 'training.loss.VPLoss'
139 | elif opts.precond == 've':
140 | c.network_kwargs.class_name = 'training.networks.VEPrecond'
141 | c.loss_kwargs.class_name = 'training.loss.VELoss'
142 | elif opts.precond == 'edm':
143 | c.network_kwargs.class_name = 'training.networks.EDMPrecond'
144 | c.loss_kwargs.class_name = 'training.loss.EDMLoss'
145 | elif opts.precond =='tiw_edm':
146 | c.network_kwargs.class_name = 'training.networks.EDMPrecond'
147 | c.loss_kwargs.class_name = 'training.loss.TIW_EDMLoss'
148 | else:
149 | assert 0
150 |
151 | # Network options.
152 | if opts.cbase is not None:
153 | c.network_kwargs.model_channels = opts.cbase
154 | if opts.cres is not None:
155 | c.network_kwargs.channel_mult = opts.cres
156 | if opts.augment:
157 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
158 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
159 | c.network_kwargs.augment_dim = 9
160 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
161 |
162 | # Training options.
163 | c.total_kimg = max(int(opts.duration * 1000), 1)
164 | c.ema_halflife_kimg = int(opts.ema * 1000)
165 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
166 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
167 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump)
168 |
169 | # Random seed.
170 | if opts.seed is not None:
171 | c.seed = opts.seed
172 | else:
173 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
174 | torch.distributed.broadcast(seed, src=0)
175 | c.seed = int(seed)
176 |
177 | # Transfer learning and resume.
178 | if opts.transfer is not None:
179 | if opts.resume is not None:
180 | raise click.ClickException('--transfer and --resume cannot be specified at the same time')
181 | c.resume_pkl = opts.transfer
182 | c.ema_rampup_ratio = None
183 | elif opts.resume is not None:
184 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
185 | if not match or not os.path.isfile(opts.resume):
186 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
187 | c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
188 | c.resume_kimg = int(match.group(1))
189 | c.resume_state_dump = opts.resume
190 |
191 | # Description string.
192 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
193 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
194 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
195 | if opts.desc is not None:
196 | desc += f'-{opts.desc}'
197 |
198 | # Pick output directory.
199 | if dist.get_rank() != 0:
200 | c.run_dir = None
201 | elif opts.nosubdir:
202 | c.run_dir = opts.outdir
203 | else:
204 | prev_run_dirs = []
205 | if os.path.isdir(opts.outdir):
206 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
207 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
208 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
209 | cur_run_id = max(prev_run_ids, default=-1) + 1
210 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
211 | assert not os.path.exists(c.run_dir)
212 |
213 | # Print options.
214 | dist.print0()
215 | dist.print0('Training options:')
216 | dist.print0(json.dumps(c, indent=2))
217 | dist.print0()
218 | dist.print0(f'Output directory: {c.run_dir}')
219 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
220 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
221 | dist.print0(f'Network architecture: {opts.arch}')
222 | dist.print0(f'Preconditioning & loss: {opts.precond}')
223 | dist.print0(f'Number of GPUs: {dist.get_world_size()}')
224 | dist.print0(f'Batch size: {c.batch_size}')
225 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
226 | dist.print0()
227 |
228 | # Dry run?
229 | if opts.dry_run:
230 | dist.print0('Dry run; exiting.')
231 | return
232 |
233 | # Create output directory.
234 | dist.print0('Creating output directory...')
235 | if dist.get_rank() == 0:
236 | os.makedirs(c.run_dir, exist_ok=True)
237 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
238 | json.dump(c, f, indent=2)
239 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
240 |
241 | # Train.
242 | training_loop.training_loop(**c)
243 |
244 | #----------------------------------------------------------------------------
245 |
246 | if __name__ == "__main__":
247 | main()
248 |
249 | #----------------------------------------------------------------------------
250 |
--------------------------------------------------------------------------------
/guided_diffusion/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 |
221 | def logkv_mean(key, val):
222 | """
223 | The same as logkv(), but if called many times, values averaged.
224 | """
225 | get_current().logkv_mean(key, val)
226 |
227 |
228 | def logkvs(d):
229 | """
230 | Log a dictionary of key-value pairs
231 | """
232 | for (k, v) in d.items():
233 | logkv(k, v)
234 |
235 |
236 | def dumpkvs():
237 | """
238 | Write all of the diagnostics from the current iteration
239 | """
240 | return get_current().dumpkvs()
241 |
242 |
243 | def getkvs():
244 | return get_current().name2val
245 |
246 |
247 | def log(*args, level=INFO):
248 | """
249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250 | """
251 | get_current().log(*args, level=level)
252 |
253 |
254 | def debug(*args):
255 | log(*args, level=DEBUG)
256 |
257 |
258 | def info(*args):
259 | log(*args, level=INFO)
260 |
261 |
262 | def warn(*args):
263 | log(*args, level=WARN)
264 |
265 |
266 | def error(*args):
267 | log(*args, level=ERROR)
268 |
269 |
270 | def set_level(level):
271 | """
272 | Set logging threshold on current logger.
273 | """
274 | get_current().set_level(level)
275 |
276 |
277 | def set_comm(comm):
278 | get_current().set_comm(comm)
279 |
280 |
281 | def get_dir():
282 | """
283 | Get directory that log files are being written to.
284 | will be None if there is no output directory (i.e., if you didn't call start)
285 | """
286 | return get_current().get_dir()
287 |
288 |
289 | record_tabular = logkv
290 | dump_tabular = dumpkvs
291 |
292 |
293 | @contextmanager
294 | def profile_kv(scopename):
295 | logkey = "wait_" + scopename
296 | tstart = time.time()
297 | try:
298 | yield
299 | finally:
300 | get_current().name2val[logkey] += time.time() - tstart
301 |
302 |
303 | def profile(n):
304 | """
305 | Usage:
306 | @profile("my_func")
307 | def my_func(): code
308 | """
309 |
310 | def decorator_with_name(func):
311 | def func_wrapper(*args, **kwargs):
312 | with profile_kv(n):
313 | return func(*args, **kwargs)
314 |
315 | return func_wrapper
316 |
317 | return decorator_with_name
318 |
319 |
320 | # ================================================================
321 | # Backend
322 | # ================================================================
323 |
324 |
325 | def get_current():
326 | if Logger.CURRENT is None:
327 | _configure_default_logger()
328 |
329 | return Logger.CURRENT
330 |
331 |
332 | class Logger(object):
333 | DEFAULT = None # A logger with no output files. (See right below class definition)
334 | # So that you can still log to the terminal without setting up any output files
335 | CURRENT = None # Current logger being used by the free functions above
336 |
337 | def __init__(self, dir, output_formats, comm=None):
338 | self.name2val = defaultdict(float) # values this iteration
339 | self.name2cnt = defaultdict(int)
340 | self.level = INFO
341 | self.dir = dir
342 | self.output_formats = output_formats
343 | self.comm = comm
344 |
345 | # Logging API, forwarded
346 | # ----------------------------------------
347 | def logkv(self, key, val):
348 | self.name2val[key] = val
349 |
350 | def logkv_mean(self, key, val):
351 | oldval, cnt = self.name2val[key], self.name2cnt[key]
352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353 | self.name2cnt[key] = cnt + 1
354 |
355 | def dumpkvs(self):
356 | if self.comm is None:
357 | d = self.name2val
358 | else:
359 | d = mpi_weighted_mean(
360 | self.comm,
361 | {
362 | name: (val, self.name2cnt.get(name, 1))
363 | for (name, val) in self.name2val.items()
364 | },
365 | )
366 | if self.comm.rank != 0:
367 | d["dummy"] = 1 # so we don't get a warning about empty dict
368 | out = d.copy() # Return the dict for unit testing purposes
369 | for fmt in self.output_formats:
370 | if isinstance(fmt, KVWriter):
371 | fmt.writekvs(d)
372 | self.name2val.clear()
373 | self.name2cnt.clear()
374 | return out
375 |
376 | def log(self, *args, level=INFO):
377 | if self.level <= level:
378 | self._do_log(args)
379 |
380 | # Configuration
381 | # ----------------------------------------
382 | def set_level(self, level):
383 | self.level = level
384 |
385 | def set_comm(self, comm):
386 | self.comm = comm
387 |
388 | def get_dir(self):
389 | return self.dir
390 |
391 | def close(self):
392 | for fmt in self.output_formats:
393 | fmt.close()
394 |
395 | # Misc
396 | # ----------------------------------------
397 | def _do_log(self, args):
398 | for fmt in self.output_formats:
399 | if isinstance(fmt, SeqWriter):
400 | fmt.writeseq(map(str, args))
401 |
402 |
403 | def get_rank_without_mpi_import():
404 | # check environment variables here instead of importing mpi4py
405 | # to avoid calling MPI_Init() when this module is imported
406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407 | if varname in os.environ:
408 | return int(os.environ[varname])
409 | return 0
410 |
411 |
412 | def mpi_weighted_mean(comm, local_name2valcount):
413 | """
414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415 | Perform a weighted average over dicts that are each on a different node
416 | Input: local_name2valcount: dict mapping key -> (value, count)
417 | Returns: key -> mean
418 | """
419 | all_name2valcount = comm.gather(local_name2valcount)
420 | if comm.rank == 0:
421 | name2sum = defaultdict(float)
422 | name2count = defaultdict(float)
423 | for n2vc in all_name2valcount:
424 | for (name, (val, count)) in n2vc.items():
425 | try:
426 | val = float(val)
427 | except ValueError:
428 | if comm.rank == 0:
429 | warnings.warn(
430 | "WARNING: tried to compute mean on non-float {}={}".format(
431 | name, val
432 | )
433 | )
434 | else:
435 | name2sum[name] += val * count
436 | name2count[name] += count
437 | return {name: name2sum[name] / name2count[name] for name in name2sum}
438 | else:
439 | return {}
440 |
441 |
442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443 | """
444 | If comm is provided, average all numerical stats across that comm
445 | """
446 | if dir is None:
447 | dir = os.getenv("OPENAI_LOGDIR")
448 | if dir is None:
449 | dir = osp.join(
450 | tempfile.gettempdir(),
451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452 | )
453 | assert isinstance(dir, str)
454 | dir = os.path.expanduser(dir)
455 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
456 |
457 | rank = get_rank_without_mpi_import()
458 | if rank > 0:
459 | log_suffix = log_suffix + "-rank%03i" % rank
460 |
461 | if format_strs is None:
462 | if rank == 0:
463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464 | else:
465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466 | format_strs = filter(None, format_strs)
467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468 |
469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470 | if output_formats:
471 | log("Logging to %s" % dir)
472 |
473 |
474 | def _configure_default_logger():
475 | configure()
476 | Logger.DEFAULT = Logger.CURRENT
477 |
478 |
479 | def reset():
480 | if Logger.CURRENT is not Logger.DEFAULT:
481 | Logger.CURRENT.close()
482 | Logger.CURRENT = Logger.DEFAULT
483 | log("Reset logger")
484 |
485 |
486 | @contextmanager
487 | def scoped_configure(dir=None, format_strs=None, comm=None):
488 | prevlogger = Logger.CURRENT
489 | configure(dir=dir, format_strs=format_strs, comm=comm)
490 | try:
491 | yield
492 | finally:
493 | Logger.CURRENT.close()
494 | Logger.CURRENT = prevlogger
495 |
496 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Generate random images using the techniques described in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import os
12 | import re
13 | import click
14 | import tqdm
15 | import pickle
16 | import numpy as np
17 | import torch
18 | import PIL.Image
19 | import dnnlib
20 | from torch_utils import distributed as dist
21 |
22 | #----------------------------------------------------------------------------
23 | # Proposed EDM sampler (Algorithm 2).
24 |
25 | def edm_sampler(
26 | net, latents, class_labels=None, randn_like=torch.randn_like,
27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
29 | ):
30 | # Adjust noise levels based on what's supported by the network.
31 | sigma_min = max(sigma_min, net.sigma_min)
32 | sigma_max = min(sigma_max, net.sigma_max)
33 |
34 | # Time step discretization.
35 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
36 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
37 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
38 |
39 | # Main sampling loop.
40 | x_next = latents.to(torch.float64) * t_steps[0]
41 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
42 | x_cur = x_next
43 |
44 | # Increase noise temporarily.
45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
46 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
48 |
49 | # Euler step.
50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
51 | d_cur = (x_hat - denoised) / t_hat
52 | x_next = x_hat + (t_next - t_hat) * d_cur
53 |
54 | # Apply 2nd order correction.
55 | if i < num_steps - 1:
56 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
57 | d_prime = (x_next - denoised) / t_next
58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
59 |
60 | return x_next
61 |
62 | #----------------------------------------------------------------------------
63 | # Generalized ablation sampler, representing the superset of all sampling
64 | # methods discussed in the paper.
65 |
66 | def ablation_sampler(
67 | net, latents, class_labels=None, randn_like=torch.randn_like,
68 | num_steps=18, sigma_min=None, sigma_max=None, rho=7,
69 | solver='heun', discretization='edm', schedule='linear', scaling='none',
70 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
71 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
72 | ):
73 | assert solver in ['euler', 'heun']
74 | assert discretization in ['vp', 've', 'iddpm', 'edm']
75 | assert schedule in ['vp', 've', 'linear']
76 | assert scaling in ['vp', 'none']
77 |
78 | # Helper functions for VP & VE noise level schedules.
79 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
80 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
81 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
82 | ve_sigma = lambda t: t.sqrt()
83 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
84 | ve_sigma_inv = lambda sigma: sigma ** 2
85 |
86 | # Select default noise level range based on the specified time step discretization.
87 | if sigma_min is None:
88 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s)
89 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
90 | if sigma_max is None:
91 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1)
92 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
93 |
94 | # Adjust noise levels based on what's supported by the network.
95 | sigma_min = max(sigma_min, net.sigma_min)
96 | sigma_max = min(sigma_max, net.sigma_max)
97 |
98 | # Compute corresponding betas for VP.
99 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
100 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
101 |
102 | # Define time steps in terms of noise level.
103 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
104 | if discretization == 'vp':
105 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
106 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
107 | elif discretization == 've':
108 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
109 | sigma_steps = ve_sigma(orig_t_steps)
110 | elif discretization == 'iddpm':
111 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
112 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
113 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
114 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
115 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
116 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
117 | else:
118 | assert discretization == 'edm'
119 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
120 |
121 | # Define noise level schedule.
122 | if schedule == 'vp':
123 | sigma = vp_sigma(vp_beta_d, vp_beta_min)
124 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
125 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
126 | elif schedule == 've':
127 | sigma = ve_sigma
128 | sigma_deriv = ve_sigma_deriv
129 | sigma_inv = ve_sigma_inv
130 | else:
131 | assert schedule == 'linear'
132 | sigma = lambda t: t
133 | sigma_deriv = lambda t: 1
134 | sigma_inv = lambda sigma: sigma
135 |
136 | # Define scaling schedule.
137 | if scaling == 'vp':
138 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
139 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
140 | else:
141 | assert scaling == 'none'
142 | s = lambda t: 1
143 | s_deriv = lambda t: 0
144 |
145 | # Compute final time steps based on the corresponding noise levels.
146 | t_steps = sigma_inv(net.round_sigma(sigma_steps))
147 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
148 |
149 | # Main sampling loop.
150 | t_next = t_steps[0]
151 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
152 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
153 | x_cur = x_next
154 |
155 | # Increase noise temporarily.
156 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
157 | t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
158 | x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)
159 |
160 | # Euler step.
161 | h = t_next - t_hat
162 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
163 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
164 | x_prime = x_hat + alpha * h * d_cur
165 | t_prime = t_hat + alpha * h
166 |
167 | # Apply 2nd order correction.
168 | if solver == 'euler' or i == num_steps - 1:
169 | x_next = x_hat + h * d_cur
170 | else:
171 | assert solver == 'heun'
172 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
173 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
174 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
175 |
176 | return x_next
177 |
178 | #----------------------------------------------------------------------------
179 | # Wrapper for torch.Generator that allows specifying a different random seed
180 | # for each sample in a minibatch.
181 |
182 | class StackedRandomGenerator:
183 | def __init__(self, device, seeds):
184 | super().__init__()
185 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
186 |
187 | def randn(self, size, **kwargs):
188 | assert size[0] == len(self.generators)
189 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
190 |
191 | def randn_like(self, input):
192 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
193 |
194 | def randint(self, *args, size, **kwargs):
195 | assert size[0] == len(self.generators)
196 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
197 |
198 | #----------------------------------------------------------------------------
199 | # Parse a comma separated list of numbers or ranges and return a list of ints.
200 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
201 |
202 | def parse_int_list(s):
203 | if isinstance(s, list): return s
204 | ranges = []
205 | range_re = re.compile(r'^(\d+)-(\d+)$')
206 | for p in s.split(','):
207 | m = range_re.match(p)
208 | if m:
209 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
210 | else:
211 | ranges.append(int(p))
212 | return ranges
213 |
214 | #----------------------------------------------------------------------------
215 |
216 | @click.command()
217 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
218 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
219 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
220 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
221 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
222 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
223 |
224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
228 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
229 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
230 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
231 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
232 |
233 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun']))
234 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
235 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear']))
236 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none']))
237 |
238 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):
239 | """Generate random images using the techniques described in the paper
240 | "Elucidating the Design Space of Diffusion-Based Generative Models".
241 |
242 | Examples:
243 |
244 | \b
245 | # Generate 64 images and save them as out/*.png
246 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\
247 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
248 |
249 | \b
250 | # Generate 1024 images using 2 GPUs
251 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\
252 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
253 | """
254 | dist.init()
255 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
256 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
257 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
258 |
259 | # Rank 0 goes first.
260 | if dist.get_rank() != 0:
261 | torch.distributed.barrier()
262 |
263 | # Load network.
264 | dist.print0(f'Loading network from "{network_pkl}"...')
265 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
266 | net = pickle.load(f)['ema'].to(device)
267 |
268 | # Other ranks follow.
269 | if dist.get_rank() == 0:
270 | torch.distributed.barrier()
271 |
272 | # Loop over batches.
273 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
274 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
275 | torch.distributed.barrier()
276 | batch_size = len(batch_seeds)
277 | if batch_size == 0:
278 | continue
279 |
280 | # Pick latents and labels.
281 | rnd = StackedRandomGenerator(device, batch_seeds)
282 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
283 | class_labels = None
284 | if net.label_dim:
285 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
286 | if class_idx is not None:
287 | class_labels[:, :] = 0
288 | class_labels[:, class_idx] = 1
289 |
290 | # Generate images.
291 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
292 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
293 | sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
294 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
295 |
296 | # Save images.
297 | images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
298 | for seed, image_np in zip(batch_seeds, images_np):
299 | image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir
300 | os.makedirs(image_dir, exist_ok=True)
301 | image_path = os.path.join(image_dir, f'{seed:06d}.png')
302 | if image_np.shape[2] == 1:
303 | PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
304 | else:
305 | PIL.Image.fromarray(image_np, 'RGB').save(image_path)
306 |
307 | # Done.
308 | torch.distributed.barrier()
309 | dist.print0('Done.')
310 |
311 | #----------------------------------------------------------------------------
312 |
313 | if __name__ == "__main__":
314 | main()
315 |
316 | #----------------------------------------------------------------------------
317 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import tempfile
27 | import urllib
28 | import urllib.request
29 | import uuid
30 |
31 | from distutils.util import strtobool
32 | from typing import Any, List, Tuple, Union, Optional
33 |
34 |
35 | # Util classes
36 | # ------------------------------------------------------------------------------------------
37 |
38 |
39 | class EasyDict(dict):
40 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41 |
42 | def __getattr__(self, name: str) -> Any:
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name: str, value: Any) -> None:
49 | self[name] = value
50 |
51 | def __delattr__(self, name: str) -> None:
52 | del self[name]
53 |
54 |
55 | class Logger(object):
56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57 |
58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59 | self.file = None
60 |
61 | if file_name is not None:
62 | self.file = open(file_name, file_mode)
63 |
64 | self.should_flush = should_flush
65 | self.stdout = sys.stdout
66 | self.stderr = sys.stderr
67 |
68 | sys.stdout = self
69 | sys.stderr = self
70 |
71 | def __enter__(self) -> "Logger":
72 | return self
73 |
74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75 | self.close()
76 |
77 | def write(self, text: Union[str, bytes]) -> None:
78 | """Write text to stdout (and a file) and optionally flush."""
79 | if isinstance(text, bytes):
80 | text = text.decode()
81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82 | return
83 |
84 | if self.file is not None:
85 | self.file.write(text)
86 |
87 | self.stdout.write(text)
88 |
89 | if self.should_flush:
90 | self.flush()
91 |
92 | def flush(self) -> None:
93 | """Flush written text to both stdout and a file, if open."""
94 | if self.file is not None:
95 | self.file.flush()
96 |
97 | self.stdout.flush()
98 |
99 | def close(self) -> None:
100 | """Flush, close possible files, and remove stdout/stderr mirroring."""
101 | self.flush()
102 |
103 | # if using multiple loggers, prevent closing in wrong order
104 | if sys.stdout is self:
105 | sys.stdout = self.stdout
106 | if sys.stderr is self:
107 | sys.stderr = self.stderr
108 |
109 | if self.file is not None:
110 | self.file.close()
111 | self.file = None
112 |
113 |
114 | # Cache directories
115 | # ------------------------------------------------------------------------------------------
116 |
117 | _dnnlib_cache_dir = None
118 |
119 | def set_cache_dir(path: str) -> None:
120 | global _dnnlib_cache_dir
121 | _dnnlib_cache_dir = path
122 |
123 | def make_cache_dir_path(*paths: str) -> str:
124 | if _dnnlib_cache_dir is not None:
125 | return os.path.join(_dnnlib_cache_dir, *paths)
126 | if 'DNNLIB_CACHE_DIR' in os.environ:
127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128 | if 'HOME' in os.environ:
129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130 | if 'USERPROFILE' in os.environ:
131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133 |
134 | # Small util functions
135 | # ------------------------------------------------------------------------------------------
136 |
137 |
138 | def format_time(seconds: Union[int, float]) -> str:
139 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140 | s = int(np.rint(seconds))
141 |
142 | if s < 60:
143 | return "{0}s".format(s)
144 | elif s < 60 * 60:
145 | return "{0}m {1:02}s".format(s // 60, s % 60)
146 | elif s < 24 * 60 * 60:
147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148 | else:
149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150 |
151 |
152 | def format_time_brief(seconds: Union[int, float]) -> str:
153 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154 | s = int(np.rint(seconds))
155 |
156 | if s < 60:
157 | return "{0}s".format(s)
158 | elif s < 60 * 60:
159 | return "{0}m {1:02}s".format(s // 60, s % 60)
160 | elif s < 24 * 60 * 60:
161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162 | else:
163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164 |
165 |
166 | def ask_yes_no(question: str) -> bool:
167 | """Ask the user the question until the user inputs a valid answer."""
168 | while True:
169 | try:
170 | print("{0} [y/n]".format(question))
171 | return strtobool(input().lower())
172 | except ValueError:
173 | pass
174 |
175 |
176 | def tuple_product(t: Tuple) -> Any:
177 | """Calculate the product of the tuple elements."""
178 | result = 1
179 |
180 | for v in t:
181 | result *= v
182 |
183 | return result
184 |
185 |
186 | _str_to_ctype = {
187 | "uint8": ctypes.c_ubyte,
188 | "uint16": ctypes.c_uint16,
189 | "uint32": ctypes.c_uint32,
190 | "uint64": ctypes.c_uint64,
191 | "int8": ctypes.c_byte,
192 | "int16": ctypes.c_int16,
193 | "int32": ctypes.c_int32,
194 | "int64": ctypes.c_int64,
195 | "float32": ctypes.c_float,
196 | "float64": ctypes.c_double
197 | }
198 |
199 |
200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202 | type_str = None
203 |
204 | if isinstance(type_obj, str):
205 | type_str = type_obj
206 | elif hasattr(type_obj, "__name__"):
207 | type_str = type_obj.__name__
208 | elif hasattr(type_obj, "name"):
209 | type_str = type_obj.name
210 | else:
211 | raise RuntimeError("Cannot infer type name from input")
212 |
213 | assert type_str in _str_to_ctype.keys()
214 |
215 | my_dtype = np.dtype(type_str)
216 | my_ctype = _str_to_ctype[type_str]
217 |
218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219 |
220 | return my_dtype, my_ctype
221 |
222 |
223 | def is_pickleable(obj: Any) -> bool:
224 | try:
225 | with io.BytesIO() as stream:
226 | pickle.dump(obj, stream)
227 | return True
228 | except:
229 | return False
230 |
231 |
232 | # Functionality to import modules/objects by name, and call functions by name
233 | # ------------------------------------------------------------------------------------------
234 |
235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236 | """Searches for the underlying module behind the name to some python object.
237 | Returns the module and the object name (original name with module part removed)."""
238 |
239 | # allow convenience shorthands, substitute them by full names
240 | obj_name = re.sub("^np.", "numpy.", obj_name)
241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242 |
243 | # list alternatives for (module_name, local_obj_name)
244 | parts = obj_name.split(".")
245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246 |
247 | # try each alternative in turn
248 | for module_name, local_obj_name in name_pairs:
249 | try:
250 | module = importlib.import_module(module_name) # may raise ImportError
251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
252 | return module, local_obj_name
253 | except:
254 | pass
255 |
256 | # maybe some of the modules themselves contain errors?
257 | for module_name, _local_obj_name in name_pairs:
258 | try:
259 | importlib.import_module(module_name) # may raise ImportError
260 | except ImportError:
261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262 | raise
263 |
264 | # maybe the requested attribute is missing?
265 | for module_name, local_obj_name in name_pairs:
266 | try:
267 | module = importlib.import_module(module_name) # may raise ImportError
268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
269 | except ImportError:
270 | pass
271 |
272 | # we are out of luck, but we have no idea why
273 | raise ImportError(obj_name)
274 |
275 |
276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277 | """Traverses the object name and returns the last (rightmost) python object."""
278 | if obj_name == '':
279 | return module
280 | obj = module
281 | for part in obj_name.split("."):
282 | obj = getattr(obj, part)
283 | return obj
284 |
285 |
286 | def get_obj_by_name(name: str) -> Any:
287 | """Finds the python object with the given name."""
288 | module, obj_name = get_module_from_obj_name(name)
289 | return get_obj_from_module(module, obj_name)
290 |
291 |
292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293 | """Finds the python object with the given name and calls it as a function."""
294 | assert func_name is not None
295 | func_obj = get_obj_by_name(func_name)
296 | assert callable(func_obj)
297 | return func_obj(*args, **kwargs)
298 |
299 |
300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301 | """Finds the python class with the given name and constructs it with the given arguments."""
302 | return call_func_by_name(*args, func_name=class_name, **kwargs)
303 |
304 |
305 | def get_module_dir_by_obj_name(obj_name: str) -> str:
306 | """Get the directory path of the module containing the given object name."""
307 | module, _ = get_module_from_obj_name(obj_name)
308 | return os.path.dirname(inspect.getfile(module))
309 |
310 |
311 | def is_top_level_function(obj: Any) -> bool:
312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314 |
315 |
316 | def get_top_level_function_name(obj: Any) -> str:
317 | """Return the fully-qualified name of a top-level function."""
318 | assert is_top_level_function(obj)
319 | module = obj.__module__
320 | if module == '__main__':
321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322 | return module + "." + obj.__name__
323 |
324 |
325 | # File system helpers
326 | # ------------------------------------------------------------------------------------------
327 |
328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329 | """List all files recursively in a given directory while ignoring given file and directory names.
330 | Returns list of tuples containing both absolute and relative paths."""
331 | assert os.path.isdir(dir_path)
332 | base_name = os.path.basename(os.path.normpath(dir_path))
333 |
334 | if ignores is None:
335 | ignores = []
336 |
337 | result = []
338 |
339 | for root, dirs, files in os.walk(dir_path, topdown=True):
340 | for ignore_ in ignores:
341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342 |
343 | # dirs need to be edited in-place
344 | for d in dirs_to_remove:
345 | dirs.remove(d)
346 |
347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348 |
349 | absolute_paths = [os.path.join(root, f) for f in files]
350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351 |
352 | if add_base_to_relative:
353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354 |
355 | assert len(absolute_paths) == len(relative_paths)
356 | result += zip(absolute_paths, relative_paths)
357 |
358 | return result
359 |
360 |
361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362 | """Takes in a list of tuples of (src, dst) paths and copies files.
363 | Will create all necessary directories."""
364 | for file in files:
365 | target_dir_name = os.path.dirname(file[1])
366 |
367 | # will create all intermediate-level directories
368 | if not os.path.exists(target_dir_name):
369 | os.makedirs(target_dir_name)
370 |
371 | shutil.copyfile(file[0], file[1])
372 |
373 |
374 | # URL helpers
375 | # ------------------------------------------------------------------------------------------
376 |
377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378 | """Determine whether the given object is a valid URL string."""
379 | if not isinstance(obj, str) or not "://" in obj:
380 | return False
381 | if allow_file_urls and obj.startswith('file://'):
382 | return True
383 | try:
384 | res = requests.compat.urlparse(obj)
385 | if not res.scheme or not res.netloc or not "." in res.netloc:
386 | return False
387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388 | if not res.scheme or not res.netloc or not "." in res.netloc:
389 | return False
390 | except:
391 | return False
392 | return True
393 |
394 |
395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396 | """Download the given URL and return a binary-mode file object to access the data."""
397 | assert num_attempts >= 1
398 | assert not (return_filename and (not cache))
399 |
400 | # Doesn't look like an URL scheme so interpret it as a local filename.
401 | if not re.match('^[a-z]+://', url):
402 | return url if return_filename else open(url, "rb")
403 |
404 | # Handle file URLs. This code handles unusual file:// patterns that
405 | # arise on Windows:
406 | #
407 | # file:///c:/foo.txt
408 | #
409 | # which would translate to a local '/c:/foo.txt' filename that's
410 | # invalid. Drop the forward slash for such pathnames.
411 | #
412 | # If you touch this code path, you should test it on both Linux and
413 | # Windows.
414 | #
415 | # Some internet resources suggest using urllib.request.url2pathname() but
416 | # but that converts forward slashes to backslashes and this causes
417 | # its own set of problems.
418 | if url.startswith('file://'):
419 | filename = urllib.parse.urlparse(url).path
420 | if re.match(r'^/[a-zA-Z]:', filename):
421 | filename = filename[1:]
422 | return filename if return_filename else open(filename, "rb")
423 |
424 | assert is_url(url)
425 |
426 | # Lookup from cache.
427 | if cache_dir is None:
428 | cache_dir = make_cache_dir_path('downloads')
429 |
430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431 | if cache:
432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433 | if len(cache_files) == 1:
434 | filename = cache_files[0]
435 | return filename if return_filename else open(filename, "rb")
436 |
437 | # Download.
438 | url_name = None
439 | url_data = None
440 | with requests.Session() as session:
441 | if verbose:
442 | print("Downloading %s ..." % url, end="", flush=True)
443 | for attempts_left in reversed(range(num_attempts)):
444 | try:
445 | with session.get(url) as res:
446 | res.raise_for_status()
447 | if len(res.content) == 0:
448 | raise IOError("No data received")
449 |
450 | if len(res.content) < 8192:
451 | content_str = res.content.decode("utf-8")
452 | if "download_warning" in res.headers.get("Set-Cookie", ""):
453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454 | if len(links) == 1:
455 | url = requests.compat.urljoin(url, links[0])
456 | raise IOError("Google Drive virus checker nag")
457 | if "Google Drive - Quota exceeded" in content_str:
458 | raise IOError("Google Drive download quota exceeded -- please try again later")
459 |
460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461 | url_name = match[1] if match else url
462 | url_data = res.content
463 | if verbose:
464 | print(" done")
465 | break
466 | except KeyboardInterrupt:
467 | raise
468 | except:
469 | if not attempts_left:
470 | if verbose:
471 | print(" failed")
472 | raise
473 | if verbose:
474 | print(".", end="", flush=True)
475 |
476 | # Save to cache.
477 | if cache:
478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479 | safe_name = safe_name[:min(len(safe_name), 128)]
480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482 | os.makedirs(cache_dir, exist_ok=True)
483 | with open(temp_file, "wb") as f:
484 | f.write(url_data)
485 | os.replace(temp_file, cache_file) # atomic
486 | if return_filename:
487 | return cache_file
488 |
489 | # Return data as file object.
490 | assert not return_filename
491 | return io.BytesIO(url_data)
492 |
--------------------------------------------------------------------------------