├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── CoModGAN.ipynb ├── Dockerfile ├── LICENSE.txt ├── README.md ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── docker_run.sh ├── docs ├── dataset-tool-help.txt ├── license.html ├── stylegan2-ada-teaser-1024x252.png ├── stylegan2-ada-training-curves.png └── train-help.txt ├── generate.py ├── legacy.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── projector.py ├── run.sh ├── style_mixing.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py └── training ├── __init__.py ├── augment.py ├── countless ├── .gitignore ├── README.md ├── __init__.py ├── countless2d.py ├── countless3d.py ├── images │ ├── gcim.jpg │ ├── gray_segmentation.png │ ├── segmentation.png │ └── sparse.png ├── memprof │ ├── countless2d_gcim_N_1000.png │ ├── countless2d_quick_gcim_N_1000.png │ ├── countless3d.png │ ├── countless3d_dynamic.png │ ├── countless3d_dynamic_generalized.png │ └── countless3d_generalized.png ├── requirements.txt └── test.py ├── dataset.py ├── loss.py ├── masks.py ├── networks.py ├── networks_6channel.py ├── seg_mask.py ├── training_loop.py └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. In '...' directory, run command '...' 16 | 2. See error (copy&paste full log, including exceptions and **stacktraces**). 17 | 18 | Please copy&paste text instead of screenshots for better searchability. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Linux Ubuntu 20.04, Windows 10] 28 | - PyTorch version (e.g., pytorch 1.7.1) 29 | - CUDA toolkit version (e.g., CUDA 11.0) 30 | - NVIDIA driver version 31 | - GPU [e.g., Titan V, RTX 3090] 32 | - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:20.12-py3) 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .cache/ 3 | generate1.py 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:20.12-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio-ffmpeg==0.4.3 pyspng==0.1.0 15 | 16 | WORKDIR /workspace 17 | 18 | # Unset TORCH_CUDA_ARCH_LIST and exec. This makes pytorch run-time 19 | # extension builds significantly faster as we only compile for the 20 | # currently active GPU configuration. 21 | RUN (printf '#!/bin/bash\nunset TORCH_CUDA_ARCH_LIST\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 22 | ENTRYPOINT ["/entry.sh"] 23 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CoModGAN Pytorch 2 | 3 | * Built on top of [StyleGANv2 ADA implementation](https://github.com/NVlabs/stylegan2-ada-pytorch) 4 | 5 | **Compatibility** 6 | * Compatible with old network pickles created using the TensorFlow version. 7 | * New ZIP/PNG based dataset format for maximal interoperability with existing 3rd party tools. 8 | * TFRecords datasets are no longer supported — they need to be converted to the new format. 9 | * New JSON-based format for logs, metrics, and training curves. 10 | * Training curves are also exported in the old TFEvents format if TensorBoard is installed. 11 | * Command line syntax is mostly unchanged, with a few exceptions (e.g., `dataset_tool.py`). 12 | * Comparison methods are not supported (`--cmethod`, `--dcap`, `--cfg=cifarbaseline`, `--aug=adarv`) 13 | * **Truncation is now disabled by default.** 14 | 15 | ## Data repository 16 | 17 | | Path | Description 18 | | :--- | :---------- 19 | | [stylegan2-ada-pytorch](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/) | Main directory hosted on Amazon S3 20 | |   ├  [ada-paper.pdf](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/ada-paper.pdf) | Paper PDF 21 | |   ├  [images](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/images/) | Curated example images produced using the pre-trained models 22 | |   ├  [videos](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/videos/) | Curated example interpolation videos 23 | |   └  [pretrained](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/) | Pre-trained models 24 | |     ├  ffhq.pkl | FFHQ at 1024x1024, trained using original StyleGAN2 25 | |     ├  brecahad.pkl | BreCaHAD at 512x512, trained from scratch using ADA 26 | |     ├  [paper-fig7c-training-set-sweeps](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/paper-fig7c-training-set-sweeps/) | Models used in Fig.7c (sweep over training set size) 27 | |     ├  [paper-fig11a-small-datasets](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/paper-fig11a-small-datasets/) | Models used in Fig.11a (small datasets & transfer learning) 28 | |     ├  [paper-fig11b-cifar10](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/paper-fig11b-cifar10/) | Models used in Fig.11b (CIFAR-10) 29 | |     ├  [transfer-learning-source-nets](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/) | Models used as starting point for transfer learning 30 | |     └  [metrics](https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/) | Feature detectors used by the quality metrics 31 | 32 | 33 | ## Getting started 34 | 35 | Pre-trained networks are stored as `*.pkl` files that can be referenced using local filenames or URLs: 36 | 37 | ```.bash 38 | # Generate curated MetFaces images without truncation (Fig.10 left) 39 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \ 40 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 41 | 42 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left) 43 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \ 44 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 45 | 46 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car) 47 | python generate.py --outdir=out --seeds=0-35 --class=1 \ 48 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl 49 | 50 | # Style mixing example 51 | python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \ 52 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 53 | ``` 54 | 55 | Outputs from the above commands are placed under `out/*.png`, controlled by `--outdir`. Downloaded network pickles are cached under `$HOME/.cache/dnnlib`, which can be overridden by setting the `DNNLIB_CACHE_DIR` environment variable. The default PyTorch extension build directory is `$HOME/.cache/torch_extensions`, which can be overridden by setting `TORCH_EXTENSIONS_DIR`. 56 | 57 | ## Projecting images to latent space 58 | 59 | To find the matching latent vector for a given image file, run: 60 | 61 | ```.bash 62 | python projector.py --outdir=out --target=~/mytargetimg.png \ 63 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 64 | ``` 65 | 66 | For optimal results, the target image should be cropped and aligned similar to the [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset). The above command saves the projection target `out/target.png`, result `out/proj.png`, latent vector `out/projected_w.npz`, and progression video `out/proj.mp4`. You can render the resulting latent vector by specifying `--projected_w` for `generate.py`: 67 | 68 | ```.bash 69 | python generate.py --outdir=out --projected_w=out/projected_w.npz \ 70 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 71 | ``` 72 | 73 | ## Using networks from Python 74 | 75 | You can use pre-trained networks in your own Python code as follows: 76 | 77 | ```.python 78 | with open('ffhq.pkl', 'rb') as f: 79 | G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module 80 | z = torch.randn([1, G.z_dim]).cuda() # latent codes 81 | c = None # class labels (not used in this example) 82 | img = G(z, c) # NCHW, float32, dynamic range [-1, +1] 83 | ``` 84 | 85 | The above code requires `torch_utils` and `dnnlib` to be accessible via `PYTHONPATH`. It does not need source code for the networks themselves — their class definitions are loaded from the pickle via `torch_utils.persistence`. 86 | 87 | The pickle contains three networks. `'G'` and `'D'` are instantaneous snapshots taken during training, and `'G_ema'` represents a moving average of the generator weights over several training steps. The networks are regular instances of `torch.nn.Module`, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default. 88 | 89 | The generator consists of two submodules, `G.mapping` and `G.synthesis`, that can be executed separately. They also support various additional options: 90 | 91 | ```.python 92 | w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8) 93 | img = G.synthesis(w, noise_mode='const', force_fp32=True) 94 | ``` 95 | 96 | Please refer to [`generate.py`](./generate.py), [`style_mixing.py`](./style_mixing.py), and [`projector.py`](./projector.py) for further examples. 97 | 98 | ## Preparing datasets 99 | 100 | Datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. 101 | 102 | Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information. Alternatively, the folder can also be used directly as a dataset, without running it through `dataset_tool.py` first, but doing so may lead to suboptimal performance. 103 | 104 | Legacy TFRecords datasets are not supported — see below for instructions on how to convert them. 105 | 106 | **FFHQ**: 107 | 108 | Step 1: Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as TFRecords. 109 | 110 | Step 2: Extract images from TFRecords using `dataset_tool.py` from the [TensorFlow version of StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada/): 111 | 112 | ```.bash 113 | # Using dataset_tool.py from TensorFlow version at 114 | # https://github.com/NVlabs/stylegan2-ada/ 115 | python ../stylegan2-ada/dataset_tool.py unpack \ 116 | --tfrecord_dir=~/ffhq-dataset/tfrecords/ffhq --output_dir=/tmp/ffhq-unpacked 117 | ``` 118 | 119 | Step 3: Create ZIP archive using `dataset_tool.py` from this repository: 120 | 121 | ```.bash 122 | # Original 1024x1024 resolution. 123 | python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq.zip 124 | 125 | # Scaled down 256x256 resolution. 126 | python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq256x256.zip \ 127 | --width=256 --height=256 128 | ``` 129 | 130 | ## Training new networks 131 | 132 | In its most basic form, training new networks boils down to: 133 | 134 | ```.bash 135 | python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 --dry-run 136 | python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 137 | ``` 138 | 139 | The first command is optional; it validates the arguments, prints out the training configuration, and exits. The second command kicks off the actual training. 140 | 141 | In this example, the results are saved to a newly created directory `~/training-runs/-mydataset-auto1`, controlled by `--outdir`. The training exports network pickles (`network-snapshot-.pkl`) and example images (`fakes.png`) at regular intervals (controlled by `--snap`). For each pickle, it also evaluates FID (controlled by `--metrics`) and logs the resulting scores in `metric-fid50k_full.jsonl` (as well as TFEvents if TensorBoard is installed). 142 | 143 | The name of the output directory reflects the training configuration. For example, `00000-mydataset-auto1` indicates that the *base configuration* was `auto1`, meaning that the hyperparameters were selected automatically for training on one GPU. The base configuration is controlled by `--cfg`: 144 | 145 | | Base config | Description 146 | | :-------------------- | :---------- 147 | | `auto` (default) | Automatically select reasonable defaults based on resolution and GPU count. Serves as a good starting point for new datasets but does not necessarily lead to optimal results. 148 | | `stylegan2` | Reproduce results for StyleGAN2 config F at 1024x1024 using 1, 2, 4, or 8 GPUs. 149 | | `paper256` | Reproduce results for FFHQ and LSUN Cat at 256x256 using 1, 2, 4, or 8 GPUs. 150 | | `paper512` | Reproduce results for BreCaHAD and AFHQ at 512x512 using 1, 2, 4, or 8 GPUs. 151 | | `paper1024` | Reproduce results for MetFaces at 1024x1024 using 1, 2, 4, or 8 GPUs. 152 | | `cifar` | Reproduce results for CIFAR-10 (tuned configuration) using 1 or 2 GPUs. 153 | 154 | The training configuration can be further customized with additional command line options: 155 | 156 | * `--aug=noaug` disables ADA. 157 | * `--cond=1` enables class-conditional training (requires a dataset with labels). 158 | * `--mirror=1` amplifies the dataset with x-flips. Often beneficial, even with ADA. 159 | * `--resume=ffhq1024 --snap=10` performs transfer learning from FFHQ trained at 1024x1024. 160 | * `--resume=~/training-runs//network-snapshot-.pkl` resumes a previous training run. 161 | * `--gamma=10` overrides R1 gamma. We recommend trying a couple of different values for each new dataset. 162 | * `--aug=ada --target=0.7` adjusts ADA target value (default: 0.6). 163 | * `--augpipe=blit` enables pixel blitting but disables all other augmentations. 164 | * `--augpipe=bgcfnc` enables all available augmentations (blit, geom, color, filter, noise, cutout). 165 | 166 | Please refer to [`python train.py --help`](./docs/train-help.txt) for the full list. 167 | 168 | 169 | References: 170 | 1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017 171 | 2. [Demystifying MMD GANs](https://arxiv.org/abs/1801.01401), Bińkowski et al. 2018 172 | 3. [Improved Precision and Recall Metric for Assessing Generative Models](https://arxiv.org/abs/1904.06991), Kynkäänniemi et al. 2019 173 | 4. [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498), Salimans et al. 2016 174 | 5. [A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/abs/1812.04948), Karras et al. 2018 175 | 176 | ## Citation 177 | 178 | ``` 179 | @inproceedings{Karras2020ada, 180 | title = {Training Generative Adversarial Networks with Limited Data}, 181 | author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila}, 182 | booktitle = {Proc. NeurIPS}, 183 | year = {2020} 184 | } 185 | 186 | @inproceedings{zhao2021comodgan, 187 | title={Large Scale Image Completion via Co-Modulated Generative Adversarial Networks}, 188 | author={Zhao, Shengyu and Cui, Jonathan and Sheng, Yilun and Dong, Yue and Liang, Xiao and Chang, Eric I and Xu, Yan}, 189 | booktitle={International Conference on Learning Representations (ICLR)}, 190 | year={2021} 191 | } 192 | ``` 193 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | import dnnlib 18 | 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def subprocess_fn(rank, args, temp_dir): 29 | dnnlib.util.Logger(should_flush=True) 30 | 31 | # Init torch.distributed. 32 | if args.num_gpus > 1: 33 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 34 | if os.name == 'nt': 35 | init_method = 'file:///' + init_file.replace('\\', '/') 36 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 37 | else: 38 | init_method = f'file://{init_file}' 39 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 40 | 41 | # Init torch_utils. 42 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 43 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 44 | if rank != 0 or not args.verbose: 45 | custom_ops.verbosity = 'none' 46 | 47 | # Print network summary. 48 | device = torch.device('cuda', rank) 49 | torch.backends.cudnn.benchmark = True 50 | torch.backends.cuda.matmul.allow_tf32 = False 51 | torch.backends.cudnn.allow_tf32 = False 52 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 53 | if rank == 0 and args.verbose: 54 | z = torch.empty([1, G.z_dim], device=device) 55 | c = torch.empty([1, G.c_dim], device=device) 56 | misc.print_module_summary(G, [z, c]) 57 | 58 | # Calculate each metric. 59 | for metric in args.metrics: 60 | if rank == 0 and args.verbose: 61 | print(f'Calculating {metric}...') 62 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 63 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 64 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) 65 | if rank == 0: 66 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 67 | if rank == 0 and args.verbose: 68 | print() 69 | 70 | # Done. 71 | if rank == 0 and args.verbose: 72 | print('Exiting...') 73 | 74 | #---------------------------------------------------------------------------- 75 | 76 | class CommaSeparatedList(click.ParamType): 77 | name = 'list' 78 | 79 | def convert(self, value, param, ctx): 80 | _ = param, ctx 81 | if value is None or value.lower() == 'none' or value == '': 82 | return [] 83 | return value.split(',') 84 | 85 | #---------------------------------------------------------------------------- 86 | 87 | @click.command() 88 | @click.pass_context 89 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 90 | @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True) 91 | @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH') 92 | @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL') 93 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 94 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 95 | 96 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 97 | """Calculate quality metrics for previous training run or pretrained network pickle. 98 | 99 | Examples: 100 | 101 | \b 102 | # Previous training run: look up options automatically, save result to JSONL file. 103 | python calc_metrics.py --metrics=pr50k3_full \\ 104 | --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl 105 | 106 | \b 107 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 108 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\ 109 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 110 | 111 | Available metrics: 112 | 113 | \b 114 | ADA paper: 115 | fid50k_full Frechet inception distance against the full dataset. 116 | kid50k_full Kernel inception distance against the full dataset. 117 | pr50k3_full Precision and recall againt the full dataset. 118 | is50k Inception score for CIFAR-10. 119 | 120 | \b 121 | StyleGAN and StyleGAN2 papers: 122 | fid50k Frechet inception distance against 50k real images. 123 | kid50k Kernel inception distance against 50k real images. 124 | pr50k3 Precision and recall against 50k real images. 125 | ppl2_wend Perceptual path length in W at path endpoints against full image. 126 | ppl_zfull Perceptual path length in Z for full paths against cropped image. 127 | ppl_wfull Perceptual path length in W for full paths against cropped image. 128 | ppl_zend Perceptual path length in Z at path endpoints against cropped image. 129 | ppl_wend Perceptual path length in W at path endpoints against cropped image. 130 | """ 131 | dnnlib.util.Logger(should_flush=True) 132 | 133 | # Validate arguments. 134 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 135 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 136 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 137 | if not args.num_gpus >= 1: 138 | ctx.fail('--gpus must be at least 1') 139 | 140 | # Load network. 141 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 142 | ctx.fail('--network must point to a file or URL') 143 | if args.verbose: 144 | print(f'Loading network from "{network_pkl}"...') 145 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 146 | network_dict = legacy.load_network_pkl(f) 147 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 148 | 149 | # Initialize dataset options. 150 | if data is not None: 151 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 152 | elif network_dict['training_set_kwargs'] is not None: 153 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 154 | else: 155 | ctx.fail('Could not look up dataset options; please specify --data') 156 | 157 | # Finalize dataset options. 158 | args.dataset_kwargs.resolution = args.G.img_resolution 159 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 160 | if mirror is not None: 161 | args.dataset_kwargs.xflip = mirror 162 | 163 | # Print dataset options. 164 | if args.verbose: 165 | print('Dataset options:') 166 | print(json.dumps(args.dataset_kwargs, indent=2)) 167 | 168 | # Locate run dir. 169 | args.run_dir = None 170 | if os.path.isfile(network_pkl): 171 | pkl_dir = os.path.dirname(network_pkl) 172 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 173 | args.run_dir = pkl_dir 174 | 175 | # Launch processes. 176 | if args.verbose: 177 | print('Launching processes...') 178 | torch.multiprocessing.set_start_method('spawn') 179 | with tempfile.TemporaryDirectory() as temp_dir: 180 | if args.num_gpus == 1: 181 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 182 | else: 183 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 184 | 185 | #---------------------------------------------------------------------------- 186 | 187 | if __name__ == "__main__": 188 | calc_metrics() # pylint: disable=no-value-for-parameter 189 | 190 | #---------------------------------------------------------------------------- 191 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | set -e 12 | 13 | # Wrapper script for setting up `docker run` to properly 14 | # cache downloaded files, custom extension builds and 15 | # mount the source directory into the container and make it 16 | # run as non-root user. 17 | # 18 | # Use it like: 19 | # 20 | # ./docker_run.sh python generate.py --help 21 | # 22 | # To override the default `stylegan2ada:latest` image, run: 23 | # 24 | # IMAGE=my_image:v1.0 ./docker_run.sh python generate.py --help 25 | # 26 | 27 | rest=$@ 28 | 29 | IMAGE="${IMAGE:-sg2ada:latest}" 30 | 31 | CONTAINER_ID=$(docker inspect --format="{{.Id}}" ${IMAGE} 2> /dev/null) 32 | if [[ "${CONTAINER_ID}" ]]; then 33 | docker run --shm-size=2g --gpus all -it --rm -v `pwd`:/scratch --user $(id -u):$(id -g) \ 34 | --workdir=/scratch -e HOME=/scratch $IMAGE $@ 35 | else 36 | echo "Unknown container image: ${IMAGE}" 37 | exit 1 38 | fi 39 | -------------------------------------------------------------------------------- /docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ - Load LSUN dataset 9 | --source cifar-10-python.tar.gz - Load CIFAR-10 dataset 10 | --source path/ - Recursively load all images from path/ 11 | --source dataset.zip - Recursively load all images from dataset.zip 12 | 13 | The output dataset format can be either an image folder or a zip archive. 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir - Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive 18 | 19 | Images within the dataset archive will be stored as uncompressed PNG. 20 | 21 | Image scale/crop and resolution requirements: 22 | 23 | Output images must be square-shaped and they must all have the same power- 24 | of-two dimensions. 25 | 26 | To scale arbitrary input image size to a specific width and height, use 27 | the --width and --height options. Output resolution will be either the 28 | original input resolution (if --width/--height was not specified) or the 29 | one specified with --width/height. 30 | 31 | Use the --transform=center-crop or --transform=center-crop-wide options to 32 | apply a center crop transform on the input image. These options should be 33 | used with the --width and --height options. For example: 34 | 35 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 36 | --transform=center-crop-wide --width 512 --height=384 37 | 38 | Options: 39 | --source PATH Directory or archive name for input dataset 40 | [required] 41 | --dest PATH Output directory or archive name for output 42 | dataset [required] 43 | --max-images INTEGER Output only up to `max-images` images 44 | --resize-filter [box|lanczos] Filter to use when resizing images for 45 | output resolution [default: lanczos] 46 | --transform [center-crop|center-crop-wide] 47 | Input crop/resize mode 48 | --width INTEGER Output width 49 | --height INTEGER Output height 50 | --help Show this message and exit. 51 | -------------------------------------------------------------------------------- /docs/license.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Nvidia Source Code License-NC 7 | 8 | 56 | 57 | 58 | 59 |

NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)

60 | 61 |
62 | 63 |

1. Definitions

64 | 65 |

“Licensor” means any person or entity that distributes its Work.

66 | 67 |

“Software” means the original work of authorship made available under 68 | this License.

69 | 70 |

“Work” means the Software and any additions to or derivative works of 71 | the Software that are made available under this License.

72 | 73 |

The terms “reproduce,” “reproduction,” “derivative works,” and 74 | “distribution” have the meaning as provided under U.S. copyright law; 75 | provided, however, that for the purposes of this License, derivative 76 | works shall not include works that remain separable from, or merely 77 | link (or bind by name) to the interfaces of, the Work.

78 | 79 |

Works, including the Software, are “made available” under this License 80 | by including in or with the Work either (a) a copyright notice 81 | referencing the applicability of this License to the Work, or (b) a 82 | copy of this License.

83 | 84 |

2. License Grants

85 | 86 |

2.1 Copyright Grant. Subject to the terms and conditions of this 87 | License, each Licensor grants to you a perpetual, worldwide, 88 | non-exclusive, royalty-free, copyright license to reproduce, 89 | prepare derivative works of, publicly display, publicly perform, 90 | sublicense and distribute its Work and any resulting derivative 91 | works in any form.

92 | 93 |

3. Limitations

94 | 95 |

