├── test_nvcc ├── .gitignore ├── imgs └── eccv20_architecture.png ├── docs ├── stylegan2-teaser-1024x256.png ├── stylegan2-training-curves.png ├── versions.html └── license.html ├── metrics ├── __init__.py ├── metric_defaults.py ├── inception_score.py ├── frechet_inception_distance.py ├── perceptual_path_length.py ├── metric_base.py ├── linear_separability.py └── precision_recall.py ├── training ├── __init__.py ├── misc.py └── dataset.py ├── dnnlib ├── tflib │ ├── ops │ │ ├── __init__.py │ │ ├── fused_bias_act.cu │ │ └── fused_bias_act.py │ ├── __init__.py │ ├── custom_ops.py │ ├── autosummary.py │ └── tfutil.py ├── submission │ ├── internal │ │ ├── __init__.py │ │ └── local.py │ ├── __init__.py │ └── run_context.py └── __init__.py ├── run_generate_grid.sh ├── run_pair_imgs.sh ├── Dockerfile ├── test_nvcc.cu ├── train_run.sh ├── grab_traversals.py ├── run_metrics.py ├── README.md ├── LICENSE.txt ├── run_generator_vc.py ├── run_pair_generator_vc.py ├── run_projector.py ├── pretrained_networks.py ├── run_generator.py ├── run_training.py ├── projector.py ├── projector_vc.py └── run_unsupervised_acc.py /test_nvcc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuxinqimac/stylegan2vp/HEAD/test_nvcc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | .*.swp 3 | *.swo 4 | .*.swo 5 | *.pkl 6 | *.pyc 7 | .*.pyc 8 | *.so 9 | .*.so 10 | -------------------------------------------------------------------------------- /imgs/eccv20_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuxinqimac/stylegan2vp/HEAD/imgs/eccv20_architecture.png -------------------------------------------------------------------------------- /docs/stylegan2-teaser-1024x256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuxinqimac/stylegan2vp/HEAD/docs/stylegan2-teaser-1024x256.png -------------------------------------------------------------------------------- /docs/stylegan2-training-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuxinqimac/stylegan2vp/HEAD/docs/stylegan2-training-curves.png -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /dnnlib/submission/internal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import local 8 | -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import run_context 8 | from . import submit 9 | -------------------------------------------------------------------------------- /run_generate_grid.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python run_generator_vc.py \ 3 | --network_pkl /mnt/hdd/repo_results/stylegan2vp/test/test.pkl \ 4 | --n_imgs 100 \ 5 | --n_discrete 0 \ 6 | --n_continuous 30 \ 7 | --model_type vc_gan_with_vc_head \ 8 | --result-dir /mnt/hdd/repo_results/stylegan2vp/test/generation 9 | -------------------------------------------------------------------------------- /run_pair_imgs.sh: -------------------------------------------------------------------------------- 1 | python run_pair_generator_vc.py \ 2 | --network_pkl /mnt/hdd/repo_results/stylegan2vp/test/test.pkl \ 3 | --n_imgs 10000 \ 4 | --n_discrete 0 \ 5 | --n_continuous 30 \ 6 | --batch_size 100 \ 7 | --latent_type onedim \ 8 | --model_type vc_gan_with_vc_head \ 9 | --result-dir /mnt/hdd/repo_results/stylegan2vp/test/pair_dataset 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | FROM tensorflow/tensorflow:1.15.0-gpu-py3 8 | 9 | RUN pip install scipy==1.3.3 10 | RUN pip install requests==2.22.0 11 | RUN pip install Pillow==6.2.1 12 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import autosummary 8 | from . import network 9 | from . import optimizer 10 | from . import tfutil 11 | from . import custom_ops 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | 18 | from .custom_ops import get_plugin 19 | -------------------------------------------------------------------------------- /test_nvcc.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | void checkCudaError(cudaError_t err) 10 | { 11 | if (err != cudaSuccess) 12 | { 13 | printf("%s: %s\n", cudaGetErrorName(err), cudaGetErrorString(err)); 14 | exit(1); 15 | } 16 | } 17 | 18 | __global__ void cudaKernel(void) 19 | { 20 | printf("GPU says hello.\n"); 21 | } 22 | 23 | int main(void) 24 | { 25 | printf("CPU says hello.\n"); 26 | checkCudaError(cudaLaunchKernel((void*)cudaKernel, 1, 1, NULL, 0, NULL)); 27 | checkCudaError(cudaDeviceSynchronize()); 28 | return 0; 29 | } 30 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import submission 8 | 9 | from .submission.run_context import RunContext 10 | 11 | from .submission.submit import SubmitTarget 12 | from .submission.submit import PathType 13 | from .submission.submit import SubmitConfig 14 | from .submission.submit import submit_run 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import convert_path 17 | from .submission.submit import make_run_dir_path 18 | 19 | from .util import EasyDict 20 | 21 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 22 | -------------------------------------------------------------------------------- /dnnlib/submission/internal/local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | class TargetOptions(): 8 | def __init__(self): 9 | self.do_not_copy_source_files = False 10 | 11 | class Target(): 12 | def __init__(self): 13 | pass 14 | 15 | def finalize_submit_config(self, submit_config, host_run_dir): 16 | print ('Local submit ', end='', flush=True) 17 | submit_config.run_dir = host_run_dir 18 | 19 | def submit(self, submit_config, host_run_dir): 20 | from ..submit import run_wrapper, convert_path 21 | print('- run_dir: %s' % convert_path(submit_config.run_dir), flush=True) 22 | return run_wrapper(submit_config) 23 | -------------------------------------------------------------------------------- /train_run.sh: -------------------------------------------------------------------------------- 1 | python run_training_vc.py \ 2 | --result-dir /mnt/hdd/repo_results/stylegan2vp/results_vc \ 3 | --data-dir /mnt/hdd/Datasets/CelebA_dataset \ 4 | --dataset celeba_tfr \ 5 | --num-gpus 2 \ 6 | --model_type vc_gan_with_vc_head \ 7 | --C_lambda 0.01 \ 8 | --random_eps True \ 9 | --latent_type uniform \ 10 | --delta_type onedim \ 11 | --module_list '[Conv-up-1, C_global-10, Conv-id-2, Noise-2, Conv-up-1, C_global-10, Conv-id-1, Noise-2, Conv-up-1, C_global-5, Conv-id-2, Noise-2, Conv-id-1, Noise-2, Conv-up-1, C_global-5, Conv-id-1, Noise-2, Conv-id-2, Conv-up-1, Conv-id-1]' 12 | 13 | #python run_training_vc.py \ 14 | #--result-dir /mnt/hdd/repo_results/stylegan2vp/results_info \ 15 | #--data-dir /mnt/hdd/Datasets/CelebA_dataset \ 16 | #--dataset celeba_tfr \ 17 | #--num-gpus 2 \ 18 | #--model_type info_gan \ 19 | #--C_lambda 0.01 \ 20 | #--latent_type uniform \ 21 | #--module_list '[Conv-up-1, C_global-10, Conv-id-2, Noise-2, Conv-up-1, C_global-10, Conv-id-1, Noise-2, Conv-up-1, C_global-5, Conv-id-2, Noise-2, Conv-id-1, Noise-2, Conv-up-1, C_global-5, Conv-id-1, Noise-2, Conv-id-2, Conv-up-1, Conv-id-1]' 22 | -------------------------------------------------------------------------------- /docs/versions.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | StyleGAN versions 7 | 8 | 45 | 46 | 47 | 48 |

StyleGAN2

49 | 54 | 55 |

Original StyleGAN

