├── README.md ├── dataset_tool_edm.py ├── di_train.py ├── dnnlib ├── __init__.py └── util.py ├── metrics ├── __init__.py ├── di_frechet_inception_distance.py ├── di_inception_score.py ├── di_kernel_inception_distance.py ├── di_metric_main.py ├── di_metric_utils.py ├── di_precision_recall.py └── perceptual_path_length.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── distributed.py ├── misc.py ├── ops │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── bias_act.cpython-37.pyc │ │ ├── bias_act.cpython-38.pyc │ │ ├── conv2d_gradfix.cpython-37.pyc │ │ ├── conv2d_gradfix.cpython-38.pyc │ │ ├── conv2d_resample.cpython-37.pyc │ │ ├── conv2d_resample.cpython-38.pyc │ │ ├── fma.cpython-37.pyc │ │ ├── fma.cpython-38.pyc │ │ ├── upfirdn2d.cpython-37.pyc │ │ └── upfirdn2d.cpython-38.pyc │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py └── training ├── __init__.py ├── augment.py ├── dataset.py ├── di_loss.py ├── di_training_loop.py └── networks.py /README.md: -------------------------------------------------------------------------------- 1 | ## Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models (Diff-Instruct)
Official PyTorch implementation of the NeurIPS 2023 paper 2 | 3 | **Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models**
4 | Weijian Luo, Tianyang Hu, Shifeng Zhang, Jiacheng Sun, Zhenguo Li and Zhihua Zhang. 5 |
https://openreview.net/forum?id=MLIs5iRq4w
6 | 7 | Abstract: *Due to the ease of training, ability to scale, and high sample quality, diffusion models (DMs) have become the preferred option for generative modeling, with numerous pre-trained models available for a wide variety of datasets. Containing intricate information about data distributions, pre-trained DMs are valuable assets for downstream applications. In this work, we consider learning from pre-trained DMs and transferring their knowledge to other generative models in a data-free fashion. Specifically, we propose a general framework called Diff-Instruct to instruct the training of arbitrary generative models as long as the generated samples are differentiable with respect to the model parameters. Our proposed Diff-Instruct is built on a rigorous mathematical foundation where the instruction process directly corresponds to minimizing a novel divergence we call Integral Kullback-Leibler (IKL) divergence. IKL is tailored for DMs by calculating the integral of the KL divergence along a diffusion process, which we show to be more robust in comparing distributions with misaligned supports. We also reveal non-trivial connections of our method to existing works such as DreamFusion \citep{poole2022dreamfusion}, and generative adversarial training. To demonstrate the effectiveness and universality of Diff-Instruct, we consider two scenarios: distilling pre-trained diffusion models and refining existing GAN models. The experiments on distilling pre-trained diffusion models show that Diff-Instruct results in state-of-the-art single-step diffusion-based models. The experiments on refining GAN models show that the Diff-Instruct can consistently improve the pre-trained generators of GAN models across various settings. Our official code is released through \url{https://github.com/pkulwj1994/diff_instruct}.* 8 | 9 | Code was based on Pytorch implementation of EDM diffusion model: https://github.com/NVlabs/edm. 10 | 11 | ## Prepare conda env 12 | 13 | git clone https://github.com/pkulwj1994/diff_instruct.git 14 | cd diff_instruct 15 | 16 | source activate 17 | conda create -n di_v100 python=3.8 18 | conda activate di_v100 19 | pip install torch==1.12.1 torchvision==0.13.1 tqdm click psutil scipy 20 | 21 | ## Pre-trained models 22 | 23 | We use pre-trained EDM models: 24 | 25 | - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/) 26 | 27 | 28 | ## Preparing datasets 29 | 30 | Datasets are stored in the same format as in [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information. 31 | 32 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive: 33 | 34 | ```.bash 35 | python dataset_tool_edm.py --source=/data/downloads/cifar-10-python.tar.gz --dest=/data/datasets/cifar10-32x32.zip 36 | ``` 37 | 38 | **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution: 39 | 40 | ```.bash 41 | python dataset_tool.py --source=/data/downloads/imagenet/ILSVRC/Data/CLS-LOC/train --dest=/data/datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop 42 | ``` 43 | 44 | ## Distill single-step models for CIFAR10 unconditional generation on a single V100 GPU (result in an FID <= 4.5) 45 | 46 | You can run diffusion distillation using `di_train.py`. For example: 47 | 48 | ```.bash 49 | # Train one-step DI model for unconditional CIFAR-10 using 1 GPUs 50 | CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 --master_port=25678 di_train.py --outdir=/logs/di/ci10-uncond --data=/data/datasets/cifar10-32x32.zip --arch=ddpmpp --batch 128 --edm_model cifar10-uncond --cond=0 --metrics fid50k_full --tick 10 --snap 50 --lr 0.00001 --glr 0.00001 --init_sigma 1.0 --fp16=0 --lr_warmup_kimg -1 --ls 1.0 --sgls 1.0 51 | ``` 52 | 53 | In the experiment, the FID will be calculated automatically for each "snap" of rounds. 54 | 55 | ## License 56 | 57 | All material, including source code and pre-trained models, is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/). 58 | 59 | ## Citation 60 | 61 | ``` 62 | @article{luo2024diffinstruct, 63 | title={Diff-instruct: A universal approach for transferring knowledge from pre-trained diffusion models}, 64 | author={Luo, Weijian and Hu, Tianyang and Zhang, Shifeng and Sun, Jiacheng and Li, Zhenguo and Zhang, Zhihua}, 65 | journal={Advances in Neural Information Processing Systems}, 66 | volume={36}, 67 | year={2024} 68 | } 69 | ``` 70 | 71 | ## Development 72 | 73 | This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests. 74 | 75 | ## Acknowledgments 76 | 77 | We thank EDM paper ""Elucidating the Design Space of Diffusion-Based Generative Models"" for its great implementation of EDM diffusion models in https://github.com/NVlabs/edm. We thank Shuchen Xue, and Zhengyang Geng for constructive feedback on code implementations. 78 | 79 | 80 | -------------------------------------------------------------------------------- /dataset_tool_edm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Tool for creating ZIP/PNG based datasets.""" 9 | 10 | import functools 11 | import gzip 12 | import io 13 | import json 14 | import os 15 | import pickle 16 | import re 17 | import sys 18 | import tarfile 19 | import zipfile 20 | from pathlib import Path 21 | from typing import Callable, Optional, Tuple, Union 22 | import click 23 | import numpy as np 24 | import PIL.Image 25 | from tqdm import tqdm 26 | 27 | #---------------------------------------------------------------------------- 28 | # Parse a 'M,N' or 'MxN' integer tuple. 29 | # Example: '4x2' returns (4,2) 30 | 31 | def parse_tuple(s: str) -> Tuple[int, int]: 32 | m = re.match(r'^(\d+)[x,](\d+)$', s) 33 | if m: 34 | return int(m.group(1)), int(m.group(2)) 35 | raise click.ClickException(f'cannot parse tuple {s}') 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def maybe_min(a: int, b: Optional[int]) -> int: 40 | if b is not None: 41 | return min(a, b) 42 | return a 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def file_ext(name: Union[str, Path]) -> str: 47 | return str(name).split('.')[-1] 48 | 49 | #---------------------------------------------------------------------------- 50 | 51 | def is_image_ext(fname: Union[str, Path]) -> bool: 52 | ext = file_ext(fname).lower() 53 | return f'.{ext}' in PIL.Image.EXTENSION 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def open_image_folder(source_dir, *, max_images: Optional[int]): 58 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] 59 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images} 60 | max_idx = maybe_min(len(input_images), max_images) 61 | 62 | # Load labels. 63 | labels = dict() 64 | meta_fname = os.path.join(source_dir, 'dataset.json') 65 | if os.path.isfile(meta_fname): 66 | with open(meta_fname, 'r') as file: 67 | data = json.load(file)['labels'] 68 | if data is not None: 69 | labels = {x[0]: x[1] for x in data} 70 | 71 | # No labels available => determine from top-level directory names. 72 | if len(labels) == 0: 73 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()} 74 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))} 75 | if len(toplevel_indices) > 1: 76 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()} 77 | 78 | def iterate_images(): 79 | for idx, fname in enumerate(input_images): 80 | img = np.array(PIL.Image.open(fname)) 81 | yield dict(img=img, label=labels.get(arch_fnames.get(fname))) 82 | if idx >= max_idx - 1: 83 | break 84 | return max_idx, iterate_images() 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | def open_image_zip(source, *, max_images: Optional[int]): 89 | with zipfile.ZipFile(source, mode='r') as z: 90 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] 91 | max_idx = maybe_min(len(input_images), max_images) 92 | 93 | # Load labels. 94 | labels = dict() 95 | if 'dataset.json' in z.namelist(): 96 | with z.open('dataset.json', 'r') as file: 97 | data = json.load(file)['labels'] 98 | if data is not None: 99 | labels = {x[0]: x[1] for x in data} 100 | 101 | def iterate_images(): 102 | with zipfile.ZipFile(source, mode='r') as z: 103 | for idx, fname in enumerate(input_images): 104 | with z.open(fname, 'r') as file: 105 | img = np.array(PIL.Image.open(file)) 106 | yield dict(img=img, label=labels.get(fname)) 107 | if idx >= max_idx - 1: 108 | break 109 | return max_idx, iterate_images() 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): 114 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python 115 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb 116 | 117 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 118 | max_idx = maybe_min(txn.stat()['entries'], max_images) 119 | 120 | def iterate_images(): 121 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 122 | for idx, (_key, value) in enumerate(txn.cursor()): 123 | try: 124 | try: 125 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) 126 | if img is None: 127 | raise IOError('cv2.imdecode failed') 128 | img = img[:, :, ::-1] # BGR => RGB 129 | except IOError: 130 | img = np.array(PIL.Image.open(io.BytesIO(value))) 131 | yield dict(img=img, label=None) 132 | if idx >= max_idx - 1: 133 | break 134 | except: 135 | print(sys.exc_info()[1]) 136 | 137 | return max_idx, iterate_images() 138 | 139 | #---------------------------------------------------------------------------- 140 | 141 | def open_cifar10(tarball: str, *, max_images: Optional[int]): 142 | images = [] 143 | labels = [] 144 | 145 | with tarfile.open(tarball, 'r:gz') as tar: 146 | for batch in range(1, 6): 147 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') 148 | with tar.extractfile(member) as file: 149 | data = pickle.load(file, encoding='latin1') 150 | images.append(data['data'].reshape(-1, 3, 32, 32)) 151 | labels.append(data['labels']) 152 | 153 | images = np.concatenate(images) 154 | labels = np.concatenate(labels) 155 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC 156 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 157 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] 158 | assert np.min(images) == 0 and np.max(images) == 255 159 | assert np.min(labels) == 0 and np.max(labels) == 9 160 | 161 | max_idx = maybe_min(len(images), max_images) 162 | 163 | def iterate_images(): 164 | for idx, img in enumerate(images): 165 | yield dict(img=img, label=int(labels[idx])) 166 | if idx >= max_idx - 1: 167 | break 168 | 169 | return max_idx, iterate_images() 170 | 171 | #---------------------------------------------------------------------------- 172 | 173 | def open_mnist(images_gz: str, *, max_images: Optional[int]): 174 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') 175 | assert labels_gz != images_gz 176 | images = [] 177 | labels = [] 178 | 179 | with gzip.open(images_gz, 'rb') as f: 180 | images = np.frombuffer(f.read(), np.uint8, offset=16) 181 | with gzip.open(labels_gz, 'rb') as f: 182 | labels = np.frombuffer(f.read(), np.uint8, offset=8) 183 | 184 | images = images.reshape(-1, 28, 28) 185 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) 186 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 187 | assert labels.shape == (60000,) and labels.dtype == np.uint8 188 | assert np.min(images) == 0 and np.max(images) == 255 189 | assert np.min(labels) == 0 and np.max(labels) == 9 190 | 191 | max_idx = maybe_min(len(images), max_images) 192 | 193 | def iterate_images(): 194 | for idx, img in enumerate(images): 195 | yield dict(img=img, label=int(labels[idx])) 196 | if idx >= max_idx - 1: 197 | break 198 | 199 | return max_idx, iterate_images() 200 | 201 | #---------------------------------------------------------------------------- 202 | 203 | def make_transform( 204 | transform: Optional[str], 205 | output_width: Optional[int], 206 | output_height: Optional[int] 207 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]: 208 | def scale(width, height, img): 209 | w = img.shape[1] 210 | h = img.shape[0] 211 | if width == w and height == h: 212 | return img 213 | img = PIL.Image.fromarray(img) 214 | ww = width if width is not None else w 215 | hh = height if height is not None else h 216 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) 217 | return np.array(img) 218 | 219 | def center_crop(width, height, img): 220 | crop = np.min(img.shape[:2]) 221 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] 222 | if img.ndim == 2: 223 | img = img[:, :, np.newaxis].repeat(3, axis=2) 224 | img = PIL.Image.fromarray(img, 'RGB') 225 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 226 | return np.array(img) 227 | 228 | def center_crop_wide(width, height, img): 229 | ch = int(np.round(width * img.shape[0] / img.shape[1])) 230 | if img.shape[1] < width or ch < height: 231 | return None 232 | 233 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] 234 | if img.ndim == 2: 235 | img = img[:, :, np.newaxis].repeat(3, axis=2) 236 | img = PIL.Image.fromarray(img, 'RGB') 237 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 238 | img = np.array(img) 239 | 240 | canvas = np.zeros([width, width, 3], dtype=np.uint8) 241 | canvas[(width - height) // 2 : (width + height) // 2, :] = img 242 | return canvas 243 | 244 | if transform is None: 245 | return functools.partial(scale, output_width, output_height) 246 | if transform == 'center-crop': 247 | if output_width is None or output_height is None: 248 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform') 249 | return functools.partial(center_crop, output_width, output_height) 250 | if transform == 'center-crop-wide': 251 | if output_width is None or output_height is None: 252 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 253 | return functools.partial(center_crop_wide, output_width, output_height) 254 | assert False, 'unknown transform' 255 | 256 | #---------------------------------------------------------------------------- 257 | 258 | def open_dataset(source, *, max_images: Optional[int]): 259 | if os.path.isdir(source): 260 | if source.rstrip('/').endswith('_lmdb'): 261 | return open_lmdb(source, max_images=max_images) 262 | else: 263 | return open_image_folder(source, max_images=max_images) 264 | elif os.path.isfile(source): 265 | if os.path.basename(source) == 'cifar-10-python.tar.gz': 266 | return open_cifar10(source, max_images=max_images) 267 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': 268 | return open_mnist(source, max_images=max_images) 269 | elif file_ext(source) == 'zip': 270 | return open_image_zip(source, max_images=max_images) 271 | else: 272 | assert False, 'unknown archive type' 273 | else: 274 | raise click.ClickException(f'Missing input file or directory: {source}') 275 | 276 | #---------------------------------------------------------------------------- 277 | 278 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: 279 | dest_ext = file_ext(dest) 280 | 281 | if dest_ext == 'zip': 282 | if os.path.dirname(dest) != '': 283 | os.makedirs(os.path.dirname(dest), exist_ok=True) 284 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) 285 | def zip_write_bytes(fname: str, data: Union[bytes, str]): 286 | zf.writestr(fname, data) 287 | return '', zip_write_bytes, zf.close 288 | else: 289 | # If the output folder already exists, check that is is 290 | # empty. 291 | # 292 | # Note: creating the output directory is not strictly 293 | # necessary as folder_write_bytes() also mkdirs, but it's better 294 | # to give an error message earlier in case the dest folder 295 | # somehow cannot be created. 296 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0: 297 | raise click.ClickException('--dest folder must be empty') 298 | os.makedirs(dest, exist_ok=True) 299 | 300 | def folder_write_bytes(fname: str, data: Union[bytes, str]): 301 | os.makedirs(os.path.dirname(fname), exist_ok=True) 302 | with open(fname, 'wb') as fout: 303 | if isinstance(data, str): 304 | data = data.encode('utf8') 305 | fout.write(data) 306 | return dest, folder_write_bytes, lambda: None 307 | 308 | #---------------------------------------------------------------------------- 309 | 310 | @click.command() 311 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 312 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 313 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 314 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide'])) 315 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple) 316 | 317 | def main( 318 | source: str, 319 | dest: str, 320 | max_images: Optional[int], 321 | transform: Optional[str], 322 | resolution: Optional[Tuple[int, int]] 323 | ): 324 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. 325 | 326 | The input dataset format is guessed from the --source argument: 327 | 328 | \b 329 | --source *_lmdb/ Load LSUN dataset 330 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 331 | --source train-images-idx3-ubyte.gz Load MNIST dataset 332 | --source path/ Recursively load all images from path/ 333 | --source dataset.zip Recursively load all images from dataset.zip 334 | 335 | Specifying the output format and path: 336 | 337 | \b 338 | --dest /path/to/dir Save output files under /path/to/dir 339 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 340 | 341 | The output dataset format can be either an image folder or an uncompressed zip archive. 342 | Zip archives makes it easier to move datasets around file servers and clusters, and may 343 | offer better training performance on network file systems. 344 | 345 | Images within the dataset archive will be stored as uncompressed PNG. 346 | Uncompresed PNGs can be efficiently decoded in the training loop. 347 | 348 | Class labels are stored in a file called 'dataset.json' that is stored at the 349 | dataset root folder. This file has the following structure: 350 | 351 | \b 352 | { 353 | "labels": [ 354 | ["00000/img00000000.png",6], 355 | ["00000/img00000001.png",9], 356 | ... repeated for every image in the datase 357 | ["00049/img00049999.png",1] 358 | ] 359 | } 360 | 361 | If the 'dataset.json' file cannot be found, class labels are determined from 362 | top-level directory names. 363 | 364 | Image scale/crop and resolution requirements: 365 | 366 | Output images must be square-shaped and they must all have the same power-of-two 367 | dimensions. 368 | 369 | To scale arbitrary input image size to a specific width and height, use the 370 | --resolution option. Output resolution will be either the original 371 | input resolution (if resolution was not specified) or the one specified with 372 | --resolution option. 373 | 374 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a 375 | center crop transform on the input image. These options should be used with the 376 | --resolution option. For example: 377 | 378 | \b 379 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ 380 | --transform=center-crop-wide --resolution=512x384 381 | """ 382 | 383 | PIL.Image.init() 384 | 385 | if dest == '': 386 | raise click.ClickException('--dest output filename or directory must not be an empty string') 387 | 388 | num_files, input_iter = open_dataset(source, max_images=max_images) 389 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 390 | 391 | if resolution is None: resolution = (None, None) 392 | transform_image = make_transform(transform, *resolution) 393 | 394 | dataset_attrs = None 395 | 396 | labels = [] 397 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 398 | idx_str = f'{idx:08d}' 399 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png' 400 | 401 | # Apply crop and resize. 402 | img = transform_image(image['img']) 403 | if img is None: 404 | continue 405 | 406 | # Error check to require uniform image attributes across 407 | # the whole dataset. 408 | channels = img.shape[2] if img.ndim == 3 else 1 409 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels} 410 | if dataset_attrs is None: 411 | dataset_attrs = cur_image_attrs 412 | width = dataset_attrs['width'] 413 | height = dataset_attrs['height'] 414 | if width != height: 415 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') 416 | if dataset_attrs['channels'] not in [1, 3]: 417 | raise click.ClickException('Input images must be stored as RGB or grayscale') 418 | if width != 2 ** int(np.floor(np.log2(width))): 419 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two') 420 | elif dataset_attrs != cur_image_attrs: 421 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] 422 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) 423 | 424 | # Save the image as an uncompressed PNG. 425 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels]) 426 | image_bits = io.BytesIO() 427 | img.save(image_bits, format='png', compress_level=0, optimize=False) 428 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) 429 | labels.append([archive_fname, image['label']] if image['label'] is not None else None) 430 | 431 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 432 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 433 | close_dest() 434 | 435 | #---------------------------------------------------------------------------- 436 | 437 | if __name__ == "__main__": 438 | main() 439 | 440 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /di_train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Weijian Luo, Peking University . All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Train one-step diffusion-based generative model using the techniques described in the 9 | paper "Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models" 10 | Weijian Luo, Tianyang Hu, Shifeng Zhang, Jiacheng Sun, Zhenguo Li and Zhihua Zhang. 11 | 12 | https://github.com/pkulwj1994/diff_instruct 13 | 14 | Code was modified from paper ""Elucidating the Design Space of Diffusion-Based Generative Models"" 15 | https://github.com/NVlabs/edm 16 | """ 17 | 18 | import os 19 | import re 20 | import json 21 | import click 22 | import torch 23 | import dnnlib 24 | from torch_utils import distributed as dist 25 | from training import di_training_loop as training_loop 26 | 27 | import warnings 28 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 29 | 30 | #---------------------------------------------------------------------------- 31 | # Parse a comma separated list of numbers or ranges and return a list of ints. 32 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 33 | 34 | def parse_int_list(s): 35 | if isinstance(s, list): return s 36 | ranges = [] 37 | range_re = re.compile(r'^(\d+)-(\d+)$') 38 | for p in s.split(','): 39 | m = range_re.match(p) 40 | if m: 41 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 42 | else: 43 | ranges.append(int(p)) 44 | return ranges 45 | 46 | class CommaSeparatedList(click.ParamType): 47 | name = 'list' 48 | 49 | def convert(self, value, param, ctx): 50 | _ = param, ctx 51 | if value is None or value.lower() == 'none' or value == '': 52 | return [] 53 | return value.split(',') 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | @click.command() 58 | 59 | # Main options.gpu 60 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True) 61 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True) 62 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True) 63 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True) 64 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True) 65 | 66 | # Hyperparameters. 67 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True) 68 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) 69 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) 70 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int) 71 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list) 72 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 73 | @click.option('--glr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 74 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True) 75 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) 76 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True) 77 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) 78 | 79 | # Performance-related. 80 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 81 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 82 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 83 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) 84 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 85 | 86 | # I/O-related. 87 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 88 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True) 89 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True) 90 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) 91 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int) 92 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 93 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str) 94 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) 95 | 96 | @click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList()) 97 | @click.option('--edm_model', help='edm_model', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['cifar10-uncond', 'cifar10-cond', 'ffhq64', 'afhq64-v2', 'imagenet64-cond', 'ffhq64-uncond', 'afhqv2_64-uncond']), default='cifar10-cond', show_default=True) 98 | @click.option('--init_sigma', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 99 | 100 | @click.option('--ema_mu', help='ema rate', metavar='FLOAT', type=click.FloatRange(min=-1.5, min_open=True), default=-1.0, show_default=True) 101 | @click.option('--lr_warmup_kimg', help='lr warmup', metavar='KIMG', type=click.IntRange(min=-2), default=-1, show_default=True) 102 | @click.option('--sgls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 103 | 104 | 105 | def main(**kwargs): 106 | """Train diffusion-based generative model using the techniques described in the 107 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 108 | 109 | Examples: 110 | 111 | \b 112 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 113 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ 114 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 115 | """ 116 | opts = dnnlib.EasyDict(kwargs) 117 | torch.multiprocessing.set_start_method('spawn') 118 | dist.init() 119 | 120 | # Initialize config dict. 121 | c = dnnlib.EasyDict() 122 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) 123 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) 124 | c.network_kwargs = dnnlib.EasyDict() 125 | c.loss_kwargs = dnnlib.EasyDict() 126 | 127 | c.sg_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.0,0.999], eps=1e-8) 128 | c.g_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.glr, betas=[0.0,0.999], eps=1e-8) 129 | 130 | c.init_sigma = opts.init_sigma 131 | c.ema_mu = opts.ema_mu 132 | c.use_fp16 = opts.fp16 133 | c.lr_rampup_kimg = opts.lr_warmup_kimg 134 | 135 | 136 | # Validate dataset options. 137 | try: 138 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) 139 | dataset_name = dataset_obj.name 140 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution 141 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size 142 | if opts.cond and not dataset_obj.has_labels: 143 | raise click.ClickException('--cond=True requires labels specified in dataset.json') 144 | del dataset_obj # conserve memory 145 | except IOError as err: 146 | raise click.ClickException(f'--data: {err}') 147 | 148 | # Network architecture. 149 | if opts.arch == 'ddpmpp': 150 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') 151 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2]) 152 | elif opts.arch == 'ncsnpp': 153 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') 154 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2]) 155 | else: 156 | assert opts.arch == 'adm' 157 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) 158 | 159 | # Training options. 160 | c.total_kimg = max(int(opts.duration * 1000), 1) 161 | c.ema_halflife_kimg = int(opts.ema * 1000) 162 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 163 | c.update(loss_scaling=opts.ls, sgls=opts.sgls, cudnn_benchmark=opts.bench) 164 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap) 165 | 166 | # Random seed. 167 | if opts.seed is not None: 168 | c.seed = opts.seed 169 | else: 170 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 171 | torch.distributed.broadcast(seed, src=0) 172 | c.seed = int(seed) 173 | 174 | # Preconditioning & loss function. 175 | if opts.precond == 'vp': 176 | c.network_kwargs.class_name = 'training.networks.VPPrecond' 177 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 178 | elif opts.precond == 've': 179 | c.network_kwargs.class_name = 'training.networks.VEPrecond' 180 | c.loss_kwargs.class_name = 'training.loss.VELoss' 181 | else: 182 | assert opts.precond == 'edm' 183 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 184 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 185 | 186 | c.loss_kwargs.class_name = 'training.di_loss.DI_EDMLoss' 187 | c.metrics = opts.metrics 188 | 189 | # Network options. 190 | if opts.cbase is not None: 191 | c.network_kwargs.model_channels = opts.cbase 192 | if opts.cres is not None: 193 | c.network_kwargs.channel_mult = opts.cres 194 | if opts.augment: 195 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) 196 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 197 | c.network_kwargs.augment_dim = 9 198 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) 199 | 200 | # Training options. 201 | c.total_kimg = max(int(opts.duration * 1000), 1) 202 | c.ema_halflife_kimg = int(opts.ema * 1000) 203 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 204 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) 205 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap) 206 | 207 | # Random seed. 208 | if opts.seed is not None: 209 | c.seed = opts.seed 210 | else: 211 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 212 | torch.distributed.broadcast(seed, src=0) 213 | c.seed = int(seed) 214 | 215 | resume_specs = { 216 | 'cifar10-uncond': 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl', 217 | 'cifar10-cond': 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl', 218 | } 219 | 220 | c.resume_pkl = resume_specs[opts.edm_model] 221 | if opts.transfer is not None: 222 | c.transfer_pkl = opts.transfer 223 | c.ema_rampup_ratio = None 224 | else: 225 | c.transfer_pkl = None 226 | 227 | # Description string. 228 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond' 229 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32' 230 | desc = f'{dataset_name:s}-{cond_str:s}-ls{opts.ls}-sgls{opts.sgls}-glr{opts.glr}-sglr{opts.lr}-sigma{opts.init_sigma}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}-lrwarmkimg{opts.lr_warmup_kimg}' 231 | if opts.desc is not None: 232 | desc += f'-{opts.desc}' 233 | 234 | # Pick output directory. 235 | if dist.get_rank() != 0: 236 | c.run_dir = None 237 | elif opts.nosubdir: 238 | c.run_dir = opts.outdir 239 | else: 240 | prev_run_dirs = [] 241 | if os.path.isdir(opts.outdir): 242 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] 243 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 244 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 245 | cur_run_id = max(prev_run_ids, default=-1) + 1 246 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') 247 | assert not os.path.exists(c.run_dir) 248 | 249 | # Print options. 250 | dist.print0() 251 | dist.print0('Training options:') 252 | dist.print0(json.dumps(c, indent=2)) 253 | dist.print0() 254 | dist.print0(f'Output directory: {c.run_dir}') 255 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}') 256 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') 257 | dist.print0(f'Network architecture: {opts.arch}') 258 | dist.print0(f'Preconditioning & loss: {opts.precond}') 259 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 260 | dist.print0(f'Batch size: {c.batch_size}') 261 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') 262 | dist.print0() 263 | 264 | # Dry run? 265 | if opts.dry_run: 266 | dist.print0('Dry run; exiting.') 267 | return 268 | 269 | # Create output directory. 270 | dist.print0('Creating output directory...') 271 | if dist.get_rank() == 0: 272 | os.makedirs(c.run_dir, exist_ok=True) 273 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 274 | json.dump(c, f, indent=2) 275 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) 276 | 277 | # Train. 278 | training_loop.training_loop(**c) 279 | 280 | #---------------------------------------------------------------------------- 281 | 282 | if __name__ == "__main__": 283 | main() 284 | 285 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 59 | self.file = None 60 | 61 | if file_name is not None: 62 | self.file = open(file_name, file_mode) 63 | 64 | self.should_flush = should_flush 65 | self.stdout = sys.stdout 66 | self.stderr = sys.stderr 67 | 68 | sys.stdout = self 69 | sys.stderr = self 70 | 71 | def __enter__(self) -> "Logger": 72 | return self 73 | 74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 75 | self.close() 76 | 77 | def write(self, text: Union[str, bytes]) -> None: 78 | """Write text to stdout (and a file) and optionally flush.""" 79 | if isinstance(text, bytes): 80 | text = text.decode() 81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 82 | return 83 | 84 | if self.file is not None: 85 | self.file.write(text) 86 | 87 | self.stdout.write(text) 88 | 89 | if self.should_flush: 90 | self.flush() 91 | 92 | def flush(self) -> None: 93 | """Flush written text to both stdout and a file, if open.""" 94 | if self.file is not None: 95 | self.file.flush() 96 | 97 | self.stdout.flush() 98 | 99 | def close(self) -> None: 100 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 101 | self.flush() 102 | 103 | # if using multiple loggers, prevent closing in wrong order 104 | if sys.stdout is self: 105 | sys.stdout = self.stdout 106 | if sys.stderr is self: 107 | sys.stderr = self.stderr 108 | 109 | if self.file is not None: 110 | self.file.close() 111 | self.file = None 112 | 113 | 114 | # Cache directories 115 | # ------------------------------------------------------------------------------------------ 116 | 117 | _dnnlib_cache_dir = None 118 | 119 | def set_cache_dir(path: str) -> None: 120 | global _dnnlib_cache_dir 121 | _dnnlib_cache_dir = path 122 | 123 | def make_cache_dir_path(*paths: str) -> str: 124 | if _dnnlib_cache_dir is not None: 125 | return os.path.join(_dnnlib_cache_dir, *paths) 126 | if 'DNNLIB_CACHE_DIR' in os.environ: 127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 128 | if 'HOME' in os.environ: 129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 130 | if 'USERPROFILE' in os.environ: 131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 133 | 134 | # Small util functions 135 | # ------------------------------------------------------------------------------------------ 136 | 137 | 138 | def format_time(seconds: Union[int, float]) -> str: 139 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 140 | s = int(np.rint(seconds)) 141 | 142 | if s < 60: 143 | return "{0}s".format(s) 144 | elif s < 60 * 60: 145 | return "{0}m {1:02}s".format(s // 60, s % 60) 146 | elif s < 24 * 60 * 60: 147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 148 | else: 149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 150 | 151 | 152 | def format_time_brief(seconds: Union[int, float]) -> str: 153 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 154 | s = int(np.rint(seconds)) 155 | 156 | if s < 60: 157 | return "{0}s".format(s) 158 | elif s < 60 * 60: 159 | return "{0}m {1:02}s".format(s // 60, s % 60) 160 | elif s < 24 * 60 * 60: 161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 162 | else: 163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 164 | 165 | 166 | def ask_yes_no(question: str) -> bool: 167 | """Ask the user the question until the user inputs a valid answer.""" 168 | while True: 169 | try: 170 | print("{0} [y/n]".format(question)) 171 | return strtobool(input().lower()) 172 | except ValueError: 173 | pass 174 | 175 | 176 | def tuple_product(t: Tuple) -> Any: 177 | """Calculate the product of the tuple elements.""" 178 | result = 1 179 | 180 | for v in t: 181 | result *= v 182 | 183 | return result 184 | 185 | 186 | _str_to_ctype = { 187 | "uint8": ctypes.c_ubyte, 188 | "uint16": ctypes.c_uint16, 189 | "uint32": ctypes.c_uint32, 190 | "uint64": ctypes.c_uint64, 191 | "int8": ctypes.c_byte, 192 | "int16": ctypes.c_int16, 193 | "int32": ctypes.c_int32, 194 | "int64": ctypes.c_int64, 195 | "float32": ctypes.c_float, 196 | "float64": ctypes.c_double 197 | } 198 | 199 | 200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 202 | type_str = None 203 | 204 | if isinstance(type_obj, str): 205 | type_str = type_obj 206 | elif hasattr(type_obj, "__name__"): 207 | type_str = type_obj.__name__ 208 | elif hasattr(type_obj, "name"): 209 | type_str = type_obj.name 210 | else: 211 | raise RuntimeError("Cannot infer type name from input") 212 | 213 | assert type_str in _str_to_ctype.keys() 214 | 215 | my_dtype = np.dtype(type_str) 216 | my_ctype = _str_to_ctype[type_str] 217 | 218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 219 | 220 | return my_dtype, my_ctype 221 | 222 | 223 | def is_pickleable(obj: Any) -> bool: 224 | try: 225 | with io.BytesIO() as stream: 226 | pickle.dump(obj, stream) 227 | return True 228 | except: 229 | return False 230 | 231 | 232 | # Functionality to import modules/objects by name, and call functions by name 233 | # ------------------------------------------------------------------------------------------ 234 | 235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 236 | """Searches for the underlying module behind the name to some python object. 237 | Returns the module and the object name (original name with module part removed).""" 238 | 239 | # allow convenience shorthands, substitute them by full names 240 | obj_name = re.sub("^np.", "numpy.", obj_name) 241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 242 | 243 | # list alternatives for (module_name, local_obj_name) 244 | parts = obj_name.split(".") 245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 246 | 247 | # try each alternative in turn 248 | for module_name, local_obj_name in name_pairs: 249 | try: 250 | module = importlib.import_module(module_name) # may raise ImportError 251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 252 | return module, local_obj_name 253 | except: 254 | pass 255 | 256 | # maybe some of the modules themselves contain errors? 257 | for module_name, _local_obj_name in name_pairs: 258 | try: 259 | importlib.import_module(module_name) # may raise ImportError 260 | except ImportError: 261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 262 | raise 263 | 264 | # maybe the requested attribute is missing? 265 | for module_name, local_obj_name in name_pairs: 266 | try: 267 | module = importlib.import_module(module_name) # may raise ImportError 268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 269 | except ImportError: 270 | pass 271 | 272 | # we are out of luck, but we have no idea why 273 | raise ImportError(obj_name) 274 | 275 | 276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 277 | """Traverses the object name and returns the last (rightmost) python object.""" 278 | if obj_name == '': 279 | return module 280 | obj = module 281 | for part in obj_name.split("."): 282 | obj = getattr(obj, part) 283 | return obj 284 | 285 | 286 | def get_obj_by_name(name: str) -> Any: 287 | """Finds the python object with the given name.""" 288 | module, obj_name = get_module_from_obj_name(name) 289 | return get_obj_from_module(module, obj_name) 290 | 291 | 292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 293 | """Finds the python object with the given name and calls it as a function.""" 294 | assert func_name is not None 295 | func_obj = get_obj_by_name(func_name) 296 | assert callable(func_obj) 297 | return func_obj(*args, **kwargs) 298 | 299 | 300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 301 | """Finds the python class with the given name and constructs it with the given arguments.""" 302 | return call_func_by_name(*args, func_name=class_name, **kwargs) 303 | 304 | 305 | def get_module_dir_by_obj_name(obj_name: str) -> str: 306 | """Get the directory path of the module containing the given object name.""" 307 | module, _ = get_module_from_obj_name(obj_name) 308 | return os.path.dirname(inspect.getfile(module)) 309 | 310 | 311 | def is_top_level_function(obj: Any) -> bool: 312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 314 | 315 | 316 | def get_top_level_function_name(obj: Any) -> str: 317 | """Return the fully-qualified name of a top-level function.""" 318 | assert is_top_level_function(obj) 319 | module = obj.__module__ 320 | if module == '__main__': 321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 322 | return module + "." + obj.__name__ 323 | 324 | 325 | # File system helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 329 | """List all files recursively in a given directory while ignoring given file and directory names. 330 | Returns list of tuples containing both absolute and relative paths.""" 331 | assert os.path.isdir(dir_path) 332 | base_name = os.path.basename(os.path.normpath(dir_path)) 333 | 334 | if ignores is None: 335 | ignores = [] 336 | 337 | result = [] 338 | 339 | for root, dirs, files in os.walk(dir_path, topdown=True): 340 | for ignore_ in ignores: 341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 342 | 343 | # dirs need to be edited in-place 344 | for d in dirs_to_remove: 345 | dirs.remove(d) 346 | 347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 348 | 349 | absolute_paths = [os.path.join(root, f) for f in files] 350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 351 | 352 | if add_base_to_relative: 353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 354 | 355 | assert len(absolute_paths) == len(relative_paths) 356 | result += zip(absolute_paths, relative_paths) 357 | 358 | return result 359 | 360 | 361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 362 | """Takes in a list of tuples of (src, dst) paths and copies files. 363 | Will create all necessary directories.""" 364 | for file in files: 365 | target_dir_name = os.path.dirname(file[1]) 366 | 367 | # will create all intermediate-level directories 368 | if not os.path.exists(target_dir_name): 369 | os.makedirs(target_dir_name) 370 | 371 | shutil.copyfile(file[0], file[1]) 372 | 373 | 374 | # URL helpers 375 | # ------------------------------------------------------------------------------------------ 376 | 377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 378 | """Determine whether the given object is a valid URL string.""" 379 | if not isinstance(obj, str) or not "://" in obj: 380 | return False 381 | if allow_file_urls and obj.startswith('file://'): 382 | return True 383 | try: 384 | res = requests.compat.urlparse(obj) 385 | if not res.scheme or not res.netloc or not "." in res.netloc: 386 | return False 387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 388 | if not res.scheme or not res.netloc or not "." in res.netloc: 389 | return False 390 | except: 391 | return False 392 | return True 393 | 394 | 395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 396 | """Download the given URL and return a binary-mode file object to access the data.""" 397 | assert num_attempts >= 1 398 | assert not (return_filename and (not cache)) 399 | 400 | # Doesn't look like an URL scheme so interpret it as a local filename. 401 | if not re.match('^[a-z]+://', url): 402 | return url if return_filename else open(url, "rb") 403 | 404 | # Handle file URLs. This code handles unusual file:// patterns that 405 | # arise on Windows: 406 | # 407 | # file:///c:/foo.txt 408 | # 409 | # which would translate to a local '/c:/foo.txt' filename that's 410 | # invalid. Drop the forward slash for such pathnames. 411 | # 412 | # If you touch this code path, you should test it on both Linux and 413 | # Windows. 414 | # 415 | # Some internet resources suggest using urllib.request.url2pathname() but 416 | # but that converts forward slashes to backslashes and this causes 417 | # its own set of problems. 418 | if url.startswith('file://'): 419 | filename = urllib.parse.urlparse(url).path 420 | if re.match(r'^/[a-zA-Z]:', filename): 421 | filename = filename[1:] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | assert is_url(url) 425 | 426 | # Lookup from cache. 427 | if cache_dir is None: 428 | cache_dir = make_cache_dir_path('downloads') 429 | 430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 431 | if cache: 432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 433 | if len(cache_files) == 1: 434 | filename = cache_files[0] 435 | return filename if return_filename else open(filename, "rb") 436 | 437 | # Download. 438 | url_name = None 439 | url_data = None 440 | with requests.Session() as session: 441 | if verbose: 442 | print("Downloading %s ..." % url, end="", flush=True) 443 | for attempts_left in reversed(range(num_attempts)): 444 | try: 445 | with session.get(url) as res: 446 | res.raise_for_status() 447 | if len(res.content) == 0: 448 | raise IOError("No data received") 449 | 450 | if len(res.content) < 8192: 451 | content_str = res.content.decode("utf-8") 452 | if "download_warning" in res.headers.get("Set-Cookie", ""): 453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 454 | if len(links) == 1: 455 | url = requests.compat.urljoin(url, links[0]) 456 | raise IOError("Google Drive virus checker nag") 457 | if "Google Drive - Quota exceeded" in content_str: 458 | raise IOError("Google Drive download quota exceeded -- please try again later") 459 | 460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 461 | url_name = match[1] if match else url 462 | url_data = res.content 463 | if verbose: 464 | print(" done") 465 | break 466 | except KeyboardInterrupt: 467 | raise 468 | except: 469 | if not attempts_left: 470 | if verbose: 471 | print(" failed") 472 | raise 473 | if verbose: 474 | print(".", end="", flush=True) 475 | 476 | # Save to cache. 477 | if cache: 478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 479 | safe_name = safe_name[:min(len(safe_name), 128)] 480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 482 | os.makedirs(cache_dir, exist_ok=True) 483 | with open(temp_file, "wb") as f: 484 | f.write(url_data) 485 | os.replace(temp_file, cache_file) # atomic 486 | if return_filename: 487 | return cache_file 488 | 489 | # Return data as file object. 490 | assert not return_filename 491 | return io.BytesIO(url_data) 492 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/di_frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import di_metric_utils as metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /metrics/di_inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import di_metric_utils as metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | # detector_url = '/home/luoweijian/work/edm/cache/inception-2015-12-05.pt' 22 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 23 | 24 | gen_probs = metric_utils.compute_feature_stats_for_generator( 25 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 26 | capture_all=True, max_items=num_gen).get_all() 27 | 28 | if opts.rank != 0: 29 | return float('nan'), float('nan') 30 | 31 | scores = [] 32 | for i in range(num_splits): 33 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 34 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 35 | kl = np.mean(np.sum(kl, axis=1)) 36 | scores.append(np.exp(kl)) 37 | return float(np.mean(scores)), float(np.std(scores)) 38 | 39 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /metrics/di_kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import di_metric_utils as metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = '/home/luoweijian/work/edm/cache/inception-2015-12-05.pt' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /metrics/di_metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import json 12 | import torch 13 | import dnnlib 14 | 15 | from . import di_metric_utils as metric_utils 16 | from . import di_frechet_inception_distance as frechet_inception_distance 17 | from . import di_kernel_inception_distance 18 | from . import di_precision_recall as precision_recall 19 | from . import perceptual_path_length 20 | from . import di_inception_score as inception_score 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | _metric_dict = dict() # name => fn 25 | 26 | def register_metric(fn): 27 | assert callable(fn) 28 | _metric_dict[fn.__name__] = fn 29 | return fn 30 | 31 | def is_valid_metric(metric): 32 | return metric in _metric_dict 33 | 34 | def list_valid_metrics(): 35 | return list(_metric_dict.keys()) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 40 | assert is_valid_metric(metric) 41 | opts = metric_utils.MetricOptions(**kwargs) 42 | 43 | # Calculate. 44 | start_time = time.time() 45 | results = _metric_dict[metric](opts) 46 | total_time = time.time() - start_time 47 | 48 | # Broadcast results. 49 | for key, value in list(results.items()): 50 | if opts.num_gpus > 1: 51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 52 | torch.distributed.broadcast(tensor=value, src=0) 53 | value = float(value.cpu()) 54 | results[key] = value 55 | 56 | # Decorate with metadata. 57 | return dnnlib.EasyDict( 58 | results = dnnlib.EasyDict(results), 59 | metric = metric, 60 | total_time = total_time, 61 | total_time_str = dnnlib.util.format_time(total_time), 62 | num_gpus = opts.num_gpus, 63 | ) 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 68 | metric = result_dict['metric'] 69 | assert is_valid_metric(metric) 70 | if run_dir is not None and snapshot_pkl is not None: 71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 72 | 73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 74 | print(jsonl_line) 75 | if run_dir is not None and os.path.isdir(run_dir): 76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 77 | f.write(jsonl_line + '\n') 78 | 79 | #---------------------------------------------------------------------------- 80 | # Primary metrics. 81 | 82 | @register_metric 83 | def fid50k_full(opts): 84 | opts.dataset_kwargs.update(max_size=None, xflip=False) 85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 86 | return dict(fid50k_full=fid) 87 | 88 | @register_metric 89 | def kid50k_full(opts): 90 | opts.dataset_kwargs.update(max_size=None, xflip=False) 91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 92 | return dict(kid50k_full=kid) 93 | 94 | @register_metric 95 | def pr50k3_full(opts): 96 | opts.dataset_kwargs.update(max_size=None, xflip=False) 97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 99 | 100 | @register_metric 101 | def ppl2_wend(opts): 102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 103 | return dict(ppl2_wend=ppl) 104 | 105 | @register_metric 106 | def is50k(opts): 107 | opts.dataset_kwargs.update(max_size=None, xflip=False) 108 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 109 | return dict(is50k_mean=mean, is50k_std=std) 110 | 111 | #---------------------------------------------------------------------------- 112 | # Legacy metrics. 113 | 114 | @register_metric 115 | def fid50k(opts): 116 | opts.dataset_kwargs.update(max_size=None) 117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 118 | return dict(fid50k=fid) 119 | 120 | @register_metric 121 | def kid50k(opts): 122 | opts.dataset_kwargs.update(max_size=None) 123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 124 | return dict(kid50k=kid) 125 | 126 | @register_metric 127 | def pr50k3(opts): 128 | opts.dataset_kwargs.update(max_size=None) 129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 131 | 132 | @register_metric 133 | def ppl_zfull(opts): 134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2) 135 | return dict(ppl_zfull=ppl) 136 | 137 | @register_metric 138 | def ppl_wfull(opts): 139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2) 140 | return dict(ppl_wfull=ppl) 141 | 142 | @register_metric 143 | def ppl_zend(opts): 144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2) 145 | return dict(ppl_zend=ppl) 146 | 147 | @register_metric 148 | def ppl_wend(opts): 149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2) 150 | return dict(ppl_wend=ppl) 151 | 152 | #---------------------------------------------------------------------------- 153 | -------------------------------------------------------------------------------- /metrics/di_metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import pickle 13 | import copy 14 | import uuid 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class MetricOptions: 22 | def __init__(self, G=None, init_sigma=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): 23 | assert 0 <= rank < num_gpus 24 | self.G = G 25 | self.G_kwargs = dnnlib.EasyDict(G_kwargs) 26 | self.init_sigma = init_sigma 27 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) 28 | self.num_gpus = num_gpus 29 | self.rank = rank 30 | self.device = device if device is not None else torch.device('cuda', rank) 31 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() 32 | self.cache = cache 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | _feature_detector_cache = dict() 37 | 38 | def get_feature_detector_name(url): 39 | return os.path.splitext(url.split('/')[-1])[0] 40 | 41 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): 42 | assert 0 <= rank < num_gpus 43 | key = (url, device) 44 | if key not in _feature_detector_cache: 45 | is_leader = (rank == 0) 46 | if not is_leader and num_gpus > 1: 47 | torch.distributed.barrier() # leader goes first 48 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: 49 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) 50 | if is_leader and num_gpus > 1: 51 | torch.distributed.barrier() # others follow 52 | return _feature_detector_cache[key] 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | class FeatureStats: 57 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): 58 | self.capture_all = capture_all 59 | self.capture_mean_cov = capture_mean_cov 60 | self.max_items = max_items 61 | self.num_items = 0 62 | self.num_features = None 63 | self.all_features = None 64 | self.raw_mean = None 65 | self.raw_cov = None 66 | 67 | def set_num_features(self, num_features): 68 | if self.num_features is not None: 69 | assert num_features == self.num_features 70 | else: 71 | self.num_features = num_features 72 | self.all_features = [] 73 | self.raw_mean = np.zeros([num_features], dtype=np.float64) 74 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) 75 | 76 | def is_full(self): 77 | return (self.max_items is not None) and (self.num_items >= self.max_items) 78 | 79 | def append(self, x): 80 | x = np.asarray(x, dtype=np.float32) 81 | assert x.ndim == 2 82 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): 83 | if self.num_items >= self.max_items: 84 | return 85 | x = x[:self.max_items - self.num_items] 86 | 87 | self.set_num_features(x.shape[1]) 88 | self.num_items += x.shape[0] 89 | if self.capture_all: 90 | self.all_features.append(x) 91 | if self.capture_mean_cov: 92 | x64 = x.astype(np.float64) 93 | self.raw_mean += x64.sum(axis=0) 94 | self.raw_cov += x64.T @ x64 95 | 96 | def append_torch(self, x, num_gpus=1, rank=0): 97 | assert isinstance(x, torch.Tensor) and x.ndim == 2 98 | assert 0 <= rank < num_gpus 99 | if num_gpus > 1: 100 | ys = [] 101 | for src in range(num_gpus): 102 | y = x.clone() 103 | torch.distributed.broadcast(y, src=src) 104 | ys.append(y) 105 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples 106 | self.append(x.cpu().numpy()) 107 | 108 | def get_all(self): 109 | assert self.capture_all 110 | return np.concatenate(self.all_features, axis=0) 111 | 112 | def get_all_torch(self): 113 | return torch.from_numpy(self.get_all()) 114 | 115 | def get_mean_cov(self): 116 | assert self.capture_mean_cov 117 | mean = self.raw_mean / self.num_items 118 | cov = self.raw_cov / self.num_items 119 | cov = cov - np.outer(mean, mean) 120 | return mean, cov 121 | 122 | def save(self, pkl_file): 123 | with open(pkl_file, 'wb') as f: 124 | pickle.dump(self.__dict__, f) 125 | 126 | @staticmethod 127 | def load(pkl_file): 128 | with open(pkl_file, 'rb') as f: 129 | s = dnnlib.EasyDict(pickle.load(f)) 130 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) 131 | obj.__dict__.update(s) 132 | return obj 133 | 134 | #---------------------------------------------------------------------------- 135 | 136 | class ProgressMonitor: 137 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): 138 | self.tag = tag 139 | self.num_items = num_items 140 | self.verbose = verbose 141 | self.flush_interval = flush_interval 142 | self.progress_fn = progress_fn 143 | self.pfn_lo = pfn_lo 144 | self.pfn_hi = pfn_hi 145 | self.pfn_total = pfn_total 146 | self.start_time = time.time() 147 | self.batch_time = self.start_time 148 | self.batch_items = 0 149 | if self.progress_fn is not None: 150 | self.progress_fn(self.pfn_lo, self.pfn_total) 151 | 152 | def update(self, cur_items): 153 | assert (self.num_items is None) or (cur_items <= self.num_items) 154 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): 155 | return 156 | cur_time = time.time() 157 | total_time = cur_time - self.start_time 158 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) 159 | if (self.verbose) and (self.tag is not None): 160 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') 161 | self.batch_time = cur_time 162 | self.batch_items = cur_items 163 | 164 | if (self.progress_fn is not None) and (self.num_items is not None): 165 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) 166 | 167 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): 168 | return ProgressMonitor( 169 | tag = tag, 170 | num_items = num_items, 171 | flush_interval = flush_interval, 172 | verbose = self.verbose, 173 | progress_fn = self.progress_fn, 174 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, 175 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, 176 | pfn_total = self.pfn_total, 177 | ) 178 | 179 | #---------------------------------------------------------------------------- 180 | 181 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): 182 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 183 | if data_loader_kwargs is None: 184 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) 185 | 186 | # Try to lookup from cache. 187 | cache_file = None 188 | if opts.cache: 189 | # Choose cache file name. 190 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) 191 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) 192 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' 193 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') 194 | 195 | # Check if the file exists (all processes must agree). 196 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False 197 | if opts.num_gpus > 1: 198 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) 199 | torch.distributed.broadcast(tensor=flag, src=0) 200 | flag = (float(flag.cpu()) != 0) 201 | 202 | # Load. 203 | if flag: 204 | return FeatureStats.load(cache_file) 205 | 206 | # Initialize. 207 | num_items = len(dataset) 208 | if max_items is not None: 209 | num_items = min(num_items, max_items) 210 | stats = FeatureStats(max_items=num_items, **stats_kwargs) 211 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) 212 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 213 | 214 | # Main loop. 215 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] 216 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): 217 | if images.shape[1] == 1: 218 | images = images.repeat([1, 3, 1, 1]) 219 | features = detector(images.to(opts.device), **detector_kwargs) 220 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 221 | progress.update(stats.num_items) 222 | 223 | # Save to cache. 224 | if cache_file is not None and opts.rank == 0: 225 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 226 | temp_file = cache_file + '.' + uuid.uuid4().hex 227 | stats.save(temp_file) 228 | os.replace(temp_file, cache_file) # atomic 229 | return stats 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=128, batch_gen=None, jit=False, **stats_kwargs): 234 | if batch_gen is None: 235 | batch_gen = min(batch_size, 4) 236 | assert batch_size % batch_gen == 0 237 | 238 | # Setup generator and load labels. 239 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 240 | init_sigma = opts.init_sigma 241 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 242 | 243 | # Image generation func. 244 | def run_generator(z, c, init_sigma): 245 | # img = G(z=z, c=c, **opts.G_kwargs) 246 | img = G(z, init_sigma*torch.ones(z.shape[0],1,1,1).to(z.device), c, augment_labels=torch.zeros(z.shape[0], 9).to(z.device)) 247 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 248 | return img 249 | 250 | # JIT. 251 | if jit: 252 | z = init_sigma*torch.zeros([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device) 253 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device) 254 | run_generator = torch.jit.trace(run_generator, [z, c, init_sigma], check_trace=False) 255 | 256 | # Initialize. 257 | stats = FeatureStats(**stats_kwargs) 258 | assert stats.max_items is not None 259 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) 260 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 261 | 262 | # Main loop. 263 | while not stats.is_full(): 264 | images = [] 265 | for _i in range(batch_size // batch_gen): 266 | z = init_sigma*torch.randn([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device) 267 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] 268 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 269 | images.append(run_generator(z, c, init_sigma)) 270 | images = torch.cat(images) 271 | if images.shape[1] == 1: 272 | images = images.repeat([1, 3, 1, 1]) 273 | features = detector(images, **detector_kwargs) 274 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 275 | progress.update(stats.num_items) 276 | return stats 277 | 278 | #---------------------------------------------------------------------------- 279 | -------------------------------------------------------------------------------- /metrics/di_precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import di_metric_utils as metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | from . import di_metric_utils as metric_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | # Spherical interpolation of a batch of vectors. 23 | def slerp(a, b, t): 24 | a = a / a.norm(dim=-1, keepdim=True) 25 | b = b / b.norm(dim=-1, keepdim=True) 26 | d = (a * b).sum(dim=-1, keepdim=True) 27 | p = t * torch.acos(d) 28 | c = b - d * a 29 | c = c / c.norm(dim=-1, keepdim=True) 30 | d = a * torch.cos(p) + c * torch.sin(p) 31 | d = d / d.norm(dim=-1, keepdim=True) 32 | return d 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | class PPLSampler(torch.nn.Module): 37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 38 | assert space in ['z', 'w'] 39 | assert sampling in ['full', 'end'] 40 | super().__init__() 41 | self.G = copy.deepcopy(G) 42 | self.G_kwargs = G_kwargs 43 | self.epsilon = epsilon 44 | self.space = space 45 | self.sampling = sampling 46 | self.crop = crop 47 | self.vgg16 = copy.deepcopy(vgg16) 48 | 49 | def forward(self, c): 50 | # Generate random latents and interpolation t-values. 51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 53 | 54 | # Interpolate in W or Z. 55 | if self.space == 'w': 56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 59 | else: # space == 'z' 60 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 63 | 64 | # Randomize noise buffers. 65 | for name, buf in self.G.named_buffers(): 66 | if name.endswith('.noise_const'): 67 | buf.copy_(torch.randn_like(buf)) 68 | 69 | # Generate images. 70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 71 | 72 | # Center crop. 73 | if self.crop: 74 | assert img.shape[2] == img.shape[3] 75 | c = img.shape[2] // 8 76 | img = img[:, :, c*3 : c*7, c*2 : c*6] 77 | 78 | # Downsample to 256x256. 79 | factor = self.G.img_resolution // 256 80 | if factor > 1: 81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 82 | 83 | # Scale dynamic range from [-1,1] to [0,255]. 84 | img = (img + 1) * (255 / 2) 85 | if self.G.img_channels == 1: 86 | img = img.repeat([1, 3, 1, 1]) 87 | 88 | # Evaluate differential LPIPS. 89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 91 | return dist 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False): 96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 99 | 100 | # Setup sampler. 101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 102 | sampler.eval().requires_grad_(False).to(opts.device) 103 | if jit: 104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) 105 | sampler = torch.jit.trace(sampler, [c], check_trace=False) 106 | 107 | # Sampling loop. 108 | dist = [] 109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 111 | progress.update(batch_start) 112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] 113 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 114 | x = sampler(c) 115 | for src in range(opts.num_gpus): 116 | y = x.clone() 117 | if opts.num_gpus > 1: 118 | torch.distributed.broadcast(y, src=src) 119 | dist.append(y) 120 | progress.update(num_samples) 121 | 122 | # Compute PPL. 123 | if opts.rank != 0: 124 | return float('nan') 125 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 126 | lo = np.percentile(dist, 1, interpolation='lower') 127 | hi = np.percentile(dist, 99, interpolation='higher') 128 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 129 | return float(ppl) 130 | 131 | #---------------------------------------------------------------------------- 132 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/bias_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/bias_act.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/bias_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_gradfix.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_resample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_resample.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/fma.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/fma.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/fma.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/fma.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/upfirdn2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/upfirdn2d.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pkulwj1994/diff_instruct/0c37389fbfb1f5fc38708dec57d0dae257726460/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | # def _init(): 42 | # global _inited, _plugin 43 | # if not _inited: 44 | # _inited = True 45 | # sources = ['bias_act.cpp', 'bias_act.cu'] 46 | # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | # try: 48 | # _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | # except: 50 | # warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | # return _plugin is not None 52 | 53 | def _init(): 54 | return False 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 59 | r"""Fused bias and activation function. 60 | 61 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 62 | and scales the result by `gain`. Each of the steps is optional. In most cases, 63 | the fused op is considerably more efficient than performing the same calculation 64 | using standard PyTorch ops. It supports first and second order gradients, 65 | but not third order gradients. 66 | 67 | Args: 68 | x: Input activation tensor. Can be of any shape. 69 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 70 | as `x`. The shape must be known, and it must match the dimension of `x` 71 | corresponding to `dim`. 72 | dim: The dimension in `x` corresponding to the elements of `b`. 73 | The value of `dim` is ignored if `b` is not specified. 74 | act: Name of the activation function to evaluate, or `"linear"` to disable. 75 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 76 | See `activation_funcs` for a full list. `None` is not allowed. 77 | alpha: Shape parameter for the activation function, or `None` to use the default. 78 | gain: Scaling factor for the output tensor, or `None` to use default. 79 | See `activation_funcs` for the default scaling of each activation function. 80 | If unsure, consider specifying 1. 81 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 82 | the clamping (default). 83 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 84 | 85 | Returns: 86 | Tensor of the same shape and datatype as `x`. 87 | """ 88 | assert isinstance(x, torch.Tensor) 89 | assert impl in ['ref', 'cuda'] 90 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 91 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 92 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 93 | 94 | #---------------------------------------------------------------------------- 95 | 96 | @misc.profiled_function 97 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 98 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 99 | """ 100 | assert isinstance(x, torch.Tensor) 101 | assert clamp is None or clamp >= 0 102 | spec = activation_funcs[act] 103 | alpha = float(alpha if alpha is not None else spec.def_alpha) 104 | gain = float(gain if gain is not None else spec.def_gain) 105 | clamp = float(clamp if clamp is not None else -1) 106 | 107 | # Add bias. 108 | if b is not None: 109 | assert isinstance(b, torch.Tensor) and b.ndim == 1 110 | assert 0 <= dim < x.ndim 111 | assert b.shape[0] == x.shape[dim] 112 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 113 | 114 | # Evaluate activation function. 115 | alpha = float(alpha) 116 | x = spec.func(x, alpha=alpha) 117 | 118 | # Scale by gain. 119 | gain = float(gain) 120 | if gain != 1: 121 | x = x * gain 122 | 123 | # Clamp. 124 | if clamp >= 0: 125 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 126 | return x 127 | 128 | #---------------------------------------------------------------------------- 129 | 130 | _bias_act_cuda_cache = dict() 131 | 132 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 133 | """Fast CUDA implementation of `bias_act()` using custom ops. 134 | """ 135 | # Parse arguments. 136 | assert clamp is None or clamp >= 0 137 | spec = activation_funcs[act] 138 | alpha = float(alpha if alpha is not None else spec.def_alpha) 139 | gain = float(gain if gain is not None else spec.def_gain) 140 | clamp = float(clamp if clamp is not None else -1) 141 | 142 | # Lookup from cache. 143 | key = (dim, act, alpha, gain, clamp) 144 | if key in _bias_act_cuda_cache: 145 | return _bias_act_cuda_cache[key] 146 | 147 | # Forward op. 148 | class BiasActCuda(torch.autograd.Function): 149 | @staticmethod 150 | def forward(ctx, x, b): # pylint: disable=arguments-differ 151 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 152 | x = x.contiguous(memory_format=ctx.memory_format) 153 | b = b.contiguous() if b is not None else _null_tensor 154 | y = x 155 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 156 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 157 | ctx.save_for_backward( 158 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 159 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 160 | y if 'y' in spec.ref else _null_tensor) 161 | return y 162 | 163 | @staticmethod 164 | def backward(ctx, dy): # pylint: disable=arguments-differ 165 | dy = dy.contiguous(memory_format=ctx.memory_format) 166 | x, b, y = ctx.saved_tensors 167 | dx = None 168 | db = None 169 | 170 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 171 | dx = dy 172 | if act != 'linear' or gain != 1 or clamp >= 0: 173 | dx = BiasActCudaGrad.apply(dy, x, b, y) 174 | 175 | if ctx.needs_input_grad[1]: 176 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 177 | 178 | return dx, db 179 | 180 | # Backward op. 181 | class BiasActCudaGrad(torch.autograd.Function): 182 | @staticmethod 183 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 184 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 185 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 186 | ctx.save_for_backward( 187 | dy if spec.has_2nd_grad else _null_tensor, 188 | x, b, y) 189 | return dx 190 | 191 | @staticmethod 192 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 193 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 194 | dy, x, b, y = ctx.saved_tensors 195 | d_dy = None 196 | d_x = None 197 | d_b = None 198 | d_y = None 199 | 200 | if ctx.needs_input_grad[0]: 201 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 202 | 203 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 204 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 205 | 206 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 207 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 208 | 209 | return d_dy, d_x, d_b, d_y 210 | 211 | # Add to cache. 212 | _bias_act_cuda_cache[key] = BiasActCuda 213 | return BiasActCuda 214 | 215 | #---------------------------------------------------------------------------- 216 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient resampling of 2D images.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import traceback 16 | 17 | from .. import custom_ops 18 | from .. import misc 19 | from . import conv2d_gradfix 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _inited = False 24 | _plugin = None 25 | 26 | # def _init(): 27 | # global _inited, _plugin 28 | # if not _inited: 29 | # sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] 30 | # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 31 | # try: 32 | # _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 33 | # except: 34 | # warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 35 | # return _plugin is not None 36 | 37 | def _init(): 38 | return False 39 | 40 | def _parse_scaling(scaling): 41 | if isinstance(scaling, int): 42 | scaling = [scaling, scaling] 43 | assert isinstance(scaling, (list, tuple)) 44 | assert all(isinstance(x, int) for x in scaling) 45 | sx, sy = scaling 46 | assert sx >= 1 and sy >= 1 47 | return sx, sy 48 | 49 | def _parse_padding(padding): 50 | if isinstance(padding, int): 51 | padding = [padding, padding] 52 | assert isinstance(padding, (list, tuple)) 53 | assert all(isinstance(x, int) for x in padding) 54 | if len(padding) == 2: 55 | padx, pady = padding 56 | padding = [padx, padx, pady, pady] 57 | padx0, padx1, pady0, pady1 = padding 58 | return padx0, padx1, pady0, pady1 59 | 60 | def _get_filter_size(f): 61 | if f is None: 62 | return 1, 1 63 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 64 | fw = f.shape[-1] 65 | fh = f.shape[0] 66 | with misc.suppress_tracer_warnings(): 67 | fw = int(fw) 68 | fh = int(fh) 69 | misc.assert_shape(f, [fh, fw][:f.ndim]) 70 | assert fw >= 1 and fh >= 1 71 | return fw, fh 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): 76 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. 77 | 78 | Args: 79 | f: Torch tensor, numpy array, or python list of the shape 80 | `[filter_height, filter_width]` (non-separable), 81 | `[filter_taps]` (separable), 82 | `[]` (impulse), or 83 | `None` (identity). 84 | device: Result device (default: cpu). 85 | normalize: Normalize the filter so that it retains the magnitude 86 | for constant input signal (DC)? (default: True). 87 | flip_filter: Flip the filter? (default: False). 88 | gain: Overall scaling factor for signal magnitude (default: 1). 89 | separable: Return a separable filter? (default: select automatically). 90 | 91 | Returns: 92 | Float32 tensor of the shape 93 | `[filter_height, filter_width]` (non-separable) or 94 | `[filter_taps]` (separable). 95 | """ 96 | # Validate. 97 | if f is None: 98 | f = 1 99 | f = torch.as_tensor(f, dtype=torch.float32) 100 | assert f.ndim in [0, 1, 2] 101 | assert f.numel() > 0 102 | if f.ndim == 0: 103 | f = f[np.newaxis] 104 | 105 | # Separable? 106 | if separable is None: 107 | separable = (f.ndim == 1 and f.numel() >= 8) 108 | if f.ndim == 1 and not separable: 109 | f = f.ger(f) 110 | assert f.ndim == (1 if separable else 2) 111 | 112 | # Apply normalize, flip, gain, and device. 113 | if normalize: 114 | f /= f.sum() 115 | if flip_filter: 116 | f = f.flip(list(range(f.ndim))) 117 | f = f * (gain ** (f.ndim / 2)) 118 | f = f.to(device=device) 119 | return f 120 | 121 | #---------------------------------------------------------------------------- 122 | 123 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): 124 | r"""Pad, upsample, filter, and downsample a batch of 2D images. 125 | 126 | Performs the following sequence of operations for each channel: 127 | 128 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 129 | 130 | 2. Pad the image with the specified number of zeros on each side (`padding`). 131 | Negative padding corresponds to cropping the image. 132 | 133 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it 134 | so that the footprint of all output pixels lies within the input image. 135 | 136 | 4. Downsample the image by keeping every Nth pixel (`down`). 137 | 138 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 139 | The fused op is considerably more efficient than performing the same calculation 140 | using standard PyTorch ops. It supports gradients of arbitrary order. 141 | 142 | Args: 143 | x: Float32/float64/float16 input tensor of the shape 144 | `[batch_size, num_channels, in_height, in_width]`. 145 | f: Float32 FIR filter of the shape 146 | `[filter_height, filter_width]` (non-separable), 147 | `[filter_taps]` (separable), or 148 | `None` (identity). 149 | up: Integer upsampling factor. Can be a single int or a list/tuple 150 | `[x, y]` (default: 1). 151 | down: Integer downsampling factor. Can be a single int or a list/tuple 152 | `[x, y]` (default: 1). 153 | padding: Padding with respect to the upsampled image. Can be a single number 154 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 155 | (default: 0). 156 | flip_filter: False = convolution, True = correlation (default: False). 157 | gain: Overall scaling factor for signal magnitude (default: 1). 158 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 159 | 160 | Returns: 161 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 162 | """ 163 | assert isinstance(x, torch.Tensor) 164 | assert impl in ['ref', 'cuda'] 165 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 166 | return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) 167 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) 168 | 169 | #---------------------------------------------------------------------------- 170 | 171 | @misc.profiled_function 172 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 173 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. 174 | """ 175 | # Validate arguments. 176 | assert isinstance(x, torch.Tensor) and x.ndim == 4 177 | if f is None: 178 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 179 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 180 | assert f.dtype == torch.float32 and not f.requires_grad 181 | batch_size, num_channels, in_height, in_width = x.shape 182 | upx, upy = _parse_scaling(up) 183 | downx, downy = _parse_scaling(down) 184 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 185 | 186 | # Upsample by inserting zeros. 187 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) 188 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) 189 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) 190 | 191 | # Pad or crop. 192 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) 193 | x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] 194 | 195 | # Setup filter. 196 | f = f * (gain ** (f.ndim / 2)) 197 | f = f.to(x.dtype) 198 | if not flip_filter: 199 | f = f.flip(list(range(f.ndim))) 200 | 201 | # Convolve with the filter. 202 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) 203 | if f.ndim == 4: 204 | x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) 205 | else: 206 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) 207 | x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) 208 | 209 | # Downsample by throwing away pixels. 210 | x = x[:, :, ::downy, ::downx] 211 | return x 212 | 213 | #---------------------------------------------------------------------------- 214 | 215 | _upfirdn2d_cuda_cache = dict() 216 | 217 | def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): 218 | """Fast CUDA implementation of `upfirdn2d()` using custom ops. 219 | """ 220 | # Parse arguments. 221 | upx, upy = _parse_scaling(up) 222 | downx, downy = _parse_scaling(down) 223 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 224 | 225 | # Lookup from cache. 226 | key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 227 | if key in _upfirdn2d_cuda_cache: 228 | return _upfirdn2d_cuda_cache[key] 229 | 230 | # Forward op. 231 | class Upfirdn2dCuda(torch.autograd.Function): 232 | @staticmethod 233 | def forward(ctx, x, f): # pylint: disable=arguments-differ 234 | assert isinstance(x, torch.Tensor) and x.ndim == 4 235 | if f is None: 236 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 237 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 238 | y = x 239 | if f.ndim == 2: 240 | y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) 241 | else: 242 | y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) 243 | y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) 244 | ctx.save_for_backward(f) 245 | ctx.x_shape = x.shape 246 | return y 247 | 248 | @staticmethod 249 | def backward(ctx, dy): # pylint: disable=arguments-differ 250 | f, = ctx.saved_tensors 251 | _, _, ih, iw = ctx.x_shape 252 | _, _, oh, ow = dy.shape 253 | fw, fh = _get_filter_size(f) 254 | p = [ 255 | fw - padx0 - 1, 256 | iw * upx - ow * downx + padx0 - upx + 1, 257 | fh - pady0 - 1, 258 | ih * upy - oh * downy + pady0 - upy + 1, 259 | ] 260 | dx = None 261 | df = None 262 | 263 | if ctx.needs_input_grad[0]: 264 | dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) 265 | 266 | assert not ctx.needs_input_grad[1] 267 | return dx, df 268 | 269 | # Add to cache. 270 | _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda 271 | return Upfirdn2dCuda 272 | 273 | #---------------------------------------------------------------------------- 274 | 275 | def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): 276 | r"""Filter a batch of 2D images using the given 2D FIR filter. 277 | 278 | By default, the result is padded so that its shape matches the input. 279 | User-specified padding is applied on top of that, with negative values 280 | indicating cropping. Pixels outside the image are assumed to be zero. 281 | 282 | Args: 283 | x: Float32/float64/float16 input tensor of the shape 284 | `[batch_size, num_channels, in_height, in_width]`. 285 | f: Float32 FIR filter of the shape 286 | `[filter_height, filter_width]` (non-separable), 287 | `[filter_taps]` (separable), or 288 | `None` (identity). 289 | padding: Padding with respect to the output. Can be a single number or a 290 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 291 | (default: 0). 292 | flip_filter: False = convolution, True = correlation (default: False). 293 | gain: Overall scaling factor for signal magnitude (default: 1). 294 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 295 | 296 | Returns: 297 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 298 | """ 299 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 300 | fw, fh = _get_filter_size(f) 301 | p = [ 302 | padx0 + fw // 2, 303 | padx1 + (fw - 1) // 2, 304 | pady0 + fh // 2, 305 | pady1 + (fh - 1) // 2, 306 | ] 307 | return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 308 | 309 | #---------------------------------------------------------------------------- 310 | 311 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 312 | r"""Upsample a batch of 2D images using the given 2D FIR filter. 313 | 314 | By default, the result is padded so that its shape is a multiple of the input. 315 | User-specified padding is applied on top of that, with negative values 316 | indicating cropping. Pixels outside the image are assumed to be zero. 317 | 318 | Args: 319 | x: Float32/float64/float16 input tensor of the shape 320 | `[batch_size, num_channels, in_height, in_width]`. 321 | f: Float32 FIR filter of the shape 322 | `[filter_height, filter_width]` (non-separable), 323 | `[filter_taps]` (separable), or 324 | `None` (identity). 325 | up: Integer upsampling factor. Can be a single int or a list/tuple 326 | `[x, y]` (default: 1). 327 | padding: Padding with respect to the output. Can be a single number or a 328 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 329 | (default: 0). 330 | flip_filter: False = convolution, True = correlation (default: False). 331 | gain: Overall scaling factor for signal magnitude (default: 1). 332 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 333 | 334 | Returns: 335 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 336 | """ 337 | upx, upy = _parse_scaling(up) 338 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 339 | fw, fh = _get_filter_size(f) 340 | p = [ 341 | padx0 + (fw + upx - 1) // 2, 342 | padx1 + (fw - upx) // 2, 343 | pady0 + (fh + upy - 1) // 2, 344 | pady1 + (fh - upy) // 2, 345 | ] 346 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) 347 | 348 | #---------------------------------------------------------------------------- 349 | 350 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 351 | r"""Downsample a batch of 2D images using the given 2D FIR filter. 352 | 353 | By default, the result is padded so that its shape is a fraction of the input. 354 | User-specified padding is applied on top of that, with negative values 355 | indicating cropping. Pixels outside the image are assumed to be zero. 356 | 357 | Args: 358 | x: Float32/float64/float16 input tensor of the shape 359 | `[batch_size, num_channels, in_height, in_width]`. 360 | f: Float32 FIR filter of the shape 361 | `[filter_height, filter_width]` (non-separable), 362 | `[filter_taps]` (separable), or 363 | `None` (identity). 364 | down: Integer downsampling factor. Can be a single int or a list/tuple 365 | `[x, y]` (default: 1). 366 | padding: Padding with respect to the input. Can be a single number or a 367 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 368 | (default: 0). 369 | flip_filter: False = convolution, True = correlation (default: False). 370 | gain: Overall scaling factor for signal magnitude (default: 1). 371 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 372 | 373 | Returns: 374 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 375 | """ 376 | downx, downy = _parse_scaling(down) 377 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 378 | fw, fh = _get_filter_size(f) 379 | p = [ 380 | padx0 + (fw - downx + 1) // 2, 381 | padx1 + (fw - downx) // 2, 382 | pady0 + (fh - downy + 1) // 2, 383 | pady1 + (fh - downy) // 2, 384 | ] 385 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 386 | 387 | #---------------------------------------------------------------------------- 388 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | # Abstract base class for datasets. 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | cache = False, # Cache images in CPU memory? 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._cache = cache 40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 41 | self._raw_labels = None 42 | self._label_shape = None 43 | 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_raw_labels(self): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def __getstate__(self): 79 | return dict(self.__dict__, _raw_labels=None) 80 | 81 | def __del__(self): 82 | try: 83 | self.close() 84 | except: 85 | pass 86 | 87 | def __len__(self): 88 | return self._raw_idx.size 89 | 90 | def __getitem__(self, idx): 91 | raw_idx = self._raw_idx[idx] 92 | image = self._cached_images.get(raw_idx, None) 93 | if image is None: 94 | image = self._load_raw_image(raw_idx) 95 | if self._cache: 96 | self._cached_images[raw_idx] = image 97 | assert isinstance(image, np.ndarray) 98 | assert list(image.shape) == self.image_shape 99 | assert image.dtype == np.uint8 100 | if self._xflip[idx]: 101 | assert image.ndim == 3 # CHW 102 | image = image[:, :, ::-1] 103 | return image.copy(), self.get_label(idx) 104 | 105 | def get_label(self, idx): 106 | label = self._get_raw_labels()[self._raw_idx[idx]] 107 | if label.dtype == np.int64: 108 | onehot = np.zeros(self.label_shape, dtype=np.float32) 109 | onehot[label] = 1 110 | label = onehot 111 | return label.copy() 112 | 113 | def get_details(self, idx): 114 | d = dnnlib.EasyDict() 115 | d.raw_idx = int(self._raw_idx[idx]) 116 | d.xflip = (int(self._xflip[idx]) != 0) 117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 118 | return d 119 | 120 | @property 121 | def name(self): 122 | return self._name 123 | 124 | @property 125 | def image_shape(self): 126 | return list(self._raw_shape[1:]) 127 | 128 | @property 129 | def num_channels(self): 130 | assert len(self.image_shape) == 3 # CHW 131 | return self.image_shape[0] 132 | 133 | @property 134 | def resolution(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | assert self.image_shape[1] == self.image_shape[2] 137 | return self.image_shape[1] 138 | 139 | @property 140 | def label_shape(self): 141 | if self._label_shape is None: 142 | raw_labels = self._get_raw_labels() 143 | if raw_labels.dtype == np.int64: 144 | self._label_shape = [int(np.max(raw_labels)) + 1] 145 | else: 146 | self._label_shape = raw_labels.shape[1:] 147 | return list(self._label_shape) 148 | 149 | @property 150 | def label_dim(self): 151 | assert len(self.label_shape) == 1 152 | return self.label_shape[0] 153 | 154 | @property 155 | def has_labels(self): 156 | return any(x != 0 for x in self.label_shape) 157 | 158 | @property 159 | def has_onehot_labels(self): 160 | return self._get_raw_labels().dtype == np.int64 161 | 162 | #---------------------------------------------------------------------------- 163 | # Dataset subclass that loads images recursively from the specified directory 164 | # or ZIP file. 165 | 166 | class ImageFolderDataset(Dataset): 167 | def __init__(self, 168 | path, # Path to directory or zip. 169 | resolution = None, # Ensure specific resolution, None = highest available. 170 | use_pyspng = True, # Use pyspng if available? 171 | **super_kwargs, # Additional arguments for the Dataset base class. 172 | ): 173 | self._path = path 174 | self._use_pyspng = use_pyspng 175 | self._zipfile = None 176 | 177 | if os.path.isdir(self._path): 178 | self._type = 'dir' 179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 180 | elif self._file_ext(self._path) == '.zip': 181 | self._type = 'zip' 182 | self._all_fnames = set(self._get_zipfile().namelist()) 183 | else: 184 | raise IOError('Path must point to a directory or zip') 185 | 186 | PIL.Image.init() 187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 188 | if len(self._image_fnames) == 0: 189 | raise IOError('No image files found in the specified path') 190 | 191 | name = os.path.splitext(os.path.basename(self._path))[0] 192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 194 | raise IOError('Image files do not match the specified resolution') 195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 196 | 197 | @staticmethod 198 | def _file_ext(fname): 199 | return os.path.splitext(fname)[1].lower() 200 | 201 | def _get_zipfile(self): 202 | assert self._type == 'zip' 203 | if self._zipfile is None: 204 | self._zipfile = zipfile.ZipFile(self._path) 205 | return self._zipfile 206 | 207 | def _open_file(self, fname): 208 | if self._type == 'dir': 209 | return open(os.path.join(self._path, fname), 'rb') 210 | if self._type == 'zip': 211 | return self._get_zipfile().open(fname, 'r') 212 | return None 213 | 214 | def close(self): 215 | try: 216 | if self._zipfile is not None: 217 | self._zipfile.close() 218 | finally: 219 | self._zipfile = None 220 | 221 | def __getstate__(self): 222 | return dict(super().__getstate__(), _zipfile=None) 223 | 224 | def _load_raw_image(self, raw_idx): 225 | fname = self._image_fnames[raw_idx] 226 | with self._open_file(fname) as f: 227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 228 | image = pyspng.load(f.read()) 229 | else: 230 | image = np.array(PIL.Image.open(f)) 231 | if image.ndim == 2: 232 | image = image[:, :, np.newaxis] # HW => HWC 233 | image = image.transpose(2, 0, 1) # HWC => CHW 234 | return image 235 | 236 | def _load_raw_labels(self): 237 | fname = 'dataset.json' 238 | if fname not in self._all_fnames: 239 | return None 240 | with self._open_file(fname) as f: 241 | labels = json.load(f)['labels'] 242 | if labels is None: 243 | return None 244 | labels = dict(labels) 245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 246 | labels = np.array(labels) 247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 248 | return labels 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /training/di_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Weijian Luo, Peking University . All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Train one-step diffusion-based generative model using the techniques described in the 9 | paper "Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models" 10 | https://github.com/pkulwj1994/diff_instruct 11 | 12 | Code was modified from paper ""Elucidating the Design Space of Diffusion-Based Generative Models"" 13 | https://github.com/NVlabs/edm 14 | """ 15 | 16 | """Loss functions used in the paper 17 | "Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models".""" 18 | 19 | import torch 20 | from torch_utils import persistence 21 | from torch.distributions.log_normal import LogNormal 22 | import numpy as np 23 | 24 | #---------------------------------------------------------------------------- 25 | # Loss function corresponding to the variance preserving (VP) formulation 26 | # from the paper "Diff-Instruct: A Universal Approach for Transferring Knowledge of Diffusion Models". 27 | 28 | @persistence.persistent_class 29 | class DI_EDMLoss: 30 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 31 | self.P_mean = P_mean 32 | self.P_std = P_std 33 | self.sigma_data = sigma_data 34 | 35 | def gloss(self, Sd, Sg, images, labels=None, augment_pipe=None): 36 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 37 | 38 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 39 | weight = 1.0 40 | 41 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, torch.zeros(images.shape[0], 9).to(images.device)) 42 | n = torch.randn_like(y) * sigma 43 | 44 | Sg.train(), Sd.train() 45 | with torch.no_grad(): 46 | cuda_rng_state = torch.cuda.get_rng_state() 47 | Dd_yn = Sd(y + n, sigma, labels, augment_labels=augment_labels) 48 | torch.cuda.set_rng_state(cuda_rng_state) 49 | Dg_yn = Sg(y + n, sigma, labels, augment_labels=augment_labels) 50 | Sd.eval() 51 | 52 | loss = weight * ((Dg_yn - Dd_yn) * images) 53 | 54 | return loss 55 | 56 | def __call__(self, net, images, labels=None, augment_pipe=None): 57 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 58 | 59 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 60 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 61 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 62 | n = torch.randn_like(y) * sigma 63 | 64 | net.train() 65 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 66 | 67 | loss = weight * ((D_yn - y) ** 2) 68 | return loss --------------------------------------------------------------------------------