├── torch_utils
├── __init__.py
├── distributed.py
├── persistence.py
├── training_stats.py
└── misc.py
├── training
├── __init__.py
├── loss.py
├── dataset.py
├── training_loop.py
├── augment.py
└── networks.py
├── dnnlib
├── __init__.py
└── util.py
├── Dockerfile
├── fid.py
├── README.md
├── generate.py
├── LICENSE
├── train.py
└── dataset_tool.py
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Garena Online Private Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | FROM nvcr.io/nvidia/pytorch:22.08-py3
16 | ENV TZ=Asia/Singapore
17 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
18 |
19 | ENV PYTHONDONTWRITEBYTECODE 1
20 | ENV PYTHONUNBUFFERED 1
21 |
22 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
23 |
24 | WORKDIR /workspace
25 |
26 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
27 | ENTRYPOINT ["/entry.sh"]
28 |
--------------------------------------------------------------------------------
/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Garena Online Private Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Loss functions used in the paper "Fast Diffusion Model"."""
16 |
17 | import torch
18 | from torch_utils import persistence
19 | from torch_utils import distributed as dist
20 | import numpy as np
21 |
22 | #----------------------------------------------------------------------------
23 | # VP loss function with FDM loss weight warmup strategy.
24 |
25 | @persistence.persistent_class
26 | class VPLoss:
27 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5, warmup_ite=None):
28 | self.beta_d = beta_d
29 | self.beta_min = beta_min
30 | self.epsilon_t = epsilon_t
31 | self.warmup_ite = warmup_ite
32 | self.clamp_cur = 5.
33 | self.clamp_max = 500.
34 | if self.warmup_ite:
35 | self.warmup_step = np.exp(np.log(100) / self.warmup_ite)
36 |
37 | def __call__(self, net, images, labels, augment_pipe=None):
38 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
39 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
40 | weight = 1 / sigma ** 2
41 | if self.warmup_ite:
42 | if self.clamp_cur < self.clamp_max:
43 | weight.clamp_max_(self.clamp_cur)
44 | self.clamp_cur *= self.warmup_step
45 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
46 | n = torch.randn_like(y) * sigma
47 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
48 | loss = weight * ((D_yn - y) ** 2)
49 | return loss
50 |
51 | def sigma(self, t):
52 | t = torch.as_tensor(t)
53 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
54 |
55 | #----------------------------------------------------------------------------
56 | # VE loss function with FDM loss weight warmup strategy.
57 |
58 | @persistence.persistent_class
59 | class VELoss:
60 | def __init__(self, sigma_min=0.02, sigma_max=100, warmup_ite=None):
61 | self.sigma_min = sigma_min
62 | self.sigma_max = sigma_max
63 | self.warmup_ite = warmup_ite
64 | self.clamp_cur = 5.
65 | self.clamp_max = 500.
66 | if self.warmup_ite:
67 | self.warmup_step = np.exp(np.log(100) / self.warmup_ite)
68 |
69 | def __call__(self, net, images, labels, augment_pipe=None):
70 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
71 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
72 | weight = 1 / sigma ** 2
73 | if self.warmup_ite:
74 | if self.clamp_cur < self.clamp_max:
75 | weight.clamp_max_(self.clamp_cur)
76 | self.clamp_cur *= self.warmup_step
77 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
78 | n = torch.randn_like(y) * sigma
79 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
80 | loss = weight * ((D_yn - y) ** 2)
81 | return loss
82 |
83 | #----------------------------------------------------------------------------
84 | # EDM loss function with our FDM weight warmup strategy.
85 |
86 | @persistence.persistent_class
87 | class EDMLoss:
88 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, warmup_ite=None):
89 | self.P_mean = P_mean
90 | self.P_std = P_std
91 | self.sigma_data = sigma_data
92 | self.warmup_ite = warmup_ite
93 | self.clamp_cur = 5.
94 | self.clamp_max = 500.
95 | if self.warmup_ite:
96 | self.warmup_step = np.exp(np.log(100) / self.warmup_ite)
97 |
98 | def __call__(self, net, images, labels=None, augment_pipe=None):
99 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
100 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
101 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
102 | if self.warmup_ite:
103 | if self.clamp_cur < self.clamp_max:
104 | weight.clamp_max_(self.clamp_cur)
105 | self.clamp_cur *= self.warmup_step
106 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
107 | n = torch.randn_like(y) * sigma
108 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
109 | loss = weight * ((D_yn - y) ** 2)
110 | return loss
111 |
--------------------------------------------------------------------------------
/fid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Script for calculating Frechet Inception Distance (FID)."""
9 |
10 | import os
11 | import click
12 | import tqdm
13 | import pickle
14 | import numpy as np
15 | import scipy.linalg
16 | import torch
17 | import dnnlib
18 | from torch_utils import distributed as dist
19 | from training import dataset
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def calculate_inception_stats(
24 | image_path, num_expected=None, seed=0, max_batch_size=64,
25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
26 | ):
27 | # Rank 0 goes first.
28 | if dist.get_rank() != 0:
29 | torch.distributed.barrier()
30 |
31 | # Load Inception-v3 model.
32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
33 | dist.print0('Loading Inception-v3 model...')
34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
35 | detector_kwargs = dict(return_features=True)
36 | feature_dim = 2048
37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
38 | detector_net = pickle.load(f).to(device)
39 |
40 | # List images.
41 | dist.print0(f'Loading images from "{image_path}"...')
42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
43 | if num_expected is not None and len(dataset_obj) < num_expected:
44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
45 | if len(dataset_obj) < 2:
46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
47 |
48 | # Other ranks follow.
49 | if dist.get_rank() == 0:
50 | torch.distributed.barrier()
51 |
52 | # Divide images into batches.
53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
57 |
58 | # Accumulate statistics.
59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
63 | torch.distributed.barrier()
64 | if images.shape[0] == 0:
65 | continue
66 | if images.shape[1] == 1:
67 | images = images.repeat([1, 3, 1, 1])
68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
69 | mu += features.sum(0)
70 | sigma += features.T @ features
71 |
72 | # Calculate grand totals.
73 | torch.distributed.all_reduce(mu)
74 | torch.distributed.all_reduce(sigma)
75 | mu /= len(dataset_obj)
76 | sigma -= mu.ger(mu) * len(dataset_obj)
77 | sigma /= len(dataset_obj) - 1
78 | return mu.cpu().numpy(), sigma.cpu().numpy()
79 |
80 | #----------------------------------------------------------------------------
81 |
82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
83 | m = np.square(mu - mu_ref).sum()
84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
85 | fid = m + np.trace(sigma + sigma_ref - s * 2)
86 | return float(np.real(fid))
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | @click.group()
91 | def main():
92 | """Calculate Frechet Inception Distance (FID).
93 |
94 | Examples:
95 |
96 | \b
97 | # Generate 50000 images and save them as fid-tmp/*/*.png
98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
100 |
101 | \b
102 | # Calculate FID
103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
105 |
106 | \b
107 | # Compute dataset reference statistics
108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
109 | """
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | @main.command()
114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
119 |
120 | def calc(image_path, ref_path, num_expected, seed, batch):
121 | """Calculate FID for a given set of images."""
122 | torch.multiprocessing.set_start_method('spawn')
123 | dist.init()
124 |
125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
126 | ref = None
127 | if dist.get_rank() == 0:
128 | with dnnlib.util.open_url(ref_path) as f:
129 | ref = dict(np.load(f))
130 |
131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
132 | dist.print0('Calculating FID...')
133 | if dist.get_rank() == 0:
134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
135 | print(f'{fid:g}')
136 | torch.distributed.barrier()
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | @main.command()
141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
144 |
145 | def ref(dataset_path, dest_path, batch):
146 | """Calculate dataset reference statistics needed by 'calc'."""
147 | torch.multiprocessing.set_start_method('spawn')
148 | dist.init()
149 |
150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
152 | if dist.get_rank() == 0:
153 | if os.path.dirname(dest_path):
154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True)
155 | np.savez(dest_path, mu=mu, sigma=sigma)
156 |
157 | torch.distributed.barrier()
158 | dist.print0('Done.')
159 |
160 | #----------------------------------------------------------------------------
161 |
162 | if __name__ == "__main__":
163 | main()
164 |
165 | #----------------------------------------------------------------------------
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fast Diffusion Model
2 |
3 | This is an official PyTorch implementation of Fast Diffusion Model. See the paper [here](https://arxiv.org/abs/2306.06991). If you find our FDM helpful or heuristic to your projects, please cite this paper and also star this repository. Thanks!
4 |
5 | ```bibtex
6 | @misc{wu2023fast,
7 | title={Fast Diffusion Model},
8 | author={Zike Wu and Pan Zhou and Kenji Kawaguchi and Hanwang Zhang},
9 | year={2023},
10 | eprint={2306.06991},
11 | archivePrefix={arXiv},
12 | primaryClass={cs.CV}
13 | }
14 | ```
15 |
16 | Acknowledgement: This repo is based on the following amazing projects: [EDM](https://github.com/NVlabs/edm) and [DPM-Solver](https://github.com/LuChengTHU/dpm-solver).
17 | ## Results
18 | Image synthesis performance (FID) under different million training images (Mimg) is as follows.
19 | | Dataset | Duration
(Mimg) | EDM | EDM-FDM | VP | VP-FDM | VE | VE-FDM |
20 | |:--------:|:---------------:|:----:|:-------:|:----:|:------:|:-----:|:------:|
21 | | CIFAR10 | 50 | 5.76 | 2.17 | 2.74 | 2.74 | 49.47 | 10.01 |
22 | | CIFAR10 | 100 | 1.99 | 1.93 | 2.24 | 2.24 | 4.05 | 3.26 |
23 | | CIFAR10 | 150 | 1.92 | 1.83 | 2.19 | 2.13 | 3.27 | 3.00 |
24 | | CIFAR10 | 200 | 1.88 | 1.79 | 2.15 | 2.08 | 3.09 | 2.85 |
25 | | FFHQ | 50 | 3.21 | 3.27 | 3.07 | 12.49 | 96.49 | 93.72 |
26 | | FFHQ | 100 | 2.87 | 2.69 | 2.83 | 2.80 | 94.14 | 88.42 |
27 | | FFHQ | 150 | 2.69 | 2.63 | 2.73 | 2.53 | 79.20 | 4.73 |
28 | | FFHQ | 200 | 2.65 | 2.59 | 2.69 | 2.43 | 38.97 | 3.04 |
29 | | AFHQv2 | 50 | 2.62 | 2.73 | 3.46 | 25.70 | 57.93 | 54.41 |
30 | | AFHQv2 | 100 | 2.57 | 2.05 | 2.81 | 2.65 | 57.87 | 52.45 |
31 | | AFHQv2 | 150 | 2.44 | 1.96 | 2.72 | 2.47 | 57.69 | 50.53 |
32 | | AFHQv2 | 200 | 2.37 | 1.93 | 2.61 | 2.39 | 57.48 | 47.30 |
33 |
34 | Image synthesis performance (FID) under different inference cost on AFHQv2 with EDM sampler.
35 | | NFE | EDM | EDM-FDM | VP | VP-FDM | VE | VE-FDM |
36 | |:---:|:----:|:-------:|:----:|:------:|:-----:|:------:|
37 | | 25 | 2.78 | 2.32 | 2.88 | 2.59 | 61.04 | 48.29 |
38 | | 49 | 2.39 | 1.93 | 2.64 | 2.41 | 57.59 | 47.49 |
39 | | 79 | 2.37 | 1.93 | 2.61 | 2.39 | 57.48 | 47.30 |
40 |
41 | Image synthesis performance (FID) under different inference cost on AFHQv2 with DPM-Solver++.
42 | | NFE | EDM | EDM-FDM | VP | VP-FDM | VE | VE-FDM |
43 | |:---:|:----:|:-------:|:----:|:------:|:-----:|:------:|
44 | | 25 | 2.60 | 2.09 | 2.99 | 2.64 | 59.26 | 49.51 |
45 | | 49 | 2.42 | 1.98 | 2.79 | 2.45 | 59.16 | 48.68 |
46 | | 79 | 2.39 | 1.95 | 2.78 | 2.42 | 58.91 | 48.66 |
47 |
48 | ## Requirements
49 | All experiments were conducted using PyTorch 1.13.0, CUDA 11.7.1, and CuDNN 8.5.0. We strongly recommend to use the [provided Dockerfile](./Dockerfile) to build an image to reproduce our experiments.
50 |
51 | ## Pre-trained models
52 | We provide pre-trained models for our FDMs along with the baseline models on Hugging Face. Download the checkpoints [here](https://huggingface.co/sail/FDM/tree/main).
53 |
54 | ## Preparing datasets
55 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive:
56 |
57 | ```.bash
58 | python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
59 | --dest=datasets/cifar10-32x32.zip
60 | python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz
61 | ```
62 |
63 | **FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at 64x64 resolution:
64 |
65 | ```.bash
66 | python dataset_tool.py --source=downloads/ffhq/images1024x1024 \
67 | --dest=datasets/ffhq-64x64.zip --resolution=64x64
68 | python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz
69 | ```
70 |
71 | **AFHQv2:** Download the updated [Animal Faces-HQ dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) (`afhq-v2-dataset`) and convert to ZIP archive at 64x64 resolution:
72 |
73 | ```.bash
74 | python dataset_tool.py --source=downloads/afhqv2 \
75 | --dest=datasets/afhqv2-64x64.zip --resolution=64x64
76 | python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz
77 | ```
78 | ## Training from scratch
79 | Train FDM for class-conditional CIFAR-10 using 8 GPUs:
80 | ```.bash
81 | # EDM-FDM
82 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
83 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp \
84 | --precond=fdm_edm --warmup_ite=200
85 |
86 | # VP-FDM
87 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
88 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp --cres=1,2,2,2 \
89 | --precond=fdm_vp --warmup_ite=400
90 |
91 | # VE-FDM
92 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
93 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ncsnpp --cres=1,2,2,2 \
94 | --precond=fdm_ve --warmup_ite=400
95 | ```
96 |
97 | Train FDM for unconditional FFHQ using 8 GPUs:
98 | ```.bash
99 | # EDM-FDM
100 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output
101 | --data=datasets/ffhq-64x64.zip --cond=0 --arch=ddpmpp \
102 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 \
103 | --precond=fdm_edm --warmup_ite=800 --fdm_multipler=1
104 |
105 | # VP-FDM
106 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
107 | --data=datasets/ffhq-64x64.zip --cond=0 --arch=ddpmpp \
108 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 \
109 | --precond=fdm_vp --warmup_ite=400 --fdm_multipler=1
110 |
111 | # VE-FDM
112 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
113 | --data=datasets/ffhq-64x64.zip --cond=0 --arch=ncsnpp \
114 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 \
115 | --precond=fdm_ve --warmup_ite=400
116 | ```
117 |
118 | Train FDM for unconditional AFHQv2 using 8 GPUs:
119 | ```.bash
120 | # EDM-FDM
121 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output
122 | --data=datasets/afhqv2-64x64.zip --cond=0 --arch=ddpmpp \
123 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 \
124 | --precond=fdm_edm --warmup_ite=400
125 |
126 | # VP-FDM
127 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
128 | --data=datasets/afhqv2-64x64.zip --cond=0 --arch=ddpmpp \
129 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 \
130 | --precond=fdm_vp --warmup_ite=400
131 |
132 | # VE-FDM
133 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-output \
134 | --data=datasets/afhqv2-64x64.zip --cond=0 --arch=ncsnpp \
135 | --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15 \
136 | --precond=fdm_ve --warmup_ite=400
137 | ```
138 |
139 | ## Calculating FID
140 | To compute Fréchet inception distance (FID) for a given model and sampler, first generate 50,000 random images and then compare them against the dataset reference statistics using `fid.py`, replace `$PATH_TO_CHECKPOINT` with the path to the checkpoint:
141 |
142 | ```.bash
143 | # Generate 50000 images
144 | torchrun --standalone --nproc_per_node=8 generate.py --outdir=fid \
145 | --seeds=0-49999 --subdirs --network=$PATH_TO_CHECKPOINT
146 | # Calculate FID
147 | torchrun --standalone --nproc_per_node=8 fid.py calc --images=fid \
148 | --ref=fid-refs/cifar10-32x32.npz
149 | ```
150 |
151 | Note that the generated images should be evaluated against the same reference dataset that the model was originally trained on. Please ensure to replace the `--ref` option with the correct one (*e.g.*, `fid-refs/ffhq-64x64.npz` or `fid-refs/afhqv2-64x64.npz`) to obtain the right FID score. Addtionally, you can use `--solver=dpm` option to generate images with DPM-Solver++.
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Streaming images and labels from datasets created with dataset_tool.py."""
9 |
10 | import os
11 | import numpy as np
12 | import zipfile
13 | import PIL.Image
14 | import json
15 | import torch
16 | import dnnlib
17 |
18 | try:
19 | import pyspng
20 | except ImportError:
21 | pyspng = None
22 |
23 | #----------------------------------------------------------------------------
24 | # Abstract base class for datasets.
25 |
26 | class Dataset(torch.utils.data.Dataset):
27 | def __init__(self,
28 | name, # Name of the dataset.
29 | raw_shape, # Shape of the raw image data (NCHW).
30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
33 | random_seed = 0, # Random seed to use when applying max_size.
34 | cache = False, # Cache images in CPU memory?
35 | ):
36 | self._name = name
37 | self._raw_shape = list(raw_shape)
38 | self._use_labels = use_labels
39 | self._cache = cache
40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
41 | self._raw_labels = None
42 | self._label_shape = None
43 |
44 | # Apply max_size.
45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
46 | if (max_size is not None) and (self._raw_idx.size > max_size):
47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
48 | self._raw_idx = np.sort(self._raw_idx[:max_size])
49 |
50 | # Apply xflip.
51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
52 | if xflip:
53 | self._raw_idx = np.tile(self._raw_idx, 2)
54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
55 |
56 | def _get_raw_labels(self):
57 | if self._raw_labels is None:
58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
59 | if self._raw_labels is None:
60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
61 | assert isinstance(self._raw_labels, np.ndarray)
62 | assert self._raw_labels.shape[0] == self._raw_shape[0]
63 | assert self._raw_labels.dtype in [np.float32, np.int64]
64 | if self._raw_labels.dtype == np.int64:
65 | assert self._raw_labels.ndim == 1
66 | assert np.all(self._raw_labels >= 0)
67 | return self._raw_labels
68 |
69 | def close(self): # to be overridden by subclass
70 | pass
71 |
72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
73 | raise NotImplementedError
74 |
75 | def _load_raw_labels(self): # to be overridden by subclass
76 | raise NotImplementedError
77 |
78 | def __getstate__(self):
79 | return dict(self.__dict__, _raw_labels=None)
80 |
81 | def __del__(self):
82 | try:
83 | self.close()
84 | except:
85 | pass
86 |
87 | def __len__(self):
88 | return self._raw_idx.size
89 |
90 | def __getitem__(self, idx):
91 | raw_idx = self._raw_idx[idx]
92 | image = self._cached_images.get(raw_idx, None)
93 | if image is None:
94 | image = self._load_raw_image(raw_idx)
95 | if self._cache:
96 | self._cached_images[raw_idx] = image
97 | assert isinstance(image, np.ndarray)
98 | assert list(image.shape) == self.image_shape
99 | assert image.dtype == np.uint8
100 | if self._xflip[idx]:
101 | assert image.ndim == 3 # CHW
102 | image = image[:, :, ::-1]
103 | return image.copy(), self.get_label(idx)
104 |
105 | def get_label(self, idx):
106 | label = self._get_raw_labels()[self._raw_idx[idx]]
107 | if label.dtype == np.int64:
108 | onehot = np.zeros(self.label_shape, dtype=np.float32)
109 | onehot[label] = 1
110 | label = onehot
111 | return label.copy()
112 |
113 | def get_details(self, idx):
114 | d = dnnlib.EasyDict()
115 | d.raw_idx = int(self._raw_idx[idx])
116 | d.xflip = (int(self._xflip[idx]) != 0)
117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
118 | return d
119 |
120 | @property
121 | def name(self):
122 | return self._name
123 |
124 | @property
125 | def image_shape(self):
126 | return list(self._raw_shape[1:])
127 |
128 | @property
129 | def num_channels(self):
130 | assert len(self.image_shape) == 3 # CHW
131 | return self.image_shape[0]
132 |
133 | @property
134 | def resolution(self):
135 | assert len(self.image_shape) == 3 # CHW
136 | assert self.image_shape[1] == self.image_shape[2]
137 | return self.image_shape[1]
138 |
139 | @property
140 | def label_shape(self):
141 | if self._label_shape is None:
142 | raw_labels = self._get_raw_labels()
143 | if raw_labels.dtype == np.int64:
144 | self._label_shape = [int(np.max(raw_labels)) + 1]
145 | else:
146 | self._label_shape = raw_labels.shape[1:]
147 | return list(self._label_shape)
148 |
149 | @property
150 | def label_dim(self):
151 | assert len(self.label_shape) == 1
152 | return self.label_shape[0]
153 |
154 | @property
155 | def has_labels(self):
156 | return any(x != 0 for x in self.label_shape)
157 |
158 | @property
159 | def has_onehot_labels(self):
160 | return self._get_raw_labels().dtype == np.int64
161 |
162 | #----------------------------------------------------------------------------
163 | # Dataset subclass that loads images recursively from the specified directory
164 | # or ZIP file.
165 |
166 | class ImageFolderDataset(Dataset):
167 | def __init__(self,
168 | path, # Path to directory or zip.
169 | resolution = None, # Ensure specific resolution, None = highest available.
170 | use_pyspng = True, # Use pyspng if available?
171 | **super_kwargs, # Additional arguments for the Dataset base class.
172 | ):
173 | self._path = path
174 | self._use_pyspng = use_pyspng
175 | self._zipfile = None
176 |
177 | if os.path.isdir(self._path):
178 | self._type = 'dir'
179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
180 | elif self._file_ext(self._path) == '.zip':
181 | self._type = 'zip'
182 | self._all_fnames = set(self._get_zipfile().namelist())
183 | else:
184 | raise IOError('Path must point to a directory or zip')
185 |
186 | PIL.Image.init()
187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
188 | if len(self._image_fnames) == 0:
189 | raise IOError('No image files found in the specified path')
190 |
191 | name = os.path.splitext(os.path.basename(self._path))[0]
192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
194 | raise IOError('Image files do not match the specified resolution')
195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
196 |
197 | @staticmethod
198 | def _file_ext(fname):
199 | return os.path.splitext(fname)[1].lower()
200 |
201 | def _get_zipfile(self):
202 | assert self._type == 'zip'
203 | if self._zipfile is None:
204 | self._zipfile = zipfile.ZipFile(self._path)
205 | return self._zipfile
206 |
207 | def _open_file(self, fname):
208 | if self._type == 'dir':
209 | return open(os.path.join(self._path, fname), 'rb')
210 | if self._type == 'zip':
211 | return self._get_zipfile().open(fname, 'r')
212 | return None
213 |
214 | def close(self):
215 | try:
216 | if self._zipfile is not None:
217 | self._zipfile.close()
218 | finally:
219 | self._zipfile = None
220 |
221 | def __getstate__(self):
222 | return dict(super().__getstate__(), _zipfile=None)
223 |
224 | def _load_raw_image(self, raw_idx):
225 | fname = self._image_fnames[raw_idx]
226 | with self._open_file(fname) as f:
227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
228 | image = pyspng.load(f.read())
229 | else:
230 | image = np.array(PIL.Image.open(f))
231 | if image.ndim == 2:
232 | image = image[:, :, np.newaxis] # HW => HWC
233 | image = image.transpose(2, 0, 1) # HWC => CHW
234 | return image
235 |
236 | def _load_raw_labels(self):
237 | fname = 'dataset.json'
238 | if fname not in self._all_fnames:
239 | return None
240 | with self._open_file(fname) as f:
241 | labels = json.load(f)['labels']
242 | if labels is None:
243 | return None
244 | labels = dict(labels)
245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
246 | labels = np.array(labels)
247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
248 | return labels
249 |
250 | #----------------------------------------------------------------------------
251 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for pickling Python code alongside other data.
9 |
10 | The pickled code is automatically imported into a separate Python module
11 | during unpickling. This way, any previously exported pickles will remain
12 | usable even if the original code is no longer available, or if the current
13 | version of the code is not consistent with what was originally pickled."""
14 |
15 | import sys
16 | import pickle
17 | import io
18 | import inspect
19 | import copy
20 | import uuid
21 | import types
22 | import dnnlib
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | _version = 6 # internal version number
27 | _decorators = set() # {decorator_class, ...}
28 | _import_hooks = [] # [hook_function, ...]
29 | _module_to_src_dict = dict() # {module: src, ...}
30 | _src_to_module_dict = dict() # {src: module, ...}
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def persistent_class(orig_class):
35 | r"""Class decorator that extends a given class to save its source code
36 | when pickled.
37 |
38 | Example:
39 |
40 | from torch_utils import persistence
41 |
42 | @persistence.persistent_class
43 | class MyNetwork(torch.nn.Module):
44 | def __init__(self, num_inputs, num_outputs):
45 | super().__init__()
46 | self.fc = MyLayer(num_inputs, num_outputs)
47 | ...
48 |
49 | @persistence.persistent_class
50 | class MyLayer(torch.nn.Module):
51 | ...
52 |
53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
54 | source code alongside other internal state (e.g., parameters, buffers,
55 | and submodules). This way, any previously exported pickle will remain
56 | usable even if the class definitions have been modified or are no
57 | longer available.
58 |
59 | The decorator saves the source code of the entire Python module
60 | containing the decorated class. It does *not* save the source code of
61 | any imported modules. Thus, the imported modules must be available
62 | during unpickling, also including `torch_utils.persistence` itself.
63 |
64 | It is ok to call functions defined in the same module from the
65 | decorated class. However, if the decorated class depends on other
66 | classes defined in the same module, they must be decorated as well.
67 | This is illustrated in the above example in the case of `MyLayer`.
68 |
69 | It is also possible to employ the decorator just-in-time before
70 | calling the constructor. For example:
71 |
72 | cls = MyLayer
73 | if want_to_make_it_persistent:
74 | cls = persistence.persistent_class(cls)
75 | layer = cls(num_inputs, num_outputs)
76 |
77 | As an additional feature, the decorator also keeps track of the
78 | arguments that were used to construct each instance of the decorated
79 | class. The arguments can be queried via `obj.init_args` and
80 | `obj.init_kwargs`, and they are automatically pickled alongside other
81 | object state. This feature can be disabled on a per-instance basis
82 | by setting `self._record_init_args = False` in the constructor.
83 |
84 | A typical use case is to first unpickle a previous instance of a
85 | persistent class, and then upgrade it to use the latest version of
86 | the source code:
87 |
88 | with open('old_pickle.pkl', 'rb') as f:
89 | old_net = pickle.load(f)
90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92 | """
93 | assert isinstance(orig_class, type)
94 | if is_persistent(orig_class):
95 | return orig_class
96 |
97 | assert orig_class.__module__ in sys.modules
98 | orig_module = sys.modules[orig_class.__module__]
99 | orig_module_src = _module_to_src(orig_module)
100 |
101 | class Decorator(orig_class):
102 | _orig_module_src = orig_module_src
103 | _orig_class_name = orig_class.__name__
104 |
105 | def __init__(self, *args, **kwargs):
106 | super().__init__(*args, **kwargs)
107 | record_init_args = getattr(self, '_record_init_args', True)
108 | self._init_args = copy.deepcopy(args) if record_init_args else None
109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
110 | assert orig_class.__name__ in orig_module.__dict__
111 | _check_pickleable(self.__reduce__())
112 |
113 | @property
114 | def init_args(self):
115 | assert self._init_args is not None
116 | return copy.deepcopy(self._init_args)
117 |
118 | @property
119 | def init_kwargs(self):
120 | assert self._init_kwargs is not None
121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
122 |
123 | def __reduce__(self):
124 | fields = list(super().__reduce__())
125 | fields += [None] * max(3 - len(fields), 0)
126 | if fields[0] is not _reconstruct_persistent_obj:
127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
128 | fields[0] = _reconstruct_persistent_obj # reconstruct func
129 | fields[1] = (meta,) # reconstruct args
130 | fields[2] = None # state dict
131 | return tuple(fields)
132 |
133 | Decorator.__name__ = orig_class.__name__
134 | Decorator.__module__ = orig_class.__module__
135 | _decorators.add(Decorator)
136 | return Decorator
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | def is_persistent(obj):
141 | r"""Test whether the given object or class is persistent, i.e.,
142 | whether it will save its source code when pickled.
143 | """
144 | try:
145 | if obj in _decorators:
146 | return True
147 | except TypeError:
148 | pass
149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150 |
151 | #----------------------------------------------------------------------------
152 |
153 | def import_hook(hook):
154 | r"""Register an import hook that is called whenever a persistent object
155 | is being unpickled. A typical use case is to patch the pickled source
156 | code to avoid errors and inconsistencies when the API of some imported
157 | module has changed.
158 |
159 | The hook should have the following signature:
160 |
161 | hook(meta) -> modified meta
162 |
163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164 |
165 | type: Type of the persistent object, e.g. `'class'`.
166 | version: Internal version number of `torch_utils.persistence`.
167 | module_src Original source code of the Python module.
168 | class_name: Class name in the original Python module.
169 | state: Internal state of the object.
170 |
171 | Example:
172 |
173 | @persistence.import_hook
174 | def wreck_my_network(meta):
175 | if meta.class_name == 'MyNetwork':
176 | print('MyNetwork is being imported. I will wreck it!')
177 | meta.module_src = meta.module_src.replace("True", "False")
178 | return meta
179 | """
180 | assert callable(hook)
181 | _import_hooks.append(hook)
182 |
183 | #----------------------------------------------------------------------------
184 |
185 | def _reconstruct_persistent_obj(meta):
186 | r"""Hook that is called internally by the `pickle` module to unpickle
187 | a persistent object.
188 | """
189 | meta = dnnlib.EasyDict(meta)
190 | meta.state = dnnlib.EasyDict(meta.state)
191 | for hook in _import_hooks:
192 | meta = hook(meta)
193 | assert meta is not None
194 |
195 | assert meta.version == _version
196 | module = _src_to_module(meta.module_src)
197 |
198 | assert meta.type == 'class'
199 | orig_class = module.__dict__[meta.class_name]
200 | decorator_class = persistent_class(orig_class)
201 | obj = decorator_class.__new__(decorator_class)
202 |
203 | setstate = getattr(obj, '__setstate__', None)
204 | if callable(setstate):
205 | setstate(meta.state) # pylint: disable=not-callable
206 | else:
207 | obj.__dict__.update(meta.state)
208 | return obj
209 |
210 | #----------------------------------------------------------------------------
211 |
212 | def _module_to_src(module):
213 | r"""Query the source code of a given Python module.
214 | """
215 | src = _module_to_src_dict.get(module, None)
216 | if src is None:
217 | src = inspect.getsource(module)
218 | _module_to_src_dict[module] = src
219 | _src_to_module_dict[src] = module
220 | return src
221 |
222 | def _src_to_module(src):
223 | r"""Get or create a Python module for the given source code.
224 | """
225 | module = _src_to_module_dict.get(src, None)
226 | if module is None:
227 | module_name = "_imported_module_" + uuid.uuid4().hex
228 | module = types.ModuleType(module_name)
229 | sys.modules[module_name] = module
230 | _module_to_src_dict[module] = src
231 | _src_to_module_dict[src] = module
232 | exec(src, module.__dict__) # pylint: disable=exec-used
233 | return module
234 |
235 | #----------------------------------------------------------------------------
236 |
237 | def _check_pickleable(obj):
238 | r"""Check that the given object is pickleable, raising an exception if
239 | it is not. This function is expected to be considerably more efficient
240 | than actually pickling the object.
241 | """
242 | def recurse(obj):
243 | if isinstance(obj, (list, tuple, set)):
244 | return [recurse(x) for x in obj]
245 | if isinstance(obj, dict):
246 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248 | return None # Python primitive types are pickleable.
249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250 | return None # NumPy arrays and PyTorch tensors are pickleable.
251 | if is_persistent(obj):
252 | return None # Persistent objects are pickleable, by virtue of the constructor check.
253 | return obj
254 | with io.BytesIO() as f:
255 | pickle.dump(recurse(obj), f)
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Garena Online Private Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Generate random images using EDM Sampler and DPM-Solver++."""
16 |
17 | import os
18 | import re
19 | import click
20 | import tqdm
21 | import pickle
22 | import numpy as np
23 | import torch
24 | import PIL.Image
25 | import dnnlib
26 | from torch_utils import distributed as dist
27 | from dpm_solver import NoiseScheduleEDM, DPM_Solver
28 |
29 | #----------------------------------------------------------------------------
30 | # EDM sampler
31 |
32 | def edm_sampler(
33 | net, latents, class_labels=None, randn_like=torch.randn_like,
34 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
35 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
36 | ):
37 | # Adjust noise levels based on what's supported by the network.
38 | sigma_min = max(sigma_min, net.sigma_min)
39 | sigma_max = min(sigma_max, net.sigma_max)
40 |
41 | # Time step discretization.
42 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
43 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
44 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
45 |
46 | # Main sampling loop.
47 | x_next = latents.to(torch.float64) * t_steps[0]
48 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
49 | x_cur = x_next
50 |
51 | # Increase noise temporarily.
52 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
53 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
54 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
55 |
56 | # Euler step.
57 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
58 | d_cur = (x_hat - denoised) / t_hat
59 | x_next = x_hat + (t_next - t_hat) * d_cur
60 |
61 | # Apply 2nd order correction.
62 | if i < num_steps - 1:
63 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
64 | d_prime = (x_next - denoised) / t_next
65 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
66 |
67 | return x_next
68 |
69 | #----------------------------------------------------------------------------
70 | # DPM-Solver++
71 |
72 | def ablation_sampler(
73 | net, latents, class_labels=None, randn_like=torch.randn_like,
74 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
75 | denoise=True, **kwargs
76 | ):
77 | x_next = latents.to(torch.float64) * sigma_max
78 | ns = NoiseScheduleEDM('linear')
79 | with torch.no_grad():
80 | noise_pred_fn = lambda x, t: (x - net(x, t, class_labels).to(torch.float64)) / t
81 | dpm_solver = DPM_Solver(noise_pred_fn, ns, algorithm_type="dpmsolver++")
82 | # Initial sample
83 | x_next = dpm_solver.sample(
84 | x_next,
85 | steps=num_steps - 1 if denoise else num_steps,
86 | t_start=sigma_max,
87 | t_end=sigma_min,
88 | order=3,
89 | skip_type="logSNR",
90 | method="singlestep",
91 | denoise_to_zero=denoise,
92 | lower_order_final=True,
93 | )
94 | return x_next
95 |
96 | #----------------------------------------------------------------------------
97 | # Wrapper for torch.Generator that allows specifying a different random seed
98 | # for each sample in a minibatch.
99 |
100 | class StackedRandomGenerator:
101 | def __init__(self, device, seeds):
102 | super().__init__()
103 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
104 |
105 | def randn(self, size, **kwargs):
106 | assert size[0] == len(self.generators)
107 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
108 |
109 | def randn_like(self, input):
110 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
111 |
112 | def randint(self, *args, size, **kwargs):
113 | assert size[0] == len(self.generators)
114 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
115 |
116 | #----------------------------------------------------------------------------
117 | # Parse a comma separated list of numbers or ranges and return a list of ints.
118 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
119 |
120 | def parse_int_list(s):
121 | if isinstance(s, list): return s
122 | ranges = []
123 | range_re = re.compile(r'^(\d+)-(\d+)$')
124 | for p in s.split(','):
125 | m = range_re.match(p)
126 | if m:
127 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
128 | else:
129 | ranges.append(int(p))
130 | return ranges
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | @click.command()
135 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
136 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
137 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
138 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
139 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
140 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
141 |
142 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
143 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
144 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
145 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
146 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
147 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
148 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
149 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
150 | @click.option('--solver', help='Ablate ODE solver', metavar='edm|dpm', type=click.Choice(['edm', 'dpm']), default='edm')
151 |
152 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs):
153 | dist.init()
154 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
155 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
156 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
157 |
158 | # Rank 0 goes first.
159 | if dist.get_rank() != 0:
160 | torch.distributed.barrier()
161 |
162 | # Load network.
163 | dist.print0(f'Loading network from "{network_pkl}"...')
164 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
165 | net = pickle.load(f)['ema'].to(device)
166 |
167 | # Other ranks follow.
168 | if dist.get_rank() == 0:
169 | torch.distributed.barrier()
170 |
171 | # Loop over batches.
172 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
173 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
174 | torch.distributed.barrier()
175 | batch_size = len(batch_seeds)
176 | if batch_size == 0:
177 | continue
178 |
179 | # Pick latents and labels.
180 | rnd = StackedRandomGenerator(device, batch_seeds)
181 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
182 | class_labels = None
183 | if net.label_dim:
184 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
185 | if class_idx is not None:
186 | class_labels[:, :] = 0
187 | class_labels[:, class_idx] = 1
188 |
189 | # Generate images.
190 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
191 | sampler_fn = edm_sampler if sampler_kwargs.pop('solver', 'edm') == 'edm' else ablation_sampler
192 | images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs)
193 |
194 | # Save images.
195 | images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
196 | for seed, image_np in zip(batch_seeds, images_np):
197 | image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir
198 | os.makedirs(image_dir, exist_ok=True)
199 | image_path = os.path.join(image_dir, f'{seed:06d}.png')
200 | if image_np.shape[2] == 1:
201 | PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
202 | else:
203 | PIL.Image.fromarray(image_np, 'RGB').save(image_path)
204 |
205 | # Done.
206 | torch.distributed.barrier()
207 | dist.print0('Done.')
208 |
209 | #----------------------------------------------------------------------------
210 |
211 | if __name__ == "__main__":
212 | main()
213 |
214 | #----------------------------------------------------------------------------
215 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for reporting and collecting training statistics across
9 | multiple processes and devices. The interface is designed to minimize
10 | synchronization overhead as well as the amount of boilerplate in user
11 | code."""
12 |
13 | import re
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 |
18 | from . import misc
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
25 | _rank = 0 # Rank of the current process.
26 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
27 | _sync_called = False # Has _sync() been called yet?
28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | def init_multiprocessing(rank, sync_device):
34 | r"""Initializes `torch_utils.training_stats` for collecting statistics
35 | across multiple processes.
36 |
37 | This function must be called after
38 | `torch.distributed.init_process_group()` and before `Collector.update()`.
39 | The call is not necessary if multi-process collection is not needed.
40 |
41 | Args:
42 | rank: Rank of the current process.
43 | sync_device: PyTorch device to use for inter-process
44 | communication, or None to disable multi-process
45 | collection. Typically `torch.device('cuda', rank)`.
46 | """
47 | global _rank, _sync_device
48 | assert not _sync_called
49 | _rank = rank
50 | _sync_device = sync_device
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | @misc.profiled_function
55 | def report(name, value):
56 | r"""Broadcasts the given set of scalars to all interested instances of
57 | `Collector`, across device and process boundaries.
58 |
59 | This function is expected to be extremely cheap and can be safely
60 | called from anywhere in the training loop, loss function, or inside a
61 | `torch.nn.Module`.
62 |
63 | Warning: The current implementation expects the set of unique names to
64 | be consistent across processes. Please make sure that `report()` is
65 | called at least once for each unique name by each process, and in the
66 | same order. If a given process has no scalars to broadcast, it can do
67 | `report(name, [])` (empty list).
68 |
69 | Args:
70 | name: Arbitrary string specifying the name of the statistic.
71 | Averages are accumulated separately for each unique name.
72 | value: Arbitrary set of scalars. Can be a list, tuple,
73 | NumPy array, PyTorch tensor, or Python scalar.
74 |
75 | Returns:
76 | The same `value` that was passed in.
77 | """
78 | if name not in _counters:
79 | _counters[name] = dict()
80 |
81 | elems = torch.as_tensor(value)
82 | if elems.numel() == 0:
83 | return value
84 |
85 | elems = elems.detach().flatten().to(_reduce_dtype)
86 | moments = torch.stack([
87 | torch.ones_like(elems).sum(),
88 | elems.sum(),
89 | elems.square().sum(),
90 | ])
91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
92 | moments = moments.to(_counter_dtype)
93 |
94 | device = moments.device
95 | if device not in _counters[name]:
96 | _counters[name][device] = torch.zeros_like(moments)
97 | _counters[name][device].add_(moments)
98 | return value
99 |
100 | #----------------------------------------------------------------------------
101 |
102 | def report0(name, value):
103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
104 | but ignores any scalars provided by the other processes.
105 | See `report()` for further details.
106 | """
107 | report(name, value if _rank == 0 else [])
108 | return value
109 |
110 | #----------------------------------------------------------------------------
111 |
112 | class Collector:
113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
114 | computes their long-term averages (mean and standard deviation) over
115 | user-defined periods of time.
116 |
117 | The averages are first collected into internal counters that are not
118 | directly visible to the user. They are then copied to the user-visible
119 | state as a result of calling `update()` and can then be queried using
120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
121 | internal counters for the next round, so that the user-visible state
122 | effectively reflects averages collected between the last two calls to
123 | `update()`.
124 |
125 | Args:
126 | regex: Regular expression defining which statistics to
127 | collect. The default is to collect everything.
128 | keep_previous: Whether to retain the previous averages if no
129 | scalars were collected on a given round
130 | (default: True).
131 | """
132 | def __init__(self, regex='.*', keep_previous=True):
133 | self._regex = re.compile(regex)
134 | self._keep_previous = keep_previous
135 | self._cumulative = dict()
136 | self._moments = dict()
137 | self.update()
138 | self._moments.clear()
139 |
140 | def names(self):
141 | r"""Returns the names of all statistics broadcasted so far that
142 | match the regular expression specified at construction time.
143 | """
144 | return [name for name in _counters if self._regex.fullmatch(name)]
145 |
146 | def update(self):
147 | r"""Copies current values of the internal counters to the
148 | user-visible state and resets them for the next round.
149 |
150 | If `keep_previous=True` was specified at construction time, the
151 | operation is skipped for statistics that have received no scalars
152 | since the last update, retaining their previous averages.
153 |
154 | This method performs a number of GPU-to-CPU transfers and one
155 | `torch.distributed.all_reduce()`. It is intended to be called
156 | periodically in the main training loop, typically once every
157 | N training steps.
158 | """
159 | if not self._keep_previous:
160 | self._moments.clear()
161 | for name, cumulative in _sync(self.names()):
162 | if name not in self._cumulative:
163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
164 | delta = cumulative - self._cumulative[name]
165 | self._cumulative[name].copy_(cumulative)
166 | if float(delta[0]) != 0:
167 | self._moments[name] = delta
168 |
169 | def _get_delta(self, name):
170 | r"""Returns the raw moments that were accumulated for the given
171 | statistic between the last two calls to `update()`, or zero if
172 | no scalars were collected.
173 | """
174 | assert self._regex.fullmatch(name)
175 | if name not in self._moments:
176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
177 | return self._moments[name]
178 |
179 | def num(self, name):
180 | r"""Returns the number of scalars that were accumulated for the given
181 | statistic between the last two calls to `update()`, or zero if
182 | no scalars were collected.
183 | """
184 | delta = self._get_delta(name)
185 | return int(delta[0])
186 |
187 | def mean(self, name):
188 | r"""Returns the mean of the scalars that were accumulated for the
189 | given statistic between the last two calls to `update()`, or NaN if
190 | no scalars were collected.
191 | """
192 | delta = self._get_delta(name)
193 | if int(delta[0]) == 0:
194 | return float('nan')
195 | return float(delta[1] / delta[0])
196 |
197 | def std(self, name):
198 | r"""Returns the standard deviation of the scalars that were
199 | accumulated for the given statistic between the last two calls to
200 | `update()`, or NaN if no scalars were collected.
201 | """
202 | delta = self._get_delta(name)
203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
204 | return float('nan')
205 | if int(delta[0]) == 1:
206 | return float(0)
207 | mean = float(delta[1] / delta[0])
208 | raw_var = float(delta[2] / delta[0])
209 | return np.sqrt(max(raw_var - np.square(mean), 0))
210 |
211 | def as_dict(self):
212 | r"""Returns the averages accumulated between the last two calls to
213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
214 |
215 | dnnlib.EasyDict(
216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
217 | ...
218 | )
219 | """
220 | stats = dnnlib.EasyDict()
221 | for name in self.names():
222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
223 | return stats
224 |
225 | def __getitem__(self, name):
226 | r"""Convenience getter.
227 | `collector[name]` is a synonym for `collector.mean(name)`.
228 | """
229 | return self.mean(name)
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def _sync(names):
234 | r"""Synchronize the global cumulative counters across devices and
235 | processes. Called internally by `Collector.update()`.
236 | """
237 | if len(names) == 0:
238 | return []
239 | global _sync_called
240 | _sync_called = True
241 |
242 | # Collect deltas within current rank.
243 | deltas = []
244 | device = _sync_device if _sync_device is not None else torch.device('cpu')
245 | for name in names:
246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
247 | for counter in _counters[name].values():
248 | delta.add_(counter.to(device))
249 | counter.copy_(torch.zeros_like(counter))
250 | deltas.append(delta)
251 | deltas = torch.stack(deltas)
252 |
253 | # Sum deltas across ranks.
254 | if _sync_device is not None:
255 | torch.distributed.all_reduce(deltas)
256 |
257 | # Update cumulative values.
258 | deltas = deltas.cpu()
259 | for idx, name in enumerate(names):
260 | if name not in _cumulative:
261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
262 | _cumulative[name].add_(deltas[idx])
263 |
264 | # Return name-value pairs.
265 | return [(name, _cumulative[name]) for name in names]
266 |
267 | #----------------------------------------------------------------------------
268 | # Convenience.
269 |
270 | default_collector = Collector()
271 |
272 | #----------------------------------------------------------------------------
273 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import re
9 | import contextlib
10 | import numpy as np
11 | import torch
12 | import warnings
13 | import dnnlib
14 |
15 | #----------------------------------------------------------------------------
16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17 | # same constant is used multiple times.
18 |
19 | _constant_cache = dict()
20 |
21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22 | value = np.asarray(value)
23 | if shape is not None:
24 | shape = tuple(shape)
25 | if dtype is None:
26 | dtype = torch.get_default_dtype()
27 | if device is None:
28 | device = torch.device('cpu')
29 | if memory_format is None:
30 | memory_format = torch.contiguous_format
31 |
32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33 | tensor = _constant_cache.get(key, None)
34 | if tensor is None:
35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36 | if shape is not None:
37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38 | tensor = tensor.contiguous(memory_format=memory_format)
39 | _constant_cache[key] = tensor
40 | return tensor
41 |
42 | #----------------------------------------------------------------------------
43 | # Replace NaN/Inf with specified numerical values.
44 |
45 | try:
46 | nan_to_num = torch.nan_to_num # 1.8.0a0
47 | except AttributeError:
48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49 | assert isinstance(input, torch.Tensor)
50 | if posinf is None:
51 | posinf = torch.finfo(input.dtype).max
52 | if neginf is None:
53 | neginf = torch.finfo(input.dtype).min
54 | assert nan == 0
55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56 |
57 | #----------------------------------------------------------------------------
58 | # Symbolic assert.
59 |
60 | try:
61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62 | except AttributeError:
63 | symbolic_assert = torch.Assert # 1.7.0
64 |
65 | #----------------------------------------------------------------------------
66 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68 |
69 | @contextlib.contextmanager
70 | def suppress_tracer_warnings():
71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72 | warnings.filters.insert(0, flt)
73 | yield
74 | warnings.filters.remove(flt)
75 |
76 | #----------------------------------------------------------------------------
77 | # Assert that the shape of a tensor matches the given list of integers.
78 | # None indicates that the size of a dimension is allowed to vary.
79 | # Performs symbolic assertion when used in torch.jit.trace().
80 |
81 | def assert_shape(tensor, ref_shape):
82 | if tensor.ndim != len(ref_shape):
83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85 | if ref_size is None:
86 | pass
87 | elif isinstance(ref_size, torch.Tensor):
88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90 | elif isinstance(size, torch.Tensor):
91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93 | elif size != ref_size:
94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95 |
96 | #----------------------------------------------------------------------------
97 | # Function decorator that calls torch.autograd.profiler.record_function().
98 |
99 | def profiled_function(fn):
100 | def decorator(*args, **kwargs):
101 | with torch.autograd.profiler.record_function(fn.__name__):
102 | return fn(*args, **kwargs)
103 | decorator.__name__ = fn.__name__
104 | return decorator
105 |
106 | #----------------------------------------------------------------------------
107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
108 | # indefinitely, shuffling items as it goes.
109 |
110 | class InfiniteSampler(torch.utils.data.Sampler):
111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112 | assert len(dataset) > 0
113 | assert num_replicas > 0
114 | assert 0 <= rank < num_replicas
115 | assert 0 <= window_size <= 1
116 | super().__init__(dataset)
117 | self.dataset = dataset
118 | self.rank = rank
119 | self.num_replicas = num_replicas
120 | self.shuffle = shuffle
121 | self.seed = seed
122 | self.window_size = window_size
123 |
124 | def __iter__(self):
125 | order = np.arange(len(self.dataset))
126 | rnd = None
127 | window = 0
128 | if self.shuffle:
129 | rnd = np.random.RandomState(self.seed)
130 | rnd.shuffle(order)
131 | window = int(np.rint(order.size * self.window_size))
132 |
133 | idx = 0
134 | while True:
135 | i = idx % order.size
136 | if idx % self.num_replicas == self.rank:
137 | yield order[i]
138 | if window >= 2:
139 | j = (i - rnd.randint(window)) % order.size
140 | order[i], order[j] = order[j], order[i]
141 | idx += 1
142 |
143 | #----------------------------------------------------------------------------
144 | # Utilities for operating with torch.nn.Module parameters and buffers.
145 |
146 | def params_and_buffers(module):
147 | assert isinstance(module, torch.nn.Module)
148 | return list(module.parameters()) + list(module.buffers())
149 |
150 | def named_params_and_buffers(module):
151 | assert isinstance(module, torch.nn.Module)
152 | return list(module.named_parameters()) + list(module.named_buffers())
153 |
154 | @torch.no_grad()
155 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
156 | assert isinstance(src_module, torch.nn.Module)
157 | assert isinstance(dst_module, torch.nn.Module)
158 | src_tensors = dict(named_params_and_buffers(src_module))
159 | for name, tensor in named_params_and_buffers(dst_module):
160 | assert (name in src_tensors) or (not require_all)
161 | if name in src_tensors:
162 | tensor.copy_(src_tensors[name])
163 |
164 | #----------------------------------------------------------------------------
165 | # Context manager for easily enabling/disabling DistributedDataParallel
166 | # synchronization.
167 |
168 | @contextlib.contextmanager
169 | def ddp_sync(module, sync):
170 | assert isinstance(module, torch.nn.Module)
171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172 | yield
173 | else:
174 | with module.no_sync():
175 | yield
176 |
177 | #----------------------------------------------------------------------------
178 | # Check DistributedDataParallel consistency across processes.
179 |
180 | def check_ddp_consistency(module, ignore_regex=None):
181 | assert isinstance(module, torch.nn.Module)
182 | for name, tensor in named_params_and_buffers(module):
183 | fullname = type(module).__name__ + '.' + name
184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185 | continue
186 | tensor = tensor.detach()
187 | if tensor.is_floating_point():
188 | tensor = nan_to_num(tensor)
189 | other = tensor.clone()
190 | torch.distributed.broadcast(tensor=other, src=0)
191 | assert (tensor == other).all(), fullname
192 |
193 | #----------------------------------------------------------------------------
194 | # Print summary table of module hierarchy.
195 |
196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197 | assert isinstance(module, torch.nn.Module)
198 | assert not isinstance(module, torch.jit.ScriptModule)
199 | assert isinstance(inputs, (tuple, list))
200 |
201 | # Register hooks.
202 | entries = []
203 | nesting = [0]
204 | def pre_hook(_mod, _inputs):
205 | nesting[0] += 1
206 | def post_hook(mod, _inputs, outputs):
207 | nesting[0] -= 1
208 | if nesting[0] <= max_nesting:
209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214 |
215 | # Run module.
216 | outputs = module(*inputs)
217 | for hook in hooks:
218 | hook.remove()
219 |
220 | # Identify unique outputs, parameters, and buffers.
221 | tensors_seen = set()
222 | for e in entries:
223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227 |
228 | # Filter out redundant entries.
229 | if skip_redundant:
230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231 |
232 | # Construct table.
233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234 | rows += [['---'] * len(rows[0])]
235 | param_total = 0
236 | buffer_total = 0
237 | submodule_names = {mod: name for name, mod in module.named_modules()}
238 | for e in entries:
239 | name = '' if e.mod is module else submodule_names[e.mod]
240 | param_size = sum(t.numel() for t in e.unique_params)
241 | buffer_size = sum(t.numel() for t in e.unique_buffers)
242 | output_shapes = [str(list(t.shape)) for t in e.outputs]
243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244 | rows += [[
245 | name + (':0' if len(e.outputs) >= 2 else ''),
246 | str(param_size) if param_size else '-',
247 | str(buffer_size) if buffer_size else '-',
248 | (output_shapes + ['-'])[0],
249 | (output_dtypes + ['-'])[0],
250 | ]]
251 | for idx in range(1, len(e.outputs)):
252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253 | param_total += param_size
254 | buffer_total += buffer_size
255 | rows += [['---'] * len(rows[0])]
256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257 |
258 | # Print table.
259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260 | print()
261 | for row in rows:
262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263 | print()
264 | print(vars(module))
265 | print()
266 | return outputs
267 |
268 | #----------------------------------------------------------------------------
269 |
--------------------------------------------------------------------------------
/training/training_loop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Main training loop."""
9 |
10 | import os
11 | import time
12 | import copy
13 | import json
14 | import pickle
15 | import psutil
16 | import numpy as np
17 | import torch
18 | import dnnlib
19 | from torch_utils import distributed as dist
20 | from torch_utils import training_stats
21 | from torch_utils import misc
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | def training_loop(
26 | run_dir = '.', # Output directory.
27 | dataset_kwargs = {}, # Options for training set.
28 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
29 | network_kwargs = {}, # Options for model and preconditioning.
30 | loss_kwargs = {}, # Options for loss function.
31 | optimizer_kwargs = {}, # Options for optimizer.
32 | augment_kwargs = None, # Options for augmentation pipeline, None = disable.
33 | seed = 0, # Global random seed.
34 | batch_size = 512, # Total batch size for one training iteration.
35 | batch_gpu = None, # Limit batch size per GPU, None = no limit.
36 | total_kimg = 200000, # Training duration, measured in thousands of training images.
37 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
38 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
39 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
40 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
41 | kimg_per_tick = 50, # Interval of progress prints.
42 | snapshot_ticks = 50, # How often to save network snapshots, None = disable.
43 | state_dump_ticks = 500, # How often to dump training state, None = disable.
44 | resume_pkl = None, # Start from the given network snapshot, None = random initialization.
45 | resume_state_dump = None, # Start from the given training state, None = reset training state.
46 | resume_kimg = 0, # Start from the given training progress.
47 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
48 | device = torch.device('cuda'),
49 | ):
50 | # Initialize.
51 | start_time = time.time()
52 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
53 | torch.manual_seed(np.random.randint(1 << 31))
54 | torch.backends.cudnn.benchmark = cudnn_benchmark
55 | torch.backends.cudnn.allow_tf32 = False
56 | torch.backends.cuda.matmul.allow_tf32 = False
57 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
58 |
59 | # Select batch size per GPU.
60 | batch_gpu_total = batch_size // dist.get_world_size()
61 | if batch_gpu is None or batch_gpu > batch_gpu_total:
62 | batch_gpu = batch_gpu_total
63 | num_accumulation_rounds = batch_gpu_total // batch_gpu
64 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
65 |
66 | # Load dataset.
67 | dist.print0('Loading dataset...')
68 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
69 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
70 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
71 |
72 | # Construct network.
73 | dist.print0('Constructing network...')
74 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
75 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
76 | net.train().requires_grad_(True).to(device)
77 | if dist.get_rank() == 0:
78 | with torch.no_grad():
79 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
80 | sigma = torch.ones([batch_gpu], device=device)
81 | labels = torch.zeros([batch_gpu, net.label_dim], device=device)
82 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
83 |
84 | # Setup optimizer.
85 | dist.print0('Setting up optimizer...')
86 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
87 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
88 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
89 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
90 | ema = copy.deepcopy(net).eval().requires_grad_(False)
91 |
92 | # Resume training from previous snapshot.
93 | if resume_pkl is not None:
94 | dist.print0(f'Loading network weights from "{resume_pkl}"...')
95 | if dist.get_rank() != 0:
96 | torch.distributed.barrier() # rank 0 goes first
97 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
98 | data = pickle.load(f)
99 | if dist.get_rank() == 0:
100 | torch.distributed.barrier() # other ranks follow
101 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
102 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
103 | del data # conserve memory
104 | if resume_state_dump:
105 | dist.print0(f'Loading training state from "{resume_state_dump}"...')
106 | data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
107 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
108 | optimizer.load_state_dict(data['optimizer_state'])
109 | del data # conserve memory
110 |
111 | # Train.
112 | dist.print0(f'Training for {total_kimg} kimg...')
113 | dist.print0()
114 | cur_nimg = resume_kimg * 1000
115 | cur_tick = 0
116 | tick_start_nimg = cur_nimg
117 | tick_start_time = time.time()
118 | maintenance_time = tick_start_time - start_time
119 | dist.update_progress(cur_nimg // 1000, total_kimg)
120 | stats_jsonl = None
121 | while True:
122 |
123 | # Accumulate gradients.
124 | optimizer.zero_grad(set_to_none=True)
125 | for round_idx in range(num_accumulation_rounds):
126 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
127 | images, labels = next(dataset_iterator)
128 | images = images.to(device).to(torch.float32) / 127.5 - 1
129 | labels = labels.to(device)
130 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
131 | training_stats.report('Loss/loss', loss)
132 | loss.sum().mul(loss_scaling / batch_gpu_total).backward()
133 |
134 | # Update weights.
135 | for g in optimizer.param_groups:
136 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
137 | for param in net.parameters():
138 | if param.grad is not None:
139 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
140 | optimizer.step()
141 |
142 | # Update EMA.
143 | ema_halflife_nimg = ema_halflife_kimg * 1000
144 | if ema_rampup_ratio is not None:
145 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
146 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
147 | for p_ema, p_net in zip(ema.parameters(), net.parameters()):
148 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
149 |
150 | # Perform maintenance tasks once per tick.
151 | cur_nimg += batch_size
152 | done = (cur_nimg >= total_kimg * 1000)
153 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
154 | continue
155 |
156 | # Print status line, accumulating the same information in training_stats.
157 | tick_end_time = time.time()
158 | fields = []
159 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
160 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
161 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
162 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
163 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
164 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
165 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
166 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
167 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
168 | fields += [f"loss {loss.sum().item():<10.2f}"]
169 | torch.cuda.reset_peak_memory_stats()
170 | dist.print0(' '.join(fields))
171 |
172 | # Check for abort.
173 | if (not done) and dist.should_stop():
174 | done = True
175 | dist.print0()
176 | dist.print0('Aborting...')
177 |
178 | # Save network snapshot.
179 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
180 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
181 | for key, value in data.items():
182 | if isinstance(value, torch.nn.Module):
183 | value = copy.deepcopy(value).eval().requires_grad_(False)
184 | misc.check_ddp_consistency(value)
185 | data[key] = value.cpu()
186 | del value # conserve memory
187 | if dist.get_rank() == 0:
188 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
189 | pickle.dump(data, f)
190 | del data # conserve memory
191 |
192 | # Save full dump of the training state.
193 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
194 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
195 |
196 | # Update logs.
197 | training_stats.default_collector.update()
198 | if dist.get_rank() == 0:
199 | if stats_jsonl is None:
200 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
201 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
202 | stats_jsonl.flush()
203 | dist.update_progress(cur_nimg // 1000, total_kimg)
204 |
205 | # Update state.
206 | cur_tick += 1
207 | tick_start_nimg = cur_nimg
208 | tick_start_time = time.time()
209 | maintenance_time = tick_start_time - tick_end_time
210 | if done:
211 | break
212 |
213 | # Done.
214 | dist.print0()
215 | dist.print0('Exiting...')
216 |
217 | #----------------------------------------------------------------------------
218 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Garena Online Private Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Train diffusion-based generative model using the techniques described in the
16 | paper "Elucidating the Design Space of Diffusion-Based Generative Models"."""
17 |
18 | import os
19 | import re
20 | import json
21 | import click
22 | import torch
23 | import dnnlib
24 | from torch_utils import distributed as dist
25 | from training import training_loop
26 |
27 | import warnings
28 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
29 |
30 | #----------------------------------------------------------------------------
31 | # Parse a comma separated list of numbers or ranges and return a list of ints.
32 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
33 |
34 | def parse_int_list(s):
35 | if isinstance(s, list): return s
36 | ranges = []
37 | range_re = re.compile(r'^(\d+)-(\d+)$')
38 | for p in s.split(','):
39 | m = range_re.match(p)
40 | if m:
41 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
42 | else:
43 | ranges.append(int(p))
44 | return ranges
45 |
46 | #----------------------------------------------------------------------------
47 |
48 | @click.command()
49 |
50 | # Main options.
51 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True)
52 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
53 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
54 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp', type=click.Choice(['ddpmpp', 'ncsnpp']), default='ddpmpp', show_default=True)
55 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm', 'fdm_edm', 'fdm_vp', 'fdm_ve']), default='fdm_edm', show_default=True)
56 |
57 | # Hyperparameters.
58 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
59 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
60 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
61 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
62 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
63 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
64 | @click.option('--lr_rampup', help='Learning rate rampup', metavar='FLOAT', type=click.FloatRange(min=0, max=1000), default=10, show_default=True)
65 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
66 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
67 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
68 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
69 | @click.option('--warmup_ite', help='Loss weight warmup iteration', metavar='FLOAT', type=float, default=None, show_default=True)
70 | @click.option('--fdm_multiplier', help='FDM multiplier', metavar='FLOAT', type=float, default=2.0)
71 |
72 | # Performance-related.
73 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
74 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
75 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
76 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
77 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=8, show_default=True)
78 |
79 | # I/O-related.
80 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
81 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
82 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
83 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=200, show_default=True)
84 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True)
85 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
86 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
87 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
88 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
89 |
90 | def main(**kwargs):
91 | opts = dnnlib.EasyDict(kwargs)
92 | torch.multiprocessing.set_start_method('spawn')
93 | dist.init()
94 |
95 | # Initialize config dict.
96 | c = dnnlib.EasyDict()
97 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
98 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
99 | c.network_kwargs = dnnlib.EasyDict()
100 | c.loss_kwargs = dnnlib.EasyDict()
101 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
102 |
103 | # Validate dataset options.
104 | try:
105 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
106 | dataset_name = dataset_obj.name
107 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
108 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
109 | if opts.cond and not dataset_obj.has_labels:
110 | raise click.ClickException('--cond=True requires labels specified in dataset.json')
111 | del dataset_obj # conserve memory
112 | except IOError as err:
113 | raise click.ClickException(f'--data: {err}')
114 |
115 | # Network architecture.
116 | if opts.arch == 'ddpmpp':
117 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
118 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
119 | elif opts.arch == 'ncsnpp':
120 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
121 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2])
122 | else:
123 | raise(f'Unknown architecture: {opts.arch}')
124 |
125 | # Preconditioning & loss function.
126 | if opts.precond == 'vp':
127 | c.network_kwargs.class_name = 'training.networks.VPPrecond'
128 | c.loss_kwargs.class_name = 'training.loss.VPLoss'
129 | elif opts.precond == 've':
130 | c.network_kwargs.class_name = 'training.networks.VEPrecond'
131 | c.loss_kwargs.class_name = 'training.loss.VELoss'
132 | elif opts.precond == 'edm':
133 | c.network_kwargs.class_name = 'training.networks.EDMPrecond'
134 | c.loss_kwargs.class_name = 'training.loss.EDMLoss'
135 | elif opts.precond == 'fdm_vp':
136 | # VP-FDM
137 | c.network_kwargs.class_name = 'training.networks.FDM_VPPrecond'
138 | c.loss_kwargs.class_name = 'training.loss.VPLoss'
139 | c.network_kwargs.update(fdm_multiplier=opts.fdm_multiplier)
140 | c.loss_kwargs.update(warmup_ite=opts.warmup_ite)
141 | elif opts.precond == 'fdm_ve':
142 | # VE-FDM
143 | c.network_kwargs.class_name = 'training.networks.FDM_VEPrecond'
144 | c.loss_kwargs.class_name = 'training.loss.VELoss'
145 | c.network_kwargs.update(fdm_multiplier=opts.fdm_multiplier)
146 | c.loss_kwargs.update(warmup_ite=opts.warmup_ite)
147 | else:
148 | # EDM-FDM
149 | assert opts.precond == 'fdm_edm'
150 | c.network_kwargs.class_name = 'training.networks.FDM_EDMPrecond'
151 | c.loss_kwargs.class_name = 'training.loss.EDMLoss'
152 | c.network_kwargs.update(fdm_multiplier=opts.fdm_multiplier)
153 | c.loss_kwargs.update(warmup_ite=opts.warmup_ite)
154 |
155 | # Network options.
156 | if opts.cbase is not None:
157 | c.network_kwargs.model_channels = opts.cbase
158 | if opts.cres is not None:
159 | c.network_kwargs.channel_mult = opts.cres
160 | if opts.augment:
161 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
162 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
163 | c.network_kwargs.augment_dim = 9
164 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
165 |
166 | # Training options.
167 | c.total_kimg = max(int(opts.duration * 1000), 1)
168 | c.ema_halflife_kimg = int(opts.ema * 1000)
169 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
170 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
171 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump)
172 |
173 | # Random seed.
174 | if opts.seed is not None:
175 | c.seed = opts.seed
176 | else:
177 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
178 | torch.distributed.broadcast(seed, src=0)
179 | c.seed = int(seed)
180 |
181 | # Transfer learning and resume.
182 | if opts.transfer is not None:
183 | if opts.resume is not None:
184 | raise click.ClickException('--transfer and --resume cannot be specified at the same time')
185 | c.resume_pkl = opts.transfer
186 | c.ema_rampup_ratio = None
187 | elif opts.resume is not None:
188 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
189 | if not match or not os.path.isfile(opts.resume):
190 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
191 | c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
192 | c.resume_kimg = int(match.group(1))
193 | c.resume_state_dump = opts.resume
194 |
195 | # Description string.
196 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
197 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
198 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
199 | if opts.desc is not None:
200 | desc += f'-{opts.desc}'
201 |
202 | # Pick output directory.
203 | if dist.get_rank() == -1: # != 0:
204 | c.run_dir = None
205 | elif opts.nosubdir:
206 | c.run_dir = opts.outdir
207 | else:
208 | prev_run_dirs = []
209 | if os.path.isdir(opts.outdir):
210 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
211 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
212 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
213 | cur_run_id = max(prev_run_ids, default=-1) + 1
214 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
215 | assert not os.path.exists(c.run_dir)
216 |
217 | # Print options.
218 | dist.print0()
219 | dist.print0('Training options:')
220 | dist.print0(json.dumps(c, indent=2))
221 | dist.print0()
222 | dist.print0(f'Output directory: {c.run_dir}')
223 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
224 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
225 | dist.print0(f'Network architecture: {opts.arch}')
226 | dist.print0(f'Preconditioning & loss: {opts.precond}')
227 | dist.print0(f'Number of GPUs: {dist.get_world_size()}')
228 | dist.print0(f'Batch size: {c.batch_size}')
229 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
230 | dist.print0(f'network_kwargs: {c.network_kwargs}')
231 | dist.print0(f'loss_kwargs: {c.loss_kwargs}')
232 | dist.print0()
233 |
234 | # Dry run?
235 | if opts.dry_run:
236 | dist.print0('Dry run; exiting.')
237 | return
238 |
239 | # Create output directory.
240 | dist.print0('Creating output directory...')
241 | if dist.get_rank() == 0:
242 | os.makedirs(c.run_dir, exist_ok=True)
243 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
244 | json.dump(c, f, indent=2)
245 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
246 |
247 | # Train.
248 | training_loop.training_loop(**c, lr_rampup_kimg=opts.lr_rampup * 1000)
249 |
250 | #----------------------------------------------------------------------------
251 |
252 | if __name__ == "__main__":
253 | main()
254 |
255 | #----------------------------------------------------------------------------
256 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import tempfile
27 | import urllib
28 | import urllib.request
29 | import uuid
30 |
31 | from distutils.util import strtobool
32 | from typing import Any, List, Tuple, Union, Optional
33 |
34 |
35 | # Util classes
36 | # ------------------------------------------------------------------------------------------
37 |
38 |
39 | class EasyDict(dict):
40 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41 |
42 | def __getattr__(self, name: str) -> Any:
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name: str, value: Any) -> None:
49 | self[name] = value
50 |
51 | def __delattr__(self, name: str) -> None:
52 | del self[name]
53 |
54 |
55 | class Logger(object):
56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57 |
58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59 | self.file = None
60 |
61 | if file_name is not None:
62 | self.file = open(file_name, file_mode)
63 |
64 | self.should_flush = should_flush
65 | self.stdout = sys.stdout
66 | self.stderr = sys.stderr
67 |
68 | sys.stdout = self
69 | sys.stderr = self
70 |
71 | def __enter__(self) -> "Logger":
72 | return self
73 |
74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75 | self.close()
76 |
77 | def write(self, text: Union[str, bytes]) -> None:
78 | """Write text to stdout (and a file) and optionally flush."""
79 | if isinstance(text, bytes):
80 | text = text.decode()
81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82 | return
83 |
84 | if self.file is not None:
85 | self.file.write(text)
86 |
87 | self.stdout.write(text)
88 |
89 | if self.should_flush:
90 | self.flush()
91 |
92 | def flush(self) -> None:
93 | """Flush written text to both stdout and a file, if open."""
94 | if self.file is not None:
95 | self.file.flush()
96 |
97 | self.stdout.flush()
98 |
99 | def close(self) -> None:
100 | """Flush, close possible files, and remove stdout/stderr mirroring."""
101 | self.flush()
102 |
103 | # if using multiple loggers, prevent closing in wrong order
104 | if sys.stdout is self:
105 | sys.stdout = self.stdout
106 | if sys.stderr is self:
107 | sys.stderr = self.stderr
108 |
109 | if self.file is not None:
110 | self.file.close()
111 | self.file = None
112 |
113 |
114 | # Cache directories
115 | # ------------------------------------------------------------------------------------------
116 |
117 | _dnnlib_cache_dir = None
118 |
119 | def set_cache_dir(path: str) -> None:
120 | global _dnnlib_cache_dir
121 | _dnnlib_cache_dir = path
122 |
123 | def make_cache_dir_path(*paths: str) -> str:
124 | if _dnnlib_cache_dir is not None:
125 | return os.path.join(_dnnlib_cache_dir, *paths)
126 | if 'DNNLIB_CACHE_DIR' in os.environ:
127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128 | if 'HOME' in os.environ:
129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130 | if 'USERPROFILE' in os.environ:
131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133 |
134 | # Small util functions
135 | # ------------------------------------------------------------------------------------------
136 |
137 |
138 | def format_time(seconds: Union[int, float]) -> str:
139 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140 | s = int(np.rint(seconds))
141 |
142 | if s < 60:
143 | return "{0}s".format(s)
144 | elif s < 60 * 60:
145 | return "{0}m {1:02}s".format(s // 60, s % 60)
146 | elif s < 24 * 60 * 60:
147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148 | else:
149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150 |
151 |
152 | def format_time_brief(seconds: Union[int, float]) -> str:
153 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154 | s = int(np.rint(seconds))
155 |
156 | if s < 60:
157 | return "{0}s".format(s)
158 | elif s < 60 * 60:
159 | return "{0}m {1:02}s".format(s // 60, s % 60)
160 | elif s < 24 * 60 * 60:
161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162 | else:
163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164 |
165 |
166 | def ask_yes_no(question: str) -> bool:
167 | """Ask the user the question until the user inputs a valid answer."""
168 | while True:
169 | try:
170 | print("{0} [y/n]".format(question))
171 | return strtobool(input().lower())
172 | except ValueError:
173 | pass
174 |
175 |
176 | def tuple_product(t: Tuple) -> Any:
177 | """Calculate the product of the tuple elements."""
178 | result = 1
179 |
180 | for v in t:
181 | result *= v
182 |
183 | return result
184 |
185 |
186 | _str_to_ctype = {
187 | "uint8": ctypes.c_ubyte,
188 | "uint16": ctypes.c_uint16,
189 | "uint32": ctypes.c_uint32,
190 | "uint64": ctypes.c_uint64,
191 | "int8": ctypes.c_byte,
192 | "int16": ctypes.c_int16,
193 | "int32": ctypes.c_int32,
194 | "int64": ctypes.c_int64,
195 | "float32": ctypes.c_float,
196 | "float64": ctypes.c_double
197 | }
198 |
199 |
200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202 | type_str = None
203 |
204 | if isinstance(type_obj, str):
205 | type_str = type_obj
206 | elif hasattr(type_obj, "__name__"):
207 | type_str = type_obj.__name__
208 | elif hasattr(type_obj, "name"):
209 | type_str = type_obj.name
210 | else:
211 | raise RuntimeError("Cannot infer type name from input")
212 |
213 | assert type_str in _str_to_ctype.keys()
214 |
215 | my_dtype = np.dtype(type_str)
216 | my_ctype = _str_to_ctype[type_str]
217 |
218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219 |
220 | return my_dtype, my_ctype
221 |
222 |
223 | def is_pickleable(obj: Any) -> bool:
224 | try:
225 | with io.BytesIO() as stream:
226 | pickle.dump(obj, stream)
227 | return True
228 | except:
229 | return False
230 |
231 |
232 | # Functionality to import modules/objects by name, and call functions by name
233 | # ------------------------------------------------------------------------------------------
234 |
235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236 | """Searches for the underlying module behind the name to some python object.
237 | Returns the module and the object name (original name with module part removed)."""
238 |
239 | # allow convenience shorthands, substitute them by full names
240 | obj_name = re.sub("^np.", "numpy.", obj_name)
241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242 |
243 | # list alternatives for (module_name, local_obj_name)
244 | parts = obj_name.split(".")
245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246 |
247 | # try each alternative in turn
248 | for module_name, local_obj_name in name_pairs:
249 | try:
250 | module = importlib.import_module(module_name) # may raise ImportError
251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
252 | return module, local_obj_name
253 | except:
254 | pass
255 |
256 | # maybe some of the modules themselves contain errors?
257 | for module_name, _local_obj_name in name_pairs:
258 | try:
259 | importlib.import_module(module_name) # may raise ImportError
260 | except ImportError:
261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262 | raise
263 |
264 | # maybe the requested attribute is missing?
265 | for module_name, local_obj_name in name_pairs:
266 | try:
267 | module = importlib.import_module(module_name) # may raise ImportError
268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
269 | except ImportError:
270 | pass
271 |
272 | # we are out of luck, but we have no idea why
273 | raise ImportError(obj_name)
274 |
275 |
276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277 | """Traverses the object name and returns the last (rightmost) python object."""
278 | if obj_name == '':
279 | return module
280 | obj = module
281 | for part in obj_name.split("."):
282 | obj = getattr(obj, part)
283 | return obj
284 |
285 |
286 | def get_obj_by_name(name: str) -> Any:
287 | """Finds the python object with the given name."""
288 | module, obj_name = get_module_from_obj_name(name)
289 | return get_obj_from_module(module, obj_name)
290 |
291 |
292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293 | """Finds the python object with the given name and calls it as a function."""
294 | assert func_name is not None
295 | func_obj = get_obj_by_name(func_name)
296 | assert callable(func_obj)
297 | return func_obj(*args, **kwargs)
298 |
299 |
300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301 | """Finds the python class with the given name and constructs it with the given arguments."""
302 | return call_func_by_name(*args, func_name=class_name, **kwargs)
303 |
304 |
305 | def get_module_dir_by_obj_name(obj_name: str) -> str:
306 | """Get the directory path of the module containing the given object name."""
307 | module, _ = get_module_from_obj_name(obj_name)
308 | return os.path.dirname(inspect.getfile(module))
309 |
310 |
311 | def is_top_level_function(obj: Any) -> bool:
312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314 |
315 |
316 | def get_top_level_function_name(obj: Any) -> str:
317 | """Return the fully-qualified name of a top-level function."""
318 | assert is_top_level_function(obj)
319 | module = obj.__module__
320 | if module == '__main__':
321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322 | return module + "." + obj.__name__
323 |
324 |
325 | # File system helpers
326 | # ------------------------------------------------------------------------------------------
327 |
328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329 | """List all files recursively in a given directory while ignoring given file and directory names.
330 | Returns list of tuples containing both absolute and relative paths."""
331 | assert os.path.isdir(dir_path)
332 | base_name = os.path.basename(os.path.normpath(dir_path))
333 |
334 | if ignores is None:
335 | ignores = []
336 |
337 | result = []
338 |
339 | for root, dirs, files in os.walk(dir_path, topdown=True):
340 | for ignore_ in ignores:
341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342 |
343 | # dirs need to be edited in-place
344 | for d in dirs_to_remove:
345 | dirs.remove(d)
346 |
347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348 |
349 | absolute_paths = [os.path.join(root, f) for f in files]
350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351 |
352 | if add_base_to_relative:
353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354 |
355 | assert len(absolute_paths) == len(relative_paths)
356 | result += zip(absolute_paths, relative_paths)
357 |
358 | return result
359 |
360 |
361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362 | """Takes in a list of tuples of (src, dst) paths and copies files.
363 | Will create all necessary directories."""
364 | for file in files:
365 | target_dir_name = os.path.dirname(file[1])
366 |
367 | # will create all intermediate-level directories
368 | if not os.path.exists(target_dir_name):
369 | os.makedirs(target_dir_name)
370 |
371 | shutil.copyfile(file[0], file[1])
372 |
373 |
374 | # URL helpers
375 | # ------------------------------------------------------------------------------------------
376 |
377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378 | """Determine whether the given object is a valid URL string."""
379 | if not isinstance(obj, str) or not "://" in obj:
380 | return False
381 | if allow_file_urls and obj.startswith('file://'):
382 | return True
383 | try:
384 | res = requests.compat.urlparse(obj)
385 | if not res.scheme or not res.netloc or not "." in res.netloc:
386 | return False
387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388 | if not res.scheme or not res.netloc or not "." in res.netloc:
389 | return False
390 | except:
391 | return False
392 | return True
393 |
394 |
395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396 | """Download the given URL and return a binary-mode file object to access the data."""
397 | assert num_attempts >= 1
398 | assert not (return_filename and (not cache))
399 |
400 | # Doesn't look like an URL scheme so interpret it as a local filename.
401 | if not re.match('^[a-z]+://', url):
402 | return url if return_filename else open(url, "rb")
403 |
404 | # Handle file URLs. This code handles unusual file:// patterns that
405 | # arise on Windows:
406 | #
407 | # file:///c:/foo.txt
408 | #
409 | # which would translate to a local '/c:/foo.txt' filename that's
410 | # invalid. Drop the forward slash for such pathnames.
411 | #
412 | # If you touch this code path, you should test it on both Linux and
413 | # Windows.
414 | #
415 | # Some internet resources suggest using urllib.request.url2pathname() but
416 | # but that converts forward slashes to backslashes and this causes
417 | # its own set of problems.
418 | if url.startswith('file://'):
419 | filename = urllib.parse.urlparse(url).path
420 | if re.match(r'^/[a-zA-Z]:', filename):
421 | filename = filename[1:]
422 | return filename if return_filename else open(filename, "rb")
423 |
424 | assert is_url(url)
425 |
426 | # Lookup from cache.
427 | if cache_dir is None:
428 | cache_dir = make_cache_dir_path('downloads')
429 |
430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431 | if cache:
432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433 | if len(cache_files) == 1:
434 | filename = cache_files[0]
435 | return filename if return_filename else open(filename, "rb")
436 |
437 | # Download.
438 | url_name = None
439 | url_data = None
440 | with requests.Session() as session:
441 | if verbose:
442 | print("Downloading %s ..." % url, end="", flush=True)
443 | for attempts_left in reversed(range(num_attempts)):
444 | try:
445 | with session.get(url) as res:
446 | res.raise_for_status()
447 | if len(res.content) == 0:
448 | raise IOError("No data received")
449 |
450 | if len(res.content) < 8192:
451 | content_str = res.content.decode("utf-8")
452 | if "download_warning" in res.headers.get("Set-Cookie", ""):
453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454 | if len(links) == 1:
455 | url = requests.compat.urljoin(url, links[0])
456 | raise IOError("Google Drive virus checker nag")
457 | if "Google Drive - Quota exceeded" in content_str:
458 | raise IOError("Google Drive download quota exceeded -- please try again later")
459 |
460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461 | url_name = match[1] if match else url
462 | url_data = res.content
463 | if verbose:
464 | print(" done")
465 | break
466 | except KeyboardInterrupt:
467 | raise
468 | except:
469 | if not attempts_left:
470 | if verbose:
471 | print(" failed")
472 | raise
473 | if verbose:
474 | print(".", end="", flush=True)
475 |
476 | # Save to cache.
477 | if cache:
478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479 | safe_name = safe_name[:min(len(safe_name), 128)]
480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482 | os.makedirs(cache_dir, exist_ok=True)
483 | with open(temp_file, "wb") as f:
484 | f.write(url_data)
485 | os.replace(temp_file, cache_file) # atomic
486 | if return_filename:
487 | return cache_file
488 |
489 | # Return data as file object.
490 | assert not return_filename
491 | return io.BytesIO(url_data)
492 |
--------------------------------------------------------------------------------
/dataset_tool.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Tool for creating ZIP/PNG based datasets."""
9 |
10 | import functools
11 | import gzip
12 | import io
13 | import json
14 | import os
15 | import pickle
16 | import re
17 | import sys
18 | import tarfile
19 | import zipfile
20 | from pathlib import Path
21 | from typing import Callable, Optional, Tuple, Union
22 | import click
23 | import numpy as np
24 | import PIL.Image
25 | from tqdm import tqdm
26 |
27 | if not hasattr(PIL.Image, 'Resampling'): # Pillow<9.0
28 | PIL.Image.Resampling = PIL.Image
29 |
30 | #----------------------------------------------------------------------------
31 | # Parse a 'M,N' or 'MxN' integer tuple.
32 | # Example: '4x2' returns (4,2)
33 |
34 | def parse_tuple(s: str) -> Tuple[int, int]:
35 | m = re.match(r'^(\d+)[x,](\d+)$', s)
36 | if m:
37 | return int(m.group(1)), int(m.group(2))
38 | raise click.ClickException(f'cannot parse tuple {s}')
39 |
40 | #----------------------------------------------------------------------------
41 |
42 | def maybe_min(a: int, b: Optional[int]) -> int:
43 | if b is not None:
44 | return min(a, b)
45 | return a
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def file_ext(name: Union[str, Path]) -> str:
50 | return str(name).split('.')[-1]
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | def is_image_ext(fname: Union[str, Path]) -> bool:
55 | ext = file_ext(fname).lower()
56 | return f'.{ext}' in PIL.Image.EXTENSION
57 |
58 | #----------------------------------------------------------------------------
59 |
60 | def open_image_folder(source_dir, *, max_images: Optional[int]):
61 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
62 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
63 | max_idx = maybe_min(len(input_images), max_images)
64 |
65 | # Load labels.
66 | labels = dict()
67 | meta_fname = os.path.join(source_dir, 'dataset.json')
68 | if os.path.isfile(meta_fname):
69 | with open(meta_fname, 'r') as file:
70 | data = json.load(file)['labels']
71 | if data is not None:
72 | labels = {x[0]: x[1] for x in data}
73 |
74 | # No labels available => determine from top-level directory names.
75 | if len(labels) == 0:
76 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
77 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
78 | if len(toplevel_indices) > 1:
79 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
80 |
81 | def iterate_images():
82 | for idx, fname in enumerate(input_images):
83 | img = np.array(PIL.Image.open(fname))
84 | yield dict(img=img, label=labels.get(arch_fnames.get(fname)))
85 | if idx >= max_idx - 1:
86 | break
87 | return max_idx, iterate_images()
88 |
89 | #----------------------------------------------------------------------------
90 |
91 | def open_image_zip(source, *, max_images: Optional[int]):
92 | with zipfile.ZipFile(source, mode='r') as z:
93 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
94 | max_idx = maybe_min(len(input_images), max_images)
95 |
96 | # Load labels.
97 | labels = dict()
98 | if 'dataset.json' in z.namelist():
99 | with z.open('dataset.json', 'r') as file:
100 | data = json.load(file)['labels']
101 | if data is not None:
102 | labels = {x[0]: x[1] for x in data}
103 |
104 | def iterate_images():
105 | with zipfile.ZipFile(source, mode='r') as z:
106 | for idx, fname in enumerate(input_images):
107 | with z.open(fname, 'r') as file:
108 | img = np.array(PIL.Image.open(file))
109 | yield dict(img=img, label=labels.get(fname))
110 | if idx >= max_idx - 1:
111 | break
112 | return max_idx, iterate_images()
113 |
114 | #----------------------------------------------------------------------------
115 |
116 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
117 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python
118 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb
119 |
120 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
121 | max_idx = maybe_min(txn.stat()['entries'], max_images)
122 |
123 | def iterate_images():
124 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
125 | for idx, (_key, value) in enumerate(txn.cursor()):
126 | try:
127 | try:
128 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
129 | if img is None:
130 | raise IOError('cv2.imdecode failed')
131 | img = img[:, :, ::-1] # BGR => RGB
132 | except IOError:
133 | img = np.array(PIL.Image.open(io.BytesIO(value)))
134 | yield dict(img=img, label=None)
135 | if idx >= max_idx - 1:
136 | break
137 | except:
138 | print(sys.exc_info()[1])
139 |
140 | return max_idx, iterate_images()
141 |
142 | #----------------------------------------------------------------------------
143 |
144 | def open_cifar10(tarball: str, *, max_images: Optional[int]):
145 | images = []
146 | labels = []
147 |
148 | with tarfile.open(tarball, 'r:gz') as tar:
149 | for batch in range(1, 6):
150 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
151 | with tar.extractfile(member) as file:
152 | data = pickle.load(file, encoding='latin1')
153 | images.append(data['data'].reshape(-1, 3, 32, 32))
154 | labels.append(data['labels'])
155 |
156 | images = np.concatenate(images)
157 | labels = np.concatenate(labels)
158 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
159 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
160 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
161 | assert np.min(images) == 0 and np.max(images) == 255
162 | assert np.min(labels) == 0 and np.max(labels) == 9
163 |
164 | max_idx = maybe_min(len(images), max_images)
165 |
166 | def iterate_images():
167 | for idx, img in enumerate(images):
168 | yield dict(img=img, label=int(labels[idx]))
169 | if idx >= max_idx - 1:
170 | break
171 |
172 | return max_idx, iterate_images()
173 |
174 | #----------------------------------------------------------------------------
175 |
176 | def open_mnist(images_gz: str, *, max_images: Optional[int]):
177 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
178 | assert labels_gz != images_gz
179 | images = []
180 | labels = []
181 |
182 | with gzip.open(images_gz, 'rb') as f:
183 | images = np.frombuffer(f.read(), np.uint8, offset=16)
184 | with gzip.open(labels_gz, 'rb') as f:
185 | labels = np.frombuffer(f.read(), np.uint8, offset=8)
186 |
187 | images = images.reshape(-1, 28, 28)
188 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
189 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
190 | assert labels.shape == (60000,) and labels.dtype == np.uint8
191 | assert np.min(images) == 0 and np.max(images) == 255
192 | assert np.min(labels) == 0 and np.max(labels) == 9
193 |
194 | max_idx = maybe_min(len(images), max_images)
195 |
196 | def iterate_images():
197 | for idx, img in enumerate(images):
198 | yield dict(img=img, label=int(labels[idx]))
199 | if idx >= max_idx - 1:
200 | break
201 |
202 | return max_idx, iterate_images()
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def make_transform(
207 | transform: Optional[str],
208 | output_width: Optional[int],
209 | output_height: Optional[int]
210 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
211 | def scale(width, height, img):
212 | w = img.shape[1]
213 | h = img.shape[0]
214 | if width == w and height == h:
215 | return img
216 | img = PIL.Image.fromarray(img)
217 | ww = width if width is not None else w
218 | hh = height if height is not None else h
219 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
220 | return np.array(img)
221 |
222 | def center_crop(width, height, img):
223 | crop = np.min(img.shape[:2])
224 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
225 | if img.ndim == 2:
226 | img = img[:, :, np.newaxis].repeat(3, axis=2)
227 | img = PIL.Image.fromarray(img, 'RGB')
228 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
229 | return np.array(img)
230 |
231 | def center_crop_wide(width, height, img):
232 | ch = int(np.round(width * img.shape[0] / img.shape[1]))
233 | if img.shape[1] < width or ch < height:
234 | return None
235 |
236 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
237 | if img.ndim == 2:
238 | img = img[:, :, np.newaxis].repeat(3, axis=2)
239 | img = PIL.Image.fromarray(img, 'RGB')
240 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
241 | img = np.array(img)
242 |
243 | canvas = np.zeros([width, width, 3], dtype=np.uint8)
244 | canvas[(width - height) // 2 : (width + height) // 2, :] = img
245 | return canvas
246 |
247 | if transform is None:
248 | return functools.partial(scale, output_width, output_height)
249 | if transform == 'center-crop':
250 | if output_width is None or output_height is None:
251 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
252 | return functools.partial(center_crop, output_width, output_height)
253 | if transform == 'center-crop-wide':
254 | if output_width is None or output_height is None:
255 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
256 | return functools.partial(center_crop_wide, output_width, output_height)
257 | assert False, 'unknown transform'
258 |
259 | #----------------------------------------------------------------------------
260 |
261 | def open_dataset(source, *, max_images: Optional[int]):
262 | if os.path.isdir(source):
263 | if source.rstrip('/').endswith('_lmdb'):
264 | return open_lmdb(source, max_images=max_images)
265 | else:
266 | return open_image_folder(source, max_images=max_images)
267 | elif os.path.isfile(source):
268 | if os.path.basename(source) == 'cifar-10-python.tar.gz':
269 | return open_cifar10(source, max_images=max_images)
270 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
271 | return open_mnist(source, max_images=max_images)
272 | elif file_ext(source) == 'zip':
273 | return open_image_zip(source, max_images=max_images)
274 | else:
275 | assert False, 'unknown archive type'
276 | else:
277 | raise click.ClickException(f'Missing input file or directory: {source}')
278 |
279 | #----------------------------------------------------------------------------
280 |
281 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
282 | dest_ext = file_ext(dest)
283 |
284 | if dest_ext == 'zip':
285 | if os.path.dirname(dest) != '':
286 | os.makedirs(os.path.dirname(dest), exist_ok=True)
287 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
288 | def zip_write_bytes(fname: str, data: Union[bytes, str]):
289 | zf.writestr(fname, data)
290 | return '', zip_write_bytes, zf.close
291 | else:
292 | # If the output folder already exists, check that is is
293 | # empty.
294 | #
295 | # Note: creating the output directory is not strictly
296 | # necessary as folder_write_bytes() also mkdirs, but it's better
297 | # to give an error message earlier in case the dest folder
298 | # somehow cannot be created.
299 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
300 | raise click.ClickException('--dest folder must be empty')
301 | os.makedirs(dest, exist_ok=True)
302 |
303 | def folder_write_bytes(fname: str, data: Union[bytes, str]):
304 | os.makedirs(os.path.dirname(fname), exist_ok=True)
305 | with open(fname, 'wb') as fout:
306 | if isinstance(data, str):
307 | data = data.encode('utf8')
308 | fout.write(data)
309 | return dest, folder_write_bytes, lambda: None
310 |
311 | #----------------------------------------------------------------------------
312 |
313 | @click.command()
314 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
315 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
316 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
317 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide']))
318 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
319 |
320 | def main(
321 | source: str,
322 | dest: str,
323 | max_images: Optional[int],
324 | transform: Optional[str],
325 | resolution: Optional[Tuple[int, int]]
326 | ):
327 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
328 |
329 | The input dataset format is guessed from the --source argument:
330 |
331 | \b
332 | --source *_lmdb/ Load LSUN dataset
333 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset
334 | --source train-images-idx3-ubyte.gz Load MNIST dataset
335 | --source path/ Recursively load all images from path/
336 | --source dataset.zip Recursively load all images from dataset.zip
337 |
338 | Specifying the output format and path:
339 |
340 | \b
341 | --dest /path/to/dir Save output files under /path/to/dir
342 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
343 |
344 | The output dataset format can be either an image folder or an uncompressed zip archive.
345 | Zip archives makes it easier to move datasets around file servers and clusters, and may
346 | offer better training performance on network file systems.
347 |
348 | Images within the dataset archive will be stored as uncompressed PNG.
349 | Uncompresed PNGs can be efficiently decoded in the training loop.
350 |
351 | Class labels are stored in a file called 'dataset.json' that is stored at the
352 | dataset root folder. This file has the following structure:
353 |
354 | \b
355 | {
356 | "labels": [
357 | ["00000/img00000000.png",6],
358 | ["00000/img00000001.png",9],
359 | ... repeated for every image in the datase
360 | ["00049/img00049999.png",1]
361 | ]
362 | }
363 |
364 | If the 'dataset.json' file cannot be found, class labels are determined from
365 | top-level directory names.
366 |
367 | Image scale/crop and resolution requirements:
368 |
369 | Output images must be square-shaped and they must all have the same power-of-two
370 | dimensions.
371 |
372 | To scale arbitrary input image size to a specific width and height, use the
373 | --resolution option. Output resolution will be either the original
374 | input resolution (if resolution was not specified) or the one specified with
375 | --resolution option.
376 |
377 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a
378 | center crop transform on the input image. These options should be used with the
379 | --resolution option. For example:
380 |
381 | \b
382 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
383 | --transform=center-crop-wide --resolution=512x384
384 | """
385 |
386 | PIL.Image.init()
387 |
388 | if dest == '':
389 | raise click.ClickException('--dest output filename or directory must not be an empty string')
390 |
391 | num_files, input_iter = open_dataset(source, max_images=max_images)
392 | archive_root_dir, save_bytes, close_dest = open_dest(dest)
393 |
394 | if resolution is None: resolution = (None, None)
395 | transform_image = make_transform(transform, *resolution)
396 |
397 | dataset_attrs = None
398 |
399 | labels = []
400 | for idx, image in tqdm(enumerate(input_iter), total=num_files):
401 | idx_str = f'{idx:08d}'
402 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
403 |
404 | # Apply crop and resize.
405 | img = transform_image(image['img'])
406 | if img is None:
407 | continue
408 |
409 | # Error check to require uniform image attributes across
410 | # the whole dataset.
411 | channels = img.shape[2] if img.ndim == 3 else 1
412 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels}
413 | if dataset_attrs is None:
414 | dataset_attrs = cur_image_attrs
415 | width = dataset_attrs['width']
416 | height = dataset_attrs['height']
417 | if width != height:
418 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
419 | if dataset_attrs['channels'] not in [1, 3]:
420 | raise click.ClickException('Input images must be stored as RGB or grayscale')
421 | if width != 2 ** int(np.floor(np.log2(width))):
422 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
423 | elif dataset_attrs != cur_image_attrs:
424 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
425 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
426 |
427 | # Save the image as an uncompressed PNG.
428 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
429 | image_bits = io.BytesIO()
430 | img.save(image_bits, format='png', compress_level=0, optimize=False)
431 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
432 | labels.append([archive_fname, image['label']] if image['label'] is not None else None)
433 |
434 | metadata = {'labels': labels if all(x is not None for x in labels) else None}
435 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
436 | close_dest()
437 |
438 | #----------------------------------------------------------------------------
439 |
440 | if __name__ == "__main__":
441 | main()
442 |
443 | #----------------------------------------------------------------------------
444 |
--------------------------------------------------------------------------------
/training/augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Augmentation pipeline used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models".
10 | Built around the same concepts that were originally proposed in the paper
11 | "Training Generative Adversarial Networks with Limited Data"."""
12 |
13 | import numpy as np
14 | import torch
15 | from torch_utils import persistence
16 | from torch_utils import misc
17 |
18 | #----------------------------------------------------------------------------
19 | # Coefficients of various wavelet decomposition low-pass filters.
20 |
21 | wavelets = {
22 | 'haar': [0.7071067811865476, 0.7071067811865476],
23 | 'db1': [0.7071067811865476, 0.7071067811865476],
24 | 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25 | 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26 | 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
27 | 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
28 | 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
29 | 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
30 | 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
31 | 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
32 | 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
33 | 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
34 | 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
35 | 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
36 | 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
37 | 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
38 | }
39 |
40 | #----------------------------------------------------------------------------
41 | # Helpers for constructing transformation matrices.
42 |
43 | def matrix(*rows, device=None):
44 | assert all(len(row) == len(rows[0]) for row in rows)
45 | elems = [x for row in rows for x in row]
46 | ref = [x for x in elems if isinstance(x, torch.Tensor)]
47 | if len(ref) == 0:
48 | return misc.constant(np.asarray(rows), device=device)
49 | assert device is None or device == ref[0].device
50 | elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
51 | return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
52 |
53 | def translate2d(tx, ty, **kwargs):
54 | return matrix(
55 | [1, 0, tx],
56 | [0, 1, ty],
57 | [0, 0, 1],
58 | **kwargs)
59 |
60 | def translate3d(tx, ty, tz, **kwargs):
61 | return matrix(
62 | [1, 0, 0, tx],
63 | [0, 1, 0, ty],
64 | [0, 0, 1, tz],
65 | [0, 0, 0, 1],
66 | **kwargs)
67 |
68 | def scale2d(sx, sy, **kwargs):
69 | return matrix(
70 | [sx, 0, 0],
71 | [0, sy, 0],
72 | [0, 0, 1],
73 | **kwargs)
74 |
75 | def scale3d(sx, sy, sz, **kwargs):
76 | return matrix(
77 | [sx, 0, 0, 0],
78 | [0, sy, 0, 0],
79 | [0, 0, sz, 0],
80 | [0, 0, 0, 1],
81 | **kwargs)
82 |
83 | def rotate2d(theta, **kwargs):
84 | return matrix(
85 | [torch.cos(theta), torch.sin(-theta), 0],
86 | [torch.sin(theta), torch.cos(theta), 0],
87 | [0, 0, 1],
88 | **kwargs)
89 |
90 | def rotate3d(v, theta, **kwargs):
91 | vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92 | s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93 | return matrix(
94 | [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95 | [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96 | [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97 | [0, 0, 0, 1],
98 | **kwargs)
99 |
100 | def translate2d_inv(tx, ty, **kwargs):
101 | return translate2d(-tx, -ty, **kwargs)
102 |
103 | def scale2d_inv(sx, sy, **kwargs):
104 | return scale2d(1 / sx, 1 / sy, **kwargs)
105 |
106 | def rotate2d_inv(theta, **kwargs):
107 | return rotate2d(-theta, **kwargs)
108 |
109 | #----------------------------------------------------------------------------
110 | # Augmentation pipeline main class.
111 | # All augmentations are disabled by default; individual augmentations can
112 | # be enabled by setting their probability multipliers to 1.
113 |
114 | @persistence.persistent_class
115 | class AugmentPipe:
116 | def __init__(self, p=1,
117 | xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125,
118 | scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125,
119 | brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
120 | ):
121 | super().__init__()
122 | self.p = float(p) # Overall multiplier for augmentation probability.
123 |
124 | # Pixel blitting.
125 | self.xflip = float(xflip) # Probability multiplier for x-flip.
126 | self.yflip = float(yflip) # Probability multiplier for y-flip.
127 | self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation.
128 | self.translate_int = float(translate_int) # Probability multiplier for integer translation.
129 | self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions.
130 |
131 | # Geometric transformations.
132 | self.scale = float(scale) # Probability multiplier for isotropic scaling.
133 | self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation.
134 | self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
135 | self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation.
136 | self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
137 | self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle.
138 | self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
139 | self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame.
140 | self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions.
141 |
142 | # Color transformations.
143 | self.brightness = float(brightness) # Probability multiplier for brightness.
144 | self.contrast = float(contrast) # Probability multiplier for contrast.
145 | self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
146 | self.hue = float(hue) # Probability multiplier for hue rotation.
147 | self.saturation = float(saturation) # Probability multiplier for saturation.
148 | self.brightness_std = float(brightness_std) # Standard deviation of brightness.
149 | self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
150 | self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
151 | self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
152 |
153 | def __call__(self, images):
154 | N, C, H, W = images.shape
155 | device = images.device
156 | labels = [torch.zeros([images.shape[0], 0], device=device)]
157 |
158 | # ---------------
159 | # Pixel blitting.
160 | # ---------------
161 |
162 | if self.xflip > 0:
163 | w = torch.randint(2, [N, 1, 1, 1], device=device)
164 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w))
165 | images = torch.where(w == 1, images.flip(3), images)
166 | labels += [w]
167 |
168 | if self.yflip > 0:
169 | w = torch.randint(2, [N, 1, 1, 1], device=device)
170 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w))
171 | images = torch.where(w == 1, images.flip(2), images)
172 | labels += [w]
173 |
174 | if self.rotate_int > 0:
175 | w = torch.randint(4, [N, 1, 1, 1], device=device)
176 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w))
177 | images = torch.where((w == 1) | (w == 2), images.flip(3), images)
178 | images = torch.where((w == 2) | (w == 3), images.flip(2), images)
179 | images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images)
180 | labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)]
181 |
182 | if self.translate_int > 0:
183 | w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1
184 | w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w))
185 | tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64)
186 | ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64)
187 | b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij')
188 | x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs()
189 | y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs()
190 | images = images.flatten()[(((b * C) + c) * H + y) * W + x]
191 | labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)]
192 |
193 | # ------------------------------------------------
194 | # Select parameters for geometric transformations.
195 | # ------------------------------------------------
196 |
197 | I_3 = torch.eye(3, device=device)
198 | G_inv = I_3
199 |
200 | if self.scale > 0:
201 | w = torch.randn([N], device=device)
202 | w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w))
203 | s = w.mul(self.scale_std).exp2()
204 | G_inv = G_inv @ scale2d_inv(s, s)
205 | labels += [w]
206 |
207 | if self.rotate_frac > 0:
208 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max)
209 | w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w))
210 | G_inv = G_inv @ rotate2d_inv(-w)
211 | labels += [w.cos() - 1, w.sin()]
212 |
213 | if self.aniso > 0:
214 | w = torch.randn([N], device=device)
215 | r = (torch.rand([N], device=device) * 2 - 1) * np.pi
216 | w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w))
217 | r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r))
218 | s = w.mul(self.aniso_std).exp2()
219 | G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r)
220 | labels += [w * r.cos(), w * r.sin()]
221 |
222 | if self.translate_frac > 0:
223 | w = torch.randn([2, N], device=device)
224 | w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w))
225 | G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std))
226 | labels += [w[0], w[1]]
227 |
228 | # ----------------------------------
229 | # Execute geometric transformations.
230 | # ----------------------------------
231 |
232 | if G_inv is not I_3:
233 | cx = (W - 1) / 2
234 | cy = (H - 1) / 2
235 | cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
236 | cp = G_inv @ cp.t() # [batch, xyz, idx]
237 | Hz = np.asarray(wavelets['sym6'], dtype=np.float32)
238 | Hz_pad = len(Hz) // 4
239 | margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
240 | margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
241 | margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
242 | margin = margin.max(misc.constant([0, 0] * 2, device=device))
243 | margin = margin.min(misc.constant([W - 1, H - 1] * 2, device=device))
244 | mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
245 |
246 | # Pad image and adjust origin.
247 | images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
248 | G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
249 |
250 | # Upsample.
251 | conv_weight = misc.constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
252 | conv_pad = (len(Hz) + 1) // 2
253 | images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1]
254 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad])
255 | images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :]
256 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0])
257 | G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
258 | G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
259 |
260 | # Execute transformation.
261 | shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2]
262 | G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
263 | grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
264 | images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
265 |
266 | # Downsample and crop.
267 | conv_weight = misc.constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1])
268 | conv_pad = (len(Hz) - 1) // 2
269 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad]
270 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :]
271 |
272 | # --------------------------------------------
273 | # Select parameters for color transformations.
274 | # --------------------------------------------
275 |
276 | I_4 = torch.eye(4, device=device)
277 | M = I_4
278 | luma_axis = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device)
279 |
280 | if self.brightness > 0:
281 | w = torch.randn([N], device=device)
282 | w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w))
283 | b = w * self.brightness_std
284 | M = translate3d(b, b, b) @ M
285 | labels += [w]
286 |
287 | if self.contrast > 0:
288 | w = torch.randn([N], device=device)
289 | w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w))
290 | c = w.mul(self.contrast_std).exp2()
291 | M = scale3d(c, c, c) @ M
292 | labels += [w]
293 |
294 | if self.lumaflip > 0:
295 | w = torch.randint(2, [N, 1, 1], device=device)
296 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w))
297 | M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M
298 | labels += [w]
299 |
300 | if self.hue > 0:
301 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max)
302 | w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w))
303 | M = rotate3d(luma_axis, w) @ M
304 | labels += [w.cos() - 1, w.sin()]
305 |
306 | if self.saturation > 0:
307 | w = torch.randn([N, 1, 1], device=device)
308 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w))
309 | M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M
310 | labels += [w]
311 |
312 | # ------------------------------
313 | # Execute color transformations.
314 | # ------------------------------
315 |
316 | if M is not I_4:
317 | images = images.reshape([N, C, H * W])
318 | if C == 3:
319 | images = M[:, :3, :3] @ images + M[:, :3, 3:]
320 | elif C == 1:
321 | M = M[:, :3, :].mean(dim=1, keepdims=True)
322 | images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:]
323 | else:
324 | raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
325 | images = images.reshape([N, C, H, W])
326 |
327 | labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1)
328 | return images, labels
329 |
330 | #----------------------------------------------------------------------------
331 |
--------------------------------------------------------------------------------
/training/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Garena Online Private Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import numpy as np
17 | import torch
18 | from torch_utils import persistence
19 | from torch.nn.functional import silu
20 |
21 | #----------------------------------------------------------------------------
22 | # Unified routine for initializing weights and biases.
23 |
24 | def weight_init(shape, mode, fan_in, fan_out):
25 | if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
26 | if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
27 | if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
28 | if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape)
29 | raise ValueError(f'Invalid init mode "{mode}"')
30 |
31 | #----------------------------------------------------------------------------
32 | # Fully-connected layer.
33 |
34 | @persistence.persistent_class
35 | class Linear(torch.nn.Module):
36 | def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
37 | super().__init__()
38 | self.in_features = in_features
39 | self.out_features = out_features
40 | init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
41 | self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
42 | self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
43 |
44 | def forward(self, x):
45 | x = x @ self.weight.to(x.dtype).t()
46 | if self.bias is not None:
47 | x = x.add_(self.bias.to(x.dtype))
48 | return x
49 |
50 | #----------------------------------------------------------------------------
51 | # Convolutional layer with optional up/downsampling.
52 |
53 | @persistence.persistent_class
54 | class Conv2d(torch.nn.Module):
55 | def __init__(self,
56 | in_channels, out_channels, kernel, bias=True, up=False, down=False,
57 | resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0,
58 | ):
59 | assert not (up and down)
60 | super().__init__()
61 | self.in_channels = in_channels
62 | self.out_channels = out_channels
63 | self.up = up
64 | self.down = down
65 | self.fused_resample = fused_resample
66 | init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel)
67 | self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None
68 | self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None
69 | f = torch.as_tensor(resample_filter, dtype=torch.float32)
70 | f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square()
71 | self.register_buffer('resample_filter', f if up or down else None)
72 |
73 | def forward(self, x):
74 | w = self.weight.to(x.dtype) if self.weight is not None else None
75 | b = self.bias.to(x.dtype) if self.bias is not None else None
76 | f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None
77 | w_pad = w.shape[-1] // 2 if w is not None else 0
78 | f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0
79 |
80 | if self.fused_resample and self.up and w is not None:
81 | x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0))
82 | x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0))
83 | elif self.fused_resample and self.down and w is not None:
84 | x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad)
85 | x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2)
86 | else:
87 | if self.up:
88 | x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
89 | if self.down:
90 | x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad)
91 | if w is not None:
92 | x = torch.nn.functional.conv2d(x, w, padding=w_pad)
93 | if b is not None:
94 | x = x.add_(b.reshape(1, -1, 1, 1))
95 | return x
96 |
97 | #----------------------------------------------------------------------------
98 | # Group normalization.
99 |
100 | @persistence.persistent_class
101 | class GroupNorm(torch.nn.Module):
102 | def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5):
103 | super().__init__()
104 | self.num_groups = min(num_groups, num_channels // min_channels_per_group)
105 | self.eps = eps
106 | self.weight = torch.nn.Parameter(torch.ones(num_channels))
107 | self.bias = torch.nn.Parameter(torch.zeros(num_channels))
108 |
109 | def forward(self, x):
110 | x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps)
111 | return x
112 |
113 | #----------------------------------------------------------------------------
114 | # Attention weight computation, i.e., softmax(Q^T * K).
115 | # Performs all computation using FP32, but uses the original datatype for
116 | # inputs/outputs/gradients to conserve memory.
117 |
118 | class AttentionOp(torch.autograd.Function):
119 | @staticmethod
120 | def forward(ctx, q, k):
121 | w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
122 | ctx.save_for_backward(q, k, w)
123 | return w
124 |
125 | @staticmethod
126 | def backward(ctx, dw):
127 | q, k, w = ctx.saved_tensors
128 | db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
129 | dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
130 | dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
131 | return dq, dk
132 |
133 | #----------------------------------------------------------------------------
134 | # Unified U-Net block with optional up/downsampling and self-attention.
135 | # Represents the union of all features employed by the DDPM++, NCSN++, and
136 | # ADM architectures.
137 |
138 | @persistence.persistent_class
139 | class UNetBlock(torch.nn.Module):
140 | def __init__(self,
141 | in_channels, out_channels, emb_channels, up=False, down=False, attention=False,
142 | num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5,
143 | resample_filter=[1,1], resample_proj=False, adaptive_scale=True,
144 | init=dict(), init_zero=dict(init_weight=0), init_attn=None,
145 | ):
146 | super().__init__()
147 | self.in_channels = in_channels
148 | self.out_channels = out_channels
149 | self.emb_channels = emb_channels
150 | self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head
151 | self.dropout = dropout
152 | self.skip_scale = skip_scale
153 | self.adaptive_scale = adaptive_scale
154 |
155 | self.norm0 = GroupNorm(num_channels=in_channels, eps=eps)
156 | self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init)
157 | self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init)
158 | self.norm1 = GroupNorm(num_channels=out_channels, eps=eps)
159 | self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero)
160 |
161 | self.skip = None
162 | if out_channels != in_channels or up or down:
163 | kernel = 1 if resample_proj or out_channels!= in_channels else 0
164 | self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init)
165 |
166 | if self.num_heads:
167 | self.norm2 = GroupNorm(num_channels=out_channels, eps=eps)
168 | self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init))
169 | self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero)
170 |
171 | def forward(self, x, emb):
172 | orig = x
173 | x = self.conv0(silu(self.norm0(x)))
174 |
175 | params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
176 | if self.adaptive_scale:
177 | scale, shift = params.chunk(chunks=2, dim=1)
178 | x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
179 | else:
180 | x = silu(self.norm1(x.add_(params)))
181 |
182 | x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
183 | x = x.add_(self.skip(orig) if self.skip is not None else orig)
184 | x = x * self.skip_scale
185 |
186 | if self.num_heads:
187 | q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
188 | w = AttentionOp.apply(q, k)
189 | a = torch.einsum('nqk,nck->ncq', w, v)
190 | x = self.proj(a.reshape(*x.shape)).add_(x)
191 | x = x * self.skip_scale
192 | return x
193 |
194 | #----------------------------------------------------------------------------
195 | # Timestep embedding used in the DDPM++ and ADM architectures.
196 |
197 | @persistence.persistent_class
198 | class PositionalEmbedding(torch.nn.Module):
199 | def __init__(self, num_channels, max_positions=10000, endpoint=False):
200 | super().__init__()
201 | self.num_channels = num_channels
202 | self.max_positions = max_positions
203 | self.endpoint = endpoint
204 |
205 | def forward(self, x):
206 | freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
207 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
208 | freqs = (1 / self.max_positions) ** freqs
209 | x = x.ger(freqs.to(x.dtype))
210 | x = torch.cat([x.cos(), x.sin()], dim=1)
211 | return x
212 |
213 | #----------------------------------------------------------------------------
214 | # Timestep embedding used in the NCSN++ architecture.
215 |
216 | @persistence.persistent_class
217 | class FourierEmbedding(torch.nn.Module):
218 | def __init__(self, num_channels, scale=16):
219 | super().__init__()
220 | self.register_buffer('freqs', torch.randn(num_channels // 2) * scale)
221 |
222 | def forward(self, x):
223 | x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
224 | x = torch.cat([x.cos(), x.sin()], dim=1)
225 | return x
226 |
227 | #----------------------------------------------------------------------------
228 | # Reimplementation of the DDPM++ and NCSN++ architectures from the paper
229 | # "Score-Based Generative Modeling through Stochastic Differential
230 | # Equations". Equivalent to the original implementation by Song et al.,
231 | # available at https://github.com/yang-song/score_sde_pytorch
232 |
233 | @persistence.persistent_class
234 | class SongUNet(torch.nn.Module):
235 | def __init__(self,
236 | img_resolution, # Image resolution at input/output.
237 | in_channels, # Number of color channels at input.
238 | out_channels, # Number of color channels at output.
239 | label_dim = 0, # Number of class labels, 0 = unconditional.
240 | augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
241 |
242 | model_channels = 128, # Base multiplier for the number of channels.
243 | channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels.
244 | channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
245 | num_blocks = 4, # Number of residual blocks per resolution.
246 | attn_resolutions = [16], # List of resolutions with self-attention.
247 | dropout = 0.10, # Dropout probability of intermediate activations.
248 | label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
249 |
250 | embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
251 | channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
252 | encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
253 | decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
254 | resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
255 | ):
256 | assert embedding_type in ['fourier', 'positional']
257 | assert encoder_type in ['standard', 'skip', 'residual']
258 | assert decoder_type in ['standard', 'skip']
259 |
260 | super().__init__()
261 | self.label_dropout = label_dropout
262 | emb_channels = model_channels * channel_mult_emb
263 | noise_channels = model_channels * channel_mult_noise
264 | init = dict(init_mode='xavier_uniform')
265 | init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5)
266 | init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2))
267 | block_kwargs = dict(
268 | emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6,
269 | resample_filter=resample_filter, resample_proj=True, adaptive_scale=False,
270 | init=init, init_zero=init_zero, init_attn=init_attn,
271 | )
272 |
273 | # Mapping.
274 | self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels)
275 | self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None
276 | self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None
277 | self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init)
278 | self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
279 |
280 | # Encoder.
281 | self.enc = torch.nn.ModuleDict()
282 | cout = in_channels
283 | caux = in_channels
284 | for level, mult in enumerate(channel_mult):
285 | res = img_resolution >> level
286 | if level == 0:
287 | cin = cout
288 | cout = model_channels
289 | self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
290 | else:
291 | self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
292 | if encoder_type == 'skip':
293 | self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter)
294 | self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init)
295 | if encoder_type == 'residual':
296 | self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init)
297 | caux = cout
298 | for idx in range(num_blocks):
299 | cin = cout
300 | cout = model_channels * mult
301 | attn = (res in attn_resolutions)
302 | self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
303 | skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name]
304 |
305 | # Decoder.
306 | self.dec = torch.nn.ModuleDict()
307 | for level, mult in reversed(list(enumerate(channel_mult))):
308 | res = img_resolution >> level
309 | if level == len(channel_mult) - 1:
310 | self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
311 | self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
312 | else:
313 | self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
314 | for idx in range(num_blocks + 1):
315 | cin = cout + skips.pop()
316 | cout = model_channels * mult
317 | attn = (idx == num_blocks and res in attn_resolutions)
318 | self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
319 | if decoder_type == 'skip' or level == 0:
320 | if decoder_type == 'skip' and level < len(channel_mult) - 1:
321 | self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter)
322 | self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6)
323 | self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero)
324 |
325 | def forward(self, x, noise_labels, class_labels, augment_labels=None):
326 | # Mapping.
327 | emb = self.map_noise(noise_labels)
328 | emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
329 | if self.map_label is not None:
330 | tmp = class_labels
331 | if self.training and self.label_dropout:
332 | tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype)
333 | emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features))
334 | if self.map_augment is not None and augment_labels is not None:
335 | emb = emb + self.map_augment(augment_labels)
336 | emb = silu(self.map_layer0(emb))
337 | emb = silu(self.map_layer1(emb))
338 |
339 | # Encoder.
340 | skips = []
341 | aux = x
342 | for name, block in self.enc.items():
343 | if 'aux_down' in name:
344 | aux = block(aux)
345 | elif 'aux_skip' in name:
346 | x = skips[-1] = x + block(aux)
347 | elif 'aux_residual' in name:
348 | x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2)
349 | else:
350 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
351 | skips.append(x)
352 |
353 | # Decoder.
354 | aux = None
355 | tmp = None
356 | for name, block in self.dec.items():
357 | if 'aux_up' in name:
358 | aux = block(aux)
359 | elif 'aux_norm' in name:
360 | tmp = block(x)
361 | elif 'aux_conv' in name:
362 | tmp = block(silu(tmp))
363 | aux = tmp if aux is None else tmp + aux
364 | else:
365 | if x.shape[1] != block.in_channels:
366 | x = torch.cat([x, skips.pop()], dim=1)
367 | x = block(x, emb)
368 | return aux
369 |
370 |
371 | #----------------------------------------------------------------------------
372 | # Preconditioning corresponding to the variance preserving (VP) formulation
373 | # from the paper "Score-Based Generative Modeling through Stochastic
374 | # Differential Equations".
375 |
376 | @persistence.persistent_class
377 | class VPPrecond(torch.nn.Module):
378 | def __init__(self,
379 | img_resolution, # Image resolution.
380 | img_channels, # Number of color channels.
381 | label_dim = 0, # Number of class labels, 0 = unconditional.
382 | use_fp16 = False, # Execute the underlying model at FP16 precision?
383 | beta_d = 19.9, # Extent of the noise level schedule.
384 | beta_min = 0.1, # Initial slope of the noise level schedule.
385 | M = 1000, # Original number of timesteps in the DDPM formulation.
386 | epsilon_t = 1e-5, # Minimum t-value used during training.
387 | model_type = 'SongUNet', # Class name of the underlying model.
388 | **model_kwargs, # Keyword arguments for the underlying model.
389 | ):
390 | super().__init__()
391 | self.img_resolution = img_resolution
392 | self.img_channels = img_channels
393 | self.label_dim = label_dim
394 | self.use_fp16 = use_fp16
395 | self.beta_d = beta_d
396 | self.beta_min = beta_min
397 | self.M = M
398 | self.epsilon_t = epsilon_t
399 | self.sigma_min = float(self.sigma(epsilon_t))
400 | self.sigma_max = float(self.sigma(1))
401 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
402 |
403 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
404 | x = x.to(torch.float32)
405 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
406 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
407 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
408 |
409 | c_skip = 1
410 | c_out = -sigma
411 | c_in = 1 / (sigma ** 2 + 1).sqrt()
412 | c_noise = (self.M - 1) * self.sigma_inv(sigma)
413 |
414 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
415 | assert F_x.dtype == dtype
416 | D_x = c_skip * x + c_out * F_x.to(torch.float32)
417 | return D_x
418 |
419 | def sigma(self, t):
420 | t = torch.as_tensor(t)
421 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
422 |
423 | def sigma_inv(self, sigma):
424 | sigma = torch.as_tensor(sigma)
425 | return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d
426 |
427 | def round_sigma(self, sigma):
428 | return torch.as_tensor(sigma)
429 |
430 | #----------------------------------------------------------------------------
431 | # Preconditioning corresponding to the variance exploding (VE) formulation
432 | # from the paper "Score-Based Generative Modeling through Stochastic
433 | # Differential Equations".
434 |
435 | @persistence.persistent_class
436 | class VEPrecond(torch.nn.Module):
437 | def __init__(self,
438 | img_resolution, # Image resolution.
439 | img_channels, # Number of color channels.
440 | label_dim = 0, # Number of class labels, 0 = unconditional.
441 | use_fp16 = False, # Execute the underlying model at FP16 precision?
442 | sigma_min = 0.02, # Minimum supported noise level.
443 | sigma_max = 100, # Maximum supported noise level.
444 | model_type = 'SongUNet', # Class name of the underlying model.
445 | **model_kwargs, # Keyword arguments for the underlying model.
446 | ):
447 | super().__init__()
448 | self.img_resolution = img_resolution
449 | self.img_channels = img_channels
450 | self.label_dim = label_dim
451 | self.use_fp16 = use_fp16
452 | self.sigma_min = sigma_min
453 | self.sigma_max = sigma_max
454 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
455 |
456 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
457 | x = x.to(torch.float32)
458 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
459 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
460 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
461 |
462 | c_skip = 1
463 | c_out = sigma
464 | c_in = 1
465 | c_noise = (0.5 * sigma).log()
466 |
467 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
468 | assert F_x.dtype == dtype
469 | D_x = c_skip * x + c_out * F_x.to(torch.float32)
470 | return D_x
471 |
472 | def round_sigma(self, sigma):
473 | return torch.as_tensor(sigma)
474 |
475 | #----------------------------------------------------------------------------
476 | # Preconditioning corresponding to the EDM formulation
477 | # from the paper "Elucidating the Design Space of Diffusion-Based
478 | # Generative Models".
479 |
480 | @persistence.persistent_class
481 | class EDMPrecond(torch.nn.Module):
482 | def __init__(self,
483 | img_resolution, # Image resolution.
484 | img_channels, # Number of color channels.
485 | label_dim = 0, # Number of class labels, 0 = unconditional.
486 | use_fp16 = False, # Execute the underlying model at FP16 precision?
487 | sigma_min = 0, # Minimum supported noise level.
488 | sigma_max = float('inf'), # Maximum supported noise level.
489 | sigma_data = 0.5, # Expected standard deviation of the training data.
490 | model_type = 'SongUNet', # Class name of the underlying model.
491 | **model_kwargs, # Keyword arguments for the underlying model.
492 | ):
493 | super().__init__()
494 | self.img_resolution = img_resolution
495 | self.img_channels = img_channels
496 | self.label_dim = label_dim
497 | self.use_fp16 = use_fp16
498 | self.sigma_min = sigma_min
499 | self.sigma_max = sigma_max
500 | self.sigma_data = sigma_data
501 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
502 |
503 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
504 | x = x.to(torch.float32)
505 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
506 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
507 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
508 |
509 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
510 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
511 | c_in = 1 / (self.sigma_data ** 2 + (sigma) ** 2).sqrt()
512 | c_noise = sigma.log() / 4
513 |
514 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
515 | assert F_x.dtype == dtype
516 | D_x = c_skip * x + c_out * F_x.to(torch.float32)
517 | return D_x
518 |
519 | def round_sigma(self, sigma):
520 | return torch.as_tensor(sigma)
521 |
522 |
523 | #----------------------------------------------------------------------------
524 | # Fast Diffusion Model with EDM Preconditioning (EDM-FDM)
525 | # from the paper "Elucidating the Design Space of Diffusion-Based
526 | # Generative Models".
527 |
528 | @persistence.persistent_class
529 | class FDM_EDMPrecond(torch.nn.Module):
530 | def __init__(self,
531 | img_resolution, # Image resolution.
532 | img_channels, # Number of color channels.
533 | label_dim = 0, # Number of class labels, 0 = unconditional.
534 | use_fp16 = False, # Execute the underlying model at FP16 precision?
535 | sigma_min = 0.002, # Minimum supported noise level.
536 | sigma_max = 80.0, # Maximum supported noise level.
537 | sigma_data = 0.5, # Expected standard deviation of the training data.
538 | model_type = 'SongUNet', # Class name of the underlying model.
539 | fdm_beta_d = 19.9, # Extent of the FDM noise level schedule.
540 | fdm_beta_min = 0.1, # Initial slope of the FDM noise level schedule.
541 | fdm_multiplier = 1.0, # Multiplier of the FDM noise level schedule.
542 | **model_kwargs, # Keyword arguments for the underlying model.
543 | ):
544 | super().__init__()
545 | self.img_resolution = img_resolution
546 | self.img_channels = img_channels
547 | self.label_dim = label_dim
548 | self.use_fp16 = use_fp16
549 | self.sigma_min = sigma_min
550 | self.sigma_max = sigma_max
551 | self.sigma_data = sigma_data
552 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
553 |
554 | self.fdm_beta_d = fdm_beta_d
555 | self.fdm_beta_min = fdm_beta_min
556 | self.fdm_multiplier = fdm_multiplier
557 |
558 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
559 | x = x.to(torch.float32)
560 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
561 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
562 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
563 |
564 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
565 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
566 | c_in = self.s(sigma)
567 | c_noise = sigma.log() / 4
568 |
569 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
570 | assert F_x.dtype == dtype
571 | D_x = c_skip * x + c_out * F_x.to(torch.float32)
572 | return D_x
573 |
574 | def round_sigma(self, sigma):
575 | return torch.as_tensor(sigma)
576 |
577 | def fdm_sigma_inv(self, sigma):
578 | sigma = torch.as_tensor(sigma)
579 | return (sigma - self.sigma_min) / (self.sigma_max - self.sigma_min)
580 |
581 | def fdm_beta_fn(self, t):
582 | return self.fdm_beta_min * t + 0.5 * self.fdm_beta_d * t**2
583 |
584 | def s(self, sigma):
585 | t = self.fdm_sigma_inv(sigma)
586 | beta = self.fdm_beta_fn(t)
587 | return torch.exp(-self.fdm_multiplier * beta) * (1. + self.fdm_multiplier * beta)
588 |
589 |
590 | #----------------------------------------------------------------------------
591 | # Fast Diffusion Model with VP Preconditioning (VP-FDM)
592 | # from the paper "Score-Based Generative Modeling through Stochastic
593 | # Differential Equations".
594 |
595 | @persistence.persistent_class
596 | class FDM_VPPrecond(torch.nn.Module):
597 | def __init__(self,
598 | img_resolution, # Image resolution.
599 | img_channels, # Number of color channels.
600 | label_dim = 0, # Number of class labels, 0 = unconditional.
601 | use_fp16 = False, # Execute the underlying model at FP16 precision?
602 | beta_d = 19.9, # Extent of the noise level schedule.
603 | beta_min = 0.1, # Initial slope of the noise level schedule.
604 | M = 1000, # Original number of timesteps in the DDPM formulation.
605 | epsilon_t = 1e-5, # Minimum t-value used during training.
606 | model_type = 'SongUNet', # Class name of the underlying model.
607 | fdm_beta_d = 19.9, # Extent of the FDM noise level schedule.
608 | fdm_beta_min = 0.1, # Initial slope of the FDM noise level schedule.
609 | fdm_multiplier = 1.0, # Multiplier of the FDM noise level schedule.
610 | **model_kwargs, # Keyword arguments for the underlying model.
611 | ):
612 | super().__init__()
613 | self.img_resolution = img_resolution
614 | self.img_channels = img_channels
615 | self.label_dim = label_dim
616 | self.use_fp16 = use_fp16
617 | self.beta_d = beta_d
618 | self.beta_min = beta_min
619 | self.M = M
620 | self.epsilon_t = epsilon_t
621 | self.sigma_min = float(self.sigma(epsilon_t))
622 | self.sigma_max = float(self.sigma(1))
623 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
624 |
625 | self.fdm_beta_d = fdm_beta_d
626 | self.fdm_beta_min = fdm_beta_min
627 | self.fdm_multiplier = fdm_multiplier
628 |
629 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
630 | x = x.to(torch.float32)
631 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
632 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
633 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
634 |
635 | c_skip = 1
636 | c_out = -sigma
637 | c_in = self.s(sigma)
638 | c_noise = (self.M - 1) * self.sigma_inv(sigma)
639 |
640 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
641 | assert F_x.dtype == dtype
642 | D_x = c_skip * x + c_out * F_x.to(torch.float32)
643 | return D_x
644 |
645 | def sigma(self, t):
646 | t = torch.as_tensor(t)
647 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
648 |
649 | def sigma_inv(self, sigma):
650 | sigma = torch.as_tensor(sigma)
651 | return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d
652 |
653 | def round_sigma(self, sigma):
654 | return torch.as_tensor(sigma)
655 |
656 | def fdm_sigma_inv(self, sigma):
657 | sigma = torch.as_tensor(sigma)
658 | return (sigma - self.sigma_min) / (self.sigma_max - self.sigma_min)
659 |
660 | def fdm_beta_fn(self, t):
661 | return self.fdm_beta_min * t + 0.5 * self.fdm_beta_d * t**2
662 |
663 | def s(self, sigma):
664 | t = self.fdm_sigma_inv(sigma)
665 | beta = self.fdm_beta_fn(t)
666 | return torch.exp(-self.fdm_multiplier * beta) * (1. + self.fdm_multiplier * beta)
667 |
668 | #----------------------------------------------------------------------------
669 | # Fast Diffusion Model with VE Preconditioning (VE-FDM)
670 | # from the paper "Score-Based Generative Modeling through Stochastic
671 | # Differential Equations".
672 |
673 | @persistence.persistent_class
674 | class FDM_VEPrecond(torch.nn.Module):
675 | def __init__(self,
676 | img_resolution, # Image resolution.
677 | img_channels, # Number of color channels.
678 | label_dim = 0, # Number of class labels, 0 = unconditional.
679 | use_fp16 = False, # Execute the underlying model at FP16 precision?
680 | sigma_min = 0.02, # Minimum supported noise level.
681 | sigma_max = 100, # Maximum supported noise level.
682 | model_type = 'SongUNet', # Class name of the underlying model.
683 | fdm_beta_d = 19.9, # Extent of the FDM noise level schedule.
684 | fdm_beta_min = 0.1, # Initial slope of the FDM noise level schedule.
685 | fdm_multiplier = 1.0, # Multiplier of the FDM noise level schedule.
686 | **model_kwargs, # Keyword arguments for the underlying model.
687 | ):
688 | super().__init__()
689 | self.img_resolution = img_resolution
690 | self.img_channels = img_channels
691 | self.label_dim = label_dim
692 | self.use_fp16 = use_fp16
693 | self.sigma_min = sigma_min
694 | self.sigma_max = sigma_max
695 | self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs)
696 |
697 | self.fdm_beta_d = fdm_beta_d
698 | self.fdm_beta_min = fdm_beta_min
699 | self.fdm_multiplier = fdm_multiplier
700 |
701 | def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
702 | x = x.to(torch.float32)
703 | sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
704 | class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
705 | dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
706 |
707 | c_skip = 1
708 | c_out = sigma
709 | c_in = self.s(sigma)
710 | c_noise = (0.5 * sigma).log()
711 |
712 | F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
713 | assert F_x.dtype == dtype
714 | D_x = c_skip * x + c_out * F_x.to(torch.float32)
715 | return D_x
716 |
717 | def round_sigma(self, sigma):
718 | return torch.as_tensor(sigma)
719 |
720 | def fdm_sigma_inv(self, sigma):
721 | sigma = torch.as_tensor(sigma)
722 | return (sigma - self.sigma_min) / (self.sigma_max - self.sigma_min)
723 |
724 | def fdm_beta_fn(self, t):
725 | return self.fdm_beta_min * t + 0.5 * self.fdm_beta_d * t**2
726 |
727 | def s(self, sigma):
728 | t = self.fdm_sigma_inv(sigma)
729 | beta = self.fdm_beta_fn(t)
730 | return torch.exp(-beta) * (1. + beta)
--------------------------------------------------------------------------------