56 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /grab_traversals.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | # >.>.>.>.>.>.>.>.>.>.>.>.>.>.>.>. 5 | # Licensed under the Apache License, Version 2.0 (the "License") 6 | # You may obtain a copy of the License at 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # --- File Name: grab_traversals.py 10 | # --- Creation Date: 02-03-2020 11 | # --- Last Modified: Mon 02 Mar 2020 03:01:03 AEDT 12 | # --- Author: Xinqi Zhu 13 | # .<.<.<.<.<.<.<.<.<.<.<.<.<.<.<.< 14 | """ 15 | Grab a row of traversals from generated grids. 16 | """ 17 | 18 | import argparse 19 | import os 20 | import glob 21 | import pdb 22 | from PIL import Image 23 | 24 | 25 | def crop_images(args): 26 | if not os.path.exists(args.result_dir): 27 | os.makedirs(args.result_dir) 28 | source_imgs_path = sorted(glob.glob(os.path.join(args.source_dir, 'img_*.png'))) 29 | for source_path in source_imgs_path[:args.n_used_sources]: 30 | img_name = os.path.basename(source_path) 31 | img = Image.open(source_path) 32 | img_x, img_y = img.size 33 | y_s = (args.row - 1) * args.crop_h 34 | crop_row = img.crop((0, y_s, img_x, y_s + args.crop_h)) 35 | save_path = os.path.join(args.result_dir, img_name) 36 | crop_row.save(save_path) 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser(description='Project description.') 41 | parser.add_argument('--result_dir', 42 | help='Results directory.', 43 | type=str, 44 | default='/mnt/hdd/repo_results/test') 45 | parser.add_argument('--source_dir', 46 | help='Grid directory.', 47 | type=str, 48 | default='/mnt/hdd/Datasets/test_data') 49 | parser.add_argument('--row', 50 | help='Which row to grab. Starting from 1.', 51 | type=int) 52 | parser.add_argument('--crop_h', type=int, default=128) 53 | parser.add_argument('--crop_w', type=int, default=128) 54 | parser.add_argument('--n_used_sources', type=int, default=20) 55 | 56 | args = parser.parse_args() 57 | 58 | crop_images(args) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /metrics/metric_defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Default metric definitions.""" 8 | 9 | from dnnlib import EasyDict 10 | 11 | #---------------------------------------------------------------------------- 12 | 13 | metric_defaults = EasyDict([(args.name, args) for args in [ 14 | EasyDict(name='fid50k', func_name='metrics.frechet_inception_distance.FID', num_images=50000, minibatch_per_gpu=8), 15 | EasyDict(name='is50k', func_name='metrics.inception_score.IS', num_images=50000, num_splits=10, minibatch_per_gpu=8), 16 | EasyDict(name='ppl_zfull', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')), 17 | EasyDict(name='ppl_wfull', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')), 18 | EasyDict(name='ppl_zend', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')), 19 | EasyDict(name='ppl_wend', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')), 20 | EasyDict(name='ppl2_wend', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')), 21 | EasyDict(name='ls', func_name='metrics.linear_separability.LS', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4), 22 | EasyDict(name='pr50k3', func_name='metrics.precision_recall.PR', num_images=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000), 23 | ]]) 24 | 25 | #---------------------------------------------------------------------------- 26 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Inception Score (IS).""" 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import dnnlib.tflib as tflib 12 | 13 | from metrics import metric_base 14 | from training import misc 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | class IS(metric_base.MetricBase): 19 | def __init__(self, num_images, num_splits, minibatch_per_gpu, **kwargs): 20 | super().__init__(**kwargs) 21 | self.num_images = num_images 22 | self.num_splits = num_splits 23 | self.minibatch_per_gpu = minibatch_per_gpu 24 | 25 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 26 | minibatch_size = num_gpus * self.minibatch_per_gpu 27 | inception = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_softmax.pkl') 28 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 29 | 30 | # Construct TensorFlow graph. 31 | result_expr = [] 32 | for gpu_idx in range(num_gpus): 33 | with tf.device('/gpu:%d' % gpu_idx): 34 | Gs_clone = Gs.clone() 35 | inception_clone = inception.clone() 36 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 37 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 38 | images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs) 39 | images = tflib.convert_images_to_uint8(images) 40 | result_expr.append(inception_clone.get_output_for(images)) 41 | 42 | # Calculate activations for fakes. 43 | for begin in range(0, self.num_images, minibatch_size): 44 | self._report_progress(begin, self.num_images) 45 | end = min(begin + minibatch_size, self.num_images) 46 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 47 | 48 | # Calculate IS. 49 | scores = [] 50 | for i in range(self.num_splits): 51 | part = activations[i * self.num_images // self.num_splits : (i + 1) * self.num_images // self.num_splits] 52 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 53 | kl = np.mean(np.sum(kl, 1)) 54 | scores.append(np.exp(kl)) 55 | self._report_result(np.mean(scores), suffix='_mean') 56 | self._report_result(np.std(scores), suffix='_std') 57 | 58 | #---------------------------------------------------------------------------- 59 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Frechet Inception Distance (FID).""" 8 | 9 | import os 10 | import numpy as np 11 | import scipy 12 | import tensorflow as tf 13 | import dnnlib.tflib as tflib 14 | 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class FID(metric_base.MetricBase): 21 | def __init__(self, num_images, minibatch_per_gpu, **kwargs): 22 | super().__init__(**kwargs) 23 | self.num_images = num_images 24 | self.minibatch_per_gpu = minibatch_per_gpu 25 | 26 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 27 | minibatch_size = num_gpus * self.minibatch_per_gpu 28 | inception = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/inception_v3_features.pkl') 29 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 30 | 31 | # Calculate statistics for reals. 32 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images) 33 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 34 | if os.path.isfile(cache_file): 35 | mu_real, sigma_real = misc.load_pkl(cache_file) 36 | else: 37 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): 38 | begin = idx * minibatch_size 39 | end = min(begin + minibatch_size, self.num_images) 40 | activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) 41 | if end == self.num_images: 42 | break 43 | mu_real = np.mean(activations, axis=0) 44 | sigma_real = np.cov(activations, rowvar=False) 45 | misc.save_pkl((mu_real, sigma_real), cache_file) 46 | 47 | # Construct TensorFlow graph. 48 | result_expr = [] 49 | for gpu_idx in range(num_gpus): 50 | with tf.device('/gpu:%d' % gpu_idx): 51 | # Gs_clone = Gs.clone() 52 | Gs_clone = Gs 53 | inception_clone = inception.clone() 54 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 55 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 56 | images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs) 57 | images = tflib.convert_images_to_uint8(images) 58 | result_expr.append(inception_clone.get_output_for(images)) 59 | 60 | # Calculate statistics for fakes. 61 | for begin in range(0, self.num_images, minibatch_size): 62 | self._report_progress(begin, self.num_images) 63 | end = min(begin + minibatch_size, self.num_images) 64 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 65 | mu_fake = np.mean(activations, axis=0) 66 | sigma_fake = np.cov(activations, rowvar=False) 67 | 68 | # Calculate FID. 69 | m = np.square(mu_fake - mu_real).sum() 70 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member 71 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 72 | self._report_result(np.real(dist)) 73 | 74 | #---------------------------------------------------------------------------- 75 | -------------------------------------------------------------------------------- /run_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import dnnlib 12 | import dnnlib.tflib as tflib 13 | 14 | import pretrained_networks 15 | from metrics import metric_base 16 | from metrics.metric_defaults import metric_defaults 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def run(network_pkl, metrics, dataset, data_dir, mirror_augment): 21 | print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) 22 | tflib.init_tf() 23 | network_pkl = pretrained_networks.get_path_or_url(network_pkl) 24 | dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0) 25 | num_gpus = dnnlib.submit_config.num_gpus 26 | metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics]) 27 | metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus) 28 | 29 | #---------------------------------------------------------------------------- 30 | 31 | def _str_to_bool(v): 32 | if isinstance(v, bool): 33 | return v 34 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 35 | return True 36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 37 | return False 38 | else: 39 | raise argparse.ArgumentTypeError('Boolean value expected.') 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | _examples = '''examples: 44 | 45 | python %(prog)s --data-dir=~/datasets --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --metrics=fid50k,ppl_wend --dataset=ffhq --mirror-augment=true 46 | 47 | valid metrics: 48 | 49 | ''' + ', '.join(sorted([x for x in metric_defaults.keys()])) + ''' 50 | ''' 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser( 54 | description='Run StyleGAN2 metrics.', 55 | epilog=_examples, 56 | formatter_class=argparse.RawDescriptionHelpFormatter 57 | ) 58 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 59 | parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 60 | parser.add_argument('--metrics', help='Metrics to compute (default: %(default)s)', default='fid50k', type=lambda x: x.split(',')) 61 | parser.add_argument('--dataset', help='Training dataset', required=True) 62 | parser.add_argument('--data-dir', help='Dataset root directory', required=True) 63 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, type=_str_to_bool, metavar='BOOL') 64 | parser.add_argument('--num-gpus', help='Number of GPUs to use', type=int, default=1, metavar='N') 65 | 66 | args = parser.parse_args() 67 | 68 | if not os.path.exists(args.data_dir): 69 | print ('Error: dataset root directory does not exist.') 70 | sys.exit(1) 71 | 72 | kwargs = vars(args) 73 | sc = dnnlib.SubmitConfig() 74 | sc.num_gpus = kwargs.pop('num_gpus') 75 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 76 | sc.local.do_not_copy_source_files = True 77 | sc.run_dir_root = kwargs.pop('result_dir') 78 | sc.run_desc = 'run-metrics' 79 | dnnlib.submit_run(sc, 'run_metrics.run', **kwargs) 80 | 81 | #---------------------------------------------------------------------------- 82 | 83 | if __name__ == "__main__": 84 | main() 85 | 86 | #---------------------------------------------------------------------------- 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Disentangled Representations with Latent Variation Predictability 2 | 3 | This repository contains the code for [Learning Disentangled Representations with Latent Variation Predictability]. 4 | 5 | ## Abstract 6 | 7 | Latent traversal is a popular approach to visualize the disentangled latent 8 | representations. Given a bunch of variations in a single unit of the latent 9 | representation, it is expected that there is a change in a single factor of 10 | variation of the data while others are fixed. However, this impressive 11 | experimental observation is rarely explicitly encoded in the objective 12 | function of learning disentangled representations. This paper defines the 13 | variation predictability of latent disentangled representations. Given 14 | image pairs generated by latent codes varying in a single dimension, 15 | this varied dimension could be closely correlated with these image 16 | pairs if the representation is well disentangled. Within an adversarial 17 | generation process, we encourage variation predictability by maximizing the 18 | mutual information between latent variations and corresponding image pairs. 19 | We further develop an evaluation metric that does not rely on the 20 | ground-truth generative factors to measure the disentanglement of latent 21 | representations. The proposed variation predictability is a general constraint 22 | that is applicable to the VAE and GAN frameworks for boosting 23 | disentanglement of latent representations. Experiments show that the proposed 24 | variation predictability correlates well with existing ground-truth-required 25 | metrics and the proposed algorithm is effective for disentanglement learning. 26 | 27 | ## Requirements 28 | 29 | * 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer. 30 | * TensorFlow 1.14 or 1.15 with GPU support. The code does not support TensorFlow 2.0. 31 | 32 | This project is based on StyleGAN2, which relies on custom TensorFlow ops that are compiled on the fly using [NVCC](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html). 33 | To test that your NVCC installation is working correctly, run: 34 | 35 | ```.bash 36 | nvcc test_nvcc.cu -o test_nvcc -run 37 | | CPU says hello. 38 | | GPU says hello. 39 | ``` 40 | For more detailed instruction about StyleGAN2 environment setup, see [StyleGAN2](https://github.com/NVlabs/stylegan2). 41 | 42 | ## CelebA Dataset 43 | To prepare the tfrecord version of CelebA dataset, first download the original aligned-and-cropped version 44 | from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, then use the following code to 45 | create tfrecord dataset: 46 | 47 | ``` 48 | python dataset_tool.py create_celeba /path/to/new_tfr_dir /path/to/downloaded_celeba_dir 49 | ``` 50 | 51 | ## Training 52 | 53 | To train a model on CelebA dataset, run this command: 54 | 55 | ``` 56 | sh train_run.sh 57 | ``` 58 | 59 | You can modify this script to train different model variants. 60 | Note that for flag --data-dir you need to enter the parent directory path of 61 | the actual dataset, and use --dataset for the actual dataset directory name. 62 | 63 | ## Evaluation 64 | 65 | To evaluate trained models by variation predictability metric, run: 66 | 67 | ``` 68 | sh run_pair_imgs.sh 69 | ``` 70 | 71 | to generate a dataset of image pairs. You need to modify this script to 72 | fit your result-dir and the trained network pkl. 73 | 74 | Then use this repository https://github.com/zhuxinqimac/VP-metric-pytorch to 75 | get the VP score using the generated dataset. 76 | You should run multiple times (e.g. 3) 77 | of this evaluation procedure to obtain an averaged score for your model. 78 | 79 | ## Citation 80 | ``` 81 | @inproceedings{VPdis_eccv20, 82 | author={Xinqi Zhu and Chang Xu and Dacheng Tao}, 83 | title={Learning Disentangled Representations with Latent Variation Predictability}, 84 | booktitle={ECCV}, 85 | year={2020} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Helpers for managing the run/training loop.""" 8 | 9 | import datetime 10 | import json 11 | import os 12 | import pprint 13 | import time 14 | import types 15 | 16 | from typing import Any 17 | 18 | from . import submit 19 | 20 | # Singleton RunContext 21 | _run_context = None 22 | 23 | class RunContext(object): 24 | """Helper class for managing the run/training loop. 25 | 26 | The context will hide the implementation details of a basic run/training loop. 27 | It will set things up properly, tell if run should be stopped, and then cleans up. 28 | User should call update periodically and use should_stop to determine if run should be stopped. 29 | 30 | Args: 31 | submit_config: The SubmitConfig that is used for the current run. 32 | config_module: (deprecated) The whole config module that is used for the current run. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None): 36 | global _run_context 37 | # Only a single RunContext can be alive 38 | assert _run_context is None 39 | _run_context = self 40 | self.submit_config = submit_config 41 | self.should_stop_flag = False 42 | self.has_closed = False 43 | self.start_time = time.time() 44 | self.last_update_time = time.time() 45 | self.last_update_interval = 0.0 46 | self.progress_monitor_file_path = None 47 | 48 | # vestigial config_module support just prints a warning 49 | if config_module is not None: 50 | print("RunContext.config_module parameter support has been removed.") 51 | 52 | # write out details about the run to a text file 53 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 54 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 55 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 56 | 57 | def __enter__(self) -> "RunContext": 58 | return self 59 | 60 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 61 | self.close() 62 | 63 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 64 | """Do general housekeeping and keep the state of the context up-to-date. 65 | Should be called often enough but not in a tight loop.""" 66 | assert not self.has_closed 67 | 68 | self.last_update_interval = time.time() - self.last_update_time 69 | self.last_update_time = time.time() 70 | 71 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 72 | self.should_stop_flag = True 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | self.has_closed = True 99 | 100 | # detach the global singleton 101 | global _run_context 102 | if _run_context is self: 103 | _run_context = None 104 | 105 | @staticmethod 106 | def get(): 107 | import dnnlib 108 | if _run_context is not None: 109 | return _run_context 110 | return RunContext(dnnlib.submit_config) 111 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /run_generator_vc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | # >.>.>.>.>.>.>.>.>.>.>.>.>.>.>.>. 5 | # Licensed under the Apache License, Version 2.0 (the "License") 6 | # You may obtain a copy of the License at 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # --- File Name: run_generator_vc.py 10 | # --- Creation Date: 08-02-2020 11 | # --- Last Modified: Fri 20 Mar 2020 15:48:21 AEDT 12 | # --- Author: Xinqi Zhu 13 | # .<.<.<.<.<.<.<.<.<.<.<.<.<.<.<.< 14 | """ 15 | Generator script for vc-gan and info-gan models. 16 | """ 17 | 18 | import argparse 19 | import numpy as np 20 | import PIL.Image 21 | import dnnlib 22 | import dnnlib.tflib as tflib 23 | import re 24 | import sys 25 | 26 | import pretrained_networks 27 | from training import misc 28 | from training.training_loop_dsp import get_grid_latents 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | 33 | def generate_images(network_pkl, 34 | n_imgs, 35 | model_type, 36 | n_discrete, 37 | n_continuous, 38 | n_samples_per=10): 39 | print('Loading networks from "%s"...' % network_pkl) 40 | tflib.init_tf() 41 | if (model_type == 'info_gan') or (model_type == 'vc_gan_with_vc_head'): 42 | _G, _D, I, Gs = misc.load_pkl(network_pkl) 43 | else: 44 | _G, _D, Gs = misc.load_pkl(network_pkl) 45 | 46 | # _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 47 | 48 | Gs_kwargs = dnnlib.EasyDict() 49 | Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, 50 | nchw_to_nhwc=True) 51 | Gs_kwargs.randomize_noise = False 52 | 53 | for idx in range(n_imgs): 54 | print('Generating image %d/%d ...' % (idx, n_imgs)) 55 | 56 | if n_discrete == 0: 57 | grid_labels = np.zeros([n_continuous * n_samples_per, 0], 58 | dtype=np.float32) 59 | else: 60 | grid_labels = np.zeros( 61 | [n_discrete * n_continuous * n_samples_per, 0], 62 | dtype=np.float32) 63 | 64 | grid_size, grid_latents, grid_labels = get_grid_latents( 65 | n_discrete, n_continuous, n_samples_per, _G, grid_labels) 66 | grid_fakes = Gs.run(grid_latents, 67 | grid_labels, 68 | is_validation=True, 69 | minibatch_size=4, 70 | randomize_noise=False) 71 | misc.save_image_grid(grid_fakes, 72 | dnnlib.make_run_dir_path('img_%04d.png' % idx), 73 | drange=[-1, 1], 74 | grid_size=grid_size) 75 | 76 | 77 | #---------------------------------------------------------------------------- 78 | 79 | _examples = '''examples: 80 | 81 | # Generate images traversals 82 | python %(prog)s --network_pkl=results/info_gan.pkl --n_imgs=5 --result_dir ./results 83 | ''' 84 | 85 | 86 | #---------------------------------------------------------------------------- 87 | def main(): 88 | parser = argparse.ArgumentParser( 89 | description='''VC-GAN and INFO-GAN generator. 90 | 91 | Run 'python %(prog)s --help' for subcommand help.''', 92 | epilog=_examples, 93 | formatter_class=argparse.RawDescriptionHelpFormatter) 94 | 95 | parser.add_argument('--network_pkl', 96 | help='Network pickle filename', 97 | required=True) 98 | parser.add_argument('--n_imgs', 99 | type=int, 100 | help='Number of images to generate', 101 | required=True) 102 | parser.add_argument('--n_discrete', 103 | type=int, 104 | help='Number of discrete latents', 105 | default=0) 106 | parser.add_argument('--n_continuous', 107 | type=int, 108 | help='Number of continuous latents', 109 | default=14) 110 | parser.add_argument('--n_samples_per', 111 | type=int, 112 | help='Number of samples per row', 113 | default=10) 114 | parser.add_argument('--model_type', 115 | type=str, 116 | help='Which model is this pkl', 117 | default='vc_gan_with_vc_head', 118 | choices=['info_gan', 'vc_gan', 'vc_gan_with_vc_head']) 119 | parser.add_argument( 120 | '--result-dir', 121 | help='Root directory for run results (default: %(default)s)', 122 | default='results', 123 | metavar='DIR') 124 | 125 | args = parser.parse_args() 126 | kwargs = vars(args) 127 | 128 | sc = dnnlib.SubmitConfig() 129 | sc.num_gpus = 1 130 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 131 | sc.local.do_not_copy_source_files = True 132 | sc.run_dir_root = kwargs.pop('result_dir') 133 | 134 | dnnlib.submit_run(sc, 'run_generator_vc.generate_images', **kwargs) 135 | 136 | 137 | #---------------------------------------------------------------------------- 138 | 139 | if __name__ == "__main__": 140 | main() 141 | 142 | #---------------------------------------------------------------------------- 143 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Perceptual Path Length (PPL).""" 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import dnnlib.tflib as tflib 12 | 13 | from metrics import metric_base 14 | from training import misc 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | # Normalize batch of vectors. 19 | def normalize(v): 20 | return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) 21 | 22 | # Spherical interpolation of a batch of vectors. 23 | def slerp(a, b, t): 24 | a = normalize(a) 25 | b = normalize(b) 26 | d = tf.reduce_sum(a * b, axis=-1, keepdims=True) 27 | p = t * tf.math.acos(d) 28 | c = normalize(b - d * a) 29 | d = a * tf.math.cos(p) + c * tf.math.sin(p) 30 | return normalize(d) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | class PPL(metric_base.MetricBase): 35 | def __init__(self, num_samples, epsilon, space, sampling, crop, minibatch_per_gpu, Gs_overrides, **kwargs): 36 | assert space in ['z', 'w'] 37 | assert sampling in ['full', 'end'] 38 | super().__init__(**kwargs) 39 | self.num_samples = num_samples 40 | self.epsilon = epsilon 41 | self.space = space 42 | self.sampling = sampling 43 | self.crop = crop 44 | self.minibatch_per_gpu = minibatch_per_gpu 45 | self.Gs_overrides = Gs_overrides 46 | 47 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 48 | Gs_kwargs = dict(Gs_kwargs) 49 | Gs_kwargs.update(self.Gs_overrides) 50 | minibatch_size = num_gpus * self.minibatch_per_gpu 51 | 52 | # Construct TensorFlow graph. 53 | distance_expr = [] 54 | for gpu_idx in range(num_gpus): 55 | with tf.device('/gpu:%d' % gpu_idx): 56 | Gs_clone = Gs.clone() 57 | noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] 58 | 59 | # Generate random latents and interpolation t-values. 60 | lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) 61 | lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) 62 | labels = tf.reshape(tf.tile(self._get_random_labels_tf(self.minibatch_per_gpu), [1, 2]), [self.minibatch_per_gpu * 2, -1]) 63 | 64 | # Interpolate in W or Z. 65 | if self.space == 'w': 66 | dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, labels, **Gs_kwargs) 67 | dlat_t01 = tf.cast(dlat_t01, tf.float32) 68 | dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] 69 | dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) 70 | dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) 71 | dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) 72 | else: # space == 'z' 73 | lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] 74 | lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) 75 | lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) 76 | lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) 77 | dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, labels, **Gs_kwargs) 78 | 79 | # Synthesize images. 80 | with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch 81 | images = Gs_clone.components.synthesis.get_output_for(dlat_e01, randomize_noise=False, **Gs_kwargs) 82 | images = tf.cast(images, tf.float32) 83 | 84 | # Crop only the face region. 85 | if self.crop: 86 | c = int(images.shape[2] // 8) 87 | images = images[:, :, c*3 : c*7, c*2 : c*6] 88 | 89 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 90 | factor = images.shape[2] // 256 91 | if factor > 1: 92 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 93 | images = tf.reduce_mean(images, axis=[3,5]) 94 | 95 | # Scale dynamic range from [-1,1] to [0,255] for VGG. 96 | images = (images + 1) * (255 / 2) 97 | 98 | # Evaluate perceptual distance. 99 | img_e0, img_e1 = images[0::2], images[1::2] 100 | distance_measure = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl') 101 | distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) 102 | 103 | # Sampling loop. 104 | all_distances = [] 105 | for begin in range(0, self.num_samples, minibatch_size): 106 | self._report_progress(begin, self.num_samples) 107 | all_distances += tflib.run(distance_expr) 108 | all_distances = np.concatenate(all_distances, axis=0) 109 | 110 | # Reject outliers. 111 | lo = np.percentile(all_distances, 1, interpolation='lower') 112 | hi = np.percentile(all_distances, 99, interpolation='higher') 113 | filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) 114 | self._report_result(np.mean(filtered_distances)) 115 | 116 | #---------------------------------------------------------------------------- 117 | -------------------------------------------------------------------------------- /training/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Miscellaneous utility functions.""" 8 | 9 | import os 10 | import pickle 11 | import numpy as np 12 | import PIL.Image 13 | import PIL.ImageFont 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Convenience wrappers for pickle that are able to load data produced by 18 | # older versions of the code, and from external URLs. 19 | 20 | def open_file_or_url(file_or_url): 21 | if dnnlib.util.is_url(file_or_url): 22 | return dnnlib.util.open_url(file_or_url, cache_dir='.stylegan2-cache') 23 | return open(file_or_url, 'rb') 24 | 25 | def load_pkl(file_or_url): 26 | with open_file_or_url(file_or_url) as file: 27 | return pickle.load(file, encoding='latin1') 28 | 29 | def save_pkl(obj, filename): 30 | with open(filename, 'wb') as file: 31 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 32 | 33 | #---------------------------------------------------------------------------- 34 | # Image utils. 35 | 36 | def adjust_dynamic_range(data, drange_in, drange_out): 37 | if drange_in != drange_out: 38 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 39 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 40 | data = data * scale + bias 41 | return data 42 | 43 | def create_image_grid(images, grid_size=None): 44 | assert images.ndim == 3 or images.ndim == 4 45 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 46 | 47 | if grid_size is not None: 48 | grid_w, grid_h = tuple(grid_size) 49 | else: 50 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 51 | grid_h = max((num - 1) // grid_w + 1, 1) 52 | 53 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 54 | for idx in range(num): 55 | x = (idx % grid_w) * img_w 56 | y = (idx // grid_w) * img_h 57 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 58 | return grid 59 | 60 | def convert_to_pil_image(image, drange=[0,1]): 61 | assert image.ndim == 2 or image.ndim == 3 62 | if image.ndim == 3: 63 | if image.shape[0] == 1: 64 | image = image[0] # grayscale CHW => HW 65 | else: 66 | image = image.transpose(1, 2, 0) # CHW -> HWC 67 | 68 | image = adjust_dynamic_range(image, drange, [0,255]) 69 | image = np.rint(image).clip(0, 255).astype(np.uint8) 70 | fmt = 'RGB' if image.ndim == 3 else 'L' 71 | return PIL.Image.fromarray(image, fmt) 72 | 73 | def save_image_grid(images, filename, drange=[0,1], grid_size=None): 74 | convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) 75 | 76 | def apply_mirror_augment(minibatch): 77 | mask = np.random.rand(minibatch.shape[0]) < 0.5 78 | minibatch = np.array(minibatch) 79 | minibatch[mask] = minibatch[mask, :, :, ::-1] 80 | return minibatch 81 | 82 | #---------------------------------------------------------------------------- 83 | # Loading data from previous training runs. 84 | 85 | def parse_config_for_previous_run(run_dir): 86 | with open(os.path.join(run_dir, 'submit_config.pkl'), 'rb') as f: 87 | data = pickle.load(f) 88 | data = data.get('run_func_kwargs', {}) 89 | return dict(train=data, dataset=data.get('dataset_args', {})) 90 | 91 | #---------------------------------------------------------------------------- 92 | # Size and contents of the image snapshot grids that are exported 93 | # periodically during training. 94 | 95 | def setup_snapshot_image_grid(training_set, 96 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 97 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 98 | 99 | # Select size. 100 | gw = 1; gh = 1 101 | if size == '1080p': 102 | gw = np.clip(1920 // training_set.shape[2], 3, 32) 103 | gh = np.clip(1080 // training_set.shape[1], 2, 32) 104 | if size == '4k': 105 | gw = np.clip(3840 // training_set.shape[2], 7, 32) 106 | gh = np.clip(2160 // training_set.shape[1], 4, 32) 107 | if size == '8k': 108 | gw = np.clip(7680 // training_set.shape[2], 7, 32) 109 | gh = np.clip(4320 // training_set.shape[1], 4, 32) 110 | 111 | # Initialize data arrays. 112 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 113 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 114 | 115 | # Random layout. 116 | if layout == 'random': 117 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) 118 | 119 | # Class-conditional layouts. 120 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 121 | if layout in class_layouts: 122 | bw, bh = class_layouts[layout] 123 | nw = (gw - 1) // bw + 1 124 | nh = (gh - 1) // bh + 1 125 | blocks = [[] for _i in range(nw * nh)] 126 | for _iter in range(1000000): 127 | real, label = training_set.get_minibatch_np(1) 128 | idx = np.argmax(label[0]) 129 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 130 | idx += training_set.label_size 131 | if idx < len(blocks): 132 | blocks[idx].append((real, label)) 133 | if all(len(block) >= bw * bh for block in blocks): 134 | break 135 | for i, block in enumerate(blocks): 136 | for j, (real, label) in enumerate(block): 137 | x = (i % nw) * bw + j % bw 138 | y = (i // nw) * bh + j // bw 139 | if x < gw and y < gh: 140 | reals[x + y * gw] = real[0] 141 | labels[x + y * gw] = label[0] 142 | 143 | return (gw, gh), reals, labels 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /docs/license.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Nvidia Source Code License-NC 7 | 8 | 56 | 57 | 58 | 59 |

Nvidia Source Code License-NC

60 | 61 |
62 | 63 |

1. Definitions

64 | 65 |

“Licensor” means any person or entity that distributes its Work.

66 | 67 |

“Software” means the original work of authorship made available under 68 | this License.

69 | 70 |

“Work” means the Software and any additions to or derivative works of 71 | the Software that are made available under this License.

72 | 73 |

“Nvidia Processors” means any central processing unit (CPU), graphics 74 | processing unit (GPU), field-programmable gate array (FPGA), 75 | application-specific integrated circuit (ASIC) or any combination 76 | thereof designed, made, sold, or provided by Nvidia or its affiliates.

77 | 78 |

The terms “reproduce,” “reproduction,” “derivative works,” and 79 | “distribution” have the meaning as provided under U.S. copyright law; 80 | provided, however, that for the purposes of this License, derivative 81 | works shall not include works that remain separable from, or merely 82 | link (or bind by name) to the interfaces of, the Work.

83 | 84 |

Works, including the Software, are “made available” under this License 85 | by including in or with the Work either (a) a copyright notice 86 | referencing the applicability of this License to the Work, or (b) a 87 | copy of this License.

88 | 89 |

2. License Grants

90 | 91 |

2.1 Copyright Grant. Subject to the terms and conditions of this 92 | License, each Licensor grants to you a perpetual, worldwide, 93 | non-exclusive, royalty-free, copyright license to reproduce, 94 | prepare derivative works of, publicly display, publicly perform, 95 | sublicense and distribute its Work and any resulting derivative 96 | works in any form.

97 | 98 |

3. Limitations

99 | 100 |

3.1 Redistribution. You may reproduce or distribute the Work only 101 | if (a) you do so under this License, (b) you include a complete 102 | copy of this License with your distribution, and (c) you retain 103 | without modification any copyright, patent, trademark, or 104 | attribution notices that are present in the Work.

105 | 106 |

3.2 Derivative Works. You may specify that additional or different 107 | terms apply to the use, reproduction, and distribution of your 108 | derivative works of the Work (“Your Terms”) only if (a) Your Terms 109 | provide that the use limitation in Section 3.3 applies to your 110 | derivative works, and (b) you identify the specific derivative 111 | works that are subject to Your Terms. Notwithstanding Your Terms, 112 | this License (including the redistribution requirements in Section 113 | 3.1) will continue to apply to the Work itself.

114 | 115 |

