├── comp_gen.jpg
├── StyleGAN2
├── dnnlib
│ ├── __pycache__
│ │ ├── util.cpython-36.pyc
│ │ └── __init__.cpython-36.pyc
│ └── __init__.py
├── training
│ ├── __pycache__
│ │ ├── loss.cpython-36.pyc
│ │ ├── loss_p.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── augment.cpython-36.pyc
│ │ ├── dataset.cpython-36.pyc
│ │ ├── loss_mm.cpython-36.pyc
│ │ ├── loss_mvm.cpython-36.pyc
│ │ ├── networks.cpython-36.pyc
│ │ ├── loss_hinge.cpython-36.pyc
│ │ ├── loss_interp.cpython-36.pyc
│ │ ├── networks_mm.cpython-36.pyc
│ │ ├── training_loop.cpython-36.pyc
│ │ └── training_loop_mm.cpython-36.pyc
│ ├── __init__.py
│ ├── loss.py
│ ├── dataset.py
│ └── loss_interp.py
├── metrics
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── metric_main.cpython-36.pyc
│ │ ├── metric_utils.cpython-36.pyc
│ │ ├── inception_score.cpython-36.pyc
│ │ ├── precision_recall.cpython-36.pyc
│ │ ├── perceptual_path_length.cpython-36.pyc
│ │ ├── kernel_inception_distance.cpython-36.pyc
│ │ └── frechet_inception_distance.cpython-36.pyc
│ ├── __init__.py
│ ├── inception_score.py
│ ├── frechet_inception_distance.py
│ ├── kernel_inception_distance.py
│ ├── precision_recall.py
│ ├── perceptual_path_length.py
│ ├── metric_main.py
│ └── metric_utils.py
├── torch_utils
│ ├── __pycache__
│ │ ├── misc.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── custom_ops.cpython-36.pyc
│ │ ├── persistence.cpython-36.pyc
│ │ └── training_stats.cpython-36.pyc
│ ├── ops
│ │ ├── __pycache__
│ │ │ ├── fma.cpython-36.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── bias_act.cpython-36.pyc
│ │ │ ├── upfirdn2d.cpython-36.pyc
│ │ │ ├── conv2d_gradfix.cpython-36.pyc
│ │ │ ├── conv2d_resample.cpython-36.pyc
│ │ │ └── grid_sample_gradfix.cpython-36.pyc
│ │ ├── __init__.py
│ │ ├── bias_act.h
│ │ ├── upfirdn2d.h
│ │ ├── fma.py
│ │ ├── grid_sample_gradfix.py
│ │ ├── bias_act.cpp
│ │ ├── upfirdn2d.cpp
│ │ ├── bias_act.cu
│ │ ├── conv2d_resample.py
│ │ ├── conv2d_gradfix.py
│ │ └── bias_act.py
│ ├── __init__.py
│ ├── custom_ops.py
│ ├── persistence.py
│ ├── training_stats.py
│ └── misc.py
├── Dockerfile
├── docker_run.sh
├── LICENSE.txt
├── style_mixing.py
├── generate.py
├── calc_metrics.py
└── projector.py
├── README.md
└── interp_feature.py
/comp_gen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/comp_gen.jpg
--------------------------------------------------------------------------------
/StyleGAN2/dnnlib/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/dnnlib/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/dnnlib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/dnnlib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/loss_p.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/loss_p.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/__pycache__/misc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/__pycache__/misc.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/augment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/augment.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/loss_mm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/loss_mm.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/loss_mvm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/loss_mvm.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/metric_main.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/metric_main.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/metric_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/metric_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/fma.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/fma.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/loss_hinge.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/loss_hinge.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/loss_interp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/loss_interp.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/networks_mm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/networks_mm.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/inception_score.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/inception_score.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/__pycache__/custom_ops.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/__pycache__/custom_ops.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/__pycache__/persistence.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/__pycache__/persistence.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/training_loop.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/training_loop.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/precision_recall.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/precision_recall.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/bias_act.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/bias_act.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/training/__pycache__/training_loop_mm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/training/__pycache__/training_loop_mm.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/__pycache__/training_stats.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/__pycache__/training_stats.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/perceptual_path_length.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/perceptual_path_length.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/kernel_inception_distance.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/kernel_inception_distance.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__pycache__/frechet_inception_distance.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/metrics/__pycache__/frechet_inception_distance.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dzld00/Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation/HEAD/StyleGAN2/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-36.pyc
--------------------------------------------------------------------------------
/StyleGAN2/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/StyleGAN2/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/StyleGAN2/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/StyleGAN2/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | FROM nvcr.io/nvidia/pytorch:20.12-py3
10 |
11 | ENV PYTHONDONTWRITEBYTECODE 1
12 | ENV PYTHONUNBUFFERED 1
13 |
14 | RUN pip install imageio-ffmpeg==0.4.3 pyspng==0.1.0
15 |
16 | WORKDIR /workspace
17 |
18 | # Unset TORCH_CUDA_ARCH_LIST and exec. This makes pytorch run-time
19 | # extension builds significantly faster as we only compile for the
20 | # currently active GPU configuration.
21 | RUN (printf '#!/bin/bash\nunset TORCH_CUDA_ARCH_LIST\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
22 | ENTRYPOINT ["/entry.sh"]
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adaptive-Feature-Interpolation-for-Low-Shot-Image-Generation (ECCV 2022)
2 |
3 | Paper: [https://arxiv.org/abs/2106.10777](https://arxiv.org/abs/2112.02450)
4 |
5 | Below are generated images trained on datasets with 60-1000 samples. Fast access to datasets: https://github.com/odegeasslbc/FastGAN-pytorch (Liu's repo).
6 |
7 |
8 |
9 |
10 | (Left: StyleGAN2-ADA; Right: + Adaptive Feature Interpolation)
11 |
12 |
13 | # Usage:
14 | ### Adaptive Feature Interpolation
15 | Create a batch of new features from a batch of old features:
16 | ```
17 | new_feature = near_interp(old_feature, k, augment_prob)
18 | ```
19 | where k, augment_prob can be generated by function `dynamic_prob` or defined by user. Please refer to `interp_feature.py` for more details. Example of implementation in StyleGAN2 can be found in the corresponding folder.
20 |
21 | # Citation
22 | ```
23 | @InProceedings{10.1007/978-3-031-19784-0_15,
24 | author="Dai, Mengyu
25 | and Hang, Haibin
26 | and Guo, Xiaoyang",
27 | title="Adaptive Feature Interpolation for Low-Shot Image Generation",
28 | booktitle="Computer Vision -- ECCV 2022",
29 | year="2022",
30 | publisher="Springer Nature Switzerland",
31 | address="Cham",
32 | pages="254--270"
33 | }
34 | ```
35 |
--------------------------------------------------------------------------------
/StyleGAN2/docker_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | set -e
12 |
13 | # Wrapper script for setting up `docker run` to properly
14 | # cache downloaded files, custom extension builds and
15 | # mount the source directory into the container and make it
16 | # run as non-root user.
17 | #
18 | # Use it like:
19 | #
20 | # ./docker_run.sh python generate.py --help
21 | #
22 | # To override the default `stylegan2ada:latest` image, run:
23 | #
24 | # IMAGE=my_image:v1.0 ./docker_run.sh python generate.py --help
25 | #
26 |
27 | rest=$@
28 |
29 | IMAGE="${IMAGE:-sg2ada:latest}"
30 |
31 | CONTAINER_ID=$(docker inspect --format="{{.Id}}" ${IMAGE} 2> /dev/null)
32 | if [[ "${CONTAINER_ID}" ]]; then
33 | docker run --shm-size=2g --gpus all -it --rm -v `pwd`:/scratch --user $(id -u):$(id -g) \
34 | --workdir=/scratch -e HOME=/scratch $IMAGE $@
35 | else
36 | echo "Unknown container image: ${IMAGE}"
37 | exit 1
38 | fi
39 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Inception Score (IS) from the paper "Improved techniques for training
10 | GANs". Matches the original implementation by Salimans et al. at
11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12 |
13 | import numpy as np
14 | from . import metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_is(opts, num_gen, num_splits):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22 |
23 | gen_probs = metric_utils.compute_feature_stats_for_generator(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | capture_all=True, max_items=num_gen).get_all()
26 |
27 | if opts.rank != 0:
28 | return float('nan'), float('nan')
29 |
30 | scores = []
31 | for i in range(num_splits):
32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34 | kl = np.mean(np.sum(kl, axis=1))
35 | scores.append(np.exp(kl))
36 | return float(np.mean(scores)), float(np.std(scores))
37 |
38 | #----------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Frechet Inception Distance (FID) from the paper
10 | "GANs trained by a two time-scale update rule converge to a local Nash
11 | equilibrium". Matches the original implementation by Heusel et al. at
12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13 |
14 | import numpy as np
15 | import scipy.linalg
16 | from . import metric_utils
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_fid(opts, max_real, num_gen):
21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24 |
25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28 |
29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32 |
33 | if opts.rank != 0:
34 | return float('nan')
35 |
36 | m = np.square(mu_gen - mu_real).sum()
37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39 | return float(fid)
40 |
41 | #----------------------------------------------------------------------------
42 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/kernel_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10 | GANs". Matches the original implementation by Binkowski et al. at
11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12 |
13 | import numpy as np
14 | from . import metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22 |
23 | real_features = metric_utils.compute_feature_stats_for_dataset(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26 |
27 | gen_features = metric_utils.compute_feature_stats_for_generator(
28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30 |
31 | if opts.rank != 0:
32 | return float('nan')
33 |
34 | n = real_features.shape[1]
35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36 | t = 0
37 | for _subset_idx in range(num_subsets):
38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41 | b = (x @ y.T / n + 1) ** 3
42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43 | kid = t / num_subsets / m
44 | return float(kid)
45 |
46 | #----------------------------------------------------------------------------
47 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import warnings
15 | import torch
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | def grid_sample(input, grid):
28 | if _should_use_custom_op():
29 | return _GridSample2dForward.apply(input, grid)
30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def _should_use_custom_op():
35 | if not enabled:
36 | return False
37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38 | return True
39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40 | return False
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | class _GridSample2dForward(torch.autograd.Function):
45 | @staticmethod
46 | def forward(ctx, input, grid):
47 | assert input.ndim == 4
48 | assert grid.ndim == 4
49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50 | ctx.save_for_backward(input, grid)
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | input, grid = ctx.saved_tensors
56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57 | return grad_input, grad_grid
58 |
59 | #----------------------------------------------------------------------------
60 |
61 | class _GridSample2dBackward(torch.autograd.Function):
62 | @staticmethod
63 | def forward(ctx, grad_output, input, grid):
64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/precision_recall.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall
10 | Metric for Assessing Generative Models". Matches the original implementation
11 | by Kynkaanniemi et al. at
12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13 |
14 | import torch
15 | from . import metric_utils
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20 | assert 0 <= rank < num_gpus
21 | num_cols = col_features.shape[0]
22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24 | dist_batches = []
25 | for col_batch in col_batches[rank :: num_gpus]:
26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27 | for src in range(num_gpus):
28 | dist_broadcast = dist_batch.clone()
29 | if num_gpus > 1:
30 | torch.distributed.broadcast(dist_broadcast, src=src)
31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38 | detector_kwargs = dict(return_features=True)
39 |
40 | real_features = metric_utils.compute_feature_stats_for_dataset(
41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43 |
44 | gen_features = metric_utils.compute_feature_stats_for_generator(
45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47 |
48 | results = dict()
49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50 | kth = []
51 | for manifold_batch in manifold.split(row_batch_size):
52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54 | kth = torch.cat(kth) if opts.rank == 0 else None
55 | pred = []
56 | for probes_batch in probes.split(row_batch_size):
57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60 | return results['precision'], results['recall']
61 |
62 | #----------------------------------------------------------------------------
63 |
--------------------------------------------------------------------------------
/StyleGAN2/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
2 |
3 |
4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
5 |
6 |
7 | =======================================================================
8 |
9 | 1. Definitions
10 |
11 | "Licensor" means any person or entity that distributes its Work.
12 |
13 | "Software" means the original work of authorship made available under
14 | this License.
15 |
16 | "Work" means the Software and any additions to or derivative works of
17 | the Software that are made available under this License.
18 |
19 | The terms "reproduce," "reproduction," "derivative works," and
20 | "distribution" have the meaning as provided under U.S. copyright law;
21 | provided, however, that for the purposes of this License, derivative
22 | works shall not include works that remain separable from, or merely
23 | link (or bind by name) to the interfaces of, the Work.
24 |
25 | Works, including the Software, are "made available" under this License
26 | by including in or with the Work either (a) a copyright notice
27 | referencing the applicability of this License to the Work, or (b) a
28 | copy of this License.
29 |
30 | 2. License Grants
31 |
32 | 2.1 Copyright Grant. Subject to the terms and conditions of this
33 | License, each Licensor grants to you a perpetual, worldwide,
34 | non-exclusive, royalty-free, copyright license to reproduce,
35 | prepare derivative works of, publicly display, publicly perform,
36 | sublicense and distribute its Work and any resulting derivative
37 | works in any form.
38 |
39 | 3. Limitations
40 |
41 | 3.1 Redistribution. You may reproduce or distribute the Work only
42 | if (a) you do so under this License, (b) you include a complete
43 | copy of this License with your distribution, and (c) you retain
44 | without modification any copyright, patent, trademark, or
45 | attribution notices that are present in the Work.
46 |
47 | 3.2 Derivative Works. You may specify that additional or different
48 | terms apply to the use, reproduction, and distribution of your
49 | derivative works of the Work ("Your Terms") only if (a) Your Terms
50 | provide that the use limitation in Section 3.3 applies to your
51 | derivative works, and (b) you identify the specific derivative
52 | works that are subject to Your Terms. Notwithstanding Your Terms,
53 | this License (including the redistribution requirements in Section
54 | 3.1) will continue to apply to the Work itself.
55 |
56 | 3.3 Use Limitation. The Work and any derivative works thereof only
57 | may be used or intended for use non-commercially. Notwithstanding
58 | the foregoing, NVIDIA and its affiliates may use the Work and any
59 | derivative works commercially. As used herein, "non-commercially"
60 | means for research or evaluation purposes only.
61 |
62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63 | against any Licensor (including any claim, cross-claim or
64 | counterclaim in a lawsuit) to enforce any patents that you allege
65 | are infringed by any Work, then your rights under this License from
66 | such Licensor (including the grant in Section 2.1) will terminate
67 | immediately.
68 |
69 | 3.5 Trademarks. This License does not grant any rights to use any
70 | Licensor’s or its affiliates’ names, logos, or trademarks, except
71 | as necessary to reproduce the notices described in this License.
72 |
73 | 3.6 Termination. If you violate any term of this License, then your
74 | rights under this License (including the grant in Section 2.1) will
75 | terminate immediately.
76 |
77 | 4. Disclaimer of Warranty.
78 |
79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83 | THIS LICENSE.
84 |
85 | 5. Limitation of Liability.
86 |
87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95 | THE POSSIBILITY OF SUCH DAMAGES.
96 |
97 | =======================================================================
98 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29 |
30 | // Create output tensor.
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37 |
38 | // Initialize CUDA kernel parameters.
39 | upfirdn2d_kernel_params p;
40 | p.x = x.data_ptr();
41 | p.f = f.data_ptr();
42 | p.y = y.data_ptr();
43 | p.up = make_int2(upx, upy);
44 | p.down = make_int2(downx, downy);
45 | p.pad0 = make_int2(padx0, pady0);
46 | p.flip = (flip) ? 1 : 0;
47 | p.gain = gain;
48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56 |
57 | // Choose CUDA kernel.
58 | upfirdn2d_kernel_spec spec;
59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60 | {
61 | spec = choose_upfirdn2d_kernel(p);
62 | });
63 |
64 | // Set looping options.
65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66 | p.loopMinor = spec.loopMinor;
67 | p.loopX = spec.loopX;
68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70 |
71 | // Compute grid size.
72 | dim3 blockSize, gridSize;
73 | if (spec.tileOutW < 0) // large
74 | {
75 | blockSize = dim3(4, 32, 1);
76 | gridSize = dim3(
77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79 | p.launchMajor);
80 | }
81 | else // small
82 | {
83 | blockSize = dim3(256, 1, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 |
90 | // Launch CUDA kernel.
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("upfirdn2d", &upfirdn2d);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/StyleGAN2/style_mixing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Generate style mixing image matrix using pretrained network pickle."""
10 |
11 | import os
12 | import re
13 | from typing import List
14 |
15 | import click
16 | import dnnlib
17 | import numpy as np
18 | import PIL.Image
19 | import torch
20 |
21 | import legacy
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | def num_range(s: str) -> List[int]:
26 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27 |
28 | range_re = re.compile(r'^(\d+)-(\d+)$')
29 | m = range_re.match(s)
30 | if m:
31 | return list(range(int(m.group(1)), int(m.group(2))+1))
32 | vals = s.split(',')
33 | return [int(x) for x in vals]
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | @click.command()
38 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
39 | @click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
40 | @click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
41 | @click.option('--styles', 'col_styles', type=num_range, help='Style layer range', default='0-6', show_default=True)
42 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
43 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44 | @click.option('--outdir', type=str, required=True)
45 | def generate_style_mix(
46 | network_pkl: str,
47 | row_seeds: List[int],
48 | col_seeds: List[int],
49 | col_styles: List[int],
50 | truncation_psi: float,
51 | noise_mode: str,
52 | outdir: str
53 | ):
54 | """Generate images using pretrained network pickle.
55 |
56 | Examples:
57 |
58 | \b
59 | python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
60 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
61 | """
62 | print('Loading networks from "%s"...' % network_pkl)
63 | device = torch.device('cuda')
64 | with dnnlib.util.open_url(network_pkl) as f:
65 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
66 |
67 | os.makedirs(outdir, exist_ok=True)
68 |
69 | print('Generating W vectors...')
70 | all_seeds = list(set(row_seeds + col_seeds))
71 | all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
72 | all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
73 | w_avg = G.mapping.w_avg
74 | all_w = w_avg + (all_w - w_avg) * truncation_psi
75 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
76 |
77 | print('Generating images...')
78 | all_images = G.synthesis(all_w, noise_mode=noise_mode)
79 | all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
80 | image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
81 |
82 | print('Generating style-mixed images...')
83 | for row_seed in row_seeds:
84 | for col_seed in col_seeds:
85 | w = w_dict[row_seed].clone()
86 | w[col_styles] = w_dict[col_seed][col_styles]
87 | image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
88 | image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
89 | image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
90 |
91 | print('Saving images...')
92 | os.makedirs(outdir, exist_ok=True)
93 | for (row_seed, col_seed), image in image_dict.items():
94 | PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
95 |
96 | print('Saving image grid...')
97 | W = G.img_resolution
98 | H = G.img_resolution
99 | canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
100 | for row_idx, row_seed in enumerate([0] + row_seeds):
101 | for col_idx, col_seed in enumerate([0] + col_seeds):
102 | if row_idx == 0 and col_idx == 0:
103 | continue
104 | key = (row_seed, col_seed)
105 | if row_idx == 0:
106 | key = (col_seed, col_seed)
107 | if col_idx == 0:
108 | key = (row_seed, row_seed)
109 | canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
110 | canvas.save(f'{outdir}/grid.png')
111 |
112 |
113 | #----------------------------------------------------------------------------
114 |
115 | if __name__ == "__main__":
116 | generate_style_mix() # pylint: disable=no-value-for-parameter
117 |
118 | #----------------------------------------------------------------------------
119 |
--------------------------------------------------------------------------------
/StyleGAN2/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Generate images using pretrained network pickle."""
10 |
11 | import os
12 | import re
13 | from typing import List, Optional
14 |
15 | import click
16 | import dnnlib
17 | import numpy as np
18 | import PIL.Image
19 | import torch
20 |
21 | import legacy
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | def num_range(s: str) -> List[int]:
26 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27 |
28 | range_re = re.compile(r'^(\d+)-(\d+)$')
29 | m = range_re.match(s)
30 | if m:
31 | return list(range(int(m.group(1)), int(m.group(2))+1))
32 | vals = s.split(',')
33 | return [int(x) for x in vals]
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | @click.command()
38 | @click.pass_context
39 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
40 | @click.option('--seeds', type=num_range, help='List of random seeds')
41 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
42 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
43 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44 | @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
45 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
46 | def generate_images(
47 | ctx: click.Context,
48 | network_pkl: str,
49 | seeds: Optional[List[int]],
50 | truncation_psi: float,
51 | noise_mode: str,
52 | outdir: str,
53 | class_idx: Optional[int],
54 | projected_w: Optional[str]
55 | ):
56 | """Generate images using pretrained network pickle.
57 |
58 | Examples:
59 |
60 | \b
61 | # Generate curated MetFaces images without truncation (Fig.10 left)
62 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
63 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
64 |
65 | \b
66 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
67 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
68 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
69 |
70 | \b
71 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
72 | python generate.py --outdir=out --seeds=0-35 --class=1 \\
73 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
74 |
75 | \b
76 | # Render an image from projected W
77 | python generate.py --outdir=out --projected_w=projected_w.npz \\
78 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
79 | """
80 | #seeds = [i for i in range(num)]
81 |
82 | print('Loading networks from "%s"...' % network_pkl)
83 | device = torch.device('cuda')
84 | with dnnlib.util.open_url(network_pkl) as f:
85 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
86 |
87 | os.makedirs(outdir, exist_ok=True)
88 |
89 | # Synthesize the result of a W projection.
90 | if projected_w is not None:
91 | if seeds is not None:
92 | print ('warn: --seeds is ignored when using --projected-w')
93 | print(f'Generating images from projected W "{projected_w}"')
94 | ws = np.load(projected_w)['w']
95 | ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
96 | assert ws.shape[1:] == (G.num_ws, G.w_dim)
97 | for idx, w in enumerate(ws):
98 | img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
99 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
100 | img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
101 | return
102 |
103 | if seeds is None:
104 | ctx.fail('--seeds option is required when not using --projected-w')
105 |
106 | # Labels.
107 | label = torch.zeros([1, G.c_dim], device=device)
108 | if G.c_dim != 0:
109 | if class_idx is None:
110 | ctx.fail('Must specify class label with --class when using a conditional network')
111 | label[:, class_idx] = 1
112 | else:
113 | if class_idx is not None:
114 | print ('warn: --class=lbl ignored when running on an unconditional network')
115 |
116 | # Generate images.
117 | for seed_idx, seed in enumerate(seeds):
118 | if seed_idx % 100 == 0:
119 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
120 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
121 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
122 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
123 | #PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
124 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.jpg')
125 |
126 |
127 | #----------------------------------------------------------------------------
128 |
129 | if __name__ == "__main__":
130 | generate_images() # pylint: disable=no-value-for-parameter
131 |
132 | #----------------------------------------------------------------------------
133 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/perceptual_path_length.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10 | Architecture for Generative Adversarial Networks". Matches the original
11 | implementation by Karras et al. at
12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13 |
14 | import copy
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 | from . import metric_utils
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | # Spherical interpolation of a batch of vectors.
23 | def slerp(a, b, t):
24 | a = a / a.norm(dim=-1, keepdim=True)
25 | b = b / b.norm(dim=-1, keepdim=True)
26 | d = (a * b).sum(dim=-1, keepdim=True)
27 | p = t * torch.acos(d)
28 | c = b - d * a
29 | c = c / c.norm(dim=-1, keepdim=True)
30 | d = a * torch.cos(p) + c * torch.sin(p)
31 | d = d / d.norm(dim=-1, keepdim=True)
32 | return d
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | class PPLSampler(torch.nn.Module):
37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38 | assert space in ['z', 'w']
39 | assert sampling in ['full', 'end']
40 | super().__init__()
41 | self.G = copy.deepcopy(G)
42 | self.G_kwargs = G_kwargs
43 | self.epsilon = epsilon
44 | self.space = space
45 | self.sampling = sampling
46 | self.crop = crop
47 | self.vgg16 = copy.deepcopy(vgg16)
48 |
49 | def forward(self, c):
50 | # Generate random latents and interpolation t-values.
51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53 |
54 | # Interpolate in W or Z.
55 | if self.space == 'w':
56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59 | else: # space == 'z'
60 | zt0 = slerp(z0, z1, t.unsqueeze(1))
61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63 |
64 | # Randomize noise buffers.
65 | for name, buf in self.G.named_buffers():
66 | if name.endswith('.noise_const'):
67 | buf.copy_(torch.randn_like(buf))
68 |
69 | # Generate images.
70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71 |
72 | # Center crop.
73 | if self.crop:
74 | assert img.shape[2] == img.shape[3]
75 | c = img.shape[2] // 8
76 | img = img[:, :, c*3 : c*7, c*2 : c*6]
77 |
78 | # Downsample to 256x256.
79 | factor = self.G.img_resolution // 256
80 | if factor > 1:
81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82 |
83 | # Scale dynamic range from [-1,1] to [0,255].
84 | img = (img + 1) * (255 / 2)
85 | if self.G.img_channels == 1:
86 | img = img.repeat([1, 3, 1, 1])
87 |
88 | # Evaluate differential LPIPS.
89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91 | return dist
92 |
93 | #----------------------------------------------------------------------------
94 |
95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99 |
100 | # Setup sampler.
101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102 | sampler.eval().requires_grad_(False).to(opts.device)
103 | if jit:
104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105 | sampler = torch.jit.trace(sampler, [c], check_trace=False)
106 |
107 | # Sampling loop.
108 | dist = []
109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111 | progress.update(batch_start)
112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114 | x = sampler(c)
115 | for src in range(opts.num_gpus):
116 | y = x.clone()
117 | if opts.num_gpus > 1:
118 | torch.distributed.broadcast(y, src=src)
119 | dist.append(y)
120 | progress.update(num_samples)
121 |
122 | # Compute PPL.
123 | if opts.rank != 0:
124 | return float('nan')
125 | dist = torch.cat(dist)[:num_samples].cpu().numpy()
126 | lo = np.percentile(dist, 1, interpolation='lower')
127 | hi = np.percentile(dist, 99, interpolation='higher')
128 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129 | return float(ppl)
130 |
131 | #----------------------------------------------------------------------------
132 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import glob
11 | import torch
12 | import torch.utils.cpp_extension
13 | import importlib
14 | import hashlib
15 | import shutil
16 | from pathlib import Path
17 |
18 | from torch.utils.file_baton import FileBaton
19 |
20 | #----------------------------------------------------------------------------
21 | # Global options.
22 |
23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24 |
25 | #----------------------------------------------------------------------------
26 | # Internal helper funcs.
27 |
28 | def _find_compiler_bindir():
29 | patterns = [
30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34 | ]
35 | for pattern in patterns:
36 | matches = sorted(glob.glob(pattern))
37 | if len(matches):
38 | return matches[-1]
39 | return None
40 |
41 | #----------------------------------------------------------------------------
42 | # Main entry point for compiling and loading C++/CUDA plugins.
43 |
44 | _cached_plugins = dict()
45 |
46 | def get_plugin(module_name, sources, **build_kwargs):
47 | assert verbosity in ['none', 'brief', 'full']
48 |
49 | # Already cached?
50 | if module_name in _cached_plugins:
51 | return _cached_plugins[module_name]
52 |
53 | # Print status.
54 | if verbosity == 'full':
55 | print(f'Setting up PyTorch plugin "{module_name}"...')
56 | elif verbosity == 'brief':
57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58 |
59 | try: # pylint: disable=too-many-nested-blocks
60 | # Make sure we can find the necessary compiler binaries.
61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62 | compiler_bindir = _find_compiler_bindir()
63 | if compiler_bindir is None:
64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65 | os.environ['PATH'] += ';' + compiler_bindir
66 |
67 | # Compile and load.
68 | verbose_build = (verbosity == 'full')
69 |
70 | # Incremental build md5sum trickery. Copies all the input source files
71 | # into a cached build directory under a combined md5 digest of the input
72 | # source files. Copying is done only if the combined digest has changed.
73 | # This keeps input file timestamps and filenames the same as in previous
74 | # extension builds, allowing for fast incremental rebuilds.
75 | #
76 | # This optimization is done only in case all the source files reside in
77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78 | # environment variable is set (we take this as a signal that the user
79 | # actually cares about this.)
80 | source_dirs_set = set(os.path.dirname(source) for source in sources)
81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83 |
84 | # Compute a combined hash digest for all source files in the same
85 | # custom op directory (usually .cu, .cpp, .py and .h files).
86 | hash_md5 = hashlib.md5()
87 | for src in all_source_files:
88 | with open(src, 'rb') as f:
89 | hash_md5.update(f.read())
90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92 |
93 | if not os.path.isdir(digest_build_dir):
94 | os.makedirs(digest_build_dir, exist_ok=True)
95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96 | if baton.try_acquire():
97 | try:
98 | for src in all_source_files:
99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100 | finally:
101 | baton.release()
102 | else:
103 | # Someone else is copying source files under the digest dir,
104 | # wait until done and continue.
105 | baton.wait()
106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108 | verbose=verbose_build, sources=digest_sources, **build_kwargs)
109 | else:
110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111 | module = importlib.import_module(module_name)
112 |
113 | except:
114 | if verbosity == 'brief':
115 | print('Failed!')
116 | raise
117 |
118 | # Print status and add to cache.
119 | if verbosity == 'full':
120 | print(f'Done setting up PyTorch plugin "{module_name}".')
121 | elif verbosity == 'brief':
122 | print('Done.')
123 | _cached_plugins[module_name] = module
124 | return module
125 |
126 | #----------------------------------------------------------------------------
127 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/metric_main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import json
12 | import torch
13 | import dnnlib
14 |
15 | from . import metric_utils
16 | from . import frechet_inception_distance
17 | from . import kernel_inception_distance
18 | from . import precision_recall
19 | from . import perceptual_path_length
20 | from . import inception_score
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | _metric_dict = dict() # name => fn
25 |
26 | def register_metric(fn):
27 | assert callable(fn)
28 | _metric_dict[fn.__name__] = fn
29 | return fn
30 |
31 | def is_valid_metric(metric):
32 | return metric in _metric_dict
33 |
34 | def list_valid_metrics():
35 | return list(_metric_dict.keys())
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40 | assert is_valid_metric(metric)
41 | opts = metric_utils.MetricOptions(**kwargs)
42 |
43 | # Calculate.
44 | start_time = time.time()
45 | results = _metric_dict[metric](opts)
46 | total_time = time.time() - start_time
47 |
48 | # Broadcast results.
49 | for key, value in list(results.items()):
50 | if opts.num_gpus > 1:
51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52 | torch.distributed.broadcast(tensor=value, src=0)
53 | value = float(value.cpu())
54 | results[key] = value
55 |
56 | # Decorate with metadata.
57 | return dnnlib.EasyDict(
58 | results = dnnlib.EasyDict(results),
59 | metric = metric,
60 | total_time = total_time,
61 | total_time_str = dnnlib.util.format_time(total_time),
62 | num_gpus = opts.num_gpus,
63 | )
64 |
65 | #----------------------------------------------------------------------------
66 |
67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68 | metric = result_dict['metric']
69 | assert is_valid_metric(metric)
70 | if run_dir is not None and snapshot_pkl is not None:
71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72 |
73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74 | print(jsonl_line)
75 | if run_dir is not None and os.path.isdir(run_dir):
76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77 | f.write(jsonl_line + '\n')
78 |
79 | #----------------------------------------------------------------------------
80 | # Primary metrics.
81 |
82 | @register_metric
83 | def fid50k_full(opts):
84 | opts.dataset_kwargs.update(max_size=None, xflip=False)
85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) #max_real=None
86 | return dict(fid50k_full=fid)
87 |
88 | @register_metric
89 | def kid50k_full(opts):
90 | opts.dataset_kwargs.update(max_size=None, xflip=False)
91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=5000, num_subsets=100, max_subset_size=1000)
92 | return dict(kid50k_full=kid)
93 |
94 | @register_metric
95 | def pr50k3_full(opts):
96 | opts.dataset_kwargs.update(max_size=None, xflip=False)
97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=5000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99 |
100 | @register_metric
101 | def ppl2_wend(opts):
102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103 | return dict(ppl2_wend=ppl)
104 |
105 | @register_metric
106 | def is50k(opts):
107 | opts.dataset_kwargs.update(max_size=None, xflip=False)
108 | mean, std = inception_score.compute_is(opts, num_gen=5000, num_splits=10)
109 | return dict(is50k_mean=mean, is50k_std=std)
110 |
111 | #----------------------------------------------------------------------------
112 | # Legacy metrics.
113 |
114 | @register_metric
115 | def fid50k(opts):
116 | opts.dataset_kwargs.update(max_size=None)
117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118 | return dict(fid50k=fid)
119 |
120 | @register_metric
121 | def kid50k(opts):
122 | opts.dataset_kwargs.update(max_size=None)
123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124 | return dict(kid50k=kid)
125 |
126 | @register_metric
127 | def pr50k3(opts):
128 | opts.dataset_kwargs.update(max_size=None)
129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131 |
132 | @register_metric
133 | def ppl_zfull(opts):
134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135 | return dict(ppl_zfull=ppl)
136 |
137 | @register_metric
138 | def ppl_wfull(opts):
139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140 | return dict(ppl_wfull=ppl)
141 |
142 | @register_metric
143 | def ppl_zend(opts):
144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145 | return dict(ppl_zend=ppl)
146 |
147 | @register_metric
148 | def ppl_wend(opts):
149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150 | return dict(ppl_wend=ppl)
151 |
152 | #----------------------------------------------------------------------------
153 |
--------------------------------------------------------------------------------
/interp_feature.py:
--------------------------------------------------------------------------------
1 | ### Adaptive Feature Interpolation
2 | # create a set of new features from old features
3 | # new_feature = near_interp(old_feature, k, augment_prob)
4 | # k, augment_prob can be generated by function "dynamic_prob" or defined by user
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from sklearn.manifold import MDS
10 | import random
11 |
12 | def near_interp(embeddings, k, augment_prob):
13 | if k == 1 or augment_prob == 0:
14 | return embeddings
15 |
16 | k = min(k, embeddings.size()[0])
17 |
18 | pd = pairwise_distances(embeddings, embeddings)
19 | pd = pd/pd.max()
20 | pd_s = (1 / (1+pd))
21 |
22 | # Select top k near neighbours
23 | k_smallest = torch.topk(pd, k, largest=False).indices # shape: batch_size x k
24 |
25 | # Feature interpolation
26 | t = 1
27 | alpha = torch.ones(k, device=embeddings.device)
28 | inner_embeddings = []
29 | for row in k_smallest:
30 | for i in range(k):
31 | alpha[i] = pd_s[row[0],row[i]]**t
32 |
33 | p = torch.distributions.dirichlet.Dirichlet(alpha).sample().to(embeddings.device)
34 |
35 | inner_pts = torch.matmul(p.reshape((1,-1)),embeddings.index_select(0,row))
36 | inner_embeddings.append(F.normalize(inner_pts))
37 |
38 | batch_size = embeddings.size()[0]
39 | out_embeddings = []
40 |
41 |
42 | # Output interpolated feature with probability p
43 | for idx in range(batch_size):
44 | p = random.random()
45 | if p < augment_prob:
46 | out_embeddings.append(inner_embeddings[idx])
47 | else:
48 | out_embeddings.append(embeddings[idx,:].unsqueeze(0))
49 |
50 | return torch.stack(out_embeddings).reshape((batch_size,-1))
51 |
52 |
53 | def dynamic_prob(embeddings):
54 | embeddings = F.normalize(embeddings)
55 | batch_size = embeddings.size()[0]
56 |
57 | D = pairwise_distances(embeddings, embeddings)
58 | D = D.detach().cpu().numpy()
59 | D = D / np.amax(D)
60 |
61 | #l_sorted = cmdscale(D)
62 | l_sorted = eigen_mds(D)
63 |
64 | # Calculate k,p based on number of large eigenvalues
65 | k = batch_size - next(x[0] for x in enumerate(l_sorted) if x[1] < 0.1 * l_sorted[0])
66 | p = (k-1) / batch_size
67 |
68 | #k = 2
69 | #p = 0.9
70 |
71 | return p, k
72 |
73 |
74 | def cmdscale(D):
75 | """
76 | Classical multidimensional scaling (MDS)
77 |
78 | Parameters
79 | ----------
80 | D : (n, n) array
81 | Symmetric distance matrix.
82 |
83 | Returns
84 | -------
85 | Y : (n, p) array
86 | Configuration matrix. Each column represents a dimension. Only the
87 | p dimensions corresponding to positive eigenvalues of B are returned.
88 | Note that each dimension is only determined up to an overall sign,
89 | corresponding to a reflection.
90 |
91 | e : (n,) array
92 | Eigenvalues of B.
93 |
94 | """
95 |
96 | # Number of points
97 | n = len(D)
98 |
99 | # Centering matrix
100 | H = np.eye(n) - np.ones((n, n))/n
101 |
102 | # YY^T
103 | B = -H.dot(D**2).dot(H)/2
104 |
105 | # Diagonalize
106 | evals, evecs = np.linalg.eigh(B)
107 |
108 | # Sort by eigenvalue in descending order
109 | idx = np.argsort(evals)[::-1]
110 | evals = evals[idx]
111 | evecs = evecs[:,idx]
112 |
113 | # Compute the coordinates using positive-eigenvalued components only
114 | # w, = np.where(evals > 0)
115 | # L = np.diag(np.sqrt(evals[w]))
116 | # V = evecs[:,w]
117 | # Y = V.dot(L)
118 |
119 | return np.sort(evals)[::-1]
120 |
121 |
122 | def eigen_mds(pd):
123 | mds = MDS(n_components=len(pd), dissimilarity='precomputed')
124 | pts = mds.fit_transform(pd)
125 |
126 | _,l_sorted,_ = np.linalg.svd(pts)
127 |
128 | return l_sorted
129 |
130 |
131 | def pairwise_distances(x, y):
132 | '''
133 | Input: x is a Nxd matrix
134 | y is an optional Mxd matirx
135 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
136 | if y is not given then use 'y=x'.
137 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
138 | '''
139 | x_norm = (x**2).sum(1).view(-1, 1)
140 | if y is not None:
141 | y_t = torch.transpose(y, 0, 1)
142 | y_norm = (y**2).sum(1).view(1, -1)
143 | else:
144 | y_t = torch.transpose(x, 0, 1)
145 | y_norm = x_norm.view(1, -1)
146 |
147 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
148 | # Ensure diagonal is zero if x=y
149 | # if y is None:
150 | # dist = dist - torch.diag(dist.diag)
151 | return torch.sqrt(torch.clamp(dist, 0.0, np.inf))
152 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/StyleGAN2/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 | from torch_utils import training_stats
12 | from torch_utils import misc
13 | from torch_utils.ops import conv2d_gradfix
14 |
15 | #----------------------------------------------------------------------------
16 |
17 | class Loss:
18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
19 | raise NotImplementedError()
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | class StyleGAN2Loss(Loss):
24 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
25 | super().__init__()
26 | self.device = device
27 | self.G_mapping = G_mapping
28 | self.G_synthesis = G_synthesis
29 | self.D = D
30 | self.augment_pipe = augment_pipe
31 | self.style_mixing_prob = style_mixing_prob
32 | self.r1_gamma = r1_gamma
33 | self.pl_batch_shrink = pl_batch_shrink
34 | self.pl_decay = pl_decay
35 | self.pl_weight = pl_weight
36 | self.pl_mean = torch.zeros([], device=device)
37 |
38 | def run_G(self, z, c, sync):
39 | with misc.ddp_sync(self.G_mapping, sync):
40 | ws = self.G_mapping(z, c)
41 | if self.style_mixing_prob > 0:
42 | with torch.autograd.profiler.record_function('style_mixing'):
43 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
44 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
45 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
46 | with misc.ddp_sync(self.G_synthesis, sync):
47 | img = self.G_synthesis(ws)
48 | return img, ws
49 |
50 | def run_D(self, img, c, sync):
51 | if self.augment_pipe is not None:
52 | img = self.augment_pipe(img)
53 | with misc.ddp_sync(self.D, sync):
54 | logits = self.D(img, c)
55 | return logits
56 |
57 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
58 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
59 | do_Gmain = (phase in ['Gmain', 'Gboth'])
60 | do_Dmain = (phase in ['Dmain', 'Dboth'])
61 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
62 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
63 |
64 | # Gmain: Maximize logits for generated images.
65 | if do_Gmain:
66 | with torch.autograd.profiler.record_function('Gmain_forward'):
67 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
68 | gen_logits = self.run_D(gen_img, gen_c, sync=False)
69 | training_stats.report('Loss/scores/fake', gen_logits)
70 | training_stats.report('Loss/signs/fake', gen_logits.sign())
71 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
72 | #loss_Gmain = torch.nn.functional.relu(-gen_logits).mean()
73 | training_stats.report('Loss/G/loss', loss_Gmain)
74 | with torch.autograd.profiler.record_function('Gmain_backward'):
75 | loss_Gmain.mean().mul(gain).backward()
76 |
77 | # Gpl: Apply path length regularization.
78 | if do_Gpl:
79 | with torch.autograd.profiler.record_function('Gpl_forward'):
80 | batch_size = gen_z.shape[0] // self.pl_batch_shrink
81 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
82 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
83 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
84 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
85 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
86 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
87 | self.pl_mean.copy_(pl_mean.detach())
88 | pl_penalty = (pl_lengths - pl_mean).square()
89 | training_stats.report('Loss/pl_penalty', pl_penalty)
90 | loss_Gpl = pl_penalty * self.pl_weight
91 | training_stats.report('Loss/G/reg', loss_Gpl)
92 | with torch.autograd.profiler.record_function('Gpl_backward'):
93 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
94 |
95 | # Dmain: Minimize logits for generated images.
96 | loss_Dgen = 0
97 | if do_Dmain:
98 | with torch.autograd.profiler.record_function('Dgen_forward'):
99 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
100 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
101 | training_stats.report('Loss/scores/fake', gen_logits)
102 | training_stats.report('Loss/signs/fake', gen_logits.sign())
103 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
104 | #loss_Dgen = torch.nn.functional.relu(1 + gen_logits).mean()
105 | with torch.autograd.profiler.record_function('Dgen_backward'):
106 | loss_Dgen.mean().mul(gain).backward()
107 |
108 | # Dmain: Maximize logits for real images.
109 | # Dr1: Apply R1 regularization.
110 | if do_Dmain or do_Dr1:
111 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
112 | with torch.autograd.profiler.record_function(name + '_forward'):
113 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
114 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
115 | training_stats.report('Loss/scores/real', real_logits)
116 | training_stats.report('Loss/signs/real', real_logits.sign())
117 |
118 | loss_Dreal = 0
119 | if do_Dmain:
120 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
121 | #loss_Dreal = torch.nn.functional.relu(1-real_logits).mean()
122 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
123 |
124 | loss_Dr1 = 0
125 | if do_Dr1:
126 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
127 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
128 | r1_penalty = r1_grads.square().sum([1,2,3])
129 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
130 | training_stats.report('Loss/r1_penalty', r1_penalty)
131 | training_stats.report('Loss/D/reg', loss_Dr1)
132 |
133 | with torch.autograd.profiler.record_function(name + '_backward'):
134 | #(real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
135 | (real_logits.mean() * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
136 |
137 |
138 |
139 | #----------------------------------------------------------------------------
140 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | w = w.flip([2, 3])
37 |
38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42 | if out_channels <= 4 and groups == 1:
43 | in_shape = x.shape
44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46 | else:
47 | x = x.to(memory_format=torch.contiguous_format)
48 | w = w.to(memory_format=torch.contiguous_format)
49 | x = conv2d_gradfix.conv2d(x, w, groups=groups)
50 | return x.to(memory_format=torch.channels_last)
51 |
52 | # Otherwise => execute using conv2d_gradfix.
53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54 | return op(x, w, stride=stride, padding=padding, groups=groups)
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | @misc.profiled_function
59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60 | r"""2D convolution with optional up/downsampling.
61 |
62 | Padding is performed only once at the beginning, not between the operations.
63 |
64 | Args:
65 | x: Input tensor of shape
66 | `[batch_size, in_channels, in_height, in_width]`.
67 | w: Weight tensor of shape
68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70 | calling upfirdn2d.setup_filter(). None = identity (default).
71 | up: Integer upsampling factor (default: 1).
72 | down: Integer downsampling factor (default: 1).
73 | padding: Padding with respect to the upsampled image. Can be a single number
74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75 | (default: 0).
76 | groups: Split input channels into N groups (default: 1).
77 | flip_weight: False = convolution, True = correlation (default: True).
78 | flip_filter: False = convolution, True = correlation (default: False).
79 |
80 | Returns:
81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82 | """
83 | # Validate arguments.
84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87 | assert isinstance(up, int) and (up >= 1)
88 | assert isinstance(down, int) and (down >= 1)
89 | assert isinstance(groups, int) and (groups >= 1)
90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91 | fw, fh = _get_filter_size(f)
92 | px0, px1, py0, py1 = _parse_padding(padding)
93 |
94 | # Adjust padding to account for up/downsampling.
95 | if up > 1:
96 | px0 += (fw + up - 1) // 2
97 | px1 += (fw - up) // 2
98 | py0 += (fh + up - 1) // 2
99 | py1 += (fh - up) // 2
100 | if down > 1:
101 | px0 += (fw - down + 1) // 2
102 | px1 += (fw - down) // 2
103 | py0 += (fh - down + 1) // 2
104 | py1 += (fh - down) // 2
105 |
106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110 | return x
111 |
112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116 | return x
117 |
118 | # Fast path: downsampling only => use strided convolution.
119 | if down > 1 and up == 1:
120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122 | return x
123 |
124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125 | if up > 1:
126 | if groups == 1:
127 | w = w.transpose(0, 1)
128 | else:
129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130 | w = w.transpose(1, 2)
131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132 | px0 -= kw - 1
133 | px1 -= kw - up
134 | py0 -= kh - 1
135 | py1 -= kh - up
136 | pxt = max(min(-px0, -px1), 0)
137 | pyt = max(min(-py0, -py1), 0)
138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140 | if down > 1:
141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142 | return x
143 |
144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145 | if up == 1 and down == 1:
146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148 |
149 | # Fallback: Generic reference implementation.
150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152 | if down > 1:
153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154 | return x
155 |
156 | #----------------------------------------------------------------------------
157 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import warnings
13 | import contextlib
14 | import torch
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 |
25 | @contextlib.contextmanager
26 | def no_weight_gradients():
27 | global weight_gradients_disabled
28 | old = weight_gradients_disabled
29 | weight_gradients_disabled = True
30 | yield
31 | weight_gradients_disabled = old
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36 | if _should_use_custom_op(input):
37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39 |
40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41 | if _should_use_custom_op(input):
42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def _should_use_custom_op(input):
48 | assert isinstance(input, torch.Tensor)
49 | if (not enabled) or (not torch.backends.cudnn.enabled):
50 | return False
51 | if input.device.type != 'cuda':
52 | return False
53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54 | return True
55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56 | return False
57 |
58 | def _tuple_of_ints(xs, ndim):
59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60 | assert len(xs) == ndim
61 | assert all(isinstance(x, int) for x in xs)
62 | return xs
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | _conv2d_gradfix_cache = dict()
67 |
68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69 | # Parse arguments.
70 | ndim = 2
71 | weight_shape = tuple(weight_shape)
72 | stride = _tuple_of_ints(stride, ndim)
73 | padding = _tuple_of_ints(padding, ndim)
74 | output_padding = _tuple_of_ints(output_padding, ndim)
75 | dilation = _tuple_of_ints(dilation, ndim)
76 |
77 | # Lookup from cache.
78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79 | if key in _conv2d_gradfix_cache:
80 | return _conv2d_gradfix_cache[key]
81 |
82 | # Validate arguments.
83 | assert groups >= 1
84 | assert len(weight_shape) == ndim + 2
85 | assert all(stride[i] >= 1 for i in range(ndim))
86 | assert all(padding[i] >= 0 for i in range(ndim))
87 | assert all(dilation[i] >= 0 for i in range(ndim))
88 | if not transpose:
89 | assert all(output_padding[i] == 0 for i in range(ndim))
90 | else: # transpose
91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92 |
93 | # Helpers.
94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95 | def calc_output_padding(input_shape, output_shape):
96 | if transpose:
97 | return [0, 0]
98 | return [
99 | input_shape[i + 2]
100 | - (output_shape[i + 2] - 1) * stride[i]
101 | - (1 - 2 * padding[i])
102 | - dilation[i] * (weight_shape[i + 2] - 1)
103 | for i in range(ndim)
104 | ]
105 |
106 | # Forward & backward.
107 | class Conv2d(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, bias):
110 | assert weight.shape == weight_shape
111 | if not transpose:
112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113 | else: # transpose
114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115 | ctx.save_for_backward(input, weight)
116 | return output
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | input, weight = ctx.saved_tensors
121 | grad_input = None
122 | grad_weight = None
123 | grad_bias = None
124 |
125 | if ctx.needs_input_grad[0]:
126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128 | assert grad_input.shape == input.shape
129 |
130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
132 | assert grad_weight.shape == weight_shape
133 |
134 | if ctx.needs_input_grad[2]:
135 | grad_bias = grad_output.sum([0, 2, 3])
136 |
137 | return grad_input, grad_weight, grad_bias
138 |
139 | # Gradient with respect to the weights.
140 | class Conv2dGradWeight(torch.autograd.Function):
141 | @staticmethod
142 | def forward(ctx, grad_output, input):
143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146 | assert grad_weight.shape == weight_shape
147 | ctx.save_for_backward(grad_output, input)
148 | return grad_weight
149 |
150 | @staticmethod
151 | def backward(ctx, grad2_grad_weight):
152 | grad_output, input = ctx.saved_tensors
153 | grad2_grad_output = None
154 | grad2_input = None
155 |
156 | if ctx.needs_input_grad[0]:
157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158 | assert grad2_grad_output.shape == grad_output.shape
159 |
160 | if ctx.needs_input_grad[1]:
161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163 | assert grad2_input.shape == input.shape
164 |
165 | return grad2_grad_output, grad2_input
166 |
167 | _conv2d_gradfix_cache[key] = Conv2d
168 | return Conv2d
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/StyleGAN2/calc_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Calculate quality metrics for previous training run or pretrained network pickle."""
10 |
11 | import os
12 | import click
13 | import json
14 | import tempfile
15 | import copy
16 | import torch
17 | import dnnlib
18 |
19 | import legacy
20 | from metrics import metric_main
21 | from metrics import metric_utils
22 | from torch_utils import training_stats
23 | from torch_utils import custom_ops
24 | from torch_utils import misc
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def subprocess_fn(rank, args, temp_dir):
29 | dnnlib.util.Logger(should_flush=True)
30 |
31 | # Init torch.distributed.
32 | if args.num_gpus > 1:
33 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
34 | if os.name == 'nt':
35 | init_method = 'file:///' + init_file.replace('\\', '/')
36 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
37 | else:
38 | init_method = f'file://{init_file}'
39 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
40 |
41 | # Init torch_utils.
42 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
43 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
44 | if rank != 0 or not args.verbose:
45 | custom_ops.verbosity = 'none'
46 |
47 | # Print network summary.
48 | device = torch.device('cuda', rank)
49 | torch.backends.cudnn.benchmark = True
50 | torch.backends.cuda.matmul.allow_tf32 = False
51 | torch.backends.cudnn.allow_tf32 = False
52 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
53 | if rank == 0 and args.verbose:
54 | z = torch.empty([1, G.z_dim], device=device)
55 | c = torch.empty([1, G.c_dim], device=device)
56 | misc.print_module_summary(G, [z, c])
57 |
58 | # Calculate each metric.
59 | for metric in args.metrics:
60 | if rank == 0 and args.verbose:
61 | print(f'Calculating {metric}...')
62 | progress = metric_utils.ProgressMonitor(verbose=args.verbose)
63 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
64 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
65 | if rank == 0:
66 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
67 | if rank == 0 and args.verbose:
68 | print()
69 |
70 | # Done.
71 | if rank == 0 and args.verbose:
72 | print('Exiting...')
73 |
74 | #----------------------------------------------------------------------------
75 |
76 | class CommaSeparatedList(click.ParamType):
77 | name = 'list'
78 |
79 | def convert(self, value, param, ctx):
80 | _ = param, ctx
81 | if value is None or value.lower() == 'none' or value == '':
82 | return []
83 | return value.split(',')
84 |
85 | #----------------------------------------------------------------------------
86 |
87 | @click.command()
88 | @click.pass_context
89 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
90 | @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
91 | @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
92 | @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
93 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
94 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
95 |
96 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
97 | """Calculate quality metrics for previous training run or pretrained network pickle.
98 |
99 | Examples:
100 |
101 | \b
102 | # Previous training run: look up options automatically, save result to JSONL file.
103 | python calc_metrics.py --metrics=pr50k3_full \\
104 | --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
105 |
106 | \b
107 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
108 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
109 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
110 |
111 | Available metrics:
112 |
113 | \b
114 | ADA paper:
115 | fid50k_full Frechet inception distance against the full dataset.
116 | kid50k_full Kernel inception distance against the full dataset.
117 | pr50k3_full Precision and recall againt the full dataset.
118 | is50k Inception score for CIFAR-10.
119 |
120 | \b
121 | StyleGAN and StyleGAN2 papers:
122 | fid50k Frechet inception distance against 50k real images.
123 | kid50k Kernel inception distance against 50k real images.
124 | pr50k3 Precision and recall against 50k real images.
125 | ppl2_wend Perceptual path length in W at path endpoints against full image.
126 | ppl_zfull Perceptual path length in Z for full paths against cropped image.
127 | ppl_wfull Perceptual path length in W for full paths against cropped image.
128 | ppl_zend Perceptual path length in Z at path endpoints against cropped image.
129 | ppl_wend Perceptual path length in W at path endpoints against cropped image.
130 | """
131 | dnnlib.util.Logger(should_flush=True)
132 |
133 | # Validate arguments.
134 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
135 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
136 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
137 | if not args.num_gpus >= 1:
138 | ctx.fail('--gpus must be at least 1')
139 |
140 | # Load network.
141 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
142 | ctx.fail('--network must point to a file or URL')
143 | if args.verbose:
144 | print(f'Loading network from "{network_pkl}"...')
145 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
146 | network_dict = legacy.load_network_pkl(f)
147 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module
148 |
149 | # Initialize dataset options.
150 | if data is not None:
151 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
152 | elif network_dict['training_set_kwargs'] is not None:
153 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
154 | else:
155 | ctx.fail('Could not look up dataset options; please specify --data')
156 |
157 | # Finalize dataset options.
158 | args.dataset_kwargs.resolution = args.G.img_resolution
159 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
160 | if mirror is not None:
161 | args.dataset_kwargs.xflip = mirror
162 |
163 | # Print dataset options.
164 | if args.verbose:
165 | print('Dataset options:')
166 | print(json.dumps(args.dataset_kwargs, indent=2))
167 |
168 | # Locate run dir.
169 | args.run_dir = None
170 | if os.path.isfile(network_pkl):
171 | pkl_dir = os.path.dirname(network_pkl)
172 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
173 | args.run_dir = pkl_dir
174 |
175 | # Launch processes.
176 | if args.verbose:
177 | print('Launching processes...')
178 | torch.multiprocessing.set_start_method('spawn')
179 | with tempfile.TemporaryDirectory() as temp_dir:
180 | if args.num_gpus == 1:
181 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
182 | else:
183 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
184 |
185 | #----------------------------------------------------------------------------
186 |
187 | if __name__ == "__main__":
188 | calc_metrics() # pylint: disable=no-value-for-parameter
189 |
190 | #----------------------------------------------------------------------------
191 |
--------------------------------------------------------------------------------
/StyleGAN2/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import numpy as np
11 | import zipfile
12 | import PIL.Image
13 | import json
14 | import torch
15 | import dnnlib
16 |
17 | try:
18 | import pyspng
19 | except ImportError:
20 | pyspng = None
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | class Dataset(torch.utils.data.Dataset):
25 | def __init__(self,
26 | name, # Name of the dataset.
27 | raw_shape, # Shape of the raw image data (NCHW).
28 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
29 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
30 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
31 | random_seed = 0, # Random seed to use when applying max_size.
32 | ):
33 | self._name = name
34 | self._raw_shape = list(raw_shape)
35 | self._use_labels = use_labels
36 | self._raw_labels = None
37 | self._label_shape = None
38 |
39 | # Apply max_size.
40 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
41 | if (max_size is not None) and (self._raw_idx.size > max_size):
42 | np.random.RandomState(random_seed).shuffle(self._raw_idx)
43 | self._raw_idx = np.sort(self._raw_idx[:max_size])
44 |
45 | # Apply xflip.
46 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
47 | if xflip:
48 | self._raw_idx = np.tile(self._raw_idx, 2)
49 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
50 |
51 | def _get_raw_labels(self):
52 | if self._raw_labels is None:
53 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
54 | if self._raw_labels is None:
55 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
56 | assert isinstance(self._raw_labels, np.ndarray)
57 | assert self._raw_labels.shape[0] == self._raw_shape[0]
58 | assert self._raw_labels.dtype in [np.float32, np.int64]
59 | if self._raw_labels.dtype == np.int64:
60 | assert self._raw_labels.ndim == 1
61 | assert np.all(self._raw_labels >= 0)
62 | return self._raw_labels
63 |
64 | def close(self): # to be overridden by subclass
65 | pass
66 |
67 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
68 | raise NotImplementedError
69 |
70 | def _load_raw_labels(self): # to be overridden by subclass
71 | raise NotImplementedError
72 |
73 | def __getstate__(self):
74 | return dict(self.__dict__, _raw_labels=None)
75 |
76 | def __del__(self):
77 | try:
78 | self.close()
79 | except:
80 | pass
81 |
82 | def __len__(self):
83 | return self._raw_idx.size
84 |
85 | def __getitem__(self, idx):
86 | image = self._load_raw_image(self._raw_idx[idx])
87 | assert isinstance(image, np.ndarray)
88 | assert list(image.shape) == self.image_shape
89 | assert image.dtype == np.uint8
90 | #print(list(image.shape), self.image_shape)
91 | if self._xflip[idx]:
92 | assert image.ndim == 3 # CHW
93 | image = image[:, :, ::-1]
94 | return image.copy(), self.get_label(idx)
95 |
96 | def get_label(self, idx):
97 | label = self._get_raw_labels()[self._raw_idx[idx]]
98 | if label.dtype == np.int64:
99 | onehot = np.zeros(self.label_shape, dtype=np.float32)
100 | onehot[label] = 1
101 | label = onehot
102 | return label.copy()
103 |
104 | def get_details(self, idx):
105 | d = dnnlib.EasyDict()
106 | d.raw_idx = int(self._raw_idx[idx])
107 | d.xflip = (int(self._xflip[idx]) != 0)
108 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
109 | return d
110 |
111 | @property
112 | def name(self):
113 | return self._name
114 |
115 | @property
116 | def image_shape(self):
117 | return list(self._raw_shape[1:])
118 |
119 | @property
120 | def num_channels(self):
121 | assert len(self.image_shape) == 3 # CHW
122 | return self.image_shape[0]
123 |
124 | @property
125 | def resolution(self):
126 | assert len(self.image_shape) == 3 # CHW
127 | assert self.image_shape[1] == self.image_shape[2]
128 | return self.image_shape[1]
129 |
130 | @property
131 | def label_shape(self):
132 | if self._label_shape is None:
133 | raw_labels = self._get_raw_labels()
134 | if raw_labels.dtype == np.int64:
135 | self._label_shape = [int(np.max(raw_labels)) + 1]
136 | else:
137 | self._label_shape = raw_labels.shape[1:]
138 | return list(self._label_shape)
139 |
140 | @property
141 | def label_dim(self):
142 | assert len(self.label_shape) == 1
143 | return self.label_shape[0]
144 |
145 | @property
146 | def has_labels(self):
147 | return any(x != 0 for x in self.label_shape)
148 |
149 | @property
150 | def has_onehot_labels(self):
151 | return self._get_raw_labels().dtype == np.int64
152 |
153 | #----------------------------------------------------------------------------
154 |
155 | class ImageFolderDataset(Dataset):
156 | def __init__(self,
157 | path, # Path to directory or zip.
158 | resolution = None, # Ensure specific resolution, None = highest available.
159 | **super_kwargs, # Additional arguments for the Dataset base class.
160 | ):
161 | self._path = path
162 | self._zipfile = None
163 |
164 | if os.path.isdir(self._path):
165 | self._type = 'dir'
166 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
167 | elif self._file_ext(self._path) == '.zip':
168 | self._type = 'zip'
169 | self._all_fnames = set(self._get_zipfile().namelist())
170 | else:
171 | raise IOError('Path must point to a directory or zip')
172 |
173 | PIL.Image.init()
174 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
175 | if len(self._image_fnames) == 0:
176 | raise IOError('No image files found in the specified path')
177 |
178 | name = os.path.splitext(os.path.basename(self._path))[0]
179 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
180 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
181 | raise IOError('Image files do not match the specified resolution')
182 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
183 |
184 | @staticmethod
185 | def _file_ext(fname):
186 | return os.path.splitext(fname)[1].lower()
187 |
188 | def _get_zipfile(self):
189 | assert self._type == 'zip'
190 | if self._zipfile is None:
191 | self._zipfile = zipfile.ZipFile(self._path)
192 | return self._zipfile
193 |
194 | def _open_file(self, fname):
195 | if self._type == 'dir':
196 | return open(os.path.join(self._path, fname), 'rb')
197 | if self._type == 'zip':
198 | return self._get_zipfile().open(fname, 'r')
199 | return None
200 |
201 | def close(self):
202 | try:
203 | if self._zipfile is not None:
204 | self._zipfile.close()
205 | finally:
206 | self._zipfile = None
207 |
208 | def __getstate__(self):
209 | return dict(super().__getstate__(), _zipfile=None)
210 |
211 | def _load_raw_image(self, raw_idx):
212 | fname = self._image_fnames[raw_idx]
213 | with self._open_file(fname) as f:
214 | if pyspng is not None and self._file_ext(fname) == '.png':
215 | image = pyspng.load(f.read())
216 | else:
217 | image = np.array(PIL.Image.open(f))
218 | if image.ndim == 2:
219 | image = image[:, :, np.newaxis] # HW => HWC
220 | image = image.transpose(2, 0, 1) # HWC => CHW
221 | return image
222 |
223 | def _load_raw_labels(self):
224 | fname = 'dataset.json'
225 | if fname not in self._all_fnames:
226 | return None
227 | with self._open_file(fname) as f:
228 | labels = json.load(f)['labels']
229 | if labels is None:
230 | return None
231 | labels = dict(labels)
232 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
233 | labels = np.array(labels)
234 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
235 | return labels
236 |
237 | #----------------------------------------------------------------------------
238 |
--------------------------------------------------------------------------------
/StyleGAN2/projector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Project given image to the latent space of pretrained network pickle."""
10 |
11 | import copy
12 | import os
13 | from time import perf_counter
14 |
15 | import click
16 | import imageio
17 | import numpy as np
18 | import PIL.Image
19 | import torch
20 | import torch.nn.functional as F
21 |
22 | import dnnlib
23 | import legacy
24 |
25 | def project(
26 | G,
27 | target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
28 | *,
29 | num_steps = 1000,
30 | w_avg_samples = 10000,
31 | initial_learning_rate = 0.1,
32 | initial_noise_factor = 0.05,
33 | lr_rampdown_length = 0.25,
34 | lr_rampup_length = 0.05,
35 | noise_ramp_length = 0.75,
36 | regularize_noise_weight = 1e5,
37 | verbose = False,
38 | device: torch.device
39 | ):
40 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
41 |
42 | def logprint(*args):
43 | if verbose:
44 | print(*args)
45 |
46 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
47 |
48 | # Compute w stats.
49 | logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
50 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
51 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
52 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
53 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
54 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
55 |
56 | # Setup noise inputs.
57 | noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
58 |
59 | # Load VGG16 feature detector.
60 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
61 | with dnnlib.util.open_url(url) as f:
62 | vgg16 = torch.jit.load(f).eval().to(device)
63 |
64 | # Features for target image.
65 | target_images = target.unsqueeze(0).to(device).to(torch.float32)
66 | if target_images.shape[2] > 256:
67 | target_images = F.interpolate(target_images, size=(256, 256), mode='area')
68 | target_features = vgg16(target_images, resize_images=False, return_lpips=True)
69 |
70 | w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
71 | w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
72 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
73 |
74 | # Init noise.
75 | for buf in noise_bufs.values():
76 | buf[:] = torch.randn_like(buf)
77 | buf.requires_grad = True
78 |
79 | for step in range(num_steps):
80 | # Learning rate schedule.
81 | t = step / num_steps
82 | w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
83 | lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
84 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
85 | lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
86 | lr = initial_learning_rate * lr_ramp
87 | for param_group in optimizer.param_groups:
88 | param_group['lr'] = lr
89 |
90 | # Synth images from opt_w.
91 | w_noise = torch.randn_like(w_opt) * w_noise_scale
92 | ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
93 | synth_images = G.synthesis(ws, noise_mode='const')
94 |
95 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
96 | synth_images = (synth_images + 1) * (255/2)
97 | if synth_images.shape[2] > 256:
98 | synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
99 |
100 | # Features for synth images.
101 | synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
102 | dist = (target_features - synth_features).square().sum()
103 |
104 | # Noise regularization.
105 | reg_loss = 0.0
106 | for v in noise_bufs.values():
107 | noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
108 | while True:
109 | reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
110 | reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
111 | if noise.shape[2] <= 8:
112 | break
113 | noise = F.avg_pool2d(noise, kernel_size=2)
114 | loss = dist + reg_loss * regularize_noise_weight
115 |
116 | # Step
117 | optimizer.zero_grad(set_to_none=True)
118 | loss.backward()
119 | optimizer.step()
120 | logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
121 |
122 | # Save projected W for each optimization step.
123 | w_out[step] = w_opt.detach()[0]
124 |
125 | # Normalize noise.
126 | with torch.no_grad():
127 | for buf in noise_bufs.values():
128 | buf -= buf.mean()
129 | buf *= buf.square().mean().rsqrt()
130 |
131 | return w_out.repeat([1, G.mapping.num_ws, 1])
132 |
133 | #----------------------------------------------------------------------------
134 |
135 | @click.command()
136 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
137 | @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
138 | @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
139 | @click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
140 | @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
141 | @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
142 | def run_projection(
143 | network_pkl: str,
144 | target_fname: str,
145 | outdir: str,
146 | save_video: bool,
147 | seed: int,
148 | num_steps: int
149 | ):
150 | """Project given image to the latent space of pretrained network pickle.
151 |
152 | Examples:
153 |
154 | \b
155 | python projector.py --outdir=out --target=~/mytargetimg.png \\
156 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
157 | """
158 | np.random.seed(seed)
159 | torch.manual_seed(seed)
160 |
161 | # Load networks.
162 | print('Loading networks from "%s"...' % network_pkl)
163 | device = torch.device('cuda')
164 | with dnnlib.util.open_url(network_pkl) as fp:
165 | G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
166 |
167 | # Load target image.
168 | target_pil = PIL.Image.open(target_fname).convert('RGB')
169 | w, h = target_pil.size
170 | s = min(w, h)
171 | target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
172 | target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
173 | target_uint8 = np.array(target_pil, dtype=np.uint8)
174 |
175 | # Optimize projection.
176 | start_time = perf_counter()
177 | projected_w_steps = project(
178 | G,
179 | target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
180 | num_steps=num_steps,
181 | device=device,
182 | verbose=True
183 | )
184 | print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
185 |
186 | # Render debug output: optional video and projected image and W vector.
187 | os.makedirs(outdir, exist_ok=True)
188 | if save_video:
189 | video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
190 | print (f'Saving optimization progress video "{outdir}/proj.mp4"')
191 | for projected_w in projected_w_steps:
192 | synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
193 | synth_image = (synth_image + 1) * (255/2)
194 | synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
195 | video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
196 | video.close()
197 |
198 | # Save final projected frame and W vector.
199 | target_pil.save(f'{outdir}/target.png')
200 | projected_w = projected_w_steps[-1]
201 | synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
202 | synth_image = (synth_image + 1) * (255/2)
203 | synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
204 | PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
205 | np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
206 |
207 | #----------------------------------------------------------------------------
208 |
209 | if __name__ == "__main__":
210 | run_projection() # pylint: disable=no-value-for-parameter
211 |
212 | #----------------------------------------------------------------------------
213 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for pickling Python code alongside other data.
10 |
11 | The pickled code is automatically imported into a separate Python module
12 | during unpickling. This way, any previously exported pickles will remain
13 | usable even if the original code is no longer available, or if the current
14 | version of the code is not consistent with what was originally pickled."""
15 |
16 | import sys
17 | import pickle
18 | import io
19 | import inspect
20 | import copy
21 | import uuid
22 | import types
23 | import dnnlib
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | _version = 6 # internal version number
28 | _decorators = set() # {decorator_class, ...}
29 | _import_hooks = [] # [hook_function, ...]
30 | _module_to_src_dict = dict() # {module: src, ...}
31 | _src_to_module_dict = dict() # {src: module, ...}
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def persistent_class(orig_class):
36 | r"""Class decorator that extends a given class to save its source code
37 | when pickled.
38 |
39 | Example:
40 |
41 | from torch_utils import persistence
42 |
43 | @persistence.persistent_class
44 | class MyNetwork(torch.nn.Module):
45 | def __init__(self, num_inputs, num_outputs):
46 | super().__init__()
47 | self.fc = MyLayer(num_inputs, num_outputs)
48 | ...
49 |
50 | @persistence.persistent_class
51 | class MyLayer(torch.nn.Module):
52 | ...
53 |
54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55 | source code alongside other internal state (e.g., parameters, buffers,
56 | and submodules). This way, any previously exported pickle will remain
57 | usable even if the class definitions have been modified or are no
58 | longer available.
59 |
60 | The decorator saves the source code of the entire Python module
61 | containing the decorated class. It does *not* save the source code of
62 | any imported modules. Thus, the imported modules must be available
63 | during unpickling, also including `torch_utils.persistence` itself.
64 |
65 | It is ok to call functions defined in the same module from the
66 | decorated class. However, if the decorated class depends on other
67 | classes defined in the same module, they must be decorated as well.
68 | This is illustrated in the above example in the case of `MyLayer`.
69 |
70 | It is also possible to employ the decorator just-in-time before
71 | calling the constructor. For example:
72 |
73 | cls = MyLayer
74 | if want_to_make_it_persistent:
75 | cls = persistence.persistent_class(cls)
76 | layer = cls(num_inputs, num_outputs)
77 |
78 | As an additional feature, the decorator also keeps track of the
79 | arguments that were used to construct each instance of the decorated
80 | class. The arguments can be queried via `obj.init_args` and
81 | `obj.init_kwargs`, and they are automatically pickled alongside other
82 | object state. A typical use case is to first unpickle a previous
83 | instance of a persistent class, and then upgrade it to use the latest
84 | version of the source code:
85 |
86 | with open('old_pickle.pkl', 'rb') as f:
87 | old_net = pickle.load(f)
88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90 | """
91 | assert isinstance(orig_class, type)
92 | if is_persistent(orig_class):
93 | return orig_class
94 |
95 | assert orig_class.__module__ in sys.modules
96 | orig_module = sys.modules[orig_class.__module__]
97 | orig_module_src = _module_to_src(orig_module)
98 |
99 | class Decorator(orig_class):
100 | _orig_module_src = orig_module_src
101 | _orig_class_name = orig_class.__name__
102 |
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | self._init_args = copy.deepcopy(args)
106 | self._init_kwargs = copy.deepcopy(kwargs)
107 | assert orig_class.__name__ in orig_module.__dict__
108 | _check_pickleable(self.__reduce__())
109 |
110 | @property
111 | def init_args(self):
112 | return copy.deepcopy(self._init_args)
113 |
114 | @property
115 | def init_kwargs(self):
116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117 |
118 | def __reduce__(self):
119 | fields = list(super().__reduce__())
120 | fields += [None] * max(3 - len(fields), 0)
121 | if fields[0] is not _reconstruct_persistent_obj:
122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123 | fields[0] = _reconstruct_persistent_obj # reconstruct func
124 | fields[1] = (meta,) # reconstruct args
125 | fields[2] = None # state dict
126 | return tuple(fields)
127 |
128 | Decorator.__name__ = orig_class.__name__
129 | _decorators.add(Decorator)
130 | return Decorator
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | def is_persistent(obj):
135 | r"""Test whether the given object or class is persistent, i.e.,
136 | whether it will save its source code when pickled.
137 | """
138 | try:
139 | if obj in _decorators:
140 | return True
141 | except TypeError:
142 | pass
143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144 |
145 | #----------------------------------------------------------------------------
146 |
147 | def import_hook(hook):
148 | r"""Register an import hook that is called whenever a persistent object
149 | is being unpickled. A typical use case is to patch the pickled source
150 | code to avoid errors and inconsistencies when the API of some imported
151 | module has changed.
152 |
153 | The hook should have the following signature:
154 |
155 | hook(meta) -> modified meta
156 |
157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158 |
159 | type: Type of the persistent object, e.g. `'class'`.
160 | version: Internal version number of `torch_utils.persistence`.
161 | module_src Original source code of the Python module.
162 | class_name: Class name in the original Python module.
163 | state: Internal state of the object.
164 |
165 | Example:
166 |
167 | @persistence.import_hook
168 | def wreck_my_network(meta):
169 | if meta.class_name == 'MyNetwork':
170 | print('MyNetwork is being imported. I will wreck it!')
171 | meta.module_src = meta.module_src.replace("True", "False")
172 | return meta
173 | """
174 | assert callable(hook)
175 | _import_hooks.append(hook)
176 |
177 | #----------------------------------------------------------------------------
178 |
179 | def _reconstruct_persistent_obj(meta):
180 | r"""Hook that is called internally by the `pickle` module to unpickle
181 | a persistent object.
182 | """
183 | meta = dnnlib.EasyDict(meta)
184 | meta.state = dnnlib.EasyDict(meta.state)
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 | return obj
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def _module_to_src(module):
207 | r"""Query the source code of a given Python module.
208 | """
209 | src = _module_to_src_dict.get(module, None)
210 | if src is None:
211 | src = inspect.getsource(module)
212 | _module_to_src_dict[module] = src
213 | _src_to_module_dict[src] = module
214 | return src
215 |
216 | def _src_to_module(src):
217 | r"""Get or create a Python module for the given source code.
218 | """
219 | module = _src_to_module_dict.get(src, None)
220 | if module is None:
221 | module_name = "_imported_module_" + uuid.uuid4().hex
222 | module = types.ModuleType(module_name)
223 | sys.modules[module_name] = module
224 | _module_to_src_dict[module] = src
225 | _src_to_module_dict[src] = module
226 | exec(src, module.__dict__) # pylint: disable=exec-used
227 | return module
228 |
229 | #----------------------------------------------------------------------------
230 |
231 | def _check_pickleable(obj):
232 | r"""Check that the given object is pickleable, raising an exception if
233 | it is not. This function is expected to be considerably more efficient
234 | than actually pickling the object.
235 | """
236 | def recurse(obj):
237 | if isinstance(obj, (list, tuple, set)):
238 | return [recurse(x) for x in obj]
239 | if isinstance(obj, dict):
240 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242 | return None # Python primitive types are pickleable.
243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244 | return None # NumPy arrays and PyTorch tensors are pickleable.
245 | if is_persistent(obj):
246 | return None # Persistent objects are pickleable, by virtue of the constructor check.
247 | return obj
248 | with io.BytesIO() as f:
249 | pickle.dump(recurse(obj), f)
250 |
251 | #----------------------------------------------------------------------------
252 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import warnings
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | import traceback
17 |
18 | from .. import custom_ops
19 | from .. import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | activation_funcs = {
24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33 | }
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | _inited = False
38 | _plugin = None
39 | _null_tensor = torch.empty([0])
40 |
41 | def _init():
42 | global _inited, _plugin
43 | if not _inited:
44 | _inited = True
45 | sources = ['bias_act.cpp', 'bias_act.cu']
46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47 | try:
48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49 | except:
50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51 | return _plugin is not None
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56 | r"""Fused bias and activation function.
57 |
58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59 | and scales the result by `gain`. Each of the steps is optional. In most cases,
60 | the fused op is considerably more efficient than performing the same calculation
61 | using standard PyTorch ops. It supports first and second order gradients,
62 | but not third order gradients.
63 |
64 | Args:
65 | x: Input activation tensor. Can be of any shape.
66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67 | as `x`. The shape must be known, and it must match the dimension of `x`
68 | corresponding to `dim`.
69 | dim: The dimension in `x` corresponding to the elements of `b`.
70 | The value of `dim` is ignored if `b` is not specified.
71 | act: Name of the activation function to evaluate, or `"linear"` to disable.
72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73 | See `activation_funcs` for a full list. `None` is not allowed.
74 | alpha: Shape parameter for the activation function, or `None` to use the default.
75 | gain: Scaling factor for the output tensor, or `None` to use default.
76 | See `activation_funcs` for the default scaling of each activation function.
77 | If unsure, consider specifying 1.
78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79 | the clamping (default).
80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81 |
82 | Returns:
83 | Tensor of the same shape and datatype as `x`.
84 | """
85 | assert isinstance(x, torch.Tensor)
86 | assert impl in ['ref', 'cuda']
87 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90 |
91 | #----------------------------------------------------------------------------
92 |
93 | @misc.profiled_function
94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96 | """
97 | assert isinstance(x, torch.Tensor)
98 | assert clamp is None or clamp >= 0
99 | spec = activation_funcs[act]
100 | alpha = float(alpha if alpha is not None else spec.def_alpha)
101 | gain = float(gain if gain is not None else spec.def_gain)
102 | clamp = float(clamp if clamp is not None else -1)
103 |
104 | # Add bias.
105 | if b is not None:
106 | assert isinstance(b, torch.Tensor) and b.ndim == 1
107 | assert 0 <= dim < x.ndim
108 | assert b.shape[0] == x.shape[dim]
109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110 |
111 | # Evaluate activation function.
112 | alpha = float(alpha)
113 | x = spec.func(x, alpha=alpha)
114 |
115 | # Scale by gain.
116 | gain = float(gain)
117 | if gain != 1:
118 | x = x * gain
119 |
120 | # Clamp.
121 | if clamp >= 0:
122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123 | return x
124 |
125 | #----------------------------------------------------------------------------
126 |
127 | _bias_act_cuda_cache = dict()
128 |
129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130 | """Fast CUDA implementation of `bias_act()` using custom ops.
131 | """
132 | # Parse arguments.
133 | assert clamp is None or clamp >= 0
134 | spec = activation_funcs[act]
135 | alpha = float(alpha if alpha is not None else spec.def_alpha)
136 | gain = float(gain if gain is not None else spec.def_gain)
137 | clamp = float(clamp if clamp is not None else -1)
138 |
139 | # Lookup from cache.
140 | key = (dim, act, alpha, gain, clamp)
141 | if key in _bias_act_cuda_cache:
142 | return _bias_act_cuda_cache[key]
143 |
144 | # Forward op.
145 | class BiasActCuda(torch.autograd.Function):
146 | @staticmethod
147 | def forward(ctx, x, b): # pylint: disable=arguments-differ
148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149 | x = x.contiguous(memory_format=ctx.memory_format)
150 | b = b.contiguous() if b is not None else _null_tensor
151 | y = x
152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154 | ctx.save_for_backward(
155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157 | y if 'y' in spec.ref else _null_tensor)
158 | return y
159 |
160 | @staticmethod
161 | def backward(ctx, dy): # pylint: disable=arguments-differ
162 | dy = dy.contiguous(memory_format=ctx.memory_format)
163 | x, b, y = ctx.saved_tensors
164 | dx = None
165 | db = None
166 |
167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168 | dx = dy
169 | if act != 'linear' or gain != 1 or clamp >= 0:
170 | dx = BiasActCudaGrad.apply(dy, x, b, y)
171 |
172 | if ctx.needs_input_grad[1]:
173 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
174 |
175 | return dx, db
176 |
177 | # Backward op.
178 | class BiasActCudaGrad(torch.autograd.Function):
179 | @staticmethod
180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183 | ctx.save_for_backward(
184 | dy if spec.has_2nd_grad else _null_tensor,
185 | x, b, y)
186 | return dx
187 |
188 | @staticmethod
189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191 | dy, x, b, y = ctx.saved_tensors
192 | d_dy = None
193 | d_x = None
194 | d_b = None
195 | d_y = None
196 |
197 | if ctx.needs_input_grad[0]:
198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199 |
200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202 |
203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205 |
206 | return d_dy, d_x, d_b, d_y
207 |
208 | # Add to cache.
209 | _bias_act_cuda_cache[key] = BiasActCuda
210 | return BiasActCuda
211 |
212 | #----------------------------------------------------------------------------
213 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for reporting and collecting training statistics across
10 | multiple processes and devices. The interface is designed to minimize
11 | synchronization overhead as well as the amount of boilerplate in user
12 | code."""
13 |
14 | import re
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | from . import misc
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
26 | _rank = 0 # Rank of the current process.
27 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
28 | _sync_called = False # Has _sync() been called yet?
29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def init_multiprocessing(rank, sync_device):
35 | r"""Initializes `torch_utils.training_stats` for collecting statistics
36 | across multiple processes.
37 |
38 | This function must be called after
39 | `torch.distributed.init_process_group()` and before `Collector.update()`.
40 | The call is not necessary if multi-process collection is not needed.
41 |
42 | Args:
43 | rank: Rank of the current process.
44 | sync_device: PyTorch device to use for inter-process
45 | communication, or None to disable multi-process
46 | collection. Typically `torch.device('cuda', rank)`.
47 | """
48 | global _rank, _sync_device
49 | assert not _sync_called
50 | _rank = rank
51 | _sync_device = sync_device
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | @misc.profiled_function
56 | def report(name, value):
57 | r"""Broadcasts the given set of scalars to all interested instances of
58 | `Collector`, across device and process boundaries.
59 |
60 | This function is expected to be extremely cheap and can be safely
61 | called from anywhere in the training loop, loss function, or inside a
62 | `torch.nn.Module`.
63 |
64 | Warning: The current implementation expects the set of unique names to
65 | be consistent across processes. Please make sure that `report()` is
66 | called at least once for each unique name by each process, and in the
67 | same order. If a given process has no scalars to broadcast, it can do
68 | `report(name, [])` (empty list).
69 |
70 | Args:
71 | name: Arbitrary string specifying the name of the statistic.
72 | Averages are accumulated separately for each unique name.
73 | value: Arbitrary set of scalars. Can be a list, tuple,
74 | NumPy array, PyTorch tensor, or Python scalar.
75 |
76 | Returns:
77 | The same `value` that was passed in.
78 | """
79 | if name not in _counters:
80 | _counters[name] = dict()
81 |
82 | elems = torch.as_tensor(value)
83 | if elems.numel() == 0:
84 | return value
85 |
86 | elems = elems.detach().flatten().to(_reduce_dtype)
87 | moments = torch.stack([
88 | torch.ones_like(elems).sum(),
89 | elems.sum(),
90 | elems.square().sum(),
91 | ])
92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
93 | moments = moments.to(_counter_dtype)
94 |
95 | device = moments.device
96 | if device not in _counters[name]:
97 | _counters[name][device] = torch.zeros_like(moments)
98 | _counters[name][device].add_(moments)
99 | return value
100 |
101 | #----------------------------------------------------------------------------
102 |
103 | def report0(name, value):
104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105 | but ignores any scalars provided by the other processes.
106 | See `report()` for further details.
107 | """
108 | report(name, value if _rank == 0 else [])
109 | return value
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | class Collector:
114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
115 | computes their long-term averages (mean and standard deviation) over
116 | user-defined periods of time.
117 |
118 | The averages are first collected into internal counters that are not
119 | directly visible to the user. They are then copied to the user-visible
120 | state as a result of calling `update()` and can then be queried using
121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122 | internal counters for the next round, so that the user-visible state
123 | effectively reflects averages collected between the last two calls to
124 | `update()`.
125 |
126 | Args:
127 | regex: Regular expression defining which statistics to
128 | collect. The default is to collect everything.
129 | keep_previous: Whether to retain the previous averages if no
130 | scalars were collected on a given round
131 | (default: True).
132 | """
133 | def __init__(self, regex='.*', keep_previous=True):
134 | self._regex = re.compile(regex)
135 | self._keep_previous = keep_previous
136 | self._cumulative = dict()
137 | self._moments = dict()
138 | self.update()
139 | self._moments.clear()
140 |
141 | def names(self):
142 | r"""Returns the names of all statistics broadcasted so far that
143 | match the regular expression specified at construction time.
144 | """
145 | return [name for name in _counters if self._regex.fullmatch(name)]
146 |
147 | def update(self):
148 | r"""Copies current values of the internal counters to the
149 | user-visible state and resets them for the next round.
150 |
151 | If `keep_previous=True` was specified at construction time, the
152 | operation is skipped for statistics that have received no scalars
153 | since the last update, retaining their previous averages.
154 |
155 | This method performs a number of GPU-to-CPU transfers and one
156 | `torch.distributed.all_reduce()`. It is intended to be called
157 | periodically in the main training loop, typically once every
158 | N training steps.
159 | """
160 | if not self._keep_previous:
161 | self._moments.clear()
162 | for name, cumulative in _sync(self.names()):
163 | if name not in self._cumulative:
164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165 | delta = cumulative - self._cumulative[name]
166 | self._cumulative[name].copy_(cumulative)
167 | if float(delta[0]) != 0:
168 | self._moments[name] = delta
169 |
170 | def _get_delta(self, name):
171 | r"""Returns the raw moments that were accumulated for the given
172 | statistic between the last two calls to `update()`, or zero if
173 | no scalars were collected.
174 | """
175 | assert self._regex.fullmatch(name)
176 | if name not in self._moments:
177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178 | return self._moments[name]
179 |
180 | def num(self, name):
181 | r"""Returns the number of scalars that were accumulated for the given
182 | statistic between the last two calls to `update()`, or zero if
183 | no scalars were collected.
184 | """
185 | delta = self._get_delta(name)
186 | return int(delta[0])
187 |
188 | def mean(self, name):
189 | r"""Returns the mean of the scalars that were accumulated for the
190 | given statistic between the last two calls to `update()`, or NaN if
191 | no scalars were collected.
192 | """
193 | delta = self._get_delta(name)
194 | if int(delta[0]) == 0:
195 | return float('nan')
196 | return float(delta[1] / delta[0])
197 |
198 | def std(self, name):
199 | r"""Returns the standard deviation of the scalars that were
200 | accumulated for the given statistic between the last two calls to
201 | `update()`, or NaN if no scalars were collected.
202 | """
203 | delta = self._get_delta(name)
204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205 | return float('nan')
206 | if int(delta[0]) == 1:
207 | return float(0)
208 | mean = float(delta[1] / delta[0])
209 | raw_var = float(delta[2] / delta[0])
210 | return np.sqrt(max(raw_var - np.square(mean), 0))
211 |
212 | def as_dict(self):
213 | r"""Returns the averages accumulated between the last two calls to
214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215 |
216 | dnnlib.EasyDict(
217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218 | ...
219 | )
220 | """
221 | stats = dnnlib.EasyDict()
222 | for name in self.names():
223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224 | return stats
225 |
226 | def __getitem__(self, name):
227 | r"""Convenience getter.
228 | `collector[name]` is a synonym for `collector.mean(name)`.
229 | """
230 | return self.mean(name)
231 |
232 | #----------------------------------------------------------------------------
233 |
234 | def _sync(names):
235 | r"""Synchronize the global cumulative counters across devices and
236 | processes. Called internally by `Collector.update()`.
237 | """
238 | if len(names) == 0:
239 | return []
240 | global _sync_called
241 | _sync_called = True
242 |
243 | # Collect deltas within current rank.
244 | deltas = []
245 | device = _sync_device if _sync_device is not None else torch.device('cpu')
246 | for name in names:
247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248 | for counter in _counters[name].values():
249 | delta.add_(counter.to(device))
250 | counter.copy_(torch.zeros_like(counter))
251 | deltas.append(delta)
252 | deltas = torch.stack(deltas)
253 |
254 | # Sum deltas across ranks.
255 | if _sync_device is not None:
256 | torch.distributed.all_reduce(deltas)
257 |
258 | # Update cumulative values.
259 | deltas = deltas.cpu()
260 | for idx, name in enumerate(names):
261 | if name not in _cumulative:
262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263 | _cumulative[name].add_(deltas[idx])
264 |
265 | # Return name-value pairs.
266 | return [(name, _cumulative[name]) for name in names]
267 |
268 | #----------------------------------------------------------------------------
269 |
--------------------------------------------------------------------------------
/StyleGAN2/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import re
10 | import contextlib
11 | import numpy as np
12 | import torch
13 | import warnings
14 | import dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Replace NaN/Inf with specified numerical values.
45 |
46 | try:
47 | nan_to_num = torch.nan_to_num # 1.8.0a0
48 | except AttributeError:
49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50 | assert isinstance(input, torch.Tensor)
51 | if posinf is None:
52 | posinf = torch.finfo(input.dtype).max
53 | #posinf = 1
54 | if neginf is None:
55 | neginf = torch.finfo(input.dtype).min
56 | #neginf = 0
57 | assert nan == 0
58 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59 |
60 | #----------------------------------------------------------------------------
61 | # Symbolic assert.
62 |
63 | try:
64 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65 | except AttributeError:
66 | symbolic_assert = torch.Assert # 1.7.0
67 |
68 | #----------------------------------------------------------------------------
69 | # Context manager to suppress known warnings in torch.jit.trace().
70 |
71 | class suppress_tracer_warnings(warnings.catch_warnings):
72 | def __enter__(self):
73 | super().__enter__()
74 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
75 | return self
76 |
77 | #----------------------------------------------------------------------------
78 | # Assert that the shape of a tensor matches the given list of integers.
79 | # None indicates that the size of a dimension is allowed to vary.
80 | # Performs symbolic assertion when used in torch.jit.trace().
81 |
82 | def assert_shape(tensor, ref_shape):
83 | if tensor.ndim != len(ref_shape):
84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86 | if ref_size is None:
87 | pass
88 | elif isinstance(ref_size, torch.Tensor):
89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91 | elif isinstance(size, torch.Tensor):
92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94 | elif size != ref_size:
95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96 |
97 | #----------------------------------------------------------------------------
98 | # Function decorator that calls torch.autograd.profiler.record_function().
99 |
100 | def profiled_function(fn):
101 | def decorator(*args, **kwargs):
102 | with torch.autograd.profiler.record_function(fn.__name__):
103 | return fn(*args, **kwargs)
104 | decorator.__name__ = fn.__name__
105 | return decorator
106 |
107 | #----------------------------------------------------------------------------
108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
109 | # indefinitely, shuffling items as it goes.
110 |
111 | class InfiniteSampler(torch.utils.data.Sampler):
112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113 | assert len(dataset) > 0
114 | assert num_replicas > 0
115 | assert 0 <= rank < num_replicas
116 | assert 0 <= window_size <= 1
117 | super().__init__(dataset)
118 | self.dataset = dataset
119 | self.rank = rank
120 | self.num_replicas = num_replicas
121 | self.shuffle = shuffle
122 | self.seed = seed
123 | self.window_size = window_size
124 |
125 | def __iter__(self):
126 | order = np.arange(len(self.dataset))
127 | rnd = None
128 | window = 0
129 | if self.shuffle:
130 | rnd = np.random.RandomState(self.seed)
131 | rnd.shuffle(order)
132 | window = int(np.rint(order.size * self.window_size))
133 |
134 | idx = 0
135 | while True:
136 | i = idx % order.size
137 | if idx % self.num_replicas == self.rank:
138 | yield order[i]
139 | if window >= 2:
140 | j = (i - rnd.randint(window)) % order.size
141 | order[i], order[j] = order[j], order[i]
142 | idx += 1
143 |
144 | #----------------------------------------------------------------------------
145 | # Utilities for operating with torch.nn.Module parameters and buffers.
146 |
147 | def params_and_buffers(module):
148 | assert isinstance(module, torch.nn.Module)
149 | return list(module.parameters()) + list(module.buffers())
150 |
151 | def named_params_and_buffers(module):
152 | assert isinstance(module, torch.nn.Module)
153 | return list(module.named_parameters()) + list(module.named_buffers())
154 |
155 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
156 | assert isinstance(src_module, torch.nn.Module)
157 | assert isinstance(dst_module, torch.nn.Module)
158 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
159 | for name, tensor in named_params_and_buffers(dst_module):
160 | assert (name in src_tensors) or (not require_all)
161 | if name in src_tensors:
162 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
163 |
164 | #----------------------------------------------------------------------------
165 | # Context manager for easily enabling/disabling DistributedDataParallel
166 | # synchronization.
167 |
168 | @contextlib.contextmanager
169 | def ddp_sync(module, sync):
170 | assert isinstance(module, torch.nn.Module)
171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172 | yield
173 | else:
174 | with module.no_sync():
175 | yield
176 |
177 | #----------------------------------------------------------------------------
178 | # Check DistributedDataParallel consistency across processes.
179 |
180 | def check_ddp_consistency(module, ignore_regex=None):
181 | assert isinstance(module, torch.nn.Module)
182 | for name, tensor in named_params_and_buffers(module):
183 | fullname = type(module).__name__ + '.' + name
184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185 | continue
186 | tensor = tensor.detach()
187 | other = tensor.clone()
188 | torch.distributed.broadcast(tensor=other, src=0)
189 | #assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
190 |
191 | #----------------------------------------------------------------------------
192 | # Print summary table of module hierarchy.
193 |
194 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
195 | assert isinstance(module, torch.nn.Module)
196 | assert not isinstance(module, torch.jit.ScriptModule)
197 | assert isinstance(inputs, (tuple, list))
198 |
199 | # Register hooks.
200 | entries = []
201 | nesting = [0]
202 | def pre_hook(_mod, _inputs):
203 | nesting[0] += 1
204 | def post_hook(mod, _inputs, outputs):
205 | nesting[0] -= 1
206 | if nesting[0] <= max_nesting:
207 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
208 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
209 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
210 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
211 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
212 |
213 | # Run module.
214 | outputs = module(*inputs)
215 | for hook in hooks:
216 | hook.remove()
217 |
218 | # Identify unique outputs, parameters, and buffers.
219 | tensors_seen = set()
220 | for e in entries:
221 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
222 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
223 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
224 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
225 |
226 | # Filter out redundant entries.
227 | if skip_redundant:
228 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
229 |
230 | # Construct table.
231 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
232 | rows += [['---'] * len(rows[0])]
233 | param_total = 0
234 | buffer_total = 0
235 | submodule_names = {mod: name for name, mod in module.named_modules()}
236 | for e in entries:
237 | name = '' if e.mod is module else submodule_names[e.mod]
238 | param_size = sum(t.numel() for t in e.unique_params)
239 | buffer_size = sum(t.numel() for t in e.unique_buffers)
240 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
241 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
242 | rows += [[
243 | name + (':0' if len(e.outputs) >= 2 else ''),
244 | str(param_size) if param_size else '-',
245 | str(buffer_size) if buffer_size else '-',
246 | (output_shapes + ['-'])[0],
247 | (output_dtypes + ['-'])[0],
248 | ]]
249 | for idx in range(1, len(e.outputs)):
250 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
251 | param_total += param_size
252 | buffer_total += buffer_size
253 | rows += [['---'] * len(rows[0])]
254 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
255 |
256 | # Print table.
257 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
258 | print()
259 | for row in rows:
260 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
261 | print()
262 | return outputs
263 |
264 | #----------------------------------------------------------------------------
265 |
--------------------------------------------------------------------------------
/StyleGAN2/metrics/metric_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import time
11 | import hashlib
12 | import pickle
13 | import copy
14 | import uuid
15 | import numpy as np
16 | import torch
17 | import dnnlib
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | class MetricOptions:
22 | def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
23 | assert 0 <= rank < num_gpus
24 | self.G = G
25 | self.G_kwargs = dnnlib.EasyDict(G_kwargs)
26 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
27 | self.num_gpus = num_gpus
28 | self.rank = rank
29 | self.device = device if device is not None else torch.device('cuda', rank)
30 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
31 | self.cache = cache
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | _feature_detector_cache = dict()
36 |
37 | def get_feature_detector_name(url):
38 | return os.path.splitext(url.split('/')[-1])[0]
39 |
40 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
41 | assert 0 <= rank < num_gpus
42 | key = (url, device)
43 | if key not in _feature_detector_cache:
44 | is_leader = (rank == 0)
45 | if not is_leader and num_gpus > 1:
46 | torch.distributed.barrier() # leader goes first
47 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
48 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
49 | if is_leader and num_gpus > 1:
50 | torch.distributed.barrier() # others follow
51 | return _feature_detector_cache[key]
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | class FeatureStats:
56 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
57 | self.capture_all = capture_all
58 | self.capture_mean_cov = capture_mean_cov
59 | self.max_items = max_items
60 | self.num_items = 0
61 | self.num_features = None
62 | self.all_features = None
63 | self.raw_mean = None
64 | self.raw_cov = None
65 |
66 | def set_num_features(self, num_features):
67 | if self.num_features is not None:
68 | assert num_features == self.num_features
69 | else:
70 | self.num_features = num_features
71 | self.all_features = []
72 | self.raw_mean = np.zeros([num_features], dtype=np.float64)
73 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
74 |
75 | def is_full(self):
76 | return (self.max_items is not None) and (self.num_items >= self.max_items)
77 |
78 | def append(self, x):
79 | x = np.asarray(x, dtype=np.float32)
80 | assert x.ndim == 2
81 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
82 | if self.num_items >= self.max_items:
83 | return
84 | x = x[:self.max_items - self.num_items]
85 |
86 | self.set_num_features(x.shape[1])
87 | self.num_items += x.shape[0]
88 | if self.capture_all:
89 | self.all_features.append(x)
90 | if self.capture_mean_cov:
91 | x64 = x.astype(np.float64)
92 | self.raw_mean += x64.sum(axis=0)
93 | self.raw_cov += x64.T @ x64
94 |
95 | def append_torch(self, x, num_gpus=1, rank=0):
96 | assert isinstance(x, torch.Tensor) and x.ndim == 2
97 | assert 0 <= rank < num_gpus
98 | if num_gpus > 1:
99 | ys = []
100 | for src in range(num_gpus):
101 | y = x.clone()
102 | torch.distributed.broadcast(y, src=src)
103 | ys.append(y)
104 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
105 | self.append(x.cpu().numpy())
106 |
107 | def get_all(self):
108 | assert self.capture_all
109 | return np.concatenate(self.all_features, axis=0)
110 |
111 | def get_all_torch(self):
112 | return torch.from_numpy(self.get_all())
113 |
114 | def get_mean_cov(self):
115 | assert self.capture_mean_cov
116 | mean = self.raw_mean / self.num_items
117 | cov = self.raw_cov / self.num_items
118 | cov = cov - np.outer(mean, mean)
119 | return mean, cov
120 |
121 | def save(self, pkl_file):
122 | with open(pkl_file, 'wb') as f:
123 | pickle.dump(self.__dict__, f)
124 |
125 | @staticmethod
126 | def load(pkl_file):
127 | with open(pkl_file, 'rb') as f:
128 | s = dnnlib.EasyDict(pickle.load(f))
129 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
130 | obj.__dict__.update(s)
131 | return obj
132 |
133 | #----------------------------------------------------------------------------
134 |
135 | class ProgressMonitor:
136 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
137 | self.tag = tag
138 | self.num_items = num_items
139 | self.verbose = verbose
140 | self.flush_interval = flush_interval
141 | self.progress_fn = progress_fn
142 | self.pfn_lo = pfn_lo
143 | self.pfn_hi = pfn_hi
144 | self.pfn_total = pfn_total
145 | self.start_time = time.time()
146 | self.batch_time = self.start_time
147 | self.batch_items = 0
148 | if self.progress_fn is not None:
149 | self.progress_fn(self.pfn_lo, self.pfn_total)
150 |
151 | def update(self, cur_items):
152 | assert (self.num_items is None) or (cur_items <= self.num_items)
153 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
154 | return
155 | cur_time = time.time()
156 | total_time = cur_time - self.start_time
157 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
158 | if (self.verbose) and (self.tag is not None):
159 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
160 | self.batch_time = cur_time
161 | self.batch_items = cur_items
162 |
163 | if (self.progress_fn is not None) and (self.num_items is not None):
164 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
165 |
166 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
167 | return ProgressMonitor(
168 | tag = tag,
169 | num_items = num_items,
170 | flush_interval = flush_interval,
171 | verbose = self.verbose,
172 | progress_fn = self.progress_fn,
173 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
174 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
175 | pfn_total = self.pfn_total,
176 | )
177 |
178 | #----------------------------------------------------------------------------
179 |
180 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
181 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
182 | if data_loader_kwargs is None:
183 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
184 |
185 | # Try to lookup from cache.
186 | cache_file = None
187 | if opts.cache:
188 | # Choose cache file name.
189 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
190 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
191 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
192 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
193 |
194 | # Check if the file exists (all processes must agree).
195 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False
196 | if opts.num_gpus > 1:
197 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
198 | torch.distributed.broadcast(tensor=flag, src=0)
199 | flag = (float(flag.cpu()) != 0)
200 |
201 | # Load.
202 | if flag:
203 | return FeatureStats.load(cache_file)
204 |
205 | # Initialize.
206 | num_items = len(dataset)
207 | if max_items is not None:
208 | num_items = min(num_items, max_items)
209 | stats = FeatureStats(max_items=num_items, **stats_kwargs)
210 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
211 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
212 |
213 | # Main loop.
214 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
216 | if images.shape[1] == 1:
217 | images = images.repeat([1, 3, 1, 1])
218 | features = detector(images.to(opts.device), **detector_kwargs)
219 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
220 | progress.update(stats.num_items)
221 |
222 | # Save to cache.
223 | if cache_file is not None and opts.rank == 0:
224 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
225 | temp_file = cache_file + '.' + uuid.uuid4().hex
226 | stats.save(temp_file)
227 | os.replace(temp_file, cache_file) # atomic
228 | return stats
229 |
230 | #----------------------------------------------------------------------------
231 |
232 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
233 | if batch_gen is None:
234 | batch_gen = min(batch_size, 4)
235 | assert batch_size % batch_gen == 0
236 |
237 | # Setup generator and load labels.
238 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
239 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
240 |
241 | # Image generation func.
242 | def run_generator(z, c):
243 | img = G(z=z, c=c, **opts.G_kwargs)
244 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
245 | return img
246 |
247 | # JIT.
248 | if jit:
249 | z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
250 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
251 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
252 |
253 | # Initialize.
254 | stats = FeatureStats(**stats_kwargs)
255 | assert stats.max_items is not None
256 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
257 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
258 |
259 | # Main loop.
260 | while not stats.is_full():
261 | images = []
262 | for _i in range(batch_size // batch_gen):
263 | z = torch.randn([batch_gen, G.z_dim], device=opts.device)
264 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
265 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
266 | images.append(run_generator(z, c))
267 | images = torch.cat(images)
268 | if images.shape[1] == 1:
269 | images = images.repeat([1, 3, 1, 1])
270 | features = detector(images, **detector_kwargs)
271 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
272 | progress.update(stats.num_items)
273 | return stats
274 |
275 | #----------------------------------------------------------------------------
276 |
--------------------------------------------------------------------------------
/StyleGAN2/training/loss_interp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 | from torch_utils import training_stats
12 | from torch_utils import misc
13 | from torch_utils.ops import conv2d_gradfix
14 |
15 | from torch import nn
16 | import torch.nn.functional as F
17 | from math import exp
18 | from sklearn.manifold import MDS
19 | import random
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | class Loss:
24 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
25 | raise NotImplementedError()
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | class StyleGAN2Loss(Loss):
30 | def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
31 | super().__init__()
32 | self.device = device
33 | self.G_mapping = G_mapping
34 | self.G_synthesis = G_synthesis
35 | self.D = D
36 | self.augment_pipe = augment_pipe
37 | self.style_mixing_prob = style_mixing_prob
38 | self.r1_gamma = r1_gamma
39 | self.pl_batch_shrink = pl_batch_shrink
40 | self.pl_decay = pl_decay
41 | self.pl_weight = pl_weight
42 | self.pl_mean = torch.zeros([], device=device)
43 |
44 | def run_G(self, z, c, sync):
45 | with misc.ddp_sync(self.G_mapping, sync):
46 | ws = self.G_mapping(z, c)
47 | if self.style_mixing_prob > 0:
48 | with torch.autograd.profiler.record_function('style_mixing'):
49 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
50 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
51 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
52 | with misc.ddp_sync(self.G_synthesis, sync):
53 | img = self.G_synthesis(ws)
54 | return img, ws
55 |
56 | def run_D(self, img, c, sync):
57 | if self.augment_pipe is not None:
58 | img = self.augment_pipe(img)
59 | with misc.ddp_sync(self.D, sync):
60 | logits = self.D(img, c)
61 | return logits
62 |
63 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
64 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
65 | do_Gmain = (phase in ['Gmain', 'Gboth'])
66 | do_Dmain = (phase in ['Dmain', 'Dboth'])
67 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
68 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
69 |
70 | augment_prob, kk = 0, 0
71 | # Gmain: Maximize logits for generated images.
72 | if do_Gmain:
73 | with torch.autograd.profiler.record_function('Gmain_forward'):
74 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
75 | gen_logits = self.run_D(gen_img, gen_c, sync=False)
76 | training_stats.report('Loss/scores/fake', gen_logits)
77 | training_stats.report('Loss/signs/fake', gen_logits.sign())
78 | #loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
79 | augment_prob, kk = dynamic_prob(gen_logits)
80 | batch_size = gen_logits.size()[0]
81 | gen_logits_aug = near_interp(gen_logits, kk, augment_prob)
82 | #loss_Gmain = torch.nn.functional.softplus(-gen_logits_aug).mean()
83 | loss_Gmain = torch.nn.functional.relu(-gen_logits_aug).mean()
84 | training_stats.report('Loss/G/loss', loss_Gmain)
85 | with torch.autograd.profiler.record_function('Gmain_backward'):
86 | loss_Gmain.mean().mul(gain).backward()
87 |
88 | # Gpl: Apply path length regularization.
89 | if do_Gpl:
90 | with torch.autograd.profiler.record_function('Gpl_forward'):
91 | batch_size = gen_z.shape[0] // self.pl_batch_shrink
92 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
93 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
94 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
95 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
96 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
97 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
98 | self.pl_mean.copy_(pl_mean.detach())
99 | pl_penalty = (pl_lengths - pl_mean).square()
100 | training_stats.report('Loss/pl_penalty', pl_penalty)
101 | loss_Gpl = pl_penalty * self.pl_weight
102 | training_stats.report('Loss/G/reg', loss_Gpl)
103 | with torch.autograd.profiler.record_function('Gpl_backward'):
104 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
105 |
106 | # Dmain: Minimize logits for generated images.
107 | loss_Dgen = 0
108 | if do_Dmain:
109 | with torch.autograd.profiler.record_function('Dgen_forward'):
110 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
111 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
112 | training_stats.report('Loss/scores/fake', gen_logits)
113 | training_stats.report('Loss/signs/fake', gen_logits.sign())
114 | #loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
115 | augment_prob, kk = dynamic_prob(gen_logits)
116 | batch_size = gen_logits.size()[0]
117 | gen_logits_aug = near_interp(gen_logits, kk, augment_prob)
118 | #loss_Dgen = torch.nn.functional.softplus(gen_logits_aug).mean()
119 | loss_Dgen = torch.nn.functional.relu(1 + gen_logits_aug).mean()
120 | with torch.autograd.profiler.record_function('Dgen_backward'):
121 | loss_Dgen.mean().mul(gain).backward()
122 |
123 | # Dmain: Maximize logits for real images.
124 | # Dr1: Apply R1 regularization.
125 | if do_Dmain or do_Dr1:
126 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
127 | with torch.autograd.profiler.record_function(name + '_forward'):
128 | real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
129 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
130 | training_stats.report('Loss/scores/real', real_logits)
131 | training_stats.report('Loss/signs/real', real_logits.sign())
132 |
133 | loss_Dreal = 0
134 | if do_Dmain:
135 | #loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
136 | augment_prob, kk = dynamic_prob(real_logits)
137 | batch_size = real_logits.size()[0]
138 | #print(augment_prob, kk)
139 | real_logits_aug = near_interp(real_logits, kk, augment_prob)
140 | #loss_Dreal = torch.nn.functional.softplus(-real_logits_aug).mean()
141 | loss_Dreal = torch.nn.functional.relu(1-real_logits_aug).mean()
142 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
143 |
144 | loss_Dr1 = 0
145 | if do_Dr1:
146 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
147 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
148 | r1_penalty = r1_grads.square().sum([1,2,3])
149 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
150 | training_stats.report('Loss/r1_penalty', r1_penalty)
151 | training_stats.report('Loss/D/reg', loss_Dr1)
152 |
153 | with torch.autograd.profiler.record_function(name + '_backward'):
154 | #(real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
155 | (real_logits.mean() * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
156 |
157 | return augment_prob, kk
158 |
159 |
160 |
161 | ### Adaptive Feature Interpolation----------------------------------------------------------------------------
162 |
163 | def near_interp(embeddings, k, augment_prob):
164 | if k == 1 or augment_prob == 0:
165 | return embeddings
166 |
167 | k = min(k, embeddings.size()[0])
168 |
169 | pd = pairwise_distances(embeddings, embeddings)
170 | pd = pd/pd.max()
171 | pd_s = (1 / (1+pd))
172 |
173 | k_smallest = torch.topk(pd, k, largest=False).indices # shape: batch_size x k
174 |
175 | t = 1
176 | alpha = torch.ones(k, device=embeddings.device)
177 | inner_embeddings = []
178 | for row in k_smallest:
179 | for i in range(k):
180 | alpha[i] = pd_s[row[0],row[i]]**t
181 |
182 | p = torch.distributions.dirichlet.Dirichlet(alpha).sample().to(embeddings.device)
183 | # print(p)
184 | inner_pts = torch.matmul(p.reshape((1,-1)),embeddings.index_select(0,row))
185 | inner_embeddings.append(F.normalize(inner_pts))
186 |
187 | batch_size = embeddings.size()[0]
188 | out_embeddings = []
189 | for idx in range(batch_size):
190 | p = random.random()
191 | if p < augment_prob:
192 | out_embeddings.append(inner_embeddings[idx])
193 | else:
194 | out_embeddings.append(embeddings[idx,:].unsqueeze(0))
195 |
196 | return torch.stack(out_embeddings).reshape((batch_size,-1))
197 |
198 |
199 | def dynamic_prob(embeddings):
200 | embeddings = F.normalize(embeddings)
201 | batch_size = embeddings.size()[0]
202 |
203 | D = pairwise_distances(embeddings, embeddings)
204 | D = D.detach().cpu().numpy()
205 | D = D / np.amax(D)
206 |
207 | #l_sorted = cmdscale(D)
208 | l_sorted = eigen_mds(D)
209 |
210 | k = batch_size - next(x[0] for x in enumerate(l_sorted) if x[1] < 0.1 * l_sorted[0])
211 | p = (k-1) / batch_size
212 |
213 | #k = 2
214 | #p = 0.9
215 |
216 | return p, k
217 |
218 |
219 | def cmdscale(D):
220 | """
221 | Classical multidimensional scaling (MDS)
222 |
223 | Parameters
224 | ----------
225 | D : (n, n) array
226 | Symmetric distance matrix.
227 |
228 | Returns
229 | -------
230 | Y : (n, p) array
231 | Configuration matrix. Each column represents a dimension. Only the
232 | p dimensions corresponding to positive eigenvalues of B are returned.
233 | Note that each dimension is only determined up to an overall sign,
234 | corresponding to a reflection.
235 |
236 | e : (n,) array
237 | Eigenvalues of B.
238 |
239 | """
240 |
241 | # Number of points
242 | n = len(D)
243 |
244 | # Centering matrix
245 | H = np.eye(n) - np.ones((n, n))/n
246 |
247 | # YY^T
248 | B = -H.dot(D**2).dot(H)/2
249 |
250 | # Diagonalize
251 | evals, evecs = np.linalg.eigh(B)
252 |
253 | # Sort by eigenvalue in descending order
254 | idx = np.argsort(evals)[::-1]
255 | evals = evals[idx]
256 | evecs = evecs[:,idx]
257 |
258 | # Compute the coordinates using positive-eigenvalued components only
259 | # w, = np.where(evals > 0)
260 | # L = np.diag(np.sqrt(evals[w]))
261 | # V = evecs[:,w]
262 | # Y = V.dot(L)
263 |
264 | return np.sort(evals)[::-1]
265 |
266 |
267 | def eigen_mds(pd):
268 | mds = MDS(n_components=len(pd), dissimilarity='precomputed')
269 | pts = mds.fit_transform(pd)
270 |
271 | _,l_sorted,_ = np.linalg.svd(pts)
272 |
273 | return l_sorted
274 |
275 |
276 | def pairwise_distances(x, y):
277 | '''
278 | Input: x is a Nxd matrix
279 | y is an optional Mxd matirx
280 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
281 | if y is not given then use 'y=x'.
282 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
283 | '''
284 | x_norm = (x**2).sum(1).view(-1, 1)
285 | if y is not None:
286 | y_t = torch.transpose(y, 0, 1)
287 | y_norm = (y**2).sum(1).view(1, -1)
288 | else:
289 | y_t = torch.transpose(x, 0, 1)
290 | y_norm = x_norm.view(1, -1)
291 |
292 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
293 | # Ensure diagonal is zero if x=y
294 | # if y is None:
295 | # dist = dist - torch.diag(dist.diag)
296 | return torch.sqrt(torch.clamp(dist, 0.0, np.inf))
297 |
298 | #----------------------------------------------------------------------------
299 |
--------------------------------------------------------------------------------