3.1 Redistribution. You may reproduce or distribute the Work only 96 | if (a) you do so under this License, (b) you include a complete 97 | copy of this License with your distribution, and (c) you retain 98 | without modification any copyright, patent, trademark, or 99 | attribution notices that are present in the Work.

100 | 101 |

3.2 Derivative Works. You may specify that additional or different 102 | terms apply to the use, reproduction, and distribution of your 103 | derivative works of the Work (“Your Terms”) only if (a) Your Terms 104 | provide that the use limitation in Section 3.3 applies to your 105 | derivative works, and (b) you identify the specific derivative 106 | works that are subject to Your Terms. Notwithstanding Your Terms, 107 | this License (including the redistribution requirements in Section 108 | 3.1) will continue to apply to the Work itself.

109 | 110 |

3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for 111 | use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work 112 | and any derivative works commercially. As used herein, “non-commercially” means for research or 113 | evaluation purposes only. 114 | 115 |

3.4 Patent Claims. If you bring or threaten to bring a patent claim 116 | against any Licensor (including any claim, cross-claim or 117 | counterclaim in a lawsuit) to enforce any patents that you allege 118 | are infringed by any Work, then your rights under this License from 119 | such Licensor (including the grant in Section 2.1) will terminate immediately. 120 | 121 |

3.5 Trademarks. This License does not grant any rights to use any 122 | Licensor’s or its affiliates’ names, logos, or trademarks, except 123 | as necessary to reproduce the notices described in this License.

124 | 125 |

3.6 Termination. If you violate any term of this License, then your 126 | rights under this License (including the grant in Section 2.1) 127 | will terminate immediately.

128 | 129 |

4. Disclaimer of Warranty.

130 | 131 |

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY 132 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 133 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 134 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 135 | THIS LICENSE.

136 | 137 |

5. Limitation of Liability.

138 | 139 |

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 140 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 141 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 142 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 143 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 144 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 145 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 146 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 147 | THE POSSIBILITY OF SUCH DAMAGES.

