├── LICENSE.txt ├── NOTICE.txt ├── README.md ├── calc_metrics.py ├── dnnlib ├── __init__.py └── util.py ├── figures ├── d_vis.png ├── feat_vis.png ├── model.png └── sample.png ├── generate.py ├── legacy.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── tools └── visualize_gfeat.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py └── training ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py ├── loss_ggdr.py ├── networks.py ├── networks_ggdr.py └── training_loop.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | GGDR is based on the StyleGAN2-ADA project(https://github.com/NVlabs/stylegan2-ada-pytorch), 2 | so heavily brought their codes. 3 | 4 | training/loss_ggdr.py, training/network_ggdr.py 5 | - Copyright (c) 2022-present NAVER Corp. 6 | all other files 7 | - Copyright (c) 2021, NVIDIA Corporation. 8 | 9 | under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 10 | 11 | --- 12 | 13 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 14 | 15 | 16 | ======================================================================= 17 | 18 | 1. Definitions 19 | 20 | "Licensor" means any person or entity that distributes its Work. 21 | 22 | "Software" means the original work of authorship made available under 23 | this License. 24 | 25 | "Work" means the Software and any additions to or derivative works of 26 | the Software that are made available under this License. 27 | 28 | The terms "reproduce," "reproduction," "derivative works," and 29 | "distribution" have the meaning as provided under U.S. copyright law; 30 | provided, however, that for the purposes of this License, derivative 31 | works shall not include works that remain separable from, or merely 32 | link (or bind by name) to the interfaces of, the Work. 33 | 34 | Works, including the Software, are "made available" under this License 35 | by including in or with the Work either (a) a copyright notice 36 | referencing the applicability of this License to the Work, or (b) a 37 | copy of this License. 38 | 39 | 2. License Grants 40 | 41 | 2.1 Copyright Grant. Subject to the terms and conditions of this 42 | License, each Licensor grants to you a perpetual, worldwide, 43 | non-exclusive, royalty-free, copyright license to reproduce, 44 | prepare derivative works of, publicly display, publicly perform, 45 | sublicense and distribute its Work and any resulting derivative 46 | works in any form. 47 | 48 | 3. Limitations 49 | 50 | 3.1 Redistribution. You may reproduce or distribute the Work only 51 | if (a) you do so under this License, (b) you include a complete 52 | copy of this License with your distribution, and (c) you retain 53 | without modification any copyright, patent, trademark, or 54 | attribution notices that are present in the Work. 55 | 56 | 3.2 Derivative Works. You may specify that additional or different 57 | terms apply to the use, reproduction, and distribution of your 58 | derivative works of the Work ("Your Terms") only if (a) Your Terms 59 | provide that the use limitation in Section 3.3 applies to your 60 | derivative works, and (b) you identify the specific derivative 61 | works that are subject to Your Terms. Notwithstanding Your Terms, 62 | this License (including the redistribution requirements in Section 63 | 3.1) will continue to apply to the Work itself. 64 | 65 | 3.3 Use Limitation. The Work and any derivative works thereof only 66 | may be used or intended for use non-commercially. Notwithstanding 67 | the foregoing, NVIDIA and its affiliates may use the Work and any 68 | derivative works commercially. As used herein, "non-commercially" 69 | means for research or evaluation purposes only. 70 | 71 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 72 | against any Licensor (including any claim, cross-claim or 73 | counterclaim in a lawsuit) to enforce any patents that you allege 74 | are infringed by any Work, then your rights under this License from 75 | such Licensor (including the grant in Section 2.1) will terminate 76 | immediately. 77 | 78 | 3.5 Trademarks. This License does not grant any rights to use any 79 | Licensor’s or its affiliates’ names, logos, or trademarks, except 80 | as necessary to reproduce the notices described in this License. 81 | 82 | 3.6 Termination. If you violate any term of this License, then your 83 | rights under this License (including the grant in Section 2.1) will 84 | terminate immediately. 85 | 86 | 4. Disclaimer of Warranty. 87 | 88 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 89 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 90 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 91 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 92 | THIS LICENSE. 93 | 94 | 5. Limitation of Liability. 95 | 96 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 97 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 98 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 99 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 100 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 101 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 102 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 103 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 104 | THE POSSIBILITY OF SUCH DAMAGES. 105 | 106 | ======================================================================= 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GGDR - Generator-Guided Regularization for Discriminator (Official PyTorch Implementation) 2 | **[Generator Knows What Discriminator Should Learn in Unconditional GANs (ECCV 2022)](http://arxiv.org/abs/2207.13320)** \ 3 | Gayoung Lee1, Hyunsu Kim1, Junho Kim1, Seonghyeon Kim2, Jung-Woo Ha1, Yunjey Choi1 4 | 5 | 1NAVER AI Lab, 2NAVER CLOVA 6 | 7 |
8 | 9 |
10 | 11 | > **Abstract** *Recent conditional image generation methods benefit from dense supervision such as segmentation label maps to achieve high-fidelity. However, it is rarely explored to employ dense supervision for unconditional image generation. Here we explore the efficacy of dense supervision in unconditional generation and find generator feature maps can be an alternative of cost-expensive semantic label maps. From our empirical evidences, we propose a new **generator-guided discriminator regularization (GGDR)** in which the generator feature maps supervise the discriminator to have rich semantic representations in unconditional generation. In specific, we employ an encoder-decoder architecture for discriminator, which is trained to reconstruct the generator feature maps given fake images as inputs. Extensive experiments on mulitple datasets show that our GGDR consistently improves the performance of baseline methods in terms of quantitative and qualitative aspects. Code will be publicly available for the research community.* 12 | 13 | ## Credit 14 | We attach GGDR to [StyleGAN2-ADA-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch), so heavily brought their codes. 15 | 16 | ## Usage 17 | Usage of this repository is almost same with [StyleGAN2-ADA-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch) except GGDR option. See their repository for more detailed instructions. 18 | 19 | #### Training StyleGAN2-ADA with GGDR 20 | ``` 21 | > python train.py --outdir=training-runs --reg_type=ggdr --ggdr_res=64 --gpus=8 --cfg=paper256 --data=./datasets/ffhq256.zip 22 | ``` 23 | Belows are some additional arguments can be customized. 24 | - ```--reg_type=ggdr``` Enable GGDR (default: disabled) 25 | - ```--ggdr_res=64``` Set target feature map by given resolution for GGDR (default: 64). If you use smaller images(e.g. cifar10), it is recommended to set this $(resolution) / 4 (e.g. 8 for cifar10). 26 | - ```--aug=noaug``` Disables ADA (default: enabled) 27 | - ```--mirror=1``` Enables x-flips (default: disabled) 28 | 29 | #### Inference with trained model 30 | ``` 31 | > python generate.py --outdir=out --seeds=100-200 --network=PATH_TO_MODEL 32 | ``` 33 | 34 | ## Results 35 | ### Selective samples in the paper 36 |
37 | 38 |
39 | 40 | ### Discriminator feature map visualization 41 |
42 | 43 |
44 | 45 | 46 | 47 | ## License 48 | Licensed under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA). 49 | 50 | ## Citation 51 | ```bibtex 52 | @inproceedings{lee2022ggdr, 53 | title={Generator Knows What Discriminator Should Learn in Unconditional GANs}, 54 | author={Lee, Gayoung and Kim, Hyunsu and Kim, Junho and Kim, Seonghyeon and Ha, Jung-Woo and Choi, Yunjey}, 55 | booktitle={ECCV}, 56 | year={2022} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | import dnnlib 18 | 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def subprocess_fn(rank, args, temp_dir): 29 | dnnlib.util.Logger(should_flush=True) 30 | 31 | # Init torch.distributed. 32 | if args.num_gpus > 1: 33 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 34 | if os.name == 'nt': 35 | init_method = 'file:///' + init_file.replace('\\', '/') 36 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 37 | else: 38 | init_method = f'file://{init_file}' 39 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 40 | 41 | # Init torch_utils. 42 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 43 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 44 | if rank != 0 or not args.verbose: 45 | custom_ops.verbosity = 'none' 46 | 47 | # Print network summary. 48 | device = torch.device('cuda', rank) 49 | torch.backends.cudnn.benchmark = True 50 | torch.backends.cuda.matmul.allow_tf32 = False 51 | torch.backends.cudnn.allow_tf32 = False 52 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 53 | if rank == 0 and args.verbose: 54 | z = torch.empty([1, G.z_dim], device=device) 55 | c = torch.empty([1, G.c_dim], device=device) 56 | misc.print_module_summary(G, [z, c]) 57 | 58 | # Calculate each metric. 59 | for metric in args.metrics: 60 | if rank == 0 and args.verbose: 61 | print(f'Calculating {metric}...') 62 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 63 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 64 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) 65 | if rank == 0: 66 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 67 | if rank == 0 and args.verbose: 68 | print() 69 | 70 | # Done. 71 | if rank == 0 and args.verbose: 72 | print('Exiting...') 73 | 74 | #---------------------------------------------------------------------------- 75 | 76 | class CommaSeparatedList(click.ParamType): 77 | name = 'list' 78 | 79 | def convert(self, value, param, ctx): 80 | _ = param, ctx 81 | if value is None or value.lower() == 'none' or value == '': 82 | return [] 83 | return value.split(',') 84 | 85 | #---------------------------------------------------------------------------- 86 | 87 | @click.command() 88 | @click.pass_context 89 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 90 | @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True) 91 | @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH') 92 | @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL') 93 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 94 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 95 | 96 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 97 | """Calculate quality metrics for previous training run or pretrained network pickle. 98 | 99 | Examples: 100 | 101 | \b 102 | # Previous training run: look up options automatically, save result to JSONL file. 103 | python calc_metrics.py --metrics=pr50k3_full \\ 104 | --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl 105 | 106 | \b 107 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 108 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\ 109 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 110 | 111 | Available metrics: 112 | 113 | \b 114 | ADA paper: 115 | fid50k_full Frechet inception distance against the full dataset. 116 | kid50k_full Kernel inception distance against the full dataset. 117 | pr50k3_full Precision and recall againt the full dataset. 118 | is50k Inception score for CIFAR-10. 119 | 120 | \b 121 | StyleGAN and StyleGAN2 papers: 122 | fid50k Frechet inception distance against 50k real images. 123 | kid50k Kernel inception distance against 50k real images. 124 | pr50k3 Precision and recall against 50k real images. 125 | ppl2_wend Perceptual path length in W at path endpoints against full image. 126 | ppl_zfull Perceptual path length in Z for full paths against cropped image. 127 | ppl_wfull Perceptual path length in W for full paths against cropped image. 128 | ppl_zend Perceptual path length in Z at path endpoints against cropped image. 129 | ppl_wend Perceptual path length in W at path endpoints against cropped image. 130 | """ 131 | dnnlib.util.Logger(should_flush=True) 132 | 133 | # Validate arguments. 134 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 135 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 136 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 137 | if not args.num_gpus >= 1: 138 | ctx.fail('--gpus must be at least 1') 139 | 140 | # Load network. 141 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 142 | ctx.fail('--network must point to a file or URL') 143 | if args.verbose: 144 | print(f'Loading network from "{network_pkl}"...') 145 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 146 | network_dict = legacy.load_network_pkl(f) 147 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 148 | 149 | # Initialize dataset options. 150 | if data is not None: 151 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 152 | elif network_dict['training_set_kwargs'] is not None: 153 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 154 | else: 155 | ctx.fail('Could not look up dataset options; please specify --data') 156 | 157 | # Finalize dataset options. 158 | args.dataset_kwargs.resolution = args.G.img_resolution 159 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 160 | if mirror is not None: 161 | args.dataset_kwargs.xflip = mirror 162 | 163 | # Print dataset options. 164 | if args.verbose: 165 | print('Dataset options:') 166 | print(json.dumps(args.dataset_kwargs, indent=2)) 167 | 168 | # Locate run dir. 169 | args.run_dir = None 170 | if os.path.isfile(network_pkl): 171 | pkl_dir = os.path.dirname(network_pkl) 172 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 173 | args.run_dir = pkl_dir 174 | 175 | # Launch processes. 176 | if args.verbose: 177 | print('Launching processes...') 178 | torch.multiprocessing.set_start_method('spawn') 179 | with tempfile.TemporaryDirectory() as temp_dir: 180 | if args.num_gpus == 1: 181 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 182 | else: 183 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 184 | 185 | #---------------------------------------------------------------------------- 186 | 187 | if __name__ == "__main__": 188 | calc_metrics() # pylint: disable=no-value-for-parameter 189 | 190 | #---------------------------------------------------------------------------- 191 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /figures/d_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/d_vis.png -------------------------------------------------------------------------------- /figures/feat_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/feat_vis.png -------------------------------------------------------------------------------- /figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/model.png -------------------------------------------------------------------------------- /figures/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/sample.png -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def num_range(s: str) -> List[int]: 26 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 27 | 28 | range_re = re.compile(r'^(\d+)-(\d+)$') 29 | m = range_re.match(s) 30 | if m: 31 | return list(range(int(m.group(1)), int(m.group(2))+1)) 32 | vals = s.split(',') 33 | return [int(x) for x in vals] 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | @click.command() 38 | @click.pass_context 39 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 40 | @click.option('--seeds', type=num_range, help='List of random seeds') 41 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 42 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 43 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 44 | @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') 45 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 46 | def generate_images( 47 | ctx: click.Context, 48 | network_pkl: str, 49 | seeds: Optional[List[int]], 50 | truncation_psi: float, 51 | noise_mode: str, 52 | outdir: str, 53 | class_idx: Optional[int], 54 | projected_w: Optional[str] 55 | ): 56 | """Generate images using pretrained network pickle. 57 | 58 | Examples: 59 | 60 | \b 61 | # Generate curated MetFaces images without truncation (Fig.10 left) 62 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ 63 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 64 | 65 | \b 66 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left) 67 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 68 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 69 | 70 | \b 71 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car) 72 | python generate.py --outdir=out --seeds=0-35 --class=1 \\ 73 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl 74 | 75 | \b 76 | # Render an image from projected W 77 | python generate.py --outdir=out --projected_w=projected_w.npz \\ 78 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 79 | """ 80 | 81 | print('Loading networks from "%s"...' % network_pkl) 82 | device = torch.device('cuda') 83 | with dnnlib.util.open_url(network_pkl) as f: 84 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 85 | 86 | print("start here") 87 | 88 | print(G) 89 | 90 | os.makedirs(outdir, exist_ok=True) 91 | 92 | # Synthesize the result of a W projection. 93 | if projected_w is not None: 94 | if seeds is not None: 95 | print ('warn: --seeds is ignored when using --projected-w') 96 | print(f'Generating images from projected W "{projected_w}"') 97 | ws = np.load(projected_w)['w'] 98 | ws = torch.tensor(ws, device=device) # pylint: disable=not-callable 99 | assert ws.shape[1:] == (G.num_ws, G.w_dim) 100 | for idx, w in enumerate(ws): 101 | img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) 102 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 103 | img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png') 104 | return 105 | 106 | if seeds is None: 107 | ctx.fail('--seeds option is required when not using --projected-w') 108 | 109 | # Labels. 110 | label = torch.zeros([1, G.c_dim], device=device) 111 | if G.c_dim != 0: 112 | if class_idx is None: 113 | ctx.fail('Must specify class label with --class when using a conditional network') 114 | label[:, class_idx] = 1 115 | else: 116 | if class_idx is not None: 117 | print ('warn: --class=lbl ignored when running on an unconditional network') 118 | 119 | # Generate images. 120 | for seed_idx, seed in enumerate(seeds): 121 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 122 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 123 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 124 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 125 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 126 | 127 | 128 | #---------------------------------------------------------------------------- 129 | 130 | if __name__ == "__main__": 131 | generate_images() # pylint: disable=no-value-for-parameter 132 | 133 | #---------------------------------------------------------------------------- 134 | -------------------------------------------------------------------------------- /legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import click 10 | import pickle 11 | import re 12 | import copy 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | from torch_utils import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def load_network_pkl(f, force_fp16=False): 21 | data = _LegacyUnpickler(f).load() 22 | 23 | # Legacy TensorFlow pickle => convert. 24 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): 25 | tf_G, tf_D, tf_Gs = data 26 | G = convert_tf_generator(tf_G) 27 | D = convert_tf_discriminator(tf_D) 28 | G_ema = convert_tf_generator(tf_Gs) 29 | data = dict(G=G, D=D, G_ema=G_ema) 30 | 31 | # Add missing fields. 32 | if 'training_set_kwargs' not in data: 33 | data['training_set_kwargs'] = None 34 | if 'augment_pipe' not in data: 35 | data['augment_pipe'] = None 36 | 37 | # Validate contents. 38 | # assert isinstance(data['G'], torch.nn.Module) 39 | # assert isinstance(data['D'], torch.nn.Module) 40 | assert isinstance(data['G_ema'], torch.nn.Module) 41 | # assert isinstance(data['training_set_kwargs'], (dict, type(None))) 42 | # assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) 43 | 44 | # Force FP16. 45 | if force_fp16: 46 | for key in ['G', 'D', 'G_ema']: 47 | old = data[key] 48 | kwargs = copy.deepcopy(old.init_kwargs) 49 | if key.startswith('G'): 50 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) 51 | kwargs.synthesis_kwargs.num_fp16_res = 4 52 | kwargs.synthesis_kwargs.conv_clamp = 256 53 | if key.startswith('D'): 54 | kwargs.num_fp16_res = 4 55 | kwargs.conv_clamp = 256 56 | if kwargs != old.init_kwargs: 57 | new = type(old)(**kwargs).eval().requires_grad_(False) 58 | misc.copy_params_and_buffers(old, new, require_all=True) 59 | data[key] = new 60 | return data 61 | 62 | #---------------------------------------------------------------------------- 63 | 64 | class _TFNetworkStub(dnnlib.EasyDict): 65 | pass 66 | 67 | class _LegacyUnpickler(pickle.Unpickler): 68 | def find_class(self, module, name): 69 | if module == 'dnnlib.tflib.network' and name == 'Network': 70 | return _TFNetworkStub 71 | return super().find_class(module, name) 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | def _collect_tf_params(tf_net): 76 | # pylint: disable=protected-access 77 | tf_params = dict() 78 | def recurse(prefix, tf_net): 79 | for name, value in tf_net.variables: 80 | tf_params[prefix + name] = value 81 | for name, comp in tf_net.components.items(): 82 | recurse(prefix + name + '/', comp) 83 | recurse('', tf_net) 84 | return tf_params 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | def _populate_module_params(module, *patterns): 89 | for name, tensor in misc.named_params_and_buffers(module): 90 | found = False 91 | value = None 92 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): 93 | match = re.fullmatch(pattern, name) 94 | if match: 95 | found = True 96 | if value_fn is not None: 97 | value = value_fn(*match.groups()) 98 | break 99 | try: 100 | assert found 101 | if value is not None: 102 | tensor.copy_(torch.from_numpy(np.array(value))) 103 | except: 104 | print(name, list(tensor.shape)) 105 | raise 106 | 107 | #---------------------------------------------------------------------------- 108 | 109 | def convert_tf_generator(tf_G): 110 | if tf_G.version < 4: 111 | raise ValueError('TensorFlow pickle version too low') 112 | 113 | # Collect kwargs. 114 | tf_kwargs = tf_G.static_kwargs 115 | known_kwargs = set() 116 | def kwarg(tf_name, default=None, none=None): 117 | known_kwargs.add(tf_name) 118 | val = tf_kwargs.get(tf_name, default) 119 | return val if val is not None else none 120 | 121 | # Convert kwargs. 122 | kwargs = dnnlib.EasyDict( 123 | z_dim = kwarg('latent_size', 512), 124 | c_dim = kwarg('label_size', 0), 125 | w_dim = kwarg('dlatent_size', 512), 126 | img_resolution = kwarg('resolution', 1024), 127 | img_channels = kwarg('num_channels', 3), 128 | mapping_kwargs = dnnlib.EasyDict( 129 | num_layers = kwarg('mapping_layers', 8), 130 | embed_features = kwarg('label_fmaps', None), 131 | layer_features = kwarg('mapping_fmaps', None), 132 | activation = kwarg('mapping_nonlinearity', 'lrelu'), 133 | lr_multiplier = kwarg('mapping_lrmul', 0.01), 134 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), 135 | ), 136 | synthesis_kwargs = dnnlib.EasyDict( 137 | channel_base = kwarg('fmap_base', 16384) * 2, 138 | channel_max = kwarg('fmap_max', 512), 139 | num_fp16_res = kwarg('num_fp16_res', 0), 140 | conv_clamp = kwarg('conv_clamp', None), 141 | architecture = kwarg('architecture', 'skip'), 142 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 143 | use_noise = kwarg('use_noise', True), 144 | activation = kwarg('nonlinearity', 'lrelu'), 145 | ), 146 | ) 147 | 148 | # Check for unknown kwargs. 149 | kwarg('truncation_psi') 150 | kwarg('truncation_cutoff') 151 | kwarg('style_mixing_prob') 152 | kwarg('structure') 153 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 154 | if len(unknown_kwargs) > 0: 155 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 156 | 157 | # Collect params. 158 | tf_params = _collect_tf_params(tf_G) 159 | for name, value in list(tf_params.items()): 160 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) 161 | if match: 162 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 163 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value 164 | kwargs.synthesis.kwargs.architecture = 'orig' 165 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 166 | 167 | # Convert params. 168 | from training import networks 169 | G = networks.Generator(**kwargs).eval().requires_grad_(False) 170 | # pylint: disable=unnecessary-lambda 171 | _populate_module_params(G, 172 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], 173 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), 174 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], 175 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), 176 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], 177 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], 178 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), 179 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], 180 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], 181 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], 182 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), 183 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, 184 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 185 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], 186 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], 187 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], 188 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), 189 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, 190 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), 191 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], 192 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], 193 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], 194 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), 195 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, 196 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), 197 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], 198 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), 199 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, 200 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), 201 | r'.*\.resample_filter', None, 202 | ) 203 | return G 204 | 205 | #---------------------------------------------------------------------------- 206 | 207 | def convert_tf_discriminator(tf_D): 208 | if tf_D.version < 4: 209 | raise ValueError('TensorFlow pickle version too low') 210 | 211 | # Collect kwargs. 212 | tf_kwargs = tf_D.static_kwargs 213 | known_kwargs = set() 214 | def kwarg(tf_name, default=None): 215 | known_kwargs.add(tf_name) 216 | return tf_kwargs.get(tf_name, default) 217 | 218 | # Convert kwargs. 219 | kwargs = dnnlib.EasyDict( 220 | c_dim = kwarg('label_size', 0), 221 | img_resolution = kwarg('resolution', 1024), 222 | img_channels = kwarg('num_channels', 3), 223 | architecture = kwarg('architecture', 'resnet'), 224 | channel_base = kwarg('fmap_base', 16384) * 2, 225 | channel_max = kwarg('fmap_max', 512), 226 | num_fp16_res = kwarg('num_fp16_res', 0), 227 | conv_clamp = kwarg('conv_clamp', None), 228 | cmap_dim = kwarg('mapping_fmaps', None), 229 | block_kwargs = dnnlib.EasyDict( 230 | activation = kwarg('nonlinearity', 'lrelu'), 231 | resample_filter = kwarg('resample_kernel', [1,3,3,1]), 232 | freeze_layers = kwarg('freeze_layers', 0), 233 | ), 234 | mapping_kwargs = dnnlib.EasyDict( 235 | num_layers = kwarg('mapping_layers', 0), 236 | embed_features = kwarg('mapping_fmaps', None), 237 | layer_features = kwarg('mapping_fmaps', None), 238 | activation = kwarg('nonlinearity', 'lrelu'), 239 | lr_multiplier = kwarg('mapping_lrmul', 0.1), 240 | ), 241 | epilogue_kwargs = dnnlib.EasyDict( 242 | mbstd_group_size = kwarg('mbstd_group_size', None), 243 | mbstd_num_channels = kwarg('mbstd_num_features', 1), 244 | activation = kwarg('nonlinearity', 'lrelu'), 245 | ), 246 | ) 247 | 248 | # Check for unknown kwargs. 249 | kwarg('structure') 250 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) 251 | if len(unknown_kwargs) > 0: 252 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) 253 | 254 | # Collect params. 255 | tf_params = _collect_tf_params(tf_D) 256 | for name, value in list(tf_params.items()): 257 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) 258 | if match: 259 | r = kwargs.img_resolution // (2 ** int(match.group(1))) 260 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value 261 | kwargs.architecture = 'orig' 262 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') 263 | 264 | # Convert params. 265 | from training import networks 266 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False) 267 | # pylint: disable=unnecessary-lambda 268 | _populate_module_params(D, 269 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), 270 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], 271 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), 272 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], 273 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), 274 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), 275 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], 276 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), 277 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], 278 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), 279 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], 280 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), 281 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], 282 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), 283 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], 284 | r'.*\.resample_filter', None, 285 | ) 286 | return D 287 | 288 | #---------------------------------------------------------------------------- 289 | 290 | @click.command() 291 | @click.option('--source', help='Input pickle', required=True, metavar='PATH') 292 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH') 293 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) 294 | def convert_network_pickle(source, dest, force_fp16): 295 | """Convert legacy network pickle into the native PyTorch format. 296 | 297 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. 298 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. 299 | 300 | Example: 301 | 302 | \b 303 | python legacy.py \\ 304 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ 305 | --dest=stylegan2-cat-config-f.pkl 306 | """ 307 | print(f'Loading "{source}"...') 308 | with dnnlib.util.open_url(source) as f: 309 | data = load_network_pkl(f, force_fp16=force_fp16) 310 | print(f'Saving "{dest}"...') 311 | with open(dest, 'wb') as f: 312 | pickle.dump(data, f) 313 | print('Done.') 314 | 315 | #---------------------------------------------------------------------------- 316 | 317 | if __name__ == "__main__": 318 | convert_network_pickle() # pylint: disable=no-value-for-parameter 319 | 320 | #---------------------------------------------------------------------------- 321 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import json 12 | import torch 13 | import dnnlib 14 | 15 | from . import metric_utils 16 | from . import frechet_inception_distance 17 | from . import kernel_inception_distance 18 | from . import precision_recall 19 | from . import perceptual_path_length 20 | from . import inception_score 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | _metric_dict = dict() # name => fn 25 | 26 | def register_metric(fn): 27 | assert callable(fn) 28 | _metric_dict[fn.__name__] = fn 29 | return fn 30 | 31 | def is_valid_metric(metric): 32 | return metric in _metric_dict 33 | 34 | def list_valid_metrics(): 35 | return list(_metric_dict.keys()) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 40 | assert is_valid_metric(metric) 41 | opts = metric_utils.MetricOptions(**kwargs) 42 | 43 | # Calculate. 44 | start_time = time.time() 45 | results = _metric_dict[metric](opts) 46 | total_time = time.time() - start_time 47 | 48 | # Broadcast results. 49 | for key, value in list(results.items()): 50 | if opts.num_gpus > 1: 51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 52 | torch.distributed.broadcast(tensor=value, src=0) 53 | value = float(value.cpu()) 54 | results[key] = value 55 | 56 | # Decorate with metadata. 57 | return dnnlib.EasyDict( 58 | results = dnnlib.EasyDict(results), 59 | metric = metric, 60 | total_time = total_time, 61 | total_time_str = dnnlib.util.format_time(total_time), 62 | num_gpus = opts.num_gpus, 63 | ) 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 68 | metric = result_dict['metric'] 69 | assert is_valid_metric(metric) 70 | if run_dir is not None and snapshot_pkl is not None: 71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 72 | 73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 74 | print(jsonl_line) 75 | if run_dir is not None and os.path.isdir(run_dir): 76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 77 | f.write(jsonl_line + '\n') 78 | 79 | #---------------------------------------------------------------------------- 80 | # Primary metrics. 81 | 82 | @register_metric 83 | def fid50k_full(opts): 84 | opts.dataset_kwargs.update(max_size=None, xflip=False) 85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 86 | return dict(fid50k_full=fid) 87 | 88 | @register_metric 89 | def kid50k_full(opts): 90 | opts.dataset_kwargs.update(max_size=None, xflip=False) 91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 92 | return dict(kid50k_full=kid) 93 | 94 | @register_metric 95 | def pr50k3_full(opts): 96 | opts.dataset_kwargs.update(max_size=None, xflip=False) 97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 99 | 100 | @register_metric 101 | def ppl2_wend(opts): 102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 103 | return dict(ppl2_wend=ppl) 104 | 105 | @register_metric 106 | def is50k(opts): 107 | opts.dataset_kwargs.update(max_size=None, xflip=False) 108 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 109 | return dict(is50k_mean=mean, is50k_std=std) 110 | 111 | #---------------------------------------------------------------------------- 112 | # Legacy metrics. 113 | 114 | @register_metric 115 | def fid50k(opts): 116 | opts.dataset_kwargs.update(max_size=None) 117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 118 | return dict(fid50k=fid) 119 | 120 | @register_metric 121 | def kid50k(opts): 122 | opts.dataset_kwargs.update(max_size=None) 123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 124 | return dict(kid50k=kid) 125 | 126 | @register_metric 127 | def pr50k3(opts): 128 | opts.dataset_kwargs.update(max_size=None) 129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 131 | 132 | @register_metric 133 | def ppl_zfull(opts): 134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2) 135 | return dict(ppl_zfull=ppl) 136 | 137 | @register_metric 138 | def ppl_wfull(opts): 139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2) 140 | return dict(ppl_wfull=ppl) 141 | 142 | @register_metric 143 | def ppl_zend(opts): 144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2) 145 | return dict(ppl_zend=ppl) 146 | 147 | @register_metric 148 | def ppl_wend(opts): 149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2) 150 | return dict(ppl_wend=ppl) 151 | 152 | #---------------------------------------------------------------------------- 153 | -------------------------------------------------------------------------------- /metrics/metric_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import pickle 13 | import copy 14 | import uuid 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | import dnnlib 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | class MetricOptions: 23 | def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): 24 | assert 0 <= rank < num_gpus 25 | self.G = G 26 | self.G_kwargs = dnnlib.EasyDict(G_kwargs) 27 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) 28 | self.num_gpus = num_gpus 29 | self.rank = rank 30 | self.device = device if device is not None else torch.device('cuda', rank) 31 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() 32 | self.cache = cache 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | _feature_detector_cache = dict() 37 | 38 | def get_feature_detector_name(url): 39 | return os.path.splitext(url.split('/')[-1])[0] 40 | 41 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): 42 | assert 0 <= rank < num_gpus 43 | key = (url, device) 44 | if key not in _feature_detector_cache: 45 | is_leader = (rank == 0) 46 | if not is_leader and num_gpus > 1: 47 | torch.distributed.barrier() # leader goes first 48 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: 49 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) 50 | if is_leader and num_gpus > 1: 51 | torch.distributed.barrier() # others follow 52 | return _feature_detector_cache[key] 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | class FeatureStats: 57 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): 58 | self.capture_all = capture_all 59 | self.capture_mean_cov = capture_mean_cov 60 | self.max_items = max_items 61 | self.num_items = 0 62 | self.num_features = None 63 | self.all_features = None 64 | self.raw_mean = None 65 | self.raw_cov = None 66 | 67 | def set_num_features(self, num_features): 68 | if self.num_features is not None: 69 | assert num_features == self.num_features 70 | else: 71 | self.num_features = num_features 72 | self.all_features = [] 73 | self.raw_mean = np.zeros([num_features], dtype=np.float64) 74 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) 75 | 76 | def is_full(self): 77 | return (self.max_items is not None) and (self.num_items >= self.max_items) 78 | 79 | def append(self, x): 80 | x = np.asarray(x, dtype=np.float32) 81 | assert x.ndim == 2 82 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): 83 | if self.num_items >= self.max_items: 84 | return 85 | x = x[:self.max_items - self.num_items] 86 | 87 | self.set_num_features(x.shape[1]) 88 | self.num_items += x.shape[0] 89 | if self.capture_all: 90 | self.all_features.append(x) 91 | if self.capture_mean_cov: 92 | x64 = x.astype(np.float64) 93 | self.raw_mean += x64.sum(axis=0) 94 | self.raw_cov += x64.T @ x64 95 | 96 | def append_torch(self, x, num_gpus=1, rank=0): 97 | assert isinstance(x, torch.Tensor) and x.ndim == 2 98 | assert 0 <= rank < num_gpus 99 | if num_gpus > 1: 100 | ys = [] 101 | for src in range(num_gpus): 102 | y = x.clone() 103 | torch.distributed.broadcast(y, src=src) 104 | ys.append(y) 105 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples 106 | self.append(x.cpu().numpy()) 107 | 108 | def get_all(self): 109 | assert self.capture_all 110 | return np.concatenate(self.all_features, axis=0) 111 | 112 | def get_all_torch(self): 113 | return torch.from_numpy(self.get_all()) 114 | 115 | def get_mean_cov(self): 116 | assert self.capture_mean_cov 117 | mean = self.raw_mean / self.num_items 118 | cov = self.raw_cov / self.num_items 119 | cov = cov - np.outer(mean, mean) 120 | return mean, cov 121 | 122 | def save(self, pkl_file): 123 | with open(pkl_file, 'wb') as f: 124 | pickle.dump(self.__dict__, f) 125 | 126 | @staticmethod 127 | def load(pkl_file): 128 | with open(pkl_file, 'rb') as f: 129 | s = dnnlib.EasyDict(pickle.load(f)) 130 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) 131 | obj.__dict__.update(s) 132 | return obj 133 | 134 | #---------------------------------------------------------------------------- 135 | 136 | class ProgressMonitor: 137 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): 138 | self.tag = tag 139 | self.num_items = num_items 140 | self.verbose = verbose 141 | self.flush_interval = flush_interval 142 | self.progress_fn = progress_fn 143 | self.pfn_lo = pfn_lo 144 | self.pfn_hi = pfn_hi 145 | self.pfn_total = pfn_total 146 | self.start_time = time.time() 147 | self.batch_time = self.start_time 148 | self.batch_items = 0 149 | if self.progress_fn is not None: 150 | self.progress_fn(self.pfn_lo, self.pfn_total) 151 | 152 | def update(self, cur_items): 153 | assert (self.num_items is None) or (cur_items <= self.num_items) 154 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): 155 | return 156 | cur_time = time.time() 157 | total_time = cur_time - self.start_time 158 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) 159 | if (self.verbose) and (self.tag is not None): 160 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') 161 | self.batch_time = cur_time 162 | self.batch_items = cur_items 163 | 164 | if (self.progress_fn is not None) and (self.num_items is not None): 165 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) 166 | 167 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): 168 | return ProgressMonitor( 169 | tag = tag, 170 | num_items = num_items, 171 | flush_interval = flush_interval, 172 | verbose = self.verbose, 173 | progress_fn = self.progress_fn, 174 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, 175 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, 176 | pfn_total = self.pfn_total, 177 | ) 178 | 179 | #---------------------------------------------------------------------------- 180 | 181 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): 182 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 183 | if data_loader_kwargs is None: 184 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) 185 | 186 | # Try to lookup from cache. 187 | cache_file = None 188 | if opts.cache: 189 | # Choose cache file name. 190 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) 191 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) 192 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' 193 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') 194 | 195 | # Check if the file exists (all processes must agree). 196 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False 197 | if opts.num_gpus > 1: 198 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) 199 | torch.distributed.broadcast(tensor=flag, src=0) 200 | flag = (float(flag.cpu()) != 0) 201 | 202 | # Load. 203 | if flag: 204 | return FeatureStats.load(cache_file) 205 | 206 | # Initialize. 207 | num_items = len(dataset) 208 | if max_items is not None: 209 | num_items = min(num_items, max_items) 210 | stats = FeatureStats(max_items=num_items, **stats_kwargs) 211 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) 212 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 213 | 214 | # Main loop. 215 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] 216 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): 217 | if images.shape[1] == 1: 218 | images = images.repeat([1, 3, 1, 1]) 219 | features = detector(images.to(opts.device), **detector_kwargs) 220 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 221 | progress.update(stats.num_items) 222 | 223 | # Save to cache. 224 | if cache_file is not None and opts.rank == 0: 225 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 226 | temp_file = cache_file + '.' + uuid.uuid4().hex 227 | stats.save(temp_file) 228 | os.replace(temp_file, cache_file) # atomic 229 | return stats 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def lap_to_img(lap_imgs): 234 | img = 0 235 | h, w = lap_imgs[-1].size(2), lap_imgs[-1].size(3) 236 | for la in lap_imgs: 237 | img = img + F.interpolate(la, size=(h, w)) 238 | 239 | return img 240 | 241 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs): 242 | if batch_gen is None: 243 | batch_gen = min(batch_size, 4) 244 | assert batch_size % batch_gen == 0 245 | 246 | # Setup generator and load labels. 247 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 248 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 249 | 250 | # TODO : change this (LAP) 251 | # Image generation func. 252 | def run_generator(z, c): 253 | img = G(z=z, c=c, **opts.G_kwargs) 254 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 255 | return img 256 | 257 | # JIT. 258 | if jit: 259 | z = torch.zeros([batch_gen, G.z_dim], device=opts.device) 260 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device) 261 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) 262 | 263 | # Initialize. 264 | stats = FeatureStats(**stats_kwargs) 265 | assert stats.max_items is not None 266 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) 267 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 268 | 269 | # Main loop. 270 | while not stats.is_full(): 271 | images = [] 272 | for _i in range(batch_size // batch_gen): 273 | z = torch.randn([batch_gen, G.z_dim], device=opts.device) 274 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] 275 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 276 | images.append(run_generator(z, c)) 277 | images = torch.cat(images) 278 | if images.shape[1] == 1: 279 | images = images.repeat([1, 3, 1, 1]) 280 | features = detector(images, **detector_kwargs) 281 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 282 | progress.update(stats.num_items) 283 | return stats 284 | 285 | #---------------------------------------------------------------------------- 286 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | from . import metric_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | # Spherical interpolation of a batch of vectors. 23 | def slerp(a, b, t): 24 | a = a / a.norm(dim=-1, keepdim=True) 25 | b = b / b.norm(dim=-1, keepdim=True) 26 | d = (a * b).sum(dim=-1, keepdim=True) 27 | p = t * torch.acos(d) 28 | c = b - d * a 29 | c = c / c.norm(dim=-1, keepdim=True) 30 | d = a * torch.cos(p) + c * torch.sin(p) 31 | d = d / d.norm(dim=-1, keepdim=True) 32 | return d 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | class PPLSampler(torch.nn.Module): 37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 38 | assert space in ['z', 'w'] 39 | assert sampling in ['full', 'end'] 40 | super().__init__() 41 | self.G = copy.deepcopy(G) 42 | self.G_kwargs = G_kwargs 43 | self.epsilon = epsilon 44 | self.space = space 45 | self.sampling = sampling 46 | self.crop = crop 47 | self.vgg16 = copy.deepcopy(vgg16) 48 | 49 | def forward(self, c): 50 | # Generate random latents and interpolation t-values. 51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 53 | 54 | # Interpolate in W or Z. 55 | if self.space == 'w': 56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 59 | else: # space == 'z' 60 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 63 | 64 | # Randomize noise buffers. 65 | for name, buf in self.G.named_buffers(): 66 | if name.endswith('.noise_const'): 67 | buf.copy_(torch.randn_like(buf)) 68 | 69 | # Generate images. 70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 71 | 72 | # Center crop. 73 | if self.crop: 74 | assert img.shape[2] == img.shape[3] 75 | c = img.shape[2] // 8 76 | img = img[:, :, c*3 : c*7, c*2 : c*6] 77 | 78 | # Downsample to 256x256. 79 | factor = self.G.img_resolution // 256 80 | if factor > 1: 81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 82 | 83 | # Scale dynamic range from [-1,1] to [0,255]. 84 | img = (img + 1) * (255 / 2) 85 | if self.G.img_channels == 1: 86 | img = img.repeat([1, 3, 1, 1]) 87 | 88 | # Evaluate differential LPIPS. 89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 91 | return dist 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False): 96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 99 | 100 | # Setup sampler. 101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 102 | sampler.eval().requires_grad_(False).to(opts.device) 103 | if jit: 104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) 105 | sampler = torch.jit.trace(sampler, [c], check_trace=False) 106 | 107 | # Sampling loop. 108 | dist = [] 109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 111 | progress.update(batch_start) 112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] 113 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 114 | x = sampler(c) 115 | for src in range(opts.num_gpus): 116 | y = x.clone() 117 | if opts.num_gpus > 1: 118 | torch.distributed.broadcast(y, src=src) 119 | dist.append(y) 120 | progress.update(num_samples) 121 | 122 | # Compute PPL. 123 | if opts.rank != 0: 124 | return float('nan') 125 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 126 | lo = np.percentile(dist, 1, interpolation='lower') 127 | hi = np.percentile(dist, 99, interpolation='higher') 128 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 129 | return float(ppl) 130 | 131 | #---------------------------------------------------------------------------- 132 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /tools/visualize_gfeat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import click 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision 9 | from kmeans_pytorch import kmeans 10 | 11 | import sys 12 | from pathlib import Path 13 | file = Path(__file__).resolve() 14 | parent, root = file.parent, file.parents[1] 15 | sys.path.append(str(root)) 16 | 17 | import legacy 18 | import dnnlib 19 | 20 | 21 | @click.command() 22 | @click.pass_context 23 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 24 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 25 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 26 | @click.option('--num_iters', help='Number of iteration for visualization', type=int, default=1) 27 | @click.option('--batch_size', help='Batch size for clustering', type=int, default=64) 28 | def generate_images( 29 | ctx: click.Context, 30 | network_pkl: str, 31 | truncation_psi: float, 32 | outdir: str, 33 | num_iters: int, 34 | batch_size: int, 35 | ): 36 | """K-means visualization of generator feature maps. Cluster the images in the same batch(So the batch size matters here) 37 | 38 | Usage: 39 | python tools/visualize_gfeat.py --outdir=out --network=your_network_path.pkl 40 | """ 41 | torch.manual_seed(0) 42 | random.seed(0) 43 | np.random.seed(0) 44 | 45 | print('Loading networks from "%s"...' % network_pkl) 46 | device = torch.device('cuda') 47 | 48 | with dnnlib.util.open_url(network_pkl) as f: 49 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 50 | 51 | os.makedirs(f'{outdir}', exist_ok=True) 52 | 53 | for iter_idx in range(num_iters): 54 | 55 | z = torch.from_numpy(np.random.randn(batch_size, G.z_dim)).to(device) 56 | ws = G.mapping(z, c=None, truncation_psi=truncation_psi) 57 | 58 | fake_imgs, fake_feat = G.synthesis(ws, get_feat=True) 59 | 60 | vis_img = [] 61 | 62 | # the feature maps are saved in the dictionary whose keys are their 63 | # resolutions. 64 | target_layers = [16, 32, 64] 65 | num_clusters = 6 66 | 67 | for res in target_layers: 68 | img = get_cluster_vis(fake_feat[res], num_clusters=num_clusters, target_res=res) # bnum, 256, 256 69 | vis_img.append(img) 70 | 71 | for idx, val in enumerate(vis_img): 72 | vis_img[idx] = F.interpolate(val, size=(256, 256)) 73 | 74 | vis_img = torch.cat(vis_img, dim=0) # bnum * res_num, 256, 256 75 | vis_img = (vis_img + 1) * 127.5 / 255.0 76 | fake_imgs = (fake_imgs + 1) * 127.5 / 255.0 77 | fake_imgs = F.interpolate(fake_imgs, size=(256, 256)) 78 | 79 | vis_img = torch.cat([fake_imgs, vis_img], dim=0) 80 | vis_img = torchvision.utils.make_grid(vis_img, normalize=False, nrow=batch_size) 81 | torchvision.utils.save_image(vis_img, f'{outdir}/{iter_idx}.png') 82 | 83 | 84 | def get_colors(): 85 | dummy_color = np.array([ 86 | [178, 34, 34], # firebrick 87 | [0, 139, 139], # dark cyan 88 | [245, 222, 179], # wheat 89 | [25, 25, 112], # midnight blue 90 | [255, 140, 0], # dark orange 91 | [128, 128, 0], # olive 92 | [50, 50, 50], # dark grey 93 | [34, 139, 34], # forest green 94 | [100, 149, 237], # corn flower blue 95 | [153, 50, 204], # dark orchid 96 | [240, 128, 128], # light coral 97 | ]) 98 | 99 | for t in (0.6, 0.3): # just increase the number of colors for big K 100 | dummy_color = np.concatenate((dummy_color, dummy_color * t)) 101 | 102 | dummy_color = (np.array(dummy_color) - 128.0) / 128.0 103 | dummy_color = torch.from_numpy(dummy_color) 104 | 105 | return dummy_color 106 | 107 | 108 | def get_cluster_vis(feat, num_clusters=10, target_res=16): 109 | # feat : NCHW 110 | print(feat.size()) 111 | img_num, C, H, W = feat.size() 112 | feat = feat.permute(0, 2, 3, 1).contiguous().view(img_num * H * W, -1) 113 | feat = feat.to(torch.float32).cuda() 114 | cluster_ids_x, cluster_centers = kmeans( 115 | X=feat, num_clusters=num_clusters, distance='cosine', 116 | tol=1e-4, 117 | device=torch.device("cuda:0")) 118 | 119 | cluster_ids_x = cluster_ids_x.cuda() 120 | cluster_centers = cluster_centers.cuda() 121 | color_rgb = get_colors().cuda() 122 | vis_img = [] 123 | for idx in range(img_num): 124 | num_pixel = target_res * target_res 125 | current_res = cluster_ids_x[num_pixel * idx:num_pixel * (idx + 1)].cuda() 126 | color_ids = torch.index_select(color_rgb, 0, current_res) 127 | color_map = color_ids.permute(1, 0).view(1, 3, target_res, target_res) 128 | color_map = F.interpolate(color_map, size=(256, 256)) 129 | vis_img.append(color_map.cuda()) 130 | 131 | vis_img = torch.cat(vis_img, dim=0) 132 | 133 | return vis_img 134 | 135 | 136 | if __name__ == "__main__": 137 | generate_images() # pylint: disable=no-value-for-parameter 138 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'full' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Context manager for easily enabling/disabling DistributedDataParallel 164 | # synchronization. 165 | 166 | @contextlib.contextmanager 167 | def ddp_sync(module, sync): 168 | assert isinstance(module, torch.nn.Module) 169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 170 | yield 171 | else: 172 | with module.no_sync(): 173 | yield 174 | 175 | #---------------------------------------------------------------------------- 176 | # Check DistributedDataParallel consistency across processes. 177 | 178 | def check_ddp_consistency(module, ignore_regex=None): 179 | assert isinstance(module, torch.nn.Module) 180 | for name, tensor in named_params_and_buffers(module): 181 | fullname = type(module).__name__ + '.' + name 182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 183 | continue 184 | tensor = tensor.detach() 185 | other = tensor.clone() 186 | torch.distributed.broadcast(tensor=other, src=0) 187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 188 | 189 | #---------------------------------------------------------------------------- 190 | # Print summary table of module hierarchy. 191 | 192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 193 | assert isinstance(module, torch.nn.Module) 194 | assert not isinstance(module, torch.jit.ScriptModule) 195 | assert isinstance(inputs, (tuple, list)) 196 | 197 | # Register hooks. 198 | entries = [] 199 | nesting = [0] 200 | def pre_hook(_mod, _inputs): 201 | nesting[0] += 1 202 | def post_hook(mod, _inputs, outputs): 203 | nesting[0] -= 1 204 | if nesting[0] <= max_nesting: 205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 210 | 211 | # Run module. 212 | outputs = module(*inputs) 213 | for hook in hooks: 214 | hook.remove() 215 | 216 | # Identify unique outputs, parameters, and buffers. 217 | tensors_seen = set() 218 | for e in entries: 219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 223 | 224 | # Filter out redundant entries. 225 | if skip_redundant: 226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 227 | 228 | # Construct table. 229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 230 | rows += [['---'] * len(rows[0])] 231 | param_total = 0 232 | buffer_total = 0 233 | submodule_names = {mod: name for name, mod in module.named_modules()} 234 | for e in entries: 235 | name = '' if e.mod is module else submodule_names[e.mod] 236 | param_size = sum(t.numel() for t in e.unique_params) 237 | buffer_size = sum(t.numel() for t in e.unique_buffers) 238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 240 | rows += [[ 241 | name + (':0' if len(e.outputs) >= 2 else ''), 242 | str(param_size) if param_size else '-', 243 | str(buffer_size) if buffer_size else '-', 244 | (output_shapes + ['-'])[0], 245 | (output_dtypes + ['-'])[0], 246 | ]] 247 | for idx in range(1, len(e.outputs)): 248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 249 | param_total += param_size 250 | buffer_total += buffer_size 251 | rows += [['---'] * len(rows[0])] 252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 253 | 254 | # Print table. 255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 256 | print() 257 | for row in rows: 258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 259 | print() 260 | return outputs 261 | 262 | #---------------------------------------------------------------------------- 263 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import zipfile 12 | import PIL.Image 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | try: 18 | import pyspng 19 | except ImportError: 20 | pyspng = None 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, 26 | name, # Name of the dataset. 27 | raw_shape, # Shape of the raw image data (NCHW). 28 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 29 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 30 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 31 | random_seed = 0, # Random seed to use when applying max_size. 32 | ): 33 | self._name = name 34 | self._raw_shape = list(raw_shape) 35 | self._use_labels = use_labels 36 | self._raw_labels = None 37 | self._label_shape = None 38 | 39 | # Apply max_size. 40 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 41 | if (max_size is not None) and (self._raw_idx.size > max_size): 42 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 43 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 44 | 45 | # Apply xflip. 46 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 47 | if xflip: 48 | self._raw_idx = np.tile(self._raw_idx, 2) 49 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 50 | 51 | def _get_raw_labels(self): 52 | if self._raw_labels is None: 53 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 54 | if self._raw_labels is None: 55 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 56 | assert isinstance(self._raw_labels, np.ndarray) 57 | assert self._raw_labels.shape[0] == self._raw_shape[0] 58 | assert self._raw_labels.dtype in [np.float32, np.int64] 59 | if self._raw_labels.dtype == np.int64: 60 | assert self._raw_labels.ndim == 1 61 | assert np.all(self._raw_labels >= 0) 62 | return self._raw_labels 63 | 64 | def close(self): # to be overridden by subclass 65 | pass 66 | 67 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 68 | raise NotImplementedError 69 | 70 | def _load_raw_labels(self): # to be overridden by subclass 71 | raise NotImplementedError 72 | 73 | def __getstate__(self): 74 | return dict(self.__dict__, _raw_labels=None) 75 | 76 | def __del__(self): 77 | try: 78 | self.close() 79 | except: 80 | pass 81 | 82 | def __len__(self): 83 | return self._raw_idx.size 84 | 85 | def __getitem__(self, idx): 86 | image = self._load_raw_image(self._raw_idx[idx]) 87 | assert isinstance(image, np.ndarray) 88 | assert list(image.shape) == self.image_shape 89 | assert image.dtype == np.uint8 90 | if self._xflip[idx]: 91 | assert image.ndim == 3 # CHW 92 | image = image[:, :, ::-1] 93 | return image.copy(), self.get_label(idx) 94 | 95 | def get_label(self, idx): 96 | label = self._get_raw_labels()[self._raw_idx[idx]] 97 | if label.dtype == np.int64: 98 | onehot = np.zeros(self.label_shape, dtype=np.float32) 99 | onehot[label] = 1 100 | label = onehot 101 | return label.copy() 102 | 103 | def get_details(self, idx): 104 | d = dnnlib.EasyDict() 105 | d.raw_idx = int(self._raw_idx[idx]) 106 | d.xflip = (int(self._xflip[idx]) != 0) 107 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 108 | return d 109 | 110 | @property 111 | def name(self): 112 | return self._name 113 | 114 | @property 115 | def image_shape(self): 116 | return list(self._raw_shape[1:]) 117 | 118 | @property 119 | def num_channels(self): 120 | assert len(self.image_shape) == 3 # CHW 121 | return self.image_shape[0] 122 | 123 | @property 124 | def resolution(self): 125 | assert len(self.image_shape) == 3 # CHW 126 | assert self.image_shape[1] == self.image_shape[2] 127 | return self.image_shape[1] 128 | 129 | @property 130 | def label_shape(self): 131 | if self._label_shape is None: 132 | raw_labels = self._get_raw_labels() 133 | if raw_labels.dtype == np.int64: 134 | self._label_shape = [int(np.max(raw_labels)) + 1] 135 | else: 136 | self._label_shape = raw_labels.shape[1:] 137 | return list(self._label_shape) 138 | 139 | @property 140 | def label_dim(self): 141 | assert len(self.label_shape) == 1 142 | return self.label_shape[0] 143 | 144 | @property 145 | def has_labels(self): 146 | return any(x != 0 for x in self.label_shape) 147 | 148 | @property 149 | def has_onehot_labels(self): 150 | return self._get_raw_labels().dtype == np.int64 151 | 152 | #---------------------------------------------------------------------------- 153 | 154 | class ImageFolderDataset(Dataset): 155 | def __init__(self, 156 | path, # Path to directory or zip. 157 | resolution = None, # Ensure specific resolution, None = highest available. 158 | **super_kwargs, # Additional arguments for the Dataset base class. 159 | ): 160 | self._path = path 161 | self._zipfile = None 162 | 163 | if os.path.isdir(self._path): 164 | self._type = 'dir' 165 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 166 | elif self._file_ext(self._path) == '.zip': 167 | self._type = 'zip' 168 | self._all_fnames = set(self._get_zipfile().namelist()) 169 | else: 170 | raise IOError('Path must point to a directory or zip') 171 | 172 | PIL.Image.init() 173 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 174 | if len(self._image_fnames) == 0: 175 | raise IOError('No image files found in the specified path') 176 | 177 | name = os.path.splitext(os.path.basename(self._path))[0] 178 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 179 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 180 | raise IOError('Image files do not match the specified resolution') 181 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 182 | 183 | @staticmethod 184 | def _file_ext(fname): 185 | return os.path.splitext(fname)[1].lower() 186 | 187 | def _get_zipfile(self): 188 | assert self._type == 'zip' 189 | if self._zipfile is None: 190 | self._zipfile = zipfile.ZipFile(self._path) 191 | return self._zipfile 192 | 193 | def _open_file(self, fname): 194 | if self._type == 'dir': 195 | return open(os.path.join(self._path, fname), 'rb') 196 | if self._type == 'zip': 197 | return self._get_zipfile().open(fname, 'r') 198 | return None 199 | 200 | def close(self): 201 | try: 202 | if self._zipfile is not None: 203 | self._zipfile.close() 204 | finally: 205 | self._zipfile = None 206 | 207 | def __getstate__(self): 208 | return dict(super().__getstate__(), _zipfile=None) 209 | 210 | def _load_raw_image(self, raw_idx): 211 | fname = self._image_fnames[raw_idx] 212 | with self._open_file(fname) as f: 213 | if pyspng is not None and self._file_ext(fname) == '.png': 214 | image = pyspng.load(f.read()) 215 | else: 216 | image = np.array(PIL.Image.open(f)) 217 | if image.ndim == 2: 218 | image = image[:, :, np.newaxis] # HW => HWC 219 | image = image.transpose(2, 0, 1) # HWC => CHW 220 | return image 221 | 222 | def _load_raw_labels(self): 223 | fname = 'dataset.json' 224 | if fname not in self._all_fnames: 225 | return None 226 | with self._open_file(fname) as f: 227 | labels = json.load(f)['labels'] 228 | if labels is None: 229 | return None 230 | labels = dict(labels) 231 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 232 | labels = np.array(labels) 233 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 234 | return labels 235 | 236 | #---------------------------------------------------------------------------- 237 | #---------------------------------------------------------------------------- 238 | 239 | class PairedImageFolderDataset(Dataset): 240 | def __init__(self, 241 | path, # Path to directory or zip. 242 | resolution = None, # Ensure specific resolution, None = highest available. 243 | **super_kwargs, # Additional arguments for the Dataset base class. 244 | ): 245 | self._rootpath = path 246 | self._path = os.path.join(path, 'images') 247 | self._labelpath = os.path.join(path, 'annotations') 248 | 249 | self._zipfile = None 250 | 251 | if os.path.isdir(self._path): 252 | self._type = 'dir' 253 | # image path 254 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 255 | else: 256 | raise IOError('Path must point to a directory or zip') 257 | 258 | PIL.Image.init() 259 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 260 | if len(self._image_fnames) == 0: 261 | raise IOError('No image files found in the specified path') 262 | 263 | name = os.path.splitext(os.path.basename(self._path))[0] 264 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 265 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 266 | raise IOError('Image files do not match the specified resolution') 267 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 268 | 269 | @staticmethod 270 | def _file_ext(fname): 271 | return os.path.splitext(fname)[1].lower() 272 | 273 | def _get_zipfile(self): 274 | assert self._type == 'zip' 275 | if self._zipfile is None: 276 | self._zipfile = zipfile.ZipFile(self._path) 277 | return self._zipfile 278 | 279 | def _open_file(self, fname): 280 | if self._type == 'dir': 281 | return open(os.path.join(self._path, fname), 'rb') 282 | if self._type == 'zip': 283 | return self._get_zipfile().open(fname, 'r') 284 | return None 285 | 286 | def close(self): 287 | try: 288 | if self._zipfile is not None: 289 | self._zipfile.close() 290 | finally: 291 | self._zipfile = None 292 | 293 | def __getstate__(self): 294 | return dict(super().__getstate__(), _zipfile=None) 295 | 296 | def _load_raw_image(self, raw_idx): 297 | fname = self._image_fnames[raw_idx] 298 | with self._open_file(fname) as f: 299 | if pyspng is not None and self._file_ext(fname) == '.png': 300 | image = pyspng.load(f.read()) 301 | else: 302 | image = np.array(PIL.Image.open(f)) 303 | if image.ndim == 2: 304 | image = image[:, :, np.newaxis] # HW => HWC 305 | image = image.transpose(2, 0, 1) # HWC => CHW 306 | return image 307 | 308 | def _load_raw_labels(self): 309 | return None 310 | 311 | def _load_raw_labelmap(self, raw_idx): 312 | fname = self._image_fnames[raw_idx].replace('jpg', 'png') 313 | with open(os.path.join(self._labelpath, fname), 'rb') as f: 314 | image = np.array(PIL.Image.open(f)) 315 | return image 316 | 317 | def __getitem__(self, idx): 318 | image = self._load_raw_image(self._raw_idx[idx]) 319 | assert isinstance(image, np.ndarray) 320 | assert list(image.shape) == self.image_shape 321 | assert image.dtype == np.uint8 322 | label = self._load_raw_labelmap(self._raw_idx[idx]) 323 | if self._xflip[idx]: 324 | assert image.ndim == 3 # CHW 325 | image = image[:, :, ::-1] 326 | label = label[:, ::-1] 327 | return image.copy(), label.copy() 328 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | from torch_utils import training_stats 12 | from torch_utils import misc 13 | from torch_utils.ops import conv2d_gradfix 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | class Loss: 18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass 19 | raise NotImplementedError() 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | class StyleGAN2Loss(Loss): 24 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2): 25 | super().__init__() 26 | self.device = device 27 | self.G_mapping = G_mapping 28 | self.G_synthesis = G_synthesis 29 | self.D = D 30 | self.augment_pipe = augment_pipe 31 | self.style_mixing_prob = style_mixing_prob 32 | self.r1_gamma = r1_gamma 33 | self.pl_batch_shrink = pl_batch_shrink 34 | self.pl_decay = pl_decay 35 | self.pl_weight = pl_weight 36 | self.pl_mean = torch.zeros([], device=device) 37 | 38 | def run_G(self, z, c, sync): 39 | with misc.ddp_sync(self.G_mapping, sync): 40 | ws = self.G_mapping(z, c) 41 | if self.style_mixing_prob > 0: 42 | with torch.autograd.profiler.record_function('style_mixing'): 43 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 44 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 45 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] 46 | with misc.ddp_sync(self.G_synthesis, sync): 47 | img = self.G_synthesis(ws) 48 | return img, ws 49 | 50 | def run_D(self, img, c, sync): 51 | if self.augment_pipe is not None: 52 | img = self.augment_pipe(img) 53 | with misc.ddp_sync(self.D, sync): 54 | logits = self.D(img, c) 55 | return logits 56 | 57 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): 58 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 59 | do_Gmain = (phase in ['Gmain', 'Gboth']) 60 | do_Dmain = (phase in ['Dmain', 'Dboth']) 61 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) 62 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) 63 | 64 | # Gmain: Maximize logits for generated images. 65 | if do_Gmain: 66 | with torch.autograd.profiler.record_function('Gmain_forward'): 67 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl. 68 | gen_logits = self.run_D(gen_img, gen_c, sync=False) 69 | training_stats.report('Loss/scores/fake', gen_logits) 70 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 71 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) 72 | training_stats.report('Loss/G/loss', loss_Gmain) 73 | with torch.autograd.profiler.record_function('Gmain_backward'): 74 | loss_Gmain.mean().mul(gain).backward() 75 | 76 | # Gpl: Apply path length regularization. 77 | if do_Gpl: 78 | with torch.autograd.profiler.record_function('Gpl_forward'): 79 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 80 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) 81 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 82 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): 83 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 84 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 85 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 86 | self.pl_mean.copy_(pl_mean.detach()) 87 | pl_penalty = (pl_lengths - pl_mean).square() 88 | training_stats.report('Loss/pl_penalty', pl_penalty) 89 | loss_Gpl = pl_penalty * self.pl_weight 90 | training_stats.report('Loss/G/reg', loss_Gpl) 91 | with torch.autograd.profiler.record_function('Gpl_backward'): 92 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() 93 | 94 | # Dmain: Minimize logits for generated images. 95 | loss_Dgen = 0 96 | if do_Dmain: 97 | with torch.autograd.profiler.record_function('Dgen_forward'): 98 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) 99 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. 100 | training_stats.report('Loss/scores/fake', gen_logits) 101 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 102 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) 103 | with torch.autograd.profiler.record_function('Dgen_backward'): 104 | loss_Dgen.mean().mul(gain).backward() 105 | 106 | # Dmain: Maximize logits for real images. 107 | # Dr1: Apply R1 regularization. 108 | if do_Dmain or do_Dr1: 109 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' 110 | with torch.autograd.profiler.record_function(name + '_forward'): 111 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1) 112 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync) 113 | training_stats.report('Loss/scores/real', real_logits) 114 | training_stats.report('Loss/signs/real', real_logits.sign()) 115 | 116 | loss_Dreal = 0 117 | if do_Dmain: 118 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) 119 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 120 | 121 | loss_Dr1 = 0 122 | if do_Dr1: 123 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 124 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 125 | r1_penalty = r1_grads.square().sum([1,2,3]) 126 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 127 | training_stats.report('Loss/r1_penalty', r1_penalty) 128 | training_stats.report('Loss/D/reg', loss_Dr1) 129 | 130 | with torch.autograd.profiler.record_function(name + '_backward'): 131 | (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward() 132 | 133 | -------------------------------------------------------------------------------- /training/loss_ggdr.py: -------------------------------------------------------------------------------- 1 | # Generative Guided Discriminator Regularization(GGDR) 2 | # Copyright (c) 2022-present NAVER Corp. 3 | # Under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator 4 | # Augmentation (ADA) 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_utils import training_stats 10 | from torch_utils import misc 11 | from torch_utils.ops import conv2d_gradfix 12 | from training.loss import Loss 13 | 14 | 15 | class StyleGAN2GGDRLoss(Loss): 16 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, ggdr_res=64): 17 | super().__init__() 18 | self.device = device 19 | self.G_mapping = G_mapping 20 | self.G_synthesis = G_synthesis 21 | self.D = D 22 | self.augment_pipe = augment_pipe 23 | self.style_mixing_prob = style_mixing_prob 24 | self.r1_gamma = r1_gamma 25 | self.pl_batch_shrink = pl_batch_shrink 26 | self.pl_decay = pl_decay 27 | self.pl_weight = pl_weight 28 | self.pl_mean = torch.zeros([], device=device) 29 | self.ggdr_res = [ggdr_res] 30 | 31 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device) 32 | 33 | def run_G(self, z, c, ws=None, sync=True): 34 | with misc.ddp_sync(self.G_mapping, sync): 35 | if ws is None: 36 | ws = self.G_mapping(z, c) 37 | if self.style_mixing_prob > 0: 38 | with torch.autograd.profiler.record_function('style_mixing'): 39 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 40 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 41 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] 42 | with misc.ddp_sync(self.G_synthesis, sync): 43 | img, output_feat = self.G_synthesis(ws, get_feat=True) 44 | return img, ws, output_feat 45 | 46 | def run_aug_if_needed(self, img, gfeats): 47 | """ 48 | Augment image and feature map consistently 49 | """ 50 | if self.augment_pipe is not None: 51 | aug_img, gfeats = self.augment_pipe(img, gfeats) 52 | else: 53 | aug_img = img 54 | return aug_img, gfeats 55 | 56 | def run_D(self, img, c, gfeats=None, sync=None): 57 | aug_img, gfeats = self.run_aug_if_needed(img, gfeats) 58 | with misc.ddp_sync(self.D, sync): 59 | logits, out = self.D(aug_img, c) 60 | 61 | return logits, out, aug_img, gfeats 62 | 63 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): 64 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 65 | do_Gmain = (phase in ['Gmain', 'Gboth']) 66 | do_Dmain = (phase in ['Dmain', 'Dboth']) 67 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) 68 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) 69 | 70 | # Gmain: Maximize logits for generated images. 71 | if do_Gmain: 72 | with torch.autograd.profiler.record_function('Gmain_forward'): 73 | gen_img, _gen_ws, _gen_feat = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) 74 | gen_logits, _recon_gen_fmaps, _, _ = self.run_D(gen_img, gen_c, sync=False) 75 | training_stats.report('Loss/scores/fake', gen_logits) 76 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 77 | 78 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) 79 | training_stats.report('Loss/G/loss', loss_Gmain) 80 | with torch.autograd.profiler.record_function('Gmain_backward'): 81 | loss_Gmain.mean().mul(gain).backward() 82 | 83 | # Gpl: Apply path length regularization. 84 | if do_Gpl: 85 | with torch.autograd.profiler.record_function('Gpl_forward'): 86 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 87 | gen_img, gen_ws, gen_fmaps = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) 88 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 89 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): 90 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 91 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 92 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 93 | self.pl_mean.copy_(pl_mean.detach()) 94 | pl_penalty = (pl_lengths - pl_mean).square() 95 | training_stats.report('Loss/pl_penalty', pl_penalty) 96 | loss_Gpl = pl_penalty * self.pl_weight 97 | training_stats.report('Loss/G/reg', loss_Gpl) 98 | with torch.autograd.profiler.record_function('Gpl_backward'): 99 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() 100 | 101 | # Dmain: Minimize logits for generated images. 102 | if do_Dmain: 103 | with torch.autograd.profiler.record_function('Dgen_forward'): 104 | # recon fake features and w 105 | gen_img, _gen_ws, gen_fmaps = self.run_G(gen_z, gen_c, sync=sync) 106 | 107 | aug_gen_logits, aug_recon_gen_fmaps, aug_gen_img, aug_fmaps = \ 108 | self.run_D(gen_img, gen_c, gen_fmaps, sync=sync) 109 | 110 | loss_gan_gen = torch.nn.functional.softplus(aug_gen_logits) + \ 111 | aug_recon_gen_fmaps[max(aug_recon_gen_fmaps.keys())][:, 0, 0, 0] * 0 112 | 113 | loss_gen_reg = self.get_ggdr_reg(self.ggdr_res, aug_recon_gen_fmaps, aug_fmaps) 114 | 115 | loss_Dmain = loss_gan_gen + loss_gen_reg 116 | 117 | training_stats.report('Loss/D/loss_gan_gen', loss_gan_gen) 118 | training_stats.report('Loss/D/loss_gen_reg', loss_gen_reg) 119 | 120 | with torch.autograd.profiler.record_function('Dgen_backward'): 121 | loss_Dmain.mean().mul(gain).backward() 122 | 123 | # Dmain: Maximize logits for real images. 124 | # Dr1: Apply R1 regularization. 125 | if do_Dmain or do_Dr1: 126 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' 127 | with torch.autograd.profiler.record_function(name + '_forward'): 128 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1) 129 | real_logits, aug_recon_real_fmaps, _, _ = self.run_D(real_img_tmp, real_c, sync=sync) 130 | training_stats.report('Loss/scores/real', real_logits) 131 | training_stats.report('Loss/signs/real', real_logits.sign()) 132 | 133 | loss_Dreal = 0 134 | if do_Dmain: 135 | loss_Dreal = torch.nn.functional.softplus(-real_logits) 136 | training_stats.report(f'Loss/D/loss', loss_Dreal + loss_Dreal) 137 | 138 | loss_Dr1 = 0 139 | if do_Dr1: 140 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 141 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 142 | r1_penalty = r1_grads.square().sum([1, 2, 3]) 143 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 144 | training_stats.report('Loss/r1_penalty', r1_penalty) 145 | training_stats.report('Loss/D/reg', loss_Dr1) 146 | 147 | # collect not used branch for DDP training 148 | loss_not_used = aug_recon_real_fmaps[max(aug_recon_real_fmaps.keys())][:, 0, 0, 0] * 0 149 | 150 | with torch.autograd.profiler.record_function(name + '_backward'): 151 | (loss_Dreal + loss_Dr1 + real_logits * 0 + loss_not_used * 0).mean().mul(gain).backward() 152 | 153 | def cosine_distance(self, x, y): 154 | return 1. - F.cosine_similarity(x, y).mean() 155 | 156 | def get_ggdr_reg(self, ggdr_resolutions, source, target): 157 | loss_gen_recon = 0 158 | 159 | for res in ggdr_resolutions: 160 | loss_gen_recon += 10 * self.cosine_distance(source[res], target[res]) / len(ggdr_resolutions) 161 | 162 | return loss_gen_recon 163 | -------------------------------------------------------------------------------- /training/networks_ggdr.py: -------------------------------------------------------------------------------- 1 | # Generative Guided Discriminator Regularization(GGDR) 2 | # Copyright (c) 2022-present NAVER Corp. 3 | # Under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator 4 | # Augmentation (ADA) 5 | 6 | import numpy as np 7 | import torch 8 | from torch_utils import misc 9 | from torch_utils import persistence 10 | from training.networks import Conv2dLayer, MappingNetwork, DiscriminatorBlock, DiscriminatorEpilogue 11 | from training.networks import SynthesisNetwork as OrigSynthesisNetwork 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | @persistence.persistent_class 16 | class SynthesisNetwork(OrigSynthesisNetwork): 17 | def __init__(self, 18 | w_dim, # Intermediate latent (W) dimensionality. 19 | img_resolution, # Output image resolution. 20 | img_channels, # Number of color channels. 21 | channel_base = 32768, # Overall multiplier for the number of channels. 22 | channel_max = 512, # Maximum number of channels in any layer. 23 | num_fp16_res = 0, # Use FP16 for the N highest resolutions. 24 | **block_kwargs, # Arguments for SynthesisBlock. 25 | ): 26 | super().__init__(w_dim, img_resolution, img_channels, channel_base, channel_max, num_fp16_res, **block_kwargs) 27 | 28 | def forward(self, ws, get_feat=False, **block_kwargs): 29 | block_ws = [] 30 | with torch.autograd.profiler.record_function('split_ws'): 31 | misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) 32 | ws = ws.to(torch.float32) 33 | w_idx = 0 34 | for res in self.block_resolutions: 35 | block = getattr(self, f'b{res}') 36 | block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) 37 | w_idx += block.num_conv 38 | 39 | x = img = None 40 | 41 | feats = {} 42 | for res, cur_ws in zip(self.block_resolutions, block_ws): 43 | block = getattr(self, f'b{res}') 44 | x, img = block(x, img, cur_ws, **block_kwargs) 45 | 46 | if get_feat: 47 | feats[res] = x.float() 48 | 49 | if get_feat: 50 | return img, feats 51 | else: 52 | return img 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | @persistence.persistent_class 57 | class Generator(torch.nn.Module): 58 | def __init__(self, 59 | z_dim, # Input latent (Z) dimensionality. 60 | c_dim, # Conditioning label (C) dimensionality. 61 | w_dim, # Intermediate latent (W) dimensionality. 62 | img_resolution, # Output resolution. 63 | img_channels, # Number of output color channels. 64 | mapping_kwargs = {}, # Arguments for MappingNetwork. 65 | synthesis_kwargs = {}, # Arguments for SynthesisNetwork. 66 | ): 67 | super().__init__() 68 | self.z_dim = z_dim 69 | self.c_dim = c_dim 70 | self.w_dim = w_dim 71 | self.img_resolution = img_resolution 72 | self.img_channels = img_channels 73 | self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) 74 | self.num_ws = self.synthesis.num_ws 75 | self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) 76 | 77 | def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): 78 | ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) 79 | img = self.synthesis(ws, **synthesis_kwargs) 80 | return img 81 | 82 | #---------------------------------------------------------------------------- 83 | 84 | @persistence.persistent_class 85 | class Discriminator(torch.nn.Module): 86 | def __init__(self, 87 | c_dim, # Conditioning label (C) dimensionality. 88 | img_resolution, # Input resolution. 89 | img_channels, # Number of input color channels. 90 | architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. 91 | channel_base = 32768, # Overall multiplier for the number of channels. 92 | channel_max = 512, # Maximum number of channels in any layer. 93 | num_fp16_res = 0, # Use FP16 for the N highest resolutions. 94 | conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. 95 | cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. 96 | block_kwargs = {}, # Arguments for DiscriminatorBlock. 97 | mapping_kwargs = {}, # Arguments for MappingNetwork. 98 | epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. 99 | w_dim = 512, 100 | decoder_res = 64, 101 | ): 102 | super().__init__() 103 | self.c_dim = c_dim 104 | self.img_resolution = img_resolution 105 | self.img_resolution_log2 = int(np.log2(img_resolution)) 106 | self.img_channels = img_channels 107 | self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] 108 | channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} 109 | fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) 110 | self.fp16_resolution = fp16_resolution 111 | 112 | if cmap_dim is None: 113 | cmap_dim = channels_dict[4] 114 | if c_dim == 0: 115 | cmap_dim = 0 116 | 117 | common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) 118 | cur_layer_idx = 0 119 | for res in self.block_resolutions: 120 | in_channels = channels_dict[res] if res < img_resolution else 0 121 | tmp_channels = channels_dict[res] 122 | out_channels = channels_dict[res // 2] 123 | use_fp16 = (res >= fp16_resolution) 124 | block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, 125 | first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) 126 | setattr(self, f'b{res}', block) 127 | cur_layer_idx += block.num_layers 128 | 129 | if c_dim > 0: 130 | self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) 131 | 132 | self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) 133 | 134 | # ************************************************* 135 | # Decoder part for GGDR loss 136 | # ************************************************* 137 | dec_kernel_size = 1 138 | self.dec_resolutions = [2 ** i for i in range(3, int(np.log2(decoder_res)) + 1)] 139 | 140 | for res in self.dec_resolutions: 141 | out_channels = channels_dict[res] 142 | in_channels = channels_dict[res // 2] 143 | if res != self.dec_resolutions[0]: 144 | in_channels *= 2 145 | 146 | block = Conv2dLayer(in_channels, out_channels, kernel_size=dec_kernel_size, 147 | activation='linear', up=2) 148 | setattr(self, f'b{res}_dec', block) 149 | 150 | def forward(self, img, c, **block_kwargs): 151 | x = None 152 | feats = {} 153 | for res in self.block_resolutions: 154 | block = getattr(self, f'b{res}') 155 | x, img = block(x, img, **block_kwargs) 156 | feats[res // 2] = x # keep feature maps for unet decoder 157 | 158 | cmap = None 159 | if self.c_dim > 0: 160 | cmap = self.mapping(None, c) 161 | 162 | logits = self.b4(x, img, cmap) # original real/fake logits 163 | 164 | # Run decoder part 165 | fmaps = {} 166 | for idx, res in enumerate(self.dec_resolutions): 167 | block = getattr(self, f'b{res}_dec') 168 | if idx == 0: 169 | y = feats[res // 2] 170 | else: 171 | y = torch.cat([y, feats[res // 2]], dim=1) 172 | y = block(y) 173 | fmaps[res] = y 174 | 175 | return logits, fmaps 176 | 177 | #---------------------------------------------------------------------------- 178 | --------------------------------------------------------------------------------