├── README.md
├── dataset_tool_edm.py
├── di_train.py
├── dnnlib
├── __init__.py
└── util.py
├── metrics
├── __init__.py
├── di_frechet_inception_distance.py
├── di_inception_score.py
├── di_kernel_inception_distance.py
├── di_metric_main.py
├── di_metric_utils.py
├── di_precision_recall.py
└── perceptual_path_length.py
├── torch_utils
├── __init__.py
├── custom_ops.py
├── distributed.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── bias_act.cpython-37.pyc
│ │ ├── bias_act.cpython-38.pyc
│ │ ├── conv2d_gradfix.cpython-37.pyc
│ │ ├── conv2d_gradfix.cpython-38.pyc
│ │ ├── conv2d_resample.cpython-37.pyc
│ │ ├── conv2d_resample.cpython-38.pyc
│ │ ├── fma.cpython-37.pyc
│ │ ├── fma.cpython-38.pyc
│ │ ├── upfirdn2d.cpython-37.pyc
│ │ └── upfirdn2d.cpython-38.pyc
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
└── training_stats.py
└── training
├── __init__.py
├── augment.py
├── dataset.py
├── di_loss.py
├── di_training_loop.py
└── networks.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models (Diff-Instruct)
Official PyTorch implementation of the NeurIPS 2023 paper
2 |
3 | **Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models**
4 | Weijian Luo, Tianyang Hu, Shifeng Zhang, Jiacheng Sun, Zhenguo Li and Zhihua Zhang.
5 |
https://openreview.net/forum?id=MLIs5iRq4w
6 |
7 | Abstract: *Due to the ease of training, ability to scale, and high sample quality, diffusion models (DMs) have become the preferred option for generative modeling, with numerous pre-trained models available for a wide variety of datasets. Containing intricate information about data distributions, pre-trained DMs are valuable assets for downstream applications. In this work, we consider learning from pre-trained DMs and transferring their knowledge to other generative models in a data-free fashion. Specifically, we propose a general framework called Diff-Instruct to instruct the training of arbitrary generative models as long as the generated samples are differentiable with respect to the model parameters. Our proposed Diff-Instruct is built on a rigorous mathematical foundation where the instruction process directly corresponds to minimizing a novel divergence we call Integral Kullback-Leibler (IKL) divergence. IKL is tailored for DMs by calculating the integral of the KL divergence along a diffusion process, which we show to be more robust in comparing distributions with misaligned supports. We also reveal non-trivial connections of our method to existing works such as DreamFusion \citep{poole2022dreamfusion}, and generative adversarial training. To demonstrate the effectiveness and universality of Diff-Instruct, we consider two scenarios: distilling pre-trained diffusion models and refining existing GAN models. The experiments on distilling pre-trained diffusion models show that Diff-Instruct results in state-of-the-art single-step diffusion-based models. The experiments on refining GAN models show that the Diff-Instruct can consistently improve the pre-trained generators of GAN models across various settings. Our official code is released through \url{https://github.com/pkulwj1994/diff_instruct}.*
8 |
9 | Code was based on Pytorch implementation of EDM diffusion model: https://github.com/NVlabs/edm.
10 |
11 | ## Prepare conda env
12 |
13 | git clone https://github.com/pkulwj1994/diff_instruct.git
14 | cd diff_instruct
15 |
16 | source activate
17 | conda create -n di_v100 python=3.8
18 | conda activate di_v100
19 | pip install torch==1.12.1 torchvision==0.13.1 tqdm click psutil scipy
20 |
21 | ## Pre-trained models
22 |
23 | We use pre-trained EDM models:
24 |
25 | - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/)
26 |
27 |
28 | ## Preparing datasets
29 |
30 | Datasets are stored in the same format as in [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information.
31 |
32 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive:
33 |
34 | ```.bash
35 | python dataset_tool_edm.py --source=/data/downloads/cifar-10-python.tar.gz --dest=/data/datasets/cifar10-32x32.zip
36 | ```
37 |
38 | **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution:
39 |
40 | ```.bash
41 | python dataset_tool.py --source=/data/downloads/imagenet/ILSVRC/Data/CLS-LOC/train --dest=/data/datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
42 | ```
43 |
44 | ## Distill single-step models for CIFAR10 unconditional generation on a single V100 GPU (result in an FID <= 4.5)
45 |
46 | You can run diffusion distillation using `di_train.py`. For example:
47 |
48 | ```.bash
49 | # Train one-step DI model for unconditional CIFAR-10 using 1 GPUs
50 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 --master_port=25678 di_train.py --outdir=/logs/di/ci10-uncond --data=/data/datasets/cifar10-32x32.zip --arch=ddpmpp --batch 128 --edm_model cifar10-uncond --cond=0 --metrics fid50k_full --tick 10 --snap 50 --lr 0.00001 --glr 0.00001 --init_sigma 1.0 --fp16=0 --lr_warmup_kimg -1 --ls 1.0 --sgls 1.0
51 | ```
52 |
53 | In the experiment, the FID will be calculated automatically for each "snap" of rounds.
54 |
55 | ## License
56 |
57 | All material, including source code and pre-trained models, is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).
58 |
59 | ## Citation
60 |
61 | ```
62 | @article{luo2024diffinstruct,
63 | title={Diff-instruct: A universal approach for transferring knowledge from pre-trained diffusion models},
64 | author={Luo, Weijian and Hu, Tianyang and Zhang, Shifeng and Sun, Jiacheng and Li, Zhenguo and Zhang, Zhihua},
65 | journal={Advances in Neural Information Processing Systems},
66 | volume={36},
67 | year={2024}
68 | }
69 | ```
70 |
71 | ## Development
72 |
73 | This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
74 |
75 | ## Acknowledgments
76 |
77 | We thank EDM paper ""Elucidating the Design Space of Diffusion-Based Generative Models"" for its great implementation of EDM diffusion models in https://github.com/NVlabs/edm. We thank Shuchen Xue, and Zhengyang Geng for constructive feedback on code implementations.
78 |
79 |
80 |
--------------------------------------------------------------------------------
/dataset_tool_edm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Tool for creating ZIP/PNG based datasets."""
9 |
10 | import functools
11 | import gzip
12 | import io
13 | import json
14 | import os
15 | import pickle
16 | import re
17 | import sys
18 | import tarfile
19 | import zipfile
20 | from pathlib import Path
21 | from typing import Callable, Optional, Tuple, Union
22 | import click
23 | import numpy as np
24 | import PIL.Image
25 | from tqdm import tqdm
26 |
27 | #----------------------------------------------------------------------------
28 | # Parse a 'M,N' or 'MxN' integer tuple.
29 | # Example: '4x2' returns (4,2)
30 |
31 | def parse_tuple(s: str) -> Tuple[int, int]:
32 | m = re.match(r'^(\d+)[x,](\d+)$', s)
33 | if m:
34 | return int(m.group(1)), int(m.group(2))
35 | raise click.ClickException(f'cannot parse tuple {s}')
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def maybe_min(a: int, b: Optional[int]) -> int:
40 | if b is not None:
41 | return min(a, b)
42 | return a
43 |
44 | #----------------------------------------------------------------------------
45 |
46 | def file_ext(name: Union[str, Path]) -> str:
47 | return str(name).split('.')[-1]
48 |
49 | #----------------------------------------------------------------------------
50 |
51 | def is_image_ext(fname: Union[str, Path]) -> bool:
52 | ext = file_ext(fname).lower()
53 | return f'.{ext}' in PIL.Image.EXTENSION
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | def open_image_folder(source_dir, *, max_images: Optional[int]):
58 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
59 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
60 | max_idx = maybe_min(len(input_images), max_images)
61 |
62 | # Load labels.
63 | labels = dict()
64 | meta_fname = os.path.join(source_dir, 'dataset.json')
65 | if os.path.isfile(meta_fname):
66 | with open(meta_fname, 'r') as file:
67 | data = json.load(file)['labels']
68 | if data is not None:
69 | labels = {x[0]: x[1] for x in data}
70 |
71 | # No labels available => determine from top-level directory names.
72 | if len(labels) == 0:
73 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
74 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
75 | if len(toplevel_indices) > 1:
76 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
77 |
78 | def iterate_images():
79 | for idx, fname in enumerate(input_images):
80 | img = np.array(PIL.Image.open(fname))
81 | yield dict(img=img, label=labels.get(arch_fnames.get(fname)))
82 | if idx >= max_idx - 1:
83 | break
84 | return max_idx, iterate_images()
85 |
86 | #----------------------------------------------------------------------------
87 |
88 | def open_image_zip(source, *, max_images: Optional[int]):
89 | with zipfile.ZipFile(source, mode='r') as z:
90 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
91 | max_idx = maybe_min(len(input_images), max_images)
92 |
93 | # Load labels.
94 | labels = dict()
95 | if 'dataset.json' in z.namelist():
96 | with z.open('dataset.json', 'r') as file:
97 | data = json.load(file)['labels']
98 | if data is not None:
99 | labels = {x[0]: x[1] for x in data}
100 |
101 | def iterate_images():
102 | with zipfile.ZipFile(source, mode='r') as z:
103 | for idx, fname in enumerate(input_images):
104 | with z.open(fname, 'r') as file:
105 | img = np.array(PIL.Image.open(file))
106 | yield dict(img=img, label=labels.get(fname))
107 | if idx >= max_idx - 1:
108 | break
109 | return max_idx, iterate_images()
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
114 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python
115 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb
116 |
117 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
118 | max_idx = maybe_min(txn.stat()['entries'], max_images)
119 |
120 | def iterate_images():
121 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
122 | for idx, (_key, value) in enumerate(txn.cursor()):
123 | try:
124 | try:
125 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
126 | if img is None:
127 | raise IOError('cv2.imdecode failed')
128 | img = img[:, :, ::-1] # BGR => RGB
129 | except IOError:
130 | img = np.array(PIL.Image.open(io.BytesIO(value)))
131 | yield dict(img=img, label=None)
132 | if idx >= max_idx - 1:
133 | break
134 | except:
135 | print(sys.exc_info()[1])
136 |
137 | return max_idx, iterate_images()
138 |
139 | #----------------------------------------------------------------------------
140 |
141 | def open_cifar10(tarball: str, *, max_images: Optional[int]):
142 | images = []
143 | labels = []
144 |
145 | with tarfile.open(tarball, 'r:gz') as tar:
146 | for batch in range(1, 6):
147 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
148 | with tar.extractfile(member) as file:
149 | data = pickle.load(file, encoding='latin1')
150 | images.append(data['data'].reshape(-1, 3, 32, 32))
151 | labels.append(data['labels'])
152 |
153 | images = np.concatenate(images)
154 | labels = np.concatenate(labels)
155 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
156 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
157 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
158 | assert np.min(images) == 0 and np.max(images) == 255
159 | assert np.min(labels) == 0 and np.max(labels) == 9
160 |
161 | max_idx = maybe_min(len(images), max_images)
162 |
163 | def iterate_images():
164 | for idx, img in enumerate(images):
165 | yield dict(img=img, label=int(labels[idx]))
166 | if idx >= max_idx - 1:
167 | break
168 |
169 | return max_idx, iterate_images()
170 |
171 | #----------------------------------------------------------------------------
172 |
173 | def open_mnist(images_gz: str, *, max_images: Optional[int]):
174 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
175 | assert labels_gz != images_gz
176 | images = []
177 | labels = []
178 |
179 | with gzip.open(images_gz, 'rb') as f:
180 | images = np.frombuffer(f.read(), np.uint8, offset=16)
181 | with gzip.open(labels_gz, 'rb') as f:
182 | labels = np.frombuffer(f.read(), np.uint8, offset=8)
183 |
184 | images = images.reshape(-1, 28, 28)
185 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
186 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
187 | assert labels.shape == (60000,) and labels.dtype == np.uint8
188 | assert np.min(images) == 0 and np.max(images) == 255
189 | assert np.min(labels) == 0 and np.max(labels) == 9
190 |
191 | max_idx = maybe_min(len(images), max_images)
192 |
193 | def iterate_images():
194 | for idx, img in enumerate(images):
195 | yield dict(img=img, label=int(labels[idx]))
196 | if idx >= max_idx - 1:
197 | break
198 |
199 | return max_idx, iterate_images()
200 |
201 | #----------------------------------------------------------------------------
202 |
203 | def make_transform(
204 | transform: Optional[str],
205 | output_width: Optional[int],
206 | output_height: Optional[int]
207 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
208 | def scale(width, height, img):
209 | w = img.shape[1]
210 | h = img.shape[0]
211 | if width == w and height == h:
212 | return img
213 | img = PIL.Image.fromarray(img)
214 | ww = width if width is not None else w
215 | hh = height if height is not None else h
216 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
217 | return np.array(img)
218 |
219 | def center_crop(width, height, img):
220 | crop = np.min(img.shape[:2])
221 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
222 | if img.ndim == 2:
223 | img = img[:, :, np.newaxis].repeat(3, axis=2)
224 | img = PIL.Image.fromarray(img, 'RGB')
225 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
226 | return np.array(img)
227 |
228 | def center_crop_wide(width, height, img):
229 | ch = int(np.round(width * img.shape[0] / img.shape[1]))
230 | if img.shape[1] < width or ch < height:
231 | return None
232 |
233 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
234 | if img.ndim == 2:
235 | img = img[:, :, np.newaxis].repeat(3, axis=2)
236 | img = PIL.Image.fromarray(img, 'RGB')
237 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
238 | img = np.array(img)
239 |
240 | canvas = np.zeros([width, width, 3], dtype=np.uint8)
241 | canvas[(width - height) // 2 : (width + height) // 2, :] = img
242 | return canvas
243 |
244 | if transform is None:
245 | return functools.partial(scale, output_width, output_height)
246 | if transform == 'center-crop':
247 | if output_width is None or output_height is None:
248 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
249 | return functools.partial(center_crop, output_width, output_height)
250 | if transform == 'center-crop-wide':
251 | if output_width is None or output_height is None:
252 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
253 | return functools.partial(center_crop_wide, output_width, output_height)
254 | assert False, 'unknown transform'
255 |
256 | #----------------------------------------------------------------------------
257 |
258 | def open_dataset(source, *, max_images: Optional[int]):
259 | if os.path.isdir(source):
260 | if source.rstrip('/').endswith('_lmdb'):
261 | return open_lmdb(source, max_images=max_images)
262 | else:
263 | return open_image_folder(source, max_images=max_images)
264 | elif os.path.isfile(source):
265 | if os.path.basename(source) == 'cifar-10-python.tar.gz':
266 | return open_cifar10(source, max_images=max_images)
267 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
268 | return open_mnist(source, max_images=max_images)
269 | elif file_ext(source) == 'zip':
270 | return open_image_zip(source, max_images=max_images)
271 | else:
272 | assert False, 'unknown archive type'
273 | else:
274 | raise click.ClickException(f'Missing input file or directory: {source}')
275 |
276 | #----------------------------------------------------------------------------
277 |
278 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
279 | dest_ext = file_ext(dest)
280 |
281 | if dest_ext == 'zip':
282 | if os.path.dirname(dest) != '':
283 | os.makedirs(os.path.dirname(dest), exist_ok=True)
284 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
285 | def zip_write_bytes(fname: str, data: Union[bytes, str]):
286 | zf.writestr(fname, data)
287 | return '', zip_write_bytes, zf.close
288 | else:
289 | # If the output folder already exists, check that is is
290 | # empty.
291 | #
292 | # Note: creating the output directory is not strictly
293 | # necessary as folder_write_bytes() also mkdirs, but it's better
294 | # to give an error message earlier in case the dest folder
295 | # somehow cannot be created.
296 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
297 | raise click.ClickException('--dest folder must be empty')
298 | os.makedirs(dest, exist_ok=True)
299 |
300 | def folder_write_bytes(fname: str, data: Union[bytes, str]):
301 | os.makedirs(os.path.dirname(fname), exist_ok=True)
302 | with open(fname, 'wb') as fout:
303 | if isinstance(data, str):
304 | data = data.encode('utf8')
305 | fout.write(data)
306 | return dest, folder_write_bytes, lambda: None
307 |
308 | #----------------------------------------------------------------------------
309 |
310 | @click.command()
311 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
312 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
313 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
314 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide']))
315 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
316 |
317 | def main(
318 | source: str,
319 | dest: str,
320 | max_images: Optional[int],
321 | transform: Optional[str],
322 | resolution: Optional[Tuple[int, int]]
323 | ):
324 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
325 |
326 | The input dataset format is guessed from the --source argument:
327 |
328 | \b
329 | --source *_lmdb/ Load LSUN dataset
330 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset
331 | --source train-images-idx3-ubyte.gz Load MNIST dataset
332 | --source path/ Recursively load all images from path/
333 | --source dataset.zip Recursively load all images from dataset.zip
334 |
335 | Specifying the output format and path:
336 |
337 | \b
338 | --dest /path/to/dir Save output files under /path/to/dir
339 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
340 |
341 | The output dataset format can be either an image folder or an uncompressed zip archive.
342 | Zip archives makes it easier to move datasets around file servers and clusters, and may
343 | offer better training performance on network file systems.
344 |
345 | Images within the dataset archive will be stored as uncompressed PNG.
346 | Uncompresed PNGs can be efficiently decoded in the training loop.
347 |
348 | Class labels are stored in a file called 'dataset.json' that is stored at the
349 | dataset root folder. This file has the following structure:
350 |
351 | \b
352 | {
353 | "labels": [
354 | ["00000/img00000000.png",6],
355 | ["00000/img00000001.png",9],
356 | ... repeated for every image in the datase
357 | ["00049/img00049999.png",1]
358 | ]
359 | }
360 |
361 | If the 'dataset.json' file cannot be found, class labels are determined from
362 | top-level directory names.
363 |
364 | Image scale/crop and resolution requirements:
365 |
366 | Output images must be square-shaped and they must all have the same power-of-two
367 | dimensions.
368 |
369 | To scale arbitrary input image size to a specific width and height, use the
370 | --resolution option. Output resolution will be either the original
371 | input resolution (if resolution was not specified) or the one specified with
372 | --resolution option.
373 |
374 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a
375 | center crop transform on the input image. These options should be used with the
376 | --resolution option. For example:
377 |
378 | \b
379 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
380 | --transform=center-crop-wide --resolution=512x384
381 | """
382 |
383 | PIL.Image.init()
384 |
385 | if dest == '':
386 | raise click.ClickException('--dest output filename or directory must not be an empty string')
387 |
388 | num_files, input_iter = open_dataset(source, max_images=max_images)
389 | archive_root_dir, save_bytes, close_dest = open_dest(dest)
390 |
391 | if resolution is None: resolution = (None, None)
392 | transform_image = make_transform(transform, *resolution)
393 |
394 | dataset_attrs = None
395 |
396 | labels = []
397 | for idx, image in tqdm(enumerate(input_iter), total=num_files):
398 | idx_str = f'{idx:08d}'
399 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
400 |
401 | # Apply crop and resize.
402 | img = transform_image(image['img'])
403 | if img is None:
404 | continue
405 |
406 | # Error check to require uniform image attributes across
407 | # the whole dataset.
408 | channels = img.shape[2] if img.ndim == 3 else 1
409 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels}
410 | if dataset_attrs is None:
411 | dataset_attrs = cur_image_attrs
412 | width = dataset_attrs['width']
413 | height = dataset_attrs['height']
414 | if width != height:
415 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
416 | if dataset_attrs['channels'] not in [1, 3]:
417 | raise click.ClickException('Input images must be stored as RGB or grayscale')
418 | if width != 2 ** int(np.floor(np.log2(width))):
419 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
420 | elif dataset_attrs != cur_image_attrs:
421 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
422 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
423 |
424 | # Save the image as an uncompressed PNG.
425 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels])
426 | image_bits = io.BytesIO()
427 | img.save(image_bits, format='png', compress_level=0, optimize=False)
428 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
429 | labels.append([archive_fname, image['label']] if image['label'] is not None else None)
430 |
431 | metadata = {'labels': labels if all(x is not None for x in labels) else None}
432 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
433 | close_dest()
434 |
435 | #----------------------------------------------------------------------------
436 |
437 | if __name__ == "__main__":
438 | main()
439 |
440 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/di_train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Weijian Luo, Peking University . All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Train one-step diffusion-based generative model using the techniques described in the
9 | paper "Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models"
10 | Weijian Luo, Tianyang Hu, Shifeng Zhang, Jiacheng Sun, Zhenguo Li and Zhihua Zhang.
11 |
12 | https://github.com/pkulwj1994/diff_instruct
13 |
14 | Code was modified from paper ""Elucidating the Design Space of Diffusion-Based Generative Models""
15 | https://github.com/NVlabs/edm
16 | """
17 |
18 | import os
19 | import re
20 | import json
21 | import click
22 | import torch
23 | import dnnlib
24 | from torch_utils import distributed as dist
25 | from training import di_training_loop as training_loop
26 |
27 | import warnings
28 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
29 |
30 | #----------------------------------------------------------------------------
31 | # Parse a comma separated list of numbers or ranges and return a list of ints.
32 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
33 |
34 | def parse_int_list(s):
35 | if isinstance(s, list): return s
36 | ranges = []
37 | range_re = re.compile(r'^(\d+)-(\d+)$')
38 | for p in s.split(','):
39 | m = range_re.match(p)
40 | if m:
41 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
42 | else:
43 | ranges.append(int(p))
44 | return ranges
45 |
46 | class CommaSeparatedList(click.ParamType):
47 | name = 'list'
48 |
49 | def convert(self, value, param, ctx):
50 | _ = param, ctx
51 | if value is None or value.lower() == 'none' or value == '':
52 | return []
53 | return value.split(',')
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | @click.command()
58 |
59 | # Main options.gpu
60 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True)
61 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
62 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
63 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True)
64 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True)
65 |
66 | # Hyperparameters.
67 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
68 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
69 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
70 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
71 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
72 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
73 | @click.option('--glr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
74 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
75 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
76 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
77 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
78 |
79 | # Performance-related.
80 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
81 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
82 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
83 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
84 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
85 |
86 | # I/O-related.
87 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
88 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
89 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
90 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)
91 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
92 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
93 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
94 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
95 |
96 | @click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
97 | @click.option('--edm_model', help='edm_model', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['cifar10-uncond', 'cifar10-cond', 'ffhq64', 'afhq64-v2', 'imagenet64-cond', 'ffhq64-uncond', 'afhqv2_64-uncond']), default='cifar10-cond', show_default=True)
98 | @click.option('--init_sigma', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
99 |
100 | @click.option('--ema_mu', help='ema rate', metavar='FLOAT', type=click.FloatRange(min=-1.5, min_open=True), default=-1.0, show_default=True)
101 | @click.option('--lr_warmup_kimg', help='lr warmup', metavar='KIMG', type=click.IntRange(min=-2), default=-1, show_default=True)
102 | @click.option('--sgls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
103 |
104 |
105 | def main(**kwargs):
106 | """Train diffusion-based generative model using the techniques described in the
107 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".
108 |
109 | Examples:
110 |
111 | \b
112 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
113 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\
114 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
115 | """
116 | opts = dnnlib.EasyDict(kwargs)
117 | torch.multiprocessing.set_start_method('spawn')
118 | dist.init()
119 |
120 | # Initialize config dict.
121 | c = dnnlib.EasyDict()
122 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
123 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
124 | c.network_kwargs = dnnlib.EasyDict()
125 | c.loss_kwargs = dnnlib.EasyDict()
126 |
127 | c.sg_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.0,0.999], eps=1e-8)
128 | c.g_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.glr, betas=[0.0,0.999], eps=1e-8)
129 |
130 | c.init_sigma = opts.init_sigma
131 | c.ema_mu = opts.ema_mu
132 | c.use_fp16 = opts.fp16
133 | c.lr_rampup_kimg = opts.lr_warmup_kimg
134 |
135 |
136 | # Validate dataset options.
137 | try:
138 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
139 | dataset_name = dataset_obj.name
140 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
141 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
142 | if opts.cond and not dataset_obj.has_labels:
143 | raise click.ClickException('--cond=True requires labels specified in dataset.json')
144 | del dataset_obj # conserve memory
145 | except IOError as err:
146 | raise click.ClickException(f'--data: {err}')
147 |
148 | # Network architecture.
149 | if opts.arch == 'ddpmpp':
150 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
151 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
152 | elif opts.arch == 'ncsnpp':
153 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
154 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2])
155 | else:
156 | assert opts.arch == 'adm'
157 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
158 |
159 | # Training options.
160 | c.total_kimg = max(int(opts.duration * 1000), 1)
161 | c.ema_halflife_kimg = int(opts.ema * 1000)
162 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
163 | c.update(loss_scaling=opts.ls, sgls=opts.sgls, cudnn_benchmark=opts.bench)
164 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap)
165 |
166 | # Random seed.
167 | if opts.seed is not None:
168 | c.seed = opts.seed
169 | else:
170 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
171 | torch.distributed.broadcast(seed, src=0)
172 | c.seed = int(seed)
173 |
174 | # Preconditioning & loss function.
175 | if opts.precond == 'vp':
176 | c.network_kwargs.class_name = 'training.networks.VPPrecond'
177 | c.loss_kwargs.class_name = 'training.loss.VPLoss'
178 | elif opts.precond == 've':
179 | c.network_kwargs.class_name = 'training.networks.VEPrecond'
180 | c.loss_kwargs.class_name = 'training.loss.VELoss'
181 | else:
182 | assert opts.precond == 'edm'
183 | c.network_kwargs.class_name = 'training.networks.EDMPrecond'
184 | c.loss_kwargs.class_name = 'training.loss.EDMLoss'
185 |
186 | c.loss_kwargs.class_name = 'training.di_loss.DI_EDMLoss'
187 | c.metrics = opts.metrics
188 |
189 | # Network options.
190 | if opts.cbase is not None:
191 | c.network_kwargs.model_channels = opts.cbase
192 | if opts.cres is not None:
193 | c.network_kwargs.channel_mult = opts.cres
194 | if opts.augment:
195 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
196 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
197 | c.network_kwargs.augment_dim = 9
198 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
199 |
200 | # Training options.
201 | c.total_kimg = max(int(opts.duration * 1000), 1)
202 | c.ema_halflife_kimg = int(opts.ema * 1000)
203 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
204 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
205 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap)
206 |
207 | # Random seed.
208 | if opts.seed is not None:
209 | c.seed = opts.seed
210 | else:
211 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
212 | torch.distributed.broadcast(seed, src=0)
213 | c.seed = int(seed)
214 |
215 | resume_specs = {
216 | 'cifar10-uncond': 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl',
217 | 'cifar10-cond': 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl',
218 | }
219 |
220 | c.resume_pkl = resume_specs[opts.edm_model]
221 | if opts.transfer is not None:
222 | c.transfer_pkl = opts.transfer
223 | c.ema_rampup_ratio = None
224 | else:
225 | c.transfer_pkl = None
226 |
227 | # Description string.
228 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
229 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
230 | desc = f'{dataset_name:s}-{cond_str:s}-ls{opts.ls}-sgls{opts.sgls}-glr{opts.glr}-sglr{opts.lr}-sigma{opts.init_sigma}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}-lrwarmkimg{opts.lr_warmup_kimg}'
231 | if opts.desc is not None:
232 | desc += f'-{opts.desc}'
233 |
234 | # Pick output directory.
235 | if dist.get_rank() != 0:
236 | c.run_dir = None
237 | elif opts.nosubdir:
238 | c.run_dir = opts.outdir
239 | else:
240 | prev_run_dirs = []
241 | if os.path.isdir(opts.outdir):
242 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
243 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
244 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
245 | cur_run_id = max(prev_run_ids, default=-1) + 1
246 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
247 | assert not os.path.exists(c.run_dir)
248 |
249 | # Print options.
250 | dist.print0()
251 | dist.print0('Training options:')
252 | dist.print0(json.dumps(c, indent=2))
253 | dist.print0()
254 | dist.print0(f'Output directory: {c.run_dir}')
255 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
256 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
257 | dist.print0(f'Network architecture: {opts.arch}')
258 | dist.print0(f'Preconditioning & loss: {opts.precond}')
259 | dist.print0(f'Number of GPUs: {dist.get_world_size()}')
260 | dist.print0(f'Batch size: {c.batch_size}')
261 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
262 | dist.print0()
263 |
264 | # Dry run?
265 | if opts.dry_run:
266 | dist.print0('Dry run; exiting.')
267 | return
268 |
269 | # Create output directory.
270 | dist.print0('Creating output directory...')
271 | if dist.get_rank() == 0:
272 | os.makedirs(c.run_dir, exist_ok=True)
273 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
274 | json.dump(c, f, indent=2)
275 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
276 |
277 | # Train.
278 | training_loop.training_loop(**c)
279 |
280 | #----------------------------------------------------------------------------
281 |
282 | if __name__ == "__main__":
283 | main()
284 |
285 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import tempfile
27 | import urllib
28 | import urllib.request
29 | import uuid
30 |
31 | from distutils.util import strtobool
32 | from typing import Any, List, Tuple, Union, Optional
33 |
34 |
35 | # Util classes
36 | # ------------------------------------------------------------------------------------------
37 |
38 |
39 | class EasyDict(dict):
40 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41 |
42 | def __getattr__(self, name: str) -> Any:
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name: str, value: Any) -> None:
49 | self[name] = value
50 |
51 | def __delattr__(self, name: str) -> None:
52 | del self[name]
53 |
54 |
55 | class Logger(object):
56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57 |
58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59 | self.file = None
60 |
61 | if file_name is not None:
62 | self.file = open(file_name, file_mode)
63 |
64 | self.should_flush = should_flush
65 | self.stdout = sys.stdout
66 | self.stderr = sys.stderr
67 |
68 | sys.stdout = self
69 | sys.stderr = self
70 |
71 | def __enter__(self) -> "Logger":
72 | return self
73 |
74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75 | self.close()
76 |
77 | def write(self, text: Union[str, bytes]) -> None:
78 | """Write text to stdout (and a file) and optionally flush."""
79 | if isinstance(text, bytes):
80 | text = text.decode()
81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82 | return
83 |
84 | if self.file is not None:
85 | self.file.write(text)
86 |
87 | self.stdout.write(text)
88 |
89 | if self.should_flush:
90 | self.flush()
91 |
92 | def flush(self) -> None:
93 | """Flush written text to both stdout and a file, if open."""
94 | if self.file is not None:
95 | self.file.flush()
96 |
97 | self.stdout.flush()
98 |
99 | def close(self) -> None:
100 | """Flush, close possible files, and remove stdout/stderr mirroring."""
101 | self.flush()
102 |
103 | # if using multiple loggers, prevent closing in wrong order
104 | if sys.stdout is self:
105 | sys.stdout = self.stdout
106 | if sys.stderr is self:
107 | sys.stderr = self.stderr
108 |
109 | if self.file is not None:
110 | self.file.close()
111 | self.file = None
112 |
113 |
114 | # Cache directories
115 | # ------------------------------------------------------------------------------------------
116 |
117 | _dnnlib_cache_dir = None
118 |
119 | def set_cache_dir(path: str) -> None:
120 | global _dnnlib_cache_dir
121 | _dnnlib_cache_dir = path
122 |
123 | def make_cache_dir_path(*paths: str) -> str:
124 | if _dnnlib_cache_dir is not None:
125 | return os.path.join(_dnnlib_cache_dir, *paths)
126 | if 'DNNLIB_CACHE_DIR' in os.environ:
127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128 | if 'HOME' in os.environ:
129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130 | if 'USERPROFILE' in os.environ:
131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133 |
134 | # Small util functions
135 | # ------------------------------------------------------------------------------------------
136 |
137 |
138 | def format_time(seconds: Union[int, float]) -> str:
139 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140 | s = int(np.rint(seconds))
141 |
142 | if s < 60:
143 | return "{0}s".format(s)
144 | elif s < 60 * 60:
145 | return "{0}m {1:02}s".format(s // 60, s % 60)
146 | elif s < 24 * 60 * 60:
147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148 | else:
149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150 |
151 |
152 | def format_time_brief(seconds: Union[int, float]) -> str:
153 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154 | s = int(np.rint(seconds))
155 |
156 | if s < 60:
157 | return "{0}s".format(s)
158 | elif s < 60 * 60:
159 | return "{0}m {1:02}s".format(s // 60, s % 60)
160 | elif s < 24 * 60 * 60:
161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162 | else:
163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164 |
165 |
166 | def ask_yes_no(question: str) -> bool:
167 | """Ask the user the question until the user inputs a valid answer."""
168 | while True:
169 | try:
170 | print("{0} [y/n]".format(question))
171 | return strtobool(input().lower())
172 | except ValueError:
173 | pass
174 |
175 |
176 | def tuple_product(t: Tuple) -> Any:
177 | """Calculate the product of the tuple elements."""
178 | result = 1
179 |
180 | for v in t:
181 | result *= v
182 |
183 | return result
184 |
185 |
186 | _str_to_ctype = {
187 | "uint8": ctypes.c_ubyte,
188 | "uint16": ctypes.c_uint16,
189 | "uint32": ctypes.c_uint32,
190 | "uint64": ctypes.c_uint64,
191 | "int8": ctypes.c_byte,
192 | "int16": ctypes.c_int16,
193 | "int32": ctypes.c_int32,
194 | "int64": ctypes.c_int64,
195 | "float32": ctypes.c_float,
196 | "float64": ctypes.c_double
197 | }
198 |
199 |
200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202 | type_str = None
203 |
204 | if isinstance(type_obj, str):
205 | type_str = type_obj
206 | elif hasattr(type_obj, "__name__"):
207 | type_str = type_obj.__name__
208 | elif hasattr(type_obj, "name"):
209 | type_str = type_obj.name
210 | else:
211 | raise RuntimeError("Cannot infer type name from input")
212 |
213 | assert type_str in _str_to_ctype.keys()
214 |
215 | my_dtype = np.dtype(type_str)
216 | my_ctype = _str_to_ctype[type_str]
217 |
218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219 |
220 | return my_dtype, my_ctype
221 |
222 |
223 | def is_pickleable(obj: Any) -> bool:
224 | try:
225 | with io.BytesIO() as stream:
226 | pickle.dump(obj, stream)
227 | return True
228 | except:
229 | return False
230 |
231 |
232 | # Functionality to import modules/objects by name, and call functions by name
233 | # ------------------------------------------------------------------------------------------
234 |
235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236 | """Searches for the underlying module behind the name to some python object.
237 | Returns the module and the object name (original name with module part removed)."""
238 |
239 | # allow convenience shorthands, substitute them by full names
240 | obj_name = re.sub("^np.", "numpy.", obj_name)
241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242 |
243 | # list alternatives for (module_name, local_obj_name)
244 | parts = obj_name.split(".")
245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246 |
247 | # try each alternative in turn
248 | for module_name, local_obj_name in name_pairs:
249 | try:
250 | module = importlib.import_module(module_name) # may raise ImportError
251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
252 | return module, local_obj_name
253 | except:
254 | pass
255 |
256 | # maybe some of the modules themselves contain errors?
257 | for module_name, _local_obj_name in name_pairs:
258 | try:
259 | importlib.import_module(module_name) # may raise ImportError
260 | except ImportError:
261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262 | raise
263 |
264 | # maybe the requested attribute is missing?
265 | for module_name, local_obj_name in name_pairs:
266 | try:
267 | module = importlib.import_module(module_name) # may raise ImportError
268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
269 | except ImportError:
270 | pass
271 |
272 | # we are out of luck, but we have no idea why
273 | raise ImportError(obj_name)
274 |
275 |
276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277 | """Traverses the object name and returns the last (rightmost) python object."""
278 | if obj_name == '':
279 | return module
280 | obj = module
281 | for part in obj_name.split("."):
282 | obj = getattr(obj, part)
283 | return obj
284 |
285 |
286 | def get_obj_by_name(name: str) -> Any:
287 | """Finds the python object with the given name."""
288 | module, obj_name = get_module_from_obj_name(name)
289 | return get_obj_from_module(module, obj_name)
290 |
291 |
292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293 | """Finds the python object with the given name and calls it as a function."""
294 | assert func_name is not None
295 | func_obj = get_obj_by_name(func_name)
296 | assert callable(func_obj)
297 | return func_obj(*args, **kwargs)
298 |
299 |
300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301 | """Finds the python class with the given name and constructs it with the given arguments."""
302 | return call_func_by_name(*args, func_name=class_name, **kwargs)
303 |
304 |
305 | def get_module_dir_by_obj_name(obj_name: str) -> str:
306 | """Get the directory path of the module containing the given object name."""
307 | module, _ = get_module_from_obj_name(obj_name)
308 | return os.path.dirname(inspect.getfile(module))
309 |
310 |
311 | def is_top_level_function(obj: Any) -> bool:
312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314 |
315 |
316 | def get_top_level_function_name(obj: Any) -> str:
317 | """Return the fully-qualified name of a top-level function."""
318 | assert is_top_level_function(obj)
319 | module = obj.__module__
320 | if module == '__main__':
321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322 | return module + "." + obj.__name__
323 |
324 |
325 | # File system helpers
326 | # ------------------------------------------------------------------------------------------
327 |
328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329 | """List all files recursively in a given directory while ignoring given file and directory names.
330 | Returns list of tuples containing both absolute and relative paths."""
331 | assert os.path.isdir(dir_path)
332 | base_name = os.path.basename(os.path.normpath(dir_path))
333 |
334 | if ignores is None:
335 | ignores = []
336 |
337 | result = []
338 |
339 | for root, dirs, files in os.walk(dir_path, topdown=True):
340 | for ignore_ in ignores:
341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342 |
343 | # dirs need to be edited in-place
344 | for d in dirs_to_remove:
345 | dirs.remove(d)
346 |
347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348 |
349 | absolute_paths = [os.path.join(root, f) for f in files]
350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351 |
352 | if add_base_to_relative:
353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354 |
355 | assert len(absolute_paths) == len(relative_paths)
356 | result += zip(absolute_paths, relative_paths)
357 |
358 | return result
359 |
360 |
361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362 | """Takes in a list of tuples of (src, dst) paths and copies files.
363 | Will create all necessary directories."""
364 | for file in files:
365 | target_dir_name = os.path.dirname(file[1])
366 |
367 | # will create all intermediate-level directories
368 | if not os.path.exists(target_dir_name):
369 | os.makedirs(target_dir_name)
370 |
371 | shutil.copyfile(file[0], file[1])
372 |
373 |
374 | # URL helpers
375 | # ------------------------------------------------------------------------------------------
376 |
377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378 | """Determine whether the given object is a valid URL string."""
379 | if not isinstance(obj, str) or not "://" in obj:
380 | return False
381 | if allow_file_urls and obj.startswith('file://'):
382 | return True
383 | try:
384 | res = requests.compat.urlparse(obj)
385 | if not res.scheme or not res.netloc or not "." in res.netloc:
386 | return False
387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388 | if not res.scheme or not res.netloc or not "." in res.netloc:
389 | return False
390 | except:
391 | return False
392 | return True
393 |
394 |
395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396 | """Download the given URL and return a binary-mode file object to access the data."""
397 | assert num_attempts >= 1
398 | assert not (return_filename and (not cache))
399 |
400 | # Doesn't look like an URL scheme so interpret it as a local filename.
401 | if not re.match('^[a-z]+://', url):
402 | return url if return_filename else open(url, "rb")
403 |
404 | # Handle file URLs. This code handles unusual file:// patterns that
405 | # arise on Windows:
406 | #
407 | # file:///c:/foo.txt
408 | #
409 | # which would translate to a local '/c:/foo.txt' filename that's
410 | # invalid. Drop the forward slash for such pathnames.
411 | #
412 | # If you touch this code path, you should test it on both Linux and
413 | # Windows.
414 | #
415 | # Some internet resources suggest using urllib.request.url2pathname() but
416 | # but that converts forward slashes to backslashes and this causes
417 | # its own set of problems.
418 | if url.startswith('file://'):
419 | filename = urllib.parse.urlparse(url).path
420 | if re.match(r'^/[a-zA-Z]:', filename):
421 | filename = filename[1:]
422 | return filename if return_filename else open(filename, "rb")
423 |
424 | assert is_url(url)
425 |
426 | # Lookup from cache.
427 | if cache_dir is None:
428 | cache_dir = make_cache_dir_path('downloads')
429 |
430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431 | if cache:
432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433 | if len(cache_files) == 1:
434 | filename = cache_files[0]
435 | return filename if return_filename else open(filename, "rb")
436 |
437 | # Download.
438 | url_name = None
439 | url_data = None
440 | with requests.Session() as session:
441 | if verbose:
442 | print("Downloading %s ..." % url, end="", flush=True)
443 | for attempts_left in reversed(range(num_attempts)):
444 | try:
445 | with session.get(url) as res:
446 | res.raise_for_status()
447 | if len(res.content) == 0:
448 | raise IOError("No data received")
449 |
450 | if len(res.content) < 8192:
451 | content_str = res.content.decode("utf-8")
452 | if "download_warning" in res.headers.get("Set-Cookie", ""):
453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454 | if len(links) == 1:
455 | url = requests.compat.urljoin(url, links[0])
456 | raise IOError("Google Drive virus checker nag")
457 | if "Google Drive - Quota exceeded" in content_str:
458 | raise IOError("Google Drive download quota exceeded -- please try again later")
459 |
460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461 | url_name = match[1] if match else url
462 | url_data = res.content
463 | if verbose:
464 | print(" done")
465 | break
466 | except KeyboardInterrupt:
467 | raise
468 | except:
469 | if not attempts_left:
470 | if verbose:
471 | print(" failed")
472 | raise
473 | if verbose:
474 | print(".", end="", flush=True)
475 |
476 | # Save to cache.
477 | if cache:
478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479 | safe_name = safe_name[:min(len(safe_name), 128)]
480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482 | os.makedirs(cache_dir, exist_ok=True)
483 | with open(temp_file, "wb") as f:
484 | f.write(url_data)
485 | os.replace(temp_file, cache_file) # atomic
486 | if return_filename:
487 | return cache_file
488 |
489 | # Return data as file object.
490 | assert not return_filename
491 | return io.BytesIO(url_data)
492 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/metrics/di_frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Frechet Inception Distance (FID) from the paper
10 | "GANs trained by a two time-scale update rule converge to a local Nash
11 | equilibrium". Matches the original implementation by Heusel et al. at
12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13 |
14 | import numpy as np
15 | import scipy.linalg
16 | from . import di_metric_utils as metric_utils
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_fid(opts, max_real, num_gen):
21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24 |
25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28 |
29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32 |
33 | if opts.rank != 0:
34 | return float('nan')
35 |
36 | m = np.square(mu_gen - mu_real).sum()
37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39 | return float(fid)
40 |
41 | #----------------------------------------------------------------------------
42 |
--------------------------------------------------------------------------------
/metrics/di_inception_score.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Inception Score (IS) from the paper "Improved techniques for training
10 | GANs". Matches the original implementation by Salimans et al. at
11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12 |
13 | import numpy as np
14 | from . import di_metric_utils as metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_is(opts, num_gen, num_splits):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21 | # detector_url = '/home/luoweijian/work/edm/cache/inception-2015-12-05.pt'
22 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
23 |
24 | gen_probs = metric_utils.compute_feature_stats_for_generator(
25 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
26 | capture_all=True, max_items=num_gen).get_all()
27 |
28 | if opts.rank != 0:
29 | return float('nan'), float('nan')
30 |
31 | scores = []
32 | for i in range(num_splits):
33 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
34 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
35 | kl = np.mean(np.sum(kl, axis=1))
36 | scores.append(np.exp(kl))
37 | return float(np.mean(scores)), float(np.std(scores))
38 |
39 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/metrics/di_kernel_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10 | GANs". Matches the original implementation by Binkowski et al. at
11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12 |
13 | import numpy as np
14 | from . import di_metric_utils as metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = '/home/luoweijian/work/edm/cache/inception-2015-12-05.pt'
21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22 |
23 | real_features = metric_utils.compute_feature_stats_for_dataset(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26 |
27 | gen_features = metric_utils.compute_feature_stats_for_generator(
28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30 |
31 | if opts.rank != 0:
32 | return float('nan')
33 |
34 | n = real_features.shape[1]
35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36 | t = 0
37 | for _subset_idx in range(num_subsets):
38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41 | b = (x @ y.T / n + 1) ** 3
42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43 | kid = t / num_subsets / m
44 | return float(kid)
45 |
46 | #----------------------------------------------------------------------------
47 |
--------------------------------------------------------------------------------
/metrics/di_metric_main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import json
12 | import torch
13 | import dnnlib
14 |
15 | from . import di_metric_utils as metric_utils
16 | from . import di_frechet_inception_distance as frechet_inception_distance
17 | from . import di_kernel_inception_distance
18 | from . import di_precision_recall as precision_recall
19 | from . import perceptual_path_length
20 | from . import di_inception_score as inception_score
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | _metric_dict = dict() # name => fn
25 |
26 | def register_metric(fn):
27 | assert callable(fn)
28 | _metric_dict[fn.__name__] = fn
29 | return fn
30 |
31 | def is_valid_metric(metric):
32 | return metric in _metric_dict
33 |
34 | def list_valid_metrics():
35 | return list(_metric_dict.keys())
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40 | assert is_valid_metric(metric)
41 | opts = metric_utils.MetricOptions(**kwargs)
42 |
43 | # Calculate.
44 | start_time = time.time()
45 | results = _metric_dict[metric](opts)
46 | total_time = time.time() - start_time
47 |
48 | # Broadcast results.
49 | for key, value in list(results.items()):
50 | if opts.num_gpus > 1:
51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52 | torch.distributed.broadcast(tensor=value, src=0)
53 | value = float(value.cpu())
54 | results[key] = value
55 |
56 | # Decorate with metadata.
57 | return dnnlib.EasyDict(
58 | results = dnnlib.EasyDict(results),
59 | metric = metric,
60 | total_time = total_time,
61 | total_time_str = dnnlib.util.format_time(total_time),
62 | num_gpus = opts.num_gpus,
63 | )
64 |
65 | #----------------------------------------------------------------------------
66 |
67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68 | metric = result_dict['metric']
69 | assert is_valid_metric(metric)
70 | if run_dir is not None and snapshot_pkl is not None:
71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72 |
73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74 | print(jsonl_line)
75 | if run_dir is not None and os.path.isdir(run_dir):
76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77 | f.write(jsonl_line + '\n')
78 |
79 | #----------------------------------------------------------------------------
80 | # Primary metrics.
81 |
82 | @register_metric
83 | def fid50k_full(opts):
84 | opts.dataset_kwargs.update(max_size=None, xflip=False)
85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86 | return dict(fid50k_full=fid)
87 |
88 | @register_metric
89 | def kid50k_full(opts):
90 | opts.dataset_kwargs.update(max_size=None, xflip=False)
91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
92 | return dict(kid50k_full=kid)
93 |
94 | @register_metric
95 | def pr50k3_full(opts):
96 | opts.dataset_kwargs.update(max_size=None, xflip=False)
97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99 |
100 | @register_metric
101 | def ppl2_wend(opts):
102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103 | return dict(ppl2_wend=ppl)
104 |
105 | @register_metric
106 | def is50k(opts):
107 | opts.dataset_kwargs.update(max_size=None, xflip=False)
108 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
109 | return dict(is50k_mean=mean, is50k_std=std)
110 |
111 | #----------------------------------------------------------------------------
112 | # Legacy metrics.
113 |
114 | @register_metric
115 | def fid50k(opts):
116 | opts.dataset_kwargs.update(max_size=None)
117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118 | return dict(fid50k=fid)
119 |
120 | @register_metric
121 | def kid50k(opts):
122 | opts.dataset_kwargs.update(max_size=None)
123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124 | return dict(kid50k=kid)
125 |
126 | @register_metric
127 | def pr50k3(opts):
128 | opts.dataset_kwargs.update(max_size=None)
129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131 |
132 | @register_metric
133 | def ppl_zfull(opts):
134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135 | return dict(ppl_zfull=ppl)
136 |
137 | @register_metric
138 | def ppl_wfull(opts):
139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140 | return dict(ppl_wfull=ppl)
141 |
142 | @register_metric
143 | def ppl_zend(opts):
144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145 | return dict(ppl_zend=ppl)
146 |
147 | @register_metric
148 | def ppl_wend(opts):
149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150 | return dict(ppl_wend=ppl)
151 |
152 | #----------------------------------------------------------------------------
153 |
--------------------------------------------------------------------------------
/metrics/di_metric_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import hashlib
12 | import pickle
13 | import copy
14 | import uuid
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | class MetricOptions:
22 | def __init__(self, G=None, init_sigma=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
23 | assert 0 <= rank < num_gpus
24 | self.G = G
25 | self.G_kwargs = dnnlib.EasyDict(G_kwargs)
26 | self.init_sigma = init_sigma
27 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
28 | self.num_gpus = num_gpus
29 | self.rank = rank
30 | self.device = device if device is not None else torch.device('cuda', rank)
31 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
32 | self.cache = cache
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | _feature_detector_cache = dict()
37 |
38 | def get_feature_detector_name(url):
39 | return os.path.splitext(url.split('/')[-1])[0]
40 |
41 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
42 | assert 0 <= rank < num_gpus
43 | key = (url, device)
44 | if key not in _feature_detector_cache:
45 | is_leader = (rank == 0)
46 | if not is_leader and num_gpus > 1:
47 | torch.distributed.barrier() # leader goes first
48 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
49 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
50 | if is_leader and num_gpus > 1:
51 | torch.distributed.barrier() # others follow
52 | return _feature_detector_cache[key]
53 |
54 | #----------------------------------------------------------------------------
55 |
56 | class FeatureStats:
57 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
58 | self.capture_all = capture_all
59 | self.capture_mean_cov = capture_mean_cov
60 | self.max_items = max_items
61 | self.num_items = 0
62 | self.num_features = None
63 | self.all_features = None
64 | self.raw_mean = None
65 | self.raw_cov = None
66 |
67 | def set_num_features(self, num_features):
68 | if self.num_features is not None:
69 | assert num_features == self.num_features
70 | else:
71 | self.num_features = num_features
72 | self.all_features = []
73 | self.raw_mean = np.zeros([num_features], dtype=np.float64)
74 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
75 |
76 | def is_full(self):
77 | return (self.max_items is not None) and (self.num_items >= self.max_items)
78 |
79 | def append(self, x):
80 | x = np.asarray(x, dtype=np.float32)
81 | assert x.ndim == 2
82 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
83 | if self.num_items >= self.max_items:
84 | return
85 | x = x[:self.max_items - self.num_items]
86 |
87 | self.set_num_features(x.shape[1])
88 | self.num_items += x.shape[0]
89 | if self.capture_all:
90 | self.all_features.append(x)
91 | if self.capture_mean_cov:
92 | x64 = x.astype(np.float64)
93 | self.raw_mean += x64.sum(axis=0)
94 | self.raw_cov += x64.T @ x64
95 |
96 | def append_torch(self, x, num_gpus=1, rank=0):
97 | assert isinstance(x, torch.Tensor) and x.ndim == 2
98 | assert 0 <= rank < num_gpus
99 | if num_gpus > 1:
100 | ys = []
101 | for src in range(num_gpus):
102 | y = x.clone()
103 | torch.distributed.broadcast(y, src=src)
104 | ys.append(y)
105 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
106 | self.append(x.cpu().numpy())
107 |
108 | def get_all(self):
109 | assert self.capture_all
110 | return np.concatenate(self.all_features, axis=0)
111 |
112 | def get_all_torch(self):
113 | return torch.from_numpy(self.get_all())
114 |
115 | def get_mean_cov(self):
116 | assert self.capture_mean_cov
117 | mean = self.raw_mean / self.num_items
118 | cov = self.raw_cov / self.num_items
119 | cov = cov - np.outer(mean, mean)
120 | return mean, cov
121 |
122 | def save(self, pkl_file):
123 | with open(pkl_file, 'wb') as f:
124 | pickle.dump(self.__dict__, f)
125 |
126 | @staticmethod
127 | def load(pkl_file):
128 | with open(pkl_file, 'rb') as f:
129 | s = dnnlib.EasyDict(pickle.load(f))
130 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
131 | obj.__dict__.update(s)
132 | return obj
133 |
134 | #----------------------------------------------------------------------------
135 |
136 | class ProgressMonitor:
137 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
138 | self.tag = tag
139 | self.num_items = num_items
140 | self.verbose = verbose
141 | self.flush_interval = flush_interval
142 | self.progress_fn = progress_fn
143 | self.pfn_lo = pfn_lo
144 | self.pfn_hi = pfn_hi
145 | self.pfn_total = pfn_total
146 | self.start_time = time.time()
147 | self.batch_time = self.start_time
148 | self.batch_items = 0
149 | if self.progress_fn is not None:
150 | self.progress_fn(self.pfn_lo, self.pfn_total)
151 |
152 | def update(self, cur_items):
153 | assert (self.num_items is None) or (cur_items <= self.num_items)
154 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
155 | return
156 | cur_time = time.time()
157 | total_time = cur_time - self.start_time
158 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
159 | if (self.verbose) and (self.tag is not None):
160 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
161 | self.batch_time = cur_time
162 | self.batch_items = cur_items
163 |
164 | if (self.progress_fn is not None) and (self.num_items is not None):
165 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
166 |
167 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
168 | return ProgressMonitor(
169 | tag = tag,
170 | num_items = num_items,
171 | flush_interval = flush_interval,
172 | verbose = self.verbose,
173 | progress_fn = self.progress_fn,
174 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
175 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
176 | pfn_total = self.pfn_total,
177 | )
178 |
179 | #----------------------------------------------------------------------------
180 |
181 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
182 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
183 | if data_loader_kwargs is None:
184 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
185 |
186 | # Try to lookup from cache.
187 | cache_file = None
188 | if opts.cache:
189 | # Choose cache file name.
190 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
191 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
192 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
193 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
194 |
195 | # Check if the file exists (all processes must agree).
196 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False
197 | if opts.num_gpus > 1:
198 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
199 | torch.distributed.broadcast(tensor=flag, src=0)
200 | flag = (float(flag.cpu()) != 0)
201 |
202 | # Load.
203 | if flag:
204 | return FeatureStats.load(cache_file)
205 |
206 | # Initialize.
207 | num_items = len(dataset)
208 | if max_items is not None:
209 | num_items = min(num_items, max_items)
210 | stats = FeatureStats(max_items=num_items, **stats_kwargs)
211 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
212 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
213 |
214 | # Main loop.
215 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
216 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
217 | if images.shape[1] == 1:
218 | images = images.repeat([1, 3, 1, 1])
219 | features = detector(images.to(opts.device), **detector_kwargs)
220 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
221 | progress.update(stats.num_items)
222 |
223 | # Save to cache.
224 | if cache_file is not None and opts.rank == 0:
225 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
226 | temp_file = cache_file + '.' + uuid.uuid4().hex
227 | stats.save(temp_file)
228 | os.replace(temp_file, cache_file) # atomic
229 | return stats
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=128, batch_gen=None, jit=False, **stats_kwargs):
234 | if batch_gen is None:
235 | batch_gen = min(batch_size, 4)
236 | assert batch_size % batch_gen == 0
237 |
238 | # Setup generator and load labels.
239 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
240 | init_sigma = opts.init_sigma
241 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
242 |
243 | # Image generation func.
244 | def run_generator(z, c, init_sigma):
245 | # img = G(z=z, c=c, **opts.G_kwargs)
246 | img = G(z, init_sigma*torch.ones(z.shape[0],1,1,1).to(z.device), c, augment_labels=torch.zeros(z.shape[0], 9).to(z.device))
247 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
248 | return img
249 |
250 | # JIT.
251 | if jit:
252 | z = init_sigma*torch.zeros([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device)
253 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
254 | run_generator = torch.jit.trace(run_generator, [z, c, init_sigma], check_trace=False)
255 |
256 | # Initialize.
257 | stats = FeatureStats(**stats_kwargs)
258 | assert stats.max_items is not None
259 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
260 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
261 |
262 | # Main loop.
263 | while not stats.is_full():
264 | images = []
265 | for _i in range(batch_size // batch_gen):
266 | z = init_sigma*torch.randn([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device)
267 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
268 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
269 | images.append(run_generator(z, c, init_sigma))
270 | images = torch.cat(images)
271 | if images.shape[1] == 1:
272 | images = images.repeat([1, 3, 1, 1])
273 | features = detector(images, **detector_kwargs)
274 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
275 | progress.update(stats.num_items)
276 | return stats
277 |
278 | #----------------------------------------------------------------------------
279 |
--------------------------------------------------------------------------------
/metrics/di_precision_recall.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall
10 | Metric for Assessing Generative Models". Matches the original implementation
11 | by Kynkaanniemi et al. at
12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13 |
14 | import torch
15 | from . import di_metric_utils as metric_utils
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20 | assert 0 <= rank < num_gpus
21 | num_cols = col_features.shape[0]
22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24 | dist_batches = []
25 | for col_batch in col_batches[rank :: num_gpus]:
26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27 | for src in range(num_gpus):
28 | dist_broadcast = dist_batch.clone()
29 | if num_gpus > 1:
30 | torch.distributed.broadcast(dist_broadcast, src=src)
31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38 | detector_kwargs = dict(return_features=True)
39 |
40 | real_features = metric_utils.compute_feature_stats_for_dataset(
41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43 |
44 | gen_features = metric_utils.compute_feature_stats_for_generator(
45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47 |
48 | results = dict()
49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50 | kth = []
51 | for manifold_batch in manifold.split(row_batch_size):
52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54 | kth = torch.cat(kth) if opts.rank == 0 else None
55 | pred = []
56 | for probes_batch in probes.split(row_batch_size):
57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60 | return results['precision'], results['recall']
61 |
62 | #----------------------------------------------------------------------------
63 |
--------------------------------------------------------------------------------
/metrics/perceptual_path_length.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10 | Architecture for Generative Adversarial Networks". Matches the original
11 | implementation by Karras et al. at
12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13 |
14 | import copy
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 | from . import di_metric_utils as metric_utils
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | # Spherical interpolation of a batch of vectors.
23 | def slerp(a, b, t):
24 | a = a / a.norm(dim=-1, keepdim=True)
25 | b = b / b.norm(dim=-1, keepdim=True)
26 | d = (a * b).sum(dim=-1, keepdim=True)
27 | p = t * torch.acos(d)
28 | c = b - d * a
29 | c = c / c.norm(dim=-1, keepdim=True)
30 | d = a * torch.cos(p) + c * torch.sin(p)
31 | d = d / d.norm(dim=-1, keepdim=True)
32 | return d
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | class PPLSampler(torch.nn.Module):
37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38 | assert space in ['z', 'w']
39 | assert sampling in ['full', 'end']
40 | super().__init__()
41 | self.G = copy.deepcopy(G)
42 | self.G_kwargs = G_kwargs
43 | self.epsilon = epsilon
44 | self.space = space
45 | self.sampling = sampling
46 | self.crop = crop
47 | self.vgg16 = copy.deepcopy(vgg16)
48 |
49 | def forward(self, c):
50 | # Generate random latents and interpolation t-values.
51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53 |
54 | # Interpolate in W or Z.
55 | if self.space == 'w':
56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59 | else: # space == 'z'
60 | zt0 = slerp(z0, z1, t.unsqueeze(1))
61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63 |
64 | # Randomize noise buffers.
65 | for name, buf in self.G.named_buffers():
66 | if name.endswith('.noise_const'):
67 | buf.copy_(torch.randn_like(buf))
68 |
69 | # Generate images.
70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71 |
72 | # Center crop.
73 | if self.crop:
74 | assert img.shape[2] == img.shape[3]
75 | c = img.shape[2] // 8
76 | img = img[:, :, c*3 : c*7, c*2 : c*6]
77 |
78 | # Downsample to 256x256.
79 | factor = self.G.img_resolution // 256
80 | if factor > 1:
81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82 |
83 | # Scale dynamic range from [-1,1] to [0,255].
84 | img = (img + 1) * (255 / 2)
85 | if self.G.img_channels == 1:
86 | img = img.repeat([1, 3, 1, 1])
87 |
88 | # Evaluate differential LPIPS.
89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91 | return dist
92 |
93 | #----------------------------------------------------------------------------
94 |
95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99 |
100 | # Setup sampler.
101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102 | sampler.eval().requires_grad_(False).to(opts.device)
103 | if jit:
104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105 | sampler = torch.jit.trace(sampler, [c], check_trace=False)
106 |
107 | # Sampling loop.
108 | dist = []
109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111 | progress.update(batch_start)
112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114 | x = sampler(c)
115 | for src in range(opts.num_gpus):
116 | y = x.clone()
117 | if opts.num_gpus > 1:
118 | torch.distributed.broadcast(y, src=src)
119 | dist.append(y)
120 | progress.update(num_samples)
121 |
122 | # Compute PPL.
123 | if opts.rank != 0:
124 | return float('nan')
125 | dist = torch.cat(dist)[:num_samples].cpu().numpy()
126 | lo = np.percentile(dist, 1, interpolation='lower')
127 | hi = np.percentile(dist, 99, interpolation='higher')
128 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129 | return float(ppl)
130 |
131 | #----------------------------------------------------------------------------
132 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import glob
11 | import torch
12 | import torch.utils.cpp_extension
13 | import importlib
14 | import hashlib
15 | import shutil
16 | from pathlib import Path
17 |
18 | from torch.utils.file_baton import FileBaton
19 |
20 | #----------------------------------------------------------------------------
21 | # Global options.
22 |
23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24 |
25 | #----------------------------------------------------------------------------
26 | # Internal helper funcs.
27 |
28 | def _find_compiler_bindir():
29 | patterns = [
30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34 | ]
35 | for pattern in patterns:
36 | matches = sorted(glob.glob(pattern))
37 | if len(matches):
38 | return matches[-1]
39 | return None
40 |
41 | #----------------------------------------------------------------------------
42 | # Main entry point for compiling and loading C++/CUDA plugins.
43 |
44 | _cached_plugins = dict()
45 |
46 | def get_plugin(module_name, sources, **build_kwargs):
47 | assert verbosity in ['none', 'brief', 'full']
48 |
49 | # Already cached?
50 | if module_name in _cached_plugins:
51 | return _cached_plugins[module_name]
52 |
53 | # Print status.
54 | if verbosity == 'full':
55 | print(f'Setting up PyTorch plugin "{module_name}"...')
56 | elif verbosity == 'brief':
57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58 |
59 | try: # pylint: disable=too-many-nested-blocks
60 | # Make sure we can find the necessary compiler binaries.
61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62 | compiler_bindir = _find_compiler_bindir()
63 | if compiler_bindir is None:
64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65 | os.environ['PATH'] += ';' + compiler_bindir
66 |
67 | # Compile and load.
68 | verbose_build = (verbosity == 'full')
69 |
70 | # Incremental build md5sum trickery. Copies all the input source files
71 | # into a cached build directory under a combined md5 digest of the input
72 | # source files. Copying is done only if the combined digest has changed.
73 | # This keeps input file timestamps and filenames the same as in previous
74 | # extension builds, allowing for fast incremental rebuilds.
75 | #
76 | # This optimization is done only in case all the source files reside in
77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78 | # environment variable is set (we take this as a signal that the user
79 | # actually cares about this.)
80 | source_dirs_set = set(os.path.dirname(source) for source in sources)
81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83 |
84 | # Compute a combined hash digest for all source files in the same
85 | # custom op directory (usually .cu, .cpp, .py and .h files).
86 | hash_md5 = hashlib.md5()
87 | for src in all_source_files:
88 | with open(src, 'rb') as f:
89 | hash_md5.update(f.read())
90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92 |
93 | if not os.path.isdir(digest_build_dir):
94 | os.makedirs(digest_build_dir, exist_ok=True)
95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96 | if baton.try_acquire():
97 | try:
98 | for src in all_source_files:
99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100 | finally:
101 | baton.release()
102 | else:
103 | # Someone else is copying source files under the digest dir,
104 | # wait until done and continue.
105 | baton.wait()
106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108 | verbose=verbose_build, sources=digest_sources, **build_kwargs)
109 | else:
110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111 | module = importlib.import_module(module_name)
112 |
113 | except:
114 | if verbosity == 'brief':
115 | print('Failed!')
116 | raise
117 |
118 | # Print status and add to cache.
119 | if verbosity == 'full':
120 | print(f'Done setting up PyTorch plugin "{module_name}".')
121 | elif verbosity == 'brief':
122 | print('Done.')
123 | _cached_plugins[module_name] = module
124 | return module
125 |
126 | #----------------------------------------------------------------------------
127 |
--------------------------------------------------------------------------------
/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import re
9 | import contextlib
10 | import numpy as np
11 | import torch
12 | import warnings
13 | import dnnlib
14 |
15 | #----------------------------------------------------------------------------
16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17 | # same constant is used multiple times.
18 |
19 | _constant_cache = dict()
20 |
21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22 | value = np.asarray(value)
23 | if shape is not None:
24 | shape = tuple(shape)
25 | if dtype is None:
26 | dtype = torch.get_default_dtype()
27 | if device is None:
28 | device = torch.device('cpu')
29 | if memory_format is None:
30 | memory_format = torch.contiguous_format
31 |
32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33 | tensor = _constant_cache.get(key, None)
34 | if tensor is None:
35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36 | if shape is not None:
37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38 | tensor = tensor.contiguous(memory_format=memory_format)
39 | _constant_cache[key] = tensor
40 | return tensor
41 |
42 | #----------------------------------------------------------------------------
43 | # Replace NaN/Inf with specified numerical values.
44 |
45 | try:
46 | nan_to_num = torch.nan_to_num # 1.8.0a0
47 | except AttributeError:
48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49 | assert isinstance(input, torch.Tensor)
50 | if posinf is None:
51 | posinf = torch.finfo(input.dtype).max
52 | if neginf is None:
53 | neginf = torch.finfo(input.dtype).min
54 | assert nan == 0
55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56 |
57 | #----------------------------------------------------------------------------
58 | # Symbolic assert.
59 |
60 | try:
61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62 | except AttributeError:
63 | symbolic_assert = torch.Assert # 1.7.0
64 |
65 | #----------------------------------------------------------------------------
66 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68 |
69 | @contextlib.contextmanager
70 | def suppress_tracer_warnings():
71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72 | warnings.filters.insert(0, flt)
73 | yield
74 | warnings.filters.remove(flt)
75 |
76 | #----------------------------------------------------------------------------
77 | # Assert that the shape of a tensor matches the given list of integers.
78 | # None indicates that the size of a dimension is allowed to vary.
79 | # Performs symbolic assertion when used in torch.jit.trace().
80 |
81 | def assert_shape(tensor, ref_shape):
82 | if tensor.ndim != len(ref_shape):
83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85 | if ref_size is None:
86 | pass
87 | elif isinstance(ref_size, torch.Tensor):
88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90 | elif isinstance(size, torch.Tensor):
91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93 | elif size != ref_size:
94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95 |
96 | #----------------------------------------------------------------------------
97 | # Function decorator that calls torch.autograd.profiler.record_function().
98 |
99 | def profiled_function(fn):
100 | def decorator(*args, **kwargs):
101 | with torch.autograd.profiler.record_function(fn.__name__):
102 | return fn(*args, **kwargs)
103 | decorator.__name__ = fn.__name__
104 | return decorator
105 |
106 | #----------------------------------------------------------------------------
107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
108 | # indefinitely, shuffling items as it goes.
109 |
110 | class InfiniteSampler(torch.utils.data.Sampler):
111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112 | assert len(dataset) > 0
113 | assert num_replicas > 0
114 | assert 0 <= rank < num_replicas
115 | assert 0 <= window_size <= 1
116 | super().__init__(dataset)
117 | self.dataset = dataset
118 | self.rank = rank
119 | self.num_replicas = num_replicas
120 | self.shuffle = shuffle
121 | self.seed = seed
122 | self.window_size = window_size
123 |
124 | def __iter__(self):
125 | order = np.arange(len(self.dataset))
126 | rnd = None
127 | window = 0
128 | if self.shuffle:
129 | rnd = np.random.RandomState(self.seed)
130 | rnd.shuffle(order)
131 | window = int(np.rint(order.size * self.window_size))
132 |
133 | idx = 0
134 | while True:
135 | i = idx % order.size
136 | if idx % self.num_replicas == self.rank:
137 | yield order[i]
138 | if window >= 2:
139 | j = (i - rnd.randint(window)) % order.size
140 | order[i], order[j] = order[j], order[i]
141 | idx += 1
142 |
143 | #----------------------------------------------------------------------------
144 | # Utilities for operating with torch.nn.Module parameters and buffers.
145 |
146 | def params_and_buffers(module):
147 | assert isinstance(module, torch.nn.Module)
148 | return list(module.parameters()) + list(module.buffers())
149 |
150 | def named_params_and_buffers(module):
151 | assert isinstance(module, torch.nn.Module)
152 | return list(module.named_parameters()) + list(module.named_buffers())
153 |
154 | @torch.no_grad()
155 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
156 | assert isinstance(src_module, torch.nn.Module)
157 | assert isinstance(dst_module, torch.nn.Module)
158 | src_tensors = dict(named_params_and_buffers(src_module))
159 | for name, tensor in named_params_and_buffers(dst_module):
160 | assert (name in src_tensors) or (not require_all)
161 | if name in src_tensors:
162 | tensor.copy_(src_tensors[name])
163 |
164 | #----------------------------------------------------------------------------
165 | # Context manager for easily enabling/disabling DistributedDataParallel
166 | # synchronization.
167 |
168 | @contextlib.contextmanager
169 | def ddp_sync(module, sync):
170 | assert isinstance(module, torch.nn.Module)
171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172 | yield
173 | else:
174 | with module.no_sync():
175 | yield
176 |
177 | #----------------------------------------------------------------------------
178 | # Check DistributedDataParallel consistency across processes.
179 |
180 | def check_ddp_consistency(module, ignore_regex=None):
181 | assert isinstance(module, torch.nn.Module)
182 | for name, tensor in named_params_and_buffers(module):
183 | fullname = type(module).__name__ + '.' + name
184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185 | continue
186 | tensor = tensor.detach()
187 | if tensor.is_floating_point():
188 | tensor = nan_to_num(tensor)
189 | other = tensor.clone()
190 | torch.distributed.broadcast(tensor=other, src=0)
191 | assert (tensor == other).all(), fullname
192 |
193 | #----------------------------------------------------------------------------
194 | # Print summary table of module hierarchy.
195 |
196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197 | assert isinstance(module, torch.nn.Module)
198 | assert not isinstance(module, torch.jit.ScriptModule)
199 | assert isinstance(inputs, (tuple, list))
200 |
201 | # Register hooks.
202 | entries = []
203 | nesting = [0]
204 | def pre_hook(_mod, _inputs):
205 | nesting[0] += 1
206 | def post_hook(mod, _inputs, outputs):
207 | nesting[0] -= 1
208 | if nesting[0] <= max_nesting:
209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214 |
215 | # Run module.
216 | outputs = module(*inputs)
217 | for hook in hooks:
218 | hook.remove()
219 |
220 | # Identify unique outputs, parameters, and buffers.
221 | tensors_seen = set()
222 | for e in entries:
223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227 |
228 | # Filter out redundant entries.
229 | if skip_redundant:
230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231 |
232 | # Construct table.
233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234 | rows += [['---'] * len(rows[0])]
235 | param_total = 0
236 | buffer_total = 0
237 | submodule_names = {mod: name for name, mod in module.named_modules()}
238 | for e in entries:
239 | name = '' if e.mod is module else submodule_names[e.mod]
240 | param_size = sum(t.numel() for t in e.unique_params)
241 | buffer_size = sum(t.numel() for t in e.unique_buffers)
242 | output_shapes = [str(list(t.shape)) for t in e.outputs]
243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244 | rows += [[
245 | name + (':0' if len(e.outputs) >= 2 else ''),
246 | str(param_size) if param_size else '-',
247 | str(buffer_size) if buffer_size else '-',
248 | (output_shapes + ['-'])[0],
249 | (output_dtypes + ['-'])[0],
250 | ]]
251 | for idx in range(1, len(e.outputs)):
252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253 | param_total += param_size
254 | buffer_total += buffer_size
255 | rows += [['---'] * len(rows[0])]
256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257 |
258 | # Print table.
259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260 | print()
261 | for row in rows:
262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263 | print()
264 | return outputs
265 |
266 | #----------------------------------------------------------------------------
267 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/bias_act.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/bias_act.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_resample.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_resample.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/fma.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/fma.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/fma.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/fma.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/upfirdn2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/upfirdn2d.cpython-37.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import warnings
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | import traceback
17 |
18 | from .. import custom_ops
19 | from .. import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | activation_funcs = {
24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33 | }
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | _inited = False
38 | _plugin = None
39 | _null_tensor = torch.empty([0])
40 |
41 | # def _init():
42 | # global _inited, _plugin
43 | # if not _inited:
44 | # _inited = True
45 | # sources = ['bias_act.cpp', 'bias_act.cu']
46 | # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47 | # try:
48 | # _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49 | # except:
50 | # warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51 | # return _plugin is not None
52 |
53 | def _init():
54 | return False
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
59 | r"""Fused bias and activation function.
60 |
61 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
62 | and scales the result by `gain`. Each of the steps is optional. In most cases,
63 | the fused op is considerably more efficient than performing the same calculation
64 | using standard PyTorch ops. It supports first and second order gradients,
65 | but not third order gradients.
66 |
67 | Args:
68 | x: Input activation tensor. Can be of any shape.
69 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
70 | as `x`. The shape must be known, and it must match the dimension of `x`
71 | corresponding to `dim`.
72 | dim: The dimension in `x` corresponding to the elements of `b`.
73 | The value of `dim` is ignored if `b` is not specified.
74 | act: Name of the activation function to evaluate, or `"linear"` to disable.
75 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
76 | See `activation_funcs` for a full list. `None` is not allowed.
77 | alpha: Shape parameter for the activation function, or `None` to use the default.
78 | gain: Scaling factor for the output tensor, or `None` to use default.
79 | See `activation_funcs` for the default scaling of each activation function.
80 | If unsure, consider specifying 1.
81 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
82 | the clamping (default).
83 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
84 |
85 | Returns:
86 | Tensor of the same shape and datatype as `x`.
87 | """
88 | assert isinstance(x, torch.Tensor)
89 | assert impl in ['ref', 'cuda']
90 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
91 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
92 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
93 |
94 | #----------------------------------------------------------------------------
95 |
96 | @misc.profiled_function
97 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
98 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
99 | """
100 | assert isinstance(x, torch.Tensor)
101 | assert clamp is None or clamp >= 0
102 | spec = activation_funcs[act]
103 | alpha = float(alpha if alpha is not None else spec.def_alpha)
104 | gain = float(gain if gain is not None else spec.def_gain)
105 | clamp = float(clamp if clamp is not None else -1)
106 |
107 | # Add bias.
108 | if b is not None:
109 | assert isinstance(b, torch.Tensor) and b.ndim == 1
110 | assert 0 <= dim < x.ndim
111 | assert b.shape[0] == x.shape[dim]
112 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
113 |
114 | # Evaluate activation function.
115 | alpha = float(alpha)
116 | x = spec.func(x, alpha=alpha)
117 |
118 | # Scale by gain.
119 | gain = float(gain)
120 | if gain != 1:
121 | x = x * gain
122 |
123 | # Clamp.
124 | if clamp >= 0:
125 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
126 | return x
127 |
128 | #----------------------------------------------------------------------------
129 |
130 | _bias_act_cuda_cache = dict()
131 |
132 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
133 | """Fast CUDA implementation of `bias_act()` using custom ops.
134 | """
135 | # Parse arguments.
136 | assert clamp is None or clamp >= 0
137 | spec = activation_funcs[act]
138 | alpha = float(alpha if alpha is not None else spec.def_alpha)
139 | gain = float(gain if gain is not None else spec.def_gain)
140 | clamp = float(clamp if clamp is not None else -1)
141 |
142 | # Lookup from cache.
143 | key = (dim, act, alpha, gain, clamp)
144 | if key in _bias_act_cuda_cache:
145 | return _bias_act_cuda_cache[key]
146 |
147 | # Forward op.
148 | class BiasActCuda(torch.autograd.Function):
149 | @staticmethod
150 | def forward(ctx, x, b): # pylint: disable=arguments-differ
151 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
152 | x = x.contiguous(memory_format=ctx.memory_format)
153 | b = b.contiguous() if b is not None else _null_tensor
154 | y = x
155 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
156 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
157 | ctx.save_for_backward(
158 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
159 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
160 | y if 'y' in spec.ref else _null_tensor)
161 | return y
162 |
163 | @staticmethod
164 | def backward(ctx, dy): # pylint: disable=arguments-differ
165 | dy = dy.contiguous(memory_format=ctx.memory_format)
166 | x, b, y = ctx.saved_tensors
167 | dx = None
168 | db = None
169 |
170 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
171 | dx = dy
172 | if act != 'linear' or gain != 1 or clamp >= 0:
173 | dx = BiasActCudaGrad.apply(dy, x, b, y)
174 |
175 | if ctx.needs_input_grad[1]:
176 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
177 |
178 | return dx, db
179 |
180 | # Backward op.
181 | class BiasActCudaGrad(torch.autograd.Function):
182 | @staticmethod
183 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
184 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
185 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
186 | ctx.save_for_backward(
187 | dy if spec.has_2nd_grad else _null_tensor,
188 | x, b, y)
189 | return dx
190 |
191 | @staticmethod
192 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
193 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
194 | dy, x, b, y = ctx.saved_tensors
195 | d_dy = None
196 | d_x = None
197 | d_b = None
198 | d_y = None
199 |
200 | if ctx.needs_input_grad[0]:
201 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
202 |
203 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
204 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
205 |
206 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
207 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
208 |
209 | return d_dy, d_x, d_b, d_y
210 |
211 | # Add to cache.
212 | _bias_act_cuda_cache[key] = BiasActCuda
213 | return BiasActCuda
214 |
215 | #----------------------------------------------------------------------------
216 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import warnings
13 | import contextlib
14 | import torch
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 |
25 | @contextlib.contextmanager
26 | def no_weight_gradients():
27 | global weight_gradients_disabled
28 | old = weight_gradients_disabled
29 | weight_gradients_disabled = True
30 | yield
31 | weight_gradients_disabled = old
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36 | if _should_use_custom_op(input):
37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39 |
40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41 | if _should_use_custom_op(input):
42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def _should_use_custom_op(input):
48 | assert isinstance(input, torch.Tensor)
49 | if (not enabled) or (not torch.backends.cudnn.enabled):
50 | return False
51 | if input.device.type != 'cuda':
52 | return False
53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54 | return True
55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56 | return False
57 |
58 | def _tuple_of_ints(xs, ndim):
59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60 | assert len(xs) == ndim
61 | assert all(isinstance(x, int) for x in xs)
62 | return xs
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | _conv2d_gradfix_cache = dict()
67 |
68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69 | # Parse arguments.
70 | ndim = 2
71 | weight_shape = tuple(weight_shape)
72 | stride = _tuple_of_ints(stride, ndim)
73 | padding = _tuple_of_ints(padding, ndim)
74 | output_padding = _tuple_of_ints(output_padding, ndim)
75 | dilation = _tuple_of_ints(dilation, ndim)
76 |
77 | # Lookup from cache.
78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79 | if key in _conv2d_gradfix_cache:
80 | return _conv2d_gradfix_cache[key]
81 |
82 | # Validate arguments.
83 | assert groups >= 1
84 | assert len(weight_shape) == ndim + 2
85 | assert all(stride[i] >= 1 for i in range(ndim))
86 | assert all(padding[i] >= 0 for i in range(ndim))
87 | assert all(dilation[i] >= 0 for i in range(ndim))
88 | if not transpose:
89 | assert all(output_padding[i] == 0 for i in range(ndim))
90 | else: # transpose
91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92 |
93 | # Helpers.
94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95 | def calc_output_padding(input_shape, output_shape):
96 | if transpose:
97 | return [0, 0]
98 | return [
99 | input_shape[i + 2]
100 | - (output_shape[i + 2] - 1) * stride[i]
101 | - (1 - 2 * padding[i])
102 | - dilation[i] * (weight_shape[i + 2] - 1)
103 | for i in range(ndim)
104 | ]
105 |
106 | # Forward & backward.
107 | class Conv2d(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, bias):
110 | assert weight.shape == weight_shape
111 | if not transpose:
112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113 | else: # transpose
114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115 | ctx.save_for_backward(input, weight)
116 | return output
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | input, weight = ctx.saved_tensors
121 | grad_input = None
122 | grad_weight = None
123 | grad_bias = None
124 |
125 | if ctx.needs_input_grad[0]:
126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128 | assert grad_input.shape == input.shape
129 |
130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
132 | assert grad_weight.shape == weight_shape
133 |
134 | if ctx.needs_input_grad[2]:
135 | grad_bias = grad_output.sum([0, 2, 3])
136 |
137 | return grad_input, grad_weight, grad_bias
138 |
139 | # Gradient with respect to the weights.
140 | class Conv2dGradWeight(torch.autograd.Function):
141 | @staticmethod
142 | def forward(ctx, grad_output, input):
143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146 | assert grad_weight.shape == weight_shape
147 | ctx.save_for_backward(grad_output, input)
148 | return grad_weight
149 |
150 | @staticmethod
151 | def backward(ctx, grad2_grad_weight):
152 | grad_output, input = ctx.saved_tensors
153 | grad2_grad_output = None
154 | grad2_input = None
155 |
156 | if ctx.needs_input_grad[0]:
157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158 | assert grad2_grad_output.shape == grad_output.shape
159 |
160 | if ctx.needs_input_grad[1]:
161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163 | assert grad2_input.shape == input.shape
164 |
165 | return grad2_grad_output, grad2_input
166 |
167 | _conv2d_gradfix_cache[key] = Conv2d
168 | return Conv2d
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | w = w.flip([2, 3])
37 |
38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42 | if out_channels <= 4 and groups == 1:
43 | in_shape = x.shape
44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46 | else:
47 | x = x.to(memory_format=torch.contiguous_format)
48 | w = w.to(memory_format=torch.contiguous_format)
49 | x = conv2d_gradfix.conv2d(x, w, groups=groups)
50 | return x.to(memory_format=torch.channels_last)
51 |
52 | # Otherwise => execute using conv2d_gradfix.
53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54 | return op(x, w, stride=stride, padding=padding, groups=groups)
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | @misc.profiled_function
59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60 | r"""2D convolution with optional up/downsampling.
61 |
62 | Padding is performed only once at the beginning, not between the operations.
63 |
64 | Args:
65 | x: Input tensor of shape
66 | `[batch_size, in_channels, in_height, in_width]`.
67 | w: Weight tensor of shape
68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70 | calling upfirdn2d.setup_filter(). None = identity (default).
71 | up: Integer upsampling factor (default: 1).
72 | down: Integer downsampling factor (default: 1).
73 | padding: Padding with respect to the upsampled image. Can be a single number
74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75 | (default: 0).
76 | groups: Split input channels into N groups (default: 1).
77 | flip_weight: False = convolution, True = correlation (default: True).
78 | flip_filter: False = convolution, True = correlation (default: False).
79 |
80 | Returns:
81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82 | """
83 | # Validate arguments.
84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87 | assert isinstance(up, int) and (up >= 1)
88 | assert isinstance(down, int) and (down >= 1)
89 | assert isinstance(groups, int) and (groups >= 1)
90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91 | fw, fh = _get_filter_size(f)
92 | px0, px1, py0, py1 = _parse_padding(padding)
93 |
94 | # Adjust padding to account for up/downsampling.
95 | if up > 1:
96 | px0 += (fw + up - 1) // 2
97 | px1 += (fw - up) // 2
98 | py0 += (fh + up - 1) // 2
99 | py1 += (fh - up) // 2
100 | if down > 1:
101 | px0 += (fw - down + 1) // 2
102 | px1 += (fw - down) // 2
103 | py0 += (fh - down + 1) // 2
104 | py1 += (fh - down) // 2
105 |
106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110 | return x
111 |
112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116 | return x
117 |
118 | # Fast path: downsampling only => use strided convolution.
119 | if down > 1 and up == 1:
120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122 | return x
123 |
124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125 | if up > 1:
126 | if groups == 1:
127 | w = w.transpose(0, 1)
128 | else:
129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130 | w = w.transpose(1, 2)
131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132 | px0 -= kw - 1
133 | px1 -= kw - up
134 | py0 -= kh - 1
135 | py1 -= kh - up
136 | pxt = max(min(-px0, -px1), 0)
137 | pyt = max(min(-py0, -py1), 0)
138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140 | if down > 1:
141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142 | return x
143 |
144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145 | if up == 1 and down == 1:
146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148 |
149 | # Fallback: Generic reference implementation.
150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152 | if down > 1:
153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154 | return x
155 |
156 | #----------------------------------------------------------------------------
157 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import warnings
15 | import torch
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | def grid_sample(input, grid):
28 | if _should_use_custom_op():
29 | return _GridSample2dForward.apply(input, grid)
30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def _should_use_custom_op():
35 | if not enabled:
36 | return False
37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38 | return True
39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40 | return False
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | class _GridSample2dForward(torch.autograd.Function):
45 | @staticmethod
46 | def forward(ctx, input, grid):
47 | assert input.ndim == 4
48 | assert grid.ndim == 4
49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50 | ctx.save_for_backward(input, grid)
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, grid = ctx.saved_tensors
56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57 | return grad_input, grad_grid
58 |
59 | #----------------------------------------------------------------------------
60 |
61 | class _GridSample2dBackward(torch.autograd.Function):
62 | @staticmethod
63 | def forward(ctx, grad_output, input, grid):
64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29 |
30 | // Create output tensor.
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37 |
38 | // Initialize CUDA kernel parameters.
39 | upfirdn2d_kernel_params p;
40 | p.x = x.data_ptr();
41 | p.f = f.data_ptr();
42 | p.y = y.data_ptr();
43 | p.up = make_int2(upx, upy);
44 | p.down = make_int2(downx, downy);
45 | p.pad0 = make_int2(padx0, pady0);
46 | p.flip = (flip) ? 1 : 0;
47 | p.gain = gain;
48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56 |
57 | // Choose CUDA kernel.
58 | upfirdn2d_kernel_spec spec;
59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60 | {
61 | spec = choose_upfirdn2d_kernel(p);
62 | });
63 |
64 | // Set looping options.
65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66 | p.loopMinor = spec.loopMinor;
67 | p.loopX = spec.loopX;
68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70 |
71 | // Compute grid size.
72 | dim3 blockSize, gridSize;
73 | if (spec.tileOutW < 0) // large
74 | {
75 | blockSize = dim3(4, 32, 1);
76 | gridSize = dim3(
77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79 | p.launchMajor);
80 | }
81 | else // small
82 | {
83 | blockSize = dim3(256, 1, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 |
90 | // Launch CUDA kernel.
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("upfirdn2d", &upfirdn2d);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient resampling of 2D images."""
10 |
11 | import os
12 | import warnings
13 | import numpy as np
14 | import torch
15 | import traceback
16 |
17 | from .. import custom_ops
18 | from .. import misc
19 | from . import conv2d_gradfix
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | _inited = False
24 | _plugin = None
25 |
26 | # def _init():
27 | # global _inited, _plugin
28 | # if not _inited:
29 | # sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
30 | # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
31 | # try:
32 | # _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
33 | # except:
34 | # warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
35 | # return _plugin is not None
36 |
37 | def _init():
38 | return False
39 |
40 | def _parse_scaling(scaling):
41 | if isinstance(scaling, int):
42 | scaling = [scaling, scaling]
43 | assert isinstance(scaling, (list, tuple))
44 | assert all(isinstance(x, int) for x in scaling)
45 | sx, sy = scaling
46 | assert sx >= 1 and sy >= 1
47 | return sx, sy
48 |
49 | def _parse_padding(padding):
50 | if isinstance(padding, int):
51 | padding = [padding, padding]
52 | assert isinstance(padding, (list, tuple))
53 | assert all(isinstance(x, int) for x in padding)
54 | if len(padding) == 2:
55 | padx, pady = padding
56 | padding = [padx, padx, pady, pady]
57 | padx0, padx1, pady0, pady1 = padding
58 | return padx0, padx1, pady0, pady1
59 |
60 | def _get_filter_size(f):
61 | if f is None:
62 | return 1, 1
63 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
64 | fw = f.shape[-1]
65 | fh = f.shape[0]
66 | with misc.suppress_tracer_warnings():
67 | fw = int(fw)
68 | fh = int(fh)
69 | misc.assert_shape(f, [fh, fw][:f.ndim])
70 | assert fw >= 1 and fh >= 1
71 | return fw, fh
72 |
73 | #----------------------------------------------------------------------------
74 |
75 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
76 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
77 |
78 | Args:
79 | f: Torch tensor, numpy array, or python list of the shape
80 | `[filter_height, filter_width]` (non-separable),
81 | `[filter_taps]` (separable),
82 | `[]` (impulse), or
83 | `None` (identity).
84 | device: Result device (default: cpu).
85 | normalize: Normalize the filter so that it retains the magnitude
86 | for constant input signal (DC)? (default: True).
87 | flip_filter: Flip the filter? (default: False).
88 | gain: Overall scaling factor for signal magnitude (default: 1).
89 | separable: Return a separable filter? (default: select automatically).
90 |
91 | Returns:
92 | Float32 tensor of the shape
93 | `[filter_height, filter_width]` (non-separable) or
94 | `[filter_taps]` (separable).
95 | """
96 | # Validate.
97 | if f is None:
98 | f = 1
99 | f = torch.as_tensor(f, dtype=torch.float32)
100 | assert f.ndim in [0, 1, 2]
101 | assert f.numel() > 0
102 | if f.ndim == 0:
103 | f = f[np.newaxis]
104 |
105 | # Separable?
106 | if separable is None:
107 | separable = (f.ndim == 1 and f.numel() >= 8)
108 | if f.ndim == 1 and not separable:
109 | f = f.ger(f)
110 | assert f.ndim == (1 if separable else 2)
111 |
112 | # Apply normalize, flip, gain, and device.
113 | if normalize:
114 | f /= f.sum()
115 | if flip_filter:
116 | f = f.flip(list(range(f.ndim)))
117 | f = f * (gain ** (f.ndim / 2))
118 | f = f.to(device=device)
119 | return f
120 |
121 | #----------------------------------------------------------------------------
122 |
123 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
124 | r"""Pad, upsample, filter, and downsample a batch of 2D images.
125 |
126 | Performs the following sequence of operations for each channel:
127 |
128 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
129 |
130 | 2. Pad the image with the specified number of zeros on each side (`padding`).
131 | Negative padding corresponds to cropping the image.
132 |
133 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
134 | so that the footprint of all output pixels lies within the input image.
135 |
136 | 4. Downsample the image by keeping every Nth pixel (`down`).
137 |
138 | This sequence of operations bears close resemblance to scipy.signal.upfirdn().
139 | The fused op is considerably more efficient than performing the same calculation
140 | using standard PyTorch ops. It supports gradients of arbitrary order.
141 |
142 | Args:
143 | x: Float32/float64/float16 input tensor of the shape
144 | `[batch_size, num_channels, in_height, in_width]`.
145 | f: Float32 FIR filter of the shape
146 | `[filter_height, filter_width]` (non-separable),
147 | `[filter_taps]` (separable), or
148 | `None` (identity).
149 | up: Integer upsampling factor. Can be a single int or a list/tuple
150 | `[x, y]` (default: 1).
151 | down: Integer downsampling factor. Can be a single int or a list/tuple
152 | `[x, y]` (default: 1).
153 | padding: Padding with respect to the upsampled image. Can be a single number
154 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
155 | (default: 0).
156 | flip_filter: False = convolution, True = correlation (default: False).
157 | gain: Overall scaling factor for signal magnitude (default: 1).
158 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
159 |
160 | Returns:
161 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
162 | """
163 | assert isinstance(x, torch.Tensor)
164 | assert impl in ['ref', 'cuda']
165 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
166 | return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
167 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
168 |
169 | #----------------------------------------------------------------------------
170 |
171 | @misc.profiled_function
172 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
173 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
174 | """
175 | # Validate arguments.
176 | assert isinstance(x, torch.Tensor) and x.ndim == 4
177 | if f is None:
178 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
179 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
180 | assert f.dtype == torch.float32 and not f.requires_grad
181 | batch_size, num_channels, in_height, in_width = x.shape
182 | upx, upy = _parse_scaling(up)
183 | downx, downy = _parse_scaling(down)
184 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
185 |
186 | # Upsample by inserting zeros.
187 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
188 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
189 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
190 |
191 | # Pad or crop.
192 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
193 | x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
194 |
195 | # Setup filter.
196 | f = f * (gain ** (f.ndim / 2))
197 | f = f.to(x.dtype)
198 | if not flip_filter:
199 | f = f.flip(list(range(f.ndim)))
200 |
201 | # Convolve with the filter.
202 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
203 | if f.ndim == 4:
204 | x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
205 | else:
206 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
207 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
208 |
209 | # Downsample by throwing away pixels.
210 | x = x[:, :, ::downy, ::downx]
211 | return x
212 |
213 | #----------------------------------------------------------------------------
214 |
215 | _upfirdn2d_cuda_cache = dict()
216 |
217 | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
218 | """Fast CUDA implementation of `upfirdn2d()` using custom ops.
219 | """
220 | # Parse arguments.
221 | upx, upy = _parse_scaling(up)
222 | downx, downy = _parse_scaling(down)
223 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
224 |
225 | # Lookup from cache.
226 | key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
227 | if key in _upfirdn2d_cuda_cache:
228 | return _upfirdn2d_cuda_cache[key]
229 |
230 | # Forward op.
231 | class Upfirdn2dCuda(torch.autograd.Function):
232 | @staticmethod
233 | def forward(ctx, x, f): # pylint: disable=arguments-differ
234 | assert isinstance(x, torch.Tensor) and x.ndim == 4
235 | if f is None:
236 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
237 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
238 | y = x
239 | if f.ndim == 2:
240 | y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
241 | else:
242 | y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
243 | y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
244 | ctx.save_for_backward(f)
245 | ctx.x_shape = x.shape
246 | return y
247 |
248 | @staticmethod
249 | def backward(ctx, dy): # pylint: disable=arguments-differ
250 | f, = ctx.saved_tensors
251 | _, _, ih, iw = ctx.x_shape
252 | _, _, oh, ow = dy.shape
253 | fw, fh = _get_filter_size(f)
254 | p = [
255 | fw - padx0 - 1,
256 | iw * upx - ow * downx + padx0 - upx + 1,
257 | fh - pady0 - 1,
258 | ih * upy - oh * downy + pady0 - upy + 1,
259 | ]
260 | dx = None
261 | df = None
262 |
263 | if ctx.needs_input_grad[0]:
264 | dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
265 |
266 | assert not ctx.needs_input_grad[1]
267 | return dx, df
268 |
269 | # Add to cache.
270 | _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
271 | return Upfirdn2dCuda
272 |
273 | #----------------------------------------------------------------------------
274 |
275 | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
276 | r"""Filter a batch of 2D images using the given 2D FIR filter.
277 |
278 | By default, the result is padded so that its shape matches the input.
279 | User-specified padding is applied on top of that, with negative values
280 | indicating cropping. Pixels outside the image are assumed to be zero.
281 |
282 | Args:
283 | x: Float32/float64/float16 input tensor of the shape
284 | `[batch_size, num_channels, in_height, in_width]`.
285 | f: Float32 FIR filter of the shape
286 | `[filter_height, filter_width]` (non-separable),
287 | `[filter_taps]` (separable), or
288 | `None` (identity).
289 | padding: Padding with respect to the output. Can be a single number or a
290 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
291 | (default: 0).
292 | flip_filter: False = convolution, True = correlation (default: False).
293 | gain: Overall scaling factor for signal magnitude (default: 1).
294 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
295 |
296 | Returns:
297 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
298 | """
299 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
300 | fw, fh = _get_filter_size(f)
301 | p = [
302 | padx0 + fw // 2,
303 | padx1 + (fw - 1) // 2,
304 | pady0 + fh // 2,
305 | pady1 + (fh - 1) // 2,
306 | ]
307 | return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
308 |
309 | #----------------------------------------------------------------------------
310 |
311 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
312 | r"""Upsample a batch of 2D images using the given 2D FIR filter.
313 |
314 | By default, the result is padded so that its shape is a multiple of the input.
315 | User-specified padding is applied on top of that, with negative values
316 | indicating cropping. Pixels outside the image are assumed to be zero.
317 |
318 | Args:
319 | x: Float32/float64/float16 input tensor of the shape
320 | `[batch_size, num_channels, in_height, in_width]`.
321 | f: Float32 FIR filter of the shape
322 | `[filter_height, filter_width]` (non-separable),
323 | `[filter_taps]` (separable), or
324 | `None` (identity).
325 | up: Integer upsampling factor. Can be a single int or a list/tuple
326 | `[x, y]` (default: 1).
327 | padding: Padding with respect to the output. Can be a single number or a
328 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
329 | (default: 0).
330 | flip_filter: False = convolution, True = correlation (default: False).
331 | gain: Overall scaling factor for signal magnitude (default: 1).
332 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
333 |
334 | Returns:
335 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
336 | """
337 | upx, upy = _parse_scaling(up)
338 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
339 | fw, fh = _get_filter_size(f)
340 | p = [
341 | padx0 + (fw + upx - 1) // 2,
342 | padx1 + (fw - upx) // 2,
343 | pady0 + (fh + upy - 1) // 2,
344 | pady1 + (fh - upy) // 2,
345 | ]
346 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
347 |
348 | #----------------------------------------------------------------------------
349 |
350 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
351 | r"""Downsample a batch of 2D images using the given 2D FIR filter.
352 |
353 | By default, the result is padded so that its shape is a fraction of the input.
354 | User-specified padding is applied on top of that, with negative values
355 | indicating cropping. Pixels outside the image are assumed to be zero.
356 |
357 | Args:
358 | x: Float32/float64/float16 input tensor of the shape
359 | `[batch_size, num_channels, in_height, in_width]`.
360 | f: Float32 FIR filter of the shape
361 | `[filter_height, filter_width]` (non-separable),
362 | `[filter_taps]` (separable), or
363 | `None` (identity).
364 | down: Integer downsampling factor. Can be a single int or a list/tuple
365 | `[x, y]` (default: 1).
366 | padding: Padding with respect to the input. Can be a single number or a
367 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
368 | (default: 0).
369 | flip_filter: False = convolution, True = correlation (default: False).
370 | gain: Overall scaling factor for signal magnitude (default: 1).
371 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
372 |
373 | Returns:
374 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
375 | """
376 | downx, downy = _parse_scaling(down)
377 | padx0, padx1, pady0, pady1 = _parse_padding(padding)
378 | fw, fh = _get_filter_size(f)
379 | p = [
380 | padx0 + (fw - downx + 1) // 2,
381 | padx1 + (fw - downx) // 2,
382 | pady0 + (fh - downy + 1) // 2,
383 | pady1 + (fh - downy) // 2,
384 | ]
385 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
386 |
387 | #----------------------------------------------------------------------------
388 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for pickling Python code alongside other data.
9 |
10 | The pickled code is automatically imported into a separate Python module
11 | during unpickling. This way, any previously exported pickles will remain
12 | usable even if the original code is no longer available, or if the current
13 | version of the code is not consistent with what was originally pickled."""
14 |
15 | import sys
16 | import pickle
17 | import io
18 | import inspect
19 | import copy
20 | import uuid
21 | import types
22 | import dnnlib
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | _version = 6 # internal version number
27 | _decorators = set() # {decorator_class, ...}
28 | _import_hooks = [] # [hook_function, ...]
29 | _module_to_src_dict = dict() # {module: src, ...}
30 | _src_to_module_dict = dict() # {src: module, ...}
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def persistent_class(orig_class):
35 | r"""Class decorator that extends a given class to save its source code
36 | when pickled.
37 |
38 | Example:
39 |
40 | from torch_utils import persistence
41 |
42 | @persistence.persistent_class
43 | class MyNetwork(torch.nn.Module):
44 | def __init__(self, num_inputs, num_outputs):
45 | super().__init__()
46 | self.fc = MyLayer(num_inputs, num_outputs)
47 | ...
48 |
49 | @persistence.persistent_class
50 | class MyLayer(torch.nn.Module):
51 | ...
52 |
53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
54 | source code alongside other internal state (e.g., parameters, buffers,
55 | and submodules). This way, any previously exported pickle will remain
56 | usable even if the class definitions have been modified or are no
57 | longer available.
58 |
59 | The decorator saves the source code of the entire Python module
60 | containing the decorated class. It does *not* save the source code of
61 | any imported modules. Thus, the imported modules must be available
62 | during unpickling, also including `torch_utils.persistence` itself.
63 |
64 | It is ok to call functions defined in the same module from the
65 | decorated class. However, if the decorated class depends on other
66 | classes defined in the same module, they must be decorated as well.
67 | This is illustrated in the above example in the case of `MyLayer`.
68 |
69 | It is also possible to employ the decorator just-in-time before
70 | calling the constructor. For example:
71 |
72 | cls = MyLayer
73 | if want_to_make_it_persistent:
74 | cls = persistence.persistent_class(cls)
75 | layer = cls(num_inputs, num_outputs)
76 |
77 | As an additional feature, the decorator also keeps track of the
78 | arguments that were used to construct each instance of the decorated
79 | class. The arguments can be queried via `obj.init_args` and
80 | `obj.init_kwargs`, and they are automatically pickled alongside other
81 | object state. This feature can be disabled on a per-instance basis
82 | by setting `self._record_init_args = False` in the constructor.
83 |
84 | A typical use case is to first unpickle a previous instance of a
85 | persistent class, and then upgrade it to use the latest version of
86 | the source code:
87 |
88 | with open('old_pickle.pkl', 'rb') as f:
89 | old_net = pickle.load(f)
90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92 | """
93 | assert isinstance(orig_class, type)
94 | if is_persistent(orig_class):
95 | return orig_class
96 |
97 | assert orig_class.__module__ in sys.modules
98 | orig_module = sys.modules[orig_class.__module__]
99 | orig_module_src = _module_to_src(orig_module)
100 |
101 | class Decorator(orig_class):
102 | _orig_module_src = orig_module_src
103 | _orig_class_name = orig_class.__name__
104 |
105 | def __init__(self, *args, **kwargs):
106 | super().__init__(*args, **kwargs)
107 | record_init_args = getattr(self, '_record_init_args', True)
108 | self._init_args = copy.deepcopy(args) if record_init_args else None
109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
110 | assert orig_class.__name__ in orig_module.__dict__
111 | _check_pickleable(self.__reduce__())
112 |
113 | @property
114 | def init_args(self):
115 | assert self._init_args is not None
116 | return copy.deepcopy(self._init_args)
117 |
118 | @property
119 | def init_kwargs(self):
120 | assert self._init_kwargs is not None
121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
122 |
123 | def __reduce__(self):
124 | fields = list(super().__reduce__())
125 | fields += [None] * max(3 - len(fields), 0)
126 | if fields[0] is not _reconstruct_persistent_obj:
127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
128 | fields[0] = _reconstruct_persistent_obj # reconstruct func
129 | fields[1] = (meta,) # reconstruct args
130 | fields[2] = None # state dict
131 | return tuple(fields)
132 |
133 | Decorator.__name__ = orig_class.__name__
134 | Decorator.__module__ = orig_class.__module__
135 | _decorators.add(Decorator)
136 | return Decorator
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | def is_persistent(obj):
141 | r"""Test whether the given object or class is persistent, i.e.,
142 | whether it will save its source code when pickled.
143 | """
144 | try:
145 | if obj in _decorators:
146 | return True
147 | except TypeError:
148 | pass
149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150 |
151 | #----------------------------------------------------------------------------
152 |
153 | def import_hook(hook):
154 | r"""Register an import hook that is called whenever a persistent object
155 | is being unpickled. A typical use case is to patch the pickled source
156 | code to avoid errors and inconsistencies when the API of some imported
157 | module has changed.
158 |
159 | The hook should have the following signature:
160 |
161 | hook(meta) -> modified meta
162 |
163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164 |
165 | type: Type of the persistent object, e.g. `'class'`.
166 | version: Internal version number of `torch_utils.persistence`.
167 | module_src Original source code of the Python module.
168 | class_name: Class name in the original Python module.
169 | state: Internal state of the object.
170 |
171 | Example:
172 |
173 | @persistence.import_hook
174 | def wreck_my_network(meta):
175 | if meta.class_name == 'MyNetwork':
176 | print('MyNetwork is being imported. I will wreck it!')
177 | meta.module_src = meta.module_src.replace("True", "False")
178 | return meta
179 | """
180 | assert callable(hook)
181 | _import_hooks.append(hook)
182 |
183 | #----------------------------------------------------------------------------
184 |
185 | def _reconstruct_persistent_obj(meta):
186 | r"""Hook that is called internally by the `pickle` module to unpickle
187 | a persistent object.
188 | """
189 | meta = dnnlib.EasyDict(meta)
190 | meta.state = dnnlib.EasyDict(meta.state)
191 | for hook in _import_hooks:
192 | meta = hook(meta)
193 | assert meta is not None
194 |
195 | assert meta.version == _version
196 | module = _src_to_module(meta.module_src)
197 |
198 | assert meta.type == 'class'
199 | orig_class = module.__dict__[meta.class_name]
200 | decorator_class = persistent_class(orig_class)
201 | obj = decorator_class.__new__(decorator_class)
202 |
203 | setstate = getattr(obj, '__setstate__', None)
204 | if callable(setstate):
205 | setstate(meta.state) # pylint: disable=not-callable
206 | else:
207 | obj.__dict__.update(meta.state)
208 | return obj
209 |
210 | #----------------------------------------------------------------------------
211 |
212 | def _module_to_src(module):
213 | r"""Query the source code of a given Python module.
214 | """
215 | src = _module_to_src_dict.get(module, None)
216 | if src is None:
217 | src = inspect.getsource(module)
218 | _module_to_src_dict[module] = src
219 | _src_to_module_dict[src] = module
220 | return src
221 |
222 | def _src_to_module(src):
223 | r"""Get or create a Python module for the given source code.
224 | """
225 | module = _src_to_module_dict.get(src, None)
226 | if module is None:
227 | module_name = "_imported_module_" + uuid.uuid4().hex
228 | module = types.ModuleType(module_name)
229 | sys.modules[module_name] = module
230 | _module_to_src_dict[module] = src
231 | _src_to_module_dict[src] = module
232 | exec(src, module.__dict__) # pylint: disable=exec-used
233 | return module
234 |
235 | #----------------------------------------------------------------------------
236 |
237 | def _check_pickleable(obj):
238 | r"""Check that the given object is pickleable, raising an exception if
239 | it is not. This function is expected to be considerably more efficient
240 | than actually pickling the object.
241 | """
242 | def recurse(obj):
243 | if isinstance(obj, (list, tuple, set)):
244 | return [recurse(x) for x in obj]
245 | if isinstance(obj, dict):
246 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248 | return None # Python primitive types are pickleable.
249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250 | return None # NumPy arrays and PyTorch tensors are pickleable.
251 | if is_persistent(obj):
252 | return None # Persistent objects are pickleable, by virtue of the constructor check.
253 | return obj
254 | with io.BytesIO() as f:
255 | pickle.dump(recurse(obj), f)
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for reporting and collecting training statistics across
9 | multiple processes and devices. The interface is designed to minimize
10 | synchronization overhead as well as the amount of boilerplate in user
11 | code."""
12 |
13 | import re
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 |
18 | from . import misc
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
25 | _rank = 0 # Rank of the current process.
26 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
27 | _sync_called = False # Has _sync() been called yet?
28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | def init_multiprocessing(rank, sync_device):
34 | r"""Initializes `torch_utils.training_stats` for collecting statistics
35 | across multiple processes.
36 |
37 | This function must be called after
38 | `torch.distributed.init_process_group()` and before `Collector.update()`.
39 | The call is not necessary if multi-process collection is not needed.
40 |
41 | Args:
42 | rank: Rank of the current process.
43 | sync_device: PyTorch device to use for inter-process
44 | communication, or None to disable multi-process
45 | collection. Typically `torch.device('cuda', rank)`.
46 | """
47 | global _rank, _sync_device
48 | assert not _sync_called
49 | _rank = rank
50 | _sync_device = sync_device
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | @misc.profiled_function
55 | def report(name, value):
56 | r"""Broadcasts the given set of scalars to all interested instances of
57 | `Collector`, across device and process boundaries.
58 |
59 | This function is expected to be extremely cheap and can be safely
60 | called from anywhere in the training loop, loss function, or inside a
61 | `torch.nn.Module`.
62 |
63 | Warning: The current implementation expects the set of unique names to
64 | be consistent across processes. Please make sure that `report()` is
65 | called at least once for each unique name by each process, and in the
66 | same order. If a given process has no scalars to broadcast, it can do
67 | `report(name, [])` (empty list).
68 |
69 | Args:
70 | name: Arbitrary string specifying the name of the statistic.
71 | Averages are accumulated separately for each unique name.
72 | value: Arbitrary set of scalars. Can be a list, tuple,
73 | NumPy array, PyTorch tensor, or Python scalar.
74 |
75 | Returns:
76 | The same `value` that was passed in.
77 | """
78 | if name not in _counters:
79 | _counters[name] = dict()
80 |
81 | elems = torch.as_tensor(value)
82 | if elems.numel() == 0:
83 | return value
84 |
85 | elems = elems.detach().flatten().to(_reduce_dtype)
86 | moments = torch.stack([
87 | torch.ones_like(elems).sum(),
88 | elems.sum(),
89 | elems.square().sum(),
90 | ])
91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
92 | moments = moments.to(_counter_dtype)
93 |
94 | device = moments.device
95 | if device not in _counters[name]:
96 | _counters[name][device] = torch.zeros_like(moments)
97 | _counters[name][device].add_(moments)
98 | return value
99 |
100 | #----------------------------------------------------------------------------
101 |
102 | def report0(name, value):
103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
104 | but ignores any scalars provided by the other processes.
105 | See `report()` for further details.
106 | """
107 | report(name, value if _rank == 0 else [])
108 | return value
109 |
110 | #----------------------------------------------------------------------------
111 |
112 | class Collector:
113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
114 | computes their long-term averages (mean and standard deviation) over
115 | user-defined periods of time.
116 |
117 | The averages are first collected into internal counters that are not
118 | directly visible to the user. They are then copied to the user-visible
119 | state as a result of calling `update()` and can then be queried using
120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
121 | internal counters for the next round, so that the user-visible state
122 | effectively reflects averages collected between the last two calls to
123 | `update()`.
124 |
125 | Args:
126 | regex: Regular expression defining which statistics to
127 | collect. The default is to collect everything.
128 | keep_previous: Whether to retain the previous averages if no
129 | scalars were collected on a given round
130 | (default: True).
131 | """
132 | def __init__(self, regex='.*', keep_previous=True):
133 | self._regex = re.compile(regex)
134 | self._keep_previous = keep_previous
135 | self._cumulative = dict()
136 | self._moments = dict()
137 | self.update()
138 | self._moments.clear()
139 |
140 | def names(self):
141 | r"""Returns the names of all statistics broadcasted so far that
142 | match the regular expression specified at construction time.
143 | """
144 | return [name for name in _counters if self._regex.fullmatch(name)]
145 |
146 | def update(self):
147 | r"""Copies current values of the internal counters to the
148 | user-visible state and resets them for the next round.
149 |
150 | If `keep_previous=True` was specified at construction time, the
151 | operation is skipped for statistics that have received no scalars
152 | since the last update, retaining their previous averages.
153 |
154 | This method performs a number of GPU-to-CPU transfers and one
155 | `torch.distributed.all_reduce()`. It is intended to be called
156 | periodically in the main training loop, typically once every
157 | N training steps.
158 | """
159 | if not self._keep_previous:
160 | self._moments.clear()
161 | for name, cumulative in _sync(self.names()):
162 | if name not in self._cumulative:
163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
164 | delta = cumulative - self._cumulative[name]
165 | self._cumulative[name].copy_(cumulative)
166 | if float(delta[0]) != 0:
167 | self._moments[name] = delta
168 |
169 | def _get_delta(self, name):
170 | r"""Returns the raw moments that were accumulated for the given
171 | statistic between the last two calls to `update()`, or zero if
172 | no scalars were collected.
173 | """
174 | assert self._regex.fullmatch(name)
175 | if name not in self._moments:
176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
177 | return self._moments[name]
178 |
179 | def num(self, name):
180 | r"""Returns the number of scalars that were accumulated for the given
181 | statistic between the last two calls to `update()`, or zero if
182 | no scalars were collected.
183 | """
184 | delta = self._get_delta(name)
185 | return int(delta[0])
186 |
187 | def mean(self, name):
188 | r"""Returns the mean of the scalars that were accumulated for the
189 | given statistic between the last two calls to `update()`, or NaN if
190 | no scalars were collected.
191 | """
192 | delta = self._get_delta(name)
193 | if int(delta[0]) == 0:
194 | return float('nan')
195 | return float(delta[1] / delta[0])
196 |
197 | def std(self, name):
198 | r"""Returns the standard deviation of the scalars that were
199 | accumulated for the given statistic between the last two calls to
200 | `update()`, or NaN if no scalars were collected.
201 | """
202 | delta = self._get_delta(name)
203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
204 | return float('nan')
205 | if int(delta[0]) == 1:
206 | return float(0)
207 | mean = float(delta[1] / delta[0])
208 | raw_var = float(delta[2] / delta[0])
209 | return np.sqrt(max(raw_var - np.square(mean), 0))
210 |
211 | def as_dict(self):
212 | r"""Returns the averages accumulated between the last two calls to
213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
214 |
215 | dnnlib.EasyDict(
216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
217 | ...
218 | )
219 | """
220 | stats = dnnlib.EasyDict()
221 | for name in self.names():
222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
223 | return stats
224 |
225 | def __getitem__(self, name):
226 | r"""Convenience getter.
227 | `collector[name]` is a synonym for `collector.mean(name)`.
228 | """
229 | return self.mean(name)
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def _sync(names):
234 | r"""Synchronize the global cumulative counters across devices and
235 | processes. Called internally by `Collector.update()`.
236 | """
237 | if len(names) == 0:
238 | return []
239 | global _sync_called
240 | _sync_called = True
241 |
242 | # Collect deltas within current rank.
243 | deltas = []
244 | device = _sync_device if _sync_device is not None else torch.device('cpu')
245 | for name in names:
246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
247 | for counter in _counters[name].values():
248 | delta.add_(counter.to(device))
249 | counter.copy_(torch.zeros_like(counter))
250 | deltas.append(delta)
251 | deltas = torch.stack(deltas)
252 |
253 | # Sum deltas across ranks.
254 | if _sync_device is not None:
255 | torch.distributed.all_reduce(deltas)
256 |
257 | # Update cumulative values.
258 | deltas = deltas.cpu()
259 | for idx, name in enumerate(names):
260 | if name not in _cumulative:
261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
262 | _cumulative[name].add_(deltas[idx])
263 |
264 | # Return name-value pairs.
265 | return [(name, _cumulative[name]) for name in names]
266 |
267 | #----------------------------------------------------------------------------
268 | # Convenience.
269 |
270 | default_collector = Collector()
271 |
272 | #----------------------------------------------------------------------------
273 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Streaming images and labels from datasets created with dataset_tool.py."""
9 |
10 | import os
11 | import numpy as np
12 | import zipfile
13 | import PIL.Image
14 | import json
15 | import torch
16 | import dnnlib
17 |
18 | try:
19 | import pyspng
20 | except ImportError:
21 | pyspng = None
22 |
23 | #----------------------------------------------------------------------------
24 | # Abstract base class for datasets.
25 |
26 | class Dataset(torch.utils.data.Dataset):
27 | def __init__(self,
28 | name, # Name of the dataset.
29 | raw_shape, # Shape of the raw image data (NCHW).
30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
33 | random_seed = 0, # Random seed to use when applying max_size.
34 | cache = False, # Cache images in CPU memory?
35 | ):
36 | self._name = name
37 | self._raw_shape = list(raw_shape)
38 | self._use_labels = use_labels
39 | self._cache = cache
40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
41 | self._raw_labels = None
42 | self._label_shape = None
43 |
44 | # Apply max_size.
45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
46 | if (max_size is not None) and (self._raw_idx.size > max_size):
47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
48 | self._raw_idx = np.sort(self._raw_idx[:max_size])
49 |
50 | # Apply xflip.
51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
52 | if xflip:
53 | self._raw_idx = np.tile(self._raw_idx, 2)
54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
55 |
56 | def _get_raw_labels(self):
57 | if self._raw_labels is None:
58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
59 | if self._raw_labels is None:
60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
61 | assert isinstance(self._raw_labels, np.ndarray)
62 | assert self._raw_labels.shape[0] == self._raw_shape[0]
63 | assert self._raw_labels.dtype in [np.float32, np.int64]
64 | if self._raw_labels.dtype == np.int64:
65 | assert self._raw_labels.ndim == 1
66 | assert np.all(self._raw_labels >= 0)
67 | return self._raw_labels
68 |
69 | def close(self): # to be overridden by subclass
70 | pass
71 |
72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
73 | raise NotImplementedError
74 |
75 | def _load_raw_labels(self): # to be overridden by subclass
76 | raise NotImplementedError
77 |
78 | def __getstate__(self):
79 | return dict(self.__dict__, _raw_labels=None)
80 |
81 | def __del__(self):
82 | try:
83 | self.close()
84 | except:
85 | pass
86 |
87 | def __len__(self):
88 | return self._raw_idx.size
89 |
90 | def __getitem__(self, idx):
91 | raw_idx = self._raw_idx[idx]
92 | image = self._cached_images.get(raw_idx, None)
93 | if image is None:
94 | image = self._load_raw_image(raw_idx)
95 | if self._cache:
96 | self._cached_images[raw_idx] = image
97 | assert isinstance(image, np.ndarray)
98 | assert list(image.shape) == self.image_shape
99 | assert image.dtype == np.uint8
100 | if self._xflip[idx]:
101 | assert image.ndim == 3 # CHW
102 | image = image[:, :, ::-1]
103 | return image.copy(), self.get_label(idx)
104 |
105 | def get_label(self, idx):
106 | label = self._get_raw_labels()[self._raw_idx[idx]]
107 | if label.dtype == np.int64:
108 | onehot = np.zeros(self.label_shape, dtype=np.float32)
109 | onehot[label] = 1
110 | label = onehot
111 | return label.copy()
112 |
113 | def get_details(self, idx):
114 | d = dnnlib.EasyDict()
115 | d.raw_idx = int(self._raw_idx[idx])
116 | d.xflip = (int(self._xflip[idx]) != 0)
117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
118 | return d
119 |
120 | @property
121 | def name(self):
122 | return self._name
123 |
124 | @property
125 | def image_shape(self):
126 | return list(self._raw_shape[1:])
127 |
128 | @property
129 | def num_channels(self):
130 | assert len(self.image_shape) == 3 # CHW
131 | return self.image_shape[0]
132 |
133 | @property
134 | def resolution(self):
135 | assert len(self.image_shape) == 3 # CHW
136 | assert self.image_shape[1] == self.image_shape[2]
137 | return self.image_shape[1]
138 |
139 | @property
140 | def label_shape(self):
141 | if self._label_shape is None:
142 | raw_labels = self._get_raw_labels()
143 | if raw_labels.dtype == np.int64:
144 | self._label_shape = [int(np.max(raw_labels)) + 1]
145 | else:
146 | self._label_shape = raw_labels.shape[1:]
147 | return list(self._label_shape)
148 |
149 | @property
150 | def label_dim(self):
151 | assert len(self.label_shape) == 1
152 | return self.label_shape[0]
153 |
154 | @property
155 | def has_labels(self):
156 | return any(x != 0 for x in self.label_shape)
157 |
158 | @property
159 | def has_onehot_labels(self):
160 | return self._get_raw_labels().dtype == np.int64
161 |
162 | #----------------------------------------------------------------------------
163 | # Dataset subclass that loads images recursively from the specified directory
164 | # or ZIP file.
165 |
166 | class ImageFolderDataset(Dataset):
167 | def __init__(self,
168 | path, # Path to directory or zip.
169 | resolution = None, # Ensure specific resolution, None = highest available.
170 | use_pyspng = True, # Use pyspng if available?
171 | **super_kwargs, # Additional arguments for the Dataset base class.
172 | ):
173 | self._path = path
174 | self._use_pyspng = use_pyspng
175 | self._zipfile = None
176 |
177 | if os.path.isdir(self._path):
178 | self._type = 'dir'
179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
180 | elif self._file_ext(self._path) == '.zip':
181 | self._type = 'zip'
182 | self._all_fnames = set(self._get_zipfile().namelist())
183 | else:
184 | raise IOError('Path must point to a directory or zip')
185 |
186 | PIL.Image.init()
187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
188 | if len(self._image_fnames) == 0:
189 | raise IOError('No image files found in the specified path')
190 |
191 | name = os.path.splitext(os.path.basename(self._path))[0]
192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
194 | raise IOError('Image files do not match the specified resolution')
195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
196 |
197 | @staticmethod
198 | def _file_ext(fname):
199 | return os.path.splitext(fname)[1].lower()
200 |
201 | def _get_zipfile(self):
202 | assert self._type == 'zip'
203 | if self._zipfile is None:
204 | self._zipfile = zipfile.ZipFile(self._path)
205 | return self._zipfile
206 |
207 | def _open_file(self, fname):
208 | if self._type == 'dir':
209 | return open(os.path.join(self._path, fname), 'rb')
210 | if self._type == 'zip':
211 | return self._get_zipfile().open(fname, 'r')
212 | return None
213 |
214 | def close(self):
215 | try:
216 | if self._zipfile is not None:
217 | self._zipfile.close()
218 | finally:
219 | self._zipfile = None
220 |
221 | def __getstate__(self):
222 | return dict(super().__getstate__(), _zipfile=None)
223 |
224 | def _load_raw_image(self, raw_idx):
225 | fname = self._image_fnames[raw_idx]
226 | with self._open_file(fname) as f:
227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
228 | image = pyspng.load(f.read())
229 | else:
230 | image = np.array(PIL.Image.open(f))
231 | if image.ndim == 2:
232 | image = image[:, :, np.newaxis] # HW => HWC
233 | image = image.transpose(2, 0, 1) # HWC => CHW
234 | return image
235 |
236 | def _load_raw_labels(self):
237 | fname = 'dataset.json'
238 | if fname not in self._all_fnames:
239 | return None
240 | with self._open_file(fname) as f:
241 | labels = json.load(f)['labels']
242 | if labels is None:
243 | return None
244 | labels = dict(labels)
245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
246 | labels = np.array(labels)
247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
248 | return labels
249 |
250 | #----------------------------------------------------------------------------
251 |
--------------------------------------------------------------------------------
/training/di_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Weijian Luo, Peking University . All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Train one-step diffusion-based generative model using the techniques described in the
9 | paper "Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models"
10 | https://github.com/pkulwj1994/diff_instruct
11 |
12 | Code was modified from paper ""Elucidating the Design Space of Diffusion-Based Generative Models""
13 | https://github.com/NVlabs/edm
14 | """
15 |
16 | """Loss functions used in the paper
17 | "Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models"."""
18 |
19 | import torch
20 | from torch_utils import persistence
21 | from torch.distributions.log_normal import LogNormal
22 | import numpy as np
23 |
24 | #----------------------------------------------------------------------------
25 | # Loss function corresponding to the variance preserving (VP) formulation
26 | # from the paper "Diff-Instruct: A Universal Approach for Transferring Knowledge of Diffusion Models".
27 |
28 | @persistence.persistent_class
29 | class DI_EDMLoss:
30 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
31 | self.P_mean = P_mean
32 | self.P_std = P_std
33 | self.sigma_data = sigma_data
34 |
35 | def gloss(self, Sd, Sg, images, labels=None, augment_pipe=None):
36 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
37 |
38 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
39 | weight = 1.0
40 |
41 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device))
42 | n = torch.randn_like(y) * sigma
43 |
44 | Sg.train(), Sd.train()
45 | with torch.no_grad():
46 | cuda_rng_state = torch.cuda.get_rng_state()
47 | Dd_yn = Sd(y + n, sigma, labels, augment_labels=augment_labels)
48 | torch.cuda.set_rng_state(cuda_rng_state)
49 | Dg_yn = Sg(y + n, sigma, labels, augment_labels=augment_labels)
50 | Sd.eval()
51 |
52 | loss = weight * ((Dg_yn - Dd_yn) * images)
53 |
54 | return loss
55 |
56 | def __call__(self, net, images, labels=None, augment_pipe=None):
57 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
58 |
59 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
60 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
61 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
62 | n = torch.randn_like(y) * sigma
63 |
64 | net.train()
65 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
66 |
67 | loss = weight * ((D_yn - y) ** 2)
68 | return loss
--------------------------------------------------------------------------------