3.3 Use Limitation. The Work and any derivative works thereof only 116 | may be used or intended for use non-commercially. The Work or 117 | derivative works thereof may be used or intended for use by Nvidia 118 | or its affiliates commercially or non-commercially. As used herein, 119 | “non-commercially” means for research or evaluation purposes only.

120 | 121 |

3.4 Patent Claims. If you bring or threaten to bring a patent claim 122 | against any Licensor (including any claim, cross-claim or 123 | counterclaim in a lawsuit) to enforce any patents that you allege 124 | are infringed by any Work, then your rights under this License from 125 | such Licensor (including the grants in Sections 2.1 and 2.2) will 126 | terminate immediately.

127 | 128 |

3.5 Trademarks. This License does not grant any rights to use any 129 | Licensor’s or its affiliates’ names, logos, or trademarks, except 130 | as necessary to reproduce the notices described in this License.

131 | 132 |

3.6 Termination. If you violate any term of this License, then your 133 | rights under this License (including the grants in Sections 2.1 and 134 | 2.2) will terminate immediately.

135 | 136 |

4. Disclaimer of Warranty.

137 | 138 |

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY 139 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 140 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 141 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 142 | THIS LICENSE.

143 | 144 |

5. Limitation of Liability.

145 | 146 |

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 147 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 148 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 149 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 150 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 151 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 152 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 153 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 154 | THE POSSIBILITY OF SUCH DAMAGES.