148 | 149 |
150 |
151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /docs/stylegan2-ada-teaser-1024x252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/docs/stylegan2-ada-teaser-1024x252.png -------------------------------------------------------------------------------- /docs/stylegan2-ada-training-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/docs/stylegan2-ada-training-curves.png -------------------------------------------------------------------------------- /docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train a GAN using the techniques described in the paper "Training 4 | Generative Adversarial Networks with Limited Data". 5 | 6 | Examples: 7 | 8 | # Train with custom images using 1 GPU. 9 | python train.py --outdir=~/training-runs --data=~/my-image-folder 10 | 11 | # Train class-conditional CIFAR-10 using 2 GPUs. 12 | python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \ 13 | --gpus=2 --cfg=cifar --cond=1 14 | 15 | # Transfer learn MetFaces from FFHQ using 4 GPUs. 16 | python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \ 17 | --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 18 | 19 | # Reproduce original StyleGAN2 config F. 20 | python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \ 21 | --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug 22 | 23 | Base configs (--cfg): 24 | auto Automatically select reasonable defaults based on resolution 25 | and GPU count. Good starting point for new datasets. 26 | stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. 27 | paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. 28 | paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. 29 | paper1024 Reproduce results for MetFaces at 1024x1024. 30 | cifar Reproduce results for CIFAR-10 at 32x32. 31 | 32 | Transfer learning source networks (--resume): 33 | ffhq256 FFHQ trained at 256x256 resolution. 34 | ffhq512 FFHQ trained at 512x512 resolution. 35 | ffhq1024 FFHQ trained at 1024x1024 resolution. 36 | celebahq256 CelebA-HQ trained at 256x256 resolution. 37 | lsundog256 LSUN Dog trained at 256x256 resolution. 38 | Custom network pickle. 39 | 40 | Options: 41 | --outdir DIR Where to save the results [required] 42 | --gpus INT Number of GPUs to use [default: 1] 43 | --snap INT Snapshot interval [default: 50 ticks] 44 | --metrics LIST Comma-separated list or "none" [default: 45 | fid50k_full] 46 | --seed INT Random seed [default: 0] 47 | -n, --dry-run Print training options and exit 48 | --data PATH Training data (directory or zip) [required] 49 | --cond BOOL Train conditional model based on dataset 50 | labels [default: false] 51 | --subset INT Train with only N images [default: all] 52 | --mirror BOOL Enable dataset x-flips [default: false] 53 | --cfg [auto|stylegan2|paper256|paper512|paper1024|cifar] 54 | Base config [default: auto] 55 | --gamma FLOAT Override R1 gamma 56 | --kimg INT Override training duration 57 | --batch INT Override batch size 58 | --aug [noaug|ada|fixed] Augmentation mode [default: ada] 59 | --p FLOAT Augmentation probability for --aug=fixed 60 | --target FLOAT ADA target value for --aug=ada 61 | --augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc] 62 | Augmentation pipeline [default: bgc] 63 | --resume PKL Resume training [default: noresume] 64 | --freezed INT Freeze-D [default: 0 layers] 65 | --fp32 BOOL Disable mixed-precision training 66 | --nhwc BOOL Use NHWC memory format with FP16 67 | --nobench BOOL Disable cuDNN benchmarking 68 | --allow-tf32 BOOL Allow PyTorch to use TF32 internally 69 | --workers INT Override number of DataLoader workers 70 | --help Show this message and exit. 71 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def num_range(s: str) -> List[int]: 26 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 27 | 28 | range_re = re.compile(r'^(\d+)-(\d+)$') 29 | m = range_re.match(s) 30 | if m: 31 | return list(range(int(m.group(1)), int(m.group(2))+1)) 32 | vals = s.split(',') 33 | return [int(x) for x in vals] 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | @click.command() 38 | @click.pass_context 39 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 40 | @click.option('--seeds', type=num_range, help='List of random seeds') 41 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 42 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 43 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 44 | @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') 45 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 46 | @click.option('--inpdir', help='Where to take the input images from', type=str, required=True, metavar='DIR') 47 | def generate_images( 48 | ctx: click.Context, 49 | network_pkl: str, 50 | seeds: Optional[List[int]], 51 | truncation_psi: float, 52 | noise_mode: str, 53 | inpdir: str, 54 | outdir: str, 55 | class_idx: Optional[int], 56 | projected_w: Optional[str] 57 | ): 58 | """Generate images using pretrained network pickle. 59 | 60 | Examples: 61 | 62 | \b 63 | # Generate curated MetFaces images without truncation (Fig.10 left) 64 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ 65 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 66 | 67 | \b 68 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left) 69 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 70 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 71 | 72 | \b 73 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car) 74 | python generate.py --outdir=out --seeds=0-35 --class=1 \\ 75 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl 76 | 77 | \b 78 | # Render an image from projected W 79 | python generate.py --outdir=out --projected_w=projected_w.npz \\ 80 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 81 | """ 82 | 83 | print('Loading networks from "%s"...' % network_pkl) 84 | device = torch.device('cuda') 85 | with dnnlib.util.open_url(network_pkl) as f: 86 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 87 | 88 | os.makedirs(outdir, exist_ok=True) 89 | 90 | # Synthesize the result of a W projection. 91 | if projected_w is not None: 92 | if seeds is not None: 93 | print ('warn: --seeds is ignored when using --projected-w') 94 | print(f'Generating images from projected W "{projected_w}"') 95 | ws = np.load(projected_w)['w'] 96 | ws = torch.tensor(ws, device=device) # pylint: disable=not-callable 97 | assert ws.shape[1:] == (G.num_ws, G.w_dim) 98 | for idx, w in enumerate(ws): 99 | img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) 100 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 101 | img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png') 102 | return 103 | 104 | if seeds is None: 105 | ctx.fail('--seeds option is required when not using --projected-w') 106 | 107 | # Labels. 108 | label = torch.zeros([1, G.c_dim], device=device) 109 | if G.c_dim != 0: 110 | if class_idx is None: 111 | ctx.fail('Must specify class label with --class when using a conditional network') 112 | label[:, class_idx] = 1 113 | else: 114 | if class_idx is not None: 115 | print ('warn: --class=lbl ignored when running on an unconditional network') 116 | 117 | # Generate images. 118 | imgdir = inpdir + '/images/' 119 | maskdir = inpdir + '/masks/' 120 | inps = os.listdir(imgdir) 121 | seed = seeds[0] 122 | for i, inp in enumerate(inps): 123 | print('Generating image for seed %d (%d) ...' % (seed, len(seeds))) 124 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 125 | inp_img = np.transpose(np.array(PIL.Image.open(imgdir + inp)), (2, 0, 1)) 126 | print(inp_img.shape) 127 | inp_img = (torch.from_numpy(inp_img).to(torch.float32) / 127.5 - 1).to(device) 128 | inp_img = inp_img.unsqueeze(0) 129 | inp_mask = np.transpose(np.array(PIL.Image.open(maskdir + inp)), (2, 0, 1)) 130 | print(inp_mask.shape) 131 | inp_mask = (torch.from_numpy(inp_mask).to(torch.float32) / 255.).to(device) 132 | inp_mask = inp_mask.unsqueeze(0) 133 | img = G(z, label, image_in=inp_img, mask_in=inp_mask, truncation_psi=truncation_psi, noise_mode=noise_mode) 134 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 135 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/'+ inp) 136 | 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | if __name__ == "__main__": 141 | generate_images() # pylint: disable=no-value-for-parameter 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import json 12 | import torch 13 | import dnnlib 14 | 15 | from . import metric_utils 16 | from . import frechet_inception_distance 17 | from . import kernel_inception_distance 18 | from . import precision_recall 19 | from . import perceptual_path_length 20 | from . import inception_score 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | _metric_dict = dict() # name => fn 25 | 26 | def register_metric(fn): 27 | assert callable(fn) 28 | _metric_dict[fn.__name__] = fn 29 | return fn 30 | 31 | def is_valid_metric(metric): 32 | return metric in _metric_dict 33 | 34 | def list_valid_metrics(): 35 | return list(_metric_dict.keys()) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 40 | assert is_valid_metric(metric) 41 | opts = metric_utils.MetricOptions(**kwargs) 42 | 43 | # Calculate. 44 | start_time = time.time() 45 | results = _metric_dict[metric](opts) 46 | total_time = time.time() - start_time 47 | 48 | # Broadcast results. 49 | for key, value in list(results.items()): 50 | if opts.num_gpus > 1: 51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 52 | torch.distributed.broadcast(tensor=value, src=0) 53 | value = float(value.cpu()) 54 | results[key] = value 55 | 56 | # Decorate with metadata. 57 | return dnnlib.EasyDict( 58 | results = dnnlib.EasyDict(results), 59 | metric = metric, 60 | total_time = total_time, 61 | total_time_str = dnnlib.util.format_time(total_time), 62 | num_gpus = opts.num_gpus, 63 | ) 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 68 | metric = result_dict['metric'] 69 | assert is_valid_metric(metric) 70 | if run_dir is not None and snapshot_pkl is not None: 71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 72 | 73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 74 | print(jsonl_line) 75 | if run_dir is not None and os.path.isdir(run_dir): 76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 77 | f.write(jsonl_line + '\n') 78 | 79 | #---------------------------------------------------------------------------- 80 | # Primary metrics. 81 | 82 | @register_metric 83 | def fid50k_full(opts): 84 | opts.dataset_kwargs.update(max_size=None, xflip=False) 85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 86 | return dict(fid50k_full=fid) 87 | 88 | @register_metric 89 | def kid50k_full(opts): 90 | opts.dataset_kwargs.update(max_size=None, xflip=False) 91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 92 | return dict(kid50k_full=kid) 93 | 94 | @register_metric 95 | def pr50k3_full(opts): 96 | opts.dataset_kwargs.update(max_size=None, xflip=False) 97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 99 | 100 | @register_metric 101 | def ppl2_wend(opts): 102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 103 | return dict(ppl2_wend=ppl) 104 | 105 | @register_metric 106 | def is50k(opts): 107 | opts.dataset_kwargs.update(max_size=None, xflip=False) 108 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 109 | return dict(is50k_mean=mean, is50k_std=std) 110 | 111 | #---------------------------------------------------------------------------- 112 | # Legacy metrics. 113 | 114 | @register_metric 115 | def fid50k(opts): 116 | opts.dataset_kwargs.update(max_size=None) 117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 118 | return dict(fid50k=fid) 119 | 120 | @register_metric 121 | def kid50k(opts): 122 | opts.dataset_kwargs.update(max_size=None) 123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 124 | return dict(kid50k=kid) 125 | 126 | @register_metric 127 | def pr50k3(opts): 128 | opts.dataset_kwargs.update(max_size=None) 129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 131 | 132 | @register_metric 133 | def ppl_zfull(opts): 134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2) 135 | return dict(ppl_zfull=ppl) 136 | 137 | @register_metric 138 | def ppl_wfull(opts): 139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2) 140 | return dict(ppl_wfull=ppl) 141 | 142 | @register_metric 143 | def ppl_zend(opts): 144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2) 145 | return dict(ppl_zend=ppl) 146 | 147 | @register_metric 148 | def ppl_wend(opts): 149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2) 150 | return dict(ppl_wend=ppl) 151 | 152 | #---------------------------------------------------------------------------- 153 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | from . import metric_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | # Spherical interpolation of a batch of vectors. 23 | def slerp(a, b, t): 24 | a = a / a.norm(dim=-1, keepdim=True) 25 | b = b / b.norm(dim=-1, keepdim=True) 26 | d = (a * b).sum(dim=-1, keepdim=True) 27 | p = t * torch.acos(d) 28 | c = b - d * a 29 | c = c / c.norm(dim=-1, keepdim=True) 30 | d = a * torch.cos(p) + c * torch.sin(p) 31 | d = d / d.norm(dim=-1, keepdim=True) 32 | return d 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | class PPLSampler(torch.nn.Module): 37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 38 | assert space in ['z', 'w'] 39 | assert sampling in ['full', 'end'] 40 | super().__init__() 41 | self.G = copy.deepcopy(G) 42 | self.G_kwargs = G_kwargs 43 | self.epsilon = epsilon 44 | self.space = space 45 | self.sampling = sampling 46 | self.crop = crop 47 | self.vgg16 = copy.deepcopy(vgg16) 48 | 49 | def forward(self, c): 50 | # Generate random latents and interpolation t-values. 51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 53 | 54 | # Interpolate in W or Z. 55 | if self.space == 'w': 56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 59 | else: # space == 'z' 60 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 63 | 64 | # Randomize noise buffers. 65 | for name, buf in self.G.named_buffers(): 66 | if name.endswith('.noise_const'): 67 | buf.copy_(torch.randn_like(buf)) 68 | 69 | # Generate images. 70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 71 | 72 | # Center crop. 73 | if self.crop: 74 | assert img.shape[2] == img.shape[3] 75 | c = img.shape[2] // 8 76 | img = img[:, :, c*3 : c*7, c*2 : c*6] 77 | 78 | # Downsample to 256x256. 79 | factor = self.G.img_resolution // 256 80 | if factor > 1: 81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 82 | 83 | # Scale dynamic range from [-1,1] to [0,255]. 84 | img = (img + 1) * (255 / 2) 85 | if self.G.img_channels == 1: 86 | img = img.repeat([1, 3, 1, 1]) 87 | 88 | # Evaluate differential LPIPS. 89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 91 | return dist 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False): 96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 99 | 100 | # Setup sampler. 101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 102 | sampler.eval().requires_grad_(False).to(opts.device) 103 | if jit: 104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) 105 | sampler = torch.jit.trace(sampler, [c], check_trace=False) 106 | 107 | # Sampling loop. 108 | dist = [] 109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 111 | progress.update(batch_start) 112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] 113 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 114 | x = sampler(c) 115 | for src in range(opts.num_gpus): 116 | y = x.clone() 117 | if opts.num_gpus > 1: 118 | torch.distributed.broadcast(y, src=src) 119 | dist.append(y) 120 | progress.update(num_samples) 121 | 122 | # Compute PPL. 123 | if opts.rank != 0: 124 | return float('nan') 125 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 126 | lo = np.percentile(dist, 1, interpolation='lower') 127 | hi = np.percentile(dist, 99, interpolation='higher') 128 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 129 | return float(ppl) 130 | 131 | #---------------------------------------------------------------------------- 132 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Project given image to the latent space of pretrained network pickle.""" 10 | 11 | import copy 12 | import os 13 | from time import perf_counter 14 | 15 | import click 16 | import imageio 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | import dnnlib 23 | import legacy 24 | 25 | def project( 26 | G, 27 | target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution 28 | *, 29 | num_steps = 1000, 30 | w_avg_samples = 10000, 31 | initial_learning_rate = 0.1, 32 | initial_noise_factor = 0.05, 33 | lr_rampdown_length = 0.25, 34 | lr_rampup_length = 0.05, 35 | noise_ramp_length = 0.75, 36 | regularize_noise_weight = 1e5, 37 | verbose = False, 38 | device: torch.device 39 | ): 40 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) 41 | 42 | def logprint(*args): 43 | if verbose: 44 | print(*args) 45 | 46 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore 47 | 48 | # Compute w stats. 49 | logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') 50 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) 51 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] 52 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] 53 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] 54 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 55 | 56 | # Setup noise inputs. 57 | noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } 58 | 59 | # Load VGG16 feature detector. 60 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 61 | with dnnlib.util.open_url(url) as f: 62 | vgg16 = torch.jit.load(f).eval().to(device) 63 | 64 | # Features for target image. 65 | target_images = target.unsqueeze(0).to(device).to(torch.float32) 66 | if target_images.shape[2] > 256: 67 | target_images = F.interpolate(target_images, size=(256, 256), mode='area') 68 | target_features = vgg16(target_images, resize_images=False, return_lpips=True) 69 | 70 | w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable 71 | w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) 72 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) 73 | 74 | # Init noise. 75 | for buf in noise_bufs.values(): 76 | buf[:] = torch.randn_like(buf) 77 | buf.requires_grad = True 78 | 79 | for step in range(num_steps): 80 | # Learning rate schedule. 81 | t = step / num_steps 82 | w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 83 | lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) 84 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 85 | lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) 86 | lr = initial_learning_rate * lr_ramp 87 | for param_group in optimizer.param_groups: 88 | param_group['lr'] = lr 89 | 90 | # Synth images from opt_w. 91 | w_noise = torch.randn_like(w_opt) * w_noise_scale 92 | ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1]) 93 | synth_images = G.synthesis(ws, noise_mode='const') 94 | 95 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 96 | synth_images = (synth_images + 1) * (255/2) 97 | if synth_images.shape[2] > 256: 98 | synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') 99 | 100 | # Features for synth images. 101 | synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) 102 | dist = (target_features - synth_features).square().sum() 103 | 104 | # Noise regularization. 105 | reg_loss = 0.0 106 | for v in noise_bufs.values(): 107 | noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() 108 | while True: 109 | reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 110 | reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 111 | if noise.shape[2] <= 8: 112 | break 113 | noise = F.avg_pool2d(noise, kernel_size=2) 114 | loss = dist + reg_loss * regularize_noise_weight 115 | 116 | # Step 117 | optimizer.zero_grad(set_to_none=True) 118 | loss.backward() 119 | optimizer.step() 120 | logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') 121 | 122 | # Save projected W for each optimization step. 123 | w_out[step] = w_opt.detach()[0] 124 | 125 | # Normalize noise. 126 | with torch.no_grad(): 127 | for buf in noise_bufs.values(): 128 | buf -= buf.mean() 129 | buf *= buf.square().mean().rsqrt() 130 | 131 | return w_out.repeat([1, G.mapping.num_ws, 1]) 132 | 133 | #---------------------------------------------------------------------------- 134 | 135 | @click.command() 136 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 137 | @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE') 138 | @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True) 139 | @click.option('--seed', help='Random seed', type=int, default=303, show_default=True) 140 | @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True) 141 | @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR') 142 | def run_projection( 143 | network_pkl: str, 144 | target_fname: str, 145 | outdir: str, 146 | save_video: bool, 147 | seed: int, 148 | num_steps: int 149 | ): 150 | """Project given image to the latent space of pretrained network pickle. 151 | 152 | Examples: 153 | 154 | \b 155 | python projector.py --outdir=out --target=~/mytargetimg.png \\ 156 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 157 | """ 158 | np.random.seed(seed) 159 | torch.manual_seed(seed) 160 | 161 | # Load networks. 162 | print('Loading networks from "%s"...' % network_pkl) 163 | device = torch.device('cuda') 164 | with dnnlib.util.open_url(network_pkl) as fp: 165 | G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore 166 | 167 | # Load target image. 168 | target_pil = PIL.Image.open(target_fname).convert('RGB') 169 | w, h = target_pil.size 170 | s = min(w, h) 171 | target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) 172 | target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) 173 | target_uint8 = np.array(target_pil, dtype=np.uint8) 174 | 175 | # Optimize projection. 176 | start_time = perf_counter() 177 | projected_w_steps = project( 178 | G, 179 | target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable 180 | num_steps=num_steps, 181 | device=device, 182 | verbose=True 183 | ) 184 | print (f'Elapsed: {(perf_counter()-start_time):.1f} s') 185 | 186 | # Render debug output: optional video and projected image and W vector. 187 | os.makedirs(outdir, exist_ok=True) 188 | if save_video: 189 | video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') 190 | print (f'Saving optimization progress video "{outdir}/proj.mp4"') 191 | for projected_w in projected_w_steps: 192 | synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') 193 | synth_image = (synth_image + 1) * (255/2) 194 | synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() 195 | video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) 196 | video.close() 197 | 198 | # Save final projected frame and W vector. 199 | target_pil.save(f'{outdir}/target.png') 200 | projected_w = projected_w_steps[-1] 201 | synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') 202 | synth_image = (synth_image + 1) * (255/2) 203 | synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() 204 | PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png') 205 | np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy()) 206 | 207 | #---------------------------------------------------------------------------- 208 | 209 | if __name__ == "__main__": 210 | run_projection() # pylint: disable=no-value-for-parameter 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python train.py --outdir="./training-runs" --data="../comp.zip" --mask_data="../mask.zip" --gpus=1 -------------------------------------------------------------------------------- /style_mixing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate style mixing image matrix using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def num_range(s: str) -> List[int]: 26 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 27 | 28 | range_re = re.compile(r'^(\d+)-(\d+)$') 29 | m = range_re.match(s) 30 | if m: 31 | return list(range(int(m.group(1)), int(m.group(2))+1)) 32 | vals = s.split(',') 33 | return [int(x) for x in vals] 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | @click.command() 38 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 39 | @click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True) 40 | @click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True) 41 | @click.option('--styles', 'col_styles', type=num_range, help='Style layer range', default='0-6', show_default=True) 42 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 43 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 44 | @click.option('--outdir', type=str, required=True) 45 | def generate_style_mix( 46 | network_pkl: str, 47 | row_seeds: List[int], 48 | col_seeds: List[int], 49 | col_styles: List[int], 50 | truncation_psi: float, 51 | noise_mode: str, 52 | outdir: str 53 | ): 54 | """Generate images using pretrained network pickle. 55 | 56 | Examples: 57 | 58 | \b 59 | python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\ 60 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 61 | """ 62 | print('Loading networks from "%s"...' % network_pkl) 63 | device = torch.device('cuda') 64 | with dnnlib.util.open_url(network_pkl) as f: 65 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 66 | 67 | os.makedirs(outdir, exist_ok=True) 68 | 69 | print('Generating W vectors...') 70 | all_seeds = list(set(row_seeds + col_seeds)) 71 | all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds]) 72 | all_w = G.mapping(torch.from_numpy(all_z).to(device), None) 73 | w_avg = G.mapping.w_avg 74 | all_w = w_avg + (all_w - w_avg) * truncation_psi 75 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} 76 | 77 | print('Generating images...') 78 | all_images = G.synthesis(all_w, noise_mode=noise_mode) 79 | all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() 80 | image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} 81 | 82 | print('Generating style-mixed images...') 83 | for row_seed in row_seeds: 84 | for col_seed in col_seeds: 85 | w = w_dict[row_seed].clone() 86 | w[col_styles] = w_dict[col_seed][col_styles] 87 | image = G.synthesis(w[np.newaxis], noise_mode=noise_mode) 88 | image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 89 | image_dict[(row_seed, col_seed)] = image[0].cpu().numpy() 90 | 91 | print('Saving images...') 92 | os.makedirs(outdir, exist_ok=True) 93 | for (row_seed, col_seed), image in image_dict.items(): 94 | PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png') 95 | 96 | print('Saving image grid...') 97 | W = G.img_resolution 98 | H = G.img_resolution 99 | canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black') 100 | for row_idx, row_seed in enumerate([0] + row_seeds): 101 | for col_idx, col_seed in enumerate([0] + col_seeds): 102 | if row_idx == 0 and col_idx == 0: 103 | continue 104 | key = (row_seed, col_seed) 105 | if row_idx == 0: 106 | key = (col_seed, col_seed) 107 | if col_idx == 0: 108 | key = (row_seed, row_seed) 109 | canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx)) 110 | canvas.save(f'{outdir}/grid.png') 111 | 112 | 113 | #---------------------------------------------------------------------------- 114 | 115 | if __name__ == "__main__": 116 | generate_style_mix() # pylint: disable=no-value-for-parameter 117 | 118 | #---------------------------------------------------------------------------- 119 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Context manager for easily enabling/disabling DistributedDataParallel 164 | # synchronization. 165 | 166 | @contextlib.contextmanager 167 | def ddp_sync(module, sync): 168 | assert isinstance(module, torch.nn.Module) 169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 170 | yield 171 | else: 172 | with module.no_sync(): 173 | yield 174 | 175 | #---------------------------------------------------------------------------- 176 | # Check DistributedDataParallel consistency across processes. 177 | 178 | def check_ddp_consistency(module, ignore_regex=None): 179 | assert isinstance(module, torch.nn.Module) 180 | for name, tensor in named_params_and_buffers(module): 181 | fullname = type(module).__name__ + '.' + name 182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 183 | continue 184 | tensor = tensor.detach() 185 | other = tensor.clone() 186 | torch.distributed.broadcast(tensor=other, src=0) 187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 188 | 189 | #---------------------------------------------------------------------------- 190 | # Print summary table of module hierarchy. 191 | 192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 193 | assert isinstance(module, torch.nn.Module) 194 | assert not isinstance(module, torch.jit.ScriptModule) 195 | assert isinstance(inputs, (tuple, list)) 196 | 197 | # Register hooks. 198 | entries = [] 199 | nesting = [0] 200 | def pre_hook(_mod, _inputs): 201 | nesting[0] += 1 202 | def post_hook(mod, _inputs, outputs): 203 | nesting[0] -= 1 204 | if nesting[0] <= max_nesting: 205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 210 | 211 | # Run module. 212 | outputs = module(*inputs) 213 | for hook in hooks: 214 | hook.remove() 215 | 216 | # Identify unique outputs, parameters, and buffers. 217 | tensors_seen = set() 218 | for e in entries: 219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 223 | 224 | # Filter out redundant entries. 225 | if skip_redundant: 226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 227 | 228 | # Construct table. 229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 230 | rows += [['---'] * len(rows[0])] 231 | param_total = 0 232 | buffer_total = 0 233 | submodule_names = {mod: name for name, mod in module.named_modules()} 234 | for e in entries: 235 | name = '' if e.mod is module else submodule_names[e.mod] 236 | param_size = sum(t.numel() for t in e.unique_params) 237 | buffer_size = sum(t.numel() for t in e.unique_buffers) 238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 240 | rows += [[ 241 | name + (':0' if len(e.outputs) >= 2 else ''), 242 | str(param_size) if param_size else '-', 243 | str(buffer_size) if buffer_size else '-', 244 | (output_shapes + ['-'])[0], 245 | (output_dtypes + ['-'])[0], 246 | ]] 247 | for idx in range(1, len(e.outputs)): 248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 249 | param_total += param_size 250 | buffer_total += buffer_size 251 | rows += [['---'] * len(rows[0])] 252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 253 | 254 | # Print table. 255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 256 | print() 257 | for row in rows: 258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 259 | print() 260 | return outputs 261 | 262 | #---------------------------------------------------------------------------- 263 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/countless/.gitignore: -------------------------------------------------------------------------------- 1 | results -------------------------------------------------------------------------------- /training/countless/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless) 2 | 3 | Python COUNTLESS Downsampling 4 | ============================= 5 | 6 | To install: 7 | 8 | `pip install -r requirements.txt` 9 | 10 | To test: 11 | 12 | `python test.py` 13 | 14 | To benchmark countless2d: 15 | 16 | `python python/countless2d.py python/images/gray_segmentation.png` 17 | 18 | To benchmark countless3d: 19 | 20 | `python python/countless3d.py` 21 | 22 | Adjust N and the list of algorithms inside each script to modify the run parameters. 23 | 24 | 25 | Python3 is slightly faster than Python2. -------------------------------------------------------------------------------- /training/countless/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/__init__.py -------------------------------------------------------------------------------- /training/countless/countless3d.py: -------------------------------------------------------------------------------- 1 | from six.moves import range 2 | from PIL import Image 3 | import numpy as np 4 | import io 5 | import time 6 | import math 7 | import random 8 | import sys 9 | from collections import defaultdict 10 | from copy import deepcopy 11 | from itertools import combinations 12 | from functools import reduce 13 | from tqdm import tqdm 14 | 15 | from memory_profiler import profile 16 | 17 | def countless5(a,b,c,d,e): 18 | """First stage of generalizing from countless2d. 19 | 20 | You have five slots: A, B, C, D, E 21 | 22 | You can decide if something is the winner by first checking for 23 | matches of three, then matches of two, then picking just one if 24 | the other two tries fail. In countless2d, you just check for matches 25 | of two and then pick one of them otherwise. 26 | 27 | Unfortunately, you need to check ABC, ABD, ABE, BCD, BDE, & CDE. 28 | Then you need to check AB, AC, AD, BC, BD 29 | We skip checking E because if none of these match, we pick E. We can 30 | skip checking AE, BE, CE, DE since if any of those match, E is our boy 31 | so it's redundant. 32 | 33 | So countless grows cominatorially in complexity. 34 | """ 35 | sections = [ a,b,c,d,e ] 36 | 37 | p2 = lambda q,r: q * (q == r) # q if p == q else 0 38 | p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) # q if q == r == s else 0 39 | 40 | lor = lambda x,y: x + (x == 0) * y 41 | 42 | results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) 43 | results3 = reduce(lor, results3) 44 | 45 | results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) 46 | results2 = reduce(lor, results2) 47 | 48 | return reduce(lor, (results3, results2, e)) 49 | 50 | def countless8(a,b,c,d,e,f,g,h): 51 | """Extend countless5 to countless8. Same deal, except we also 52 | need to check for matches of length 4.""" 53 | sections = [ a, b, c, d, e, f, g, h ] 54 | 55 | p2 = lambda q,r: q * (q == r) 56 | p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) 57 | p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) ) 58 | 59 | lor = lambda x,y: x + (x == 0) * y 60 | 61 | results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) ) 62 | results4 = reduce(lor, results4) 63 | 64 | results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) 65 | results3 = reduce(lor, results3) 66 | 67 | # We can always use our shortcut of omitting the last element 68 | # for N choose 2 69 | results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) 70 | results2 = reduce(lor, results2) 71 | 72 | return reduce(lor, [ results4, results3, results2, h ]) 73 | 74 | def dynamic_countless3d(data): 75 | """countless8 + dynamic programming. ~2x faster""" 76 | sections = [] 77 | 78 | # shift zeros up one so they don't interfere with bitwise operators 79 | # we'll shift down at the end 80 | data += 1 81 | 82 | # This loop splits the 2D array apart into four arrays that are 83 | # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), 84 | # and (1,1) representing the A, B, C, and D positions from Figure 1. 85 | factor = (2,2,2) 86 | for offset in np.ndindex(factor): 87 | part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] 88 | sections.append(part) 89 | 90 | pick = lambda a,b: a * (a == b) 91 | lor = lambda x,y: x + (x == 0) * y 92 | 93 | subproblems2 = {} 94 | 95 | results2 = None 96 | for x,y in combinations(range(7), 2): 97 | res = pick(sections[x], sections[y]) 98 | subproblems2[(x,y)] = res 99 | if results2 is not None: 100 | results2 += (results2 == 0) * res 101 | else: 102 | results2 = res 103 | 104 | subproblems3 = {} 105 | 106 | results3 = None 107 | for x,y,z in combinations(range(8), 3): 108 | res = pick(subproblems2[(x,y)], sections[z]) 109 | 110 | if z != 7: 111 | subproblems3[(x,y,z)] = res 112 | 113 | if results3 is not None: 114 | results3 += (results3 == 0) * res 115 | else: 116 | results3 = res 117 | 118 | results3 = reduce(lor, (results3, results2, sections[-1])) 119 | 120 | # free memory 121 | results2 = None 122 | subproblems2 = None 123 | res = None 124 | 125 | results4 = ( pick(subproblems3[(x,y,z)], sections[w]) for x,y,z,w in combinations(range(8), 4) ) 126 | results4 = reduce(lor, results4) 127 | subproblems3 = None # free memory 128 | 129 | final_result = lor(results4, results3) - 1 130 | data -= 1 131 | return final_result 132 | 133 | def countless3d(data): 134 | """Now write countless8 in such a way that it could be used 135 | to process an image.""" 136 | sections = [] 137 | 138 | # shift zeros up one so they don't interfere with bitwise operators 139 | # we'll shift down at the end 140 | data += 1 141 | 142 | # This loop splits the 2D array apart into four arrays that are 143 | # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), 144 | # and (1,1) representing the A, B, C, and D positions from Figure 1. 145 | factor = (2,2,2) 146 | for offset in np.ndindex(factor): 147 | part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] 148 | sections.append(part) 149 | 150 | p2 = lambda q,r: q * (q == r) 151 | p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) 152 | p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) ) 153 | 154 | lor = lambda x,y: x + (x == 0) * y 155 | 156 | results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) ) 157 | results4 = reduce(lor, results4) 158 | 159 | results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) 160 | results3 = reduce(lor, results3) 161 | 162 | results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) 163 | results2 = reduce(lor, results2) 164 | 165 | final_result = reduce(lor, (results4, results3, results2, sections[-1])) - 1 166 | data -= 1 167 | return final_result 168 | 169 | def countless_generalized(data, factor): 170 | assert len(data.shape) == len(factor) 171 | 172 | sections = [] 173 | 174 | mode_of = reduce(lambda x,y: x * y, factor) 175 | majority = int(math.ceil(float(mode_of) / 2)) 176 | 177 | data += 1 178 | 179 | # This loop splits the 2D array apart into four arrays that are 180 | # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), 181 | # and (1,1) representing the A, B, C, and D positions from Figure 1. 182 | for offset in np.ndindex(factor): 183 | part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] 184 | sections.append(part) 185 | 186 | def pick(elements): 187 | eq = ( elements[i] == elements[i+1] for i in range(len(elements) - 1) ) 188 | anded = reduce(lambda p,q: p & q, eq) 189 | return elements[0] * anded 190 | 191 | def logical_or(x,y): 192 | return x + (x == 0) * y 193 | 194 | result = ( pick(combo) for combo in combinations(sections, majority) ) 195 | result = reduce(logical_or, result) 196 | for i in range(majority - 1, 3-1, -1): # 3-1 b/c of exclusive bounds 197 | partial_result = ( pick(combo) for combo in combinations(sections, i) ) 198 | partial_result = reduce(logical_or, partial_result) 199 | result = logical_or(result, partial_result) 200 | 201 | partial_result = ( pick(combo) for combo in combinations(sections[:-1], 2) ) 202 | partial_result = reduce(logical_or, partial_result) 203 | result = logical_or(result, partial_result) 204 | 205 | result = logical_or(result, sections[-1]) - 1 206 | data -= 1 207 | return result 208 | 209 | def dynamic_countless_generalized(data, factor): 210 | assert len(data.shape) == len(factor) 211 | 212 | sections = [] 213 | 214 | mode_of = reduce(lambda x,y: x * y, factor) 215 | majority = int(math.ceil(float(mode_of) / 2)) 216 | 217 | data += 1 # offset from zero 218 | 219 | # This loop splits the 2D array apart into four arrays that are 220 | # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), 221 | # and (1,1) representing the A, B, C, and D positions from Figure 1. 222 | for offset in np.ndindex(factor): 223 | part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] 224 | sections.append(part) 225 | 226 | pick = lambda a,b: a * (a == b) 227 | lor = lambda x,y: x + (x == 0) * y # logical or 228 | 229 | subproblems = [ {}, {} ] 230 | results2 = None 231 | for x,y in combinations(range(len(sections) - 1), 2): 232 | res = pick(sections[x], sections[y]) 233 | subproblems[0][(x,y)] = res 234 | if results2 is not None: 235 | results2 = lor(results2, res) 236 | else: 237 | results2 = res 238 | 239 | results = [ results2 ] 240 | for r in range(3, majority+1): 241 | r_results = None 242 | for combo in combinations(range(len(sections)), r): 243 | res = pick(subproblems[0][combo[:-1]], sections[combo[-1]]) 244 | 245 | if combo[-1] != len(sections) - 1: 246 | subproblems[1][combo] = res 247 | 248 | if r_results is not None: 249 | r_results = lor(r_results, res) 250 | else: 251 | r_results = res 252 | results.append(r_results) 253 | subproblems[0] = subproblems[1] 254 | subproblems[1] = {} 255 | 256 | results.reverse() 257 | final_result = lor(reduce(lor, results), sections[-1]) - 1 258 | data -= 1 259 | return final_result 260 | 261 | def downsample_with_averaging(array): 262 | """ 263 | Downsample x by factor using averaging. 264 | 265 | @return: The downsampled array, of the same type as x. 266 | """ 267 | factor = (2,2,2) 268 | 269 | if np.array_equal(factor[:3], np.array([1,1,1])): 270 | return array 271 | 272 | output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor)) 273 | temp = np.zeros(output_shape, float) 274 | counts = np.zeros(output_shape, np.int) 275 | for offset in np.ndindex(factor): 276 | part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] 277 | indexing_expr = tuple(np.s_[:s] for s in part.shape) 278 | temp[indexing_expr] += part 279 | counts[indexing_expr] += 1 280 | return np.cast[array.dtype](temp / counts) 281 | 282 | def downsample_with_max_pooling(array): 283 | 284 | factor = (2,2,2) 285 | 286 | sections = [] 287 | 288 | for offset in np.ndindex(factor): 289 | part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] 290 | sections.append(part) 291 | 292 | output = sections[0].copy() 293 | 294 | for section in sections[1:]: 295 | np.maximum(output, section, output) 296 | 297 | return output 298 | 299 | def striding(array): 300 | """Downsample x by factor using striding. 301 | 302 | @return: The downsampled array, of the same type as x. 303 | """ 304 | factor = (2,2,2) 305 | if np.all(np.array(factor, int) == 1): 306 | return array 307 | return array[tuple(np.s_[::f] for f in factor)] 308 | 309 | def benchmark(): 310 | def countless3d_generalized(img): 311 | return countless_generalized(img, (2,8,1)) 312 | def countless3d_dynamic_generalized(img): 313 | return dynamic_countless_generalized(img, (8,8,1)) 314 | 315 | methods = [ 316 | # countless3d, 317 | # dynamic_countless3d, 318 | countless3d_generalized, 319 | # countless3d_dynamic_generalized, 320 | # striding, 321 | # downsample_with_averaging, 322 | # downsample_with_max_pooling 323 | ] 324 | 325 | data = np.zeros(shape=(16**2, 16**2, 16**2), dtype=np.uint8) + 1 326 | 327 | N = 5 328 | 329 | print('Algorithm\tMPx\tMB/sec\tSec\tN=%d' % N) 330 | 331 | for fn in methods: 332 | start = time.time() 333 | for _ in range(N): 334 | result = fn(data) 335 | end = time.time() 336 | 337 | total_time = (end - start) 338 | mpx = N * float(data.shape[0] * data.shape[1] * data.shape[2]) / total_time / 1024.0 / 1024.0 339 | mbytes = mpx * np.dtype(data.dtype).itemsize 340 | # Output in tab separated format to enable copy-paste into excel/numbers 341 | print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time)) 342 | 343 | if __name__ == '__main__': 344 | benchmark() 345 | 346 | # Algorithm MPx MB/sec Sec N=5 347 | # countless3d 10.564 10.564 60.58 348 | # dynamic_countless3d 22.717 22.717 28.17 349 | # countless3d_generalized 9.702 9.702 65.96 350 | # countless3d_dynamic_generalized 22.720 22.720 28.17 351 | # striding 253360.506 253360.506 0.00 352 | # downsample_with_averaging 224.098 224.098 2.86 353 | # downsample_with_max_pooling 690.474 690.474 0.93 354 | 355 | 356 | 357 | -------------------------------------------------------------------------------- /training/countless/images/gcim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/images/gcim.jpg -------------------------------------------------------------------------------- /training/countless/images/gray_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/images/gray_segmentation.png -------------------------------------------------------------------------------- /training/countless/images/segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/images/segmentation.png -------------------------------------------------------------------------------- /training/countless/images/sparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/images/sparse.png -------------------------------------------------------------------------------- /training/countless/memprof/countless2d_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/memprof/countless2d_gcim_N_1000.png -------------------------------------------------------------------------------- /training/countless/memprof/countless2d_quick_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/memprof/countless2d_quick_gcim_N_1000.png -------------------------------------------------------------------------------- /training/countless/memprof/countless3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/memprof/countless3d.png -------------------------------------------------------------------------------- /training/countless/memprof/countless3d_dynamic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/memprof/countless3d_dynamic.png -------------------------------------------------------------------------------- /training/countless/memprof/countless3d_dynamic_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/memprof/countless3d_dynamic_generalized.png -------------------------------------------------------------------------------- /training/countless/memprof/countless3d_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayush12gupta/CoModGAN-PyTorch-Implementation/0fc99e71286979dd305eab9fd66bf4bb3c3e4430/training/countless/memprof/countless3d_generalized.png -------------------------------------------------------------------------------- /training/countless/requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=6.2.0 2 | numpy>=1.16 3 | scipy 4 | tqdm 5 | memory_profiler 6 | six 7 | pytest -------------------------------------------------------------------------------- /training/countless/test.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | 5 | import countless2d 6 | import countless3d 7 | 8 | def test_countless2d(): 9 | def test_all_cases(fn, test_zero): 10 | case1 = np.array([ [ 1, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) # all different 11 | case2 = np.array([ [ 1, 1 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # two are same 12 | case1z = np.array([ [ 0, 1 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # all different 13 | case2z = np.array([ [ 0, 0 ], [ 2, 3 ] ]).reshape((2,2,1,1)) # two are same 14 | case3 = np.array([ [ 1, 1 ], [ 2, 2 ] ]).reshape((2,2,1,1)) # two groups are same 15 | case4 = np.array([ [ 1, 2 ], [ 2, 2 ] ]).reshape((2,2,1,1)) # 3 are the same 16 | case5 = np.array([ [ 5, 5 ], [ 5, 5 ] ]).reshape((2,2,1,1)) # all are the same 17 | 18 | is_255_handled = np.array([ [ 255, 255 ], [ 1, 2 ] ], dtype=np.uint8).reshape((2,2,1,1)) 19 | 20 | test = lambda case: fn(case) 21 | 22 | if test_zero: 23 | assert test(case1z) == [[[[3]]]] # d 24 | assert test(case2z) == [[[[0]]]] # a==b 25 | else: 26 | assert test(case1) == [[[[4]]]] # d 27 | assert test(case2) == [[[[1]]]] # a==b 28 | 29 | assert test(case3) == [[[[1]]]] # a==b 30 | assert test(case4) == [[[[2]]]] # b==c 31 | assert test(case5) == [[[[5]]]] # a==b 32 | 33 | assert test(is_255_handled) == [[[[255]]]] 34 | 35 | assert fn(case1).dtype == case1.dtype 36 | 37 | test_all_cases(countless2d.simplest_countless, False) 38 | test_all_cases(countless2d.quick_countless, False) 39 | test_all_cases(countless2d.quickest_countless, False) 40 | test_all_cases(countless2d.stippled_countless, False) 41 | 42 | 43 | 44 | methods = [ 45 | countless2d.zero_corrected_countless, 46 | countless2d.countless, 47 | countless2d.countless_if, 48 | # countless2d.counting, # counting doesn't respect order so harder to write a test 49 | ] 50 | 51 | for fn in methods: 52 | print(fn.__name__) 53 | test_all_cases(fn, True) 54 | 55 | def test_stippled_countless2d(): 56 | a = np.array([ [ 1, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) 57 | b = np.array([ [ 0, 2 ], [ 3, 4 ] ]).reshape((2,2,1,1)) 58 | c = np.array([ [ 1, 0 ], [ 3, 4 ] ]).reshape((2,2,1,1)) 59 | d = np.array([ [ 1, 2 ], [ 0, 4 ] ]).reshape((2,2,1,1)) 60 | e = np.array([ [ 1, 2 ], [ 3, 0 ] ]).reshape((2,2,1,1)) 61 | f = np.array([ [ 0, 0 ], [ 3, 4 ] ]).reshape((2,2,1,1)) 62 | g = np.array([ [ 0, 2 ], [ 0, 4 ] ]).reshape((2,2,1,1)) 63 | h = np.array([ [ 0, 2 ], [ 3, 0 ] ]).reshape((2,2,1,1)) 64 | i = np.array([ [ 1, 0 ], [ 0, 4 ] ]).reshape((2,2,1,1)) 65 | j = np.array([ [ 1, 2 ], [ 0, 0 ] ]).reshape((2,2,1,1)) 66 | k = np.array([ [ 1, 0 ], [ 3, 0 ] ]).reshape((2,2,1,1)) 67 | l = np.array([ [ 1, 0 ], [ 0, 0 ] ]).reshape((2,2,1,1)) 68 | m = np.array([ [ 0, 2 ], [ 0, 0 ] ]).reshape((2,2,1,1)) 69 | n = np.array([ [ 0, 0 ], [ 3, 0 ] ]).reshape((2,2,1,1)) 70 | o = np.array([ [ 0, 0 ], [ 0, 4 ] ]).reshape((2,2,1,1)) 71 | z = np.array([ [ 0, 0 ], [ 0, 0 ] ]).reshape((2,2,1,1)) 72 | 73 | test = countless2d.stippled_countless 74 | 75 | # Note: We only tested non-matching cases above, 76 | # cases f,g,h,i,j,k prove their duals work as well 77 | # b/c if two pixels are black, either one can be chosen 78 | # if they are different or the same. 79 | 80 | assert test(a) == [[[[4]]]] 81 | assert test(b) == [[[[4]]]] 82 | assert test(c) == [[[[4]]]] 83 | assert test(d) == [[[[4]]]] 84 | assert test(e) == [[[[1]]]] 85 | assert test(f) == [[[[4]]]] 86 | assert test(g) == [[[[4]]]] 87 | assert test(h) == [[[[2]]]] 88 | assert test(i) == [[[[4]]]] 89 | assert test(j) == [[[[1]]]] 90 | assert test(k) == [[[[1]]]] 91 | assert test(l) == [[[[1]]]] 92 | assert test(m) == [[[[2]]]] 93 | assert test(n) == [[[[3]]]] 94 | assert test(o) == [[[[4]]]] 95 | assert test(z) == [[[[0]]]] 96 | 97 | bc = np.array([ [ 0, 2 ], [ 2, 4 ] ]).reshape((2,2,1,1)) 98 | bd = np.array([ [ 0, 2 ], [ 3, 2 ] ]).reshape((2,2,1,1)) 99 | cd = np.array([ [ 0, 2 ], [ 3, 3 ] ]).reshape((2,2,1,1)) 100 | 101 | assert test(bc) == [[[[2]]]] 102 | assert test(bd) == [[[[2]]]] 103 | assert test(cd) == [[[[3]]]] 104 | 105 | ab = np.array([ [ 1, 1 ], [ 0, 4 ] ]).reshape((2,2,1,1)) 106 | ac = np.array([ [ 1, 2 ], [ 1, 0 ] ]).reshape((2,2,1,1)) 107 | ad = np.array([ [ 1, 0 ], [ 3, 1 ] ]).reshape((2,2,1,1)) 108 | 109 | assert test(ab) == [[[[1]]]] 110 | assert test(ac) == [[[[1]]]] 111 | assert test(ad) == [[[[1]]]] 112 | 113 | def test_countless3d(): 114 | def test_all_cases(fn): 115 | alldifferent = [ 116 | [ 117 | [1,2], 118 | [3,4], 119 | ], 120 | [ 121 | [5,6], 122 | [7,8] 123 | ] 124 | ] 125 | allsame = [ 126 | [ 127 | [1,1], 128 | [1,1], 129 | ], 130 | [ 131 | [1,1], 132 | [1,1] 133 | ] 134 | ] 135 | 136 | assert fn(np.array(alldifferent)) == [[[8]]] 137 | assert fn(np.array(allsame)) == [[[1]]] 138 | 139 | twosame = deepcopy(alldifferent) 140 | twosame[1][1][0] = 2 141 | 142 | assert fn(np.array(twosame)) == [[[2]]] 143 | 144 | threemixed = [ 145 | [ 146 | [3,3], 147 | [1,2], 148 | ], 149 | [ 150 | [2,4], 151 | [4,3] 152 | ] 153 | ] 154 | assert fn(np.array(threemixed)) == [[[3]]] 155 | 156 | foursame = [ 157 | [ 158 | [4,4], 159 | [1,2], 160 | ], 161 | [ 162 | [2,4], 163 | [4,3] 164 | ] 165 | ] 166 | 167 | assert fn(np.array(foursame)) == [[[4]]] 168 | 169 | fivesame = [ 170 | [ 171 | [5,4], 172 | [5,5], 173 | ], 174 | [ 175 | [2,4], 176 | [5,5] 177 | ] 178 | ] 179 | 180 | assert fn(np.array(fivesame)) == [[[5]]] 181 | 182 | def countless3d_generalized(img): 183 | return countless3d.countless_generalized(img, (2,2,2)) 184 | def countless3d_dynamic_generalized(img): 185 | return countless3d.dynamic_countless_generalized(img, (2,2,2)) 186 | 187 | methods = [ 188 | countless3d.countless3d, 189 | countless3d.dynamic_countless3d, 190 | countless3d_generalized, 191 | countless3d_dynamic_generalized, 192 | ] 193 | 194 | for fn in methods: 195 | test_all_cases(fn) -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import zipfile 12 | import PIL.Image 13 | import json 14 | import torch 15 | import dnnlib 16 | import random 17 | from .masks import get_mask_generator 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | class Dataset(torch.utils.data.Dataset): 26 | def __init__(self, 27 | name, # Name of the dataset. 28 | raw_shape, # Shape of the raw image data (NCHW). 29 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 30 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 31 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 32 | random_seed = 0, # Random seed to use when applying max_size. 33 | mask_generator_kind="mixed", 34 | mask_gen_kwargs = None, 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._raw_labels = None 40 | self._label_shape = None 41 | self.iter_i = 0 42 | 43 | self.mask_generator = get_mask_generator(kind=mask_generator_kind, kwargs=mask_gen_kwargs) 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_mask_image(self, raw_idx): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def _load_raw_labels(self): # to be overridden by subclass 79 | raise NotImplementedError 80 | 81 | def __getstate__(self): 82 | return dict(self.__dict__, _raw_labels=None) 83 | 84 | def __del__(self): 85 | try: 86 | self.close() 87 | except: 88 | pass 89 | 90 | def __len__(self): 91 | return self._raw_idx.size 92 | 93 | def __getitem__(self, idx): 94 | mask_idx = random.randint(0, len(self._raw_idx)-1) 95 | # mask_image = self._load_mask_image(self._raw_idx[mask_idx]) 96 | raw_image = self._load_raw_image(self._raw_idx[idx]) 97 | mask_image = self.mask_generator(raw_image, iter_i=self.iter_i) 98 | 99 | assert isinstance(raw_image, np.ndarray) 100 | assert list(raw_image.shape) == self.image_shape 101 | # assert list(mask_image.shape) == self.image_shape 102 | assert raw_image.dtype == np.uint8 103 | assert mask_image.dtype == np.uint8 104 | if self._xflip[idx]: 105 | assert raw_image.ndim == 3 # CHW 106 | raw_image = raw_image[:, :, ::-1] 107 | mask_image = mask_image[:, :, ::-1] 108 | self.iter_i += 1 109 | return raw_image.copy(), mask_image.copy(), self.get_label(idx) 110 | 111 | def get_label(self, idx): 112 | label = self._get_raw_labels()[self._raw_idx[idx]] 113 | if label.dtype == np.int64: 114 | onehot = np.zeros(self.label_shape, dtype=np.float32) 115 | onehot[label] = 1 116 | label = onehot 117 | return label.copy() 118 | 119 | def get_details(self, idx): 120 | d = dnnlib.EasyDict() 121 | d.raw_idx = int(self._raw_idx[idx]) 122 | d.xflip = (int(self._xflip[idx]) != 0) 123 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 124 | return d 125 | 126 | @property 127 | def name(self): 128 | return self._name 129 | 130 | @property 131 | def image_shape(self): 132 | return list(self._raw_shape[1:]) 133 | 134 | @property 135 | def num_channels(self): 136 | assert len(self.image_shape) == 3 # CHW 137 | return self.image_shape[0] 138 | 139 | @property 140 | def resolution(self): 141 | assert len(self.image_shape) == 3 # CHW 142 | assert self.image_shape[1] == self.image_shape[2] 143 | return self.image_shape[1] 144 | 145 | @property 146 | def label_shape(self): 147 | if self._label_shape is None: 148 | raw_labels = self._get_raw_labels() 149 | if raw_labels.dtype == np.int64: 150 | self._label_shape = [int(np.max(raw_labels)) + 1] 151 | else: 152 | self._label_shape = raw_labels.shape[1:] 153 | return list(self._label_shape) 154 | 155 | @property 156 | def label_dim(self): 157 | assert len(self.label_shape) == 1 158 | return self.label_shape[0] 159 | 160 | @property 161 | def has_labels(self): 162 | return any(x != 0 for x in self.label_shape) 163 | 164 | @property 165 | def has_onehot_labels(self): 166 | return self._get_raw_labels().dtype == np.int64 167 | 168 | #---------------------------------------------------------------------------- 169 | 170 | class ImageFolderDataset(Dataset): 171 | def __init__(self, 172 | path, # Path to directory or zip. 173 | mask_path, 174 | resolution = None, # Ensure specific resolution, None = highest available. 175 | **super_kwargs, # Additional arguments for the Dataset base class. 176 | ): 177 | self._path = path 178 | self._mask_path = mask_path 179 | self._zipfile = None 180 | 181 | if os.path.isdir(self._path): 182 | self._type = 'dir' 183 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 184 | elif self._file_ext(self._path) == '.zip': 185 | self._type = 'zip' 186 | self._all_fnames = set(self._get_zipfile().namelist()) 187 | else: 188 | raise IOError('Path must point to a directory or zip') 189 | print("tt ", self._all_fnames) 190 | 191 | PIL.Image.init() 192 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 193 | if len(self._image_fnames) == 0: 194 | raise IOError('No image files found in the specified path') 195 | 196 | name = os.path.splitext(os.path.basename(self._path))[0] 197 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 198 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 199 | raise IOError('Image files do not match the specified resolution') 200 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 201 | 202 | @staticmethod 203 | def _file_ext(fname): 204 | return os.path.splitext(fname)[1].lower() 205 | 206 | def _get_zipfile(self, types='img'): 207 | assert self._type == 'zip' 208 | if self._zipfile is None: 209 | if types=='mask': 210 | self._zipfile = zipfile.ZipFile(self._mask_path) 211 | else: 212 | self._zipfile = zipfile.ZipFile(self._path) 213 | return self._zipfile 214 | 215 | def _open_file(self, fname, types='img'): 216 | if self._type == 'dir': 217 | if types=='mask': 218 | return open(os.path.join(self._mask_path, fname), 'rb') 219 | return open(os.path.join(self._path, fname), 'rb') 220 | if self._type == 'zip': 221 | if types=='mask': 222 | return zipfile.ZipFile(self._mask_path).open(fname, 'r') 223 | return zipfile.ZipFile(self._path).open(fname, 'r') 224 | # return self._get_zipfile(types=types).open(fname, 'r') 225 | return None 226 | 227 | def close(self): 228 | try: 229 | if self._zipfile is not None: 230 | self._zipfile.close() 231 | finally: 232 | self._zipfile = None 233 | 234 | def __getstate__(self): 235 | return dict(super().__getstate__(), _zipfile=None) 236 | 237 | def _load_raw_image(self, raw_idx): 238 | fname = self._image_fnames[raw_idx] 239 | with self._open_file(fname, 'img') as f: 240 | if pyspng is not None and self._file_ext(fname) == '.png': 241 | image = pyspng.load(f.read()) 242 | else: 243 | image = np.array(PIL.Image.open(f)) 244 | if image.ndim == 2: 245 | image = image[:, :, np.newaxis] # HW => HWC 246 | image = image.transpose(2, 0, 1) # HWC => CHW 247 | return image 248 | 249 | def _load_mask_image(self, raw_idx): 250 | fname = self._image_fnames[raw_idx] 251 | with self._open_file(fname, 'mask') as f: 252 | if pyspng is not None and self._file_ext(fname) == '.png': 253 | image = pyspng.load(f.read()) 254 | else: 255 | image = np.array(PIL.Image.open(f)) 256 | if image.ndim == 2: 257 | image = image[:, :, np.newaxis] # HW => HWC 258 | image = image.transpose(2, 0, 1) # HWC => CHW 259 | return image 260 | 261 | def _load_raw_labels(self): 262 | fname = 'dataset.json' 263 | if fname not in self._all_fnames: 264 | return None 265 | with self._open_file(fname) as f: 266 | labels = json.load(f)['labels'] 267 | if labels is None: 268 | return None 269 | labels = dict(labels) 270 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 271 | labels = np.array(labels) 272 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 273 | return labels 274 | 275 | #---------------------------------------------------------------------------- 276 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | from torch_utils import training_stats 12 | from torch_utils import misc 13 | from torch_utils.ops import conv2d_gradfix 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | class Loss: 18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass 19 | raise NotImplementedError() 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | class StyleGAN2Loss(Loss): 24 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2): 25 | super().__init__() 26 | self.device = device 27 | self.G_mapping = G_mapping 28 | self.G_synthesis = G_synthesis 29 | self.D = D 30 | self.augment_pipe = augment_pipe 31 | self.style_mixing_prob = style_mixing_prob 32 | self.r1_gamma = r1_gamma 33 | self.pl_batch_shrink = pl_batch_shrink 34 | self.pl_decay = pl_decay 35 | self.pl_weight = pl_weight 36 | self.pl_mean = torch.zeros([], device=device) 37 | 38 | def run_G(self, z, c, img, mask, sync): 39 | with misc.ddp_sync(self.G_mapping, sync): 40 | ws = self.G_mapping(z, c) 41 | if self.style_mixing_prob > 0: 42 | with torch.autograd.profiler.record_function('style_mixing'): 43 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 44 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 45 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] 46 | with misc.ddp_sync(self.G_synthesis, sync): 47 | img = self.G_synthesis(img, mask, ws) 48 | return img, ws 49 | 50 | def run_D(self, img, c, sync): 51 | if self.augment_pipe is not None: 52 | img = self.augment_pipe(img) 53 | with misc.ddp_sync(self.D, sync): 54 | logits = self.D(img, c) 55 | return logits 56 | 57 | def accumulate_gradients(self, phase, real_img, real_c, mask, gen_z, gen_c, sync, gain): 58 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 59 | do_Gmain = (phase in ['Gmain', 'Gboth']) 60 | do_Dmain = (phase in ['Dmain', 'Dboth']) 61 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) 62 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) 63 | l1_weight = 100 64 | 65 | # Gmain: Maximize logits for generated images. 66 | if do_Gmain: 67 | with torch.autograd.profiler.record_function('Gmain_forward'): 68 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, real_img, mask, sync=(sync and not do_Gpl)) # May get synced by Gpl. 69 | gen_logits = self.run_D(gen_img, gen_c, sync=False) 70 | training_stats.report('Loss/scores/fake', gen_logits) 71 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 72 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) 73 | loss_l1 = abs(torch.nn.functional.l1_loss(gen_img, real_img))*l1_weight 74 | training_stats.report('Loss/G/loss', loss_Gmain) 75 | training_stats.report('Loss/G/L1loss', loss_l1) 76 | with torch.autograd.profiler.record_function('Gmain_backward'): 77 | (loss_Gmain + loss_l1).mean().mul(gain).backward() 78 | 79 | # Gpl: Apply path length regularization. 80 | if do_Gpl: 81 | with torch.autograd.profiler.record_function('Gpl_forward'): 82 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 83 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], real_img[:batch_size], mask[:batch_size], sync=sync) 84 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 85 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): 86 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 87 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 88 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 89 | self.pl_mean.copy_(pl_mean.detach()) 90 | pl_penalty = (pl_lengths - pl_mean).square() 91 | training_stats.report('Loss/pl_penalty', pl_penalty) 92 | loss_Gpl = pl_penalty * self.pl_weight 93 | # print(gen_img.size(), real_img.size()) 94 | loss_l1 = abs(torch.nn.functional.l1_loss(gen_img, real_img[:batch_size]))*l1_weight 95 | training_stats.report('Loss/G/reg', loss_Gpl) 96 | with torch.autograd.profiler.record_function('Gpl_backward'): 97 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl + loss_l1).mean().mul(gain).backward() 98 | 99 | # Dmain: Minimize logits for generated images. 100 | loss_Dgen = 0 101 | if do_Dmain: 102 | with torch.autograd.profiler.record_function('Dgen_forward'): 103 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, real_img, mask, sync=False) 104 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. 105 | training_stats.report('Loss/scores/fake', gen_logits) 106 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 107 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) 108 | with torch.autograd.profiler.record_function('Dgen_backward'): 109 | loss_Dgen.mean().mul(gain).backward() 110 | 111 | # Dmain: Maximize logits for real images. 112 | # Dr1: Apply R1 regularization. 113 | if do_Dmain or do_Dr1: 114 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' 115 | with torch.autograd.profiler.record_function(name + '_forward'): 116 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1) 117 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync) 118 | training_stats.report('Loss/scores/real', real_logits) 119 | training_stats.report('Loss/signs/real', real_logits.sign()) 120 | 121 | loss_Dreal = 0 122 | if do_Dmain: 123 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) 124 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 125 | 126 | loss_Dr1 = 0 127 | if do_Dr1: 128 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 129 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 130 | r1_penalty = r1_grads.square().sum([1,2,3]) 131 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 132 | training_stats.report('Loss/r1_penalty', r1_penalty) 133 | training_stats.report('Loss/D/reg', loss_Dr1) 134 | 135 | with torch.autograd.profiler.record_function(name + '_backward'): 136 | (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward() 137 | 138 | #---------------------------------------------------------------------------- 139 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch_utils import misc 5 | from torch_utils import persistence 6 | from torch_utils.ops import conv2d_resample 7 | from torch_utils.ops import upfirdn2d 8 | from torch_utils.ops import bias_act 9 | from torch_utils.ops import fma 10 | 11 | def make_kernel(k): 12 | k = torch.tensor(k, dtype=torch.float32) 13 | 14 | if k.ndim == 1: 15 | k = k[None, :] * k[:, None] 16 | 17 | k /= k.sum() 18 | 19 | return k 20 | 21 | class Blur(nn.Module): 22 | def __init__(self, kernel, pad, upsample_factor=1): 23 | super().__init__() 24 | 25 | kernel = make_kernel(kernel) 26 | 27 | if upsample_factor > 1: 28 | kernel = kernel * (upsample_factor ** 2) 29 | 30 | self.register_buffer("kernel", kernel) 31 | 32 | self.pad = pad 33 | 34 | def forward(self, input): 35 | out = upfirdn2d(input, self.kernel, pad=self.pad) 36 | 37 | return out --------------------------------------------------------------------------------