├── LICENSE.txt
├── NOTICE.txt
├── README.md
├── calc_metrics.py
├── dnnlib
├── __init__.py
└── util.py
├── figures
├── d_vis.png
├── feat_vis.png
├── model.png
└── sample.png
├── generate.py
├── legacy.py
├── metrics
├── __init__.py
├── frechet_inception_distance.py
├── inception_score.py
├── kernel_inception_distance.py
├── metric_main.py
├── metric_utils.py
├── perceptual_path_length.py
└── precision_recall.py
├── tools
└── visualize_gfeat.py
├── torch_utils
├── __init__.py
├── custom_ops.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
├── persistence.py
└── training_stats.py
├── train.py
└── training
├── __init__.py
├── augment.py
├── dataset.py
├── loss.py
├── loss_ggdr.py
├── networks.py
├── networks_ggdr.py
└── training_loop.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
2 |
3 |
4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
5 |
6 |
7 | =======================================================================
8 |
9 | 1. Definitions
10 |
11 | "Licensor" means any person or entity that distributes its Work.
12 |
13 | "Software" means the original work of authorship made available under
14 | this License.
15 |
16 | "Work" means the Software and any additions to or derivative works of
17 | the Software that are made available under this License.
18 |
19 | The terms "reproduce," "reproduction," "derivative works," and
20 | "distribution" have the meaning as provided under U.S. copyright law;
21 | provided, however, that for the purposes of this License, derivative
22 | works shall not include works that remain separable from, or merely
23 | link (or bind by name) to the interfaces of, the Work.
24 |
25 | Works, including the Software, are "made available" under this License
26 | by including in or with the Work either (a) a copyright notice
27 | referencing the applicability of this License to the Work, or (b) a
28 | copy of this License.
29 |
30 | 2. License Grants
31 |
32 | 2.1 Copyright Grant. Subject to the terms and conditions of this
33 | License, each Licensor grants to you a perpetual, worldwide,
34 | non-exclusive, royalty-free, copyright license to reproduce,
35 | prepare derivative works of, publicly display, publicly perform,
36 | sublicense and distribute its Work and any resulting derivative
37 | works in any form.
38 |
39 | 3. Limitations
40 |
41 | 3.1 Redistribution. You may reproduce or distribute the Work only
42 | if (a) you do so under this License, (b) you include a complete
43 | copy of this License with your distribution, and (c) you retain
44 | without modification any copyright, patent, trademark, or
45 | attribution notices that are present in the Work.
46 |
47 | 3.2 Derivative Works. You may specify that additional or different
48 | terms apply to the use, reproduction, and distribution of your
49 | derivative works of the Work ("Your Terms") only if (a) Your Terms
50 | provide that the use limitation in Section 3.3 applies to your
51 | derivative works, and (b) you identify the specific derivative
52 | works that are subject to Your Terms. Notwithstanding Your Terms,
53 | this License (including the redistribution requirements in Section
54 | 3.1) will continue to apply to the Work itself.
55 |
56 | 3.3 Use Limitation. The Work and any derivative works thereof only
57 | may be used or intended for use non-commercially. Notwithstanding
58 | the foregoing, NVIDIA and its affiliates may use the Work and any
59 | derivative works commercially. As used herein, "non-commercially"
60 | means for research or evaluation purposes only.
61 |
62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63 | against any Licensor (including any claim, cross-claim or
64 | counterclaim in a lawsuit) to enforce any patents that you allege
65 | are infringed by any Work, then your rights under this License from
66 | such Licensor (including the grant in Section 2.1) will terminate
67 | immediately.
68 |
69 | 3.5 Trademarks. This License does not grant any rights to use any
70 | Licensor’s or its affiliates’ names, logos, or trademarks, except
71 | as necessary to reproduce the notices described in this License.
72 |
73 | 3.6 Termination. If you violate any term of this License, then your
74 | rights under this License (including the grant in Section 2.1) will
75 | terminate immediately.
76 |
77 | 4. Disclaimer of Warranty.
78 |
79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83 | THIS LICENSE.
84 |
85 | 5. Limitation of Liability.
86 |
87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95 | THE POSSIBILITY OF SUCH DAMAGES.
96 |
97 | =======================================================================
98 |
--------------------------------------------------------------------------------
/NOTICE.txt:
--------------------------------------------------------------------------------
1 | GGDR is based on the StyleGAN2-ADA project(https://github.com/NVlabs/stylegan2-ada-pytorch),
2 | so heavily brought their codes.
3 |
4 | training/loss_ggdr.py, training/network_ggdr.py
5 | - Copyright (c) 2022-present NAVER Corp.
6 | all other files
7 | - Copyright (c) 2021, NVIDIA Corporation.
8 |
9 | under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
10 |
11 | ---
12 |
13 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
14 |
15 |
16 | =======================================================================
17 |
18 | 1. Definitions
19 |
20 | "Licensor" means any person or entity that distributes its Work.
21 |
22 | "Software" means the original work of authorship made available under
23 | this License.
24 |
25 | "Work" means the Software and any additions to or derivative works of
26 | the Software that are made available under this License.
27 |
28 | The terms "reproduce," "reproduction," "derivative works," and
29 | "distribution" have the meaning as provided under U.S. copyright law;
30 | provided, however, that for the purposes of this License, derivative
31 | works shall not include works that remain separable from, or merely
32 | link (or bind by name) to the interfaces of, the Work.
33 |
34 | Works, including the Software, are "made available" under this License
35 | by including in or with the Work either (a) a copyright notice
36 | referencing the applicability of this License to the Work, or (b) a
37 | copy of this License.
38 |
39 | 2. License Grants
40 |
41 | 2.1 Copyright Grant. Subject to the terms and conditions of this
42 | License, each Licensor grants to you a perpetual, worldwide,
43 | non-exclusive, royalty-free, copyright license to reproduce,
44 | prepare derivative works of, publicly display, publicly perform,
45 | sublicense and distribute its Work and any resulting derivative
46 | works in any form.
47 |
48 | 3. Limitations
49 |
50 | 3.1 Redistribution. You may reproduce or distribute the Work only
51 | if (a) you do so under this License, (b) you include a complete
52 | copy of this License with your distribution, and (c) you retain
53 | without modification any copyright, patent, trademark, or
54 | attribution notices that are present in the Work.
55 |
56 | 3.2 Derivative Works. You may specify that additional or different
57 | terms apply to the use, reproduction, and distribution of your
58 | derivative works of the Work ("Your Terms") only if (a) Your Terms
59 | provide that the use limitation in Section 3.3 applies to your
60 | derivative works, and (b) you identify the specific derivative
61 | works that are subject to Your Terms. Notwithstanding Your Terms,
62 | this License (including the redistribution requirements in Section
63 | 3.1) will continue to apply to the Work itself.
64 |
65 | 3.3 Use Limitation. The Work and any derivative works thereof only
66 | may be used or intended for use non-commercially. Notwithstanding
67 | the foregoing, NVIDIA and its affiliates may use the Work and any
68 | derivative works commercially. As used herein, "non-commercially"
69 | means for research or evaluation purposes only.
70 |
71 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
72 | against any Licensor (including any claim, cross-claim or
73 | counterclaim in a lawsuit) to enforce any patents that you allege
74 | are infringed by any Work, then your rights under this License from
75 | such Licensor (including the grant in Section 2.1) will terminate
76 | immediately.
77 |
78 | 3.5 Trademarks. This License does not grant any rights to use any
79 | Licensor’s or its affiliates’ names, logos, or trademarks, except
80 | as necessary to reproduce the notices described in this License.
81 |
82 | 3.6 Termination. If you violate any term of this License, then your
83 | rights under this License (including the grant in Section 2.1) will
84 | terminate immediately.
85 |
86 | 4. Disclaimer of Warranty.
87 |
88 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
89 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
90 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
91 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
92 | THIS LICENSE.
93 |
94 | 5. Limitation of Liability.
95 |
96 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
97 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
98 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
99 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
100 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
101 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
102 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
103 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
104 | THE POSSIBILITY OF SUCH DAMAGES.
105 |
106 | =======================================================================
107 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GGDR - Generator-Guided Regularization for Discriminator (Official PyTorch Implementation)
2 | **[Generator Knows What Discriminator Should Learn in Unconditional GANs (ECCV 2022)](http://arxiv.org/abs/2207.13320)** \
3 | Gayoung Lee1, Hyunsu Kim1, Junho Kim1, Seonghyeon Kim2, Jung-Woo Ha1, Yunjey Choi1
4 |
5 | 1NAVER AI Lab, 2NAVER CLOVA
6 |
7 |
8 |

9 |
10 |
11 | > **Abstract** *Recent conditional image generation methods benefit from dense supervision such as segmentation label maps to achieve high-fidelity. However, it is rarely explored to employ dense supervision for unconditional image generation. Here we explore the efficacy of dense supervision in unconditional generation and find generator feature maps can be an alternative of cost-expensive semantic label maps. From our empirical evidences, we propose a new **generator-guided discriminator regularization (GGDR)** in which the generator feature maps supervise the discriminator to have rich semantic representations in unconditional generation. In specific, we employ an encoder-decoder architecture for discriminator, which is trained to reconstruct the generator feature maps given fake images as inputs. Extensive experiments on mulitple datasets show that our GGDR consistently improves the performance of baseline methods in terms of quantitative and qualitative aspects. Code will be publicly available for the research community.*
12 |
13 | ## Credit
14 | We attach GGDR to [StyleGAN2-ADA-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch), so heavily brought their codes.
15 |
16 | ## Usage
17 | Usage of this repository is almost same with [StyleGAN2-ADA-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch) except GGDR option. See their repository for more detailed instructions.
18 |
19 | #### Training StyleGAN2-ADA with GGDR
20 | ```
21 | > python train.py --outdir=training-runs --reg_type=ggdr --ggdr_res=64 --gpus=8 --cfg=paper256 --data=./datasets/ffhq256.zip
22 | ```
23 | Belows are some additional arguments can be customized.
24 | - ```--reg_type=ggdr``` Enable GGDR (default: disabled)
25 | - ```--ggdr_res=64``` Set target feature map by given resolution for GGDR (default: 64). If you use smaller images(e.g. cifar10), it is recommended to set this $(resolution) / 4 (e.g. 8 for cifar10).
26 | - ```--aug=noaug``` Disables ADA (default: enabled)
27 | - ```--mirror=1``` Enables x-flips (default: disabled)
28 |
29 | #### Inference with trained model
30 | ```
31 | > python generate.py --outdir=out --seeds=100-200 --network=PATH_TO_MODEL
32 | ```
33 |
34 | ## Results
35 | ### Selective samples in the paper
36 |
37 |

38 |
39 |
40 | ### Discriminator feature map visualization
41 |
42 |

43 |
44 |
45 |
46 |
47 | ## License
48 | Licensed under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA).
49 |
50 | ## Citation
51 | ```bibtex
52 | @inproceedings{lee2022ggdr,
53 | title={Generator Knows What Discriminator Should Learn in Unconditional GANs},
54 | author={Lee, Gayoung and Kim, Hyunsu and Kim, Junho and Kim, Seonghyeon and Ha, Jung-Woo and Choi, Yunjey},
55 | booktitle={ECCV},
56 | year={2022}
57 | }
58 | ```
59 |
--------------------------------------------------------------------------------
/calc_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Calculate quality metrics for previous training run or pretrained network pickle."""
10 |
11 | import os
12 | import click
13 | import json
14 | import tempfile
15 | import copy
16 | import torch
17 | import dnnlib
18 |
19 | import legacy
20 | from metrics import metric_main
21 | from metrics import metric_utils
22 | from torch_utils import training_stats
23 | from torch_utils import custom_ops
24 | from torch_utils import misc
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def subprocess_fn(rank, args, temp_dir):
29 | dnnlib.util.Logger(should_flush=True)
30 |
31 | # Init torch.distributed.
32 | if args.num_gpus > 1:
33 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
34 | if os.name == 'nt':
35 | init_method = 'file:///' + init_file.replace('\\', '/')
36 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
37 | else:
38 | init_method = f'file://{init_file}'
39 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
40 |
41 | # Init torch_utils.
42 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
43 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
44 | if rank != 0 or not args.verbose:
45 | custom_ops.verbosity = 'none'
46 |
47 | # Print network summary.
48 | device = torch.device('cuda', rank)
49 | torch.backends.cudnn.benchmark = True
50 | torch.backends.cuda.matmul.allow_tf32 = False
51 | torch.backends.cudnn.allow_tf32 = False
52 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
53 | if rank == 0 and args.verbose:
54 | z = torch.empty([1, G.z_dim], device=device)
55 | c = torch.empty([1, G.c_dim], device=device)
56 | misc.print_module_summary(G, [z, c])
57 |
58 | # Calculate each metric.
59 | for metric in args.metrics:
60 | if rank == 0 and args.verbose:
61 | print(f'Calculating {metric}...')
62 | progress = metric_utils.ProgressMonitor(verbose=args.verbose)
63 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
64 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
65 | if rank == 0:
66 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
67 | if rank == 0 and args.verbose:
68 | print()
69 |
70 | # Done.
71 | if rank == 0 and args.verbose:
72 | print('Exiting...')
73 |
74 | #----------------------------------------------------------------------------
75 |
76 | class CommaSeparatedList(click.ParamType):
77 | name = 'list'
78 |
79 | def convert(self, value, param, ctx):
80 | _ = param, ctx
81 | if value is None or value.lower() == 'none' or value == '':
82 | return []
83 | return value.split(',')
84 |
85 | #----------------------------------------------------------------------------
86 |
87 | @click.command()
88 | @click.pass_context
89 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
90 | @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
91 | @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
92 | @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
93 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
94 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
95 |
96 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
97 | """Calculate quality metrics for previous training run or pretrained network pickle.
98 |
99 | Examples:
100 |
101 | \b
102 | # Previous training run: look up options automatically, save result to JSONL file.
103 | python calc_metrics.py --metrics=pr50k3_full \\
104 | --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
105 |
106 | \b
107 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
108 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
109 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
110 |
111 | Available metrics:
112 |
113 | \b
114 | ADA paper:
115 | fid50k_full Frechet inception distance against the full dataset.
116 | kid50k_full Kernel inception distance against the full dataset.
117 | pr50k3_full Precision and recall againt the full dataset.
118 | is50k Inception score for CIFAR-10.
119 |
120 | \b
121 | StyleGAN and StyleGAN2 papers:
122 | fid50k Frechet inception distance against 50k real images.
123 | kid50k Kernel inception distance against 50k real images.
124 | pr50k3 Precision and recall against 50k real images.
125 | ppl2_wend Perceptual path length in W at path endpoints against full image.
126 | ppl_zfull Perceptual path length in Z for full paths against cropped image.
127 | ppl_wfull Perceptual path length in W for full paths against cropped image.
128 | ppl_zend Perceptual path length in Z at path endpoints against cropped image.
129 | ppl_wend Perceptual path length in W at path endpoints against cropped image.
130 | """
131 | dnnlib.util.Logger(should_flush=True)
132 |
133 | # Validate arguments.
134 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
135 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
136 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
137 | if not args.num_gpus >= 1:
138 | ctx.fail('--gpus must be at least 1')
139 |
140 | # Load network.
141 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
142 | ctx.fail('--network must point to a file or URL')
143 | if args.verbose:
144 | print(f'Loading network from "{network_pkl}"...')
145 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
146 | network_dict = legacy.load_network_pkl(f)
147 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module
148 |
149 | # Initialize dataset options.
150 | if data is not None:
151 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
152 | elif network_dict['training_set_kwargs'] is not None:
153 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
154 | else:
155 | ctx.fail('Could not look up dataset options; please specify --data')
156 |
157 | # Finalize dataset options.
158 | args.dataset_kwargs.resolution = args.G.img_resolution
159 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
160 | if mirror is not None:
161 | args.dataset_kwargs.xflip = mirror
162 |
163 | # Print dataset options.
164 | if args.verbose:
165 | print('Dataset options:')
166 | print(json.dumps(args.dataset_kwargs, indent=2))
167 |
168 | # Locate run dir.
169 | args.run_dir = None
170 | if os.path.isfile(network_pkl):
171 | pkl_dir = os.path.dirname(network_pkl)
172 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
173 | args.run_dir = pkl_dir
174 |
175 | # Launch processes.
176 | if args.verbose:
177 | print('Launching processes...')
178 | torch.multiprocessing.set_start_method('spawn')
179 | with tempfile.TemporaryDirectory() as temp_dir:
180 | if args.num_gpus == 1:
181 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
182 | else:
183 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
184 |
185 | #----------------------------------------------------------------------------
186 |
187 | if __name__ == "__main__":
188 | calc_metrics() # pylint: disable=no-value-for-parameter
189 |
190 | #----------------------------------------------------------------------------
191 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/figures/d_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/d_vis.png
--------------------------------------------------------------------------------
/figures/feat_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/feat_vis.png
--------------------------------------------------------------------------------
/figures/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/model.png
--------------------------------------------------------------------------------
/figures/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/naver-ai/GGDR/7071792bf715c05177b0c21577df14769fbdfb4e/figures/sample.png
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Generate images using pretrained network pickle."""
10 |
11 | import os
12 | import re
13 | from typing import List, Optional
14 |
15 | import click
16 | import dnnlib
17 | import numpy as np
18 | import PIL.Image
19 | import torch
20 |
21 | import legacy
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | def num_range(s: str) -> List[int]:
26 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27 |
28 | range_re = re.compile(r'^(\d+)-(\d+)$')
29 | m = range_re.match(s)
30 | if m:
31 | return list(range(int(m.group(1)), int(m.group(2))+1))
32 | vals = s.split(',')
33 | return [int(x) for x in vals]
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | @click.command()
38 | @click.pass_context
39 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
40 | @click.option('--seeds', type=num_range, help='List of random seeds')
41 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
42 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
43 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44 | @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
45 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
46 | def generate_images(
47 | ctx: click.Context,
48 | network_pkl: str,
49 | seeds: Optional[List[int]],
50 | truncation_psi: float,
51 | noise_mode: str,
52 | outdir: str,
53 | class_idx: Optional[int],
54 | projected_w: Optional[str]
55 | ):
56 | """Generate images using pretrained network pickle.
57 |
58 | Examples:
59 |
60 | \b
61 | # Generate curated MetFaces images without truncation (Fig.10 left)
62 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
63 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
64 |
65 | \b
66 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
67 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
68 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
69 |
70 | \b
71 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
72 | python generate.py --outdir=out --seeds=0-35 --class=1 \\
73 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
74 |
75 | \b
76 | # Render an image from projected W
77 | python generate.py --outdir=out --projected_w=projected_w.npz \\
78 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
79 | """
80 |
81 | print('Loading networks from "%s"...' % network_pkl)
82 | device = torch.device('cuda')
83 | with dnnlib.util.open_url(network_pkl) as f:
84 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
85 |
86 | print("start here")
87 |
88 | print(G)
89 |
90 | os.makedirs(outdir, exist_ok=True)
91 |
92 | # Synthesize the result of a W projection.
93 | if projected_w is not None:
94 | if seeds is not None:
95 | print ('warn: --seeds is ignored when using --projected-w')
96 | print(f'Generating images from projected W "{projected_w}"')
97 | ws = np.load(projected_w)['w']
98 | ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
99 | assert ws.shape[1:] == (G.num_ws, G.w_dim)
100 | for idx, w in enumerate(ws):
101 | img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
102 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
103 | img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
104 | return
105 |
106 | if seeds is None:
107 | ctx.fail('--seeds option is required when not using --projected-w')
108 |
109 | # Labels.
110 | label = torch.zeros([1, G.c_dim], device=device)
111 | if G.c_dim != 0:
112 | if class_idx is None:
113 | ctx.fail('Must specify class label with --class when using a conditional network')
114 | label[:, class_idx] = 1
115 | else:
116 | if class_idx is not None:
117 | print ('warn: --class=lbl ignored when running on an unconditional network')
118 |
119 | # Generate images.
120 | for seed_idx, seed in enumerate(seeds):
121 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
122 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
123 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
124 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
125 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
126 |
127 |
128 | #----------------------------------------------------------------------------
129 |
130 | if __name__ == "__main__":
131 | generate_images() # pylint: disable=no-value-for-parameter
132 |
133 | #----------------------------------------------------------------------------
134 |
--------------------------------------------------------------------------------
/legacy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import click
10 | import pickle
11 | import re
12 | import copy
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | from torch_utils import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def load_network_pkl(f, force_fp16=False):
21 | data = _LegacyUnpickler(f).load()
22 |
23 | # Legacy TensorFlow pickle => convert.
24 | if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25 | tf_G, tf_D, tf_Gs = data
26 | G = convert_tf_generator(tf_G)
27 | D = convert_tf_discriminator(tf_D)
28 | G_ema = convert_tf_generator(tf_Gs)
29 | data = dict(G=G, D=D, G_ema=G_ema)
30 |
31 | # Add missing fields.
32 | if 'training_set_kwargs' not in data:
33 | data['training_set_kwargs'] = None
34 | if 'augment_pipe' not in data:
35 | data['augment_pipe'] = None
36 |
37 | # Validate contents.
38 | # assert isinstance(data['G'], torch.nn.Module)
39 | # assert isinstance(data['D'], torch.nn.Module)
40 | assert isinstance(data['G_ema'], torch.nn.Module)
41 | # assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42 | # assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43 |
44 | # Force FP16.
45 | if force_fp16:
46 | for key in ['G', 'D', 'G_ema']:
47 | old = data[key]
48 | kwargs = copy.deepcopy(old.init_kwargs)
49 | if key.startswith('G'):
50 | kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51 | kwargs.synthesis_kwargs.num_fp16_res = 4
52 | kwargs.synthesis_kwargs.conv_clamp = 256
53 | if key.startswith('D'):
54 | kwargs.num_fp16_res = 4
55 | kwargs.conv_clamp = 256
56 | if kwargs != old.init_kwargs:
57 | new = type(old)(**kwargs).eval().requires_grad_(False)
58 | misc.copy_params_and_buffers(old, new, require_all=True)
59 | data[key] = new
60 | return data
61 |
62 | #----------------------------------------------------------------------------
63 |
64 | class _TFNetworkStub(dnnlib.EasyDict):
65 | pass
66 |
67 | class _LegacyUnpickler(pickle.Unpickler):
68 | def find_class(self, module, name):
69 | if module == 'dnnlib.tflib.network' and name == 'Network':
70 | return _TFNetworkStub
71 | return super().find_class(module, name)
72 |
73 | #----------------------------------------------------------------------------
74 |
75 | def _collect_tf_params(tf_net):
76 | # pylint: disable=protected-access
77 | tf_params = dict()
78 | def recurse(prefix, tf_net):
79 | for name, value in tf_net.variables:
80 | tf_params[prefix + name] = value
81 | for name, comp in tf_net.components.items():
82 | recurse(prefix + name + '/', comp)
83 | recurse('', tf_net)
84 | return tf_params
85 |
86 | #----------------------------------------------------------------------------
87 |
88 | def _populate_module_params(module, *patterns):
89 | for name, tensor in misc.named_params_and_buffers(module):
90 | found = False
91 | value = None
92 | for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93 | match = re.fullmatch(pattern, name)
94 | if match:
95 | found = True
96 | if value_fn is not None:
97 | value = value_fn(*match.groups())
98 | break
99 | try:
100 | assert found
101 | if value is not None:
102 | tensor.copy_(torch.from_numpy(np.array(value)))
103 | except:
104 | print(name, list(tensor.shape))
105 | raise
106 |
107 | #----------------------------------------------------------------------------
108 |
109 | def convert_tf_generator(tf_G):
110 | if tf_G.version < 4:
111 | raise ValueError('TensorFlow pickle version too low')
112 |
113 | # Collect kwargs.
114 | tf_kwargs = tf_G.static_kwargs
115 | known_kwargs = set()
116 | def kwarg(tf_name, default=None, none=None):
117 | known_kwargs.add(tf_name)
118 | val = tf_kwargs.get(tf_name, default)
119 | return val if val is not None else none
120 |
121 | # Convert kwargs.
122 | kwargs = dnnlib.EasyDict(
123 | z_dim = kwarg('latent_size', 512),
124 | c_dim = kwarg('label_size', 0),
125 | w_dim = kwarg('dlatent_size', 512),
126 | img_resolution = kwarg('resolution', 1024),
127 | img_channels = kwarg('num_channels', 3),
128 | mapping_kwargs = dnnlib.EasyDict(
129 | num_layers = kwarg('mapping_layers', 8),
130 | embed_features = kwarg('label_fmaps', None),
131 | layer_features = kwarg('mapping_fmaps', None),
132 | activation = kwarg('mapping_nonlinearity', 'lrelu'),
133 | lr_multiplier = kwarg('mapping_lrmul', 0.01),
134 | w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135 | ),
136 | synthesis_kwargs = dnnlib.EasyDict(
137 | channel_base = kwarg('fmap_base', 16384) * 2,
138 | channel_max = kwarg('fmap_max', 512),
139 | num_fp16_res = kwarg('num_fp16_res', 0),
140 | conv_clamp = kwarg('conv_clamp', None),
141 | architecture = kwarg('architecture', 'skip'),
142 | resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143 | use_noise = kwarg('use_noise', True),
144 | activation = kwarg('nonlinearity', 'lrelu'),
145 | ),
146 | )
147 |
148 | # Check for unknown kwargs.
149 | kwarg('truncation_psi')
150 | kwarg('truncation_cutoff')
151 | kwarg('style_mixing_prob')
152 | kwarg('structure')
153 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154 | if len(unknown_kwargs) > 0:
155 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156 |
157 | # Collect params.
158 | tf_params = _collect_tf_params(tf_G)
159 | for name, value in list(tf_params.items()):
160 | match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161 | if match:
162 | r = kwargs.img_resolution // (2 ** int(match.group(1)))
163 | tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164 | kwargs.synthesis.kwargs.architecture = 'orig'
165 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166 |
167 | # Convert params.
168 | from training import networks
169 | G = networks.Generator(**kwargs).eval().requires_grad_(False)
170 | # pylint: disable=unnecessary-lambda
171 | _populate_module_params(G,
172 | r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173 | r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174 | r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177 | r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178 | r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179 | r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180 | r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181 | r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182 | r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183 | r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184 | r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185 | r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186 | r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187 | r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188 | r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189 | r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190 | r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191 | r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192 | r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193 | r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194 | r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195 | r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196 | r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197 | r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198 | r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199 | r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200 | r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201 | r'.*\.resample_filter', None,
202 | )
203 | return G
204 |
205 | #----------------------------------------------------------------------------
206 |
207 | def convert_tf_discriminator(tf_D):
208 | if tf_D.version < 4:
209 | raise ValueError('TensorFlow pickle version too low')
210 |
211 | # Collect kwargs.
212 | tf_kwargs = tf_D.static_kwargs
213 | known_kwargs = set()
214 | def kwarg(tf_name, default=None):
215 | known_kwargs.add(tf_name)
216 | return tf_kwargs.get(tf_name, default)
217 |
218 | # Convert kwargs.
219 | kwargs = dnnlib.EasyDict(
220 | c_dim = kwarg('label_size', 0),
221 | img_resolution = kwarg('resolution', 1024),
222 | img_channels = kwarg('num_channels', 3),
223 | architecture = kwarg('architecture', 'resnet'),
224 | channel_base = kwarg('fmap_base', 16384) * 2,
225 | channel_max = kwarg('fmap_max', 512),
226 | num_fp16_res = kwarg('num_fp16_res', 0),
227 | conv_clamp = kwarg('conv_clamp', None),
228 | cmap_dim = kwarg('mapping_fmaps', None),
229 | block_kwargs = dnnlib.EasyDict(
230 | activation = kwarg('nonlinearity', 'lrelu'),
231 | resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232 | freeze_layers = kwarg('freeze_layers', 0),
233 | ),
234 | mapping_kwargs = dnnlib.EasyDict(
235 | num_layers = kwarg('mapping_layers', 0),
236 | embed_features = kwarg('mapping_fmaps', None),
237 | layer_features = kwarg('mapping_fmaps', None),
238 | activation = kwarg('nonlinearity', 'lrelu'),
239 | lr_multiplier = kwarg('mapping_lrmul', 0.1),
240 | ),
241 | epilogue_kwargs = dnnlib.EasyDict(
242 | mbstd_group_size = kwarg('mbstd_group_size', None),
243 | mbstd_num_channels = kwarg('mbstd_num_features', 1),
244 | activation = kwarg('nonlinearity', 'lrelu'),
245 | ),
246 | )
247 |
248 | # Check for unknown kwargs.
249 | kwarg('structure')
250 | unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251 | if len(unknown_kwargs) > 0:
252 | raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253 |
254 | # Collect params.
255 | tf_params = _collect_tf_params(tf_D)
256 | for name, value in list(tf_params.items()):
257 | match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258 | if match:
259 | r = kwargs.img_resolution // (2 ** int(match.group(1)))
260 | tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261 | kwargs.architecture = 'orig'
262 | #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263 |
264 | # Convert params.
265 | from training import networks
266 | D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267 | # pylint: disable=unnecessary-lambda
268 | _populate_module_params(D,
269 | r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270 | r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271 | r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272 | r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273 | r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274 | r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275 | r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276 | r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277 | r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278 | r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279 | r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280 | r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281 | r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282 | r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283 | r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284 | r'.*\.resample_filter', None,
285 | )
286 | return D
287 |
288 | #----------------------------------------------------------------------------
289 |
290 | @click.command()
291 | @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292 | @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293 | @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294 | def convert_network_pickle(source, dest, force_fp16):
295 | """Convert legacy network pickle into the native PyTorch format.
296 |
297 | The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298 | It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299 |
300 | Example:
301 |
302 | \b
303 | python legacy.py \\
304 | --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305 | --dest=stylegan2-cat-config-f.pkl
306 | """
307 | print(f'Loading "{source}"...')
308 | with dnnlib.util.open_url(source) as f:
309 | data = load_network_pkl(f, force_fp16=force_fp16)
310 | print(f'Saving "{dest}"...')
311 | with open(dest, 'wb') as f:
312 | pickle.dump(data, f)
313 | print('Done.')
314 |
315 | #----------------------------------------------------------------------------
316 |
317 | if __name__ == "__main__":
318 | convert_network_pickle() # pylint: disable=no-value-for-parameter
319 |
320 | #----------------------------------------------------------------------------
321 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Frechet Inception Distance (FID) from the paper
10 | "GANs trained by a two time-scale update rule converge to a local Nash
11 | equilibrium". Matches the original implementation by Heusel et al. at
12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13 |
14 | import numpy as np
15 | import scipy.linalg
16 | from . import metric_utils
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_fid(opts, max_real, num_gen):
21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24 |
25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28 |
29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32 |
33 | if opts.rank != 0:
34 | return float('nan')
35 |
36 | m = np.square(mu_gen - mu_real).sum()
37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39 | return float(fid)
40 |
41 | #----------------------------------------------------------------------------
42 |
--------------------------------------------------------------------------------
/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Inception Score (IS) from the paper "Improved techniques for training
10 | GANs". Matches the original implementation by Salimans et al. at
11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12 |
13 | import numpy as np
14 | from . import metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_is(opts, num_gen, num_splits):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22 |
23 | gen_probs = metric_utils.compute_feature_stats_for_generator(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | capture_all=True, max_items=num_gen).get_all()
26 |
27 | if opts.rank != 0:
28 | return float('nan'), float('nan')
29 |
30 | scores = []
31 | for i in range(num_splits):
32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34 | kl = np.mean(np.sum(kl, axis=1))
35 | scores.append(np.exp(kl))
36 | return float(np.mean(scores)), float(np.std(scores))
37 |
38 | #----------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/metrics/kernel_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10 | GANs". Matches the original implementation by Binkowski et al. at
11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12 |
13 | import numpy as np
14 | from . import metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22 |
23 | real_features = metric_utils.compute_feature_stats_for_dataset(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26 |
27 | gen_features = metric_utils.compute_feature_stats_for_generator(
28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30 |
31 | if opts.rank != 0:
32 | return float('nan')
33 |
34 | n = real_features.shape[1]
35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36 | t = 0
37 | for _subset_idx in range(num_subsets):
38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41 | b = (x @ y.T / n + 1) ** 3
42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43 | kid = t / num_subsets / m
44 | return float(kid)
45 |
46 | #----------------------------------------------------------------------------
47 |
--------------------------------------------------------------------------------
/metrics/metric_main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import json
12 | import torch
13 | import dnnlib
14 |
15 | from . import metric_utils
16 | from . import frechet_inception_distance
17 | from . import kernel_inception_distance
18 | from . import precision_recall
19 | from . import perceptual_path_length
20 | from . import inception_score
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | _metric_dict = dict() # name => fn
25 |
26 | def register_metric(fn):
27 | assert callable(fn)
28 | _metric_dict[fn.__name__] = fn
29 | return fn
30 |
31 | def is_valid_metric(metric):
32 | return metric in _metric_dict
33 |
34 | def list_valid_metrics():
35 | return list(_metric_dict.keys())
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40 | assert is_valid_metric(metric)
41 | opts = metric_utils.MetricOptions(**kwargs)
42 |
43 | # Calculate.
44 | start_time = time.time()
45 | results = _metric_dict[metric](opts)
46 | total_time = time.time() - start_time
47 |
48 | # Broadcast results.
49 | for key, value in list(results.items()):
50 | if opts.num_gpus > 1:
51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52 | torch.distributed.broadcast(tensor=value, src=0)
53 | value = float(value.cpu())
54 | results[key] = value
55 |
56 | # Decorate with metadata.
57 | return dnnlib.EasyDict(
58 | results = dnnlib.EasyDict(results),
59 | metric = metric,
60 | total_time = total_time,
61 | total_time_str = dnnlib.util.format_time(total_time),
62 | num_gpus = opts.num_gpus,
63 | )
64 |
65 | #----------------------------------------------------------------------------
66 |
67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68 | metric = result_dict['metric']
69 | assert is_valid_metric(metric)
70 | if run_dir is not None and snapshot_pkl is not None:
71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72 |
73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74 | print(jsonl_line)
75 | if run_dir is not None and os.path.isdir(run_dir):
76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77 | f.write(jsonl_line + '\n')
78 |
79 | #----------------------------------------------------------------------------
80 | # Primary metrics.
81 |
82 | @register_metric
83 | def fid50k_full(opts):
84 | opts.dataset_kwargs.update(max_size=None, xflip=False)
85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86 | return dict(fid50k_full=fid)
87 |
88 | @register_metric
89 | def kid50k_full(opts):
90 | opts.dataset_kwargs.update(max_size=None, xflip=False)
91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
92 | return dict(kid50k_full=kid)
93 |
94 | @register_metric
95 | def pr50k3_full(opts):
96 | opts.dataset_kwargs.update(max_size=None, xflip=False)
97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99 |
100 | @register_metric
101 | def ppl2_wend(opts):
102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103 | return dict(ppl2_wend=ppl)
104 |
105 | @register_metric
106 | def is50k(opts):
107 | opts.dataset_kwargs.update(max_size=None, xflip=False)
108 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
109 | return dict(is50k_mean=mean, is50k_std=std)
110 |
111 | #----------------------------------------------------------------------------
112 | # Legacy metrics.
113 |
114 | @register_metric
115 | def fid50k(opts):
116 | opts.dataset_kwargs.update(max_size=None)
117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118 | return dict(fid50k=fid)
119 |
120 | @register_metric
121 | def kid50k(opts):
122 | opts.dataset_kwargs.update(max_size=None)
123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124 | return dict(kid50k=kid)
125 |
126 | @register_metric
127 | def pr50k3(opts):
128 | opts.dataset_kwargs.update(max_size=None)
129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131 |
132 | @register_metric
133 | def ppl_zfull(opts):
134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135 | return dict(ppl_zfull=ppl)
136 |
137 | @register_metric
138 | def ppl_wfull(opts):
139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140 | return dict(ppl_wfull=ppl)
141 |
142 | @register_metric
143 | def ppl_zend(opts):
144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145 | return dict(ppl_zend=ppl)
146 |
147 | @register_metric
148 | def ppl_wend(opts):
149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150 | return dict(ppl_wend=ppl)
151 |
152 | #----------------------------------------------------------------------------
153 |
--------------------------------------------------------------------------------
/metrics/metric_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import hashlib
12 | import pickle
13 | import copy
14 | import uuid
15 | import numpy as np
16 | import torch
17 | import torch.nn.functional as F
18 | import dnnlib
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | class MetricOptions:
23 | def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
24 | assert 0 <= rank < num_gpus
25 | self.G = G
26 | self.G_kwargs = dnnlib.EasyDict(G_kwargs)
27 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
28 | self.num_gpus = num_gpus
29 | self.rank = rank
30 | self.device = device if device is not None else torch.device('cuda', rank)
31 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
32 | self.cache = cache
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | _feature_detector_cache = dict()
37 |
38 | def get_feature_detector_name(url):
39 | return os.path.splitext(url.split('/')[-1])[0]
40 |
41 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
42 | assert 0 <= rank < num_gpus
43 | key = (url, device)
44 | if key not in _feature_detector_cache:
45 | is_leader = (rank == 0)
46 | if not is_leader and num_gpus > 1:
47 | torch.distributed.barrier() # leader goes first
48 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
49 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
50 | if is_leader and num_gpus > 1:
51 | torch.distributed.barrier() # others follow
52 | return _feature_detector_cache[key]
53 |
54 | #----------------------------------------------------------------------------
55 |
56 | class FeatureStats:
57 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
58 | self.capture_all = capture_all
59 | self.capture_mean_cov = capture_mean_cov
60 | self.max_items = max_items
61 | self.num_items = 0
62 | self.num_features = None
63 | self.all_features = None
64 | self.raw_mean = None
65 | self.raw_cov = None
66 |
67 | def set_num_features(self, num_features):
68 | if self.num_features is not None:
69 | assert num_features == self.num_features
70 | else:
71 | self.num_features = num_features
72 | self.all_features = []
73 | self.raw_mean = np.zeros([num_features], dtype=np.float64)
74 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
75 |
76 | def is_full(self):
77 | return (self.max_items is not None) and (self.num_items >= self.max_items)
78 |
79 | def append(self, x):
80 | x = np.asarray(x, dtype=np.float32)
81 | assert x.ndim == 2
82 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
83 | if self.num_items >= self.max_items:
84 | return
85 | x = x[:self.max_items - self.num_items]
86 |
87 | self.set_num_features(x.shape[1])
88 | self.num_items += x.shape[0]
89 | if self.capture_all:
90 | self.all_features.append(x)
91 | if self.capture_mean_cov:
92 | x64 = x.astype(np.float64)
93 | self.raw_mean += x64.sum(axis=0)
94 | self.raw_cov += x64.T @ x64
95 |
96 | def append_torch(self, x, num_gpus=1, rank=0):
97 | assert isinstance(x, torch.Tensor) and x.ndim == 2
98 | assert 0 <= rank < num_gpus
99 | if num_gpus > 1:
100 | ys = []
101 | for src in range(num_gpus):
102 | y = x.clone()
103 | torch.distributed.broadcast(y, src=src)
104 | ys.append(y)
105 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
106 | self.append(x.cpu().numpy())
107 |
108 | def get_all(self):
109 | assert self.capture_all
110 | return np.concatenate(self.all_features, axis=0)
111 |
112 | def get_all_torch(self):
113 | return torch.from_numpy(self.get_all())
114 |
115 | def get_mean_cov(self):
116 | assert self.capture_mean_cov
117 | mean = self.raw_mean / self.num_items
118 | cov = self.raw_cov / self.num_items
119 | cov = cov - np.outer(mean, mean)
120 | return mean, cov
121 |
122 | def save(self, pkl_file):
123 | with open(pkl_file, 'wb') as f:
124 | pickle.dump(self.__dict__, f)
125 |
126 | @staticmethod
127 | def load(pkl_file):
128 | with open(pkl_file, 'rb') as f:
129 | s = dnnlib.EasyDict(pickle.load(f))
130 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
131 | obj.__dict__.update(s)
132 | return obj
133 |
134 | #----------------------------------------------------------------------------
135 |
136 | class ProgressMonitor:
137 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
138 | self.tag = tag
139 | self.num_items = num_items
140 | self.verbose = verbose
141 | self.flush_interval = flush_interval
142 | self.progress_fn = progress_fn
143 | self.pfn_lo = pfn_lo
144 | self.pfn_hi = pfn_hi
145 | self.pfn_total = pfn_total
146 | self.start_time = time.time()
147 | self.batch_time = self.start_time
148 | self.batch_items = 0
149 | if self.progress_fn is not None:
150 | self.progress_fn(self.pfn_lo, self.pfn_total)
151 |
152 | def update(self, cur_items):
153 | assert (self.num_items is None) or (cur_items <= self.num_items)
154 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
155 | return
156 | cur_time = time.time()
157 | total_time = cur_time - self.start_time
158 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
159 | if (self.verbose) and (self.tag is not None):
160 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
161 | self.batch_time = cur_time
162 | self.batch_items = cur_items
163 |
164 | if (self.progress_fn is not None) and (self.num_items is not None):
165 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
166 |
167 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
168 | return ProgressMonitor(
169 | tag = tag,
170 | num_items = num_items,
171 | flush_interval = flush_interval,
172 | verbose = self.verbose,
173 | progress_fn = self.progress_fn,
174 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
175 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
176 | pfn_total = self.pfn_total,
177 | )
178 |
179 | #----------------------------------------------------------------------------
180 |
181 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
182 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
183 | if data_loader_kwargs is None:
184 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
185 |
186 | # Try to lookup from cache.
187 | cache_file = None
188 | if opts.cache:
189 | # Choose cache file name.
190 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
191 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
192 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
193 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
194 |
195 | # Check if the file exists (all processes must agree).
196 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False
197 | if opts.num_gpus > 1:
198 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
199 | torch.distributed.broadcast(tensor=flag, src=0)
200 | flag = (float(flag.cpu()) != 0)
201 |
202 | # Load.
203 | if flag:
204 | return FeatureStats.load(cache_file)
205 |
206 | # Initialize.
207 | num_items = len(dataset)
208 | if max_items is not None:
209 | num_items = min(num_items, max_items)
210 | stats = FeatureStats(max_items=num_items, **stats_kwargs)
211 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
212 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
213 |
214 | # Main loop.
215 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
216 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
217 | if images.shape[1] == 1:
218 | images = images.repeat([1, 3, 1, 1])
219 | features = detector(images.to(opts.device), **detector_kwargs)
220 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
221 | progress.update(stats.num_items)
222 |
223 | # Save to cache.
224 | if cache_file is not None and opts.rank == 0:
225 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
226 | temp_file = cache_file + '.' + uuid.uuid4().hex
227 | stats.save(temp_file)
228 | os.replace(temp_file, cache_file) # atomic
229 | return stats
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def lap_to_img(lap_imgs):
234 | img = 0
235 | h, w = lap_imgs[-1].size(2), lap_imgs[-1].size(3)
236 | for la in lap_imgs:
237 | img = img + F.interpolate(la, size=(h, w))
238 |
239 | return img
240 |
241 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
242 | if batch_gen is None:
243 | batch_gen = min(batch_size, 4)
244 | assert batch_size % batch_gen == 0
245 |
246 | # Setup generator and load labels.
247 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
248 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
249 |
250 | # TODO : change this (LAP)
251 | # Image generation func.
252 | def run_generator(z, c):
253 | img = G(z=z, c=c, **opts.G_kwargs)
254 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
255 | return img
256 |
257 | # JIT.
258 | if jit:
259 | z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
260 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
261 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
262 |
263 | # Initialize.
264 | stats = FeatureStats(**stats_kwargs)
265 | assert stats.max_items is not None
266 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
267 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
268 |
269 | # Main loop.
270 | while not stats.is_full():
271 | images = []
272 | for _i in range(batch_size // batch_gen):
273 | z = torch.randn([batch_gen, G.z_dim], device=opts.device)
274 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
275 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
276 | images.append(run_generator(z, c))
277 | images = torch.cat(images)
278 | if images.shape[1] == 1:
279 | images = images.repeat([1, 3, 1, 1])
280 | features = detector(images, **detector_kwargs)
281 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
282 | progress.update(stats.num_items)
283 | return stats
284 |
285 | #----------------------------------------------------------------------------
286 |
--------------------------------------------------------------------------------
/metrics/perceptual_path_length.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10 | Architecture for Generative Adversarial Networks". Matches the original
11 | implementation by Karras et al. at
12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13 |
14 | import copy
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 | from . import metric_utils
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | # Spherical interpolation of a batch of vectors.
23 | def slerp(a, b, t):
24 | a = a / a.norm(dim=-1, keepdim=True)
25 | b = b / b.norm(dim=-1, keepdim=True)
26 | d = (a * b).sum(dim=-1, keepdim=True)
27 | p = t * torch.acos(d)
28 | c = b - d * a
29 | c = c / c.norm(dim=-1, keepdim=True)
30 | d = a * torch.cos(p) + c * torch.sin(p)
31 | d = d / d.norm(dim=-1, keepdim=True)
32 | return d
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | class PPLSampler(torch.nn.Module):
37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38 | assert space in ['z', 'w']
39 | assert sampling in ['full', 'end']
40 | super().__init__()
41 | self.G = copy.deepcopy(G)
42 | self.G_kwargs = G_kwargs
43 | self.epsilon = epsilon
44 | self.space = space
45 | self.sampling = sampling
46 | self.crop = crop
47 | self.vgg16 = copy.deepcopy(vgg16)
48 |
49 | def forward(self, c):
50 | # Generate random latents and interpolation t-values.
51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53 |
54 | # Interpolate in W or Z.
55 | if self.space == 'w':
56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59 | else: # space == 'z'
60 | zt0 = slerp(z0, z1, t.unsqueeze(1))
61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63 |
64 | # Randomize noise buffers.
65 | for name, buf in self.G.named_buffers():
66 | if name.endswith('.noise_const'):
67 | buf.copy_(torch.randn_like(buf))
68 |
69 | # Generate images.
70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71 |
72 | # Center crop.
73 | if self.crop:
74 | assert img.shape[2] == img.shape[3]
75 | c = img.shape[2] // 8
76 | img = img[:, :, c*3 : c*7, c*2 : c*6]
77 |
78 | # Downsample to 256x256.
79 | factor = self.G.img_resolution // 256
80 | if factor > 1:
81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82 |
83 | # Scale dynamic range from [-1,1] to [0,255].
84 | img = (img + 1) * (255 / 2)
85 | if self.G.img_channels == 1:
86 | img = img.repeat([1, 3, 1, 1])
87 |
88 | # Evaluate differential LPIPS.
89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91 | return dist
92 |
93 | #----------------------------------------------------------------------------
94 |
95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99 |
100 | # Setup sampler.
101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102 | sampler.eval().requires_grad_(False).to(opts.device)
103 | if jit:
104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105 | sampler = torch.jit.trace(sampler, [c], check_trace=False)
106 |
107 | # Sampling loop.
108 | dist = []
109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111 | progress.update(batch_start)
112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114 | x = sampler(c)
115 | for src in range(opts.num_gpus):
116 | y = x.clone()
117 | if opts.num_gpus > 1:
118 | torch.distributed.broadcast(y, src=src)
119 | dist.append(y)
120 | progress.update(num_samples)
121 |
122 | # Compute PPL.
123 | if opts.rank != 0:
124 | return float('nan')
125 | dist = torch.cat(dist)[:num_samples].cpu().numpy()
126 | lo = np.percentile(dist, 1, interpolation='lower')
127 | hi = np.percentile(dist, 99, interpolation='higher')
128 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129 | return float(ppl)
130 |
131 | #----------------------------------------------------------------------------
132 |
--------------------------------------------------------------------------------
/metrics/precision_recall.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall
10 | Metric for Assessing Generative Models". Matches the original implementation
11 | by Kynkaanniemi et al. at
12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13 |
14 | import torch
15 | from . import metric_utils
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20 | assert 0 <= rank < num_gpus
21 | num_cols = col_features.shape[0]
22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24 | dist_batches = []
25 | for col_batch in col_batches[rank :: num_gpus]:
26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27 | for src in range(num_gpus):
28 | dist_broadcast = dist_batch.clone()
29 | if num_gpus > 1:
30 | torch.distributed.broadcast(dist_broadcast, src=src)
31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38 | detector_kwargs = dict(return_features=True)
39 |
40 | real_features = metric_utils.compute_feature_stats_for_dataset(
41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43 |
44 | gen_features = metric_utils.compute_feature_stats_for_generator(
45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47 |
48 | results = dict()
49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50 | kth = []
51 | for manifold_batch in manifold.split(row_batch_size):
52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54 | kth = torch.cat(kth) if opts.rank == 0 else None
55 | pred = []
56 | for probes_batch in probes.split(row_batch_size):
57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60 | return results['precision'], results['recall']
61 |
62 | #----------------------------------------------------------------------------
63 |
--------------------------------------------------------------------------------
/tools/visualize_gfeat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import click
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision
9 | from kmeans_pytorch import kmeans
10 |
11 | import sys
12 | from pathlib import Path
13 | file = Path(__file__).resolve()
14 | parent, root = file.parent, file.parents[1]
15 | sys.path.append(str(root))
16 |
17 | import legacy
18 | import dnnlib
19 |
20 |
21 | @click.command()
22 | @click.pass_context
23 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
24 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
25 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
26 | @click.option('--num_iters', help='Number of iteration for visualization', type=int, default=1)
27 | @click.option('--batch_size', help='Batch size for clustering', type=int, default=64)
28 | def generate_images(
29 | ctx: click.Context,
30 | network_pkl: str,
31 | truncation_psi: float,
32 | outdir: str,
33 | num_iters: int,
34 | batch_size: int,
35 | ):
36 | """K-means visualization of generator feature maps. Cluster the images in the same batch(So the batch size matters here)
37 |
38 | Usage:
39 | python tools/visualize_gfeat.py --outdir=out --network=your_network_path.pkl
40 | """
41 | torch.manual_seed(0)
42 | random.seed(0)
43 | np.random.seed(0)
44 |
45 | print('Loading networks from "%s"...' % network_pkl)
46 | device = torch.device('cuda')
47 |
48 | with dnnlib.util.open_url(network_pkl) as f:
49 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
50 |
51 | os.makedirs(f'{outdir}', exist_ok=True)
52 |
53 | for iter_idx in range(num_iters):
54 |
55 | z = torch.from_numpy(np.random.randn(batch_size, G.z_dim)).to(device)
56 | ws = G.mapping(z, c=None, truncation_psi=truncation_psi)
57 |
58 | fake_imgs, fake_feat = G.synthesis(ws, get_feat=True)
59 |
60 | vis_img = []
61 |
62 | # the feature maps are saved in the dictionary whose keys are their
63 | # resolutions.
64 | target_layers = [16, 32, 64]
65 | num_clusters = 6
66 |
67 | for res in target_layers:
68 | img = get_cluster_vis(fake_feat[res], num_clusters=num_clusters, target_res=res) # bnum, 256, 256
69 | vis_img.append(img)
70 |
71 | for idx, val in enumerate(vis_img):
72 | vis_img[idx] = F.interpolate(val, size=(256, 256))
73 |
74 | vis_img = torch.cat(vis_img, dim=0) # bnum * res_num, 256, 256
75 | vis_img = (vis_img + 1) * 127.5 / 255.0
76 | fake_imgs = (fake_imgs + 1) * 127.5 / 255.0
77 | fake_imgs = F.interpolate(fake_imgs, size=(256, 256))
78 |
79 | vis_img = torch.cat([fake_imgs, vis_img], dim=0)
80 | vis_img = torchvision.utils.make_grid(vis_img, normalize=False, nrow=batch_size)
81 | torchvision.utils.save_image(vis_img, f'{outdir}/{iter_idx}.png')
82 |
83 |
84 | def get_colors():
85 | dummy_color = np.array([
86 | [178, 34, 34], # firebrick
87 | [0, 139, 139], # dark cyan
88 | [245, 222, 179], # wheat
89 | [25, 25, 112], # midnight blue
90 | [255, 140, 0], # dark orange
91 | [128, 128, 0], # olive
92 | [50, 50, 50], # dark grey
93 | [34, 139, 34], # forest green
94 | [100, 149, 237], # corn flower blue
95 | [153, 50, 204], # dark orchid
96 | [240, 128, 128], # light coral
97 | ])
98 |
99 | for t in (0.6, 0.3): # just increase the number of colors for big K
100 | dummy_color = np.concatenate((dummy_color, dummy_color * t))
101 |
102 | dummy_color = (np.array(dummy_color) - 128.0) / 128.0
103 | dummy_color = torch.from_numpy(dummy_color)
104 |
105 | return dummy_color
106 |
107 |
108 | def get_cluster_vis(feat, num_clusters=10, target_res=16):
109 | # feat : NCHW
110 | print(feat.size())
111 | img_num, C, H, W = feat.size()
112 | feat = feat.permute(0, 2, 3, 1).contiguous().view(img_num * H * W, -1)
113 | feat = feat.to(torch.float32).cuda()
114 | cluster_ids_x, cluster_centers = kmeans(
115 | X=feat, num_clusters=num_clusters, distance='cosine',
116 | tol=1e-4,
117 | device=torch.device("cuda:0"))
118 |
119 | cluster_ids_x = cluster_ids_x.cuda()
120 | cluster_centers = cluster_centers.cuda()
121 | color_rgb = get_colors().cuda()
122 | vis_img = []
123 | for idx in range(img_num):
124 | num_pixel = target_res * target_res
125 | current_res = cluster_ids_x[num_pixel * idx:num_pixel * (idx + 1)].cuda()
126 | color_ids = torch.index_select(color_rgb, 0, current_res)
127 | color_map = color_ids.permute(1, 0).view(1, 3, target_res, target_res)
128 | color_map = F.interpolate(color_map, size=(256, 256))
129 | vis_img.append(color_map.cuda())
130 |
131 | vis_img = torch.cat(vis_img, dim=0)
132 |
133 | return vis_img
134 |
135 |
136 | if __name__ == "__main__":
137 | generate_images() # pylint: disable=no-value-for-parameter
138 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import glob
11 | import torch
12 | import torch.utils.cpp_extension
13 | import importlib
14 | import hashlib
15 | import shutil
16 | from pathlib import Path
17 |
18 | from torch.utils.file_baton import FileBaton
19 |
20 | #----------------------------------------------------------------------------
21 | # Global options.
22 |
23 | verbosity = 'full' # Verbosity level: 'none', 'brief', 'full'
24 |
25 | #----------------------------------------------------------------------------
26 | # Internal helper funcs.
27 |
28 | def _find_compiler_bindir():
29 | patterns = [
30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34 | ]
35 | for pattern in patterns:
36 | matches = sorted(glob.glob(pattern))
37 | if len(matches):
38 | return matches[-1]
39 | return None
40 |
41 | #----------------------------------------------------------------------------
42 | # Main entry point for compiling and loading C++/CUDA plugins.
43 |
44 | _cached_plugins = dict()
45 |
46 | def get_plugin(module_name, sources, **build_kwargs):
47 | assert verbosity in ['none', 'brief', 'full']
48 |
49 | # Already cached?
50 | if module_name in _cached_plugins:
51 | return _cached_plugins[module_name]
52 |
53 | # Print status.
54 | if verbosity == 'full':
55 | print(f'Setting up PyTorch plugin "{module_name}"...')
56 | elif verbosity == 'brief':
57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58 |
59 | try: # pylint: disable=too-many-nested-blocks
60 | # Make sure we can find the necessary compiler binaries.
61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62 | compiler_bindir = _find_compiler_bindir()
63 | if compiler_bindir is None:
64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65 | os.environ['PATH'] += ';' + compiler_bindir
66 |
67 | # Compile and load.
68 | verbose_build = (verbosity == 'full')
69 |
70 | # Incremental build md5sum trickery. Copies all the input source files
71 | # into a cached build directory under a combined md5 digest of the input
72 | # source files. Copying is done only if the combined digest has changed.
73 | # This keeps input file timestamps and filenames the same as in previous
74 | # extension builds, allowing for fast incremental rebuilds.
75 | #
76 | # This optimization is done only in case all the source files reside in
77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78 | # environment variable is set (we take this as a signal that the user
79 | # actually cares about this.)
80 | source_dirs_set = set(os.path.dirname(source) for source in sources)
81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83 |
84 | # Compute a combined hash digest for all source files in the same
85 | # custom op directory (usually .cu, .cpp, .py and .h files).
86 | hash_md5 = hashlib.md5()
87 | for src in all_source_files:
88 | with open(src, 'rb') as f:
89 | hash_md5.update(f.read())
90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92 |
93 | if not os.path.isdir(digest_build_dir):
94 | os.makedirs(digest_build_dir, exist_ok=True)
95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96 | if baton.try_acquire():
97 | try:
98 | for src in all_source_files:
99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100 | finally:
101 | baton.release()
102 | else:
103 | # Someone else is copying source files under the digest dir,
104 | # wait until done and continue.
105 | baton.wait()
106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108 | verbose=verbose_build, sources=digest_sources, **build_kwargs)
109 | else:
110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111 | module = importlib.import_module(module_name)
112 |
113 | except:
114 | if verbosity == 'brief':
115 | print('Failed!')
116 | raise
117 |
118 | # Print status and add to cache.
119 | if verbosity == 'full':
120 | print(f'Done setting up PyTorch plugin "{module_name}".')
121 | elif verbosity == 'brief':
122 | print('Done.')
123 | _cached_plugins[module_name] = module
124 | return module
125 |
126 | #----------------------------------------------------------------------------
127 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import re
10 | import contextlib
11 | import numpy as np
12 | import torch
13 | import warnings
14 | import dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Replace NaN/Inf with specified numerical values.
45 |
46 | try:
47 | nan_to_num = torch.nan_to_num # 1.8.0a0
48 | except AttributeError:
49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50 | assert isinstance(input, torch.Tensor)
51 | if posinf is None:
52 | posinf = torch.finfo(input.dtype).max
53 | if neginf is None:
54 | neginf = torch.finfo(input.dtype).min
55 | assert nan == 0
56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57 |
58 | #----------------------------------------------------------------------------
59 | # Symbolic assert.
60 |
61 | try:
62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63 | except AttributeError:
64 | symbolic_assert = torch.Assert # 1.7.0
65 |
66 | #----------------------------------------------------------------------------
67 | # Context manager to suppress known warnings in torch.jit.trace().
68 |
69 | class suppress_tracer_warnings(warnings.catch_warnings):
70 | def __enter__(self):
71 | super().__enter__()
72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
73 | return self
74 |
75 | #----------------------------------------------------------------------------
76 | # Assert that the shape of a tensor matches the given list of integers.
77 | # None indicates that the size of a dimension is allowed to vary.
78 | # Performs symbolic assertion when used in torch.jit.trace().
79 |
80 | def assert_shape(tensor, ref_shape):
81 | if tensor.ndim != len(ref_shape):
82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
84 | if ref_size is None:
85 | pass
86 | elif isinstance(ref_size, torch.Tensor):
87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
89 | elif isinstance(size, torch.Tensor):
90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
92 | elif size != ref_size:
93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
94 |
95 | #----------------------------------------------------------------------------
96 | # Function decorator that calls torch.autograd.profiler.record_function().
97 |
98 | def profiled_function(fn):
99 | def decorator(*args, **kwargs):
100 | with torch.autograd.profiler.record_function(fn.__name__):
101 | return fn(*args, **kwargs)
102 | decorator.__name__ = fn.__name__
103 | return decorator
104 |
105 | #----------------------------------------------------------------------------
106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
107 | # indefinitely, shuffling items as it goes.
108 |
109 | class InfiniteSampler(torch.utils.data.Sampler):
110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
111 | assert len(dataset) > 0
112 | assert num_replicas > 0
113 | assert 0 <= rank < num_replicas
114 | assert 0 <= window_size <= 1
115 | super().__init__(dataset)
116 | self.dataset = dataset
117 | self.rank = rank
118 | self.num_replicas = num_replicas
119 | self.shuffle = shuffle
120 | self.seed = seed
121 | self.window_size = window_size
122 |
123 | def __iter__(self):
124 | order = np.arange(len(self.dataset))
125 | rnd = None
126 | window = 0
127 | if self.shuffle:
128 | rnd = np.random.RandomState(self.seed)
129 | rnd.shuffle(order)
130 | window = int(np.rint(order.size * self.window_size))
131 |
132 | idx = 0
133 | while True:
134 | i = idx % order.size
135 | if idx % self.num_replicas == self.rank:
136 | yield order[i]
137 | if window >= 2:
138 | j = (i - rnd.randint(window)) % order.size
139 | order[i], order[j] = order[j], order[i]
140 | idx += 1
141 |
142 | #----------------------------------------------------------------------------
143 | # Utilities for operating with torch.nn.Module parameters and buffers.
144 |
145 | def params_and_buffers(module):
146 | assert isinstance(module, torch.nn.Module)
147 | return list(module.parameters()) + list(module.buffers())
148 |
149 | def named_params_and_buffers(module):
150 | assert isinstance(module, torch.nn.Module)
151 | return list(module.named_parameters()) + list(module.named_buffers())
152 |
153 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
154 | assert isinstance(src_module, torch.nn.Module)
155 | assert isinstance(dst_module, torch.nn.Module)
156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
157 | for name, tensor in named_params_and_buffers(dst_module):
158 | assert (name in src_tensors) or (not require_all)
159 | if name in src_tensors:
160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
161 |
162 | #----------------------------------------------------------------------------
163 | # Context manager for easily enabling/disabling DistributedDataParallel
164 | # synchronization.
165 |
166 | @contextlib.contextmanager
167 | def ddp_sync(module, sync):
168 | assert isinstance(module, torch.nn.Module)
169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
170 | yield
171 | else:
172 | with module.no_sync():
173 | yield
174 |
175 | #----------------------------------------------------------------------------
176 | # Check DistributedDataParallel consistency across processes.
177 |
178 | def check_ddp_consistency(module, ignore_regex=None):
179 | assert isinstance(module, torch.nn.Module)
180 | for name, tensor in named_params_and_buffers(module):
181 | fullname = type(module).__name__ + '.' + name
182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
183 | continue
184 | tensor = tensor.detach()
185 | other = tensor.clone()
186 | torch.distributed.broadcast(tensor=other, src=0)
187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
188 |
189 | #----------------------------------------------------------------------------
190 | # Print summary table of module hierarchy.
191 |
192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
193 | assert isinstance(module, torch.nn.Module)
194 | assert not isinstance(module, torch.jit.ScriptModule)
195 | assert isinstance(inputs, (tuple, list))
196 |
197 | # Register hooks.
198 | entries = []
199 | nesting = [0]
200 | def pre_hook(_mod, _inputs):
201 | nesting[0] += 1
202 | def post_hook(mod, _inputs, outputs):
203 | nesting[0] -= 1
204 | if nesting[0] <= max_nesting:
205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
210 |
211 | # Run module.
212 | outputs = module(*inputs)
213 | for hook in hooks:
214 | hook.remove()
215 |
216 | # Identify unique outputs, parameters, and buffers.
217 | tensors_seen = set()
218 | for e in entries:
219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
223 |
224 | # Filter out redundant entries.
225 | if skip_redundant:
226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
227 |
228 | # Construct table.
229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
230 | rows += [['---'] * len(rows[0])]
231 | param_total = 0
232 | buffer_total = 0
233 | submodule_names = {mod: name for name, mod in module.named_modules()}
234 | for e in entries:
235 | name = '' if e.mod is module else submodule_names[e.mod]
236 | param_size = sum(t.numel() for t in e.unique_params)
237 | buffer_size = sum(t.numel() for t in e.unique_buffers)
238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
240 | rows += [[
241 | name + (':0' if len(e.outputs) >= 2 else ''),
242 | str(param_size) if param_size else '-',
243 | str(buffer_size) if buffer_size else '-',
244 | (output_shapes + ['-'])[0],
245 | (output_dtypes + ['-'])[0],
246 | ]]
247 | for idx in range(1, len(e.outputs)):
248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
249 | param_total += param_size
250 | buffer_total += buffer_size
251 | rows += [['---'] * len(rows[0])]
252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
253 |
254 | # Print table.
255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
256 | print()
257 | for row in rows:
258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
259 | print()
260 | return outputs
261 |
262 | #----------------------------------------------------------------------------
263 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import warnings
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | import traceback
17 |
18 | from .. import custom_ops
19 | from .. import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | activation_funcs = {
24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33 | }
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | _inited = False
38 | _plugin = None
39 | _null_tensor = torch.empty([0])
40 |
41 | def _init():
42 | global _inited, _plugin
43 | if not _inited:
44 | _inited = True
45 | sources = ['bias_act.cpp', 'bias_act.cu']
46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47 | try:
48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49 | except:
50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51 | return _plugin is not None
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56 | r"""Fused bias and activation function.
57 |
58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59 | and scales the result by `gain`. Each of the steps is optional. In most cases,
60 | the fused op is considerably more efficient than performing the same calculation
61 | using standard PyTorch ops. It supports first and second order gradients,
62 | but not third order gradients.
63 |
64 | Args:
65 | x: Input activation tensor. Can be of any shape.
66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67 | as `x`. The shape must be known, and it must match the dimension of `x`
68 | corresponding to `dim`.
69 | dim: The dimension in `x` corresponding to the elements of `b`.
70 | The value of `dim` is ignored if `b` is not specified.
71 | act: Name of the activation function to evaluate, or `"linear"` to disable.
72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73 | See `activation_funcs` for a full list. `None` is not allowed.
74 | alpha: Shape parameter for the activation function, or `None` to use the default.
75 | gain: Scaling factor for the output tensor, or `None` to use default.
76 | See `activation_funcs` for the default scaling of each activation function.
77 | If unsure, consider specifying 1.
78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79 | the clamping (default).
80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81 |
82 | Returns:
83 | Tensor of the same shape and datatype as `x`.
84 | """
85 | assert isinstance(x, torch.Tensor)
86 | assert impl in ['ref', 'cuda']
87 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90 |
91 | #----------------------------------------------------------------------------
92 |
93 | @misc.profiled_function
94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96 | """
97 | assert isinstance(x, torch.Tensor)
98 | assert clamp is None or clamp >= 0
99 | spec = activation_funcs[act]
100 | alpha = float(alpha if alpha is not None else spec.def_alpha)
101 | gain = float(gain if gain is not None else spec.def_gain)
102 | clamp = float(clamp if clamp is not None else -1)
103 |
104 | # Add bias.
105 | if b is not None:
106 | assert isinstance(b, torch.Tensor) and b.ndim == 1
107 | assert 0 <= dim < x.ndim
108 | assert b.shape[0] == x.shape[dim]
109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110 |
111 | # Evaluate activation function.
112 | alpha = float(alpha)
113 | x = spec.func(x, alpha=alpha)
114 |
115 | # Scale by gain.
116 | gain = float(gain)
117 | if gain != 1:
118 | x = x * gain
119 |
120 | # Clamp.
121 | if clamp >= 0:
122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123 | return x
124 |
125 | #----------------------------------------------------------------------------
126 |
127 | _bias_act_cuda_cache = dict()
128 |
129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130 | """Fast CUDA implementation of `bias_act()` using custom ops.
131 | """
132 | # Parse arguments.
133 | assert clamp is None or clamp >= 0
134 | spec = activation_funcs[act]
135 | alpha = float(alpha if alpha is not None else spec.def_alpha)
136 | gain = float(gain if gain is not None else spec.def_gain)
137 | clamp = float(clamp if clamp is not None else -1)
138 |
139 | # Lookup from cache.
140 | key = (dim, act, alpha, gain, clamp)
141 | if key in _bias_act_cuda_cache:
142 | return _bias_act_cuda_cache[key]
143 |
144 | # Forward op.
145 | class BiasActCuda(torch.autograd.Function):
146 | @staticmethod
147 | def forward(ctx, x, b): # pylint: disable=arguments-differ
148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149 | x = x.contiguous(memory_format=ctx.memory_format)
150 | b = b.contiguous() if b is not None else _null_tensor
151 | y = x
152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154 | ctx.save_for_backward(
155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157 | y if 'y' in spec.ref else _null_tensor)
158 | return y
159 |
160 | @staticmethod
161 | def backward(ctx, dy): # pylint: disable=arguments-differ
162 | dy = dy.contiguous(memory_format=ctx.memory_format)
163 | x, b, y = ctx.saved_tensors
164 | dx = None
165 | db = None
166 |
167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168 | dx = dy
169 | if act != 'linear' or gain != 1 or clamp >= 0:
170 | dx = BiasActCudaGrad.apply(dy, x, b, y)
171 |
172 | if ctx.needs_input_grad[1]:
173 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
174 |
175 | return dx, db
176 |
177 | # Backward op.
178 | class BiasActCudaGrad(torch.autograd.Function):
179 | @staticmethod
180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183 | ctx.save_for_backward(
184 | dy if spec.has_2nd_grad else _null_tensor,
185 | x, b, y)
186 | return dx
187 |
188 | @staticmethod
189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191 | dy, x, b, y = ctx.saved_tensors
192 | d_dy = None
193 | d_x = None
194 | d_b = None
195 | d_y = None
196 |
197 | if ctx.needs_input_grad[0]:
198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199 |
200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202 |
203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205 |
206 | return d_dy, d_x, d_b, d_y
207 |
208 | # Add to cache.
209 | _bias_act_cuda_cache[key] = BiasActCuda
210 | return BiasActCuda
211 |
212 | #----------------------------------------------------------------------------
213 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import warnings
13 | import contextlib
14 | import torch
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 |
25 | @contextlib.contextmanager
26 | def no_weight_gradients():
27 | global weight_gradients_disabled
28 | old = weight_gradients_disabled
29 | weight_gradients_disabled = True
30 | yield
31 | weight_gradients_disabled = old
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36 | if _should_use_custom_op(input):
37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39 |
40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41 | if _should_use_custom_op(input):
42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def _should_use_custom_op(input):
48 | assert isinstance(input, torch.Tensor)
49 | if (not enabled) or (not torch.backends.cudnn.enabled):
50 | return False
51 | if input.device.type != 'cuda':
52 | return False
53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54 | return True
55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56 | return False
57 |
58 | def _tuple_of_ints(xs, ndim):
59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60 | assert len(xs) == ndim
61 | assert all(isinstance(x, int) for x in xs)
62 | return xs
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | _conv2d_gradfix_cache = dict()
67 |
68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69 | # Parse arguments.
70 | ndim = 2
71 | weight_shape = tuple(weight_shape)
72 | stride = _tuple_of_ints(stride, ndim)
73 | padding = _tuple_of_ints(padding, ndim)
74 | output_padding = _tuple_of_ints(output_padding, ndim)
75 | dilation = _tuple_of_ints(dilation, ndim)
76 |
77 | # Lookup from cache.
78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79 | if key in _conv2d_gradfix_cache:
80 | return _conv2d_gradfix_cache[key]
81 |
82 | # Validate arguments.
83 | assert groups >= 1
84 | assert len(weight_shape) == ndim + 2
85 | assert all(stride[i] >= 1 for i in range(ndim))
86 | assert all(padding[i] >= 0 for i in range(ndim))
87 | assert all(dilation[i] >= 0 for i in range(ndim))
88 | if not transpose:
89 | assert all(output_padding[i] == 0 for i in range(ndim))
90 | else: # transpose
91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92 |
93 | # Helpers.
94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95 | def calc_output_padding(input_shape, output_shape):
96 | if transpose:
97 | return [0, 0]
98 | return [
99 | input_shape[i + 2]
100 | - (output_shape[i + 2] - 1) * stride[i]
101 | - (1 - 2 * padding[i])
102 | - dilation[i] * (weight_shape[i + 2] - 1)
103 | for i in range(ndim)
104 | ]
105 |
106 | # Forward & backward.
107 | class Conv2d(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, bias):
110 | assert weight.shape == weight_shape
111 | if not transpose:
112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113 | else: # transpose
114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115 | ctx.save_for_backward(input, weight)
116 | return output
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | input, weight = ctx.saved_tensors
121 | grad_input = None
122 | grad_weight = None
123 | grad_bias = None
124 |
125 | if ctx.needs_input_grad[0]:
126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128 | assert grad_input.shape == input.shape
129 |
130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
132 | assert grad_weight.shape == weight_shape
133 |
134 | if ctx.needs_input_grad[2]:
135 | grad_bias = grad_output.sum([0, 2, 3])
136 |
137 | return grad_input, grad_weight, grad_bias
138 |
139 | # Gradient with respect to the weights.
140 | class Conv2dGradWeight(torch.autograd.Function):
141 | @staticmethod
142 | def forward(ctx, grad_output, input):
143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146 | assert grad_weight.shape == weight_shape
147 | ctx.save_for_backward(grad_output, input)
148 | return grad_weight
149 |
150 | @staticmethod
151 | def backward(ctx, grad2_grad_weight):
152 | grad_output, input = ctx.saved_tensors
153 | grad2_grad_output = None
154 | grad2_input = None
155 |
156 | if ctx.needs_input_grad[0]:
157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158 | assert grad2_grad_output.shape == grad_output.shape
159 |
160 | if ctx.needs_input_grad[1]:
161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163 | assert grad2_input.shape == input.shape
164 |
165 | return grad2_grad_output, grad2_input
166 |
167 | _conv2d_gradfix_cache[key] = Conv2d
168 | return Conv2d
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | w = w.flip([2, 3])
37 |
38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42 | if out_channels <= 4 and groups == 1:
43 | in_shape = x.shape
44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46 | else:
47 | x = x.to(memory_format=torch.contiguous_format)
48 | w = w.to(memory_format=torch.contiguous_format)
49 | x = conv2d_gradfix.conv2d(x, w, groups=groups)
50 | return x.to(memory_format=torch.channels_last)
51 |
52 | # Otherwise => execute using conv2d_gradfix.
53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54 | return op(x, w, stride=stride, padding=padding, groups=groups)
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | @misc.profiled_function
59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60 | r"""2D convolution with optional up/downsampling.
61 |
62 | Padding is performed only once at the beginning, not between the operations.
63 |
64 | Args:
65 | x: Input tensor of shape
66 | `[batch_size, in_channels, in_height, in_width]`.
67 | w: Weight tensor of shape
68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70 | calling upfirdn2d.setup_filter(). None = identity (default).
71 | up: Integer upsampling factor (default: 1).
72 | down: Integer downsampling factor (default: 1).
73 | padding: Padding with respect to the upsampled image. Can be a single number
74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75 | (default: 0).
76 | groups: Split input channels into N groups (default: 1).
77 | flip_weight: False = convolution, True = correlation (default: True).
78 | flip_filter: False = convolution, True = correlation (default: False).
79 |
80 | Returns:
81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82 | """
83 | # Validate arguments.
84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87 | assert isinstance(up, int) and (up >= 1)
88 | assert isinstance(down, int) and (down >= 1)
89 | assert isinstance(groups, int) and (groups >= 1)
90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91 | fw, fh = _get_filter_size(f)
92 | px0, px1, py0, py1 = _parse_padding(padding)
93 |
94 | # Adjust padding to account for up/downsampling.
95 | if up > 1:
96 | px0 += (fw + up - 1) // 2
97 | px1 += (fw - up) // 2
98 | py0 += (fh + up - 1) // 2
99 | py1 += (fh - up) // 2
100 | if down > 1:
101 | px0 += (fw - down + 1) // 2
102 | px1 += (fw - down) // 2
103 | py0 += (fh - down + 1) // 2
104 | py1 += (fh - down) // 2
105 |
106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110 | return x
111 |
112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116 | return x
117 |
118 | # Fast path: downsampling only => use strided convolution.
119 | if down > 1 and up == 1:
120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122 | return x
123 |
124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125 | if up > 1:
126 | if groups == 1:
127 | w = w.transpose(0, 1)
128 | else:
129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130 | w = w.transpose(1, 2)
131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132 | px0 -= kw - 1
133 | px1 -= kw - up
134 | py0 -= kh - 1
135 | py1 -= kh - up
136 | pxt = max(min(-px0, -px1), 0)
137 | pyt = max(min(-py0, -py1), 0)
138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140 | if down > 1:
141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142 | return x
143 |
144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145 | if up == 1 and down == 1:
146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148 |
149 | # Fallback: Generic reference implementation.
150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152 | if down > 1:
153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154 | return x
155 |
156 | #----------------------------------------------------------------------------
157 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import warnings
15 | import torch
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | def grid_sample(input, grid):
28 | if _should_use_custom_op():
29 | return _GridSample2dForward.apply(input, grid)
30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def _should_use_custom_op():
35 | if not enabled:
36 | return False
37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38 | return True
39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40 | return False
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | class _GridSample2dForward(torch.autograd.Function):
45 | @staticmethod
46 | def forward(ctx, input, grid):
47 | assert input.ndim == 4
48 | assert grid.ndim == 4
49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50 | ctx.save_for_backward(input, grid)
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, grid = ctx.saved_tensors
56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57 | return grad_input, grad_grid
58 |
59 | #----------------------------------------------------------------------------
60 |
61 | class _GridSample2dBackward(torch.autograd.Function):
62 | @staticmethod
63 | def forward(ctx, grad_output, input, grid):
64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29 |
30 | // Create output tensor.
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37 |
38 | // Initialize CUDA kernel parameters.
39 | upfirdn2d_kernel_params p;
40 | p.x = x.data_ptr();
41 | p.f = f.data_ptr();
42 | p.y = y.data_ptr();
43 | p.up = make_int2(upx, upy);
44 | p.down = make_int2(downx, downy);
45 | p.pad0 = make_int2(padx0, pady0);
46 | p.flip = (flip) ? 1 : 0;
47 | p.gain = gain;
48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56 |
57 | // Choose CUDA kernel.
58 | upfirdn2d_kernel_spec spec;
59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60 | {
61 | spec = choose_upfirdn2d_kernel(p);
62 | });
63 |
64 | // Set looping options.
65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66 | p.loopMinor = spec.loopMinor;
67 | p.loopX = spec.loopX;
68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70 |
71 | // Compute grid size.
72 | dim3 blockSize, gridSize;
73 | if (spec.tileOutW < 0) // large
74 | {
75 | blockSize = dim3(4, 32, 1);
76 | gridSize = dim3(
77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79 | p.launchMajor);
80 | }
81 | else // small
82 | {
83 | blockSize = dim3(256, 1, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 |
90 | // Launch CUDA kernel.
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("upfirdn2d", &upfirdn2d);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for pickling Python code alongside other data.
10 |
11 | The pickled code is automatically imported into a separate Python module
12 | during unpickling. This way, any previously exported pickles will remain
13 | usable even if the original code is no longer available, or if the current
14 | version of the code is not consistent with what was originally pickled."""
15 |
16 | import sys
17 | import pickle
18 | import io
19 | import inspect
20 | import copy
21 | import uuid
22 | import types
23 | import dnnlib
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | _version = 6 # internal version number
28 | _decorators = set() # {decorator_class, ...}
29 | _import_hooks = [] # [hook_function, ...]
30 | _module_to_src_dict = dict() # {module: src, ...}
31 | _src_to_module_dict = dict() # {src: module, ...}
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def persistent_class(orig_class):
36 | r"""Class decorator that extends a given class to save its source code
37 | when pickled.
38 |
39 | Example:
40 |
41 | from torch_utils import persistence
42 |
43 | @persistence.persistent_class
44 | class MyNetwork(torch.nn.Module):
45 | def __init__(self, num_inputs, num_outputs):
46 | super().__init__()
47 | self.fc = MyLayer(num_inputs, num_outputs)
48 | ...
49 |
50 | @persistence.persistent_class
51 | class MyLayer(torch.nn.Module):
52 | ...
53 |
54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55 | source code alongside other internal state (e.g., parameters, buffers,
56 | and submodules). This way, any previously exported pickle will remain
57 | usable even if the class definitions have been modified or are no
58 | longer available.
59 |
60 | The decorator saves the source code of the entire Python module
61 | containing the decorated class. It does *not* save the source code of
62 | any imported modules. Thus, the imported modules must be available
63 | during unpickling, also including `torch_utils.persistence` itself.
64 |
65 | It is ok to call functions defined in the same module from the
66 | decorated class. However, if the decorated class depends on other
67 | classes defined in the same module, they must be decorated as well.
68 | This is illustrated in the above example in the case of `MyLayer`.
69 |
70 | It is also possible to employ the decorator just-in-time before
71 | calling the constructor. For example:
72 |
73 | cls = MyLayer
74 | if want_to_make_it_persistent:
75 | cls = persistence.persistent_class(cls)
76 | layer = cls(num_inputs, num_outputs)
77 |
78 | As an additional feature, the decorator also keeps track of the
79 | arguments that were used to construct each instance of the decorated
80 | class. The arguments can be queried via `obj.init_args` and
81 | `obj.init_kwargs`, and they are automatically pickled alongside other
82 | object state. A typical use case is to first unpickle a previous
83 | instance of a persistent class, and then upgrade it to use the latest
84 | version of the source code:
85 |
86 | with open('old_pickle.pkl', 'rb') as f:
87 | old_net = pickle.load(f)
88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90 | """
91 | assert isinstance(orig_class, type)
92 | if is_persistent(orig_class):
93 | return orig_class
94 |
95 | assert orig_class.__module__ in sys.modules
96 | orig_module = sys.modules[orig_class.__module__]
97 | orig_module_src = _module_to_src(orig_module)
98 |
99 | class Decorator(orig_class):
100 | _orig_module_src = orig_module_src
101 | _orig_class_name = orig_class.__name__
102 |
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | self._init_args = copy.deepcopy(args)
106 | self._init_kwargs = copy.deepcopy(kwargs)
107 | assert orig_class.__name__ in orig_module.__dict__
108 | _check_pickleable(self.__reduce__())
109 |
110 | @property
111 | def init_args(self):
112 | return copy.deepcopy(self._init_args)
113 |
114 | @property
115 | def init_kwargs(self):
116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117 |
118 | def __reduce__(self):
119 | fields = list(super().__reduce__())
120 | fields += [None] * max(3 - len(fields), 0)
121 | if fields[0] is not _reconstruct_persistent_obj:
122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123 | fields[0] = _reconstruct_persistent_obj # reconstruct func
124 | fields[1] = (meta,) # reconstruct args
125 | fields[2] = None # state dict
126 | return tuple(fields)
127 |
128 | Decorator.__name__ = orig_class.__name__
129 | _decorators.add(Decorator)
130 | return Decorator
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | def is_persistent(obj):
135 | r"""Test whether the given object or class is persistent, i.e.,
136 | whether it will save its source code when pickled.
137 | """
138 | try:
139 | if obj in _decorators:
140 | return True
141 | except TypeError:
142 | pass
143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144 |
145 | #----------------------------------------------------------------------------
146 |
147 | def import_hook(hook):
148 | r"""Register an import hook that is called whenever a persistent object
149 | is being unpickled. A typical use case is to patch the pickled source
150 | code to avoid errors and inconsistencies when the API of some imported
151 | module has changed.
152 |
153 | The hook should have the following signature:
154 |
155 | hook(meta) -> modified meta
156 |
157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158 |
159 | type: Type of the persistent object, e.g. `'class'`.
160 | version: Internal version number of `torch_utils.persistence`.
161 | module_src Original source code of the Python module.
162 | class_name: Class name in the original Python module.
163 | state: Internal state of the object.
164 |
165 | Example:
166 |
167 | @persistence.import_hook
168 | def wreck_my_network(meta):
169 | if meta.class_name == 'MyNetwork':
170 | print('MyNetwork is being imported. I will wreck it!')
171 | meta.module_src = meta.module_src.replace("True", "False")
172 | return meta
173 | """
174 | assert callable(hook)
175 | _import_hooks.append(hook)
176 |
177 | #----------------------------------------------------------------------------
178 |
179 | def _reconstruct_persistent_obj(meta):
180 | r"""Hook that is called internally by the `pickle` module to unpickle
181 | a persistent object.
182 | """
183 | meta = dnnlib.EasyDict(meta)
184 | meta.state = dnnlib.EasyDict(meta.state)
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 | return obj
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def _module_to_src(module):
207 | r"""Query the source code of a given Python module.
208 | """
209 | src = _module_to_src_dict.get(module, None)
210 | if src is None:
211 | src = inspect.getsource(module)
212 | _module_to_src_dict[module] = src
213 | _src_to_module_dict[src] = module
214 | return src
215 |
216 | def _src_to_module(src):
217 | r"""Get or create a Python module for the given source code.
218 | """
219 | module = _src_to_module_dict.get(src, None)
220 | if module is None:
221 | module_name = "_imported_module_" + uuid.uuid4().hex
222 | module = types.ModuleType(module_name)
223 | sys.modules[module_name] = module
224 | _module_to_src_dict[module] = src
225 | _src_to_module_dict[src] = module
226 | exec(src, module.__dict__) # pylint: disable=exec-used
227 | return module
228 |
229 | #----------------------------------------------------------------------------
230 |
231 | def _check_pickleable(obj):
232 | r"""Check that the given object is pickleable, raising an exception if
233 | it is not. This function is expected to be considerably more efficient
234 | than actually pickling the object.
235 | """
236 | def recurse(obj):
237 | if isinstance(obj, (list, tuple, set)):
238 | return [recurse(x) for x in obj]
239 | if isinstance(obj, dict):
240 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242 | return None # Python primitive types are pickleable.
243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244 | return None # NumPy arrays and PyTorch tensors are pickleable.
245 | if is_persistent(obj):
246 | return None # Persistent objects are pickleable, by virtue of the constructor check.
247 | return obj
248 | with io.BytesIO() as f:
249 | pickle.dump(recurse(obj), f)
250 |
251 | #----------------------------------------------------------------------------
252 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for reporting and collecting training statistics across
10 | multiple processes and devices. The interface is designed to minimize
11 | synchronization overhead as well as the amount of boilerplate in user
12 | code."""
13 |
14 | import re
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | from . import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
26 | _rank = 0 # Rank of the current process.
27 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
28 | _sync_called = False # Has _sync() been called yet?
29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def init_multiprocessing(rank, sync_device):
35 | r"""Initializes `torch_utils.training_stats` for collecting statistics
36 | across multiple processes.
37 |
38 | This function must be called after
39 | `torch.distributed.init_process_group()` and before `Collector.update()`.
40 | The call is not necessary if multi-process collection is not needed.
41 |
42 | Args:
43 | rank: Rank of the current process.
44 | sync_device: PyTorch device to use for inter-process
45 | communication, or None to disable multi-process
46 | collection. Typically `torch.device('cuda', rank)`.
47 | """
48 | global _rank, _sync_device
49 | assert not _sync_called
50 | _rank = rank
51 | _sync_device = sync_device
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | @misc.profiled_function
56 | def report(name, value):
57 | r"""Broadcasts the given set of scalars to all interested instances of
58 | `Collector`, across device and process boundaries.
59 |
60 | This function is expected to be extremely cheap and can be safely
61 | called from anywhere in the training loop, loss function, or inside a
62 | `torch.nn.Module`.
63 |
64 | Warning: The current implementation expects the set of unique names to
65 | be consistent across processes. Please make sure that `report()` is
66 | called at least once for each unique name by each process, and in the
67 | same order. If a given process has no scalars to broadcast, it can do
68 | `report(name, [])` (empty list).
69 |
70 | Args:
71 | name: Arbitrary string specifying the name of the statistic.
72 | Averages are accumulated separately for each unique name.
73 | value: Arbitrary set of scalars. Can be a list, tuple,
74 | NumPy array, PyTorch tensor, or Python scalar.
75 |
76 | Returns:
77 | The same `value` that was passed in.
78 | """
79 | if name not in _counters:
80 | _counters[name] = dict()
81 |
82 | elems = torch.as_tensor(value)
83 | if elems.numel() == 0:
84 | return value
85 |
86 | elems = elems.detach().flatten().to(_reduce_dtype)
87 | moments = torch.stack([
88 | torch.ones_like(elems).sum(),
89 | elems.sum(),
90 | elems.square().sum(),
91 | ])
92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
93 | moments = moments.to(_counter_dtype)
94 |
95 | device = moments.device
96 | if device not in _counters[name]:
97 | _counters[name][device] = torch.zeros_like(moments)
98 | _counters[name][device].add_(moments)
99 | return value
100 |
101 | #----------------------------------------------------------------------------
102 |
103 | def report0(name, value):
104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105 | but ignores any scalars provided by the other processes.
106 | See `report()` for further details.
107 | """
108 | report(name, value if _rank == 0 else [])
109 | return value
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | class Collector:
114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
115 | computes their long-term averages (mean and standard deviation) over
116 | user-defined periods of time.
117 |
118 | The averages are first collected into internal counters that are not
119 | directly visible to the user. They are then copied to the user-visible
120 | state as a result of calling `update()` and can then be queried using
121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122 | internal counters for the next round, so that the user-visible state
123 | effectively reflects averages collected between the last two calls to
124 | `update()`.
125 |
126 | Args:
127 | regex: Regular expression defining which statistics to
128 | collect. The default is to collect everything.
129 | keep_previous: Whether to retain the previous averages if no
130 | scalars were collected on a given round
131 | (default: True).
132 | """
133 | def __init__(self, regex='.*', keep_previous=True):
134 | self._regex = re.compile(regex)
135 | self._keep_previous = keep_previous
136 | self._cumulative = dict()
137 | self._moments = dict()
138 | self.update()
139 | self._moments.clear()
140 |
141 | def names(self):
142 | r"""Returns the names of all statistics broadcasted so far that
143 | match the regular expression specified at construction time.
144 | """
145 | return [name for name in _counters if self._regex.fullmatch(name)]
146 |
147 | def update(self):
148 | r"""Copies current values of the internal counters to the
149 | user-visible state and resets them for the next round.
150 |
151 | If `keep_previous=True` was specified at construction time, the
152 | operation is skipped for statistics that have received no scalars
153 | since the last update, retaining their previous averages.
154 |
155 | This method performs a number of GPU-to-CPU transfers and one
156 | `torch.distributed.all_reduce()`. It is intended to be called
157 | periodically in the main training loop, typically once every
158 | N training steps.
159 | """
160 | if not self._keep_previous:
161 | self._moments.clear()
162 | for name, cumulative in _sync(self.names()):
163 | if name not in self._cumulative:
164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165 | delta = cumulative - self._cumulative[name]
166 | self._cumulative[name].copy_(cumulative)
167 | if float(delta[0]) != 0:
168 | self._moments[name] = delta
169 |
170 | def _get_delta(self, name):
171 | r"""Returns the raw moments that were accumulated for the given
172 | statistic between the last two calls to `update()`, or zero if
173 | no scalars were collected.
174 | """
175 | assert self._regex.fullmatch(name)
176 | if name not in self._moments:
177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178 | return self._moments[name]
179 |
180 | def num(self, name):
181 | r"""Returns the number of scalars that were accumulated for the given
182 | statistic between the last two calls to `update()`, or zero if
183 | no scalars were collected.
184 | """
185 | delta = self._get_delta(name)
186 | return int(delta[0])
187 |
188 | def mean(self, name):
189 | r"""Returns the mean of the scalars that were accumulated for the
190 | given statistic between the last two calls to `update()`, or NaN if
191 | no scalars were collected.
192 | """
193 | delta = self._get_delta(name)
194 | if int(delta[0]) == 0:
195 | return float('nan')
196 | return float(delta[1] / delta[0])
197 |
198 | def std(self, name):
199 | r"""Returns the standard deviation of the scalars that were
200 | accumulated for the given statistic between the last two calls to
201 | `update()`, or NaN if no scalars were collected.
202 | """
203 | delta = self._get_delta(name)
204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205 | return float('nan')
206 | if int(delta[0]) == 1:
207 | return float(0)
208 | mean = float(delta[1] / delta[0])
209 | raw_var = float(delta[2] / delta[0])
210 | return np.sqrt(max(raw_var - np.square(mean), 0))
211 |
212 | def as_dict(self):
213 | r"""Returns the averages accumulated between the last two calls to
214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215 |
216 | dnnlib.EasyDict(
217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218 | ...
219 | )
220 | """
221 | stats = dnnlib.EasyDict()
222 | for name in self.names():
223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224 | return stats
225 |
226 | def __getitem__(self, name):
227 | r"""Convenience getter.
228 | `collector[name]` is a synonym for `collector.mean(name)`.
229 | """
230 | return self.mean(name)
231 |
232 | #----------------------------------------------------------------------------
233 |
234 | def _sync(names):
235 | r"""Synchronize the global cumulative counters across devices and
236 | processes. Called internally by `Collector.update()`.
237 | """
238 | if len(names) == 0:
239 | return []
240 | global _sync_called
241 | _sync_called = True
242 |
243 | # Collect deltas within current rank.
244 | deltas = []
245 | device = _sync_device if _sync_device is not None else torch.device('cpu')
246 | for name in names:
247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248 | for counter in _counters[name].values():
249 | delta.add_(counter.to(device))
250 | counter.copy_(torch.zeros_like(counter))
251 | deltas.append(delta)
252 | deltas = torch.stack(deltas)
253 |
254 | # Sum deltas across ranks.
255 | if _sync_device is not None:
256 | torch.distributed.all_reduce(deltas)
257 |
258 | # Update cumulative values.
259 | deltas = deltas.cpu()
260 | for idx, name in enumerate(names):
261 | if name not in _cumulative:
262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263 | _cumulative[name].add_(deltas[idx])
264 |
265 | # Return name-value pairs.
266 | return [(name, _cumulative[name]) for name in names]
267 |
268 | #----------------------------------------------------------------------------
269 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import numpy as np
11 | import zipfile
12 | import PIL.Image
13 | import json
14 | import torch
15 | import dnnlib
16 |
17 | try:
18 | import pyspng
19 | except ImportError:
20 | pyspng = None
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | class Dataset(torch.utils.data.Dataset):
25 | def __init__(self,
26 | name, # Name of the dataset.
27 | raw_shape, # Shape of the raw image data (NCHW).
28 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
29 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
30 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
31 | random_seed = 0, # Random seed to use when applying max_size.
32 | ):
33 | self._name = name
34 | self._raw_shape = list(raw_shape)
35 | self._use_labels = use_labels
36 | self._raw_labels = None
37 | self._label_shape = None
38 |
39 | # Apply max_size.
40 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
41 | if (max_size is not None) and (self._raw_idx.size > max_size):
42 | np.random.RandomState(random_seed).shuffle(self._raw_idx)
43 | self._raw_idx = np.sort(self._raw_idx[:max_size])
44 |
45 | # Apply xflip.
46 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
47 | if xflip:
48 | self._raw_idx = np.tile(self._raw_idx, 2)
49 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
50 |
51 | def _get_raw_labels(self):
52 | if self._raw_labels is None:
53 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
54 | if self._raw_labels is None:
55 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
56 | assert isinstance(self._raw_labels, np.ndarray)
57 | assert self._raw_labels.shape[0] == self._raw_shape[0]
58 | assert self._raw_labels.dtype in [np.float32, np.int64]
59 | if self._raw_labels.dtype == np.int64:
60 | assert self._raw_labels.ndim == 1
61 | assert np.all(self._raw_labels >= 0)
62 | return self._raw_labels
63 |
64 | def close(self): # to be overridden by subclass
65 | pass
66 |
67 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
68 | raise NotImplementedError
69 |
70 | def _load_raw_labels(self): # to be overridden by subclass
71 | raise NotImplementedError
72 |
73 | def __getstate__(self):
74 | return dict(self.__dict__, _raw_labels=None)
75 |
76 | def __del__(self):
77 | try:
78 | self.close()
79 | except:
80 | pass
81 |
82 | def __len__(self):
83 | return self._raw_idx.size
84 |
85 | def __getitem__(self, idx):
86 | image = self._load_raw_image(self._raw_idx[idx])
87 | assert isinstance(image, np.ndarray)
88 | assert list(image.shape) == self.image_shape
89 | assert image.dtype == np.uint8
90 | if self._xflip[idx]:
91 | assert image.ndim == 3 # CHW
92 | image = image[:, :, ::-1]
93 | return image.copy(), self.get_label(idx)
94 |
95 | def get_label(self, idx):
96 | label = self._get_raw_labels()[self._raw_idx[idx]]
97 | if label.dtype == np.int64:
98 | onehot = np.zeros(self.label_shape, dtype=np.float32)
99 | onehot[label] = 1
100 | label = onehot
101 | return label.copy()
102 |
103 | def get_details(self, idx):
104 | d = dnnlib.EasyDict()
105 | d.raw_idx = int(self._raw_idx[idx])
106 | d.xflip = (int(self._xflip[idx]) != 0)
107 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
108 | return d
109 |
110 | @property
111 | def name(self):
112 | return self._name
113 |
114 | @property
115 | def image_shape(self):
116 | return list(self._raw_shape[1:])
117 |
118 | @property
119 | def num_channels(self):
120 | assert len(self.image_shape) == 3 # CHW
121 | return self.image_shape[0]
122 |
123 | @property
124 | def resolution(self):
125 | assert len(self.image_shape) == 3 # CHW
126 | assert self.image_shape[1] == self.image_shape[2]
127 | return self.image_shape[1]
128 |
129 | @property
130 | def label_shape(self):
131 | if self._label_shape is None:
132 | raw_labels = self._get_raw_labels()
133 | if raw_labels.dtype == np.int64:
134 | self._label_shape = [int(np.max(raw_labels)) + 1]
135 | else:
136 | self._label_shape = raw_labels.shape[1:]
137 | return list(self._label_shape)
138 |
139 | @property
140 | def label_dim(self):
141 | assert len(self.label_shape) == 1
142 | return self.label_shape[0]
143 |
144 | @property
145 | def has_labels(self):
146 | return any(x != 0 for x in self.label_shape)
147 |
148 | @property
149 | def has_onehot_labels(self):
150 | return self._get_raw_labels().dtype == np.int64
151 |
152 | #----------------------------------------------------------------------------
153 |
154 | class ImageFolderDataset(Dataset):
155 | def __init__(self,
156 | path, # Path to directory or zip.
157 | resolution = None, # Ensure specific resolution, None = highest available.
158 | **super_kwargs, # Additional arguments for the Dataset base class.
159 | ):
160 | self._path = path
161 | self._zipfile = None
162 |
163 | if os.path.isdir(self._path):
164 | self._type = 'dir'
165 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
166 | elif self._file_ext(self._path) == '.zip':
167 | self._type = 'zip'
168 | self._all_fnames = set(self._get_zipfile().namelist())
169 | else:
170 | raise IOError('Path must point to a directory or zip')
171 |
172 | PIL.Image.init()
173 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
174 | if len(self._image_fnames) == 0:
175 | raise IOError('No image files found in the specified path')
176 |
177 | name = os.path.splitext(os.path.basename(self._path))[0]
178 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
179 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
180 | raise IOError('Image files do not match the specified resolution')
181 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
182 |
183 | @staticmethod
184 | def _file_ext(fname):
185 | return os.path.splitext(fname)[1].lower()
186 |
187 | def _get_zipfile(self):
188 | assert self._type == 'zip'
189 | if self._zipfile is None:
190 | self._zipfile = zipfile.ZipFile(self._path)
191 | return self._zipfile
192 |
193 | def _open_file(self, fname):
194 | if self._type == 'dir':
195 | return open(os.path.join(self._path, fname), 'rb')
196 | if self._type == 'zip':
197 | return self._get_zipfile().open(fname, 'r')
198 | return None
199 |
200 | def close(self):
201 | try:
202 | if self._zipfile is not None:
203 | self._zipfile.close()
204 | finally:
205 | self._zipfile = None
206 |
207 | def __getstate__(self):
208 | return dict(super().__getstate__(), _zipfile=None)
209 |
210 | def _load_raw_image(self, raw_idx):
211 | fname = self._image_fnames[raw_idx]
212 | with self._open_file(fname) as f:
213 | if pyspng is not None and self._file_ext(fname) == '.png':
214 | image = pyspng.load(f.read())
215 | else:
216 | image = np.array(PIL.Image.open(f))
217 | if image.ndim == 2:
218 | image = image[:, :, np.newaxis] # HW => HWC
219 | image = image.transpose(2, 0, 1) # HWC => CHW
220 | return image
221 |
222 | def _load_raw_labels(self):
223 | fname = 'dataset.json'
224 | if fname not in self._all_fnames:
225 | return None
226 | with self._open_file(fname) as f:
227 | labels = json.load(f)['labels']
228 | if labels is None:
229 | return None
230 | labels = dict(labels)
231 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
232 | labels = np.array(labels)
233 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
234 | return labels
235 |
236 | #----------------------------------------------------------------------------
237 | #----------------------------------------------------------------------------
238 |
239 | class PairedImageFolderDataset(Dataset):
240 | def __init__(self,
241 | path, # Path to directory or zip.
242 | resolution = None, # Ensure specific resolution, None = highest available.
243 | **super_kwargs, # Additional arguments for the Dataset base class.
244 | ):
245 | self._rootpath = path
246 | self._path = os.path.join(path, 'images')
247 | self._labelpath = os.path.join(path, 'annotations')
248 |
249 | self._zipfile = None
250 |
251 | if os.path.isdir(self._path):
252 | self._type = 'dir'
253 | # image path
254 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
255 | else:
256 | raise IOError('Path must point to a directory or zip')
257 |
258 | PIL.Image.init()
259 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
260 | if len(self._image_fnames) == 0:
261 | raise IOError('No image files found in the specified path')
262 |
263 | name = os.path.splitext(os.path.basename(self._path))[0]
264 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
265 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
266 | raise IOError('Image files do not match the specified resolution')
267 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
268 |
269 | @staticmethod
270 | def _file_ext(fname):
271 | return os.path.splitext(fname)[1].lower()
272 |
273 | def _get_zipfile(self):
274 | assert self._type == 'zip'
275 | if self._zipfile is None:
276 | self._zipfile = zipfile.ZipFile(self._path)
277 | return self._zipfile
278 |
279 | def _open_file(self, fname):
280 | if self._type == 'dir':
281 | return open(os.path.join(self._path, fname), 'rb')
282 | if self._type == 'zip':
283 | return self._get_zipfile().open(fname, 'r')
284 | return None
285 |
286 | def close(self):
287 | try:
288 | if self._zipfile is not None:
289 | self._zipfile.close()
290 | finally:
291 | self._zipfile = None
292 |
293 | def __getstate__(self):
294 | return dict(super().__getstate__(), _zipfile=None)
295 |
296 | def _load_raw_image(self, raw_idx):
297 | fname = self._image_fnames[raw_idx]
298 | with self._open_file(fname) as f:
299 | if pyspng is not None and self._file_ext(fname) == '.png':
300 | image = pyspng.load(f.read())
301 | else:
302 | image = np.array(PIL.Image.open(f))
303 | if image.ndim == 2:
304 | image = image[:, :, np.newaxis] # HW => HWC
305 | image = image.transpose(2, 0, 1) # HWC => CHW
306 | return image
307 |
308 | def _load_raw_labels(self):
309 | return None
310 |
311 | def _load_raw_labelmap(self, raw_idx):
312 | fname = self._image_fnames[raw_idx].replace('jpg', 'png')
313 | with open(os.path.join(self._labelpath, fname), 'rb') as f:
314 | image = np.array(PIL.Image.open(f))
315 | return image
316 |
317 | def __getitem__(self, idx):
318 | image = self._load_raw_image(self._raw_idx[idx])
319 | assert isinstance(image, np.ndarray)
320 | assert list(image.shape) == self.image_shape
321 | assert image.dtype == np.uint8
322 | label = self._load_raw_labelmap(self._raw_idx[idx])
323 | if self._xflip[idx]:
324 | assert image.ndim == 3 # CHW
325 | image = image[:, :, ::-1]
326 | label = label[:, ::-1]
327 | return image.copy(), label.copy()
328 |
--------------------------------------------------------------------------------
/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 | from torch_utils import training_stats
12 | from torch_utils import misc
13 | from torch_utils.ops import conv2d_gradfix
14 |
15 | #----------------------------------------------------------------------------
16 |
17 | class Loss:
18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
19 | raise NotImplementedError()
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | class StyleGAN2Loss(Loss):
24 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
25 | super().__init__()
26 | self.device = device
27 | self.G_mapping = G_mapping
28 | self.G_synthesis = G_synthesis
29 | self.D = D
30 | self.augment_pipe = augment_pipe
31 | self.style_mixing_prob = style_mixing_prob
32 | self.r1_gamma = r1_gamma
33 | self.pl_batch_shrink = pl_batch_shrink
34 | self.pl_decay = pl_decay
35 | self.pl_weight = pl_weight
36 | self.pl_mean = torch.zeros([], device=device)
37 |
38 | def run_G(self, z, c, sync):
39 | with misc.ddp_sync(self.G_mapping, sync):
40 | ws = self.G_mapping(z, c)
41 | if self.style_mixing_prob > 0:
42 | with torch.autograd.profiler.record_function('style_mixing'):
43 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
44 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
45 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
46 | with misc.ddp_sync(self.G_synthesis, sync):
47 | img = self.G_synthesis(ws)
48 | return img, ws
49 |
50 | def run_D(self, img, c, sync):
51 | if self.augment_pipe is not None:
52 | img = self.augment_pipe(img)
53 | with misc.ddp_sync(self.D, sync):
54 | logits = self.D(img, c)
55 | return logits
56 |
57 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
58 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
59 | do_Gmain = (phase in ['Gmain', 'Gboth'])
60 | do_Dmain = (phase in ['Dmain', 'Dboth'])
61 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
62 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
63 |
64 | # Gmain: Maximize logits for generated images.
65 | if do_Gmain:
66 | with torch.autograd.profiler.record_function('Gmain_forward'):
67 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
68 | gen_logits = self.run_D(gen_img, gen_c, sync=False)
69 | training_stats.report('Loss/scores/fake', gen_logits)
70 | training_stats.report('Loss/signs/fake', gen_logits.sign())
71 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
72 | training_stats.report('Loss/G/loss', loss_Gmain)
73 | with torch.autograd.profiler.record_function('Gmain_backward'):
74 | loss_Gmain.mean().mul(gain).backward()
75 |
76 | # Gpl: Apply path length regularization.
77 | if do_Gpl:
78 | with torch.autograd.profiler.record_function('Gpl_forward'):
79 | batch_size = gen_z.shape[0] // self.pl_batch_shrink
80 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
81 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
82 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
83 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
84 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
85 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
86 | self.pl_mean.copy_(pl_mean.detach())
87 | pl_penalty = (pl_lengths - pl_mean).square()
88 | training_stats.report('Loss/pl_penalty', pl_penalty)
89 | loss_Gpl = pl_penalty * self.pl_weight
90 | training_stats.report('Loss/G/reg', loss_Gpl)
91 | with torch.autograd.profiler.record_function('Gpl_backward'):
92 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
93 |
94 | # Dmain: Minimize logits for generated images.
95 | loss_Dgen = 0
96 | if do_Dmain:
97 | with torch.autograd.profiler.record_function('Dgen_forward'):
98 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
99 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
100 | training_stats.report('Loss/scores/fake', gen_logits)
101 | training_stats.report('Loss/signs/fake', gen_logits.sign())
102 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
103 | with torch.autograd.profiler.record_function('Dgen_backward'):
104 | loss_Dgen.mean().mul(gain).backward()
105 |
106 | # Dmain: Maximize logits for real images.
107 | # Dr1: Apply R1 regularization.
108 | if do_Dmain or do_Dr1:
109 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
110 | with torch.autograd.profiler.record_function(name + '_forward'):
111 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
112 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
113 | training_stats.report('Loss/scores/real', real_logits)
114 | training_stats.report('Loss/signs/real', real_logits.sign())
115 |
116 | loss_Dreal = 0
117 | if do_Dmain:
118 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
119 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
120 |
121 | loss_Dr1 = 0
122 | if do_Dr1:
123 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
124 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
125 | r1_penalty = r1_grads.square().sum([1,2,3])
126 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
127 | training_stats.report('Loss/r1_penalty', r1_penalty)
128 | training_stats.report('Loss/D/reg', loss_Dr1)
129 |
130 | with torch.autograd.profiler.record_function(name + '_backward'):
131 | (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
132 |
133 |
--------------------------------------------------------------------------------
/training/loss_ggdr.py:
--------------------------------------------------------------------------------
1 | # Generative Guided Discriminator Regularization(GGDR)
2 | # Copyright (c) 2022-present NAVER Corp.
3 | # Under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
4 | # Augmentation (ADA)
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch_utils import training_stats
10 | from torch_utils import misc
11 | from torch_utils.ops import conv2d_gradfix
12 | from training.loss import Loss
13 |
14 |
15 | class StyleGAN2GGDRLoss(Loss):
16 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, ggdr_res=64):
17 | super().__init__()
18 | self.device = device
19 | self.G_mapping = G_mapping
20 | self.G_synthesis = G_synthesis
21 | self.D = D
22 | self.augment_pipe = augment_pipe
23 | self.style_mixing_prob = style_mixing_prob
24 | self.r1_gamma = r1_gamma
25 | self.pl_batch_shrink = pl_batch_shrink
26 | self.pl_decay = pl_decay
27 | self.pl_weight = pl_weight
28 | self.pl_mean = torch.zeros([], device=device)
29 | self.ggdr_res = [ggdr_res]
30 |
31 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device)
32 |
33 | def run_G(self, z, c, ws=None, sync=True):
34 | with misc.ddp_sync(self.G_mapping, sync):
35 | if ws is None:
36 | ws = self.G_mapping(z, c)
37 | if self.style_mixing_prob > 0:
38 | with torch.autograd.profiler.record_function('style_mixing'):
39 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
40 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
41 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
42 | with misc.ddp_sync(self.G_synthesis, sync):
43 | img, output_feat = self.G_synthesis(ws, get_feat=True)
44 | return img, ws, output_feat
45 |
46 | def run_aug_if_needed(self, img, gfeats):
47 | """
48 | Augment image and feature map consistently
49 | """
50 | if self.augment_pipe is not None:
51 | aug_img, gfeats = self.augment_pipe(img, gfeats)
52 | else:
53 | aug_img = img
54 | return aug_img, gfeats
55 |
56 | def run_D(self, img, c, gfeats=None, sync=None):
57 | aug_img, gfeats = self.run_aug_if_needed(img, gfeats)
58 | with misc.ddp_sync(self.D, sync):
59 | logits, out = self.D(aug_img, c)
60 |
61 | return logits, out, aug_img, gfeats
62 |
63 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
64 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
65 | do_Gmain = (phase in ['Gmain', 'Gboth'])
66 | do_Dmain = (phase in ['Dmain', 'Dboth'])
67 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
68 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
69 |
70 | # Gmain: Maximize logits for generated images.
71 | if do_Gmain:
72 | with torch.autograd.profiler.record_function('Gmain_forward'):
73 | gen_img, _gen_ws, _gen_feat = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl))
74 | gen_logits, _recon_gen_fmaps, _, _ = self.run_D(gen_img, gen_c, sync=False)
75 | training_stats.report('Loss/scores/fake', gen_logits)
76 | training_stats.report('Loss/signs/fake', gen_logits.sign())
77 |
78 | loss_Gmain = torch.nn.functional.softplus(-gen_logits)
79 | training_stats.report('Loss/G/loss', loss_Gmain)
80 | with torch.autograd.profiler.record_function('Gmain_backward'):
81 | loss_Gmain.mean().mul(gain).backward()
82 |
83 | # Gpl: Apply path length regularization.
84 | if do_Gpl:
85 | with torch.autograd.profiler.record_function('Gpl_forward'):
86 | batch_size = gen_z.shape[0] // self.pl_batch_shrink
87 | gen_img, gen_ws, gen_fmaps = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
88 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
89 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
90 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
91 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
92 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
93 | self.pl_mean.copy_(pl_mean.detach())
94 | pl_penalty = (pl_lengths - pl_mean).square()
95 | training_stats.report('Loss/pl_penalty', pl_penalty)
96 | loss_Gpl = pl_penalty * self.pl_weight
97 | training_stats.report('Loss/G/reg', loss_Gpl)
98 | with torch.autograd.profiler.record_function('Gpl_backward'):
99 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
100 |
101 | # Dmain: Minimize logits for generated images.
102 | if do_Dmain:
103 | with torch.autograd.profiler.record_function('Dgen_forward'):
104 | # recon fake features and w
105 | gen_img, _gen_ws, gen_fmaps = self.run_G(gen_z, gen_c, sync=sync)
106 |
107 | aug_gen_logits, aug_recon_gen_fmaps, aug_gen_img, aug_fmaps = \
108 | self.run_D(gen_img, gen_c, gen_fmaps, sync=sync)
109 |
110 | loss_gan_gen = torch.nn.functional.softplus(aug_gen_logits) + \
111 | aug_recon_gen_fmaps[max(aug_recon_gen_fmaps.keys())][:, 0, 0, 0] * 0
112 |
113 | loss_gen_reg = self.get_ggdr_reg(self.ggdr_res, aug_recon_gen_fmaps, aug_fmaps)
114 |
115 | loss_Dmain = loss_gan_gen + loss_gen_reg
116 |
117 | training_stats.report('Loss/D/loss_gan_gen', loss_gan_gen)
118 | training_stats.report('Loss/D/loss_gen_reg', loss_gen_reg)
119 |
120 | with torch.autograd.profiler.record_function('Dgen_backward'):
121 | loss_Dmain.mean().mul(gain).backward()
122 |
123 | # Dmain: Maximize logits for real images.
124 | # Dr1: Apply R1 regularization.
125 | if do_Dmain or do_Dr1:
126 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
127 | with torch.autograd.profiler.record_function(name + '_forward'):
128 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
129 | real_logits, aug_recon_real_fmaps, _, _ = self.run_D(real_img_tmp, real_c, sync=sync)
130 | training_stats.report('Loss/scores/real', real_logits)
131 | training_stats.report('Loss/signs/real', real_logits.sign())
132 |
133 | loss_Dreal = 0
134 | if do_Dmain:
135 | loss_Dreal = torch.nn.functional.softplus(-real_logits)
136 | training_stats.report(f'Loss/D/loss', loss_Dreal + loss_Dreal)
137 |
138 | loss_Dr1 = 0
139 | if do_Dr1:
140 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
141 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
142 | r1_penalty = r1_grads.square().sum([1, 2, 3])
143 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
144 | training_stats.report('Loss/r1_penalty', r1_penalty)
145 | training_stats.report('Loss/D/reg', loss_Dr1)
146 |
147 | # collect not used branch for DDP training
148 | loss_not_used = aug_recon_real_fmaps[max(aug_recon_real_fmaps.keys())][:, 0, 0, 0] * 0
149 |
150 | with torch.autograd.profiler.record_function(name + '_backward'):
151 | (loss_Dreal + loss_Dr1 + real_logits * 0 + loss_not_used * 0).mean().mul(gain).backward()
152 |
153 | def cosine_distance(self, x, y):
154 | return 1. - F.cosine_similarity(x, y).mean()
155 |
156 | def get_ggdr_reg(self, ggdr_resolutions, source, target):
157 | loss_gen_recon = 0
158 |
159 | for res in ggdr_resolutions:
160 | loss_gen_recon += 10 * self.cosine_distance(source[res], target[res]) / len(ggdr_resolutions)
161 |
162 | return loss_gen_recon
163 |
--------------------------------------------------------------------------------
/training/networks_ggdr.py:
--------------------------------------------------------------------------------
1 | # Generative Guided Discriminator Regularization(GGDR)
2 | # Copyright (c) 2022-present NAVER Corp.
3 | # Under NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
4 | # Augmentation (ADA)
5 |
6 | import numpy as np
7 | import torch
8 | from torch_utils import misc
9 | from torch_utils import persistence
10 | from training.networks import Conv2dLayer, MappingNetwork, DiscriminatorBlock, DiscriminatorEpilogue
11 | from training.networks import SynthesisNetwork as OrigSynthesisNetwork
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | @persistence.persistent_class
16 | class SynthesisNetwork(OrigSynthesisNetwork):
17 | def __init__(self,
18 | w_dim, # Intermediate latent (W) dimensionality.
19 | img_resolution, # Output image resolution.
20 | img_channels, # Number of color channels.
21 | channel_base = 32768, # Overall multiplier for the number of channels.
22 | channel_max = 512, # Maximum number of channels in any layer.
23 | num_fp16_res = 0, # Use FP16 for the N highest resolutions.
24 | **block_kwargs, # Arguments for SynthesisBlock.
25 | ):
26 | super().__init__(w_dim, img_resolution, img_channels, channel_base, channel_max, num_fp16_res, **block_kwargs)
27 |
28 | def forward(self, ws, get_feat=False, **block_kwargs):
29 | block_ws = []
30 | with torch.autograd.profiler.record_function('split_ws'):
31 | misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
32 | ws = ws.to(torch.float32)
33 | w_idx = 0
34 | for res in self.block_resolutions:
35 | block = getattr(self, f'b{res}')
36 | block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
37 | w_idx += block.num_conv
38 |
39 | x = img = None
40 |
41 | feats = {}
42 | for res, cur_ws in zip(self.block_resolutions, block_ws):
43 | block = getattr(self, f'b{res}')
44 | x, img = block(x, img, cur_ws, **block_kwargs)
45 |
46 | if get_feat:
47 | feats[res] = x.float()
48 |
49 | if get_feat:
50 | return img, feats
51 | else:
52 | return img
53 |
54 | #----------------------------------------------------------------------------
55 |
56 | @persistence.persistent_class
57 | class Generator(torch.nn.Module):
58 | def __init__(self,
59 | z_dim, # Input latent (Z) dimensionality.
60 | c_dim, # Conditioning label (C) dimensionality.
61 | w_dim, # Intermediate latent (W) dimensionality.
62 | img_resolution, # Output resolution.
63 | img_channels, # Number of output color channels.
64 | mapping_kwargs = {}, # Arguments for MappingNetwork.
65 | synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
66 | ):
67 | super().__init__()
68 | self.z_dim = z_dim
69 | self.c_dim = c_dim
70 | self.w_dim = w_dim
71 | self.img_resolution = img_resolution
72 | self.img_channels = img_channels
73 | self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
74 | self.num_ws = self.synthesis.num_ws
75 | self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
76 |
77 | def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
78 | ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
79 | img = self.synthesis(ws, **synthesis_kwargs)
80 | return img
81 |
82 | #----------------------------------------------------------------------------
83 |
84 | @persistence.persistent_class
85 | class Discriminator(torch.nn.Module):
86 | def __init__(self,
87 | c_dim, # Conditioning label (C) dimensionality.
88 | img_resolution, # Input resolution.
89 | img_channels, # Number of input color channels.
90 | architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
91 | channel_base = 32768, # Overall multiplier for the number of channels.
92 | channel_max = 512, # Maximum number of channels in any layer.
93 | num_fp16_res = 0, # Use FP16 for the N highest resolutions.
94 | conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
95 | cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
96 | block_kwargs = {}, # Arguments for DiscriminatorBlock.
97 | mapping_kwargs = {}, # Arguments for MappingNetwork.
98 | epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
99 | w_dim = 512,
100 | decoder_res = 64,
101 | ):
102 | super().__init__()
103 | self.c_dim = c_dim
104 | self.img_resolution = img_resolution
105 | self.img_resolution_log2 = int(np.log2(img_resolution))
106 | self.img_channels = img_channels
107 | self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
108 | channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
109 | fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
110 | self.fp16_resolution = fp16_resolution
111 |
112 | if cmap_dim is None:
113 | cmap_dim = channels_dict[4]
114 | if c_dim == 0:
115 | cmap_dim = 0
116 |
117 | common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
118 | cur_layer_idx = 0
119 | for res in self.block_resolutions:
120 | in_channels = channels_dict[res] if res < img_resolution else 0
121 | tmp_channels = channels_dict[res]
122 | out_channels = channels_dict[res // 2]
123 | use_fp16 = (res >= fp16_resolution)
124 | block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
125 | first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
126 | setattr(self, f'b{res}', block)
127 | cur_layer_idx += block.num_layers
128 |
129 | if c_dim > 0:
130 | self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
131 |
132 | self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
133 |
134 | # *************************************************
135 | # Decoder part for GGDR loss
136 | # *************************************************
137 | dec_kernel_size = 1
138 | self.dec_resolutions = [2 ** i for i in range(3, int(np.log2(decoder_res)) + 1)]
139 |
140 | for res in self.dec_resolutions:
141 | out_channels = channels_dict[res]
142 | in_channels = channels_dict[res // 2]
143 | if res != self.dec_resolutions[0]:
144 | in_channels *= 2
145 |
146 | block = Conv2dLayer(in_channels, out_channels, kernel_size=dec_kernel_size,
147 | activation='linear', up=2)
148 | setattr(self, f'b{res}_dec', block)
149 |
150 | def forward(self, img, c, **block_kwargs):
151 | x = None
152 | feats = {}
153 | for res in self.block_resolutions:
154 | block = getattr(self, f'b{res}')
155 | x, img = block(x, img, **block_kwargs)
156 | feats[res // 2] = x # keep feature maps for unet decoder
157 |
158 | cmap = None
159 | if self.c_dim > 0:
160 | cmap = self.mapping(None, c)
161 |
162 | logits = self.b4(x, img, cmap) # original real/fake logits
163 |
164 | # Run decoder part
165 | fmaps = {}
166 | for idx, res in enumerate(self.dec_resolutions):
167 | block = getattr(self, f'b{res}_dec')
168 | if idx == 0:
169 | y = feats[res // 2]
170 | else:
171 | y = torch.cat([y, feats[res // 2]], dim=1)
172 | y = block(y)
173 | fmaps[res] = y
174 |
175 | return logits, fmaps
176 |
177 | #----------------------------------------------------------------------------
178 |
--------------------------------------------------------------------------------