155 | 156 |
157 |
158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /run_pair_generator_vc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | # >.>.>.>.>.>.>.>.>.>.>.>.>.>.>.>. 5 | # Licensed under the Apache License, Version 2.0 (the "License") 6 | # You may obtain a copy of the License at 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # --- File Name: run_pair_generator_vc.py 10 | # --- Creation Date: 27-02-2020 11 | # --- Last Modified: Fri 20 Mar 2020 15:48:45 AEDT 12 | # --- Author: Xinqi Zhu 13 | # .<.<.<.<.<.<.<.<.<.<.<.<.<.<.<.< 14 | """ 15 | Generate a image-pair dataset 16 | """ 17 | 18 | import argparse 19 | import numpy as np 20 | from PIL import Image 21 | import dnnlib 22 | import dnnlib.tflib as tflib 23 | import re 24 | import os 25 | import sys 26 | 27 | import pretrained_networks 28 | from training import misc 29 | from training.training_loop_dsp import get_grid_latents 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | 34 | def generate_image_pairs(network_pkl, 35 | n_imgs, 36 | model_type, 37 | n_discrete, 38 | n_continuous, 39 | result_dir, 40 | batch_size=10, 41 | latent_type='onedim'): 42 | print('Loading networks from "%s"...' % network_pkl) 43 | tflib.init_tf() 44 | if (model_type == 'info_gan') or (model_type == 'vc_gan_with_vc_head'): 45 | _G, _D, I, Gs = misc.load_pkl(network_pkl) 46 | else: 47 | _G, _D, Gs = misc.load_pkl(network_pkl) 48 | 49 | if not os.path.exists(result_dir): 50 | os.makedirs(result_dir) 51 | 52 | # _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 53 | 54 | Gs_kwargs = dnnlib.EasyDict() 55 | Gs_kwargs.randomize_noise = False 56 | 57 | n_batches = n_imgs // batch_size 58 | 59 | for i in range(n_batches): 60 | print('Generating image pairs %d/%d ...' % (i, n_batches)) 61 | grid_labels = np.zeros([batch_size, 0], dtype=np.float32) 62 | 63 | if n_discrete > 0: 64 | cat_dim = np.random.randint(0, n_discrete, size=[batch_size]) 65 | cat_onehot = np.zeros((batch_size, n_discrete)) 66 | cat_onehot[np.arange(cat_dim.size), cat_dim] = 1 67 | 68 | z_1 = np.random.uniform(low=-2, 69 | high=2, 70 | size=[batch_size, n_continuous]) 71 | z_2 = np.random.uniform(low=-2, 72 | high=2, 73 | size=[batch_size, n_continuous]) 74 | if latent_type == 'onedim': 75 | delta_dim = np.random.randint(0, n_continuous, size=[batch_size]) 76 | delta_onehot = np.zeros((batch_size, n_continuous)) 77 | delta_onehot[np.arange(delta_dim.size), delta_dim] = 1 78 | z_2 = np.where(delta_onehot > 0, z_2, z_1) 79 | delta_z = z_1 - z_2 80 | 81 | if i == 0: 82 | labels = delta_z 83 | else: 84 | labels = np.concatenate([labels, delta_z], axis=0) 85 | 86 | if n_discrete > 0: 87 | z_1 = np.concatenate((cat_onehot, z_1), axis=1) 88 | z_2 = np.concatenate((cat_onehot, z_2), axis=1) 89 | 90 | fakes_1 = Gs.run(z_1, 91 | grid_labels, 92 | is_validation=True, 93 | minibatch_size=batch_size, 94 | **Gs_kwargs) 95 | fakes_2 = Gs.run(z_2, 96 | grid_labels, 97 | is_validation=True, 98 | minibatch_size=batch_size, 99 | **Gs_kwargs) 100 | print('fakes_1.shape:', fakes_1.shape) 101 | print('fakes_2.shape:', fakes_2.shape) 102 | 103 | for j in range(fakes_1.shape[0]): 104 | pair_np = np.concatenate([fakes_1[j], fakes_2[j]], axis=2) 105 | img = misc.convert_to_pil_image(pair_np, [-1, 1]) 106 | # pair_np = (pair_np * 255).astype(np.uint8) 107 | # img = Image.fromarray(pair_np) 108 | img.save( 109 | os.path.join(result_dir, 110 | 'pair_%06d.jpg' % (i * batch_size + j))) 111 | np.save(os.path.join(result_dir, 'labels.npy'), labels) 112 | 113 | 114 | #---------------------------------------------------------------------------- 115 | 116 | _examples = '''examples: 117 | 118 | # Generate image pairs 119 | python %(prog)s --network_pkl=results/info_gan.pkl --n_imgs=5 --result_dir ./results 120 | ''' 121 | 122 | 123 | #---------------------------------------------------------------------------- 124 | def main(): 125 | parser = argparse.ArgumentParser( 126 | description='VC-GAN and INFO-GAN image-pair generator.', 127 | epilog=_examples, 128 | formatter_class=argparse.RawDescriptionHelpFormatter) 129 | 130 | parser.add_argument('--network_pkl', 131 | help='Network pickle filename', 132 | required=True) 133 | parser.add_argument('--n_imgs', 134 | type=int, 135 | help='Number of image pairs to generate', 136 | required=True) 137 | parser.add_argument('--n_discrete', 138 | type=int, 139 | help='Number of discrete latents', 140 | default=0) 141 | parser.add_argument('--n_continuous', 142 | type=int, 143 | help='Number of continuous latents', 144 | default=14) 145 | parser.add_argument('--batch_size', 146 | type=int, 147 | help='Batch size for generation', 148 | default=10) 149 | parser.add_argument('--latent_type', 150 | type=str, 151 | help='What type of latent difference to use', 152 | default='onedim', 153 | choices=['onedim', 'fulldim']) 154 | parser.add_argument('--model_type', 155 | type=str, 156 | help='Which model is this pkl', 157 | default='vc_gan_with_vc_head', 158 | choices=['info_gan', 'vc_gan', 'vc_gan_with_vc_head']) 159 | parser.add_argument('--result-dir', 160 | help='Root directory to store this dataset', 161 | required=True, 162 | metavar='DIR') 163 | 164 | args = parser.parse_args() 165 | kwargs = vars(args) 166 | 167 | sc = dnnlib.SubmitConfig() 168 | sc.num_gpus = 1 169 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 170 | sc.local.do_not_copy_source_files = True 171 | sc.run_dir_root = kwargs['result_dir'] 172 | 173 | dnnlib.submit_run(sc, 'run_pair_generator_vc.generate_image_pairs', 174 | **kwargs) 175 | 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | if __name__ == "__main__": 180 | main() 181 | 182 | #---------------------------------------------------------------------------- 183 | -------------------------------------------------------------------------------- /run_projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import argparse 8 | import numpy as np 9 | import dnnlib 10 | import dnnlib.tflib as tflib 11 | import re 12 | import sys 13 | 14 | import projector 15 | import pretrained_networks 16 | from training import dataset 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def project_image(proj, targets, png_prefix, num_snapshots): 22 | snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) 23 | misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1]) 24 | proj.start(targets) 25 | while proj.get_cur_step() < proj.num_steps: 26 | print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) 27 | proj.step() 28 | if proj.get_cur_step() in snapshot_steps: 29 | misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1]) 30 | print('\r%-30s\r' % '', end='', flush=True) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def project_generated_images(network_pkl, seeds, num_snapshots, truncation_psi): 35 | print('Loading networks from "%s"...' % network_pkl) 36 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 37 | proj = projector.Projector() 38 | proj.set_network(Gs) 39 | noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] 40 | 41 | Gs_kwargs = dnnlib.EasyDict() 42 | Gs_kwargs.randomize_noise = False 43 | Gs_kwargs.truncation_psi = truncation_psi 44 | 45 | for seed_idx, seed in enumerate(seeds): 46 | print('Projecting seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 47 | rnd = np.random.RandomState(seed) 48 | z = rnd.randn(1, *Gs.input_shape[1:]) 49 | tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) 50 | images = Gs.run(z, None, **Gs_kwargs) 51 | project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('seed%04d-' % seed), num_snapshots=num_snapshots) 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def project_real_images(network_pkl, dataset_name, data_dir, num_images, num_snapshots): 56 | print('Loading networks from "%s"...' % network_pkl) 57 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 58 | proj = projector.Projector() 59 | proj.set_network(Gs) 60 | 61 | print('Loading images from "%s"...' % dataset_name) 62 | dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0) 63 | assert dataset_obj.shape == Gs.output_shape[1:] 64 | 65 | for image_idx in range(num_images): 66 | print('Projecting image %d/%d ...' % (image_idx, num_images)) 67 | images, _labels = dataset_obj.get_minibatch_np(1) 68 | images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1]) 69 | project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=num_snapshots) 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | def _parse_num_range(s): 74 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 75 | 76 | range_re = re.compile(r'^(\d+)-(\d+)$') 77 | m = range_re.match(s) 78 | if m: 79 | return range(int(m.group(1)), int(m.group(2))+1) 80 | vals = s.split(',') 81 | return [int(x) for x in vals] 82 | 83 | #---------------------------------------------------------------------------- 84 | 85 | _examples = '''examples: 86 | 87 | # Project generated images 88 | python %(prog)s project-generated-images --network=gdrive:networks/stylegan2-car-config-f.pkl --seeds=0,1,5 89 | 90 | # Project real images 91 | python %(prog)s project-real-images --network=gdrive:networks/stylegan2-car-config-f.pkl --dataset=car --data-dir=~/datasets 92 | 93 | ''' 94 | 95 | #---------------------------------------------------------------------------- 96 | 97 | def main(): 98 | parser = argparse.ArgumentParser( 99 | description='''StyleGAN2 projector. 100 | 101 | Run 'python %(prog)s --help' for subcommand help.''', 102 | epilog=_examples, 103 | formatter_class=argparse.RawDescriptionHelpFormatter 104 | ) 105 | 106 | subparsers = parser.add_subparsers(help='Sub-commands', dest='command') 107 | 108 | project_generated_images_parser = subparsers.add_parser('project-generated-images', help='Project generated images') 109 | project_generated_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 110 | project_generated_images_parser.add_argument('--seeds', type=_parse_num_range, help='List of random seeds', default=range(3)) 111 | project_generated_images_parser.add_argument('--num-snapshots', type=int, help='Number of snapshots (default: %(default)s)', default=5) 112 | project_generated_images_parser.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=1.0) 113 | project_generated_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 114 | 115 | project_real_images_parser = subparsers.add_parser('project-real-images', help='Project real images') 116 | project_real_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 117 | project_real_images_parser.add_argument('--data-dir', help='Dataset root directory', required=True) 118 | project_real_images_parser.add_argument('--dataset', help='Training dataset', dest='dataset_name', required=True) 119 | project_real_images_parser.add_argument('--num-snapshots', type=int, help='Number of snapshots (default: %(default)s)', default=5) 120 | project_real_images_parser.add_argument('--num-images', type=int, help='Number of images to project (default: %(default)s)', default=3) 121 | project_real_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 122 | 123 | args = parser.parse_args() 124 | subcmd = args.command 125 | if subcmd is None: 126 | print ('Error: missing subcommand. Re-run with --help for usage.') 127 | sys.exit(1) 128 | 129 | kwargs = vars(args) 130 | sc = dnnlib.SubmitConfig() 131 | sc.num_gpus = 1 132 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 133 | sc.local.do_not_copy_source_files = True 134 | sc.run_dir_root = kwargs.pop('result_dir') 135 | sc.run_desc = kwargs.pop('command') 136 | 137 | func_name_map = { 138 | 'project-generated-images': 'run_projector.project_generated_images', 139 | 'project-real-images': 'run_projector.project_real_images' 140 | } 141 | dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs) 142 | 143 | #---------------------------------------------------------------------------- 144 | 145 | if __name__ == "__main__": 146 | main() 147 | 148 | #---------------------------------------------------------------------------- 149 | -------------------------------------------------------------------------------- /pretrained_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """List of pre-trained StyleGAN2 networks located on Google Drive.""" 8 | 9 | import pickle 10 | import dnnlib 11 | import dnnlib.tflib as tflib 12 | 13 | #---------------------------------------------------------------------------- 14 | # StyleGAN2 Google Drive root: https://drive.google.com/open?id=1QHc-yF5C3DChRwSdZKcx1w6K8JvSxQi7 15 | 16 | gdrive_urls = { 17 | 'gdrive:networks/stylegan2-car-config-a.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-a.pkl', 18 | 'gdrive:networks/stylegan2-car-config-b.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-b.pkl', 19 | 'gdrive:networks/stylegan2-car-config-c.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-c.pkl', 20 | 'gdrive:networks/stylegan2-car-config-d.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-d.pkl', 21 | 'gdrive:networks/stylegan2-car-config-e.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-e.pkl', 22 | 'gdrive:networks/stylegan2-car-config-f.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl', 23 | 'gdrive:networks/stylegan2-cat-config-a.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-a.pkl', 24 | 'gdrive:networks/stylegan2-cat-config-f.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl', 25 | 'gdrive:networks/stylegan2-church-config-a.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-a.pkl', 26 | 'gdrive:networks/stylegan2-church-config-f.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl', 27 | 'gdrive:networks/stylegan2-ffhq-config-a.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-a.pkl', 28 | 'gdrive:networks/stylegan2-ffhq-config-b.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-b.pkl', 29 | 'gdrive:networks/stylegan2-ffhq-config-c.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-c.pkl', 30 | 'gdrive:networks/stylegan2-ffhq-config-d.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-d.pkl', 31 | 'gdrive:networks/stylegan2-ffhq-config-e.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-e.pkl', 32 | 'gdrive:networks/stylegan2-ffhq-config-f.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl', 33 | 'gdrive:networks/stylegan2-horse-config-a.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-a.pkl', 34 | 'gdrive:networks/stylegan2-horse-config-f.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl', 35 | 'gdrive:networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl', 36 | 'gdrive:networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl', 37 | 'gdrive:networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl', 38 | 'gdrive:networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl', 39 | 'gdrive:networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl', 40 | 'gdrive:networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl', 41 | 'gdrive:networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl', 42 | 'gdrive:networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl', 43 | 'gdrive:networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl', 44 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl', 45 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl', 46 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl', 47 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl', 48 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl', 49 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl', 50 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl', 51 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl', 52 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl', 53 | } 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def get_path_or_url(path_or_gdrive_path): 58 | return gdrive_urls.get(path_or_gdrive_path, path_or_gdrive_path) 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | _cached_networks = dict() 63 | 64 | def load_networks(path_or_gdrive_path): 65 | path_or_url = get_path_or_url(path_or_gdrive_path) 66 | if path_or_url in _cached_networks: 67 | return _cached_networks[path_or_url] 68 | 69 | if dnnlib.util.is_url(path_or_url): 70 | stream = dnnlib.util.open_url(path_or_url, cache_dir='.stylegan2-cache') 71 | else: 72 | stream = open(path_or_url, 'rb') 73 | 74 | tflib.init_tf() 75 | with stream: 76 | G, D, Gs = pickle.load(stream, encoding='latin1') 77 | _cached_networks[path_or_url] = G, D, Gs 78 | return G, D, Gs 79 | 80 | #---------------------------------------------------------------------------- 81 | -------------------------------------------------------------------------------- /metrics/metric_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Common definitions for GAN metrics.""" 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import numpy as np 13 | import tensorflow as tf 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | 17 | from training import misc 18 | from training import dataset 19 | 20 | #---------------------------------------------------------------------------- 21 | # Base class for metrics. 22 | 23 | class MetricBase: 24 | def __init__(self, name): 25 | self.name = name 26 | self._dataset_obj = None 27 | self._progress_lo = None 28 | self._progress_hi = None 29 | self._progress_max = None 30 | self._progress_sec = None 31 | self._progress_time = None 32 | self._reset() 33 | 34 | def close(self): 35 | self._reset() 36 | 37 | def _reset(self, network_pkl=None, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None): 38 | if self._dataset_obj is not None: 39 | self._dataset_obj.close() 40 | 41 | self._network_pkl = network_pkl 42 | self._data_dir = data_dir 43 | self._dataset_args = dataset_args 44 | self._dataset_obj = None 45 | self._mirror_augment = mirror_augment 46 | self._eval_time = 0 47 | self._results = [] 48 | 49 | if (dataset_args is None or mirror_augment is None) and run_dir is not None: 50 | run_config = misc.parse_config_for_previous_run(run_dir) 51 | self._dataset_args = dict(run_config['dataset']) 52 | self._dataset_args['shuffle_mb'] = 0 53 | self._mirror_augment = run_config['train'].get('mirror_augment', False) 54 | 55 | def configure_progress_reports(self, plo, phi, pmax, psec=15): 56 | self._progress_lo = plo 57 | self._progress_hi = phi 58 | self._progress_max = pmax 59 | self._progress_sec = psec 60 | 61 | def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, Gs_kwargs=dict(is_validation=True)): 62 | self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) 63 | time_begin = time.time() 64 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager 65 | self._report_progress(0, 1) 66 | # _G, _D, Gs = misc.load_pkl(self._network_pkl) 67 | _G, _D, _I, Gs = misc.load_pkl(self._network_pkl) 68 | self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) 69 | self._report_progress(1, 1) 70 | self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init 71 | 72 | if log_results: 73 | if run_dir is not None: 74 | log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name) 75 | with dnnlib.util.Logger(log_file, 'a'): 76 | print(self.get_result_str().strip()) 77 | else: 78 | print(self.get_result_str().strip()) 79 | 80 | def get_result_str(self): 81 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] 82 | if len(network_name) > 29: 83 | network_name = '...' + network_name[-26:] 84 | result_str = '%-30s' % network_name 85 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) 86 | for res in self._results: 87 | result_str += ' ' + self.name + res.suffix + ' ' 88 | result_str += res.fmt % res.value 89 | return result_str 90 | 91 | def update_autosummaries(self): 92 | for res in self._results: 93 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) 94 | 95 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 96 | raise NotImplementedError # to be overridden by subclasses 97 | 98 | def _report_result(self, value, suffix='', fmt='%-10.4f'): 99 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] 100 | 101 | def _report_progress(self, pcur, pmax, status_str=''): 102 | if self._progress_lo is None or self._progress_hi is None or self._progress_max is None: 103 | return 104 | t = time.time() 105 | if self._progress_sec is not None and self._progress_time is not None and t < self._progress_time + self._progress_sec: 106 | return 107 | self._progress_time = t 108 | val = self._progress_lo + (pcur / pmax) * (self._progress_hi - self._progress_lo) 109 | dnnlib.RunContext.get().update(status_str, int(val), self._progress_max) 110 | 111 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs): 112 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) 113 | all_args.update(self._dataset_args) 114 | all_args.update(kwargs) 115 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) 116 | dataset_name = self._dataset_args.get('tfrecord_dir', None) or self._dataset_args.get('h5_file', None) 117 | dataset_name = os.path.splitext(os.path.basename(dataset_name))[0] 118 | return os.path.join('.stylegan2-cache', '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) 119 | 120 | def _get_dataset_obj(self): 121 | if self._dataset_obj is None: 122 | self._dataset_obj = dataset.load_dataset(data_dir=self._data_dir, **self._dataset_args) 123 | return self._dataset_obj 124 | 125 | def _iterate_reals(self, minibatch_size): 126 | dataset_obj = self._get_dataset_obj() 127 | while True: 128 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 129 | if self._mirror_augment: 130 | images = misc.apply_mirror_augment(images) 131 | yield images 132 | 133 | def _iterate_fakes(self, Gs, minibatch_size, num_gpus): 134 | while True: 135 | latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) 136 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 137 | images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) 138 | yield images 139 | 140 | def _get_random_labels_tf(self, minibatch_size): 141 | return self._get_dataset_obj().get_random_labels_tf(minibatch_size) 142 | 143 | #---------------------------------------------------------------------------- 144 | # Group of multiple metrics. 145 | 146 | class MetricGroup: 147 | def __init__(self, metric_kwarg_list): 148 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] 149 | 150 | def run(self, *args, **kwargs): 151 | for metric in self.metrics: 152 | metric.run(*args, **kwargs) 153 | 154 | def get_result_str(self): 155 | return ' '.join(metric.get_result_str() for metric in self.metrics) 156 | 157 | def update_autosummaries(self): 158 | for metric in self.metrics: 159 | metric.update_autosummaries() 160 | 161 | #---------------------------------------------------------------------------- 162 | # Dummy metric for debugging purposes. 163 | 164 | class DummyMetric(MetricBase): 165 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 166 | _ = Gs, Gs_kwargs, num_gpus 167 | self._report_result(0.0) 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /dnnlib/tflib/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """TensorFlow custom ops builder. 8 | """ 9 | 10 | import os 11 | import re 12 | import platform 13 | import uuid 14 | import hashlib 15 | import tempfile 16 | import shutil 17 | import subprocess 18 | import tensorflow as tf 19 | from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache') 25 | cuda_cache_version_tag = 'v1' 26 | do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! 27 | verbose = True # Print status messages to stdout. 28 | 29 | compiler_bindir_search_path = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin', 33 | ] 34 | 35 | #---------------------------------------------------------------------------- 36 | # Internal helper funcs. 37 | 38 | def _find_gcc_version(): 39 | bashCommand = 'gcc --version' 40 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 41 | output, error = process.communicate() 42 | gcc_v = int(output.strip().split()[2].decode('utf-8').split('.')[0]) 43 | return gcc_v 44 | 45 | def _find_compiler_bindir(): 46 | for compiler_path in compiler_bindir_search_path: 47 | if os.path.isdir(compiler_path): 48 | return compiler_path 49 | return None 50 | 51 | def _get_compute_cap(device): 52 | caps_str = device.physical_device_desc 53 | m = re.search('compute capability: (\\d+).(\\d+)', caps_str) 54 | major = m.group(1) 55 | minor = m.group(2) 56 | return (major, minor) 57 | 58 | def _get_cuda_gpu_arch_string(): 59 | gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] 60 | if len(gpus) == 0: 61 | raise RuntimeError('No GPU devices found') 62 | (major, minor) = _get_compute_cap(gpus[0]) 63 | return 'sm_%s%s' % (major, minor) 64 | 65 | def _run_cmd(cmd): 66 | with os.popen(cmd) as pipe: 67 | output = pipe.read() 68 | status = pipe.close() 69 | if status is not None: 70 | raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) 71 | 72 | def _prepare_nvcc_cli(opts): 73 | gcc_v = _find_gcc_version() 74 | if gcc_v >= 6: 75 | cmd = 'nvcc ' + opts.strip() 76 | else: 77 | cmd = 'nvcc --std=c++11 -DNDEBUG ' + opts.strip() 78 | cmd += ' --disable-warnings' 79 | cmd += ' --include-path "%s"' % tf.sysconfig.get_include() 80 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') 81 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') 82 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') 83 | 84 | compiler_bindir = _find_compiler_bindir() 85 | if compiler_bindir is None: 86 | # Require that _find_compiler_bindir succeeds on Windows. Allow 87 | # nvcc to use whatever is the default on Linux. 88 | if os.name == 'nt': 89 | raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) 90 | else: 91 | cmd += ' --compiler-bindir "%s"' % compiler_bindir 92 | cmd += ' 2>&1' 93 | return cmd 94 | 95 | #---------------------------------------------------------------------------- 96 | # Main entry point. 97 | 98 | _plugin_cache = dict() 99 | 100 | def get_plugin(cuda_file): 101 | cuda_file_base = os.path.basename(cuda_file) 102 | cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) 103 | 104 | # Already in cache? 105 | if cuda_file in _plugin_cache: 106 | return _plugin_cache[cuda_file] 107 | 108 | # Setup plugin. 109 | if verbose: 110 | print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) 111 | try: 112 | # Hash CUDA source. 113 | md5 = hashlib.md5() 114 | with open(cuda_file, 'rb') as f: 115 | md5.update(f.read()) 116 | md5.update(b'\n') 117 | 118 | # Hash headers included by the CUDA code by running it through the preprocessor. 119 | if not do_not_hash_included_headers: 120 | if verbose: 121 | print('Preprocessing... ', end='', flush=True) 122 | with tempfile.TemporaryDirectory() as tmp_dir: 123 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) 124 | _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) 125 | with open(tmp_file, 'rb') as f: 126 | bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros 127 | good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') 128 | for ln in f: 129 | if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas 130 | ln = ln.replace(bad_file_str, good_file_str) 131 | md5.update(ln) 132 | md5.update(b'\n') 133 | 134 | # Select compiler options. 135 | compile_opts = '' 136 | if os.name == 'nt': 137 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') 138 | elif os.name == 'posix': 139 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so') 140 | if 'hpc' in list(platform.uname())[1]: 141 | compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' 142 | else: 143 | compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=1\'' 144 | else: 145 | assert False # not Windows or Linux, w00t? 146 | compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() 147 | compile_opts += ' --use_fast_math' 148 | nvcc_cmd = _prepare_nvcc_cli(compile_opts) 149 | 150 | # Hash build configuration. 151 | md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') 152 | md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') 153 | md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') 154 | 155 | # Compile if not already compiled. 156 | bin_file_ext = '.dll' if os.name == 'nt' else '.so' 157 | bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) 158 | if not os.path.isfile(bin_file): 159 | if verbose: 160 | print('Compiling... ', end='', flush=True) 161 | with tempfile.TemporaryDirectory() as tmp_dir: 162 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) 163 | _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) 164 | os.makedirs(cuda_cache_path, exist_ok=True) 165 | intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) 166 | shutil.copyfile(tmp_file, intermediate_file) 167 | os.rename(intermediate_file, bin_file) # atomic 168 | 169 | # Load. 170 | if verbose: 171 | print('Loading... ', end='', flush=True) 172 | plugin = tf.load_op_library(bin_file) 173 | 174 | # Add to cache. 175 | _plugin_cache[cuda_file] = plugin 176 | if verbose: 177 | print('Done.', flush=True) 178 | return plugin 179 | 180 | except: 181 | if verbose: 182 | print('Failed!', flush=True) 183 | raise 184 | 185 | #---------------------------------------------------------------------------- 186 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/fused_bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #define EIGEN_USE_GPU 8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 9 | #include "tensorflow/core/framework/op.h" 10 | #include "tensorflow/core/framework/op_kernel.h" 11 | #include "tensorflow/core/framework/shape_inference.h" 12 | #include 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 18 | 19 | //------------------------------------------------------------------------ 20 | // CUDA kernel. 21 | 22 | template 23 | struct FusedBiasActKernelParams 24 | { 25 | const T* x; // [sizeX] 26 | const T* b; // [sizeB] or NULL 27 | const T* ref; // [sizeX] or NULL 28 | T* y; // [sizeX] 29 | 30 | int grad; 31 | int axis; 32 | int act; 33 | float alpha; 34 | float gain; 35 | 36 | int sizeX; 37 | int sizeB; 38 | int stepB; 39 | int loopX; 40 | }; 41 | 42 | template 43 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) 44 | { 45 | const float expRange = 80.0f; 46 | const float halfExpRange = 40.0f; 47 | const float seluScale = 1.0507009873554804934193349852946f; 48 | const float seluAlpha = 1.6732632423543772848170429916717f; 49 | 50 | // Loop over elements. 51 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 52 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 53 | { 54 | // Load and apply bias. 55 | float x = (float)p.x[xi]; 56 | if (p.b) 57 | x += (float)p.b[(xi / p.stepB) % p.sizeB]; 58 | float ref = (p.ref) ? (float)p.ref[xi] : 0.0f; 59 | if (p.gain != 0.0f & p.act != 9) 60 | ref /= p.gain; 61 | 62 | // Evaluate activation func. 63 | float y; 64 | switch (p.act * 10 + p.grad) 65 | { 66 | // linear 67 | default: 68 | case 10: y = x; break; 69 | case 11: y = x; break; 70 | case 12: y = 0.0f; break; 71 | 72 | // relu 73 | case 20: y = (x > 0.0f) ? x : 0.0f; break; 74 | case 21: y = (ref > 0.0f) ? x : 0.0f; break; 75 | case 22: y = 0.0f; break; 76 | 77 | // lrelu 78 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break; 79 | case 31: y = (ref > 0.0f) ? x : x * p.alpha; break; 80 | case 32: y = 0.0f; break; 81 | 82 | // tanh 83 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; 84 | case 41: y = x * (1.0f - ref * ref); break; 85 | case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break; 86 | 87 | // sigmoid 88 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; 89 | case 51: y = x * ref * (1.0f - ref); break; 90 | case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break; 91 | 92 | // elu 93 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; 94 | case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break; 95 | case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break; 96 | 97 | // selu 98 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; 99 | case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break; 100 | case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break; 101 | 102 | // softplus 103 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; 104 | case 81: y = x * (1.0f - expf(-ref)); break; 105 | case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break; 106 | 107 | // swish 108 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; 109 | case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break; 110 | case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break; 111 | } 112 | 113 | // Apply gain and store. 114 | p.y[xi] = (T)(y * p.gain); 115 | } 116 | } 117 | 118 | //------------------------------------------------------------------------ 119 | // TensorFlow op. 120 | 121 | template 122 | struct FusedBiasActOp : public OpKernel 123 | { 124 | FusedBiasActKernelParams m_attribs; 125 | 126 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) 127 | { 128 | memset(&m_attribs, 0, sizeof(m_attribs)); 129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); 130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); 131 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); 132 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); 133 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); 134 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); 135 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); 136 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); 137 | } 138 | 139 | void Compute(OpKernelContext* ctx) 140 | { 141 | FusedBiasActKernelParams p = m_attribs; 142 | cudaStream_t stream = ctx->eigen_device().stream(); 143 | 144 | const Tensor& x = ctx->input(0); // [...] 145 | const Tensor& b = ctx->input(1); // [sizeB] or [0] 146 | const Tensor& ref = ctx->input(2); // x.shape or [0] 147 | p.x = x.flat().data(); 148 | p.b = (b.NumElements()) ? b.flat().data() : NULL; 149 | p.ref = (ref.NumElements()) ? ref.flat().data() : NULL; 150 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); 151 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); 152 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); 153 | OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements")); 154 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); 155 | 156 | p.sizeX = (int)x.NumElements(); 157 | p.sizeB = (int)b.NumElements(); 158 | p.stepB = 1; 159 | for (int i = m_attribs.axis + 1; i < x.dims(); i++) 160 | p.stepB *= (int)x.dim_size(i); 161 | 162 | Tensor* y = NULL; // x.shape 163 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); 164 | p.y = y->flat().data(); 165 | 166 | p.loopX = 4; 167 | int blockSize = 4 * 32; 168 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 169 | void* args[] = {&p}; 170 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); 171 | } 172 | }; 173 | 174 | REGISTER_OP("FusedBiasAct") 175 | .Input ("x: T") 176 | .Input ("b: T") 177 | .Input ("ref: T") 178 | .Output ("y: T") 179 | .Attr ("T: {float, half}") 180 | .Attr ("grad: int = 0") 181 | .Attr ("axis: int = 1") 182 | .Attr ("act: int = 0") 183 | .Attr ("alpha: float = 0.0") 184 | .Attr ("gain: float = 1.0"); 185 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 186 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 187 | 188 | //------------------------------------------------------------------------ 189 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Helper for adding automatically tracked values to Tensorboard. 8 | 9 | Autosummary creates an identity op that internally keeps track of the input 10 | values and automatically shows up in TensorBoard. The reported value 11 | represents an average over input components. The average is accumulated 12 | constantly over time and flushed when save_summaries() is called. 13 | 14 | Notes: 15 | - The output tensor must be used as an input for something else in the 16 | graph. Otherwise, the autosummary op will not get executed, and the average 17 | value will not get accumulated. 18 | - It is perfectly fine to include autosummaries with the same name in 19 | several places throughout the graph, even if they are executed concurrently. 20 | - It is ok to also pass in a python scalar or numpy array. In this case, it 21 | is added to the average immediately. 22 | """ 23 | 24 | from collections import OrderedDict 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorboard import summary as summary_lib 28 | from tensorboard.plugins.custom_scalar import layout_pb2 29 | 30 | from . import tfutil 31 | from .tfutil import TfExpression 32 | from .tfutil import TfExpressionEx 33 | 34 | # Enable "Custom scalars" tab in TensorBoard for advanced formatting. 35 | # Disabled by default to reduce tfevents file size. 36 | enable_custom_scalars = False 37 | 38 | _dtype = tf.float64 39 | _vars = OrderedDict() # name => [var, ...] 40 | _immediate = OrderedDict() # name => update_op, update_value 41 | _finalized = False 42 | _merge_op = None 43 | 44 | 45 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 46 | """Internal helper for creating autosummary accumulators.""" 47 | assert not _finalized 48 | name_id = name.replace("/", "_") 49 | v = tf.cast(value_expr, _dtype) 50 | 51 | if v.shape.is_fully_defined(): 52 | size = np.prod(v.shape.as_list()) 53 | size_expr = tf.constant(size, dtype=_dtype) 54 | else: 55 | size = None 56 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 57 | 58 | if size == 1: 59 | if v.shape.ndims != 0: 60 | v = tf.reshape(v, []) 61 | v = [size_expr, v, tf.square(v)] 62 | else: 63 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 64 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 65 | 66 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 67 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 68 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 69 | 70 | if name in _vars: 71 | _vars[name].append(var) 72 | else: 73 | _vars[name] = [var] 74 | return update_op 75 | 76 | 77 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: 78 | """Create a new autosummary. 79 | 80 | Args: 81 | name: Name to use in TensorBoard 82 | value: TensorFlow expression or python value to track 83 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 84 | 85 | Example use of the passthru mechanism: 86 | 87 | n = autosummary('l2loss', loss, passthru=n) 88 | 89 | This is a shorthand for the following code: 90 | 91 | with tf.control_dependencies([autosummary('l2loss', loss)]): 92 | n = tf.identity(n) 93 | """ 94 | tfutil.assert_tf_initialized() 95 | name_id = name.replace("/", "_") 96 | 97 | if tfutil.is_tf_expression(value): 98 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 99 | condition = tf.convert_to_tensor(condition, name='condition') 100 | update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) 101 | with tf.control_dependencies([update_op]): 102 | return tf.identity(value if passthru is None else passthru) 103 | 104 | else: # python scalar or numpy array 105 | assert not tfutil.is_tf_expression(passthru) 106 | assert not tfutil.is_tf_expression(condition) 107 | if condition: 108 | if name not in _immediate: 109 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 110 | update_value = tf.placeholder(_dtype) 111 | update_op = _create_var(name, update_value) 112 | _immediate[name] = update_op, update_value 113 | update_op, update_value = _immediate[name] 114 | tfutil.run(update_op, {update_value: value}) 115 | return value if passthru is None else passthru 116 | 117 | 118 | def finalize_autosummaries() -> None: 119 | """Create the necessary ops to include autosummaries in TensorBoard report. 120 | Note: This should be done only once per graph. 121 | """ 122 | global _finalized 123 | tfutil.assert_tf_initialized() 124 | 125 | if _finalized: 126 | return None 127 | 128 | _finalized = True 129 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 130 | 131 | # Create summary ops. 132 | with tf.device(None), tf.control_dependencies(None): 133 | for name, vars_list in _vars.items(): 134 | name_id = name.replace("/", "_") 135 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 136 | moments = tf.add_n(vars_list) 137 | moments /= moments[0] 138 | with tf.control_dependencies([moments]): # read before resetting 139 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 140 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 141 | mean = moments[1] 142 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 143 | tf.summary.scalar(name, mean) 144 | if enable_custom_scalars: 145 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 146 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 147 | 148 | # Setup layout for custom scalars. 149 | layout = None 150 | if enable_custom_scalars: 151 | cat_dict = OrderedDict() 152 | for series_name in sorted(_vars.keys()): 153 | p = series_name.split("/") 154 | cat = p[0] if len(p) >= 2 else "" 155 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 156 | if cat not in cat_dict: 157 | cat_dict[cat] = OrderedDict() 158 | if chart not in cat_dict[cat]: 159 | cat_dict[cat][chart] = [] 160 | cat_dict[cat][chart].append(series_name) 161 | categories = [] 162 | for cat_name, chart_dict in cat_dict.items(): 163 | charts = [] 164 | for chart_name, series_names in chart_dict.items(): 165 | series = [] 166 | for series_name in series_names: 167 | series.append(layout_pb2.MarginChartContent.Series( 168 | value=series_name, 169 | lower="xCustomScalars/" + series_name + "/margin_lo", 170 | upper="xCustomScalars/" + series_name + "/margin_hi")) 171 | margin = layout_pb2.MarginChartContent(series=series) 172 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 173 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 174 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 175 | return layout 176 | 177 | def save_summaries(file_writer, global_step=None): 178 | """Call FileWriter.add_summary() with all summaries in the default graph, 179 | automatically finalizing and merging them on the first call. 180 | """ 181 | global _merge_op 182 | tfutil.assert_tf_initialized() 183 | 184 | if _merge_op is None: 185 | layout = finalize_autosummaries() 186 | if layout is not None: 187 | file_writer.add_summary(layout) 188 | with tf.device(None), tf.control_dependencies(None): 189 | _merge_op = tf.summary.merge_all() 190 | 191 | file_writer.add_summary(_merge_op.eval(), global_step) 192 | -------------------------------------------------------------------------------- /run_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import argparse 8 | import numpy as np 9 | import PIL.Image 10 | import dnnlib 11 | import dnnlib.tflib as tflib 12 | import re 13 | import sys 14 | 15 | import pretrained_networks 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def generate_images(network_pkl, seeds, truncation_psi): 20 | print('Loading networks from "%s"...' % network_pkl) 21 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 22 | noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] 23 | 24 | Gs_kwargs = dnnlib.EasyDict() 25 | Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 26 | Gs_kwargs.randomize_noise = False 27 | if truncation_psi is not None: 28 | Gs_kwargs.truncation_psi = truncation_psi 29 | 30 | for seed_idx, seed in enumerate(seeds): 31 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 32 | rnd = np.random.RandomState(seed) 33 | z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component] 34 | tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] 35 | images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel] 36 | PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed)) 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, minibatch_size=4): 41 | print('Loading networks from "%s"...' % network_pkl) 42 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 43 | w_avg = Gs.get_var('dlatent_avg') # [component] 44 | 45 | Gs_syn_kwargs = dnnlib.EasyDict() 46 | Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 47 | Gs_syn_kwargs.randomize_noise = False 48 | Gs_syn_kwargs.minibatch_size = minibatch_size 49 | 50 | print('Generating W vectors...') 51 | all_seeds = list(set(row_seeds + col_seeds)) 52 | all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component] 53 | all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component] 54 | all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component] 55 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component] 56 | 57 | print('Generating images...') 58 | all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel] 59 | image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} 60 | 61 | print('Generating style-mixed images...') 62 | for row_seed in row_seeds: 63 | for col_seed in col_seeds: 64 | w = w_dict[row_seed].copy() 65 | w[col_styles] = w_dict[col_seed][col_styles] 66 | image = Gs.components.synthesis.run(w[np.newaxis], **Gs_syn_kwargs)[0] 67 | image_dict[(row_seed, col_seed)] = image 68 | 69 | print('Saving images...') 70 | for (row_seed, col_seed), image in image_dict.items(): 71 | PIL.Image.fromarray(image, 'RGB').save(dnnlib.make_run_dir_path('%d-%d.png' % (row_seed, col_seed))) 72 | 73 | print('Saving image grid...') 74 | _N, _C, H, W = Gs.output_shape 75 | canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black') 76 | for row_idx, row_seed in enumerate([None] + row_seeds): 77 | for col_idx, col_seed in enumerate([None] + col_seeds): 78 | if row_seed is None and col_seed is None: 79 | continue 80 | key = (row_seed, col_seed) 81 | if row_seed is None: 82 | key = (col_seed, col_seed) 83 | if col_seed is None: 84 | key = (row_seed, row_seed) 85 | canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx)) 86 | canvas.save(dnnlib.make_run_dir_path('grid.png')) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | def _parse_num_range(s): 91 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 92 | 93 | range_re = re.compile(r'^(\d+)-(\d+)$') 94 | m = range_re.match(s) 95 | if m: 96 | return range(int(m.group(1)), int(m.group(2))+1) 97 | vals = s.split(',') 98 | return [int(x) for x in vals] 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | _examples = '''examples: 103 | 104 | # Generate ffhq uncurated images (matches paper Figure 12) 105 | python %(prog)s generate-images --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --seeds=6600-6625 --truncation-psi=0.5 106 | 107 | # Generate ffhq curated images (matches paper Figure 11) 108 | python %(prog)s generate-images --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --seeds=66,230,389,1518 --truncation-psi=1.0 109 | 110 | # Generate uncurated car images (matches paper Figure 12) 111 | python %(prog)s generate-images --network=gdrive:networks/stylegan2-car-config-f.pkl --seeds=6000-6025 --truncation-psi=0.5 112 | 113 | # Generate style mixing example (matches style mixing video clip) 114 | python %(prog)s style-mixing-example --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --row-seeds=85,100,75,458,1500 --col-seeds=55,821,1789,293 --truncation-psi=1.0 115 | ''' 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | def main(): 120 | parser = argparse.ArgumentParser( 121 | description='''StyleGAN2 generator. 122 | 123 | Run 'python %(prog)s --help' for subcommand help.''', 124 | epilog=_examples, 125 | formatter_class=argparse.RawDescriptionHelpFormatter 126 | ) 127 | 128 | subparsers = parser.add_subparsers(help='Sub-commands', dest='command') 129 | 130 | parser_generate_images = subparsers.add_parser('generate-images', help='Generate images') 131 | parser_generate_images.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 132 | parser_generate_images.add_argument('--seeds', type=_parse_num_range, help='List of random seeds', required=True) 133 | parser_generate_images.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5) 134 | parser_generate_images.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 135 | 136 | parser_style_mixing_example = subparsers.add_parser('style-mixing-example', help='Generate style mixing video') 137 | parser_style_mixing_example.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 138 | parser_style_mixing_example.add_argument('--row-seeds', type=_parse_num_range, help='Random seeds to use for image rows', required=True) 139 | parser_style_mixing_example.add_argument('--col-seeds', type=_parse_num_range, help='Random seeds to use for image columns', required=True) 140 | parser_style_mixing_example.add_argument('--col-styles', type=_parse_num_range, help='Style layer range (default: %(default)s)', default='0-6') 141 | parser_style_mixing_example.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5) 142 | parser_style_mixing_example.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 143 | 144 | args = parser.parse_args() 145 | kwargs = vars(args) 146 | subcmd = kwargs.pop('command') 147 | 148 | if subcmd is None: 149 | print ('Error: missing subcommand. Re-run with --help for usage.') 150 | sys.exit(1) 151 | 152 | sc = dnnlib.SubmitConfig() 153 | sc.num_gpus = 1 154 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 155 | sc.local.do_not_copy_source_files = True 156 | sc.run_dir_root = kwargs.pop('result_dir') 157 | sc.run_desc = subcmd 158 | 159 | func_name_map = { 160 | 'generate-images': 'run_generator.generate_images', 161 | 'style-mixing-example': 'run_generator.style_mixing_example' 162 | } 163 | dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs) 164 | 165 | #---------------------------------------------------------------------------- 166 | 167 | if __name__ == "__main__": 168 | main() 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import argparse 8 | import copy 9 | import os 10 | import sys 11 | 12 | import dnnlib 13 | from dnnlib import EasyDict 14 | 15 | from metrics.metric_defaults import metric_defaults 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | _valid_configs = [ 20 | # Table 1 21 | 'config-a', # Baseline StyleGAN 22 | 'config-b', # + Weight demodulation 23 | 'config-c', # + Lazy regularization 24 | 'config-d', # + Path length regularization 25 | 'config-e', # + No growing, new G & D arch. 26 | 'config-f', # + Large networks (default) 27 | 28 | # Table 2 29 | 'config-e-Gorig-Dorig', 'config-e-Gorig-Dresnet', 'config-e-Gorig-Dskip', 30 | 'config-e-Gresnet-Dorig', 'config-e-Gresnet-Dresnet', 'config-e-Gresnet-Dskip', 31 | 'config-e-Gskip-Dorig', 'config-e-Gskip-Dresnet', 'config-e-Gskip-Dskip', 32 | ] 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics): 37 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 38 | G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network. 39 | D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network. 40 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 41 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 42 | G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg') # Options for generator loss. 43 | D_loss = EasyDict(func_name='training.loss.D_logistic_r1') # Options for discriminator loss. 44 | sched = EasyDict() # Options for TrainingSchedule. 45 | grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid(). 46 | sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 47 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). 48 | 49 | train.data_dir = data_dir 50 | train.total_kimg = total_kimg 51 | train.mirror_augment = mirror_augment 52 | train.image_snapshot_ticks = train.network_snapshot_ticks = 10 53 | sched.G_lrate_base = sched.D_lrate_base = 0.002 54 | sched.minibatch_size_base = 32 55 | sched.minibatch_gpu_base = 4 56 | D_loss.gamma = 10 57 | metrics = [metric_defaults[x] for x in metrics] 58 | desc = 'stylegan2' 59 | 60 | desc += '-' + dataset 61 | dataset_args = EasyDict(tfrecord_dir=dataset) 62 | 63 | assert num_gpus in [1, 2, 4, 8] 64 | sc.num_gpus = num_gpus 65 | desc += '-%dgpu' % num_gpus 66 | 67 | assert config_id in _valid_configs 68 | desc += '-' + config_id 69 | 70 | # Configs A-E: Shrink networks to match original StyleGAN. 71 | if config_id != 'config-f': 72 | G.fmap_base = D.fmap_base = 8 << 10 73 | 74 | # Config E: Set gamma to 100 and override G & D architecture. 75 | if config_id.startswith('config-e'): 76 | D_loss.gamma = 100 77 | if 'Gorig' in config_id: G.architecture = 'orig' 78 | if 'Gskip' in config_id: G.architecture = 'skip' # (default) 79 | if 'Gresnet' in config_id: G.architecture = 'resnet' 80 | if 'Dorig' in config_id: D.architecture = 'orig' 81 | if 'Dskip' in config_id: D.architecture = 'skip' 82 | if 'Dresnet' in config_id: D.architecture = 'resnet' # (default) 83 | 84 | # Configs A-D: Enable progressive growing and switch to networks that support it. 85 | if config_id in ['config-a', 'config-b', 'config-c', 'config-d']: 86 | sched.lod_initial_resolution = 8 87 | sched.G_lrate_base = sched.D_lrate_base = 0.001 88 | sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 89 | sched.minibatch_size_base = 32 # (default) 90 | sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32} 91 | sched.minibatch_gpu_base = 4 # (default) 92 | sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4} 93 | G.synthesis_func = 'G_synthesis_stylegan_revised' 94 | D.func_name = 'training.networks_stylegan2.D_stylegan' 95 | 96 | # Configs A-C: Disable path length regularization. 97 | if config_id in ['config-a', 'config-b', 'config-c']: 98 | G_loss = EasyDict(func_name='training.loss.G_logistic_ns') 99 | 100 | # Configs A-B: Disable lazy regularization. 101 | if config_id in ['config-a', 'config-b']: 102 | train.lazy_regularization = False 103 | 104 | # Config A: Switch to original StyleGAN networks. 105 | if config_id == 'config-a': 106 | G = EasyDict(func_name='training.networks_stylegan.G_style') 107 | D = EasyDict(func_name='training.networks_stylegan.D_basic') 108 | 109 | if gamma is not None: 110 | D_loss.gamma = gamma 111 | 112 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 113 | sc.local.do_not_copy_source_files = True 114 | kwargs = EasyDict(train) 115 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) 116 | kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) 117 | kwargs.submit_config = copy.deepcopy(sc) 118 | kwargs.submit_config.run_dir_root = result_dir 119 | kwargs.submit_config.run_desc = desc 120 | dnnlib.submit_run(**kwargs) 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | def _str_to_bool(v): 125 | if isinstance(v, bool): 126 | return v 127 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 128 | return True 129 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 130 | return False 131 | else: 132 | raise argparse.ArgumentTypeError('Boolean value expected.') 133 | 134 | def _parse_comma_sep(s): 135 | if s is None or s.lower() == 'none' or s == '': 136 | return [] 137 | return s.split(',') 138 | 139 | #---------------------------------------------------------------------------- 140 | 141 | _examples = '''examples: 142 | 143 | # Train StyleGAN2 using the FFHQ dataset 144 | python %(prog)s --num-gpus=8 --data-dir=~/datasets --config=config-f --dataset=ffhq --mirror-augment=true 145 | 146 | valid configs: 147 | 148 | ''' + ', '.join(_valid_configs) + ''' 149 | 150 | valid metrics: 151 | 152 | ''' + ', '.join(sorted([x for x in metric_defaults.keys()])) + ''' 153 | 154 | ''' 155 | 156 | def main(): 157 | parser = argparse.ArgumentParser( 158 | description='Train StyleGAN2.', 159 | epilog=_examples, 160 | formatter_class=argparse.RawDescriptionHelpFormatter 161 | ) 162 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 163 | parser.add_argument('--data-dir', help='Dataset root directory', required=True) 164 | parser.add_argument('--dataset', help='Training dataset', required=True) 165 | parser.add_argument('--config', help='Training config (default: %(default)s)', default='config-f', required=True, dest='config_id', metavar='CONFIG') 166 | parser.add_argument('--num-gpus', help='Number of GPUs (default: %(default)s)', default=1, type=int, metavar='N') 167 | parser.add_argument('--total-kimg', help='Training length in thousands of images (default: %(default)s)', metavar='KIMG', default=25000, type=int) 168 | parser.add_argument('--gamma', help='R1 regularization weight (default is config dependent)', default=None, type=float) 169 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool) 170 | parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='fid50k', type=_parse_comma_sep) 171 | 172 | args = parser.parse_args() 173 | 174 | if not os.path.exists(args.data_dir): 175 | print ('Error: dataset root directory does not exist.') 176 | sys.exit(1) 177 | 178 | if args.config_id not in _valid_configs: 179 | print ('Error: --config value must be one of: ', ', '.join(_valid_configs)) 180 | sys.exit(1) 181 | 182 | for metric in args.metrics: 183 | if metric not in metric_defaults: 184 | print ('Error: unknown metric \'%s\'' % metric) 185 | sys.exit(1) 186 | 187 | run(**vars(args)) 188 | 189 | #---------------------------------------------------------------------------- 190 | 191 | if __name__ == "__main__": 192 | main() 193 | 194 | #---------------------------------------------------------------------------- 195 | 196 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/fused_bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Custom TensorFlow ops for efficient bias and activation.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | from .. import custom_ops 13 | from ...util import EasyDict 14 | 15 | def _get_plugin(): 16 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | activation_funcs = { 21 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), 22 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), 23 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), 24 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), 25 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), 26 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), 27 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), 28 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), 29 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), 30 | } 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'): 35 | r"""Fused bias and activation function. 36 | 37 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 38 | and scales the result by `gain`. Each of the steps is optional. In most cases, 39 | the fused op is considerably more efficient than performing the same calculation 40 | using standard TensorFlow ops. It supports first and second order gradients, 41 | but not third order gradients. 42 | 43 | Args: 44 | x: Input activation tensor. Can have any shape, but if `b` is defined, the 45 | dimension corresponding to `axis`, as well as the rank, must be known. 46 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 47 | as `x`. The shape must be known, and it must match the dimension of `x` 48 | corresponding to `axis`. 49 | axis: The dimension in `x` corresponding to the elements of `b`. 50 | The value of `axis` is ignored if `b` is not specified. 51 | act: Name of the activation function to evaluate, or `"linear"` to disable. 52 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 53 | See `activation_funcs` for a full list. `None` is not allowed. 54 | alpha: Shape parameter for the activation function, or `None` to use the default. 55 | gain: Scaling factor for the output tensor, or `None` to use default. 56 | See `activation_funcs` for the default scaling of each activation function. 57 | If unsure, consider specifying `1.0`. 58 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 59 | 60 | Returns: 61 | Tensor of the same shape and datatype as `x`. 62 | """ 63 | 64 | impl_dict = { 65 | 'ref': _fused_bias_act_ref, 66 | 'cuda': _fused_bias_act_cuda, 67 | } 68 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain): 73 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" 74 | 75 | # Validate arguments. 76 | x = tf.convert_to_tensor(x) 77 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) 78 | act_spec = activation_funcs[act] 79 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 80 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 81 | if alpha is None: 82 | alpha = act_spec.def_alpha 83 | if gain is None: 84 | gain = act_spec.def_gain 85 | 86 | # Add bias. 87 | if b.shape[0] != 0: 88 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) 89 | 90 | # Evaluate activation function. 91 | x = act_spec.func(x, alpha=alpha) 92 | 93 | # Scale by gain. 94 | if gain != 1: 95 | x *= gain 96 | return x 97 | 98 | #---------------------------------------------------------------------------- 99 | 100 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain): 101 | """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" 102 | 103 | # Validate arguments. 104 | x = tf.convert_to_tensor(x) 105 | empty_tensor = tf.constant([], dtype=x.dtype) 106 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor 107 | act_spec = activation_funcs[act] 108 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 109 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 110 | if alpha is None: 111 | alpha = act_spec.def_alpha 112 | if gain is None: 113 | gain = act_spec.def_gain 114 | 115 | # Special cases. 116 | if act == 'linear' and b is None and gain == 1.0: 117 | return x 118 | if act_spec.cuda_idx is None: 119 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) 120 | 121 | # CUDA kernel. 122 | cuda_kernel = _get_plugin().fused_bias_act 123 | cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain) 124 | 125 | # Forward pass: y = func(x, b). 126 | def func_y(x, b): 127 | y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs) 128 | y.set_shape(x.shape) 129 | return y 130 | 131 | # Backward pass: dx, db = grad(dy, x, y) 132 | def grad_dx(dy, x, y): 133 | ref = {'x': x, 'y': y}[act_spec.ref] 134 | dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs) 135 | dx.set_shape(x.shape) 136 | return dx 137 | def grad_db(dx): 138 | if b.shape[0] == 0: 139 | return empty_tensor 140 | db = dx 141 | if axis < x.shape.rank - 1: 142 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) 143 | if axis > 0: 144 | db = tf.reduce_sum(db, list(range(axis))) 145 | db.set_shape(b.shape) 146 | return db 147 | 148 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) 149 | def grad2_d_dy(d_dx, d_db, x, y): 150 | ref = {'x': x, 'y': y}[act_spec.ref] 151 | d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs) 152 | d_dy.set_shape(x.shape) 153 | return d_dy 154 | def grad2_d_x(d_dx, d_db, x, y): 155 | ref = {'x': x, 'y': y}[act_spec.ref] 156 | d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs) 157 | d_x.set_shape(x.shape) 158 | return d_x 159 | 160 | # Fast version for piecewise-linear activation funcs. 161 | @tf.custom_gradient 162 | def func_zero_2nd_grad(x, b): 163 | y = func_y(x, b) 164 | @tf.custom_gradient 165 | def grad(dy): 166 | dx = grad_dx(dy, x, y) 167 | db = grad_db(dx) 168 | def grad2(d_dx, d_db): 169 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 170 | return d_dy 171 | return (dx, db), grad2 172 | return y, grad 173 | 174 | # Slow version for general activation funcs. 175 | @tf.custom_gradient 176 | def func_nonzero_2nd_grad(x, b): 177 | y = func_y(x, b) 178 | def grad_wrap(dy): 179 | @tf.custom_gradient 180 | def grad_impl(dy, x): 181 | dx = grad_dx(dy, x, y) 182 | db = grad_db(dx) 183 | def grad2(d_dx, d_db): 184 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 185 | d_x = grad2_d_x(d_dx, d_db, x, y) 186 | return d_dy, d_x 187 | return (dx, db), grad2 188 | return grad_impl(dy, x) 189 | return y, grad_wrap 190 | 191 | # Which version to use? 192 | if act_spec.zero_2nd_grad: 193 | return func_zero_2nd_grad(x, b) 194 | return func_nonzero_2nd_grad(x, b) 195 | 196 | #---------------------------------------------------------------------------- 197 | -------------------------------------------------------------------------------- /projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import dnnlib 10 | import dnnlib.tflib as tflib 11 | 12 | from training import misc 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class Projector: 17 | def __init__(self): 18 | self.num_steps = 1000 19 | self.dlatent_avg_samples = 10000 20 | self.initial_learning_rate = 0.1 21 | self.initial_noise_factor = 0.05 22 | self.lr_rampdown_length = 0.25 23 | self.lr_rampup_length = 0.05 24 | self.noise_ramp_length = 0.75 25 | self.regularize_noise_weight = 1e5 26 | self.verbose = False 27 | self.clone_net = True 28 | 29 | self._Gs = None 30 | self._minibatch_size = None 31 | self._dlatent_avg = None 32 | self._dlatent_std = None 33 | self._noise_vars = None 34 | self._noise_init_op = None 35 | self._noise_normalize_op = None 36 | self._dlatents_var = None 37 | self._noise_in = None 38 | self._dlatents_expr = None 39 | self._images_expr = None 40 | self._target_images_var = None 41 | self._lpips = None 42 | self._dist = None 43 | self._loss = None 44 | self._reg_sizes = None 45 | self._lrate_in = None 46 | self._opt = None 47 | self._opt_step = None 48 | self._cur_step = None 49 | 50 | def _info(self, *args): 51 | if self.verbose: 52 | print('Projector:', *args) 53 | 54 | def set_network(self, Gs, minibatch_size=1): 55 | assert minibatch_size == 1 56 | self._Gs = Gs 57 | self._minibatch_size = minibatch_size 58 | if self._Gs is None: 59 | return 60 | if self.clone_net: 61 | self._Gs = self._Gs.clone() 62 | 63 | # Find dlatent stats. 64 | self._info('Finding W midpoint and stddev using %d samples...' % self.dlatent_avg_samples) 65 | latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:]) 66 | dlatent_samples = self._Gs.components.mapping.run(latent_samples, None)[:, :1, :] # [N, 1, 512] 67 | self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 1, 512] 68 | self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5 69 | self._info('std = %g' % self._dlatent_std) 70 | 71 | # Find noise inputs. 72 | self._info('Setting up noise inputs...') 73 | self._noise_vars = [] 74 | noise_init_ops = [] 75 | noise_normalize_ops = [] 76 | while True: 77 | n = 'G_synthesis/noise%d' % len(self._noise_vars) 78 | if not n in self._Gs.vars: 79 | break 80 | v = self._Gs.vars[n] 81 | self._noise_vars.append(v) 82 | noise_init_ops.append(tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32))) 83 | noise_mean = tf.reduce_mean(v) 84 | noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5 85 | noise_normalize_ops.append(tf.assign(v, (v - noise_mean) / noise_std)) 86 | self._info(n, v) 87 | self._noise_init_op = tf.group(*noise_init_ops) 88 | self._noise_normalize_op = tf.group(*noise_normalize_ops) 89 | 90 | # Image output graph. 91 | self._info('Building image output graph...') 92 | self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var') 93 | self._noise_in = tf.placeholder(tf.float32, [], name='noise_in') 94 | dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._noise_in 95 | self._dlatents_expr = tf.tile(self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1]) 96 | self._images_expr = self._Gs.components.synthesis.get_output_for(self._dlatents_expr, randomize_noise=False) 97 | 98 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 99 | proc_images_expr = (self._images_expr + 1) * (255 / 2) 100 | sh = proc_images_expr.shape.as_list() 101 | if sh[2] > 256: 102 | factor = sh[2] // 256 103 | proc_images_expr = tf.reduce_mean(tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5]) 104 | 105 | # Loss graph. 106 | self._info('Building loss graph...') 107 | self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var') 108 | if self._lpips is None: 109 | self._lpips = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl') 110 | self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var) 111 | self._loss = tf.reduce_sum(self._dist) 112 | 113 | # Noise regularization graph. 114 | self._info('Building noise regularization graph...') 115 | reg_loss = 0.0 116 | for v in self._noise_vars: 117 | sz = v.shape[2] 118 | while True: 119 | reg_loss += tf.reduce_mean(v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(v * tf.roll(v, shift=1, axis=2))**2 120 | if sz <= 8: 121 | break # Small enough already 122 | v = tf.reshape(v, [1, 1, sz//2, 2, sz//2, 2]) # Downscale 123 | v = tf.reduce_mean(v, axis=[3, 5]) 124 | sz = sz // 2 125 | self._loss += reg_loss * self.regularize_noise_weight 126 | 127 | # Optimizer. 128 | self._info('Setting up optimizer...') 129 | self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in') 130 | self._opt = dnnlib.tflib.Optimizer(learning_rate=self._lrate_in) 131 | self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars) 132 | self._opt_step = self._opt.apply_updates() 133 | 134 | def run(self, target_images): 135 | # Run to completion. 136 | self.start(target_images) 137 | while self._cur_step < self.num_steps: 138 | self.step() 139 | 140 | # Collect results. 141 | pres = dnnlib.EasyDict() 142 | pres.dlatents = self.get_dlatents() 143 | pres.noises = self.get_noises() 144 | pres.images = self.get_images() 145 | return pres 146 | 147 | def start(self, target_images): 148 | assert self._Gs is not None 149 | 150 | # Prepare target images. 151 | self._info('Preparing target images...') 152 | target_images = np.asarray(target_images, dtype='float32') 153 | target_images = (target_images + 1) * (255 / 2) 154 | sh = target_images.shape 155 | assert sh[0] == self._minibatch_size 156 | if sh[2] > self._target_images_var.shape[2]: 157 | factor = sh[2] // self._target_images_var.shape[2] 158 | target_images = np.reshape(target_images, [-1, sh[1], sh[2] // factor, factor, sh[3] // factor, factor]).mean((3, 5)) 159 | 160 | # Initialize optimization state. 161 | self._info('Initializing optimization state...') 162 | tflib.set_vars({self._target_images_var: target_images, self._dlatents_var: np.tile(self._dlatent_avg, [self._minibatch_size, 1, 1])}) 163 | tflib.run(self._noise_init_op) 164 | self._opt.reset_optimizer_state() 165 | self._cur_step = 0 166 | 167 | def step(self): 168 | assert self._cur_step is not None 169 | if self._cur_step >= self.num_steps: 170 | return 171 | if self._cur_step == 0: 172 | self._info('Running...') 173 | 174 | # Hyperparameters. 175 | t = self._cur_step / self.num_steps 176 | noise_strength = self._dlatent_std * self.initial_noise_factor * max(0.0, 1.0 - t / self.noise_ramp_length) ** 2 177 | lr_ramp = min(1.0, (1.0 - t) / self.lr_rampdown_length) 178 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 179 | lr_ramp = lr_ramp * min(1.0, t / self.lr_rampup_length) 180 | learning_rate = self.initial_learning_rate * lr_ramp 181 | 182 | # Train. 183 | feed_dict = {self._noise_in: noise_strength, self._lrate_in: learning_rate} 184 | _, dist_value, loss_value = tflib.run([self._opt_step, self._dist, self._loss], feed_dict) 185 | tflib.run(self._noise_normalize_op) 186 | 187 | # Print status. 188 | self._cur_step += 1 189 | if self._cur_step == self.num_steps or self._cur_step % 10 == 0: 190 | self._info('%-8d%-12g%-12g' % (self._cur_step, dist_value, loss_value)) 191 | if self._cur_step == self.num_steps: 192 | self._info('Done.') 193 | 194 | def get_cur_step(self): 195 | return self._cur_step 196 | 197 | def get_dlatents(self): 198 | return tflib.run(self._dlatents_expr, {self._noise_in: 0}) 199 | 200 | def get_noises(self): 201 | return tflib.run(self._noise_vars) 202 | 203 | def get_images(self): 204 | return tflib.run(self._images_expr, {self._noise_in: 0}) 205 | 206 | #---------------------------------------------------------------------------- 207 | -------------------------------------------------------------------------------- /projector_vc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | # >.>.>.>.>.>.>.>.>.>.>.>.>.>.>.>. 5 | # Licensed under the Apache License, Version 2.0 (the "License") 6 | # You may obtain a copy of the License at 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # --- File Name: projector_vc.py 10 | # --- Creation Date: 12-02-2020 11 | # --- Last Modified: Fri 20 Mar 2020 15:47:43 AEDT 12 | # --- Author: Xinqi Zhu 13 | # .<.<.<.<.<.<.<.<.<.<.<.<.<.<.<.< 14 | """ 15 | Projector of Variation Consistency Models 16 | """ 17 | 18 | from projector import Projector 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | import dnnlib 23 | import dnnlib.tflib as tflib 24 | 25 | from training import misc 26 | 27 | class ProjectorVC(Projector): 28 | def __init__(self): 29 | super().__init__() 30 | self.verbose = True 31 | self.num_steps = 200 32 | self.D_size = 0 33 | self.use_VGG = True 34 | 35 | def set_network(self, Gs, minibatch_size=1, D_size=0, use_VGG=True, num_steps=200): 36 | # assert minibatch_size == 1 37 | self.num_steps = num_steps 38 | self.D_size = D_size 39 | self.use_VGG = use_VGG 40 | self._Gs = Gs 41 | self._minibatch_size = minibatch_size 42 | if self._Gs is None: 43 | return 44 | if self.clone_net: 45 | self._Gs = self._Gs.clone() 46 | 47 | # Find dlatent stats. 48 | self._info('Finding W midpoint and stddev using %d samples...' % self.dlatent_avg_samples) 49 | # latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:]) 50 | latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, self._Gs.input_shapes[0][1]-self.D_size) # Only learn continuous latents 51 | # dlatent_samples = self._Gs.components.mapping.run(latent_samples, None) # [N, 512] 52 | dlatent_samples = latent_samples 53 | self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 512] 54 | self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5 55 | self._info('std = %g' % self._dlatent_std) 56 | 57 | # Find noise inputs. 58 | self._info('Setting up noise inputs...') 59 | self._noise_vars = [] 60 | noise_init_ops = [] 61 | noise_normalize_ops = [] 62 | while True: 63 | n = 'G_vc_synthesis/noise%d' % len(self._noise_vars) 64 | if not n in self._Gs.vars: 65 | break 66 | v = self._Gs.vars[n] 67 | self._noise_vars.append(v) 68 | noise_init_ops.append(tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32))) 69 | noise_mean = tf.reduce_mean(v) 70 | noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5 71 | noise_normalize_ops.append(tf.assign(v, (v - noise_mean) / noise_std)) 72 | self._info(n, v) 73 | self._noise_init_op = tf.group(*noise_init_ops) 74 | self._noise_normalize_op = tf.group(*noise_normalize_ops) 75 | 76 | # Image output graph. 77 | self._info('Building image output graph...') 78 | self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size * self.D_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var') 79 | self._noise_in = tf.placeholder(tf.float32, [], name='noise_in') 80 | dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._noise_in 81 | self._dlatents_expr = self._dlatents_var + dlatents_noise 82 | 83 | # Add discrete latents 84 | if self.D_size > 0: 85 | discrete_latents = tf.range(self.D_size, dtype=tf.int32) 86 | discrete_latents = tf.one_hot(discrete_latents, self.D_size) # [D_size, D_size] 87 | discrete_latents = tf.tile(discrete_latents, [self._minibatch_size, 1]) 88 | self._dlatents_expr = tf.concat([discrete_latents, self._dlatents_expr], axis=1) 89 | 90 | self._images_expr = self._Gs.components.synthesis.get_output_for(self._dlatents_expr, randomize_noise=False) 91 | 92 | # Extend channels to 3 93 | if self._images_expr.shape.as_list()[1] == 1: 94 | self._images_expr = tf.tile(self._images_expr, [1, 3, 1, 1]) 95 | 96 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 97 | proc_images_expr = (self._images_expr + 1) * (255 / 2) 98 | sh = proc_images_expr.shape.as_list() 99 | if sh[2] > 256: 100 | factor = sh[2] // 256 101 | proc_images_expr = tf.reduce_mean(tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5]) 102 | 103 | # Loss graph. 104 | self._info('Building loss graph...') 105 | self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var') 106 | print('self.proc_images_expr.shape:', proc_images_expr.shape.as_list()) 107 | print('self._target_images_var.shape:', self._target_images_var.shape.as_list()) 108 | if self.use_VGG: 109 | if self._lpips is None: 110 | self._lpips = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl') 111 | self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var) 112 | else: 113 | self._dist = (proc_images_expr - self._target_images_var) ** 2 114 | print('self._dist.shape:', self._dist.shape.as_list()) 115 | self._loss = tf.reduce_sum(self._dist) 116 | 117 | # Noise regularization graph. 118 | self._info('Building noise regularization graph...') 119 | reg_loss = 0.0 120 | for v in self._noise_vars: 121 | sz = v.shape[2] 122 | while True: 123 | reg_loss += tf.reduce_mean(v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(v * tf.roll(v, shift=1, axis=2))**2 124 | if sz <= 8: 125 | break # Small enough already 126 | v = tf.reshape(v, [1, 1, sz//2, 2, sz//2, 2]) # Downscale 127 | v = tf.reduce_mean(v, axis=[3, 5]) 128 | sz = sz // 2 129 | self._loss += reg_loss * self.regularize_noise_weight 130 | 131 | # Optimizer. 132 | self._info('Setting up optimizer...') 133 | self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in') 134 | self._opt = dnnlib.tflib.Optimizer(learning_rate=self._lrate_in) 135 | self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars) 136 | self._opt_step = self._opt.apply_updates() 137 | 138 | 139 | def start(self, target_images): 140 | assert self._Gs is not None 141 | 142 | if target_images.shape[1] == 1: 143 | target_images = np.tile(target_images, [1, 3, 1, 1]) 144 | 145 | # Prepare target images. 146 | self._info('Preparing target images...') 147 | target_images = np.asarray(target_images, dtype='float32') 148 | target_images = (target_images + 1) * (255 / 2) 149 | sh = target_images.shape 150 | assert sh[0] == self._minibatch_size 151 | if sh[2] > self._target_images_var.shape[2]: 152 | factor = sh[2] // self._target_images_var.shape[2] 153 | target_images = np.reshape(target_images, [-1, sh[1], sh[2] // factor, factor, sh[3] // factor, factor]).mean((3, 5)) 154 | 155 | if self.D_size > 0: 156 | sh = target_images.shape 157 | target_images = np.reshape(target_images, [-1, 1, sh[1], sh[2], sh[3]]) 158 | target_images = np.tile(target_images, [1, self.D_size, 1, 1, 1]) 159 | target_images = np.reshape(target_images, [-1, sh[1], sh[2], sh[3]]) 160 | assert target_images.shape[0] == self._minibatch_size * self.D_size 161 | 162 | # Initialize optimization state. 163 | self._info('Initializing optimization state...') 164 | tflib.set_vars({self._target_images_var: target_images, self._dlatents_var: np.tile(self._dlatent_avg, [self._minibatch_size * self.D_size if self.D_size > 0 else self._minibatch_size, 1])}) 165 | tflib.run(self._noise_init_op) 166 | self._opt.reset_optimizer_state() 167 | self._cur_step = 0 168 | 169 | def step(self): 170 | assert self._cur_step is not None 171 | if self._cur_step >= self.num_steps: 172 | return 173 | if self._cur_step == 0: 174 | self._info('Running...') 175 | 176 | # Hyperparameters. 177 | t = self._cur_step / self.num_steps 178 | noise_strength = self._dlatent_std * self.initial_noise_factor * max(0.0, 1.0 - t / self.noise_ramp_length) ** 2 179 | lr_ramp = min(1.0, (1.0 - t) / self.lr_rampdown_length) 180 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 181 | lr_ramp = lr_ramp * min(1.0, t / self.lr_rampup_length) 182 | learning_rate = self.initial_learning_rate * lr_ramp 183 | 184 | # Train. 185 | feed_dict = {self._noise_in: noise_strength, self._lrate_in: learning_rate} 186 | _, dist_value, loss_value = tflib.run([self._opt_step, self._dist, self._loss], feed_dict) 187 | tflib.run(self._noise_normalize_op) 188 | 189 | dist_value = np.reshape(dist_value, (-1, 10)) 190 | self.preds = np.argmin(dist_value, axis=1) 191 | 192 | # Print status. 193 | self._cur_step += 1 194 | if self._cur_step == self.num_steps or self._cur_step % 10 == 0: 195 | self._info('%-8d%-12g%-12g' % (self._cur_step, self.preds[0], loss_value)) 196 | if self._cur_step == self.num_steps: 197 | self._info('Done.') 198 | 199 | def get_dlatents(self): 200 | dlatents = tflib.run(self._dlatents_expr[self.preds], {self._noise_in: 0}) 201 | return dlatents 202 | 203 | def get_predictions(self): 204 | # self.preds.shape: [minibatch] 205 | return self.preds 206 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Multi-resolution input data pipeline.""" 8 | 9 | import os 10 | import glob 11 | import numpy as np 12 | import tensorflow as tf 13 | import dnnlib 14 | import dnnlib.tflib as tflib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Dataset class that loads data from tfrecords files. 18 | 19 | class TFRecordDataset: 20 | def __init__(self, 21 | tfrecord_dir, # Directory containing a collection of tfrecords files. 22 | resolution = None, # Dataset resolution, None = autodetect. 23 | label_file = None, # Relative path of the labels file, None = autodetect. 24 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. 25 | max_images = None, # Maximum number of images to use, None = use all images. 26 | repeat = True, # Repeat dataset indefinitely? 27 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 28 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. 29 | buffer_mb = 256, # Read buffer size (megabytes). 30 | num_threads = 2): # Number of concurrent threads. 31 | 32 | self.tfrecord_dir = tfrecord_dir 33 | self.resolution = None 34 | self.resolution_log2 = None 35 | self.shape = [] # [channels, height, width] 36 | self.dtype = 'uint8' 37 | self.dynamic_range = [0, 255] 38 | self.label_file = label_file 39 | self.label_size = None # components 40 | self.label_dtype = None 41 | self._np_labels = None 42 | self._tf_minibatch_in = None 43 | self._tf_labels_var = None 44 | self._tf_labels_dataset = None 45 | self._tf_datasets = dict() 46 | self._tf_iterator = None 47 | self._tf_init_ops = dict() 48 | self._tf_minibatch_np = None 49 | self._cur_minibatch = -1 50 | self._cur_lod = -1 51 | 52 | # List tfrecords files and inspect their shapes. 53 | assert os.path.isdir(self.tfrecord_dir) 54 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) 55 | assert len(tfr_files) >= 1 56 | tfr_shapes = [] 57 | for tfr_file in tfr_files: 58 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 59 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): 60 | tfr_shapes.append(self.parse_tfrecord_np(record).shape) 61 | break 62 | 63 | # Autodetect label filename. 64 | if self.label_file is None: 65 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) 66 | if len(guess): 67 | self.label_file = guess[0] 68 | elif not os.path.isfile(self.label_file): 69 | guess = os.path.join(self.tfrecord_dir, self.label_file) 70 | if os.path.isfile(guess): 71 | self.label_file = guess 72 | 73 | # Determine shape and resolution. 74 | max_shape = max(tfr_shapes, key=np.prod) 75 | self.resolution = resolution if resolution is not None else max_shape[1] 76 | self.resolution_log2 = int(np.log2(self.resolution)) 77 | self.shape = [max_shape[0], self.resolution, self.resolution] 78 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] 79 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes) 80 | assert all(shape[1] == shape[2] for shape in tfr_shapes) 81 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) 82 | assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) 83 | 84 | # Load labels. 85 | assert max_label_size == 'full' or max_label_size >= 0 86 | self._np_labels = np.zeros([1<<30, 0], dtype=np.float32) 87 | if self.label_file is not None and max_label_size != 0: 88 | self._np_labels = np.load(self.label_file) 89 | assert self._np_labels.ndim == 2 90 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: 91 | self._np_labels = self._np_labels[:, :max_label_size] 92 | if max_images is not None and self._np_labels.shape[0] > max_images: 93 | self._np_labels = self._np_labels[:max_images] 94 | self.label_size = self._np_labels.shape[1] 95 | self.label_dtype = self._np_labels.dtype.name 96 | 97 | # Build TF expressions. 98 | with tf.name_scope('Dataset'), tf.device('/cpu:0'): 99 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) 100 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') 101 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) 102 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): 103 | if tfr_lod < 0: 104 | continue 105 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) 106 | if max_images is not None: 107 | dset = dset.take(max_images) 108 | dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads) 109 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) 110 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize 111 | if shuffle_mb > 0: 112 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) 113 | if repeat: 114 | dset = dset.repeat() 115 | if prefetch_mb > 0: 116 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) 117 | dset = dset.batch(self._tf_minibatch_in) 118 | self._tf_datasets[tfr_lod] = dset 119 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) 120 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} 121 | 122 | def close(self): 123 | pass 124 | 125 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). 126 | def configure(self, minibatch_size, lod=0): 127 | lod = int(np.floor(lod)) 128 | assert minibatch_size >= 1 and lod in self._tf_datasets 129 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod: 130 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 131 | self._cur_minibatch = minibatch_size 132 | self._cur_lod = lod 133 | 134 | # Get next minibatch as TensorFlow expressions. 135 | def get_minibatch_tf(self): # => images, labels 136 | return self._tf_iterator.get_next() 137 | 138 | # Get next minibatch as NumPy arrays. 139 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 140 | self.configure(minibatch_size, lod) 141 | with tf.name_scope('Dataset'): 142 | if self._tf_minibatch_np is None: 143 | self._tf_minibatch_np = self.get_minibatch_tf() 144 | return tflib.run(self._tf_minibatch_np) 145 | 146 | # Get random labels as TensorFlow expression. 147 | def get_random_labels_tf(self, minibatch_size): # => labels 148 | with tf.name_scope('Dataset'): 149 | if self.label_size > 0: 150 | with tf.device('/cpu:0'): 151 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) 152 | return tf.zeros([minibatch_size, 0], self.label_dtype) 153 | 154 | # Get random labels as NumPy array. 155 | def get_random_labels_np(self, minibatch_size): # => labels 156 | if self.label_size > 0: 157 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] 158 | return np.zeros([minibatch_size, 0], self.label_dtype) 159 | 160 | # Parse individual image from a tfrecords file into TensorFlow expression. 161 | @staticmethod 162 | def parse_tfrecord_tf(record): 163 | features = tf.parse_single_example(record, features={ 164 | 'shape': tf.FixedLenFeature([3], tf.int64), 165 | 'data': tf.FixedLenFeature([], tf.string)}) 166 | data = tf.decode_raw(features['data'], tf.uint8) 167 | return tf.reshape(data, features['shape']) 168 | 169 | # Parse individual image from a tfrecords file into NumPy array. 170 | @staticmethod 171 | def parse_tfrecord_np(record): 172 | ex = tf.train.Example() 173 | ex.ParseFromString(record) 174 | shape = ex.features.feature['shape'].int64_list.value # pylint: disable=no-member 175 | data = ex.features.feature['data'].bytes_list.value[0] # pylint: disable=no-member 176 | return np.fromstring(data, np.uint8).reshape(shape) 177 | 178 | #---------------------------------------------------------------------------- 179 | # Helper func for constructing a dataset object using the given options. 180 | 181 | def load_dataset(class_name=None, data_dir=None, verbose=False, **kwargs): 182 | kwargs = dict(kwargs) 183 | if 'tfrecord_dir' in kwargs: 184 | if class_name is None: 185 | class_name = __name__ + '.TFRecordDataset' 186 | if data_dir is not None: 187 | kwargs['tfrecord_dir'] = os.path.join(data_dir, kwargs['tfrecord_dir']) 188 | 189 | assert class_name is not None 190 | if verbose: 191 | print('Streaming data using %s...' % class_name) 192 | dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs) 193 | if verbose: 194 | print('Dataset shape =', np.int32(dataset.shape).tolist()) 195 | print('Dynamic range =', dataset.dynamic_range) 196 | print('Label size =', dataset.label_size) 197 | return dataset 198 | 199 | #---------------------------------------------------------------------------- 200 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Miscellaneous helper utils for Tensorflow.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | # Silence deprecation warnings from TensorFlow 1.13 onwards 14 | import logging 15 | logging.getLogger('tensorflow').setLevel(logging.ERROR) 16 | import tensorflow.contrib # requires TensorFlow 1.x! 17 | tf.contrib = tensorflow.contrib 18 | 19 | from typing import Any, Iterable, List, Union 20 | 21 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 22 | """A type that represents a valid Tensorflow expression.""" 23 | 24 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 25 | """A type that can be converted to a valid Tensorflow expression.""" 26 | 27 | 28 | def run(*args, **kwargs) -> Any: 29 | """Run the specified ops in the default session.""" 30 | assert_tf_initialized() 31 | return tf.get_default_session().run(*args, **kwargs) 32 | 33 | 34 | def is_tf_expression(x: Any) -> bool: 35 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 36 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 37 | 38 | 39 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 40 | """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" 41 | return [dim.value for dim in shape] 42 | 43 | 44 | def flatten(x: TfExpressionEx) -> TfExpression: 45 | """Shortcut function for flattening a tensor.""" 46 | with tf.name_scope("Flatten"): 47 | return tf.reshape(x, [-1]) 48 | 49 | 50 | def log2(x: TfExpressionEx) -> TfExpression: 51 | """Logarithm in base 2.""" 52 | with tf.name_scope("Log2"): 53 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 54 | 55 | 56 | def exp2(x: TfExpressionEx) -> TfExpression: 57 | """Exponent in base 2.""" 58 | with tf.name_scope("Exp2"): 59 | return tf.exp(x * np.float32(np.log(2.0))) 60 | 61 | 62 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 63 | """Linear interpolation.""" 64 | with tf.name_scope("Lerp"): 65 | return a + (b - a) * t 66 | 67 | 68 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 69 | """Linear interpolation with clip.""" 70 | with tf.name_scope("LerpClip"): 71 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 72 | 73 | 74 | def absolute_name_scope(scope: str) -> tf.name_scope: 75 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 76 | return tf.name_scope(scope + "/") 77 | 78 | 79 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 80 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 81 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 82 | 83 | 84 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 85 | # Defaults. 86 | cfg = dict() 87 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 88 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 89 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 90 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 91 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 92 | 93 | # Remove defaults for environment variables that are already set. 94 | for key in list(cfg): 95 | fields = key.split(".") 96 | if fields[0] == "env": 97 | assert len(fields) == 2 98 | if fields[1] in os.environ: 99 | del cfg[key] 100 | 101 | # User overrides. 102 | if config_dict is not None: 103 | cfg.update(config_dict) 104 | return cfg 105 | 106 | 107 | def init_tf(config_dict: dict = None) -> None: 108 | """Initialize TensorFlow session using good default settings.""" 109 | # Skip if already initialized. 110 | if tf.get_default_session() is not None: 111 | return 112 | 113 | # Setup config dict and random seeds. 114 | cfg = _sanitize_tf_config(config_dict) 115 | np_random_seed = cfg["rnd.np_random_seed"] 116 | if np_random_seed is not None: 117 | np.random.seed(np_random_seed) 118 | tf_random_seed = cfg["rnd.tf_random_seed"] 119 | if tf_random_seed == "auto": 120 | tf_random_seed = np.random.randint(1 << 31) 121 | if tf_random_seed is not None: 122 | tf.set_random_seed(tf_random_seed) 123 | 124 | # Setup environment variables. 125 | for key, value in cfg.items(): 126 | fields = key.split(".") 127 | if fields[0] == "env": 128 | assert len(fields) == 2 129 | os.environ[fields[1]] = str(value) 130 | 131 | # Create default TensorFlow session. 132 | create_session(cfg, force_as_default=True) 133 | 134 | 135 | def assert_tf_initialized(): 136 | """Check that TensorFlow session has been initialized.""" 137 | if tf.get_default_session() is None: 138 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 139 | 140 | 141 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 142 | """Create tf.Session based on config dict.""" 143 | # Setup TensorFlow config proto. 144 | cfg = _sanitize_tf_config(config_dict) 145 | config_proto = tf.ConfigProto() 146 | for key, value in cfg.items(): 147 | fields = key.split(".") 148 | if fields[0] not in ["rnd", "env"]: 149 | obj = config_proto 150 | for field in fields[:-1]: 151 | obj = getattr(obj, field) 152 | setattr(obj, fields[-1], value) 153 | 154 | # Create session. 155 | session = tf.Session(config=config_proto) 156 | if force_as_default: 157 | # pylint: disable=protected-access 158 | session._default_session = session.as_default() 159 | session._default_session.enforce_nesting = False 160 | session._default_session.__enter__() 161 | return session 162 | 163 | 164 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 165 | """Initialize all tf.Variables that have not already been initialized. 166 | 167 | Equivalent to the following, but more efficient and does not bloat the tf graph: 168 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 169 | """ 170 | assert_tf_initialized() 171 | if target_vars is None: 172 | target_vars = tf.global_variables() 173 | 174 | test_vars = [] 175 | test_ops = [] 176 | 177 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 178 | for var in target_vars: 179 | assert is_tf_expression(var) 180 | 181 | try: 182 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 183 | except KeyError: 184 | # Op does not exist => variable may be uninitialized. 185 | test_vars.append(var) 186 | 187 | with absolute_name_scope(var.name.split(":")[0]): 188 | test_ops.append(tf.is_variable_initialized(var)) 189 | 190 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 191 | run([var.initializer for var in init_vars]) 192 | 193 | 194 | def set_vars(var_to_value_dict: dict) -> None: 195 | """Set the values of given tf.Variables. 196 | 197 | Equivalent to the following, but more efficient and does not bloat the tf graph: 198 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 199 | """ 200 | assert_tf_initialized() 201 | ops = [] 202 | feed_dict = {} 203 | 204 | for var, value in var_to_value_dict.items(): 205 | assert is_tf_expression(var) 206 | 207 | try: 208 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 209 | except KeyError: 210 | with absolute_name_scope(var.name.split(":")[0]): 211 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 212 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 213 | 214 | ops.append(setter) 215 | feed_dict[setter.op.inputs[1]] = value 216 | 217 | run(ops, feed_dict) 218 | 219 | 220 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 221 | """Create tf.Variable with large initial value without bloating the tf graph.""" 222 | assert_tf_initialized() 223 | assert isinstance(initial_value, np.ndarray) 224 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 225 | var = tf.Variable(zeros, *args, **kwargs) 226 | set_vars({var: initial_value}) 227 | return var 228 | 229 | 230 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 231 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 232 | Can be used as an input transformation for Network.run(). 233 | """ 234 | images = tf.cast(images, tf.float32) 235 | if nhwc_to_nchw: 236 | images = tf.transpose(images, [0, 3, 1, 2]) 237 | return images * ((drange[1] - drange[0]) / 255) + drange[0] 238 | 239 | 240 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 241 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 242 | Can be used as an output transformation for Network.run(). 243 | """ 244 | images = tf.cast(images, tf.float32) 245 | if shrink > 1: 246 | ksize = [1, 1, shrink, shrink] 247 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 248 | if nchw_to_nhwc: 249 | images = tf.transpose(images, [0, 2, 3, 1]) 250 | scale = 255 / (drange[1] - drange[0]) 251 | images = images * scale + (0.5 - drange[0] * scale) 252 | return tf.saturate_cast(images, tf.uint8) 253 | -------------------------------------------------------------------------------- /metrics/linear_separability.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Linear Separability (LS).""" 8 | 9 | from collections import defaultdict 10 | import numpy as np 11 | import sklearn.svm 12 | import tensorflow as tf 13 | import dnnlib.tflib as tflib 14 | 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | classifier_urls = [ 21 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-00-male.pkl', 22 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-01-smiling.pkl', 23 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-02-attractive.pkl', 24 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-03-wavy-hair.pkl', 25 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-04-young.pkl', 26 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-05-5-o-clock-shadow.pkl', 27 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-06-arched-eyebrows.pkl', 28 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-07-bags-under-eyes.pkl', 29 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-08-bald.pkl', 30 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-09-bangs.pkl', 31 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-10-big-lips.pkl', 32 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-11-big-nose.pkl', 33 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-12-black-hair.pkl', 34 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-13-blond-hair.pkl', 35 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-14-blurry.pkl', 36 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-15-brown-hair.pkl', 37 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-16-bushy-eyebrows.pkl', 38 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-17-chubby.pkl', 39 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-18-double-chin.pkl', 40 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-19-eyeglasses.pkl', 41 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-20-goatee.pkl', 42 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-21-gray-hair.pkl', 43 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-22-heavy-makeup.pkl', 44 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-23-high-cheekbones.pkl', 45 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-24-mouth-slightly-open.pkl', 46 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-25-mustache.pkl', 47 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-26-narrow-eyes.pkl', 48 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-27-no-beard.pkl', 49 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-28-oval-face.pkl', 50 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-29-pale-skin.pkl', 51 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-30-pointy-nose.pkl', 52 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-31-receding-hairline.pkl', 53 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-32-rosy-cheeks.pkl', 54 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-33-sideburns.pkl', 55 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-34-straight-hair.pkl', 56 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-35-wearing-earrings.pkl', 57 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-36-wearing-hat.pkl', 58 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-37-wearing-lipstick.pkl', 59 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-38-wearing-necklace.pkl', 60 | 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/celebahq-classifier-39-wearing-necktie.pkl', 61 | ] 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | def prob_normalize(p): 66 | p = np.asarray(p).astype(np.float32) 67 | assert len(p.shape) == 2 68 | return p / np.sum(p) 69 | 70 | def mutual_information(p): 71 | p = prob_normalize(p) 72 | px = np.sum(p, axis=1) 73 | py = np.sum(p, axis=0) 74 | result = 0.0 75 | for x in range(p.shape[0]): 76 | p_x = px[x] 77 | for y in range(p.shape[1]): 78 | p_xy = p[x][y] 79 | p_y = py[y] 80 | if p_xy > 0.0: 81 | result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output 82 | return result 83 | 84 | def entropy(p): 85 | p = prob_normalize(p) 86 | result = 0.0 87 | for x in range(p.shape[0]): 88 | for y in range(p.shape[1]): 89 | p_xy = p[x][y] 90 | if p_xy > 0.0: 91 | result -= p_xy * np.log2(p_xy) 92 | return result 93 | 94 | def conditional_entropy(p): 95 | # H(Y|X) where X corresponds to axis 0, Y to axis 1 96 | # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? 97 | p = prob_normalize(p) 98 | y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) 99 | return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | class LS(metric_base.MetricBase): 104 | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): 105 | assert num_keep <= num_samples 106 | super().__init__(**kwargs) 107 | self.num_samples = num_samples 108 | self.num_keep = num_keep 109 | self.attrib_indices = attrib_indices 110 | self.minibatch_per_gpu = minibatch_per_gpu 111 | 112 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 113 | minibatch_size = num_gpus * self.minibatch_per_gpu 114 | 115 | # Construct TensorFlow graph for each GPU. 116 | result_expr = [] 117 | for gpu_idx in range(num_gpus): 118 | with tf.device('/gpu:%d' % gpu_idx): 119 | Gs_clone = Gs.clone() 120 | 121 | # Generate images. 122 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 123 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 124 | dlatents = Gs_clone.components.mapping.get_output_for(latents, labels, **Gs_kwargs) 125 | images = Gs_clone.get_output_for(latents, None, **Gs_kwargs) 126 | 127 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 128 | if images.shape[2] > 256: 129 | factor = images.shape[2] // 256 130 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 131 | images = tf.reduce_mean(images, axis=[3, 5]) 132 | 133 | # Run classifier for each attribute. 134 | result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) 135 | for attrib_idx in self.attrib_indices: 136 | classifier = misc.load_pkl(classifier_urls[attrib_idx]) 137 | logits = classifier.get_output_for(images, None) 138 | predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) 139 | result_dict[attrib_idx] = predictions 140 | result_expr.append(result_dict) 141 | 142 | # Sampling loop. 143 | results = [] 144 | for begin in range(0, self.num_samples, minibatch_size): 145 | self._report_progress(begin, self.num_samples) 146 | results += tflib.run(result_expr) 147 | results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} 148 | 149 | # Calculate conditional entropy for each attribute. 150 | conditional_entropies = defaultdict(list) 151 | for attrib_idx in self.attrib_indices: 152 | # Prune the least confident samples. 153 | pruned_indices = list(range(self.num_samples)) 154 | pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) 155 | pruned_indices = pruned_indices[:self.num_keep] 156 | 157 | # Fit SVM to the remaining samples. 158 | svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) 159 | for space in ['latents', 'dlatents']: 160 | svm_inputs = results[space][pruned_indices] 161 | try: 162 | svm = sklearn.svm.LinearSVC() 163 | svm.fit(svm_inputs, svm_targets) 164 | svm.score(svm_inputs, svm_targets) 165 | svm_outputs = svm.predict(svm_inputs) 166 | except: 167 | svm_outputs = svm_targets # assume perfect prediction 168 | 169 | # Calculate conditional entropy. 170 | p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] 171 | conditional_entropies[space].append(conditional_entropy(p)) 172 | 173 | # Calculate separability scores. 174 | scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} 175 | self._report_result(scores['latents'], suffix='_z') 176 | self._report_result(scores['dlatents'], suffix='_w') 177 | 178 | #---------------------------------------------------------------------------- 179 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Precision/Recall (PR).""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | import dnnlib 13 | import dnnlib.tflib as tflib 14 | 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def batch_pairwise_distances(U, V): 21 | """ Compute pairwise distances between two batches of feature vectors.""" 22 | with tf.variable_scope('pairwise_dist_block'): 23 | # Squared norms of each row in U and V. 24 | norm_u = tf.reduce_sum(tf.square(U), 1) 25 | norm_v = tf.reduce_sum(tf.square(V), 1) 26 | 27 | # norm_u as a row and norm_v as a column vectors. 28 | norm_u = tf.reshape(norm_u, [-1, 1]) 29 | norm_v = tf.reshape(norm_v, [1, -1]) 30 | 31 | # Pairwise squared Euclidean distances. 32 | D = tf.maximum(norm_u - 2*tf.matmul(U, V, False, True) + norm_v, 0.0) 33 | 34 | return D 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class DistanceBlock(): 39 | """Distance block.""" 40 | def __init__(self, num_features, num_gpus): 41 | self.num_features = num_features 42 | self.num_gpus = num_gpus 43 | 44 | # Initialize TF graph to calculate pairwise distances. 45 | with tf.device('/cpu:0'): 46 | self._features_batch1 = tf.placeholder(tf.float16, shape=[None, self.num_features]) 47 | self._features_batch2 = tf.placeholder(tf.float16, shape=[None, self.num_features]) 48 | features_split2 = tf.split(self._features_batch2, self.num_gpus, axis=0) 49 | distances_split = [] 50 | for gpu_idx in range(self.num_gpus): 51 | with tf.device('/gpu:%d' % gpu_idx): 52 | distances_split.append(batch_pairwise_distances(self._features_batch1, features_split2[gpu_idx])) 53 | self._distance_block = tf.concat(distances_split, axis=1) 54 | 55 | def pairwise_distances(self, U, V): 56 | """Evaluate pairwise distances between two batches of feature vectors.""" 57 | return self._distance_block.eval(feed_dict={self._features_batch1: U, self._features_batch2: V}) 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class ManifoldEstimator(): 62 | """Finds an estimate for the manifold of given feature vectors.""" 63 | def __init__(self, distance_block, features, row_batch_size, col_batch_size, nhood_sizes, clamp_to_percentile=None): 64 | """Find an estimate of the manifold of given feature vectors.""" 65 | num_images = features.shape[0] 66 | self.nhood_sizes = nhood_sizes 67 | self.num_nhoods = len(nhood_sizes) 68 | self.row_batch_size = row_batch_size 69 | self.col_batch_size = col_batch_size 70 | self._ref_features = features 71 | self._distance_block = distance_block 72 | 73 | # Estimate manifold of features by calculating distances to kth nearest neighbor of each sample. 74 | self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float16) 75 | distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float16) 76 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 77 | 78 | for begin1 in range(0, num_images, row_batch_size): 79 | end1 = min(begin1 + row_batch_size, num_images) 80 | row_batch = features[begin1:end1] 81 | 82 | for begin2 in range(0, num_images, col_batch_size): 83 | end2 = min(begin2 + col_batch_size, num_images) 84 | col_batch = features[begin2:end2] 85 | 86 | # Compute distances between batches. 87 | distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch, col_batch) 88 | 89 | # Find the kth nearest neighbor from the current batch. 90 | self.D[begin1:end1, :] = np.partition(distance_batch[0:end1-begin1, :], seq, axis=1)[:, self.nhood_sizes] 91 | 92 | if clamp_to_percentile is not None: 93 | max_distances = np.percentile(self.D, clamp_to_percentile, axis=0) 94 | self.D[self.D > max_distances] = 0 #max_distances # 0 95 | 96 | def evaluate(self, eval_features, return_realism=False, return_neighbors=False): 97 | """Evaluate if new feature vectors are in the estimated manifold.""" 98 | num_eval_images = eval_features.shape[0] 99 | num_ref_images = self.D.shape[0] 100 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float16) 101 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 102 | #max_realism_score = np.zeros([num_eval_images,], dtype=np.float32) 103 | realism_score = np.zeros([num_eval_images,], dtype=np.float32) 104 | nearest_indices = np.zeros([num_eval_images,], dtype=np.int32) 105 | 106 | for begin1 in range(0, num_eval_images, self.row_batch_size): 107 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 108 | feature_batch = eval_features[begin1:end1] 109 | 110 | for begin2 in range(0, num_ref_images, self.col_batch_size): 111 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 112 | ref_batch = self._ref_features[begin2:end2] 113 | 114 | distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch, ref_batch) 115 | 116 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 117 | # If a feature vector is inside a hypersphere of some reference sample, then the new sample lies on the estimated manifold. 118 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 119 | samples_in_manifold = distance_batch[0:end1-begin1, :, None] <= self.D 120 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) 121 | 122 | #max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1) 123 | #nearest_indices[begin1:end1] = np.argmax(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1) 124 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1-begin1, :], axis=1) 125 | realism_score[begin1:end1] = self.D[nearest_indices[begin1:end1], 0] / np.min(distance_batch[0:end1-begin1, :], axis=1) 126 | 127 | if return_realism and return_neighbors: 128 | return batch_predictions, realism_score, nearest_indices 129 | elif return_realism: 130 | return batch_predictions, realism_score 131 | elif return_neighbors: 132 | return batch_predictions, nearest_indices 133 | 134 | return batch_predictions 135 | 136 | #---------------------------------------------------------------------------- 137 | 138 | def knn_precision_recall_features(ref_features, eval_features, feature_net, nhood_sizes, 139 | row_batch_size, col_batch_size, num_gpus): 140 | """Calculates k-NN precision and recall for two sets of feature vectors.""" 141 | state = dnnlib.EasyDict() 142 | #num_images = ref_features.shape[0] 143 | num_features = feature_net.output_shape[1] 144 | state.ref_features = ref_features 145 | state.eval_features = eval_features 146 | 147 | # Initialize DistanceBlock and ManifoldEstimators. 148 | distance_block = DistanceBlock(num_features, num_gpus) 149 | state.ref_manifold = ManifoldEstimator(distance_block, state.ref_features, row_batch_size, col_batch_size, nhood_sizes) 150 | state.eval_manifold = ManifoldEstimator(distance_block, state.eval_features, row_batch_size, col_batch_size, nhood_sizes) 151 | 152 | # Evaluate precision and recall using k-nearest neighbors. 153 | #print('Evaluating k-NN precision and recall with %i samples...' % num_images) 154 | #start = time.time() 155 | 156 | # Precision: How many points from eval_features are in ref_features manifold. 157 | state.precision, state.realism_scores, state.nearest_neighbors = state.ref_manifold.evaluate(state.eval_features, return_realism=True, return_neighbors=True) 158 | state.knn_precision = state.precision.mean(axis=0) 159 | 160 | # Recall: How many points from ref_features are in eval_features manifold. 161 | state.recall = state.eval_manifold.evaluate(state.ref_features) 162 | state.knn_recall = state.recall.mean(axis=0) 163 | 164 | #elapsed_time = time.time() - start 165 | #print('Done evaluation in: %gs' % elapsed_time) 166 | 167 | return state 168 | 169 | #---------------------------------------------------------------------------- 170 | 171 | class PR(metric_base.MetricBase): 172 | def __init__(self, num_images, nhood_size, minibatch_per_gpu, row_batch_size, col_batch_size, **kwargs): 173 | super().__init__(**kwargs) 174 | self.num_images = num_images 175 | self.nhood_size = nhood_size 176 | self.minibatch_per_gpu = minibatch_per_gpu 177 | self.row_batch_size = row_batch_size 178 | self.col_batch_size = col_batch_size 179 | 180 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 181 | minibatch_size = num_gpus * self.minibatch_per_gpu 182 | feature_net = misc.load_pkl('http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16.pkl') 183 | 184 | # Calculate features for reals. 185 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images) 186 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 187 | if os.path.isfile(cache_file): 188 | ref_features = misc.load_pkl(cache_file) 189 | else: 190 | ref_features = np.empty([self.num_images, feature_net.output_shape[1]], dtype=np.float32) 191 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): 192 | begin = idx * minibatch_size 193 | end = min(begin + minibatch_size, self.num_images) 194 | ref_features[begin:end] = feature_net.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) 195 | if end == self.num_images: 196 | break 197 | misc.save_pkl(ref_features, cache_file) 198 | 199 | # Construct TensorFlow graph. 200 | result_expr = [] 201 | for gpu_idx in range(num_gpus): 202 | with tf.device('/gpu:%d' % gpu_idx): 203 | Gs_clone = Gs.clone() 204 | feature_net_clone = feature_net.clone() 205 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 206 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 207 | images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs) 208 | images = tflib.convert_images_to_uint8(images) 209 | result_expr.append(feature_net_clone.get_output_for(images)) 210 | 211 | # Calculate features for fakes. 212 | eval_features = np.empty([self.num_images, feature_net.output_shape[1]], dtype=np.float32) 213 | for begin in range(0, self.num_images, minibatch_size): 214 | self._report_progress(begin, self.num_images) 215 | end = min(begin + minibatch_size, self.num_images) 216 | eval_features[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 217 | 218 | # Calculate precision and recall. 219 | state = knn_precision_recall_features(ref_features=ref_features, eval_features=eval_features, feature_net=feature_net, 220 | nhood_sizes=[self.nhood_size], row_batch_size=self.row_batch_size, col_batch_size=self.row_batch_size, num_gpus=num_gpus) 221 | self._report_result(state.knn_precision[0], suffix='_precision') 222 | self._report_result(state.knn_recall[0], suffix='_recall') 223 | 224 | #---------------------------------------------------------------------------- 225 | -------------------------------------------------------------------------------- /run_unsupervised_acc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | # >.>.>.>.>.>.>.>.>.>.>.>.>.>.>.>. 5 | # Licensed under the Apache License, Version 2.0 (the "License") 6 | # You may obtain a copy of the License at 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # --- File Name: run_unsupervised_acc.py 10 | # --- Creation Date: 12-02-2020 11 | # --- Last Modified: Thu 13 Feb 2020 04:04:47 AEDT 12 | # --- Author: Xinqi Zhu 13 | # .<.<.<.<.<.<.<.<.<.<.<.<.<.<.<.< 14 | """ 15 | Calculate the unsupervised classification accuracy 16 | of Variation Consistency Models 17 | """ 18 | 19 | import argparse 20 | import numpy as np 21 | import dnnlib 22 | import dnnlib.tflib as tflib 23 | import re 24 | import pdb 25 | import sys 26 | 27 | import projector_vc 28 | import pretrained_networks 29 | from training import dataset 30 | from training import misc 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def project_image(proj, targets, png_prefix, num_snapshots): 35 | snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) 36 | misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1]) 37 | proj.start(targets) 38 | while proj.get_cur_step() < proj.num_steps: 39 | print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) 40 | proj.step() 41 | if proj.get_cur_step() in snapshot_steps: 42 | misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1]) 43 | # print(proj.get_predictions()) 44 | # print('\r%-30s\r' % '', end='', flush=True) 45 | return proj.get_predictions() 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def project_generated_images(network_pkl, seeds, num_snapshots, truncation_psi, 50 | D_size=0, minibatch_size=1, use_VGG=True): 51 | tflib.init_tf() 52 | print('Loading networks from "%s"...' % network_pkl) 53 | _G, _D, I, Gs = misc.load_pkl(network_pkl) 54 | # _G, _D, Gs = misc.load_pkl(network_pkl) 55 | # _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 56 | 57 | proj = projector_vc.ProjectorVC() 58 | proj.set_network(Gs, minibatch_size=minibatch_size, D_size=D_size, use_VGG=use_VGG, num_steps=num_steps) 59 | noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] 60 | 61 | Gs_kwargs = dnnlib.EasyDict() 62 | Gs_kwargs.randomize_noise = False 63 | Gs_kwargs.truncation_psi = truncation_psi 64 | 65 | for seed_idx, seed in enumerate(seeds): 66 | print('Projecting seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 67 | rnd = np.random.RandomState(seed) 68 | z = rnd.randn(1, *Gs.input_shape[1:]) 69 | tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) 70 | images = Gs.run(z, None, **Gs_kwargs) 71 | project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('seed%04d-' % seed), num_snapshots=num_snapshots) 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | def project_real_images(network_pkl, dataset_name, data_dir, num_images, num_snapshots, 76 | D_size=0, minibatch_size=1, use_VGG=True): 77 | tflib.init_tf() 78 | print('Loading networks from "%s"...' % network_pkl) 79 | _G, _D, I, Gs = misc.load_pkl(network_pkl) 80 | # _G, _D, Gs = misc.load_pkl(network_pkl) 81 | # _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 82 | 83 | proj = projector_vc.ProjectorVC() 84 | proj.set_network(Gs, minibatch_size=minibatch_size, D_size=D_size, use_VGG=use_VGG, num_steps=num_steps) 85 | 86 | print('Loading images from "%s"...' % dataset_name) 87 | dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size='full', repeat=False, shuffle_mb=0) 88 | assert dataset_obj.shape == Gs.output_shape[1:] 89 | 90 | for image_idx in range(num_images): 91 | print('Projecting image %d/%d ...' % (image_idx, num_images)) 92 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 93 | print('images.shape:', images.shape) 94 | print('_labels.shape:', _labels.shape) 95 | print('_labels:', _labels) 96 | print('argmax of _labels:', np.argmax(_labels, axis=1)) 97 | # pdb.set_trace() 98 | images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1]) 99 | project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=num_snapshots) 100 | #---------------------------------------------------------------------------- 101 | 102 | def classify_images(network_pkl, train_dataset_name, data_dir, n_batches_of_train_imgs, 103 | test_dataset_name=None, D_size=0, minibatch_size=1, use_VGG=True, log_freq=10, num_steps=200): 104 | tflib.init_tf() 105 | print('Loading networks from "%s"...' % network_pkl) 106 | _G, _D, I, Gs = misc.load_pkl(network_pkl) 107 | # _G, _D, Gs = misc.load_pkl(network_pkl) 108 | # _G, _D, Gs = pretrained_networks.load_networks(network_pkl) 109 | 110 | proj = projector_vc.ProjectorVC() 111 | proj.set_network(Gs, minibatch_size=minibatch_size, D_size=D_size, use_VGG=use_VGG, num_steps=num_steps) 112 | 113 | print('Loading images from "%s"...' % train_dataset_name) 114 | dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=train_dataset_name, max_label_size='full', repeat=False, shuffle_mb=0) 115 | assert dataset_obj.shape == Gs.output_shape[1:] 116 | 117 | vote_matrix = np.zeros((D_size, D_size), dtype=np.int32) 118 | # Training 119 | all_correct_train = 0 120 | all_preds_train = 0 121 | for image_idx in range(n_batches_of_train_imgs): 122 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 123 | images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1]) 124 | preds = project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=0) 125 | labels = np.argmax(_labels, axis=1) 126 | for i in range(len(preds)): 127 | vote_matrix[preds[i], labels[i]] += 1 128 | pred_to_label = np.argmax(vote_matrix, axis=1) 129 | 130 | # Calc training acc 131 | preds_l = pred_to_label[preds] 132 | all_preds_train += len(preds_l) 133 | all_correct_train += np.sum(preds_l == labels) 134 | if image_idx % log_freq == 0: 135 | print('Training Acc: ', float(all_correct_train) / float(all_preds_train)) 136 | 137 | print('Loading images from "%s"...' % test_dataset_name) 138 | dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=test_dataset_name, max_label_size='full', repeat=False, shuffle_mb=0) 139 | print('Whole testing set size: ', dataset_obj.label_size) 140 | # pdb.set_trace() 141 | assert dataset_obj.shape == Gs.output_shape[1:] 142 | all_correct = 0 143 | all_preds = 0 144 | for image_idx in range(10000 // minibatch_size): 145 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 146 | images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1]) 147 | preds = project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=0) 148 | preds_l = pred_to_label[preds] 149 | labels = np.argmax(_labels, axis=1) 150 | all_preds += len(preds_l) 151 | all_correct += np.sum(preds_l == labels) 152 | if image_idx % log_freq == 0: 153 | print('Testing Acc: ', float(all_correct) / float(all_preds)) 154 | 155 | #---------------------------------------------------------------------------- 156 | 157 | def _parse_num_range(s): 158 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 159 | 160 | range_re = re.compile(r'^(\d+)-(\d+)$') 161 | m = range_re.match(s) 162 | if m: 163 | return range(int(m.group(1)), int(m.group(2))+1) 164 | vals = s.split(',') 165 | return [int(x) for x in vals] 166 | 167 | #---------------------------------------------------------------------------- 168 | 169 | def _str_to_bool(v): 170 | if isinstance(v, bool): 171 | return v 172 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 173 | return True 174 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 175 | return False 176 | else: 177 | raise argparse.ArgumentTypeError('Boolean value expected.') 178 | 179 | #---------------------------------------------------------------------------- 180 | 181 | _examples = '''examples: 182 | 183 | # Project generated images 184 | python %(prog)s project-generated-images --network=gdrive:networks/stylegan2-car-config-f.pkl --seeds=0,1,5 185 | 186 | # Project real images 187 | python %(prog)s project-real-images --network=gdrive:networks/stylegan2-car-config-f.pkl --dataset=car --data-dir=~/datasets 188 | 189 | ''' 190 | 191 | #---------------------------------------------------------------------------- 192 | 193 | def main(): 194 | parser = argparse.ArgumentParser( 195 | description='''VC-Gan Classifier. 196 | 197 | Run 'python %(prog)s --help' for subcommand help.''', 198 | epilog=_examples, 199 | formatter_class=argparse.RawDescriptionHelpFormatter 200 | ) 201 | 202 | subparsers = parser.add_subparsers(help='Sub-commands', dest='command') 203 | 204 | project_generated_images_parser = subparsers.add_parser('project-generated-images', help='Project generated images') 205 | project_generated_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 206 | project_generated_images_parser.add_argument('--seeds', type=_parse_num_range, help='List of random seeds', default=range(3)) 207 | project_generated_images_parser.add_argument('--num-snapshots', type=int, help='Number of snapshots (default: %(default)s)', default=5) 208 | project_generated_images_parser.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=1.0) 209 | project_generated_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 210 | project_generated_images_parser.add_argument('--D_size', type=int, help='Number of discrete latents', default=10) 211 | project_generated_images_parser.add_argument('--minibatch_size', type=int, help='Minibatch size', default=1) 212 | 213 | project_real_images_parser = subparsers.add_parser('project-real-images', help='Project real images') 214 | project_real_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 215 | project_real_images_parser.add_argument('--data-dir', help='Dataset root directory', required=True) 216 | project_real_images_parser.add_argument('--dataset', help='Training dataset', dest='dataset_name', required=True) 217 | project_real_images_parser.add_argument('--num-snapshots', type=int, help='Number of snapshots (default: %(default)s)', default=5) 218 | project_real_images_parser.add_argument('--num-images', type=int, help='Number of images to project (default: %(default)s)', default=3) 219 | project_real_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 220 | project_real_images_parser.add_argument('--D_size', type=int, help='Number of discrete latents', default=10) 221 | project_real_images_parser.add_argument('--minibatch_size', type=int, help='Minibatch size', default=1) 222 | project_real_images_parser.add_argument('--use_VGG', help='If use VGG for distance eval', default=True, metavar='BOOL', type=_str_to_bool) 223 | 224 | classify_real_images_parser = subparsers.add_parser('classify-real-images', help='Project real images') 225 | classify_real_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 226 | classify_real_images_parser.add_argument('--data-dir', help='Dataset root directory', required=True) 227 | classify_real_images_parser.add_argument('--train_dataset', help='Training dataset', dest='train_dataset_name', required=True) 228 | classify_real_images_parser.add_argument('--test_dataset', help='Testing dataset', dest='test_dataset_name', required=True) 229 | classify_real_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 230 | classify_real_images_parser.add_argument('--D_size', type=int, help='Number of discrete latents', default=10) 231 | classify_real_images_parser.add_argument('--minibatch_size', type=int, help='Minibatch size', default=1) 232 | classify_real_images_parser.add_argument('--use_VGG', help='If use VGG for distance eval', default=True, metavar='BOOL', type=_str_to_bool) 233 | classify_real_images_parser.add_argument('--n_batches_of_train_imgs', type=int, help='Number of batches for training', default=4000) 234 | classify_real_images_parser.add_argument('--log_freq', type=int, help='Frequency for show acc during training', default=200) 235 | classify_real_images_parser.add_argument('--num_steps', type=int, help='Number of steps for inference', default=200) 236 | 237 | 238 | args = parser.parse_args() 239 | subcmd = args.command 240 | if subcmd is None: 241 | print ('Error: missing subcommand. Re-run with --help for usage.') 242 | sys.exit(1) 243 | 244 | kwargs = vars(args) 245 | sc = dnnlib.SubmitConfig() 246 | sc.num_gpus = 1 247 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 248 | sc.local.do_not_copy_source_files = True 249 | sc.run_dir_root = kwargs.pop('result_dir') 250 | sc.run_desc = kwargs.pop('command') 251 | 252 | func_name_map = { 253 | 'project-generated-images': 'run_unsupervised_acc.project_generated_images', 254 | 'project-real-images': 'run_unsupervised_acc.project_real_images', 255 | 'classify-real-images': 'run_unsupervised_acc.classify_images' 256 | } 257 | dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs) 258 | 259 | #---------------------------------------------------------------------------- 260 | 261 | if __name__ == "__main__": 262 | main() 263 | 264 | #---------------------------------------------------------------------------- 265 | --------------------------------------------------------------------------------