├── .gitignore ├── LICENCE ├── README.md ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── environment.yml ├── gen_images.py ├── gen_video.py ├── legacy.py ├── media └── banner.png ├── metrics ├── equivariance.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── pg_modules ├── blocks.py ├── diffaug.py ├── discriminator.py ├── networks_fastgan.py ├── networks_stylegan2.py └── projector.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py ├── training_stats.py └── utils_spectrum.py ├── train.py └── training ├── dataset.py ├── loss.py └── training_loop.py /.gitignore: -------------------------------------------------------------------------------- 1 | g++ 2 | gcc 3 | 4 | best_models/* 5 | 6 | data 7 | data/* 8 | !data/.placeholder 9 | 10 | training-runs/* 11 | out/* 12 | 13 | 14 | *.zip 15 | 16 | **/__pycache__ 17 | __pycache__ 18 | .ipynb_checkpoints/ 19 | tags 20 | *.swp 21 | *.pth 22 | *.pt 23 | *.npz 24 | *.tar 25 | *.gz 26 | *.pkl 27 | *.mp4 28 | *.pyc 29 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | #### [[Project]](https://sites.google.com/view/projected-gan/) [[PDF]](http://www.cvlibs.net/publications/Sauer2021NEURIPS.pdf) [[Supplementary]](http://www.cvlibs.net/publications/Sauer2021NEURIPS_supplementary.pdf) [[Talk]](https://recorder-v3.slideslive.com/#/share?share=50538&s=bf7a6393-410c-49d9-8edf-c61fa486c354) [[CGP Summary]](https://www.casualganpapers.com/data-efficient-fast-gan-training-small-datasets/ProjectedGAN-explained.html) [[Replicate Demo]](https://replicate.com/xl-sr/projected_gan) [[Hugging Face Spaces Demo]](https://huggingface.co/spaces/autonomousvision/projected_gan) 4 | 5 | For a quick start, try the Colab:   [![Projected GAN Quickstart](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/xl-sr/757757ff8709ad1721c6d9462efdc347/projected_gan.ipynb) 6 | 7 | This repository contains the code for our NeurIPS 2021 paper "Projected GANs Converge Faster" 8 | 9 | by [Axel Sauer](https://axelsauer.com/), [Kashyap Chitta](https://kashyap7x.github.io/), [Jens Müller](https://hci.iwr.uni-heidelberg.de/users/jmueller), and [Andreas Geiger](http://www.cvlibs.net/). 10 | 11 | If you find our code or paper useful, please cite 12 | ```bibtex 13 | @InProceedings{Sauer2021NEURIPS, 14 | author = {Axel Sauer and Kashyap Chitta and Jens M{\"{u}}ller and Andreas Geiger}, 15 | title = {Projected GANs Converge Faster}, 16 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 17 | year = {2021}, 18 | } 19 | ``` 20 | ## Related Projects ## 21 | - [StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets](https://github.com/autonomousvision/stylegan_xl) 22 | 23 | ## ToDos 24 | - [x] Initial code release 25 | - [x] Easy-to-use colab 26 | - [x] StyleGAN3 support (moved to https://github.com/autonomousvision/stylegan_xl) 27 | - [x] Providing pretrained models 28 | 29 | ## Requirements ## 30 | - 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions. 31 | - Use the following commands with Miniconda3 to create and activate your PG Python environment: 32 | - ```conda env create -f environment.yml``` 33 | - ```conda activate pg``` 34 | - The StyleGAN2 generator relies on custom CUDA kernels, which are compiled on the fly. Hence you need: 35 | - CUDA toolkit 11.1 or later. 36 | - GCC 7 or later compilers. Recommended GCC version depends on CUDA version, see for example CUDA 11.4 system requirements. 37 | - If you run into problems when setting up for the custom CUDA kernels, we refer to the [Troubleshooting docs](https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md#why-is-cuda-toolkit-installation-necessary) of the original StyleGAN repo. When using the FastGAN generator you will not need the custom kernels. 38 | 39 | ## Data Preparation ## 40 | For a quick start, you can download the few-shot datasets provided by the authors of [FastGAN](https://github.com/odegeasslbc/FastGAN-pytorch). You can download them [here](https://drive.google.com/file/d/1aAJCZbXNHyraJ6Mi13dSbe7pTyfPXha0/view). To prepare the dataset at the respective resolution, run for example 41 | ``` 42 | python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \ 43 | --resolution=256x256 --transform=center-crop 44 | ``` 45 | You can get the datasets we used in our paper at their respective websites: 46 | 47 | [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/), [FFHQ](https://github.com/NVlabs/ffhq-dataset), [Cityscapes](https://www.cityscapes-dataset.com/), [LSUN](https://github.com/fyu/lsun), [AFHQ](https://github.com/clovaai/stargan-v2), [Landscape](https://www.kaggle.com/arnaud58/landscape-pictures). 48 | 49 | ## Training ## 50 | 51 | Training your own PG on LSUN church using 8 GPUs: 52 | ``` 53 | python train.py --outdir=./training-runs/ --cfg=fastgan --data=./data/pokemon256.zip \ 54 | --gpus=8 --batch=64 --mirror=1 --snap=50 --batch-gpu=8 --kimg=10000 55 | ``` 56 | ```--batch``` specifies the overall batch size, ```--batch-gpu``` specifies the batch size per GPU. If you use fewer GPUs, the training loop will automatically accumulate gradients, until the overall batch size is reached. 57 | 58 | If you want to use the StyleGAN2 generator, pass ```--cfg=stylegan2```. 59 | We also added a lightweight version of FastGAN (```--cfg=fastgan_lite```). This backbone trains fast regarding wallclock 60 | time and yields better results on small datasets like Pokemon. 61 | Samples and metrics are saved in ```outdir```. To monitor the training progress, you can inspect fid50k_full.json or run tensorboard in training-runs. 62 | 63 | ## Generating Samples & Interpolations ## 64 | 65 | To generate samples and interpolation videos, run 66 | ``` 67 | python gen_images.py --outdir=out --trunc=1.0 --seeds=10-15 \ 68 | --network=PATH_TO_NETWORK_PKL 69 | ``` 70 | and 71 | ``` 72 | python gen_video.py --output=lerp.mp4 --trunc=1.0 --seeds=0-31 --grid=4x2 \ 73 | --network=PATH_TO_NETWORK_PKL 74 | ``` 75 | 76 | We provide the following pretrained models (pass the url as `PATH_TO_NETWORK_PKL`): 77 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/art_painting.pkl`
78 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/church.pkl`
79 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/bedroom.pkl`
80 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/cityscapes.pkl`
81 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/clevr.pkl`
82 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/ffhq.pkl`
83 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/flowers.pkl`
84 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/landscape.pkl`
85 | > `https://s3.eu-central-1.amazonaws.com/avg-projects/projected_gan/models/pokemon.pkl`
86 | 87 | ## Quality Metrics ## 88 | Per default, ```train.py``` tracks FID50k during training. To calculate metrics for a specific network snapshot, run 89 | 90 | ``` 91 | python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL 92 | ``` 93 | 94 | To see the available metrics, run 95 | ``` 96 | python calc_metrics.py --help 97 | ``` 98 | 99 | ## Using PG in your own project ## 100 | 101 | Our implementation is modular, so it is straightforward to use PG in your own codebase. Simply copy the ```pg_modules``` folder to your project. 102 | Then, to get the projected multi-scale discriminator, run 103 | ``` 104 | from pg_modules.discriminator import ProjectedDiscriminator 105 | D = ProjectedDiscriminator() 106 | ``` 107 | The only thing you still need to do is to make sure that the feature network is not trained, i.e., explicitly set 108 | ``` 109 | D.feature_network.requires_grad_(False) 110 | ``` 111 | in your training loop. 112 | 113 | ## Acknowledgments ## 114 | Our codebase build and extends the awesome [StyleGAN2-ADA repo](https://github.com/NVlabs/stylegan2-ada-pytorch) and [StyleGAN3 repo](https://github.com/NVlabs/stylegan3), both by Karras et al. 115 | 116 | Furthermore, we use parts of the code of [FastGAN](https://github.com/odegeasslbc/FastGAN-pytorch) and [MiDas](https://github.com/isl-org/MiDaS). 117 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | 18 | import dnnlib 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | from torch_utils.ops import conv2d_gradfix 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def subprocess_fn(rank, args, temp_dir): 30 | dnnlib.util.Logger(should_flush=True) 31 | 32 | # Init torch.distributed. 33 | if args.num_gpus > 1: 34 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 35 | if os.name == 'nt': 36 | init_method = 'file:///' + init_file.replace('\\', '/') 37 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 38 | else: 39 | init_method = f'file://{init_file}' 40 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 41 | 42 | # Init torch_utils. 43 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 44 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 45 | if rank != 0 or not args.verbose: 46 | custom_ops.verbosity = 'none' 47 | 48 | # Configure torch. 49 | device = torch.device('cuda', rank) 50 | torch.backends.cuda.matmul.allow_tf32 = False 51 | torch.backends.cudnn.allow_tf32 = False 52 | conv2d_gradfix.enabled = True 53 | 54 | # Print network summary. 55 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 56 | if rank == 0 and args.verbose: 57 | z = torch.empty([1, G.z_dim], device=device) 58 | c = torch.empty([1, G.c_dim], device=device) 59 | misc.print_module_summary(G, [z, c]) 60 | 61 | # Calculate each metric. 62 | for metric in args.metrics: 63 | if rank == 0 and args.verbose: 64 | print(f'Calculating {metric}...') 65 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 66 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 67 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress, snapshot_pkl=args.network_pkl) 68 | if rank == 0: 69 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 70 | if rank == 0 and args.verbose: 71 | print() 72 | 73 | # Done. 74 | if rank == 0 and args.verbose: 75 | print('Exiting...') 76 | 77 | #---------------------------------------------------------------------------- 78 | 79 | def parse_comma_separated_list(s): 80 | if isinstance(s, list): 81 | return s 82 | if s is None or s.lower() == 'none' or s == '': 83 | return [] 84 | return s.split(',') 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | @click.command() 89 | @click.pass_context 90 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 91 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) 92 | @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') 93 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') 94 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 95 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 96 | 97 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 98 | """Calculate quality metrics for previous training run or pretrained network pickle. 99 | 100 | Examples: 101 | 102 | \b 103 | # Previous training run: look up options automatically, save result to JSONL file. 104 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ 105 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl 106 | 107 | \b 108 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 109 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ 110 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl 111 | 112 | \b 113 | Recommended metrics: 114 | fid50k_full Frechet inception distance against the full dataset. 115 | kid50k_full Kernel inception distance against the full dataset. 116 | pr50k3_full Precision and recall againt the full dataset. 117 | ppl2_wend Perceptual path length in W, endpoints, full image. 118 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T). 119 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). 120 | eqr50k Equivariance w.r.t. rotation (EQ-R). 121 | 122 | \b 123 | Legacy metrics: 124 | fid50k Frechet inception distance against 50k real images. 125 | kid50k Kernel inception distance against 50k real images. 126 | pr50k3 Precision and recall against 50k real images. 127 | is50k Inception score for CIFAR-10. 128 | """ 129 | dnnlib.util.Logger(should_flush=True) 130 | 131 | # Validate arguments. 132 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 133 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 134 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 135 | if not args.num_gpus >= 1: 136 | ctx.fail('--gpus must be at least 1') 137 | 138 | # Load network. 139 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 140 | ctx.fail('--network must point to a file or URL') 141 | if args.verbose: 142 | print(f'Loading network from "{network_pkl}"...') 143 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 144 | network_dict = legacy.load_network_pkl(f) 145 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 146 | 147 | # Initialize dataset options. 148 | if data is not None: 149 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 150 | elif network_dict['training_set_kwargs'] is not None: 151 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 152 | else: 153 | ctx.fail('Could not look up dataset options; please specify --data') 154 | 155 | # Finalize dataset options. 156 | args.dataset_kwargs.resolution = args.G.img_resolution 157 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 158 | if mirror is not None: 159 | args.dataset_kwargs.xflip = mirror 160 | 161 | # Print dataset options. 162 | if args.verbose: 163 | print('Dataset options:') 164 | print(json.dumps(args.dataset_kwargs, indent=2)) 165 | 166 | # Locate run dir. 167 | args.run_dir = None 168 | if os.path.isfile(network_pkl): 169 | pkl_dir = os.path.dirname(network_pkl) 170 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 171 | args.run_dir = pkl_dir 172 | 173 | # Launch processes. 174 | if args.verbose: 175 | print('Launching processes...') 176 | torch.multiprocessing.set_start_method('spawn') 177 | with tempfile.TemporaryDirectory() as temp_dir: 178 | if args.num_gpus == 1: 179 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 180 | else: 181 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | if __name__ == "__main__": 186 | calc_metrics() # pylint: disable=no-value-for-parameter 187 | 188 | #---------------------------------------------------------------------------- 189 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pg 2 | channels: 3 | - anaconda 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - absl-py=1.0.0=pyhd8ed1ab_0 11 | - aiohttp=3.7.0=py39h07f9747_0 12 | - async-timeout=3.0.1=py_1000 13 | - attrs=21.2.0=pyhd8ed1ab_0 14 | - blas=1.0=mkl 15 | - blinker=1.4=py_1 16 | - brotli=1.0.9=he6710b0_2 17 | - brotlipy=0.7.0=py39h27cfd23_1003 18 | - c-ares=1.18.1=h7f98852_0 19 | - ca-certificates=2021.10.8=ha878542_0 20 | - cachetools=4.2.4=pyhd8ed1ab_0 21 | - certifi=2021.10.8=py39hf3d152e_1 22 | - cffi=1.14.6=py39h400218f_0 23 | - chardet=3.0.4=py39h079e4ff_1008 24 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 25 | - click=8.0.3=pyhd3eb1b0_0 26 | - cryptography=35.0.0=py39hd23ed53_0 27 | - cudatoolkit=11.1.74=h6bb024c_0 28 | - cudnn=8.2.1.32=h86fa8c9_0 29 | - cycler=0.10.0=py39h06a4308_0 30 | - dataclasses=0.8=pyhc8e2a94_3 31 | - dbus=1.13.18=hb2f20db_0 32 | - dill=0.3.2=py_0 33 | - expat=2.4.1=h2531618_2 34 | - fontconfig=2.13.1=h6c09931_0 35 | - fonttools=4.25.0=pyhd3eb1b0_0 36 | - freetype=2.11.0=h70c0345_0 37 | - future=0.18.2=py39hf3d152e_4 38 | - glib=2.69.1=h5202010_0 39 | - google-auth=2.3.3=pyh6c4a22f_0 40 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 41 | - grpcio=1.38.1=py39hff7568b_0 42 | - gst-plugins-base=1.14.0=h8213a91_2 43 | - gstreamer=1.14.0=h28cd5cc_2 44 | - icu=58.2=he6710b0_3 45 | - idna=3.3=pyhd3eb1b0_0 46 | - imageio=2.9.0=pyhd3eb1b0_0 47 | - importlib-metadata=4.8.2=py39hf3d152e_0 48 | - intel-openmp=2021.4.0=h06a4308_3561 49 | - jpeg=9d=h7f8727e_0 50 | - kiwisolver=1.3.1=py39h2531618_0 51 | - lcms2=2.12=h3be6417_0 52 | - ld_impl_linux-64=2.35.1=h7274673_9 53 | - libblas=3.9.0=12_linux64_mkl 54 | - libffi=3.3=he6710b0_2 55 | - libgcc-ng=11.2.0=h1d223b6_11 56 | - libgfortran-ng=7.5.0=ha8ba4b0_17 57 | - libgfortran4=7.5.0=ha8ba4b0_17 58 | - libgomp=11.2.0=h1d223b6_11 59 | - liblapack=3.9.0=12_linux64_mkl 60 | - libpng=1.6.37=hbc83047_0 61 | - libprotobuf=3.18.0=h780b84a_1 62 | - libstdcxx-ng=11.2.0=he4da1e4_11 63 | - libtiff=4.2.0=h85742a9_0 64 | - libuuid=1.0.3=h7f8727e_2 65 | - libuv=1.40.0=h7b6447c_0 66 | - libwebp-base=1.2.0=h27cfd23_0 67 | - libxcb=1.14=h7b6447c_0 68 | - libxml2=2.9.12=h03d6c58_0 69 | - lz4-c=1.9.3=h295c915_1 70 | - magma=2.5.4=ha9b7cf9_2 71 | - markdown=3.3.6=pyhd8ed1ab_0 72 | - matplotlib=3.4.2=py39h06a4308_0 73 | - matplotlib-base=3.4.2=py39hab158f2_0 74 | - mkl=2021.4.0=h06a4308_640 75 | - mkl-service=2.4.0=py39h7f8727e_0 76 | - mkl_fft=1.3.1=py39hd3c417c_0 77 | - mkl_random=1.2.2=py39h51133e4_0 78 | - multidict=5.2.0=py39h3811e60_1 79 | - munkres=1.1.4=py_0 80 | - nccl=2.11.4.1=h97a9cb7_0 81 | - ncurses=6.3=h7f8727e_2 82 | - ninja=1.10.2=py39hd09550d_3 83 | - numpy=1.21.2=py39h20f2e39_0 84 | - numpy-base=1.21.2=py39h79a1101_0 85 | - oauthlib=3.1.1=pyhd8ed1ab_0 86 | - olefile=0.46=pyhd3eb1b0_0 87 | - openjpeg=2.4.0=h3ad879b_0 88 | - openssl=1.1.1l=h7f98852_0 89 | - pcre=8.45=h295c915_0 90 | - pillow=8.3.1=py39h2c7a002_0 91 | - pip=21.2.4=py39h06a4308_0 92 | - protobuf=3.18.0=py39he80948d_0 93 | - psutil=5.8.0=py39h3811e60_1 94 | - pyasn1=0.4.8=py_0 95 | - pyasn1-modules=0.2.7=py_0 96 | - pycparser=2.21=pyhd3eb1b0_0 97 | - pyjwt=2.3.0=pyhd8ed1ab_0 98 | - pyopenssl=21.0.0=pyhd3eb1b0_1 99 | - pyparsing=3.0.4=pyhd3eb1b0_0 100 | - pyqt=5.9.2=py39h2531618_6 101 | - pysocks=1.7.1=py39h06a4308_0 102 | - python=3.9.7=h12debd9_1 103 | - python-dateutil=2.8.2=pyhd3eb1b0_0 104 | - python_abi=3.9=2_cp39 105 | - pytorch=1.9.1=cuda111py39hb4a4491_3 106 | - pytorch-gpu=1.9.1=cuda111py39h788eb59_3 107 | - pyu2f=0.1.5=pyhd8ed1ab_0 108 | - qt=5.9.7=h5867ecd_1 109 | - readline=8.1=h27cfd23_0 110 | - requests=2.26.0=pyhd3eb1b0_0 111 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 112 | - rsa=4.8=pyhd8ed1ab_0 113 | - scipy=1.7.1=py39h292c36d_2 114 | - setuptools=58.0.4=py39h06a4308_0 115 | - sip=4.19.13=py39h2531618_0 116 | - six=1.16.0=pyhd3eb1b0_0 117 | - sleef=3.5.1=h9b69904_2 118 | - sqlite=3.36.0=hc218d9a_0 119 | - tensorboard=2.7.0=pyhd8ed1ab_0 120 | - tensorboard-data-server=0.6.0=py39h95dcef6_1 121 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 122 | - timm=0.4.12=pyhd8ed1ab_0 123 | - tk=8.6.11=h1ccaba5_0 124 | - torchvision=0.10.1=py39cuda111hcd06603_0_cuda 125 | - tornado=6.1=py39h27cfd23_0 126 | - tqdm=4.62.2=pyhd3eb1b0_1 127 | - typing_extensions=3.10.0.2=pyh06a4308_0 128 | - tzdata=2021e=hda174b7_0 129 | - urllib3=1.26.7=pyhd3eb1b0_0 130 | - werkzeug=2.0.1=pyhd8ed1ab_0 131 | - wheel=0.37.0=pyhd3eb1b0_1 132 | - xz=5.2.5=h7b6447c_0 133 | - yarl=1.7.2=py39h3811e60_1 134 | - zipp=3.6.0=pyhd8ed1ab_0 135 | - zlib=1.2.11=h7b6447c_3 136 | - zstd=1.4.9=haebb681_0 137 | - pip: 138 | - glfw==2.2.0 139 | - imageio-ffmpeg==0.4.3 140 | - imgui==1.3.0 141 | - pyopengl==3.1.5 142 | - pyspng==0.1.0 143 | -------------------------------------------------------------------------------- /gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def parse_range(s: Union[str, List]) -> List[int]: 26 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 27 | 28 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 29 | ''' 30 | if isinstance(s, list): return s 31 | ranges = [] 32 | range_re = re.compile(r'^(\d+)-(\d+)$') 33 | for p in s.split(','): 34 | m = range_re.match(p) 35 | if m: 36 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 37 | else: 38 | ranges.append(int(p)) 39 | return ranges 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 44 | '''Parse a floating point 2-vector of syntax 'a,b'. 45 | 46 | Example: 47 | '0,1' returns (0,1) 48 | ''' 49 | if isinstance(s, tuple): return s 50 | parts = s.split(',') 51 | if len(parts) == 2: 52 | return (float(parts[0]), float(parts[1])) 53 | raise ValueError(f'cannot parse 2-vector {s}') 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def make_transform(translate: Tuple[float,float], angle: float): 58 | m = np.eye(3) 59 | s = np.sin(angle/360.0*np.pi*2) 60 | c = np.cos(angle/360.0*np.pi*2) 61 | m[0][0] = c 62 | m[0][1] = s 63 | m[0][2] = translate[0] 64 | m[1][0] = -s 65 | m[1][1] = c 66 | m[1][2] = translate[1] 67 | return m 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @click.command() 72 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 73 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 74 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 75 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 76 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 77 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 78 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 79 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 80 | def generate_images( 81 | network_pkl: str, 82 | seeds: List[int], 83 | truncation_psi: float, 84 | noise_mode: str, 85 | outdir: str, 86 | translate: Tuple[float,float], 87 | rotate: float, 88 | class_idx: Optional[int] 89 | ): 90 | """Generate images using pretrained network pickle. 91 | 92 | Examples: 93 | 94 | \b 95 | # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). 96 | python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ 97 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 98 | 99 | \b 100 | # Generate uncurated images with truncation using the MetFaces-U dataset 101 | python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 102 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl 103 | """ 104 | 105 | print('Loading networks from "%s"...' % network_pkl) 106 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 107 | with dnnlib.util.open_url(network_pkl) as f: 108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 109 | 110 | os.makedirs(outdir, exist_ok=True) 111 | 112 | # Labels. 113 | label = torch.zeros([1, G.c_dim], device=device) 114 | if G.c_dim != 0: 115 | if class_idx is None: 116 | raise click.ClickException('Must specify class label with --class when using a conditional network') 117 | label[:, class_idx] = 1 118 | else: 119 | if class_idx is not None: 120 | print ('warn: --class=lbl ignored when running on an unconditional network') 121 | 122 | # Generate images. 123 | for seed_idx, seed in enumerate(seeds): 124 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 125 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device).float() 126 | 127 | # Construct an inverse rotation/translation matrix and pass to the generator. The 128 | # generator expects this matrix as an inverse to avoid potentially failing numerical 129 | # operations in the network. 130 | if hasattr(G.synthesis, 'input'): 131 | m = make_transform(translate, rotate) 132 | m = np.linalg.inv(m) 133 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 134 | 135 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 136 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 137 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 138 | 139 | 140 | #---------------------------------------------------------------------------- 141 | 142 | if __name__ == "__main__": 143 | generate_images() # pylint: disable=no-value-for-parameter 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate lerp videos using pretrained network pickle.""" 10 | 11 | import copy 12 | import os 13 | import re 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import click 17 | import dnnlib 18 | import imageio 19 | import numpy as np 20 | import scipy.interpolate 21 | import torch 22 | from tqdm import tqdm 23 | 24 | import legacy 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): 29 | batch_size, channels, img_h, img_w = img.shape 30 | if grid_w is None: 31 | grid_w = batch_size // grid_h 32 | assert batch_size == grid_w * grid_h 33 | if float_to_uint8: 34 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 35 | img = img.reshape(grid_h, grid_w, channels, img_h, img_w) 36 | img = img.permute(2, 0, 3, 1, 4) 37 | img = img.reshape(channels, grid_h * img_h, grid_w * img_w) 38 | if chw_to_hwc: 39 | img = img.permute(1, 2, 0) 40 | if to_numpy: 41 | img = img.cpu().numpy() 42 | return img 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), class_idx=None, **video_kwargs): 47 | grid_w = grid_dims[0] 48 | grid_h = grid_dims[1] 49 | 50 | if num_keyframes is None: 51 | if len(seeds) % (grid_w*grid_h) != 0: 52 | raise ValueError('Number of input seeds must be divisible by grid W*H') 53 | num_keyframes = len(seeds) // (grid_w*grid_h) 54 | 55 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) 56 | for idx in range(num_keyframes*grid_h*grid_w): 57 | all_seeds[idx] = seeds[idx % len(seeds)] 58 | 59 | if shuffle_seed is not None: 60 | rng = np.random.RandomState(seed=shuffle_seed) 61 | rng.shuffle(all_seeds) 62 | 63 | zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device).float() 64 | # Labels. 65 | label = torch.zeros([zs.size(0), G.c_dim], device=device) 66 | if G.c_dim != 0: 67 | if class_idx is None: 68 | raise click.ClickException('Must specify class label with --class when using a conditional network') 69 | label[:, class_idx] = 1 70 | else: 71 | if class_idx is not None: 72 | print ('warn: --class=lbl ignored when running on an unconditional network') 73 | 74 | ws = G.mapping(z=zs, c=label, truncation_psi=psi) 75 | _ = G.synthesis(ws[:1], c=label) # warm up 76 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) 77 | 78 | # Interpolation. 79 | grid = [] 80 | for yi in range(grid_h): 81 | row = [] 82 | for xi in range(grid_w): 83 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) 84 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) 85 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) 86 | row.append(interp) 87 | grid.append(row) 88 | 89 | # Render video. 90 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) 91 | for frame_idx in tqdm(range(num_keyframes * w_frames)): 92 | imgs = [] 93 | for yi in range(grid_h): 94 | for xi in range(grid_w): 95 | interp = grid[yi][xi] 96 | w = torch.from_numpy(interp(frame_idx / w_frames)).to(device).float() 97 | img = G.synthesis(w.unsqueeze(0), c=label, noise_mode='const')[0] 98 | imgs.append(img) 99 | video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) 100 | video_out.close() 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | def parse_range(s: Union[str, List[int]]) -> List[int]: 105 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 106 | 107 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 108 | ''' 109 | if isinstance(s, list): return s 110 | ranges = [] 111 | range_re = re.compile(r'^(\d+)-(\d+)$') 112 | for p in s.split(','): 113 | m = range_re.match(p) 114 | if m: 115 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 116 | else: 117 | ranges.append(int(p)) 118 | return ranges 119 | 120 | #---------------------------------------------------------------------------- 121 | 122 | def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: 123 | '''Parse a 'M,N' or 'MxN' integer tuple. 124 | 125 | Example: 126 | '4x2' returns (4,2) 127 | '0,1' returns (0,1) 128 | ''' 129 | if isinstance(s, tuple): return s 130 | m = re.match(r'^(\d+)[x,](\d+)$', s) 131 | if m: 132 | return (int(m.group(1)), int(m.group(2))) 133 | raise ValueError(f'cannot parse tuple {s}') 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | @click.command() 138 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 139 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True) 140 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) 141 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) 142 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) 143 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) 144 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 145 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') 146 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 147 | def generate_images( 148 | network_pkl: str, 149 | seeds: List[int], 150 | shuffle_seed: Optional[int], 151 | truncation_psi: float, 152 | grid: Tuple[int,int], 153 | num_keyframes: Optional[int], 154 | w_frames: int, 155 | output: str, 156 | class_idx: Optional[int], 157 | ): 158 | """Render a latent vector interpolation video. 159 | 160 | Examples: 161 | 162 | \b 163 | # Render a 4x2 grid of interpolations for seeds 0 through 31. 164 | python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ 165 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 166 | 167 | Animation length and seed keyframes: 168 | 169 | The animation length is either determined based on the --seeds value or explicitly 170 | specified using the --num-keyframes option. 171 | 172 | When num keyframes is specified with --num-keyframes, the output video length 173 | will be 'num_keyframes*w_frames' frames. 174 | 175 | If --num-keyframes is not specified, the number of seeds given with 176 | --seeds must be divisible by grid size W*H (--grid). In this case the 177 | output video length will be '# seeds/(w*h)*w_frames' frames. 178 | """ 179 | 180 | print('Loading networks from "%s"...' % network_pkl) 181 | device = torch.device('cuda') 182 | with dnnlib.util.open_url(network_pkl) as f: 183 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 184 | 185 | gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, class_idx=class_idx) 186 | 187 | #---------------------------------------------------------------------------- 188 | 189 | if __name__ == "__main__": 190 | generate_images() # pylint: disable=no-value-for-parameter 191 | 192 | #---------------------------------------------------------------------------- 193 | -------------------------------------------------------------------------------- /media/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/projected-gan/e1c246b8bdce4fac3c2bfcb69df309fc27df9b86/media/banner.png -------------------------------------------------------------------------------- /metrics/equivariance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper 10 | "Alias-Free Generative Adversarial Networks".""" 11 | 12 | import copy 13 | import numpy as np 14 | import torch 15 | import torch.fft 16 | from torch_utils.ops import upfirdn2d 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | # Utilities. 21 | 22 | def sinc(x): 23 | y = (x * np.pi).abs() 24 | z = torch.sin(y) / y.clamp(1e-30, float('inf')) 25 | return torch.where(y < 1e-30, torch.ones_like(x), z) 26 | 27 | def lanczos_window(x, a): 28 | x = x.abs() / a 29 | return torch.where(x < 1, sinc(x), torch.zeros_like(x)) 30 | 31 | def rotation_matrix(angle): 32 | angle = torch.as_tensor(angle).to(torch.float32) 33 | mat = torch.eye(3, device=angle.device) 34 | mat[0, 0] = angle.cos() 35 | mat[0, 1] = angle.sin() 36 | mat[1, 0] = -angle.sin() 37 | mat[1, 1] = angle.cos() 38 | return mat 39 | 40 | #---------------------------------------------------------------------------- 41 | # Apply integer translation to a batch of 2D images. Corresponds to the 42 | # operator T_x in Appendix E.1. 43 | 44 | def apply_integer_translation(x, tx, ty): 45 | _N, _C, H, W = x.shape 46 | tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) 47 | ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) 48 | ix = tx.round().to(torch.int64) 49 | iy = ty.round().to(torch.int64) 50 | 51 | z = torch.zeros_like(x) 52 | m = torch.zeros_like(x) 53 | if abs(ix) < W and abs(iy) < H: 54 | y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)] 55 | z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y 56 | m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1 57 | return z, m 58 | 59 | #---------------------------------------------------------------------------- 60 | # Apply integer translation to a batch of 2D images. Corresponds to the 61 | # operator T_x in Appendix E.2. 62 | 63 | def apply_fractional_translation(x, tx, ty, a=3): 64 | _N, _C, H, W = x.shape 65 | tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) 66 | ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) 67 | ix = tx.floor().to(torch.int64) 68 | iy = ty.floor().to(torch.int64) 69 | fx = tx - ix 70 | fy = ty - iy 71 | b = a - 1 72 | 73 | z = torch.zeros_like(x) 74 | zx0 = max(ix - b, 0) 75 | zy0 = max(iy - b, 0) 76 | zx1 = min(ix + a, 0) + W 77 | zy1 = min(iy + a, 0) + H 78 | if zx0 < zx1 and zy0 < zy1: 79 | taps = torch.arange(a * 2, device=x.device) - b 80 | filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0) 81 | filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1) 82 | y = x 83 | y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0]) 84 | y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a]) 85 | y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)] 86 | z[:, :, zy0:zy1, zx0:zx1] = y 87 | 88 | m = torch.zeros_like(x) 89 | mx0 = max(ix + a, 0) 90 | my0 = max(iy + a, 0) 91 | mx1 = min(ix - b, 0) + W 92 | my1 = min(iy - b, 0) + H 93 | if mx0 < mx1 and my0 < my1: 94 | m[:, :, my0:my1, mx0:mx1] = 1 95 | return z, m 96 | 97 | #---------------------------------------------------------------------------- 98 | # Construct an oriented low-pass filter that applies the appropriate 99 | # bandlimit with respect to the input and output of the given affine 2D 100 | # image transformation. 101 | 102 | def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): 103 | assert a <= amax < aflt 104 | mat = torch.as_tensor(mat).to(torch.float32) 105 | 106 | # Construct 2D filter taps in input & output coordinate spaces. 107 | taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) 108 | yi, xi = torch.meshgrid(taps, taps) 109 | xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) 110 | 111 | # Convolution of two oriented 2D sinc filters. 112 | fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in) 113 | fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out) 114 | f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real 115 | 116 | # Convolution of two oriented 2D Lanczos windows. 117 | wi = lanczos_window(xi, a) * lanczos_window(yi, a) 118 | wo = lanczos_window(xo, a) * lanczos_window(yo, a) 119 | w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real 120 | 121 | # Construct windowed FIR filter. 122 | f = f * w 123 | 124 | # Finalize. 125 | c = (aflt - amax) * up 126 | f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] 127 | f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) 128 | f = f / f.sum([0,2], keepdim=True) / (up ** 2) 129 | f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] 130 | return f 131 | 132 | #---------------------------------------------------------------------------- 133 | # Apply the given affine transformation to a batch of 2D images. 134 | 135 | def apply_affine_transformation(x, mat, up=4, **filter_kwargs): 136 | _N, _C, H, W = x.shape 137 | mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) 138 | 139 | # Construct filter. 140 | f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) 141 | assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 142 | p = f.shape[0] // 2 143 | 144 | # Construct sampling grid. 145 | theta = mat.inverse() 146 | theta[:2, 2] *= 2 147 | theta[0, 2] += 1 / up / W 148 | theta[1, 2] += 1 / up / H 149 | theta[0, :] *= W / (W + p / up * 2) 150 | theta[1, :] *= H / (H + p / up * 2) 151 | theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) 152 | g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) 153 | 154 | # Resample image. 155 | y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) 156 | z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) 157 | 158 | # Form mask. 159 | m = torch.zeros_like(y) 160 | c = p * 2 + 1 161 | m[:, :, c:-c, c:-c] = 1 162 | m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) 163 | return z, m 164 | 165 | #---------------------------------------------------------------------------- 166 | # Apply fractional rotation to a batch of 2D images. Corresponds to the 167 | # operator R_\alpha in Appendix E.3. 168 | 169 | def apply_fractional_rotation(x, angle, a=3, **filter_kwargs): 170 | angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) 171 | mat = rotation_matrix(angle) 172 | return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs) 173 | 174 | #---------------------------------------------------------------------------- 175 | # Modify the frequency content of a batch of 2D images as if they had undergo 176 | # fractional rotation -- but without actually rotating them. Corresponds to 177 | # the operator R^*_\alpha in Appendix E.3. 178 | 179 | def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs): 180 | angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) 181 | mat = rotation_matrix(-angle) 182 | f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs) 183 | y = upfirdn2d.filter2d(x=x, f=f) 184 | m = torch.zeros_like(y) 185 | c = f.shape[0] // 2 186 | m[:, :, c:-c, c:-c] = 1 187 | return y, m 188 | 189 | #---------------------------------------------------------------------------- 190 | # Compute the selected equivariance metrics for the given generator. 191 | 192 | def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False): 193 | assert compute_eqt_int or compute_eqt_frac or compute_eqr 194 | 195 | # Setup generator and labels. 196 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 197 | I = torch.eye(3, device=opts.device) 198 | M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None) 199 | if M is None: 200 | raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations') 201 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 202 | 203 | # Sampling loop. 204 | sums = None 205 | progress = opts.progress.sub(tag='eq sampling', num_items=num_samples) 206 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 207 | progress.update(batch_start) 208 | s = [] 209 | 210 | # Randomize noise buffers, if any. 211 | for name, buf in G.named_buffers(): 212 | if name.endswith('.noise_const'): 213 | buf.copy_(torch.randn_like(buf)) 214 | 215 | # Run mapping network. 216 | z = torch.randn([batch_size, G.z_dim], device=opts.device) 217 | c = next(c_iter) 218 | ws = G.mapping(z=z, c=c) 219 | 220 | # Generate reference image. 221 | M[:] = I 222 | orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 223 | 224 | # Integer translation (EQ-T). 225 | if compute_eqt_int: 226 | t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max 227 | t = (t * G.img_resolution).round() / G.img_resolution 228 | M[:] = I 229 | M[:2, 2] = -t 230 | img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 231 | ref, mask = apply_integer_translation(orig, t[0], t[1]) 232 | s += [(ref - img).square() * mask, mask] 233 | 234 | # Fractional translation (EQ-T_frac). 235 | if compute_eqt_frac: 236 | t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max 237 | M[:] = I 238 | M[:2, 2] = -t 239 | img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 240 | ref, mask = apply_fractional_translation(orig, t[0], t[1]) 241 | s += [(ref - img).square() * mask, mask] 242 | 243 | # Rotation (EQ-R). 244 | if compute_eqr: 245 | angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi) 246 | M[:] = rotation_matrix(-angle) 247 | img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) 248 | ref, ref_mask = apply_fractional_rotation(orig, angle) 249 | pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) 250 | mask = ref_mask * pseudo_mask 251 | s += [(ref - pseudo).square() * mask, mask] 252 | 253 | # Accumulate results. 254 | s = torch.stack([x.to(torch.float64).sum() for x in s]) 255 | sums = sums + s if sums is not None else s 256 | progress.update(num_samples) 257 | 258 | # Compute PSNRs. 259 | if opts.num_gpus > 1: 260 | torch.distributed.all_reduce(sums) 261 | sums = sums.cpu() 262 | mses = sums[0::2] / sums[1::2] 263 | psnrs = np.log10(2) * 20 - mses.log10() * 10 264 | psnrs = tuple(psnrs.numpy()) 265 | return psnrs[0] if len(psnrs) == 1 else psnrs 266 | 267 | #---------------------------------------------------------------------------- 268 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen, swav=False, sfid=False): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, swav=swav, sfid=sfid).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, swav=swav, sfid=sfid).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # distribution of this software and related documentation without an express 2 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 3 | 4 | """Main API for computing and reporting quality metrics.""" 5 | 6 | import os 7 | import time 8 | import json 9 | import torch 10 | import dnnlib 11 | 12 | from . import metric_utils 13 | from . import frechet_inception_distance 14 | from . import kernel_inception_distance 15 | from . import precision_recall 16 | from . import perceptual_path_length 17 | from . import inception_score 18 | from . import equivariance 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _metric_dict = dict() # name => fn 23 | 24 | def register_metric(fn): 25 | assert callable(fn) 26 | _metric_dict[fn.__name__] = fn 27 | return fn 28 | 29 | def is_valid_metric(metric): 30 | return metric in _metric_dict 31 | 32 | def list_valid_metrics(): 33 | return list(_metric_dict.keys()) 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 38 | assert is_valid_metric(metric) 39 | opts = metric_utils.MetricOptions(**kwargs) 40 | 41 | # Calculate. 42 | start_time = time.time() 43 | results = _metric_dict[metric](opts) 44 | total_time = time.time() - start_time 45 | 46 | # Broadcast results. 47 | for key, value in list(results.items()): 48 | if opts.num_gpus > 1: 49 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 50 | torch.distributed.broadcast(tensor=value, src=0) 51 | value = float(value.cpu()) 52 | results[key] = value 53 | 54 | # Decorate with metadata. 55 | return dnnlib.EasyDict( 56 | results = dnnlib.EasyDict(results), 57 | metric = metric, 58 | total_time = total_time, 59 | total_time_str = dnnlib.util.format_time(total_time), 60 | num_gpus = opts.num_gpus, 61 | ) 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 66 | metric = result_dict['metric'] 67 | assert is_valid_metric(metric) 68 | if run_dir is not None and snapshot_pkl is not None: 69 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 70 | 71 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 72 | print(jsonl_line) 73 | if run_dir is not None and os.path.isdir(run_dir): 74 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 75 | f.write(jsonl_line + '\n') 76 | 77 | #---------------------------------------------------------------------------- 78 | # Recommended metrics. 79 | 80 | @register_metric 81 | def fid50k_full(opts): 82 | opts.dataset_kwargs.update(max_size=None, xflip=False) 83 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 84 | return dict(fid50k_full=fid) 85 | 86 | @register_metric 87 | def fid10k_full(opts): 88 | opts.dataset_kwargs.update(max_size=None, xflip=False) 89 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000) 90 | return dict(fid10k_full=fid) 91 | 92 | @register_metric 93 | def kid50k_full(opts): 94 | opts.dataset_kwargs.update(max_size=None, xflip=False) 95 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 96 | return dict(kid50k_full=kid) 97 | 98 | @register_metric 99 | def pr50k3_full(opts): 100 | opts.dataset_kwargs.update(max_size=None, xflip=False) 101 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 102 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 103 | 104 | @register_metric 105 | def ppl2_wend(opts): 106 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 107 | return dict(ppl2_wend=ppl) 108 | 109 | @register_metric 110 | def eqt50k_int(opts): 111 | opts.G_kwargs.update(force_fp32=True) 112 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 113 | return dict(eqt50k_int=psnr) 114 | 115 | @register_metric 116 | def eqt50k_frac(opts): 117 | opts.G_kwargs.update(force_fp32=True) 118 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 119 | return dict(eqt50k_frac=psnr) 120 | 121 | @register_metric 122 | def eqr50k(opts): 123 | opts.G_kwargs.update(force_fp32=True) 124 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 125 | return dict(eqr50k=psnr) 126 | 127 | # Legacy metrics. 128 | 129 | @register_metric 130 | def fid50k(opts): 131 | opts.dataset_kwargs.update(max_size=None) 132 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 133 | return dict(fid50k=fid) 134 | 135 | @register_metric 136 | def kid50k(opts): 137 | opts.dataset_kwargs.update(max_size=None) 138 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 139 | return dict(kid50k=kid) 140 | 141 | @register_metric 142 | def pr50k3(opts): 143 | opts.dataset_kwargs.update(max_size=None) 144 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 145 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 146 | 147 | @register_metric 148 | def is50k(opts): 149 | opts.dataset_kwargs.update(max_size=None, xflip=False) 150 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 151 | return dict(is50k_mean=mean, is50k_std=std) 152 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /pg_modules/blocks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils import spectral_norm 6 | 7 | 8 | ### single layers 9 | 10 | 11 | def conv2d(*args, **kwargs): 12 | return spectral_norm(nn.Conv2d(*args, **kwargs)) 13 | 14 | 15 | def convTranspose2d(*args, **kwargs): 16 | return spectral_norm(nn.ConvTranspose2d(*args, **kwargs)) 17 | 18 | 19 | def embedding(*args, **kwargs): 20 | return spectral_norm(nn.Embedding(*args, **kwargs)) 21 | 22 | 23 | def linear(*args, **kwargs): 24 | return spectral_norm(nn.Linear(*args, **kwargs)) 25 | 26 | 27 | def NormLayer(c, mode='batch'): 28 | if mode == 'group': 29 | return nn.GroupNorm(c//2, c) 30 | elif mode == 'batch': 31 | return nn.BatchNorm2d(c) 32 | 33 | 34 | ### Activations 35 | 36 | 37 | class GLU(nn.Module): 38 | def forward(self, x): 39 | nc = x.size(1) 40 | assert nc % 2 == 0, 'channels dont divide 2!' 41 | nc = int(nc/2) 42 | return x[:, :nc] * torch.sigmoid(x[:, nc:]) 43 | 44 | 45 | class Swish(nn.Module): 46 | def forward(self, feat): 47 | return feat * torch.sigmoid(feat) 48 | 49 | 50 | ### Upblocks 51 | 52 | 53 | class InitLayer(nn.Module): 54 | def __init__(self, nz, channel, sz=4): 55 | super().__init__() 56 | 57 | self.init = nn.Sequential( 58 | convTranspose2d(nz, channel*2, sz, 1, 0, bias=False), 59 | NormLayer(channel*2), 60 | GLU(), 61 | ) 62 | 63 | def forward(self, noise): 64 | noise = noise.view(noise.shape[0], -1, 1, 1) 65 | return self.init(noise) 66 | 67 | 68 | def UpBlockSmall(in_planes, out_planes): 69 | block = nn.Sequential( 70 | nn.Upsample(scale_factor=2, mode='nearest'), 71 | conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), 72 | NormLayer(out_planes*2), GLU()) 73 | return block 74 | 75 | 76 | class UpBlockSmallCond(nn.Module): 77 | def __init__(self, in_planes, out_planes, z_dim): 78 | super().__init__() 79 | self.in_planes = in_planes 80 | self.out_planes = out_planes 81 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 82 | self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False) 83 | 84 | which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim) 85 | self.bn = which_bn(2*out_planes) 86 | self.act = GLU() 87 | 88 | def forward(self, x, c): 89 | x = self.up(x) 90 | x = self.conv(x) 91 | x = self.bn(x, c) 92 | x = self.act(x) 93 | return x 94 | 95 | 96 | def UpBlockBig(in_planes, out_planes): 97 | block = nn.Sequential( 98 | nn.Upsample(scale_factor=2, mode='nearest'), 99 | conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), 100 | NoiseInjection(), 101 | NormLayer(out_planes*2), GLU(), 102 | conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False), 103 | NoiseInjection(), 104 | NormLayer(out_planes*2), GLU() 105 | ) 106 | return block 107 | 108 | 109 | class UpBlockBigCond(nn.Module): 110 | def __init__(self, in_planes, out_planes, z_dim): 111 | super().__init__() 112 | self.in_planes = in_planes 113 | self.out_planes = out_planes 114 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 115 | self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False) 116 | self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False) 117 | 118 | which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim) 119 | self.bn1 = which_bn(2*out_planes) 120 | self.bn2 = which_bn(2*out_planes) 121 | self.act = GLU() 122 | self.noise = NoiseInjection() 123 | 124 | def forward(self, x, c): 125 | # block 1 126 | x = self.up(x) 127 | x = self.conv1(x) 128 | x = self.noise(x) 129 | x = self.bn1(x, c) 130 | x = self.act(x) 131 | 132 | # block 2 133 | x = self.conv2(x) 134 | x = self.noise(x) 135 | x = self.bn2(x, c) 136 | x = self.act(x) 137 | 138 | return x 139 | 140 | 141 | class SEBlock(nn.Module): 142 | def __init__(self, ch_in, ch_out): 143 | super().__init__() 144 | self.main = nn.Sequential( 145 | nn.AdaptiveAvgPool2d(4), 146 | conv2d(ch_in, ch_out, 4, 1, 0, bias=False), 147 | Swish(), 148 | conv2d(ch_out, ch_out, 1, 1, 0, bias=False), 149 | nn.Sigmoid(), 150 | ) 151 | 152 | def forward(self, feat_small, feat_big): 153 | return feat_big * self.main(feat_small) 154 | 155 | 156 | ### Downblocks 157 | 158 | 159 | class SeparableConv2d(nn.Module): 160 | def __init__(self, in_channels, out_channels, kernel_size, bias=False): 161 | super(SeparableConv2d, self).__init__() 162 | self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size, 163 | groups=in_channels, bias=bias, padding=1) 164 | self.pointwise = conv2d(in_channels, out_channels, 165 | kernel_size=1, bias=bias) 166 | 167 | def forward(self, x): 168 | out = self.depthwise(x) 169 | out = self.pointwise(out) 170 | return out 171 | 172 | 173 | class DownBlock(nn.Module): 174 | def __init__(self, in_planes, out_planes, separable=False): 175 | super().__init__() 176 | if not separable: 177 | self.main = nn.Sequential( 178 | conv2d(in_planes, out_planes, 4, 2, 1), 179 | NormLayer(out_planes), 180 | nn.LeakyReLU(0.2, inplace=True), 181 | ) 182 | else: 183 | self.main = nn.Sequential( 184 | SeparableConv2d(in_planes, out_planes, 3), 185 | NormLayer(out_planes), 186 | nn.LeakyReLU(0.2, inplace=True), 187 | nn.AvgPool2d(2, 2), 188 | ) 189 | 190 | def forward(self, feat): 191 | return self.main(feat) 192 | 193 | 194 | class DownBlockPatch(nn.Module): 195 | def __init__(self, in_planes, out_planes, separable=False): 196 | super().__init__() 197 | self.main = nn.Sequential( 198 | DownBlock(in_planes, out_planes, separable), 199 | conv2d(out_planes, out_planes, 1, 1, 0, bias=False), 200 | NormLayer(out_planes), 201 | nn.LeakyReLU(0.2, inplace=True), 202 | ) 203 | 204 | def forward(self, feat): 205 | return self.main(feat) 206 | 207 | 208 | ### CSM 209 | 210 | 211 | class ResidualConvUnit(nn.Module): 212 | def __init__(self, cin, activation, bn): 213 | super().__init__() 214 | self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True) 215 | self.skip_add = nn.quantized.FloatFunctional() 216 | 217 | def forward(self, x): 218 | return self.skip_add.add(self.conv(x), x) 219 | 220 | 221 | class FeatureFusionBlock(nn.Module): 222 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False): 223 | super().__init__() 224 | 225 | self.deconv = deconv 226 | self.align_corners = align_corners 227 | 228 | self.expand = expand 229 | out_features = features 230 | if self.expand==True: 231 | out_features = features//2 232 | 233 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 234 | self.skip_add = nn.quantized.FloatFunctional() 235 | 236 | def forward(self, *xs): 237 | output = xs[0] 238 | 239 | if len(xs) == 2: 240 | output = self.skip_add.add(output, xs[1]) 241 | 242 | output = nn.functional.interpolate( 243 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 244 | ) 245 | 246 | output = self.out_conv(output) 247 | 248 | return output 249 | 250 | 251 | ### Misc 252 | 253 | 254 | class NoiseInjection(nn.Module): 255 | def __init__(self): 256 | super().__init__() 257 | self.weight = nn.Parameter(torch.zeros(1), requires_grad=True) 258 | 259 | def forward(self, feat, noise=None): 260 | if noise is None: 261 | batch, _, height, width = feat.shape 262 | noise = torch.randn(batch, 1, height, width).to(feat.device) 263 | 264 | return feat + self.weight * noise 265 | 266 | 267 | class CCBN(nn.Module): 268 | ''' conditional batchnorm ''' 269 | def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1): 270 | super().__init__() 271 | self.output_size, self.input_size = output_size, input_size 272 | 273 | # Prepare gain and bias layers 274 | self.gain = which_linear(input_size, output_size) 275 | self.bias = which_linear(input_size, output_size) 276 | 277 | # epsilon to avoid dividing by 0 278 | self.eps = eps 279 | # Momentum 280 | self.momentum = momentum 281 | 282 | self.register_buffer('stored_mean', torch.zeros(output_size)) 283 | self.register_buffer('stored_var', torch.ones(output_size)) 284 | 285 | def forward(self, x, y): 286 | # Calculate class-conditional gains and biases 287 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) 288 | bias = self.bias(y).view(y.size(0), -1, 1, 1) 289 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, 290 | self.training, 0.1, self.eps) 291 | return out * gain + bias 292 | 293 | 294 | class Interpolate(nn.Module): 295 | """Interpolation module.""" 296 | 297 | def __init__(self, size, mode='bilinear', align_corners=False): 298 | """Init. 299 | Args: 300 | scale_factor (float): scaling 301 | mode (str): interpolation mode 302 | """ 303 | super(Interpolate, self).__init__() 304 | 305 | self.interp = nn.functional.interpolate 306 | self.size = size 307 | self.mode = mode 308 | self.align_corners = align_corners 309 | 310 | def forward(self, x): 311 | """Forward pass. 312 | Args: 313 | x (tensor): input 314 | Returns: 315 | tensor: interpolated data 316 | """ 317 | 318 | x = self.interp( 319 | x, 320 | size=self.size, 321 | mode=self.mode, 322 | align_corners=self.align_corners, 323 | ) 324 | 325 | return x 326 | -------------------------------------------------------------------------------- /pg_modules/diffaug.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.2): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } 77 | -------------------------------------------------------------------------------- /pg_modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d 8 | from pg_modules.projector import F_RandomProj 9 | from pg_modules.diffaug import DiffAugment 10 | 11 | 12 | class SingleDisc(nn.Module): 13 | def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): 14 | super().__init__() 15 | channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, 16 | 256: 32, 512: 16, 1024: 8} 17 | 18 | # interpolate for start sz that are not powers of two 19 | if start_sz not in channel_dict.keys(): 20 | sizes = np.array(list(channel_dict.keys())) 21 | start_sz = sizes[np.argmin(abs(sizes - start_sz))] 22 | self.start_sz = start_sz 23 | 24 | # if given ndf, allocate all layers with the same ndf 25 | if ndf is None: 26 | nfc = channel_dict 27 | else: 28 | nfc = {k: ndf for k, v in channel_dict.items()} 29 | 30 | # for feature map discriminators with nfc not in channel_dict 31 | # this is the case for the pretrained backbone (midas.pretrained) 32 | if nc is not None and head is None: 33 | nfc[start_sz] = nc 34 | 35 | layers = [] 36 | 37 | # Head if the initial input is the full modality 38 | if head: 39 | layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), 40 | nn.LeakyReLU(0.2, inplace=True)] 41 | 42 | # Down Blocks 43 | DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) 44 | while start_sz > end_sz: 45 | layers.append(DB(nfc[start_sz], nfc[start_sz//2])) 46 | start_sz = start_sz // 2 47 | 48 | layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) 49 | self.main = nn.Sequential(*layers) 50 | 51 | def forward(self, x, c): 52 | return self.main(x) 53 | 54 | 55 | class SingleDiscCond(nn.Module): 56 | def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): 57 | super().__init__() 58 | self.cmap_dim = cmap_dim 59 | 60 | # midas channels 61 | channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, 62 | 256: 32, 512: 16, 1024: 8} 63 | 64 | # interpolate for start sz that are not powers of two 65 | if start_sz not in channel_dict.keys(): 66 | sizes = np.array(list(channel_dict.keys())) 67 | start_sz = sizes[np.argmin(abs(sizes - start_sz))] 68 | self.start_sz = start_sz 69 | 70 | # if given ndf, allocate all layers with the same ndf 71 | if ndf is None: 72 | nfc = channel_dict 73 | else: 74 | nfc = {k: ndf for k, v in channel_dict.items()} 75 | 76 | # for feature map discriminators with nfc not in channel_dict 77 | # this is the case for the pretrained backbone (midas.pretrained) 78 | if nc is not None and head is None: 79 | nfc[start_sz] = nc 80 | 81 | layers = [] 82 | 83 | # Head if the initial input is the full modality 84 | if head: 85 | layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), 86 | nn.LeakyReLU(0.2, inplace=True)] 87 | 88 | # Down Blocks 89 | DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) 90 | while start_sz > end_sz: 91 | layers.append(DB(nfc[start_sz], nfc[start_sz//2])) 92 | start_sz = start_sz // 2 93 | self.main = nn.Sequential(*layers) 94 | 95 | # additions for conditioning on class information 96 | self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) 97 | self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) 98 | self.embed_proj = nn.Sequential( 99 | nn.Linear(self.embed.embedding_dim, self.cmap_dim), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | ) 102 | 103 | def forward(self, x, c): 104 | h = self.main(x) 105 | out = self.cls(h) 106 | 107 | # conditioning via projection 108 | cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) 109 | out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) 110 | 111 | return out 112 | 113 | 114 | class MultiScaleD(nn.Module): 115 | def __init__( 116 | self, 117 | channels, 118 | resolutions, 119 | num_discs=1, 120 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing 121 | cond=0, 122 | separable=False, 123 | patch=False, 124 | **kwargs, 125 | ): 126 | super().__init__() 127 | 128 | assert num_discs in [1, 2, 3, 4] 129 | 130 | # the first disc is on the lowest level of the backbone 131 | self.disc_in_channels = channels[:num_discs] 132 | self.disc_in_res = resolutions[:num_discs] 133 | Disc = SingleDiscCond if cond else SingleDisc 134 | 135 | mini_discs = [] 136 | for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): 137 | start_sz = res if not patch else 16 138 | mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], 139 | self.mini_discs = nn.ModuleDict(mini_discs) 140 | 141 | def forward(self, features, c): 142 | all_logits = [] 143 | for k, disc in self.mini_discs.items(): 144 | all_logits.append(disc(features[k], c).view(features[k].size(0), -1)) 145 | 146 | all_logits = torch.cat(all_logits, dim=1) 147 | return all_logits 148 | 149 | 150 | class ProjectedDiscriminator(torch.nn.Module): 151 | def __init__( 152 | self, 153 | diffaug=True, 154 | interp224=True, 155 | backbone_kwargs={}, 156 | **kwargs 157 | ): 158 | super().__init__() 159 | self.diffaug = diffaug 160 | self.interp224 = interp224 161 | self.feature_network = F_RandomProj(**backbone_kwargs) 162 | self.discriminator = MultiScaleD( 163 | channels=self.feature_network.CHANNELS, 164 | resolutions=self.feature_network.RESOLUTIONS, 165 | **backbone_kwargs, 166 | ) 167 | 168 | def train(self, mode=True): 169 | self.feature_network = self.feature_network.train(False) 170 | self.discriminator = self.discriminator.train(mode) 171 | return self 172 | 173 | def eval(self): 174 | return self.train(False) 175 | 176 | def forward(self, x, c): 177 | if self.diffaug: 178 | x = DiffAugment(x, policy='color,translation,cutout') 179 | 180 | if self.interp224: 181 | x = F.interpolate(x, 224, mode='bilinear', align_corners=False) 182 | 183 | features = self.feature_network(x) 184 | logits = self.discriminator(features, c) 185 | 186 | return logits 187 | -------------------------------------------------------------------------------- /pg_modules/networks_fastgan.py: -------------------------------------------------------------------------------- 1 | # original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py 2 | # 3 | # modified by Axel Sauer for "Projected GANs Converge Faster" 4 | # 5 | import torch.nn as nn 6 | from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d) 7 | 8 | 9 | def normalize_second_moment(x, dim=1, eps=1e-8): 10 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 11 | 12 | 13 | class DummyMapping(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, z, c, **kwargs): 18 | return z.unsqueeze(1) # to fit the StyleGAN API 19 | 20 | 21 | class FastganSynthesis(nn.Module): 22 | def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False): 23 | super().__init__() 24 | self.img_resolution = img_resolution 25 | self.z_dim = z_dim 26 | 27 | # channel multiplier 28 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 29 | 512:0.25, 1024:0.125} 30 | nfc = {} 31 | for k, v in nfc_multi.items(): 32 | nfc[k] = int(v*ngf) 33 | 34 | # layers 35 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4) 36 | 37 | UpBlock = UpBlockSmall if lite else UpBlockBig 38 | 39 | self.feat_8 = UpBlock(nfc[4], nfc[8]) 40 | self.feat_16 = UpBlock(nfc[8], nfc[16]) 41 | self.feat_32 = UpBlock(nfc[16], nfc[32]) 42 | self.feat_64 = UpBlock(nfc[32], nfc[64]) 43 | self.feat_128 = UpBlock(nfc[64], nfc[128]) 44 | self.feat_256 = UpBlock(nfc[128], nfc[256]) 45 | 46 | self.se_64 = SEBlock(nfc[4], nfc[64]) 47 | self.se_128 = SEBlock(nfc[8], nfc[128]) 48 | self.se_256 = SEBlock(nfc[16], nfc[256]) 49 | 50 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) 51 | 52 | if img_resolution > 256: 53 | self.feat_512 = UpBlock(nfc[256], nfc[512]) 54 | self.se_512 = SEBlock(nfc[32], nfc[512]) 55 | if img_resolution > 512: 56 | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) 57 | 58 | def forward(self, input, c, **kwargs): 59 | # map noise to hypersphere as in "Progressive Growing of GANS" 60 | input = normalize_second_moment(input[:, 0]) 61 | 62 | feat_4 = self.init(input) 63 | feat_8 = self.feat_8(feat_4) 64 | feat_16 = self.feat_16(feat_8) 65 | feat_32 = self.feat_32(feat_16) 66 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) 67 | feat_128 = self.se_128(feat_8, self.feat_128(feat_64)) 68 | 69 | if self.img_resolution >= 128: 70 | feat_last = feat_128 71 | 72 | if self.img_resolution >= 256: 73 | feat_last = self.se_256(feat_16, self.feat_256(feat_last)) 74 | 75 | if self.img_resolution >= 512: 76 | feat_last = self.se_512(feat_32, self.feat_512(feat_last)) 77 | 78 | if self.img_resolution >= 1024: 79 | feat_last = self.feat_1024(feat_last) 80 | 81 | return self.to_big(feat_last) 82 | 83 | 84 | class FastganSynthesisCond(nn.Module): 85 | def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False): 86 | super().__init__() 87 | 88 | self.z_dim = z_dim 89 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 90 | 512:0.25, 1024:0.125, 2048:0.125} 91 | nfc = {} 92 | for k, v in nfc_multi.items(): 93 | nfc[k] = int(v*ngf) 94 | 95 | self.img_resolution = img_resolution 96 | 97 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4) 98 | 99 | UpBlock = UpBlockSmallCond if lite else UpBlockBigCond 100 | 101 | self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) 102 | self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) 103 | self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) 104 | self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) 105 | self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim) 106 | self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim) 107 | 108 | self.se_64 = SEBlock(nfc[4], nfc[64]) 109 | self.se_128 = SEBlock(nfc[8], nfc[128]) 110 | self.se_256 = SEBlock(nfc[16], nfc[256]) 111 | 112 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) 113 | 114 | if img_resolution > 256: 115 | self.feat_512 = UpBlock(nfc[256], nfc[512]) 116 | self.se_512 = SEBlock(nfc[32], nfc[512]) 117 | if img_resolution > 512: 118 | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) 119 | 120 | self.embed = nn.Embedding(num_classes, z_dim) 121 | 122 | def forward(self, input, c, update_emas=False): 123 | c = self.embed(c.argmax(1)) 124 | 125 | # map noise to hypersphere as in "Progressive Growing of GANS" 126 | input = normalize_second_moment(input[:, 0]) 127 | 128 | feat_4 = self.init(input) 129 | feat_8 = self.feat_8(feat_4, c) 130 | feat_16 = self.feat_16(feat_8, c) 131 | feat_32 = self.feat_32(feat_16, c) 132 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c)) 133 | feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) 134 | 135 | if self.img_resolution >= 128: 136 | feat_last = feat_128 137 | 138 | if self.img_resolution >= 256: 139 | feat_last = self.se_256(feat_16, self.feat_256(feat_last, c)) 140 | 141 | if self.img_resolution >= 512: 142 | feat_last = self.se_512(feat_32, self.feat_512(feat_last, c)) 143 | 144 | if self.img_resolution >= 1024: 145 | feat_last = self.feat_1024(feat_last, c) 146 | 147 | return self.to_big(feat_last) 148 | 149 | 150 | class Generator(nn.Module): 151 | def __init__( 152 | self, 153 | z_dim=256, 154 | c_dim=0, 155 | w_dim=0, 156 | img_resolution=256, 157 | img_channels=3, 158 | ngf=128, 159 | cond=0, 160 | mapping_kwargs={}, 161 | synthesis_kwargs={} 162 | ): 163 | super().__init__() 164 | self.z_dim = z_dim 165 | self.c_dim = c_dim 166 | self.w_dim = w_dim 167 | self.img_resolution = img_resolution 168 | self.img_channels = img_channels 169 | 170 | # Mapping and Synthesis Networks 171 | self.mapping = DummyMapping() # to fit the StyleGAN API 172 | Synthesis = FastganSynthesisCond if cond else FastganSynthesis 173 | self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs) 174 | 175 | def forward(self, z, c, **kwargs): 176 | w = self.mapping(z, c) 177 | img = self.synthesis(w, c) 178 | return img 179 | -------------------------------------------------------------------------------- /pg_modules/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | from pg_modules.blocks import FeatureFusionBlock 5 | 6 | 7 | def _make_scratch_ccm(scratch, in_channels, cout, expand=False): 8 | # shapes 9 | out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 10 | 11 | scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) 12 | scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) 13 | scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) 14 | scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) 15 | 16 | scratch.CHANNELS = out_channels 17 | 18 | return scratch 19 | 20 | 21 | def _make_scratch_csm(scratch, in_channels, cout, expand): 22 | scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) 23 | scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) 24 | scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) 25 | scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) 26 | 27 | # last refinenet does not expand to save channels in higher dimensions 28 | scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 29 | 30 | return scratch 31 | 32 | 33 | def _make_efficientnet(model): 34 | pretrained = nn.Module() 35 | pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]) 36 | pretrained.layer1 = nn.Sequential(*model.blocks[2:3]) 37 | pretrained.layer2 = nn.Sequential(*model.blocks[3:5]) 38 | pretrained.layer3 = nn.Sequential(*model.blocks[5:9]) 39 | return pretrained 40 | 41 | 42 | def calc_channels(pretrained, inp_res=224): 43 | channels = [] 44 | tmp = torch.zeros(1, 3, inp_res, inp_res) 45 | 46 | # forward pass 47 | tmp = pretrained.layer0(tmp) 48 | channels.append(tmp.shape[1]) 49 | tmp = pretrained.layer1(tmp) 50 | channels.append(tmp.shape[1]) 51 | tmp = pretrained.layer2(tmp) 52 | channels.append(tmp.shape[1]) 53 | tmp = pretrained.layer3(tmp) 54 | channels.append(tmp.shape[1]) 55 | 56 | return channels 57 | 58 | 59 | def _make_projector(im_res, cout, proj_type, expand=False): 60 | assert proj_type in [0, 1, 2], "Invalid projection type" 61 | 62 | ### Build pretrained feature network 63 | model = timm.create_model('tf_efficientnet_lite0', pretrained=True) 64 | pretrained = _make_efficientnet(model) 65 | 66 | # determine resolution of feature maps, this is later used to calculate the number 67 | # of down blocks in the discriminators. Interestingly, the best results are achieved 68 | # by fixing this to 256, ie., we use the same number of down blocks per discriminator 69 | # independent of the dataset resolution 70 | im_res = 256 71 | pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] 72 | pretrained.CHANNELS = calc_channels(pretrained) 73 | 74 | if proj_type == 0: return pretrained, None 75 | 76 | ### Build CCM 77 | scratch = nn.Module() 78 | scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) 79 | pretrained.CHANNELS = scratch.CHANNELS 80 | 81 | if proj_type == 1: return pretrained, scratch 82 | 83 | ### build CSM 84 | scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) 85 | 86 | # CSM upsamples x2 so the feature map resolution doubles 87 | pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] 88 | pretrained.CHANNELS = scratch.CHANNELS 89 | 90 | return pretrained, scratch 91 | 92 | 93 | class F_RandomProj(nn.Module): 94 | def __init__( 95 | self, 96 | im_res=256, 97 | cout=64, 98 | expand=True, 99 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing 100 | **kwargs, 101 | ): 102 | super().__init__() 103 | self.proj_type = proj_type 104 | self.cout = cout 105 | self.expand = expand 106 | 107 | # build pretrained feature network and random decoder (scratch) 108 | self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand) 109 | self.CHANNELS = self.pretrained.CHANNELS 110 | self.RESOLUTIONS = self.pretrained.RESOLUTIONS 111 | 112 | def forward(self, x): 113 | # predict feature maps 114 | out0 = self.pretrained.layer0(x) 115 | out1 = self.pretrained.layer1(out0) 116 | out2 = self.pretrained.layer2(out1) 117 | out3 = self.pretrained.layer3(out2) 118 | 119 | # start enumerating at the lowest layer (this is where we put the first discriminator) 120 | out = { 121 | '0': out0, 122 | '1': out1, 123 | '2': out2, 124 | '3': out3, 125 | } 126 | 127 | if self.proj_type == 0: return out 128 | 129 | out0_channel_mixed = self.scratch.layer0_ccm(out['0']) 130 | out1_channel_mixed = self.scratch.layer1_ccm(out['1']) 131 | out2_channel_mixed = self.scratch.layer2_ccm(out['2']) 132 | out3_channel_mixed = self.scratch.layer3_ccm(out['3']) 133 | 134 | out = { 135 | '0': out0_channel_mixed, 136 | '1': out1_channel_mixed, 137 | '2': out2_channel_mixed, 138 | '3': out3_channel_mixed, 139 | } 140 | 141 | if self.proj_type == 1: return out 142 | 143 | # from bottom to top 144 | out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) 145 | out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) 146 | out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) 147 | out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) 148 | 149 | out = { 150 | '0': out0_scale_mixed, 151 | '1': out1_scale_mixed, 152 | '2': out2_scale_mixed, 153 | '3': out3_scale_mixed, 154 | } 155 | 156 | return out 157 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import dnnlib 15 | 16 | from .. import custom_ops 17 | from .. import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | 15 | # pylint: disable=redefined-builtin 16 | # pylint: disable=arguments-differ 17 | # pylint: disable=protected-access 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | enabled = False # Enable the custom op by setting this to true. 22 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 23 | 24 | @contextlib.contextmanager 25 | def no_weight_gradients(disable=True): 26 | global weight_gradients_disabled 27 | old = weight_gradients_disabled 28 | if disable: 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | return True 54 | 55 | def _tuple_of_ints(xs, ndim): 56 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 57 | assert len(xs) == ndim 58 | assert all(isinstance(x, int) for x in xs) 59 | return xs 60 | 61 | #---------------------------------------------------------------------------- 62 | 63 | _conv2d_gradfix_cache = dict() 64 | _null_tensor = torch.empty([0]) 65 | 66 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 67 | # Parse arguments. 68 | ndim = 2 69 | weight_shape = tuple(weight_shape) 70 | stride = _tuple_of_ints(stride, ndim) 71 | padding = _tuple_of_ints(padding, ndim) 72 | output_padding = _tuple_of_ints(output_padding, ndim) 73 | dilation = _tuple_of_ints(dilation, ndim) 74 | 75 | # Lookup from cache. 76 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 77 | if key in _conv2d_gradfix_cache: 78 | return _conv2d_gradfix_cache[key] 79 | 80 | # Validate arguments. 81 | assert groups >= 1 82 | assert len(weight_shape) == ndim + 2 83 | assert all(stride[i] >= 1 for i in range(ndim)) 84 | assert all(padding[i] >= 0 for i in range(ndim)) 85 | assert all(dilation[i] >= 0 for i in range(ndim)) 86 | if not transpose: 87 | assert all(output_padding[i] == 0 for i in range(ndim)) 88 | else: # transpose 89 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 90 | 91 | # Helpers. 92 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 93 | def calc_output_padding(input_shape, output_shape): 94 | if transpose: 95 | return [0, 0] 96 | return [ 97 | input_shape[i + 2] 98 | - (output_shape[i + 2] - 1) * stride[i] 99 | - (1 - 2 * padding[i]) 100 | - dilation[i] * (weight_shape[i + 2] - 1) 101 | for i in range(ndim) 102 | ] 103 | 104 | # Forward & backward. 105 | class Conv2d(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx, input, weight, bias): 108 | assert weight.shape == weight_shape 109 | ctx.save_for_backward( 110 | input if weight.requires_grad else _null_tensor, 111 | weight if input.requires_grad else _null_tensor, 112 | ) 113 | ctx.input_shape = input.shape 114 | 115 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 116 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 117 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 118 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 119 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 120 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 121 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 122 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 123 | 124 | # General case => cuDNN. 125 | if transpose: 126 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 127 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | input, weight = ctx.saved_tensors 132 | input_shape = ctx.input_shape 133 | grad_input = None 134 | grad_weight = None 135 | grad_bias = None 136 | 137 | if ctx.needs_input_grad[0]: 138 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 139 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 140 | grad_input = op.apply(grad_output, weight, None) 141 | assert grad_input.shape == input_shape 142 | 143 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 144 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 145 | assert grad_weight.shape == weight_shape 146 | 147 | if ctx.needs_input_grad[2]: 148 | grad_bias = grad_output.sum([0, 2, 3]) 149 | 150 | return grad_input, grad_weight, grad_bias 151 | 152 | # Gradient with respect to the weights. 153 | class Conv2dGradWeight(torch.autograd.Function): 154 | @staticmethod 155 | def forward(ctx, grad_output, input): 156 | ctx.save_for_backward( 157 | grad_output if input.requires_grad else _null_tensor, 158 | input if grad_output.requires_grad else _null_tensor, 159 | ) 160 | ctx.grad_output_shape = grad_output.shape 161 | ctx.input_shape = input.shape 162 | 163 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 164 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 165 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 166 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 167 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 168 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 169 | 170 | # General case => cuDNN. 171 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 172 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 173 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 174 | 175 | @staticmethod 176 | def backward(ctx, grad2_grad_weight): 177 | grad_output, input = ctx.saved_tensors 178 | grad_output_shape = ctx.grad_output_shape 179 | input_shape = ctx.input_shape 180 | grad2_grad_output = None 181 | grad2_input = None 182 | 183 | if ctx.needs_input_grad[0]: 184 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 185 | assert grad2_grad_output.shape == grad_output_shape 186 | 187 | if ctx.needs_input_grad[1]: 188 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 189 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 190 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 191 | assert grad2_input.shape == input_shape 192 | 193 | return grad2_grad_output, grad2_input 194 | 195 | _conv2d_gradfix_cache[key] = Conv2d 196 | return Conv2d 197 | 198 | #---------------------------------------------------------------------------- 199 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /torch_utils/utils_spectrum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fftn 3 | 4 | 5 | def roll_quadrants(data, backwards=False): 6 | """ 7 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1] 8 | Args: 9 | data: fourier transform, (NxHxW) 10 | backwards: bool, if True shift high frequencies back to center 11 | 12 | Returns: 13 | Shifted fourier transform. 14 | """ 15 | dim = data.ndim - 1 16 | 17 | if dim != 2: 18 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 19 | if any(s % 2 == 0 for s in data.shape[1:]): 20 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.') 21 | 22 | # for each dimension swap left and right half 23 | dims = tuple(range(1, dim+1)) # add one for batch dimension 24 | shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd 25 | if backwards: 26 | shifts *= -1 27 | return data.roll(shifts.tolist(), dims=dims) 28 | 29 | 30 | def batch_fft(data, normalize=False): 31 | """ 32 | Compute fourier transform of batch. 33 | Args: 34 | data: input tensor, (NxHxW) 35 | 36 | Returns: 37 | Batch fourier transform of input data. 38 | """ 39 | 40 | dim = data.ndim - 1 # subtract one for batch dimension 41 | if dim != 2: 42 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 43 | 44 | dims = tuple(range(1, dim + 1)) # add one for batch dimension 45 | if normalize: 46 | norm = 'ortho' 47 | else: 48 | norm = 'backward' 49 | 50 | if not torch.is_complex(data): 51 | data = torch.complex(data, torch.zeros_like(data)) 52 | freq = fftn(data, dim=dims, norm=norm) 53 | 54 | return freq 55 | 56 | 57 | def azimuthal_average(image, center=None): 58 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/ 59 | """ 60 | Calculate the azimuthally averaged radial profile. 61 | Requires low frequencies to be at the center of the image. 62 | Args: 63 | image: Batch of 2D images, NxHxW 64 | center: The [x,y] pixel coordinates used as the center. The default is 65 | None, which then uses the center of the image (including 66 | fracitonal pixels). 67 | 68 | Returns: 69 | Azimuthal average over the image around the center 70 | """ 71 | # Check input shapes 72 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \ 73 | f'(but it is len(center)={len(center)}.' 74 | # Calculate the indices from the image 75 | H, W = image.shape[-2:] 76 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 77 | 78 | if center is None: 79 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0]) 80 | 81 | # Compute radius for each pixel wrt center 82 | r = torch.stack([w-center[0], h-center[1]]).norm(2, 0) 83 | 84 | # Get sorted radii 85 | r_sorted, ind = r.flatten().sort() 86 | i_sorted = image.flatten(-2, -1)[..., ind] 87 | 88 | # Get the integer part of the radii (bin size = 1) 89 | r_int = r_sorted.long() # attribute to the smaller integer 90 | 91 | # Find all pixels that fall within each radial bin. 92 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii 93 | rind = torch.where(deltar)[0] # location of changed radius 94 | 95 | # compute number of elements in each bin 96 | nind = rind + 1 # number of elements = idx + 1 97 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders 98 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius 99 | 100 | # Cumulative sum to figure out sums for each radius bin 101 | if H % 2 == 0: 102 | raise NotImplementedError('Not sure if implementation correct, please check') 103 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders 104 | else: 105 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders 106 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius 107 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]] 108 | # add mean 109 | tbin = torch.cat([csim[:, 0:1], tbin], 1) 110 | 111 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins 112 | 113 | return radial_prof 114 | 115 | 116 | def get_spectrum(data, normalize=False): 117 | dim = data.ndim - 1 # subtract one for batch dimension 118 | if dim != 2: 119 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 120 | 121 | freq = batch_fft(data, normalize=normalize) 122 | power_spec = freq.real ** 2 + freq.imag ** 2 123 | N = data.shape[1] 124 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum 125 | # and is not averaged with the mean value 126 | N_2 = N//2 127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1) 128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2) 129 | 130 | power_spec = roll_quadrants(power_spec) 131 | power_spec = azimuthal_average(power_spec) 132 | return power_spec 133 | 134 | 135 | def plot_std(mean, std, x=None, ax=None, **kwargs): 136 | import matplotlib.pyplot as plt 137 | if ax is None: 138 | fig, ax = plt.subplots(1) 139 | 140 | # plot error margins in same color as line 141 | err_kwargs = { 142 | 'alpha': 0.3 143 | } 144 | 145 | if 'c' in kwargs.keys(): 146 | err_kwargs['color'] = kwargs['c'] 147 | elif 'color' in kwargs.keys(): 148 | err_kwargs['color'] = kwargs['color'] 149 | 150 | if x is None: 151 | x = torch.linspace(0, 1, len(mean)) # use normalized x axis 152 | ax.plot(x, mean, **kwargs) 153 | ax.fill_between(x, mean-std, mean+std, **err_kwargs) 154 | 155 | return ax 156 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Streaming images and labels from datasets created with dataset_tool.py.""" 10 | 11 | import os 12 | import numpy as np 13 | import zipfile 14 | import PIL.Image 15 | import json 16 | import torch 17 | import dnnlib 18 | import copy 19 | 20 | try: 21 | import pyspng 22 | except ImportError: 23 | pyspng = None 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | class Dataset(torch.utils.data.Dataset): 28 | def __init__(self, 29 | name, # Name of the dataset. 30 | raw_shape, # Shape of the raw image data (NCHW). 31 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 32 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 33 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 34 | random_seed = 1, # Random seed to use when applying max_size. 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._raw_labels = None 40 | self._label_shape = None 41 | 42 | # Apply max_size. 43 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 44 | self._base_raw_idx = copy.deepcopy(self._raw_idx) 45 | if (max_size is not None) and (self._raw_idx.size > max_size): 46 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 47 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 48 | 49 | # Apply xflip. 50 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 51 | if xflip: 52 | self._raw_idx = np.tile(self._raw_idx, 2) 53 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 54 | 55 | def set_dyn_len(self, new_len): 56 | self._raw_idx = self._base_raw_idx[:new_len] 57 | 58 | def set_classes(self, cls_list): 59 | self._raw_labels = self._load_raw_labels() 60 | new_idcs = [self._raw_labels == cl for cl in cls_list] 61 | new_idcs = np.sum(np.vstack(new_idcs), 0) # logical or 62 | new_idcs = np.where(new_idcs) # find location 63 | self._raw_idx = self._base_raw_idx[new_idcs] 64 | assert all(sorted(cls_list) == np.unique(self._raw_labels[self._raw_idx])) 65 | print(f"Training on the following classes: {cls_list}") 66 | 67 | def _get_raw_labels(self): 68 | if self._raw_labels is None: 69 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 70 | if self._raw_labels is None: 71 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 72 | assert isinstance(self._raw_labels, np.ndarray) 73 | assert self._raw_labels.shape[0] == self._raw_shape[0] 74 | assert self._raw_labels.dtype in [np.float32, np.int64] 75 | if self._raw_labels.dtype == np.int64: 76 | assert self._raw_labels.ndim == 1 77 | assert np.all(self._raw_labels >= 0) 78 | return self._raw_labels 79 | 80 | def close(self): # to be overridden by subclass 81 | pass 82 | 83 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 84 | raise NotImplementedError 85 | 86 | def _load_raw_labels(self): # to be overridden by subclass 87 | raise NotImplementedError 88 | 89 | def __getstate__(self): 90 | return dict(self.__dict__, _raw_labels=None) 91 | 92 | def __del__(self): 93 | try: 94 | self.close() 95 | except: 96 | pass 97 | 98 | def __len__(self): 99 | return self._raw_idx.size 100 | 101 | def __getitem__(self, idx): 102 | image = self._load_raw_image(self._raw_idx[idx]) 103 | assert isinstance(image, np.ndarray) 104 | assert list(image.shape) == self.image_shape 105 | assert image.dtype == np.uint8 106 | if self._xflip[idx]: 107 | assert image.ndim == 3 # CHW 108 | image = image[:, :, ::-1] 109 | return image.copy(), self.get_label(idx) 110 | 111 | def get_label(self, idx): 112 | label = self._get_raw_labels()[self._raw_idx[idx]] 113 | if label.dtype == np.int64: 114 | onehot = np.zeros(self.label_shape, dtype=np.float32) 115 | onehot[label] = 1 116 | label = onehot 117 | return label.copy() 118 | 119 | def get_details(self, idx): 120 | d = dnnlib.EasyDict() 121 | d.raw_idx = int(self._raw_idx[idx]) 122 | d.xflip = (int(self._xflip[idx]) != 0) 123 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 124 | return d 125 | 126 | @property 127 | def name(self): 128 | return self._name 129 | 130 | @property 131 | def image_shape(self): 132 | return list(self._raw_shape[1:]) 133 | 134 | @property 135 | def num_channels(self): 136 | assert len(self.image_shape) == 3 # CHW 137 | return self.image_shape[0] 138 | 139 | @property 140 | def resolution(self): 141 | assert len(self.image_shape) == 3 # CHW 142 | assert self.image_shape[1] == self.image_shape[2] 143 | return self.image_shape[1] 144 | 145 | @property 146 | def label_shape(self): 147 | if self._label_shape is None: 148 | raw_labels = self._get_raw_labels() 149 | if raw_labels.dtype == np.int64: 150 | self._label_shape = [int(np.max(raw_labels)) + 1] 151 | else: 152 | self._label_shape = raw_labels.shape[1:] 153 | return list(self._label_shape) 154 | 155 | @property 156 | def label_dim(self): 157 | assert len(self.label_shape) == 1 158 | return self.label_shape[0] 159 | 160 | @property 161 | def has_labels(self): 162 | return any(x != 0 for x in self.label_shape) 163 | 164 | @property 165 | def has_onehot_labels(self): 166 | return self._get_raw_labels().dtype == np.int64 167 | 168 | #---------------------------------------------------------------------------- 169 | 170 | class ImageFolderDataset(Dataset): 171 | def __init__(self, 172 | path, # Path to directory or zip. 173 | resolution = None, # Ensure specific resolution, None = highest available. 174 | **super_kwargs, # Additional arguments for the Dataset base class. 175 | ): 176 | self._path = path 177 | self._zipfile = None 178 | 179 | if os.path.isdir(self._path): 180 | self._type = 'dir' 181 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 182 | elif self._file_ext(self._path) == '.zip': 183 | self._type = 'zip' 184 | self._all_fnames = set(self._get_zipfile().namelist()) 185 | else: 186 | raise IOError('Path must point to a directory or zip') 187 | 188 | PIL.Image.init() 189 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 190 | if len(self._image_fnames) == 0: 191 | raise IOError('No image files found in the specified path') 192 | 193 | name = os.path.splitext(os.path.basename(self._path))[0] 194 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 195 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 196 | raise IOError('Image files do not match the specified resolution') 197 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 198 | 199 | @staticmethod 200 | def _file_ext(fname): 201 | return os.path.splitext(fname)[1].lower() 202 | 203 | def _get_zipfile(self): 204 | assert self._type == 'zip' 205 | if self._zipfile is None: 206 | self._zipfile = zipfile.ZipFile(self._path) 207 | return self._zipfile 208 | 209 | def _open_file(self, fname): 210 | if self._type == 'dir': 211 | return open(os.path.join(self._path, fname), 'rb') 212 | if self._type == 'zip': 213 | return self._get_zipfile().open(fname, 'r') 214 | return None 215 | 216 | def close(self): 217 | try: 218 | if self._zipfile is not None: 219 | self._zipfile.close() 220 | finally: 221 | self._zipfile = None 222 | 223 | def __getstate__(self): 224 | return dict(super().__getstate__(), _zipfile=None) 225 | 226 | def _load_raw_image(self, raw_idx): 227 | fname = self._image_fnames[raw_idx] 228 | with self._open_file(fname) as f: 229 | if pyspng is not None and self._file_ext(fname) == '.png': 230 | image = pyspng.load(f.read()) 231 | else: 232 | image = np.array(PIL.Image.open(f)) 233 | if image.ndim == 2: 234 | image = image[:, :, np.newaxis] # HW => HWC 235 | image = image.transpose(2, 0, 1) # HWC => CHW 236 | return image 237 | 238 | def _load_raw_labels(self): 239 | fname = 'dataset.json' 240 | if fname not in self._all_fnames: 241 | return None 242 | with self._open_file(fname) as f: 243 | labels = json.load(f)['labels'] 244 | if labels is None: 245 | return None 246 | labels = dict(labels) 247 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 248 | labels = np.array(labels) 249 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 250 | return labels 251 | 252 | #---------------------------------------------------------------------------- 253 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | # 9 | # modified by Axel Sauer for "Projected GANs Converge Faster" 10 | # 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from torch_utils import training_stats 15 | from torch_utils.ops import upfirdn2d 16 | 17 | 18 | class Loss: 19 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass 20 | raise NotImplementedError() 21 | 22 | 23 | class ProjectedGANLoss(Loss): 24 | def __init__(self, device, G, D, G_ema, blur_init_sigma=0, blur_fade_kimg=0, **kwargs): 25 | super().__init__() 26 | self.device = device 27 | self.G = G 28 | self.G_ema = G_ema 29 | self.D = D 30 | self.blur_init_sigma = blur_init_sigma 31 | self.blur_fade_kimg = blur_fade_kimg 32 | 33 | def run_G(self, z, c, update_emas=False): 34 | ws = self.G.mapping(z, c, update_emas=update_emas) 35 | img = self.G.synthesis(ws, c, update_emas=False) 36 | return img 37 | 38 | def run_D(self, img, c, blur_sigma=0, update_emas=False): 39 | blur_size = np.floor(blur_sigma * 3) 40 | if blur_size > 0: 41 | with torch.autograd.profiler.record_function('blur'): 42 | f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() 43 | img = upfirdn2d.filter2d(img, f / f.sum()) 44 | 45 | logits = self.D(img, c) 46 | return logits 47 | 48 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): 49 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 50 | do_Gmain = (phase in ['Gmain', 'Gboth']) 51 | do_Dmain = (phase in ['Dmain', 'Dboth']) 52 | if phase in ['Dreg', 'Greg']: return # no regularization needed for PG 53 | 54 | # blurring schedule 55 | blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 1 else 0 56 | 57 | if do_Gmain: 58 | 59 | # Gmain: Maximize logits for generated images. 60 | with torch.autograd.profiler.record_function('Gmain_forward'): 61 | gen_img = self.run_G(gen_z, gen_c) 62 | gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) 63 | loss_Gmain = (-gen_logits).mean() 64 | 65 | # Logging 66 | training_stats.report('Loss/scores/fake', gen_logits) 67 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 68 | training_stats.report('Loss/G/loss', loss_Gmain) 69 | 70 | with torch.autograd.profiler.record_function('Gmain_backward'): 71 | loss_Gmain.backward() 72 | 73 | if do_Dmain: 74 | 75 | # Dmain: Minimize logits for generated images. 76 | with torch.autograd.profiler.record_function('Dgen_forward'): 77 | gen_img = self.run_G(gen_z, gen_c, update_emas=True) 78 | gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) 79 | loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() 80 | 81 | # Logging 82 | training_stats.report('Loss/scores/fake', gen_logits) 83 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 84 | 85 | with torch.autograd.profiler.record_function('Dgen_backward'): 86 | loss_Dgen.backward() 87 | 88 | # Dmain: Maximize logits for real images. 89 | with torch.autograd.profiler.record_function('Dreal_forward'): 90 | real_img_tmp = real_img.detach().requires_grad_(False) 91 | real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) 92 | loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() 93 | 94 | # Logging 95 | training_stats.report('Loss/scores/real', real_logits) 96 | training_stats.report('Loss/signs/real', real_logits.sign()) 97 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 98 | 99 | with torch.autograd.profiler.record_function('Dreal_backward'): 100 | loss_Dreal.backward() 101 | --------------------------------------------------------------------------------