├── stylegan-teaser.png ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── perceptual_path_length.py ├── metric_base.py └── linear_separability.py ├── training ├── __init__.py ├── test.py ├── loss.py ├── dataset.py ├── misc.py ├── networks_progan.py └── networks_progan_parallel.py ├── dnnlib ├── submission │ ├── __init__.py │ ├── _internal │ │ └── run.py │ ├── run_context.py │ └── submit.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── tfutil.py │ └── optimizer.py ├── __init__.py └── util.py ├── config.py ├── pretrained_example.py ├── run_metrics.py ├── generate_figures.py ├── train.py └── LICENSE.txt /stylegan-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuronets/stylegan3d/HEAD/stylegan-teaser.png -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Global configuration.""" 9 | 10 | #---------------------------------------------------------------------------- 11 | # Paths. 12 | 13 | result_dir = 'results_proto' 14 | data_dir = 'datasets' 15 | cache_dir = 'cache' 16 | run_dir_ignore = ['results_proto', 'datasets', 'cache', 'results_256', 'results'] 17 | 18 | #---------------------------------------------------------------------------- 19 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | # submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /training/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import dnnlib 5 | import dnnlib.tflib as tflib 6 | from dnnlib.tflib.autosummary import autosummary 7 | 8 | import config 9 | import train 10 | from training import dataset 11 | from training import misc 12 | from metrics import metric_base 13 | 14 | def mixing(resume_run_id, resume_snapshot=None): 15 | network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) 16 | print('Loading networks from "%s"...' % network_pkl) 17 | G, D, Gs = misc.load_pkl(network_pkl) 18 | 19 | latents_1 = np.random.randn((1,*G.input_shape[1:])) 20 | labels_1 = [1,0,0,0,0,0] 21 | 22 | latents_2 = np.random.randn((1,*G.input_shape[1:])) 23 | labels_2 = [0,1,0,0,0,0] 24 | 25 | w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True) 26 | w_2 = Gs.components.mapping.get_output_for(latents_2, labels_2, is_validation=True) 27 | 28 | print(w1) 29 | print(w2) 30 | 31 | 32 | # def main(): 33 | # kwargs = EasyDict(train) 34 | # kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) 35 | # kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) 36 | # kwargs.submit_config = copy.deepcopy(submit_config) 37 | # kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 38 | # kwargs.submit_config.run_dir_ignore += config.run_dir_ignore 39 | # kwargs.submit_config.run_desc = desc 40 | # dnnlib.submit_run(**kwargs) 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | if __name__ == "__main__": 45 | main() -------------------------------------------------------------------------------- /dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /pretrained_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Minimal script for generating an image using pre-trained StyleGAN generator.""" 9 | 10 | import os 11 | import pickle 12 | import numpy as np 13 | import PIL.Image 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | import config 17 | 18 | def main(): 19 | # Initialize TensorFlow. 20 | tflib.init_tf() 21 | 22 | # Load pre-trained network. 23 | url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 24 | with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 25 | _G, _D, Gs = pickle.load(f) 26 | # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. 27 | # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. 28 | # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. 29 | 30 | # Print network details. 31 | Gs.print_layers() 32 | 33 | # Pick latent vector. 34 | rnd = np.random.RandomState(5) 35 | latents = rnd.randn(1, Gs.input_shape[1]) 36 | 37 | # Generate image. 38 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 39 | images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) 40 | 41 | # Save image. 42 | os.makedirs(config.result_dir, exist_ok=True) 43 | png_filename = os.path.join(config.result_dir, 'example.png') 44 | PIL.Image.fromarray(images[0], 'RGB').save(png_filename) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import numpy as np 12 | import scipy 13 | import tensorflow as tf 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class FID(metric_base.MetricBase): 22 | def __init__(self, num_images, minibatch_per_gpu, **kwargs): 23 | super().__init__(**kwargs) 24 | self.num_images = num_images 25 | self.minibatch_per_gpu = minibatch_per_gpu 26 | 27 | def _evaluate(self, Gs, num_gpus): 28 | minibatch_size = num_gpus * self.minibatch_per_gpu 29 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl 30 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 31 | 32 | # Calculate statistics for reals. 33 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images) 34 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 35 | if os.path.isfile(cache_file): 36 | mu_real, sigma_real = misc.load_pkl(cache_file) 37 | else: 38 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): 39 | begin = idx * minibatch_size 40 | end = min(begin + minibatch_size, self.num_images) 41 | activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) 42 | if end == self.num_images: 43 | break 44 | mu_real = np.mean(activations, axis=0) 45 | sigma_real = np.cov(activations, rowvar=False) 46 | misc.save_pkl((mu_real, sigma_real), cache_file) 47 | 48 | # Construct TensorFlow graph. 49 | result_expr = [] 50 | for gpu_idx in range(num_gpus): 51 | with tf.device('/gpu:%d' % gpu_idx): 52 | Gs_clone = Gs.clone() 53 | inception_clone = inception.clone() 54 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 55 | images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) 56 | images = tflib.convert_images_to_uint8(images) 57 | result_expr.append(inception_clone.get_output_for(images)) 58 | 59 | # Calculate statistics for fakes. 60 | for begin in range(0, self.num_images, minibatch_size): 61 | end = min(begin + minibatch_size, self.num_images) 62 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 63 | mu_fake = np.mean(activations, axis=0) 64 | sigma_fake = np.cov(activations, rowvar=False) 65 | 66 | # Calculate FID. 67 | m = np.square(mu_fake - mu_real).sum() 68 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member 69 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 70 | self._report_result(np.real(dist)) 71 | 72 | #---------------------------------------------------------------------------- 73 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /run_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Main entry point for training StyleGAN and ProGAN networks.""" 9 | 10 | import dnnlib 11 | from dnnlib import EasyDict 12 | import dnnlib.tflib as tflib 13 | 14 | import config 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment): 21 | ctx = dnnlib.RunContext(submit_config) 22 | tflib.init_tf() 23 | print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl)) 24 | metric = dnnlib.util.call_func_by_name(**metric_args) 25 | print() 26 | metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus) 27 | print() 28 | ctx.close() 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | def run_snapshot(submit_config, metric_args, run_id, snapshot): 33 | ctx = dnnlib.RunContext(submit_config) 34 | tflib.init_tf() 35 | print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot)) 36 | run_dir = misc.locate_run_dir(run_id) 37 | network_pkl = misc.locate_network_pkl(run_dir, snapshot) 38 | metric = dnnlib.util.call_func_by_name(**metric_args) 39 | print() 40 | metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) 41 | print() 42 | ctx.close() 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def run_all_snapshots(submit_config, metric_args, run_id): 47 | ctx = dnnlib.RunContext(submit_config) 48 | tflib.init_tf() 49 | print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id)) 50 | run_dir = misc.locate_run_dir(run_id) 51 | network_pkls = misc.list_network_pkls(run_dir) 52 | metric = dnnlib.util.call_func_by_name(**metric_args) 53 | print() 54 | for idx, network_pkl in enumerate(network_pkls): 55 | ctx.update('', idx, len(network_pkls)) 56 | metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) 57 | print() 58 | ctx.close() 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | def main(): 63 | submit_config = dnnlib.SubmitConfig() 64 | 65 | # Which metrics to evaluate? 66 | metrics = [] 67 | metrics += [metric_base.fid50k] 68 | #metrics += [metric_base.ppl_zfull] 69 | #metrics += [metric_base.ppl_wfull] 70 | #metrics += [metric_base.ppl_zend] 71 | #metrics += [metric_base.ppl_wend] 72 | #metrics += [metric_base.ls] 73 | #metrics += [metric_base.dummy] 74 | 75 | # Which networks to evaluate them on? 76 | tasks = [] 77 | tasks += [EasyDict(run_func_name='run_metrics.run_pickle', network_pkl='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', dataset_args=EasyDict(tfrecord_dir='ffhq', shuffle_mb=0), mirror_augment=True)] # karras2019stylegan-ffhq-1024x1024.pkl 78 | #tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)] 79 | #tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)] 80 | 81 | # How many GPUs to use? 82 | submit_config.num_gpus = 1 83 | #submit_config.num_gpus = 2 84 | #submit_config.num_gpus = 4 85 | #submit_config.num_gpus = 8 86 | 87 | # Execute. 88 | submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 89 | submit_config.run_dir_ignore += config.run_dir_ignore 90 | for task in tasks: 91 | for metric in metrics: 92 | submit_config.run_desc = '%s-%s' % (task.run_func_name, metric.name) 93 | if task.run_func_name.endswith('run_snapshot'): 94 | submit_config.run_desc += '-%s-%s' % (task.run_id, task.snapshot) 95 | if task.run_func_name.endswith('run_all_snapshots'): 96 | submit_config.run_desc += '-%s' % task.run_id 97 | submit_config.run_desc += '-%dgpu' % submit_config.num_gpus 98 | dnnlib.submit_run(submit_config, metric_args=metric, **task) 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | if __name__ == "__main__": 103 | main() 104 | 105 | #---------------------------------------------------------------------------- 106 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Perceptual Path Length (PPL).""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import dnnlib.tflib as tflib 13 | 14 | from metrics import metric_base 15 | from training import misc 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | # Normalize batch of vectors. 20 | def normalize(v): 21 | return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) 22 | 23 | # Spherical interpolation of a batch of vectors. 24 | def slerp(a, b, t): 25 | a = normalize(a) 26 | b = normalize(b) 27 | d = tf.reduce_sum(a * b, axis=-1, keepdims=True) 28 | p = t * tf.math.acos(d) 29 | c = normalize(b - d * a) 30 | d = a * tf.math.cos(p) + c * tf.math.sin(p) 31 | return normalize(d) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPL(metric_base.MetricBase): 36 | def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__(**kwargs) 40 | self.num_samples = num_samples 41 | self.epsilon = epsilon 42 | self.space = space 43 | self.sampling = sampling 44 | self.minibatch_per_gpu = minibatch_per_gpu 45 | 46 | def _evaluate(self, Gs, num_gpus): 47 | minibatch_size = num_gpus * self.minibatch_per_gpu 48 | 49 | # Construct TensorFlow graph. 50 | distance_expr = [] 51 | for gpu_idx in range(num_gpus): 52 | with tf.device('/gpu:%d' % gpu_idx): 53 | Gs_clone = Gs.clone() 54 | noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] 55 | 56 | # Generate random latents and interpolation t-values. 57 | lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) 58 | lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) 59 | 60 | # Interpolate in W or Z. 61 | if self.space == 'w': 62 | dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True) 63 | dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] 64 | dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) 65 | dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) 66 | dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) 67 | else: # space == 'z' 68 | lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] 69 | lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) 70 | lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) 71 | lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) 72 | dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True) 73 | 74 | # Synthesize images. 75 | with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch 76 | images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False) 77 | 78 | # Crop only the face region. 79 | c = int(images.shape[2] // 8) 80 | images = images[:, :, c*3 : c*7, c*2 : c*6] 81 | 82 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 83 | if images.shape[2] > 256: 84 | factor = images.shape[2] // 256 85 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 86 | images = tf.reduce_mean(images, axis=[3,5]) 87 | 88 | # Scale dynamic range from [-1,1] to [0,255] for VGG. 89 | images = (images + 1) * (255 / 2) 90 | 91 | # Evaluate perceptual distance. 92 | img_e0, img_e1 = images[0::2], images[1::2] 93 | distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl 94 | distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) 95 | 96 | # Sampling loop. 97 | all_distances = [] 98 | for _ in range(0, self.num_samples, minibatch_size): 99 | all_distances += tflib.run(distance_expr) 100 | all_distances = np.concatenate(all_distances, axis=0) 101 | 102 | # Reject outliers. 103 | lo = np.percentile(all_distances, 1, interpolation='lower') 104 | hi = np.percentile(all_distances, 99, interpolation='higher') 105 | filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) 106 | self._report_result(np.mean(filtered_distances)) 107 | 108 | #---------------------------------------------------------------------------- 109 | -------------------------------------------------------------------------------- /metrics/metric_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Common definitions for GAN metrics.""" 9 | 10 | import os 11 | import time 12 | import hashlib 13 | import numpy as np 14 | import tensorflow as tf 15 | import dnnlib 16 | import dnnlib.tflib as tflib 17 | 18 | import config 19 | from training import misc 20 | from training import dataset 21 | 22 | #---------------------------------------------------------------------------- 23 | # Standard metrics. 24 | 25 | fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) 26 | ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) 27 | ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) 28 | ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) 29 | ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) 30 | ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) 31 | dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging 32 | 33 | #---------------------------------------------------------------------------- 34 | # Base class for metrics. 35 | 36 | class MetricBase: 37 | def __init__(self, name): 38 | self.name = name 39 | self._network_pkl = None 40 | self._dataset_args = None 41 | self._mirror_augment = None 42 | self._results = [] 43 | self._eval_time = None 44 | 45 | def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): 46 | self._network_pkl = network_pkl 47 | self._dataset_args = dataset_args 48 | self._mirror_augment = mirror_augment 49 | self._results = [] 50 | 51 | if (dataset_args is None or mirror_augment is None) and run_dir is not None: 52 | run_config = misc.parse_config_for_previous_run(run_dir) 53 | self._dataset_args = dict(run_config['dataset']) 54 | self._dataset_args['shuffle_mb'] = 0 55 | self._mirror_augment = run_config['train'].get('mirror_augment', False) 56 | 57 | time_begin = time.time() 58 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager 59 | _G, _D, Gs = misc.load_pkl(self._network_pkl) 60 | self._evaluate(Gs, num_gpus=num_gpus) 61 | self._eval_time = time.time() - time_begin 62 | 63 | if log_results: 64 | result_str = self.get_result_str() 65 | if run_dir is not None: 66 | log = os.path.join(run_dir, 'metric-%s.txt' % self.name) 67 | with dnnlib.util.Logger(log, 'a'): 68 | print(result_str) 69 | else: 70 | print(result_str) 71 | 72 | def get_result_str(self): 73 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] 74 | if len(network_name) > 29: 75 | network_name = '...' + network_name[-26:] 76 | result_str = '%-30s' % network_name 77 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) 78 | for res in self._results: 79 | result_str += ' ' + self.name + res.suffix + ' ' 80 | result_str += res.fmt % res.value 81 | return result_str 82 | 83 | def update_autosummaries(self): 84 | for res in self._results: 85 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) 86 | 87 | def _evaluate(self, Gs, num_gpus): 88 | raise NotImplementedError # to be overridden by subclasses 89 | 90 | def _report_result(self, value, suffix='', fmt='%-10.4f'): 91 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] 92 | 93 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs): 94 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) 95 | all_args.update(self._dataset_args) 96 | all_args.update(kwargs) 97 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) 98 | dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] 99 | return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) 100 | 101 | def _iterate_reals(self, minibatch_size): 102 | dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) 103 | while True: 104 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 105 | if self._mirror_augment: 106 | images = misc.apply_mirror_augment(images) 107 | yield images 108 | 109 | def _iterate_fakes(self, Gs, minibatch_size, num_gpus): 110 | while True: 111 | latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) 112 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 113 | images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) 114 | yield images 115 | 116 | #---------------------------------------------------------------------------- 117 | # Group of multiple metrics. 118 | 119 | class MetricGroup: 120 | def __init__(self, metric_kwarg_list): 121 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] 122 | 123 | def run(self, *args, **kwargs): 124 | for metric in self.metrics: 125 | metric.run(*args, **kwargs) 126 | 127 | def get_result_str(self): 128 | return ' '.join(metric.get_result_str() for metric in self.metrics) 129 | 130 | def update_autosummaries(self): 131 | for metric in self.metrics: 132 | metric.update_autosummaries() 133 | 134 | #---------------------------------------------------------------------------- 135 | # Dummy metric for debugging purposes. 136 | 137 | class DummyMetric(MetricBase): 138 | def _evaluate(self, Gs, num_gpus): 139 | _ = Gs, num_gpus 140 | self._report_result(0.0) 141 | 142 | #---------------------------------------------------------------------------- 143 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /generate_figures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators.""" 9 | 10 | import os 11 | import pickle 12 | import numpy as np 13 | import PIL.Image 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | import config 17 | 18 | #---------------------------------------------------------------------------- 19 | # Helpers for loading and using pre-trained generators. 20 | 21 | url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 22 | url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl 23 | url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl 24 | url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl 25 | url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl 26 | 27 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 28 | 29 | _Gs_cache = dict() 30 | 31 | def load_Gs(url): 32 | if url not in _Gs_cache: 33 | with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 34 | _G, _D, Gs = pickle.load(f) 35 | _Gs_cache[url] = Gs 36 | return _Gs_cache[url] 37 | 38 | #---------------------------------------------------------------------------- 39 | # Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images. 40 | 41 | def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed): 42 | print(png) 43 | latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1]) 44 | images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb] 45 | 46 | canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white') 47 | image_iter = iter(list(images)) 48 | for col, lod in enumerate(lods): 49 | for row in range(rows * 2**lod): 50 | image = PIL.Image.fromarray(next(image_iter), 'RGB') 51 | image = image.crop((cx, cy, cx + cw, cy + ch)) 52 | image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS) 53 | canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod)) 54 | canvas.save(png) 55 | 56 | #---------------------------------------------------------------------------- 57 | # Figure 3: Style mixing. 58 | 59 | def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges): 60 | print(png) 61 | src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) 62 | dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds) 63 | src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] 64 | dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component] 65 | src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs) 66 | dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs) 67 | 68 | canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white') 69 | for col, src_image in enumerate(list(src_images)): 70 | canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0)) 71 | for row, dst_image in enumerate(list(dst_images)): 72 | canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h)) 73 | row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds)) 74 | row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]] 75 | row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) 76 | for col, image in enumerate(list(row_images)): 77 | canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h)) 78 | canvas.save(png) 79 | 80 | #---------------------------------------------------------------------------- 81 | # Figure 4: Noise detail. 82 | 83 | def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds): 84 | print(png) 85 | canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white') 86 | for row, seed in enumerate(seeds): 87 | latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples) 88 | images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs) 89 | canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h)) 90 | for i in range(4): 91 | crop = PIL.Image.fromarray(images[i + 1], 'RGB') 92 | crop = crop.crop((650, 180, 906, 436)) 93 | crop = crop.resize((w//2, h//2), PIL.Image.NEAREST) 94 | canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2)) 95 | diff = np.std(np.mean(images, axis=3), axis=0) * 4 96 | diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8) 97 | canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h)) 98 | canvas.save(png) 99 | 100 | #---------------------------------------------------------------------------- 101 | # Figure 5: Noise components. 102 | 103 | def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips): 104 | print(png) 105 | Gsc = Gs.clone() 106 | noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')] 107 | noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...] 108 | latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) 109 | all_images = [] 110 | for noise_range in noise_ranges: 111 | tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)}) 112 | range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs) 113 | range_images[flips, :, :] = range_images[flips, :, ::-1] 114 | all_images.append(list(range_images)) 115 | 116 | canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white') 117 | for col, col_images in enumerate(zip(*all_images)): 118 | canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0)) 119 | canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0)) 120 | canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h)) 121 | canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h)) 122 | canvas.save(png) 123 | 124 | #---------------------------------------------------------------------------- 125 | # Figure 8: Truncation trick. 126 | 127 | def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis): 128 | print(png) 129 | latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) 130 | dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component] 131 | dlatent_avg = Gs.get_var('dlatent_avg') # [component] 132 | 133 | canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white') 134 | for row, dlatent in enumerate(list(dlatents)): 135 | row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg 136 | row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) 137 | for col, image in enumerate(list(row_images)): 138 | canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h)) 139 | canvas.save(png) 140 | 141 | #---------------------------------------------------------------------------- 142 | # Main program. 143 | 144 | def main(): 145 | tflib.init_tf() 146 | os.makedirs(config.result_dir, exist_ok=True) 147 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5) 148 | draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)]) 149 | draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012]) 150 | draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1]) 151 | draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1]) 152 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0) 153 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2) 154 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1) 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | if __name__ == "__main__": 159 | main() 160 | 161 | #---------------------------------------------------------------------------- 162 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | return 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | return tf.saturate_cast(images, tf.uint8) 241 | -------------------------------------------------------------------------------- /dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | # assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /metrics/linear_separability.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Linear Separability (LS).""" 9 | 10 | from collections import defaultdict 11 | import numpy as np 12 | import sklearn.svm 13 | import tensorflow as tf 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | classifier_urls = [ 22 | 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl 23 | 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl 24 | 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl 25 | 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl 26 | 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl 27 | 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl 28 | 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl 29 | 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl 30 | 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl 31 | 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl 32 | 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl 33 | 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl 34 | 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl 35 | 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl 36 | 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl 37 | 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl 38 | 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl 39 | 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl 40 | 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl 41 | 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl 42 | 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl 43 | 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl 44 | 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl 45 | 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl 46 | 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl 47 | 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl 48 | 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl 49 | 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl 50 | 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl 51 | 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl 52 | 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl 53 | 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl 54 | 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl 55 | 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl 56 | 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl 57 | 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl 58 | 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl 59 | 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl 60 | 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl 61 | 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl 62 | ] 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | def prob_normalize(p): 67 | p = np.asarray(p).astype(np.float32) 68 | assert len(p.shape) == 2 69 | return p / np.sum(p) 70 | 71 | def mutual_information(p): 72 | p = prob_normalize(p) 73 | px = np.sum(p, axis=1) 74 | py = np.sum(p, axis=0) 75 | result = 0.0 76 | for x in range(p.shape[0]): 77 | p_x = px[x] 78 | for y in range(p.shape[1]): 79 | p_xy = p[x][y] 80 | p_y = py[y] 81 | if p_xy > 0.0: 82 | result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output 83 | return result 84 | 85 | def entropy(p): 86 | p = prob_normalize(p) 87 | result = 0.0 88 | for x in range(p.shape[0]): 89 | for y in range(p.shape[1]): 90 | p_xy = p[x][y] 91 | if p_xy > 0.0: 92 | result -= p_xy * np.log2(p_xy) 93 | return result 94 | 95 | def conditional_entropy(p): 96 | # H(Y|X) where X corresponds to axis 0, Y to axis 1 97 | # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? 98 | p = prob_normalize(p) 99 | y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) 100 | return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | class LS(metric_base.MetricBase): 105 | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): 106 | assert num_keep <= num_samples 107 | super().__init__(**kwargs) 108 | self.num_samples = num_samples 109 | self.num_keep = num_keep 110 | self.attrib_indices = attrib_indices 111 | self.minibatch_per_gpu = minibatch_per_gpu 112 | 113 | def _evaluate(self, Gs, num_gpus): 114 | minibatch_size = num_gpus * self.minibatch_per_gpu 115 | 116 | # Construct TensorFlow graph for each GPU. 117 | result_expr = [] 118 | for gpu_idx in range(num_gpus): 119 | with tf.device('/gpu:%d' % gpu_idx): 120 | Gs_clone = Gs.clone() 121 | 122 | # Generate images. 123 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 124 | dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True) 125 | images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True) 126 | 127 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 128 | if images.shape[2] > 256: 129 | factor = images.shape[2] // 256 130 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 131 | images = tf.reduce_mean(images, axis=[3, 5]) 132 | 133 | # Run classifier for each attribute. 134 | result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) 135 | for attrib_idx in self.attrib_indices: 136 | classifier = misc.load_pkl(classifier_urls[attrib_idx]) 137 | logits = classifier.get_output_for(images, None) 138 | predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) 139 | result_dict[attrib_idx] = predictions 140 | result_expr.append(result_dict) 141 | 142 | # Sampling loop. 143 | results = [] 144 | for _ in range(0, self.num_samples, minibatch_size): 145 | results += tflib.run(result_expr) 146 | results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} 147 | 148 | # Calculate conditional entropy for each attribute. 149 | conditional_entropies = defaultdict(list) 150 | for attrib_idx in self.attrib_indices: 151 | # Prune the least confident samples. 152 | pruned_indices = list(range(self.num_samples)) 153 | pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) 154 | pruned_indices = pruned_indices[:self.num_keep] 155 | 156 | # Fit SVM to the remaining samples. 157 | svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) 158 | for space in ['latents', 'dlatents']: 159 | svm_inputs = results[space][pruned_indices] 160 | try: 161 | svm = sklearn.svm.LinearSVC() 162 | svm.fit(svm_inputs, svm_targets) 163 | svm.score(svm_inputs, svm_targets) 164 | svm_outputs = svm.predict(svm_inputs) 165 | except: 166 | svm_outputs = svm_targets # assume perfect prediction 167 | 168 | # Calculate conditional entropy. 169 | p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] 170 | conditional_entropies[space].append(conditional_entropy(p)) 171 | 172 | # Calculate separability scores. 173 | scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} 174 | self._report_result(scores['latents'], suffix='_z') 175 | self._report_result(scores['dlatents'], suffix='_w') 176 | 177 | #---------------------------------------------------------------------------- 178 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Loss functions.""" 9 | 10 | import tensorflow as tf 11 | import dnnlib.tflib as tflib 12 | from dnnlib.tflib.autosummary import autosummary 13 | 14 | #---------------------------------------------------------------------------- 15 | # Convenience func that casts all of its arguments to tf.float32. 16 | 17 | def fp32(*values): 18 | if len(values) == 1 and isinstance(values[0], tuple): 19 | values = values[0] 20 | values = tuple(tf.cast(v, tf.float32) for v in values) 21 | return values if len(values) >= 2 else values[0] 22 | 23 | #---------------------------------------------------------------------------- 24 | # WGAN & WGAN-GP loss functions. 25 | 26 | def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 27 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 28 | labels = training_set.get_random_labels_tf(minibatch_size) 29 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 30 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 31 | loss = -fake_scores_out 32 | return loss 33 | 34 | def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 35 | wgan_epsilon = 0.001): # Weight for the epsilon term, \epsilon_{drift}. 36 | 37 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 38 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 39 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 40 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 41 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 42 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 43 | loss = fake_scores_out - real_scores_out 44 | 45 | with tf.name_scope('EpsilonPenalty'): 46 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) 47 | loss += epsilon_penalty * wgan_epsilon 48 | return loss 49 | 50 | def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 51 | wgan_lambda = 10.0, # Weight for the gradient penalty term. 52 | wgan_epsilon = 0.001, # Weight for the epsilon term, \epsilon_{drift}. 53 | wgan_target = 1.0): # Target value for gradient magnitudes. 54 | 55 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 56 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 57 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 58 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 59 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 60 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 61 | loss = fake_scores_out - real_scores_out 62 | 63 | with tf.name_scope('GradientPenalty'): 64 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) 65 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) 66 | mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) 67 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) 68 | mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) 69 | mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) 70 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) 71 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) 72 | gradient_penalty = tf.square(mixed_norms - wgan_target) 73 | loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) 74 | 75 | with tf.name_scope('EpsilonPenalty'): 76 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) 77 | loss += epsilon_penalty * wgan_epsilon 78 | return loss 79 | 80 | #---------------------------------------------------------------------------- 81 | # Hinge loss functions. (Use G_wgan with these) 82 | 83 | def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument 84 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 85 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 86 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 87 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 88 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 89 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 90 | loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) 91 | return loss 92 | 93 | def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 94 | wgan_lambda = 10.0, # Weight for the gradient penalty term. 95 | wgan_target = 1.0): # Target value for gradient magnitudes. 96 | 97 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 98 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 99 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 100 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 101 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 102 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 103 | loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) 104 | 105 | with tf.name_scope('GradientPenalty'): 106 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) 107 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) 108 | mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) 109 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) 110 | mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) 111 | mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) 112 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) 113 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) 114 | gradient_penalty = tf.square(mixed_norms - wgan_target) 115 | loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) 116 | return loss 117 | 118 | 119 | #---------------------------------------------------------------------------- 120 | # Loss functions advocated by the paper 121 | # "Which Training Methods for GANs do actually Converge?" 122 | 123 | def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 124 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 125 | labels = training_set.get_random_labels_tf(minibatch_size) 126 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 127 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 128 | loss = -tf.nn.softplus(fake_scores_out) # log(1 - logistic(fake_scores_out)) 129 | return loss 130 | 131 | def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 132 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 133 | labels = training_set.get_random_labels_tf(minibatch_size) 134 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 135 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 136 | loss = tf.nn.softplus(-fake_scores_out) # -log(logistic(fake_scores_out)) 137 | return loss 138 | 139 | def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument 140 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 141 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 142 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 143 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 144 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 145 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 146 | loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) 147 | loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type 148 | return loss 149 | 150 | def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument 151 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 152 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 153 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 154 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 155 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 156 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 157 | loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) 158 | loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type 159 | 160 | if r1_gamma != 0.0: 161 | with tf.name_scope('R1Penalty'): 162 | real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out)) 163 | real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0])) 164 | r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3]) 165 | r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) 166 | loss += r1_penalty * (r1_gamma * 0.5) 167 | 168 | if r2_gamma != 0.0: 169 | with tf.name_scope('R2Penalty'): 170 | fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out)) 171 | fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0])) 172 | r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3]) 173 | r2_penalty = autosummary('Loss/r2_penalty', r2_penalty) 174 | loss += r2_penalty * (r2_gamma * 0.5) 175 | return loss 176 | 177 | #---------------------------------------------------------------------------- 178 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Submit a function to be run either locally or in a computing cluster.""" 9 | 10 | import copy 11 | import io 12 | import os 13 | import pathlib 14 | import pickle 15 | import platform 16 | import pprint 17 | import re 18 | import shutil 19 | import time 20 | import traceback 21 | 22 | import zipfile 23 | 24 | from enum import Enum 25 | 26 | from .. import util 27 | from ..util import EasyDict 28 | 29 | 30 | class SubmitTarget(Enum): 31 | """The target where the function should be run. 32 | 33 | LOCAL: Run it locally. 34 | """ 35 | LOCAL = 1 36 | 37 | 38 | class PathType(Enum): 39 | """Determines in which format should a path be formatted. 40 | 41 | WINDOWS: Format with Windows style. 42 | LINUX: Format with Linux/Posix style. 43 | AUTO: Use current OS type to select either WINDOWS or LINUX. 44 | """ 45 | WINDOWS = 1 46 | LINUX = 2 47 | AUTO = 3 48 | 49 | 50 | _user_name_override = None 51 | 52 | 53 | class SubmitConfig(util.EasyDict): 54 | """Strongly typed config dict needed to submit runs. 55 | 56 | Attributes: 57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 58 | run_desc: Description of the run. Will be used in the run dir and task name. 59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 61 | submit_target: Submit target enum value. Used to select where the run is actually launched. 62 | num_gpus: Number of GPUs used/requested for the run. 63 | print_info: Whether to print debug information when submitting. 64 | ask_confirmation: Whether to ask a confirmation before submitting. 65 | run_id: Automatically populated value during submit. 66 | run_name: Automatically populated value during submit. 67 | run_dir: Automatically populated value during submit. 68 | run_func_name: Automatically populated value during submit. 69 | run_func_kwargs: Automatically populated value during submit. 70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 71 | task_name: Automatically populated value during submit. 72 | host_name: Automatically populated value during submit. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | 78 | # run (set these) 79 | self.run_dir_root = "" # should always be passed through get_path_from_template 80 | self.run_desc = "" 81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] 82 | self.run_dir_extra_files = None 83 | 84 | # submit (set these) 85 | self.submit_target = SubmitTarget.LOCAL 86 | self.num_gpus = 1 87 | self.print_info = False 88 | self.ask_confirmation = False 89 | 90 | # (automatically populated) 91 | self.run_id = None 92 | self.run_name = None 93 | self.run_dir = None 94 | self.run_func_name = None 95 | self.run_func_kwargs = None 96 | self.user_name = None 97 | self.task_name = None 98 | self.host_name = "localhost" 99 | 100 | 101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 102 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 103 | # automatically select path type depending on running OS 104 | if path_type == PathType.AUTO: 105 | if platform.system() == "Windows": 106 | path_type = PathType.WINDOWS 107 | elif platform.system() == "Linux": 108 | path_type = PathType.LINUX 109 | else: 110 | raise RuntimeError("Unknown platform") 111 | 112 | path_template = path_template.replace("", get_user_name()) 113 | 114 | # return correctly formatted path 115 | if path_type == PathType.WINDOWS: 116 | return str(pathlib.PureWindowsPath(path_template)) 117 | elif path_type == PathType.LINUX: 118 | return str(pathlib.PurePosixPath(path_template)) 119 | else: 120 | raise RuntimeError("Unknown platform") 121 | 122 | 123 | def get_template_from_path(path: str) -> str: 124 | """Convert a normal path back to its template representation.""" 125 | # replace all path parts with the template tags 126 | path = path.replace("\\", "/") 127 | return path 128 | 129 | 130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 131 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 132 | path_template = get_template_from_path(path) 133 | path = get_path_from_template(path_template, path_type) 134 | return path 135 | 136 | 137 | def set_user_name_override(name: str) -> None: 138 | """Set the global username override value.""" 139 | global _user_name_override 140 | _user_name_override = name 141 | 142 | 143 | def get_user_name(): 144 | """Get the current user name.""" 145 | if _user_name_override is not None: 146 | return _user_name_override 147 | elif platform.system() == "Windows": 148 | return os.getlogin() 149 | elif platform.system() == "Linux": 150 | try: 151 | import pwd # pylint: disable=import-error 152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member 153 | except: 154 | return "unknown" 155 | else: 156 | raise RuntimeError("Unknown platform") 157 | 158 | 159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 160 | """Create a new run dir with increasing ID number at the start.""" 161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 162 | 163 | if not os.path.exists(run_dir_root): 164 | print("Creating the run dir root: {}".format(run_dir_root)) 165 | os.makedirs(run_dir_root) 166 | 167 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 169 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 170 | 171 | if os.path.exists(run_dir): 172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 173 | 174 | print("Creating the run dir: {}".format(run_dir)) 175 | os.makedirs(run_dir) 176 | 177 | return run_dir 178 | 179 | 180 | def _get_next_run_id_local(run_dir_root: str) -> int: 181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 183 | r = re.compile("^\\d+") # match one or more digits at the start of the string 184 | run_id = 0 185 | 186 | for dir_name in dir_names: 187 | m = r.match(dir_name) 188 | 189 | if m is not None: 190 | i = int(m.group()) 191 | run_id = max(run_id, i + 1) 192 | 193 | return run_id 194 | 195 | 196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: 197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 198 | print("Copying files to the run dir") 199 | files = [] 200 | 201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 202 | assert '.' in submit_config.run_func_name 203 | for _idx in range(submit_config.run_func_name.count('.') - 1): 204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 206 | 207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 209 | 210 | if submit_config.run_dir_extra_files is not None: 211 | files += submit_config.run_dir_extra_files 212 | 213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] 215 | 216 | util.copy_files_and_create_dirs(files) 217 | 218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 219 | 220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 222 | 223 | 224 | def run_wrapper(submit_config: SubmitConfig) -> None: 225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 227 | 228 | checker = None 229 | 230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 231 | if is_local: 232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 234 | logger = util.Logger(file_name=None, should_flush=True) 235 | 236 | import dnnlib 237 | dnnlib.submit_config = submit_config 238 | 239 | try: 240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 241 | start_time = time.time() 242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) 243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 244 | except: 245 | if is_local: 246 | raise 247 | else: 248 | traceback.print_exc() 249 | 250 | log_src = os.path.join(submit_config.run_dir, "log.txt") 251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 252 | shutil.copyfile(log_src, log_dst) 253 | finally: 254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 255 | 256 | dnnlib.submit_config = None 257 | logger.close() 258 | 259 | if checker is not None: 260 | checker.stop() 261 | 262 | 263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 265 | submit_config = copy.copy(submit_config) 266 | 267 | if submit_config.user_name is None: 268 | submit_config.user_name = get_user_name() 269 | 270 | submit_config.run_func_name = run_func_name 271 | submit_config.run_func_kwargs = run_func_kwargs 272 | 273 | assert submit_config.submit_target == SubmitTarget.LOCAL 274 | if submit_config.submit_target in {SubmitTarget.LOCAL}: 275 | run_dir = _create_run_dir_local(submit_config) 276 | 277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 278 | submit_config.run_dir = run_dir 279 | _populate_run_dir(run_dir, submit_config) 280 | 281 | if submit_config.print_info: 282 | print("\nSubmit config:\n") 283 | pprint.pprint(submit_config, indent=4, width=200, compact=False) 284 | print() 285 | 286 | if submit_config.ask_confirmation: 287 | if not util.ask_yes_no("Continue submitting the job?"): 288 | return 289 | 290 | run_wrapper(submit_config) 291 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Multi-resolution input data pipeline.""" 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import tensorflow as tf 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | 17 | #---------------------------------------------------------------------------- 18 | # Parse individual image from a tfrecords file. 19 | 20 | def parse_tfrecord_tf(record): 21 | features = tf.parse_single_example(record, features={ 22 | # 'axis': tf.FixedLenFeature([1], tf.int64), 23 | 'shape': tf.FixedLenFeature([4], tf.int64), 24 | 'data': tf.FixedLenFeature([], tf.string)}) 25 | data = tf.decode_raw(features['data'], tf.uint8) 26 | return tf.reshape(data, features['shape']) 27 | 28 | def parse_tfrecord_np(record): 29 | ex = tf.train.Example() 30 | ex.ParseFromString(record) 31 | shape = ex.features.feature['shape'].int64_list.value # temporary pylint workaround # pylint: disable=no-member 32 | data = ex.features.feature['data'].bytes_list.value[0] # temporary pylint workaround # pylint: disable=no-member 33 | return np.fromstring(data, np.uint8).reshape(shape) 34 | 35 | #---------------------------------------------------------------------------- 36 | # Dataset class that loads data from tfrecords files. 37 | 38 | class TFRecordDataset: 39 | def __init__(self, 40 | tfrecord_dir, # Directory containing a collection of tfrecords files. 41 | resolution = None, # Dataset resolution, None = autodetect. 42 | label_file = None, # Relative path of the labels file, None = autodetect. 43 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. 44 | repeat = True, # Repeat dataset indefinitely. 45 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 46 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. 47 | buffer_mb = 256, # Read buffer size (megabytes). 48 | num_threads = 2): # Number of concurrent threads. 49 | 50 | self.tfrecord_dir = tfrecord_dir 51 | self.resolution = None 52 | self.resolution_log2 = None 53 | self.shape = [] # [channel, height, width] 54 | self.dtype = 'uint8' 55 | self.dynamic_range = [0, 255] 56 | self.label_file = label_file 57 | self.label_size = None # [component] 58 | self.label_dtype = None 59 | self._np_labels = None 60 | self._tf_minibatch_in = None 61 | self._tf_labels_var = None 62 | self._tf_labels_dataset = None 63 | self._tf_datasets = dict() 64 | self._tf_iterator = None 65 | self._tf_init_ops = dict() 66 | self._tf_minibatch_np = None 67 | self._cur_minibatch = -1 68 | self._cur_lod = -1 69 | 70 | # List tfrecords files and inspect their shapes. 71 | assert os.path.isdir(self.tfrecord_dir) 72 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) 73 | assert len(tfr_files) >= 1 74 | tfr_shapes = [] 75 | for tfr_file in tfr_files: 76 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 77 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): 78 | tfr_shapes.append(parse_tfrecord_np(record).shape) 79 | break 80 | 81 | # Autodetect label filename. 82 | if self.label_file is None: 83 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) 84 | if len(guess): 85 | self.label_file = guess[0] 86 | elif not os.path.isfile(self.label_file): 87 | guess = os.path.join(self.tfrecord_dir, self.label_file) 88 | if os.path.isfile(guess): 89 | self.label_file = guess 90 | 91 | # Determine shape and resolution. 92 | max_shape = max(tfr_shapes, key=np.prod) 93 | self.resolution = resolution if resolution is not None else max_shape[1] 94 | self.resolution_log2 = int(np.log2(self.resolution)) 95 | self.shape = [max_shape[0], self.resolution, self.resolution, self.resolution] 96 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] 97 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes) 98 | assert all(shape[1] == shape[2] for shape in tfr_shapes) 99 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) 100 | assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) 101 | 102 | # Load labels. 103 | assert max_label_size == 'full' or max_label_size >= 0 104 | self._np_labels = np.zeros([1<<20, 0], dtype=np.float32) 105 | if self.label_file is not None and max_label_size != 0: 106 | self._np_labels = np.load(self.label_file) 107 | assert self._np_labels.ndim == 2 108 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: 109 | self._np_labels = self._np_labels[:, :max_label_size] 110 | self.label_size = self._np_labels.shape[1] 111 | self.label_dtype = self._np_labels.dtype.name 112 | 113 | # Build TF expressions. 114 | with tf.name_scope('Dataset'), tf.device('/cpu:0'): 115 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) 116 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') 117 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) 118 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): 119 | if tfr_lod < 0: 120 | continue 121 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) 122 | dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads) 123 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) 124 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize 125 | if shuffle_mb > 0: 126 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) 127 | if repeat: 128 | dset = dset.repeat() 129 | if prefetch_mb > 0: 130 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) 131 | dset = dset.batch(self._tf_minibatch_in) 132 | self._tf_datasets[tfr_lod] = dset 133 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) 134 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} 135 | 136 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). 137 | def configure(self, minibatch_size, lod=0): 138 | lod = int(np.floor(lod)) 139 | assert minibatch_size >= 1 and lod in self._tf_datasets 140 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod: 141 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 142 | self._cur_minibatch = minibatch_size 143 | self._cur_lod = lod 144 | 145 | # Get next minibatch as TensorFlow expressions. 146 | def get_minibatch_tf(self): # => images, labels 147 | return self._tf_iterator.get_next() 148 | 149 | # Get next minibatch as NumPy arrays. 150 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 151 | self.configure(minibatch_size, lod) 152 | if self._tf_minibatch_np is None: 153 | self._tf_minibatch_np = self.get_minibatch_tf() 154 | return tflib.run(self._tf_minibatch_np) 155 | 156 | # Get random labels as TensorFlow expression. 157 | def get_random_labels_tf(self, minibatch_size): # => labels 158 | if self.label_size > 0: 159 | with tf.device('/cpu:0'): 160 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) 161 | return tf.zeros([minibatch_size, 0], self.label_dtype) 162 | 163 | # Get random labels as NumPy array. 164 | def get_random_labels_np(self, minibatch_size): # => labels 165 | if self.label_size > 0: 166 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] 167 | return np.zeros([minibatch_size, 0], self.label_dtype) 168 | 169 | #---------------------------------------------------------------------------- 170 | # Base class for datasets that are generated on the fly. 171 | 172 | class SyntheticDataset: 173 | def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'): 174 | self.resolution = resolution 175 | self.resolution_log2 = int(np.log2(resolution)) 176 | self.shape = [num_channels, resolution, resolution, resolution] 177 | self.dtype = dtype 178 | self.dynamic_range = dynamic_range 179 | self.label_size = label_size 180 | self.label_dtype = label_dtype 181 | self._tf_minibatch_var = None 182 | self._tf_lod_var = None 183 | self._tf_minibatch_np = None 184 | self._tf_labels_np = None 185 | 186 | assert self.resolution == 2 ** self.resolution_log2 187 | with tf.name_scope('Dataset'): 188 | self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var') 189 | self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var') 190 | 191 | def configure(self, minibatch_size, lod=0): 192 | lod = int(np.floor(lod)) 193 | assert minibatch_size >= 1 and 0 <= lod <= self.resolution_log2 194 | tflib.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod}) 195 | 196 | def get_minibatch_tf(self): # => images, labels 197 | with tf.name_scope('SyntheticDataset'): 198 | shrink = tf.cast(2.0 ** tf.cast(self._tf_lod_var, tf.float32), tf.int32) 199 | shape = [self.shape[0], self.shape[1] // shrink, self.shape[2] // shrink, self.shape[3] // shrink] 200 | images = self._generate_images(self._tf_minibatch_var, self._tf_lod_var, shape) 201 | labels = self._generate_labels(self._tf_minibatch_var) 202 | return images, labels 203 | 204 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 205 | self.configure(minibatch_size, lod) 206 | if self._tf_minibatch_np is None: 207 | self._tf_minibatch_np = self.get_minibatch_tf() 208 | return tflib.run(self._tf_minibatch_np) 209 | 210 | def get_random_labels_tf(self, minibatch_size): # => labels 211 | with tf.name_scope('SyntheticDataset'): 212 | return self._generate_labels(minibatch_size) 213 | 214 | def get_random_labels_np(self, minibatch_size): # => labels 215 | self.configure(minibatch_size) 216 | if self._tf_labels_np is None: 217 | self._tf_labels_np = self.get_random_labels_tf(minibatch_size) 218 | return tflib.run(self._tf_labels_np) 219 | 220 | def _generate_images(self, minibatch, lod, shape): # to be overridden by subclasses # pylint: disable=unused-argument 221 | return tf.zeros([minibatch] + shape, self.dtype) 222 | 223 | def _generate_labels(self, minibatch): # to be overridden by subclasses 224 | return tf.zeros([minibatch, self.label_size], self.label_dtype) 225 | 226 | #---------------------------------------------------------------------------- 227 | # Helper func for constructing a dataset object using the given options. 228 | 229 | def load_dataset(class_name='training.dataset.TFRecordDataset', data_dir=None, verbose=False, **kwargs): 230 | adjusted_kwargs = dict(kwargs) 231 | if 'tfrecord_dir' in adjusted_kwargs and data_dir is not None: 232 | adjusted_kwargs['tfrecord_dir'] = os.path.join(data_dir, adjusted_kwargs['tfrecord_dir']) 233 | if verbose: 234 | print('Streaming data using %s...' % class_name) 235 | dataset = dnnlib.util.get_obj_by_name(class_name)(**adjusted_kwargs) 236 | if verbose: 237 | print('Dataset shape =', np.int32(dataset.shape).tolist()) 238 | print('Dynamic range =', dataset.dynamic_range) 239 | print('Label size =', dataset.label_size) 240 | return dataset 241 | 242 | #---------------------------------------------------------------------------- 243 | -------------------------------------------------------------------------------- /training/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility functions.""" 9 | 10 | import os 11 | import glob 12 | import pickle 13 | import re 14 | import numpy as np 15 | from collections import defaultdict 16 | import PIL.Image 17 | import dnnlib 18 | 19 | import config 20 | from training import dataset 21 | import nibabel as nib 22 | 23 | #---------------------------------------------------------------------------- 24 | # Convenience wrappers for pickle that are able to load data produced by 25 | # older versions of the code, and from external URLs. 26 | 27 | def open_file_or_url(file_or_url): 28 | if dnnlib.util.is_url(file_or_url): 29 | return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir) 30 | return open(file_or_url, 'rb') 31 | 32 | def load_pkl(file_or_url): 33 | with open_file_or_url(file_or_url) as file: 34 | return pickle.load(file, encoding='latin1') 35 | 36 | def save_pkl(obj, filename): 37 | with open(filename, 'wb') as file: 38 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 39 | 40 | #---------------------------------------------------------------------------- 41 | # Image utils. 42 | 43 | 44 | # 3D stuff 45 | 46 | def save_mri_image(image, filename, drange=[0,1]): 47 | 48 | image = image[0][0] 49 | image = adjust_dynamic_range(image, drange, [0,255]) 50 | image = np.rint(image).clip(0, 255).astype(np.uint8) 51 | mri = nib.Nifti1Image(image, np.eye(4)) 52 | nib.save(mri, filename) 53 | 54 | 55 | def adjust_dynamic_range(data, drange_in, drange_out): 56 | if drange_in != drange_out: 57 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 58 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 59 | data = data * scale + bias 60 | return data 61 | 62 | def create_image_grid(images, grid_size=None): 63 | assert images.ndim == 3 or images.ndim == 4 64 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 65 | 66 | if grid_size is not None: 67 | grid_w, grid_h = tuple(grid_size) 68 | else: 69 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 70 | grid_h = max((num - 1) // grid_w + 1, 1) 71 | 72 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 73 | for idx in range(num): 74 | x = (idx % grid_w) * img_w 75 | y = (idx // grid_w) * img_h 76 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 77 | return grid 78 | 79 | def convert_to_pil_image(image, drange=[0,1]): 80 | assert image.ndim == 2 or image.ndim == 3 81 | if image.ndim == 3: 82 | if image.shape[0] == 1: 83 | image = image[0] # grayscale CHW => HW 84 | else: 85 | image = image.transpose(1, 2, 0) # CHW -> HWC 86 | 87 | image = adjust_dynamic_range(image, drange, [0,255]) 88 | image = np.rint(image).clip(0, 255).astype(np.uint8) 89 | fmt = 'RGB' if image.ndim == 3 else 'L' 90 | return PIL.Image.fromarray(image, fmt) 91 | 92 | def save_image(image, filename, drange=[0,1], quality=95): 93 | img = convert_to_pil_image(image, drange) 94 | if '.jpg' in filename: 95 | img.save(filename,"JPEG", quality=quality, optimize=True) 96 | else: 97 | img.save(filename) 98 | 99 | def save_image_grid(images, filename, drange=[0,1], grid_size=None): 100 | convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) 101 | 102 | #---------------------------------------------------------------------------- 103 | # Locating results. 104 | 105 | def locate_run_dir(run_id_or_run_dir): 106 | if isinstance(run_id_or_run_dir, str): 107 | if os.path.isdir(run_id_or_run_dir): 108 | return run_id_or_run_dir 109 | converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir) 110 | if os.path.isdir(converted): 111 | return converted 112 | 113 | run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir)) 114 | for search_dir in ['']: 115 | full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir)) 116 | run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir)) 117 | if os.path.isdir(run_dir): 118 | return run_dir 119 | run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*'))) 120 | run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))] 121 | run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)] 122 | if len(run_dirs) == 1: 123 | return run_dirs[0] 124 | raise IOError('Cannot locate result subdir for run', run_id_or_run_dir) 125 | 126 | def list_network_pkls(run_id_or_run_dir, include_final=True): 127 | run_dir = locate_run_dir(run_id_or_run_dir) 128 | pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl'))) 129 | if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': 130 | if include_final: 131 | pkls.append(pkls[0]) 132 | del pkls[0] 133 | return pkls 134 | 135 | def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): 136 | for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]: 137 | if isinstance(candidate, str): 138 | if os.path.isfile(candidate): 139 | return candidate 140 | converted = dnnlib.submission.submit.convert_path(candidate) 141 | if os.path.isfile(converted): 142 | return converted 143 | 144 | pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl) 145 | if len(pkls) >= 1 and snapshot_or_network_pkl is None: 146 | return pkls[-1] 147 | 148 | for pkl in pkls: 149 | try: 150 | name = os.path.splitext(os.path.basename(pkl))[0] 151 | number = int(name.split('-')[-1]) 152 | if number == snapshot_or_network_pkl: 153 | return pkl 154 | except ValueError: pass 155 | except IndexError: pass 156 | raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl) 157 | 158 | def get_id_string_for_network_pkl(network_pkl): 159 | p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') 160 | return '-'.join(p[max(len(p) - 2, 0):]) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Loading data from previous training runs. 164 | 165 | def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): 166 | return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl)) 167 | 168 | def parse_config_for_previous_run(run_id): 169 | run_dir = locate_run_dir(run_id) 170 | 171 | # Parse config.txt. 172 | cfg = defaultdict(dict) 173 | with open(os.path.join(run_dir, 'config.txt'), 'rt') as f: 174 | for line in f: 175 | line = re.sub(r"^{?\s*'(\w+)':\s*{(.*)(},|}})$", r"\1 = {\2}", line.strip()) 176 | if line.startswith('dataset =') or line.startswith('train ='): 177 | exec(line, cfg, cfg) # pylint: disable=exec-used 178 | 179 | # Handle legacy options. 180 | if 'file_pattern' in cfg['dataset']: 181 | cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '') 182 | if 'mirror_augment' in cfg['dataset']: 183 | cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment') 184 | if 'max_labels' in cfg['dataset']: 185 | v = cfg['dataset'].pop('max_labels') 186 | if v is None: v = 0 187 | if v == 'all': v = 'full' 188 | cfg['dataset']['max_label_size'] = v 189 | if 'max_images' in cfg['dataset']: 190 | cfg['dataset'].pop('max_images') 191 | return cfg 192 | 193 | def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment 194 | cfg = parse_config_for_previous_run(run_id) 195 | cfg['dataset'].update(kwargs) 196 | dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset']) 197 | mirror_augment = cfg['train'].get('mirror_augment', False) 198 | return dataset_obj, mirror_augment 199 | 200 | def apply_mirror_augment(minibatch): 201 | mask = np.random.rand(minibatch.shape[0]) < 0.5 202 | minibatch = np.array(minibatch) 203 | minibatch[mask] = minibatch[mask, :, :, ::-1] 204 | return minibatch 205 | 206 | #---------------------------------------------------------------------------- 207 | # Size and contents of the image snapshot grids that are exported 208 | # periodically during training. 209 | 210 | def setup_snapshot_image_grid(G, training_set, 211 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 212 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 213 | 214 | # Select size. 215 | gw = 1; gh = 1 216 | if size == '1080p': 217 | gw = np.clip(1920 // G.output_shape[3], 3, 32) 218 | gh = np.clip(1080 // G.output_shape[2], 2, 32) 219 | if size == '4k': 220 | gw = np.clip(3840 // G.output_shape[3], 7, 7) 221 | gh = np.clip(2160 // G.output_shape[2], 6, 6) 222 | 223 | # Initialize data arrays. 224 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 225 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 226 | latents = np.random.randn(gw * gh, *G.input_shape[1:]) 227 | 228 | # Random layout. 229 | if layout == 'random': 230 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) 231 | 232 | # Class-conditional layouts. 233 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 234 | if layout in class_layouts: 235 | bw, bh = class_layouts[layout] 236 | nw = (gw - 1) // bw + 1 237 | nh = (gh - 1) // bh + 1 238 | blocks = [[] for _i in range(nw * nh)] 239 | for _iter in range(1000000): 240 | real, label = training_set.get_minibatch_np(1) 241 | idx = np.argmax(label[0]) 242 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 243 | idx += training_set.label_size 244 | if idx < len(blocks): 245 | blocks[idx].append((real, label)) 246 | if all(len(block) >= bw * bh for block in blocks): 247 | break 248 | for i, block in enumerate(blocks): 249 | for j, (real, label) in enumerate(block): 250 | x = (i % nw) * bw + j % bw 251 | y = (i // nw) * bh + j // bw 252 | if x < gw and y < gh: 253 | reals[x + y * gw] = real[0] 254 | labels[x + y * gw] = label[0] 255 | 256 | # print(labels) 257 | 258 | return (gw, gh), reals, labels, latents 259 | 260 | 261 | def setup_test_snapshot_image_grid(G, training_set, 262 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 263 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 264 | 265 | # Select size. 266 | gw = 7; gh = 6 267 | # if size == '1080p': 268 | # gw = np.clip(1920 // G.output_shape[3], 3, 32) 269 | # gh = np.clip(1080 // G.output_shape[2], 2, 32) 270 | # if size == '4k': 271 | # gw = np.clip(3840 // G.output_shape[3], 7, 32) 272 | # gh = np.clip(2160 // G.output_shape[2], 4, 32) 273 | # if size=='custom': 274 | # gw = np.clip(1920 // G.output_shape[3], 3, 32) 275 | # gh = np.clip(1536 // G.output_shape[2], 2, 32) 276 | 277 | # Initialize data arrays. 278 | # np.random.seed(23) 279 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 280 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 281 | latents = np.repeat([np.zeros((7, *G.input_shape[1:]))], (gw*gh)//7, axis=0).reshape((gw*gh,*G.input_shape[1:])) 282 | 283 | # latents[1]=latents[0] 284 | 285 | # ------------------------------------ 286 | 287 | # Interpolation 288 | 289 | # temp = latents[1] 290 | # temp = np.random.randn(1, *G.input_shape[1:]) 291 | # for i in range(1,5): 292 | # latents[i] = (latents[0]*(i)+temp*(5-i))/4 293 | # latents[5]=latents[4] 294 | # latents[4]=latents[3] 295 | # latents[3]=(latents[0]+temp)/2.0 296 | 297 | # ------------------------------------ 298 | 299 | # for i in range(gh): 300 | # for j in range(gw): 301 | # labels = np.zeros(6) 302 | # labels[i]= 303 | 304 | # print(latents[2]-latents[1]) 305 | 306 | 307 | # Random layout. 308 | if layout == 'random': 309 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) 310 | 311 | # Class-conditional layouts. 312 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 313 | if layout in class_layouts: 314 | bw, bh = class_layouts[layout] 315 | nw = (gw - 1) // bw + 1 316 | nh = (gh - 1) // bh + 1 317 | blocks = [[] for _i in range(nw * nh)] 318 | for _iter in range(1000000): 319 | real, label = training_set.get_minibatch_np(1) 320 | idx = np.argmax(label[0]) 321 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 322 | idx += training_set.label_size 323 | if idx < len(blocks): 324 | blocks[idx].append((real, label)) 325 | if all(len(block) >= bw * bh for block in blocks): 326 | break 327 | for i, block in enumerate(blocks): 328 | for j, (real, label) in enumerate(block): 329 | x = (i % nw) * bw + j % bw 330 | y = (i // nw) * bh + j // bw 331 | if x < gw and y < gh: 332 | reals[x + y * gw] = real[0] 333 | labels[x + y * gw] = label[0] 334 | 335 | print(labels) 336 | labels[0] = [0,0,0,0,0,0] 337 | 338 | return (gw, gh), reals, labels, latents 339 | 340 | #---------------------------------------------------------------------------- 341 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | # import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import uuid 27 | 28 | from distutils.util import strtobool 29 | from typing import Any, List, Tuple, Union 30 | 31 | 32 | # Util classes 33 | # ------------------------------------------------------------------------------------------ 34 | 35 | 36 | class EasyDict(dict): 37 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 38 | 39 | def __getattr__(self, name: str) -> Any: 40 | try: 41 | return self[name] 42 | except KeyError: 43 | raise AttributeError(name) 44 | 45 | def __setattr__(self, name: str, value: Any) -> None: 46 | self[name] = value 47 | 48 | def __delattr__(self, name: str) -> None: 49 | del self[name] 50 | 51 | 52 | class Logger(object): 53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 54 | 55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 56 | self.file = None 57 | 58 | if file_name is not None: 59 | self.file = open(file_name, file_mode) 60 | 61 | self.should_flush = should_flush 62 | self.stdout = sys.stdout 63 | self.stderr = sys.stderr 64 | 65 | sys.stdout = self 66 | sys.stderr = self 67 | 68 | def __enter__(self) -> "Logger": 69 | return self 70 | 71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 72 | self.close() 73 | 74 | def write(self, text: str) -> None: 75 | """Write text to stdout (and a file) and optionally flush.""" 76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 77 | return 78 | 79 | if self.file is not None: 80 | self.file.write(text) 81 | 82 | self.stdout.write(text) 83 | 84 | if self.should_flush: 85 | self.flush() 86 | 87 | def flush(self) -> None: 88 | """Flush written text to both stdout and a file, if open.""" 89 | if self.file is not None: 90 | self.file.flush() 91 | 92 | self.stdout.flush() 93 | 94 | def close(self) -> None: 95 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 96 | self.flush() 97 | 98 | # if using multiple loggers, prevent closing in wrong order 99 | if sys.stdout is self: 100 | sys.stdout = self.stdout 101 | if sys.stderr is self: 102 | sys.stderr = self.stderr 103 | 104 | if self.file is not None: 105 | self.file.close() 106 | 107 | 108 | # Small util functions 109 | # ------------------------------------------------------------------------------------------ 110 | 111 | 112 | def format_time(seconds: Union[int, float]) -> str: 113 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 114 | s = int(np.rint(seconds)) 115 | 116 | if s < 60: 117 | return "{0}s".format(s) 118 | elif s < 60 * 60: 119 | return "{0}m {1:02}s".format(s // 60, s % 60) 120 | elif s < 24 * 60 * 60: 121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 122 | else: 123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 124 | 125 | 126 | def ask_yes_no(question: str) -> bool: 127 | """Ask the user the question until the user inputs a valid answer.""" 128 | while True: 129 | try: 130 | print("{0} [y/n]".format(question)) 131 | return strtobool(input().lower()) 132 | except ValueError: 133 | pass 134 | 135 | 136 | def tuple_product(t: Tuple) -> Any: 137 | """Calculate the product of the tuple elements.""" 138 | result = 1 139 | 140 | for v in t: 141 | result *= v 142 | 143 | return result 144 | 145 | 146 | _str_to_ctype = { 147 | "uint8": ctypes.c_ubyte, 148 | "uint16": ctypes.c_uint16, 149 | "uint32": ctypes.c_uint32, 150 | "uint64": ctypes.c_uint64, 151 | "int8": ctypes.c_byte, 152 | "int16": ctypes.c_int16, 153 | "int32": ctypes.c_int32, 154 | "int64": ctypes.c_int64, 155 | "float32": ctypes.c_float, 156 | "float64": ctypes.c_double 157 | } 158 | 159 | 160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 162 | type_str = None 163 | 164 | if isinstance(type_obj, str): 165 | type_str = type_obj 166 | elif hasattr(type_obj, "__name__"): 167 | type_str = type_obj.__name__ 168 | elif hasattr(type_obj, "name"): 169 | type_str = type_obj.name 170 | else: 171 | raise RuntimeError("Cannot infer type name from input") 172 | 173 | assert type_str in _str_to_ctype.keys() 174 | 175 | my_dtype = np.dtype(type_str) 176 | my_ctype = _str_to_ctype[type_str] 177 | 178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 179 | 180 | return my_dtype, my_ctype 181 | 182 | 183 | def is_pickleable(obj: Any) -> bool: 184 | try: 185 | with io.BytesIO() as stream: 186 | pickle.dump(obj, stream) 187 | return True 188 | except: 189 | return False 190 | 191 | 192 | # Functionality to import modules/objects by name, and call functions by name 193 | # ------------------------------------------------------------------------------------------ 194 | 195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 196 | """Searches for the underlying module behind the name to some python object. 197 | Returns the module and the object name (original name with module part removed).""" 198 | 199 | # allow convenience shorthands, substitute them by full names 200 | obj_name = re.sub("^np.", "numpy.", obj_name) 201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 202 | 203 | # list alternatives for (module_name, local_obj_name) 204 | parts = obj_name.split(".") 205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 206 | 207 | # try each alternative in turn 208 | for module_name, local_obj_name in name_pairs: 209 | try: 210 | module = importlib.import_module(module_name) # may raise ImportError 211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 212 | return module, local_obj_name 213 | except: 214 | pass 215 | 216 | # maybe some of the modules themselves contain errors? 217 | for module_name, _local_obj_name in name_pairs: 218 | try: 219 | importlib.import_module(module_name) # may raise ImportError 220 | except ImportError: 221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 222 | raise 223 | 224 | # maybe the requested attribute is missing? 225 | for module_name, local_obj_name in name_pairs: 226 | try: 227 | module = importlib.import_module(module_name) # may raise ImportError 228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 229 | except ImportError: 230 | pass 231 | 232 | # we are out of luck, but we have no idea why 233 | raise ImportError(obj_name) 234 | 235 | 236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 237 | """Traverses the object name and returns the last (rightmost) python object.""" 238 | if obj_name == '': 239 | return module 240 | obj = module 241 | for part in obj_name.split("."): 242 | obj = getattr(obj, part) 243 | return obj 244 | 245 | 246 | def get_obj_by_name(name: str) -> Any: 247 | """Finds the python object with the given name.""" 248 | module, obj_name = get_module_from_obj_name(name) 249 | return get_obj_from_module(module, obj_name) 250 | 251 | 252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 253 | """Finds the python object with the given name and calls it as a function.""" 254 | assert func_name is not None 255 | func_obj = get_obj_by_name(func_name) 256 | assert callable(func_obj) 257 | return func_obj(*args, **kwargs) 258 | 259 | 260 | def get_module_dir_by_obj_name(obj_name: str) -> str: 261 | """Get the directory path of the module containing the given object name.""" 262 | module, _ = get_module_from_obj_name(obj_name) 263 | return os.path.dirname(inspect.getfile(module)) 264 | 265 | 266 | def is_top_level_function(obj: Any) -> bool: 267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 269 | 270 | 271 | def get_top_level_function_name(obj: Any) -> str: 272 | """Return the fully-qualified name of a top-level function.""" 273 | assert is_top_level_function(obj) 274 | return obj.__module__ + "." + obj.__name__ 275 | 276 | 277 | # File system helpers 278 | # ------------------------------------------------------------------------------------------ 279 | 280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 281 | """List all files recursively in a given directory while ignoring given file and directory names. 282 | Returns list of tuples containing both absolute and relative paths.""" 283 | assert os.path.isdir(dir_path) 284 | base_name = os.path.basename(os.path.normpath(dir_path)) 285 | 286 | if ignores is None: 287 | ignores = [] 288 | 289 | result = [] 290 | 291 | for root, dirs, files in os.walk(dir_path, topdown=True): 292 | for ignore_ in ignores: 293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 294 | 295 | # dirs need to be edited in-place 296 | for d in dirs_to_remove: 297 | dirs.remove(d) 298 | 299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 300 | 301 | absolute_paths = [os.path.join(root, f) for f in files] 302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 303 | 304 | if add_base_to_relative: 305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 306 | 307 | assert len(absolute_paths) == len(relative_paths) 308 | result += zip(absolute_paths, relative_paths) 309 | 310 | return result 311 | 312 | 313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 314 | """Takes in a list of tuples of (src, dst) paths and copies files. 315 | Will create all necessary directories.""" 316 | for file in files: 317 | target_dir_name = os.path.dirname(file[1]) 318 | 319 | if '/results/' in file[0] or '/patch_results/' in file[0]: 320 | continue 321 | # will create all intermediate-level directories 322 | if not os.path.exists(target_dir_name): 323 | os.makedirs(target_dir_name) 324 | 325 | shutil.copyfile(file[0], file[1]) 326 | 327 | 328 | # URL helpers 329 | # ------------------------------------------------------------------------------------------ 330 | 331 | def is_url(obj: Any) -> bool: 332 | """Determine whether the given object is a valid URL string.""" 333 | if not isinstance(obj, str) or not "://" in obj: 334 | return False 335 | try: 336 | res = requests.compat.urlparse(obj) 337 | if not res.scheme or not res.netloc or not "." in res.netloc: 338 | return False 339 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 340 | if not res.scheme or not res.netloc or not "." in res.netloc: 341 | return False 342 | except: 343 | return False 344 | return True 345 | 346 | 347 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: 348 | """Download the given URL and return a binary-mode file object to access the data.""" 349 | assert is_url(url) 350 | assert num_attempts >= 1 351 | 352 | # Lookup from cache. 353 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 354 | if cache_dir is not None: 355 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 356 | if len(cache_files) == 1: 357 | return open(cache_files[0], "rb") 358 | 359 | # Download. 360 | url_name = None 361 | url_data = None 362 | with requests.Session() as session: 363 | if verbose: 364 | print("Downloading %s ..." % url, end="", flush=True) 365 | for attempts_left in reversed(range(num_attempts)): 366 | try: 367 | with session.get(url) as res: 368 | res.raise_for_status() 369 | if len(res.content) == 0: 370 | raise IOError("No data received") 371 | 372 | if len(res.content) < 8192: 373 | content_str = res.content.decode("utf-8") 374 | if "download_warning" in res.headers.get("Set-Cookie", ""): 375 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 376 | if len(links) == 1: 377 | url = requests.compat.urljoin(url, links[0]) 378 | raise IOError("Google Drive virus checker nag") 379 | if "Google Drive - Quota exceeded" in content_str: 380 | raise IOError("Google Drive quota exceeded") 381 | 382 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 383 | url_name = match[1] if match else url 384 | url_data = res.content 385 | if verbose: 386 | print(" done") 387 | break 388 | except: 389 | if not attempts_left: 390 | if verbose: 391 | print(" failed") 392 | raise 393 | if verbose: 394 | print(".", end="", flush=True) 395 | 396 | # Save to cache. 397 | if cache_dir is not None: 398 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 399 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 400 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 401 | os.makedirs(cache_dir, exist_ok=True) 402 | with open(temp_file, "wb") as f: 403 | f.write(url_data) 404 | os.replace(temp_file, cache_file) # atomic 405 | 406 | # Return data as file object. 407 | return io.BytesIO(url_data) 408 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Main entry point for training StyleGAN and ProGAN networks.""" 9 | 10 | import copy 11 | import dnnlib 12 | from dnnlib import EasyDict 13 | 14 | import config 15 | from metrics import metric_base 16 | 17 | #---------------------------------------------------------------------------- 18 | # Official training configs for StyleGAN, targeted mainly for FFHQ. 19 | 20 | if 1: 21 | desc = 'sgan' # Description string included in result subdir name. 22 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 23 | mixing = EasyDict(run_func_name='training.training_loop.mixing') 24 | test_d = EasyDict(run_func_name='training.training_loop.test_d') 25 | G = EasyDict(func_name='training.networks_stylegan.G_style') # Options for generator network. 26 | D = EasyDict(func_name='training.networks_stylegan.D_basic') # Options for discriminator network. 27 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 28 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 29 | G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating') # Options for generator loss. 30 | D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss. 31 | dataset = EasyDict() # Options for load_dataset(). 32 | sched = EasyDict() # Options for TrainingSchedule. 33 | grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid(). 34 | metrics = [metric_base.fid50k] # Options for MetricGroup. 35 | submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 36 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). 37 | 38 | # Dataset. 39 | # desc+= '-mixing' 40 | 41 | desc += '-mri'; dataset = EasyDict(tfrecord_dir='ixi-128'); train.mirror_augment = False 42 | # desc += '-test-d' 43 | 44 | # desc += '-ffhq'; dataset = EasyDict(tfrecord_dir='ffhq'); train.mirror_augment = True 45 | #desc += '-ffhq512'; dataset = EasyDict(tfrecord_dir='ffhq', resolution=512); train.mirror_augment = True 46 | #desc += '-ffhq256'; dataset = EasyDict(tfrecord_dir='ffhq', resolution=256); train.mirror_augment = True 47 | #desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True 48 | #desc += '-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-full'); train.mirror_augment = False 49 | #desc += '-car'; dataset = EasyDict(tfrecord_dir='lsun-car-512x384'); train.mirror_augment = False 50 | #desc += '-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-full'); train.mirror_augment = False 51 | 52 | # Number of GPUs. 53 | # desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 128, 64: 64, 128: 32, 256: 16} 54 | # desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8} 55 | #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16} 56 | desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 128, 64: 64, 128: 32, 256: 16} 57 | 58 | # Default options. 59 | train.total_kimg = 25000 60 | sched.lod_initial_resolution = 4 61 | sched.G_lrate_dict = {128: 0.005, 256: 0.002, 512: 0.003, 1024: 0.003} 62 | sched.D_lrate_dict = {128: 0.002, 256: 0.002, 512: 0.003, 1024: 0.003} 63 | 64 | # train.resume_run_id=51 65 | # train.test=True 66 | # train.is_training=False 67 | # train.resume_snapshot=10360 68 | # train.resume_kimg=6060 69 | 70 | # test_d.resume_run_id=46 71 | 72 | # mixing.resume_run_id=29 73 | 74 | # WGAN-GP loss for CelebA-HQ. 75 | #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) 76 | 77 | # Table 1. 78 | #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 79 | #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 80 | #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False 81 | #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0 82 | #desc += '-mixing-regularization' # default 83 | 84 | # Table 2. 85 | #desc += '-mix0'; G.style_mixing_prob = 0.0 86 | #desc += '-mix50'; G.style_mixing_prob = 0.5 87 | #desc += '-mix90'; G.style_mixing_prob = 0.9 # default 88 | #desc += '-mix100'; G.style_mixing_prob = 1.0 89 | 90 | # Table 4. 91 | #desc += '-traditional-0'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 92 | #desc += '-traditional-8'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 8; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 93 | #desc += '-stylebased-0'; G.mapping_layers = 0 94 | #desc += '-stylebased-1'; G.mapping_layers = 1 95 | #desc += '-stylebased-2'; G.mapping_layers = 2 96 | #desc += '-stylebased-8'; G.mapping_layers = 8 # default 97 | 98 | # desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label 99 | # desc += '-grpc'; grid.layout = 'row_per_class' 100 | # grid.size='4k' 101 | 102 | #---------------------------------------------------------------------------- 103 | # Official training configs for Progressive GAN, targeted mainly for CelebA-HQ. 104 | 105 | if 0: 106 | desc = 'pgan' # Description string included in result subdir name. 107 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 108 | G = EasyDict(func_name='training.networks_progan.G_paper') # Options for generator network. 109 | D = EasyDict(func_name='training.networks_progan.D_paper') # Options for discriminator network. 110 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 111 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 112 | G_loss = EasyDict(func_name='training.loss.G_wgan') # Options for generator loss. 113 | D_loss = EasyDict(func_name='training.loss.D_wgan_gp') # Options for discriminator loss. 114 | dataset = EasyDict() # Options for load_dataset(). 115 | sched = EasyDict() # Options for TrainingSchedule. 116 | grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid(). 117 | metrics = [metric_base.fid50k] # Options for MetricGroup. 118 | submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 119 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). 120 | 121 | # Dataset (choose one). 122 | desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True 123 | #desc += '-celeba'; dataset = EasyDict(tfrecord_dir='celeba'); train.mirror_augment = True 124 | #desc += '-cifar10'; dataset = EasyDict(tfrecord_dir='cifar10') 125 | #desc += '-cifar100'; dataset = EasyDict(tfrecord_dir='cifar100') 126 | #desc += '-svhn'; dataset = EasyDict(tfrecord_dir='svhn') 127 | #desc += '-mnist'; dataset = EasyDict(tfrecord_dir='mnist') 128 | #desc += '-mnistrgb'; dataset = EasyDict(tfrecord_dir='mnistrgb') 129 | #desc += '-syn1024rgb'; dataset = EasyDict(class_name='training.dataset.SyntheticDataset', resolution=1024, num_channels=3) 130 | #desc += '-lsun-airplane'; dataset = EasyDict(tfrecord_dir='lsun-airplane-100k'); train.mirror_augment = True 131 | #desc += '-lsun-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-100k'); train.mirror_augment = True 132 | #desc += '-lsun-bicycle'; dataset = EasyDict(tfrecord_dir='lsun-bicycle-100k'); train.mirror_augment = True 133 | #desc += '-lsun-bird'; dataset = EasyDict(tfrecord_dir='lsun-bird-100k'); train.mirror_augment = True 134 | #desc += '-lsun-boat'; dataset = EasyDict(tfrecord_dir='lsun-boat-100k'); train.mirror_augment = True 135 | #desc += '-lsun-bottle'; dataset = EasyDict(tfrecord_dir='lsun-bottle-100k'); train.mirror_augment = True 136 | #desc += '-lsun-bridge'; dataset = EasyDict(tfrecord_dir='lsun-bridge-100k'); train.mirror_augment = True 137 | #desc += '-lsun-bus'; dataset = EasyDict(tfrecord_dir='lsun-bus-100k'); train.mirror_augment = True 138 | #desc += '-lsun-car'; dataset = EasyDict(tfrecord_dir='lsun-car-100k'); train.mirror_augment = True 139 | #desc += '-lsun-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-100k'); train.mirror_augment = True 140 | #desc += '-lsun-chair'; dataset = EasyDict(tfrecord_dir='lsun-chair-100k'); train.mirror_augment = True 141 | #desc += '-lsun-churchoutdoor'; dataset = EasyDict(tfrecord_dir='lsun-churchoutdoor-100k'); train.mirror_augment = True 142 | #desc += '-lsun-classroom'; dataset = EasyDict(tfrecord_dir='lsun-classroom-100k'); train.mirror_augment = True 143 | #desc += '-lsun-conferenceroom'; dataset = EasyDict(tfrecord_dir='lsun-conferenceroom-100k'); train.mirror_augment = True 144 | #desc += '-lsun-cow'; dataset = EasyDict(tfrecord_dir='lsun-cow-100k'); train.mirror_augment = True 145 | #desc += '-lsun-diningroom'; dataset = EasyDict(tfrecord_dir='lsun-diningroom-100k'); train.mirror_augment = True 146 | #desc += '-lsun-diningtable'; dataset = EasyDict(tfrecord_dir='lsun-diningtable-100k'); train.mirror_augment = True 147 | #desc += '-lsun-dog'; dataset = EasyDict(tfrecord_dir='lsun-dog-100k'); train.mirror_augment = True 148 | #desc += '-lsun-horse'; dataset = EasyDict(tfrecord_dir='lsun-horse-100k'); train.mirror_augment = True 149 | #desc += '-lsun-kitchen'; dataset = EasyDict(tfrecord_dir='lsun-kitchen-100k'); train.mirror_augment = True 150 | #desc += '-lsun-livingroom'; dataset = EasyDict(tfrecord_dir='lsun-livingroom-100k'); train.mirror_augment = True 151 | #desc += '-lsun-motorbike'; dataset = EasyDict(tfrecord_dir='lsun-motorbike-100k'); train.mirror_augment = True 152 | #desc += '-lsun-person'; dataset = EasyDict(tfrecord_dir='lsun-person-100k'); train.mirror_augment = True 153 | #desc += '-lsun-pottedplant'; dataset = EasyDict(tfrecord_dir='lsun-pottedplant-100k'); train.mirror_augment = True 154 | #desc += '-lsun-restaurant'; dataset = EasyDict(tfrecord_dir='lsun-restaurant-100k'); train.mirror_augment = True 155 | #desc += '-lsun-sheep'; dataset = EasyDict(tfrecord_dir='lsun-sheep-100k'); train.mirror_augment = True 156 | #desc += '-lsun-sofa'; dataset = EasyDict(tfrecord_dir='lsun-sofa-100k'); train.mirror_augment = True 157 | #desc += '-lsun-tower'; dataset = EasyDict(tfrecord_dir='lsun-tower-100k'); train.mirror_augment = True 158 | #desc += '-lsun-train'; dataset = EasyDict(tfrecord_dir='lsun-train-100k'); train.mirror_augment = True 159 | #desc += '-lsun-tvmonitor'; dataset = EasyDict(tfrecord_dir='lsun-tvmonitor-100k'); train.mirror_augment = True 160 | 161 | # Conditioning & snapshot options. 162 | #desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label 163 | #desc += '-cond1'; dataset.max_label_size = 1 # conditioned on first component of the label 164 | #desc += '-g4k'; grid.size = '4k' 165 | #desc += '-grpc'; grid.layout = 'row_per_class' 166 | 167 | # Config presets (choose one). 168 | #desc += '-preset-v1-1gpu'; submit_config.num_gpus = 1; D.mbstd_group_size = 16; sched.minibatch_base = 16; sched.minibatch_dict = {256: 14, 512: 6, 1024: 3}; sched.lod_training_kimg = 800; sched.lod_transition_kimg = 800; train.total_kimg = 19000 169 | desc += '-preset-v2-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 170 | #desc += '-preset-v2-2gpus'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}; sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 171 | #desc += '-preset-v2-4gpus'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}; sched.G_lrate_dict = {256: 0.0015, 512: 0.002, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 172 | #desc += '-preset-v2-8gpus'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}; sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 173 | 174 | # Numerical precision (choose one). 175 | desc += '-fp32'; sched.max_minibatch_per_gpu = {256: 16, 512: 8, 1024: 4} 176 | #desc += '-fp16'; G.dtype = 'float16'; D.dtype = 'float16'; G.pixelnorm_epsilon=1e-4; G_opt.use_loss_scaling = True; D_opt.use_loss_scaling = True; sched.max_minibatch_per_gpu = {512: 16, 1024: 8} 177 | 178 | # Disable individual features. 179 | #desc += '-nogrowing'; sched.lod_initial_resolution = 1024; sched.lod_training_kimg = 0; sched.lod_transition_kimg = 0; train.total_kimg = 10000 180 | #desc += '-nopixelnorm'; G.use_pixelnorm = False 181 | #desc += '-nowscale'; G.use_wscale = False; D.use_wscale = False 182 | #desc += '-noleakyrelu'; G.use_leakyrelu = False 183 | #desc += '-nosmoothing'; train.G_smoothing_kimg = 0.0 184 | #desc += '-norepeat'; train.minibatch_repeats = 1 185 | #desc += '-noreset'; train.reset_opt_for_new_lod = False 186 | 187 | # Special modes. 188 | #desc += '-BENCHMARK'; sched.lod_initial_resolution = 4; sched.lod_training_kimg = 3; sched.lod_transition_kimg = 3; train.total_kimg = (8*2+1)*3; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000 189 | #desc += '-BENCHMARK0'; sched.lod_initial_resolution = 1024; train.total_kimg = 10; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000 190 | #desc += '-VERBOSE'; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1; train.network_snapshot_ticks = 100 191 | #desc += '-GRAPH'; train.save_tf_graph = True 192 | #desc += '-HIST'; train.save_weight_histograms = True 193 | 194 | #---------------------------------------------------------------------------- 195 | # Main entry point for training. 196 | # Calls the function indicated by 'train' using the selected options. 197 | 198 | def main(): 199 | kwargs = EasyDict(train) 200 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) 201 | kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) 202 | # kwargs.update(dataset_args=dataset) 203 | kwargs.submit_config = copy.deepcopy(submit_config) 204 | kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 205 | kwargs.submit_config.run_dir_ignore += config.run_dir_ignore 206 | kwargs.submit_config.run_desc = desc 207 | dnnlib.submit_run(**kwargs) 208 | 209 | #---------------------------------------------------------------------------- 210 | 211 | if __name__ == "__main__": 212 | main() 213 | 214 | #---------------------------------------------------------------------------- 215 | -------------------------------------------------------------------------------- /training/networks_progan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | # NOTE: Do not import any application-specific modules here! 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def lerp(a, b, t): return a + (b - a) * t 16 | def lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 17 | def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda) 18 | 19 | #---------------------------------------------------------------------------- 20 | # Get/create weight tensor for a convolutional or fully-connected layer. 21 | 22 | def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None): 23 | if fan_in is None: fan_in = np.prod(shape[:-1]) 24 | std = gain / np.sqrt(fan_in) # He init 25 | if use_wscale: 26 | wscale = tf.constant(np.float32(std), name='wscale') 27 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale 28 | else: 29 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) 30 | 31 | #---------------------------------------------------------------------------- 32 | # Fully-connected layer. 33 | 34 | def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False): 35 | if len(x.shape) > 2: 36 | x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) 37 | w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 38 | w = tf.cast(w, x.dtype) 39 | return tf.matmul(x, w) 40 | 41 | #---------------------------------------------------------------------------- 42 | # Convolutional layer. 43 | 44 | def conv3d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 45 | assert kernel >= 1 and kernel % 2 == 1 46 | w = get_weight([kernel, kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 47 | w = tf.cast(w, x.dtype) 48 | return tf.nn.conv3d(x, w, strides=[1,1,1,1,1], padding='SAME', data_format='NCDHW') 49 | 50 | #---------------------------------------------------------------------------- 51 | # Apply bias to the given activation tensor. 52 | 53 | def apply_bias(x): 54 | b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) 55 | b = tf.cast(b, x.dtype) 56 | if len(x.shape) == 2: 57 | return x + b 58 | else: 59 | return x + tf.reshape(b, [1, -1, 1, 1, 1]) 60 | 61 | #---------------------------------------------------------------------------- 62 | # Leaky ReLU activation. Same as tf.nn.leaky_relu, but supports FP16. 63 | 64 | def leaky_relu(x, alpha=0.2): 65 | with tf.name_scope('LeakyRelu'): 66 | alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') 67 | return tf.maximum(x * alpha, x) 68 | 69 | #---------------------------------------------------------------------------- 70 | # Nearest-neighbor upscaling layer. 71 | 72 | def upscale3d(x, factor=2): 73 | assert isinstance(factor, int) and factor >= 1 74 | if factor == 1: return x 75 | with tf.variable_scope('Upscale2D'): 76 | s = x.shape 77 | x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1, s[4], 1]) 78 | x = tf.tile(x, [1, 1, 1, factor, 1, factor, 1, factor]) 79 | x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor, s[4] * factor]) 80 | return x 81 | 82 | #---------------------------------------------------------------------------- 83 | # Fused upscale2d + conv2d. 84 | # Faster and uses less memory than performing the operations separately. 85 | 86 | def upscale3d_conv3d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 87 | assert kernel >= 1 and kernel % 2 == 1 88 | w = get_weight([kernel, kernel, kernel, fmaps, x.shape[1].value], gain=gain, use_wscale=use_wscale, fan_in=(kernel**2)*x.shape[1].value) 89 | w = tf.pad(w, [[1,1], [1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') 90 | w = tf.add_n([w[1:, 1:, 1:], w[1:, 1:, :-1], w[1:, :-1, 1:], w[1:, :-1, :-1], w[:-1, 1:, 1:], w[:-1, 1:, :-1], w[:-1, :-1, 1:], w[:-1, :-1, :-1]]) 91 | w = tf.cast(w, x.dtype) 92 | os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2, x.shape[4] * 2] 93 | return tf.nn.conv3d_transpose(x, w, os, strides=[1,1,2,2,2], padding='SAME', data_format='NCDHW') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Box filter downscaling layer. 97 | 98 | def downscale3d(x, factor=2): 99 | assert isinstance(factor, int) and factor >= 1 100 | if factor == 1: return x 101 | with tf.variable_scope('Downscale2D'): 102 | ksize = [1, 1, factor, factor, factor] 103 | return tf.nn.avg_pool3d(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCDHW') # NOTE: requires tf_config['graph_options.place_pruned_graph'] = True 104 | 105 | #---------------------------------------------------------------------------- 106 | # Fused conv2d + downscale2d. 107 | # Faster and uses less memory than performing the operations separately. 108 | 109 | def conv3d_downscale3d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 110 | assert kernel >= 1 and kernel % 2 == 1 111 | w = get_weight([kernel, kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 112 | w = tf.pad(w, [[1,1], [1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') 113 | w = tf.add_n([w[1:, 1:, 1:], w[1:, 1:, :-1], w[1:, :-1, 1:], w[1:, :-1, :-1], w[:-1, 1:, 1:], w[:-1, 1:, :-1], w[:-1, :-1, 1:], w[:-1, :-1, :-1]]) * 0.125 114 | w = tf.cast(w, x.dtype) 115 | return tf.nn.conv3d(x, w, strides=[1,1,2,2,2], padding='SAME', data_format='NCDHW') 116 | 117 | #---------------------------------------------------------------------------- 118 | # Pixelwise feature vector normalization. 119 | 120 | def pixel_norm(x, epsilon=1e-8): 121 | with tf.variable_scope('PixelNorm'): 122 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) 123 | 124 | #---------------------------------------------------------------------------- 125 | # Minibatch standard deviation. 126 | 127 | def minibatch_stddev_layer(x, group_size=4): 128 | with tf.variable_scope('MinibatchStddev'): 129 | group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. 130 | s = x.shape # [NCHW] Input shape. 131 | y = tf.reshape(x, [group_size, -1, s[1], s[2], s[3], s[4]]) # [GMCHW] Split minibatch into M groups of size G. 132 | y = tf.cast(y, tf.float32) # [GMCHW] Cast to FP32. 133 | y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMCHW] Subtract mean over group. 134 | y = tf.reduce_mean(tf.square(y), axis=0) # [MCHW] Calc variance over group. 135 | y = tf.sqrt(y + 1e-8) # [MCHW] Calc stddev over group. 136 | y = tf.reduce_mean(y, axis=[1,2,3,4], keepdims=True) # [M111] Take average over fmaps and pixels. 137 | y = tf.cast(y, x.dtype) # [M111] Cast back to original data type. 138 | y = tf.tile(y, [group_size, 1, s[2], s[3], s[4]]) # [N1HW] Replicate over group and pixels. 139 | return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. 140 | 141 | #---------------------------------------------------------------------------- 142 | # Generator network used in the paper. 143 | 144 | def G_paper( 145 | latents_in, # First input: Latent vectors [minibatch, latent_size]. 146 | labels_in, # Second input: Labels [minibatch, label_size]. 147 | num_channels = 1, # Number of output color channels. Overridden based on dataset. 148 | resolution = 32, # Output resolution. Overridden based on dataset. 149 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 150 | fmap_base = 2048, # Overall multiplier for the number of feature maps. 151 | fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. 152 | fmap_max = 512, # Maximum number of feature maps in any layer. 153 | latent_size = 2048, # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max). 154 | normalize_latents = True, # Normalize latent vectors before feeding them to the network? 155 | use_wscale = True, # Enable equalized learning rate? 156 | use_pixelnorm = True, # Enable pixelwise feature vector normalization? 157 | pixelnorm_epsilon = 1e-8, # Constant epsilon for pixelwise feature vector normalization. 158 | use_leakyrelu = True, # True = leaky ReLU, False = ReLU. 159 | dtype = 'float32', # Data type to use for activations and outputs. 160 | fused_scale = True, # True = use fused upscale2d + conv2d, False = separate upscale2d layers. 161 | structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically. 162 | is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. 163 | **kwargs): # Ignore unrecognized keyword args. 164 | 165 | resolution_log2 = int(np.log2(resolution)) 166 | assert resolution == 2**resolution_log2 and resolution >= 4 167 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 168 | def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x 169 | if latent_size is None: latent_size = nf(0) 170 | if structure is None: structure = 'linear' if is_template_graph else 'recursive' 171 | act = leaky_relu if use_leakyrelu else tf.nn.relu 172 | 173 | latents_in.set_shape([None, latent_size]) 174 | labels_in.set_shape([None, label_size]) 175 | combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype) 176 | lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) 177 | 178 | # Building blocks. 179 | def block(x, res): # res = 2..resolution_log2 180 | with tf.variable_scope('%dx%d' % (2**res, 2**res)): 181 | if res == 2: # 4x4 182 | if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon) 183 | with tf.variable_scope('Dense'): 184 | x = dense(x, fmaps=nf(res-1)*64, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation 185 | x = tf.reshape(x, [-1, nf(res-1), 4, 4, 4]) 186 | x = PN(act(apply_bias(x))) 187 | with tf.variable_scope('Conv'): 188 | x = PN(act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 189 | else: # 8x8 and up 190 | if fused_scale: 191 | with tf.variable_scope('Conv0_up'): 192 | x = PN(act(apply_bias(upscale3d_conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 193 | else: 194 | x = upscale3d(x) 195 | with tf.variable_scope('Conv0'): 196 | x = PN(act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 197 | # with tf.variable_scope('Conv1'): 198 | # x = PN(act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 199 | return x 200 | def torgb(x, res): # res = 2..resolution_log2 201 | lod = resolution_log2 - res 202 | with tf.variable_scope('ToRGB_lod%d' % lod): 203 | return apply_bias(conv3d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale)) 204 | 205 | # Linear structure: simple but inefficient. 206 | if structure == 'linear': 207 | x = block(combo_in, 2) 208 | images_out = torgb(x, 2) 209 | for res in range(3, resolution_log2 + 1): 210 | lod = resolution_log2 - res 211 | x = block(x, res) 212 | img = torgb(x, res) 213 | images_out = upscale3d(images_out) 214 | with tf.variable_scope('Grow_lod%d' % lod): 215 | images_out = lerp_clip(img, images_out, lod_in - lod) 216 | 217 | # Recursive structure: complex but efficient. 218 | if structure == 'recursive': 219 | def grow(x, res, lod): 220 | y = block(x, res) 221 | img = lambda: upscale3d(torgb(y, res), 2**lod) 222 | if res > 2: img = cset(img, (lod_in > lod), lambda: upscale3d(lerp(torgb(y, res), upscale3d(torgb(x, res - 1)), lod_in - lod), 2**lod)) 223 | if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1)) 224 | return img() 225 | images_out = grow(combo_in, 2, resolution_log2 - 2) 226 | 227 | assert images_out.dtype == tf.as_dtype(dtype) 228 | images_out = tf.identity(images_out, name='images_out') 229 | return images_out 230 | 231 | #---------------------------------------------------------------------------- 232 | # Discriminator network used in the paper. 233 | 234 | def D_paper( 235 | images_in, # Input: Images [minibatch, channel, height, width]. 236 | labels_in, 237 | num_channels = 1, # Number of input color channels. Overridden based on dataset. 238 | resolution = 32, # Input resolution. Overridden based on dataset. 239 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 240 | fmap_base = 2048, # Overall multiplier for the number of feature maps. 241 | fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. 242 | fmap_max = 512, # Maximum number of feature maps in any layer. 243 | use_wscale = True, # Enable equalized learning rate? 244 | mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. 245 | dtype = 'float32', # Data type to use for activations and outputs. 246 | fused_scale = True, # True = use fused conv2d + downscale2d, False = separate downscale2d layers. 247 | structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically 248 | is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. 249 | **kwargs): # Ignore unrecognized keyword args. 250 | 251 | resolution_log2 = int(np.log2(resolution)) 252 | assert resolution == 2**resolution_log2 and resolution >= 4 253 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 254 | if structure is None: structure = 'linear' if is_template_graph else 'recursive' 255 | act = leaky_relu 256 | 257 | images_in.set_shape([None, num_channels, resolution, resolution, resolution]) 258 | images_in = tf.cast(images_in, dtype) 259 | labels_in.set_shape([None, label_size]) 260 | labels_in = tf.cast(labels_in, dtype) 261 | lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) 262 | 263 | scores_out = None 264 | 265 | # Building blocks. 266 | def fromrgb(x, res): # res = 2..resolution_log2 267 | with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): 268 | # print(res, nf(res-1)) 269 | return act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale))) 270 | def block(x, res): # res = 2..resolution_log2 271 | with tf.variable_scope('%dx%d' % (2**res, 2**res)): 272 | if res >= 3: # 8x8 and up 273 | with tf.variable_scope('Conv0'): 274 | x = act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) 275 | if fused_scale: 276 | with tf.variable_scope('Conv1_down'): 277 | x = act(apply_bias(conv3d_downscale3d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) 278 | else: 279 | with tf.variable_scope('Conv1'): 280 | x = act(apply_bias(conv3d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) 281 | x = downscale3d(x) 282 | else: # 4x4 283 | if mbstd_group_size > 1: 284 | x = minibatch_stddev_layer(x, mbstd_group_size) 285 | with tf.variable_scope('Conv'): 286 | x = act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) 287 | with tf.variable_scope('Dense0'): 288 | x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale))) 289 | with tf.variable_scope('Dense1'): 290 | x = apply_bias(dense(x, fmaps=1, gain=1, use_wscale=use_wscale)) 291 | return x 292 | 293 | # Linear structure: simple but inefficient. 294 | if structure == 'linear': 295 | img = images_in 296 | x = fromrgb(img, resolution_log2) 297 | for res in range(resolution_log2, 2, -1): 298 | lod = resolution_log2 - res 299 | x = block(x, res) 300 | img = downscale3d(img) 301 | y = fromrgb(img, res - 1) 302 | with tf.variable_scope('Grow_lod%d' % lod): 303 | x = lerp_clip(x, y, lod_in - lod) 304 | scores_out = block(x, 2) 305 | 306 | # Recursive structure: complex but efficient. 307 | if structure == 'recursive': 308 | def grow(res, lod): 309 | x = lambda: fromrgb(downscale3d(images_in, 2**lod), res) 310 | if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) 311 | x = block(x(), res); y = lambda: x 312 | if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale3d(images_in, 2**(lod+1)), res - 1), lod_in - lod)) 313 | return y() 314 | scores_out = grow(2, resolution_log2 - 2) 315 | 316 | # assert combo_out.dtype == tf.as_dtype(dtype) 317 | # scores_out = tf.identity(combo_out[:, :1], name='scores_out') 318 | # labels_out = tf.identity(combo_out[:, 1:], name='labels_out') 319 | assert scores_out.dtype == tf.as_dtype(dtype) 320 | 321 | scores_out = tf.identity(scores_out, name='scores_out') 322 | 323 | 324 | return scores_out 325 | 326 | #---------------------------------------------------------------------------- 327 | -------------------------------------------------------------------------------- /training/networks_progan_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | # NOTE: Do not import any application-specific modules here! 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def lerp(a, b, t): return a + (b - a) * t 16 | def lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 17 | def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda) 18 | 19 | #---------------------------------------------------------------------------- 20 | # Get/create weight tensor for a convolutional or fully-connected layer. 21 | 22 | def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None): 23 | if fan_in is None: fan_in = np.prod(shape[:-1]) 24 | std = gain / np.sqrt(fan_in) # He init 25 | if use_wscale: 26 | wscale = tf.constant(np.float32(std), name='wscale') 27 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale 28 | else: 29 | return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) 30 | 31 | #---------------------------------------------------------------------------- 32 | # Fully-connected layer. 33 | 34 | def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False): 35 | if len(x.shape) > 2: 36 | x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) 37 | w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 38 | w = tf.cast(w, x.dtype) 39 | return tf.matmul(x, w) 40 | 41 | #---------------------------------------------------------------------------- 42 | # Convolutional layer. 43 | 44 | def conv3d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 45 | assert kernel >= 1 and kernel % 2 == 1 46 | w = get_weight([kernel, kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 47 | w = tf.cast(w, x.dtype) 48 | return tf.nn.conv3d(x, w, strides=[1,1,1,1,1], padding='SAME', data_format='NCDHW') 49 | 50 | #---------------------------------------------------------------------------- 51 | # Apply bias to the given activation tensor. 52 | 53 | def apply_bias(x): 54 | b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) 55 | b = tf.cast(b, x.dtype) 56 | if len(x.shape) == 2: 57 | return x + b 58 | else: 59 | return x + tf.reshape(b, [1, -1, 1, 1, 1]) 60 | 61 | #---------------------------------------------------------------------------- 62 | # Leaky ReLU activation. Same as tf.nn.leaky_relu, but supports FP16. 63 | 64 | def leaky_relu(x, alpha=0.2): 65 | with tf.name_scope('LeakyRelu'): 66 | alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') 67 | return tf.maximum(x * alpha, x) 68 | 69 | #---------------------------------------------------------------------------- 70 | # Nearest-neighbor upscaling layer. 71 | 72 | def upscale3d(x, factor=2): 73 | assert isinstance(factor, int) and factor >= 1 74 | if factor == 1: return x 75 | with tf.variable_scope('Upscale2D'): 76 | s = x.shape 77 | x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1, s[4], 1]) 78 | x = tf.tile(x, [1, 1, 1, factor, 1, factor, 1, factor]) 79 | x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor, s[4] * factor]) 80 | return x 81 | 82 | #---------------------------------------------------------------------------- 83 | # Fused upscale2d + conv2d. 84 | # Faster and uses less memory than performing the operations separately. 85 | 86 | def upscale3d_conv3d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 87 | assert kernel >= 1 and kernel % 2 == 1 88 | w = get_weight([kernel, kernel, kernel, fmaps, x.shape[1].value], gain=gain, use_wscale=use_wscale, fan_in=(kernel**2)*x.shape[1].value) 89 | w = tf.pad(w, [[1,1], [1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') 90 | w = tf.add_n([w[1:, 1:, 1:], w[1:, 1:, :-1], w[1:, :-1, 1:], w[1:, :-1, :-1], w[:-1, 1:, 1:], w[:-1, 1:, :-1], w[:-1, :-1, 1:], w[:-1, :-1, :-1]]) 91 | w = tf.cast(w, x.dtype) 92 | os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2, x.shape[4] * 2] 93 | return tf.nn.conv3d_transpose(x, w, os, strides=[1,1,2,2,2], padding='SAME', data_format='NCDHW') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Box filter downscaling layer. 97 | 98 | def downscale3d(x, factor=2): 99 | assert isinstance(factor, int) and factor >= 1 100 | if factor == 1: return x 101 | with tf.variable_scope('Downscale2D'): 102 | ksize = [1, 1, factor, factor, factor] 103 | return tf.nn.avg_pool3d(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCDHW') # NOTE: requires tf_config['graph_options.place_pruned_graph'] = True 104 | 105 | #---------------------------------------------------------------------------- 106 | # Fused conv2d + downscale2d. 107 | # Faster and uses less memory than performing the operations separately. 108 | 109 | def conv3d_downscale3d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 110 | assert kernel >= 1 and kernel % 2 == 1 111 | w = get_weight([kernel, kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 112 | w = tf.pad(w, [[1,1], [1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') 113 | w = tf.add_n([w[1:, 1:, 1:], w[1:, 1:, :-1], w[1:, :-1, 1:], w[1:, :-1, :-1], w[:-1, 1:, 1:], w[:-1, 1:, :-1], w[:-1, :-1, 1:], w[:-1, :-1, :-1]]) * 0.125 114 | w = tf.cast(w, x.dtype) 115 | return tf.nn.conv3d(x, w, strides=[1,1,2,2,2], padding='SAME', data_format='NCDHW') 116 | 117 | #---------------------------------------------------------------------------- 118 | # Pixelwise feature vector normalization. 119 | 120 | def pixel_norm(x, epsilon=1e-8): 121 | with tf.variable_scope('PixelNorm'): 122 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) 123 | 124 | #---------------------------------------------------------------------------- 125 | # Minibatch standard deviation. 126 | 127 | def minibatch_stddev_layer(x, group_size=4): 128 | with tf.variable_scope('MinibatchStddev'): 129 | group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. 130 | s = x.shape # [NCHW] Input shape. 131 | y = tf.reshape(x, [group_size, -1, s[1], s[2], s[3], s[4]]) # [GMCHW] Split minibatch into M groups of size G. 132 | y = tf.cast(y, tf.float32) # [GMCHW] Cast to FP32. 133 | y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMCHW] Subtract mean over group. 134 | y = tf.reduce_mean(tf.square(y), axis=0) # [MCHW] Calc variance over group. 135 | y = tf.sqrt(y + 1e-8) # [MCHW] Calc stddev over group. 136 | y = tf.reduce_mean(y, axis=[1,2,3,4], keepdims=True) # [M111] Take average over fmaps and pixels. 137 | y = tf.cast(y, x.dtype) # [M111] Cast back to original data type. 138 | y = tf.tile(y, [group_size, 1, s[2], s[3], s[4]]) # [N1HW] Replicate over group and pixels. 139 | return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. 140 | 141 | #---------------------------------------------------------------------------- 142 | # Generator network used in the paper. 143 | 144 | def G_paper( 145 | latents_in, # First input: Latent vectors [minibatch, latent_size]. 146 | labels_in, # Second input: Labels [minibatch, label_size]. 147 | num_channels = 1, # Number of output color channels. Overridden based on dataset. 148 | resolution = 32, # Output resolution. Overridden based on dataset. 149 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 150 | fmap_base = 2048, # Overall multiplier for the number of feature maps. 151 | fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. 152 | fmap_max = 256, # Maximum number of feature maps in any layer. 153 | latent_size = 2048, # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max). 154 | normalize_latents = True, # Normalize latent vectors before feeding them to the network? 155 | use_wscale = True, # Enable equalized learning rate? 156 | use_pixelnorm = True, # Enable pixelwise feature vector normalization? 157 | pixelnorm_epsilon = 1e-8, # Constant epsilon for pixelwise feature vector normalization. 158 | use_leakyrelu = True, # True = leaky ReLU, False = ReLU. 159 | dtype = 'float32', # Data type to use for activations and outputs. 160 | fused_scale = True, # True = use fused upscale2d + conv2d, False = separate upscale2d layers. 161 | structure = 'linear', # 'linear' = human-readable, 'recursive' = efficient, None = select automatically. 162 | is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. 163 | **kwargs): # Ignore unrecognized keyword args. 164 | 165 | resolution_log2 = int(np.log2(resolution)) 166 | assert resolution == 2**resolution_log2 and resolution >= 4 167 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 168 | def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x 169 | if latent_size is None: latent_size = nf(0) 170 | if structure is None: structure = 'linear' if is_template_graph else 'recursive' 171 | act = leaky_relu if use_leakyrelu else tf.nn.relu 172 | 173 | latents_in.set_shape([None, latent_size]) 174 | labels_in.set_shape([None, label_size]) 175 | combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype) 176 | lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) 177 | 178 | # Building blocks. 179 | def block(x, res): # res = 2..resolution_log2 180 | with tf.variable_scope('%dx%d' % (2**res, 2**res)): 181 | if res == 2: # 4x4 182 | if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon) 183 | with tf.variable_scope('Dense'): 184 | x = dense(x, fmaps=nf(res-1)*64, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation 185 | x = tf.reshape(x, [-1, nf(res-1), 4, 4, 4]) 186 | x = PN(act(apply_bias(x))) 187 | with tf.variable_scope('Conv'): 188 | x = PN(act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 189 | else: # 8x8 and up 190 | if fused_scale: 191 | with tf.variable_scope('Conv0_up'): 192 | x = PN(act(apply_bias(upscale3d_conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 193 | else: 194 | x = upscale3d(x) 195 | with tf.variable_scope('Conv0'): 196 | x = PN(act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 197 | # with tf.variable_scope('Conv1'): 198 | # x = PN(act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 199 | return x 200 | def torgb(x, res): # res = 2..resolution_log2 201 | lod = resolution_log2 - res 202 | with tf.variable_scope('ToRGB_lod%d' % lod): 203 | return apply_bias(conv3d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale)) 204 | 205 | # Linear structure: simple but inefficient. 206 | if structure == 'linear': 207 | x = block(combo_in, 2) 208 | images_out = torgb(x, 2) 209 | for res in range(3, resolution_log2 + 1): 210 | lod = resolution_log2 - res 211 | x = block(x, res) 212 | img = torgb(x, res) 213 | images_out = upscale3d(images_out) 214 | with tf.variable_scope('Grow_lod%d' % lod): 215 | images_out = lerp_clip(img, images_out, lod_in - lod) 216 | 217 | # Recursive structure: complex but efficient. 218 | if structure == 'recursive': 219 | def grow(x, res, lod): 220 | y = block(x, res) 221 | img = lambda: upscale3d(torgb(y, res), 2**lod) 222 | if res > 2: img = cset(img, (lod_in > lod), lambda: upscale3d(lerp(torgb(y, res), upscale3d(torgb(x, res - 1)), lod_in - lod), 2**lod)) 223 | if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1)) 224 | return img() 225 | images_out = grow(combo_in, 2, resolution_log2 - 2) 226 | 227 | assert images_out.dtype == tf.as_dtype(dtype) 228 | images_out = tf.identity(images_out, name='images_out') 229 | return images_out 230 | 231 | #---------------------------------------------------------------------------- 232 | # Discriminator network used in the paper. 233 | 234 | def D_paper( 235 | images_in, # Input: Images [minibatch, channel, height, width]. 236 | labels_in, 237 | num_channels = 1, # Number of input color channels. Overridden based on dataset. 238 | resolution = 32, # Input resolution. Overridden based on dataset. 239 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 240 | fmap_base = 2048, # Overall multiplier for the number of feature maps. 241 | fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. 242 | fmap_max = 256, # Maximum number of feature maps in any layer. 243 | use_wscale = True, # Enable equalized learning rate? 244 | mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. 245 | dtype = 'float32', # Data type to use for activations and outputs. 246 | fused_scale = True, # True = use fused conv2d + downscale2d, False = separate downscale2d layers. 247 | structure = 'linear', # 'linear' = human-readable, 'recursive' = efficient, None = select automatically 248 | is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. 249 | **kwargs): # Ignore unrecognized keyword args. 250 | 251 | resolution_log2 = int(np.log2(resolution)) 252 | assert resolution == 2**resolution_log2 and resolution >= 4 253 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 254 | if structure is None: structure = 'linear' if is_template_graph else 'recursive' 255 | act = leaky_relu 256 | 257 | images_in.set_shape([None, num_channels, resolution, resolution, resolution]) 258 | images_in = tf.cast(images_in, dtype) 259 | labels_in.set_shape([None, label_size]) 260 | labels_in = tf.cast(labels_in, dtype) 261 | lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) 262 | 263 | scores_out = None 264 | 265 | # Building blocks. 266 | def fromrgb(x, res): # res = 2..resolution_log2 267 | with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): 268 | # print(res, nf(res-1)) 269 | return act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale))) 270 | def block(x, res): # res = 2..resolution_log2 271 | with tf.variable_scope('%dx%d' % (2**res, 2**res)): 272 | if res >= 3: # 8x8 and up 273 | with tf.variable_scope('Conv0'): 274 | x = act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) 275 | if fused_scale: 276 | with tf.variable_scope('Conv1_down'): 277 | x = act(apply_bias(conv3d_downscale3d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) 278 | else: 279 | with tf.variable_scope('Conv1'): 280 | x = act(apply_bias(conv3d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) 281 | x = downscale3d(x) 282 | else: # 4x4 283 | if mbstd_group_size > 1: 284 | x = minibatch_stddev_layer(x, mbstd_group_size) 285 | with tf.variable_scope('Conv'): 286 | x = act(apply_bias(conv3d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) 287 | with tf.variable_scope('Dense0'): 288 | x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale))) 289 | with tf.variable_scope('Dense1'): 290 | x = apply_bias(dense(x, fmaps=1, gain=1, use_wscale=use_wscale)) 291 | return x 292 | 293 | # Linear structure: simple but inefficient. 294 | if structure == 'linear': 295 | img = images_in 296 | x = fromrgb(img, resolution_log2) 297 | for res in range(resolution_log2, 2, -1): 298 | if res<(resolution_log2)/2: 299 | with tf.device('/gpu:0'): 300 | lod = resolution_log2 - res 301 | x = block(x, res) 302 | img = downscale3d(img) 303 | y = fromrgb(img, res - 1) 304 | with tf.variable_scope('Grow_lod%d' % lod): 305 | x = lerp_clip(x, y, lod_in - lod) 306 | 307 | else: 308 | with tf.device('/gpu:1'): 309 | lod = resolution_log2 - res 310 | x = block(x, res) 311 | img = downscale3d(img) 312 | y = fromrgb(img, res - 1) 313 | with tf.variable_scope('Grow_lod%d' % lod): 314 | x = lerp_clip(x, y, lod_in - lod) 315 | scores_out = block(x, 2) 316 | 317 | # Recursive structure: complex but efficient. 318 | if structure == 'recursive': 319 | def grow(res, lod): 320 | x = lambda: fromrgb(downscale3d(images_in, 2**lod), res) 321 | if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) 322 | x = block(x(), res); y = lambda: x 323 | if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale3d(images_in, 2**(lod+1)), res - 1), lod_in - lod)) 324 | return y() 325 | scores_out = grow(2, resolution_log2 - 2) 326 | 327 | # assert combo_out.dtype == tf.as_dtype(dtype) 328 | # scores_out = tf.identity(combo_out[:, :1], name='scores_out') 329 | # labels_out = tf.identity(combo_out[:, 1:], name='labels_out') 330 | assert scores_out.dtype == tf.as_dtype(dtype) 331 | 332 | scores_out = tf.identity(scores_out, name='scores_out') 333 | 334 | 335 | return scores_out 336 | 337 | #---------------------------------------------------------------------------- 338 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | 3 | 4 | Attribution-NonCommercial 4.0 International 5 | 6 | ======================================================================= 7 | 8 | Creative Commons Corporation ("Creative Commons") is not a law firm and 9 | does not provide legal services or legal advice. Distribution of 10 | Creative Commons public licenses does not create a lawyer-client or 11 | other relationship. Creative Commons makes its licenses and related 12 | information available on an "as-is" basis. Creative Commons gives no 13 | warranties regarding its licenses, any material licensed under their 14 | terms and conditions, or any related information. Creative Commons 15 | disclaims all liability for damages resulting from their use to the 16 | fullest extent possible. 17 | 18 | Using Creative Commons Public Licenses 19 | 20 | Creative Commons public licenses provide a standard set of terms and 21 | conditions that creators and other rights holders may use to share 22 | original works of authorship and other material subject to copyright 23 | and certain other rights specified in the public license below. The 24 | following considerations are for informational purposes only, are not 25 | exhaustive, and do not form part of our licenses. 26 | 27 | Considerations for licensors: Our public licenses are 28 | intended for use by those authorized to give the public 29 | permission to use material in ways otherwise restricted by 30 | copyright and certain other rights. Our licenses are 31 | irrevocable. Licensors should read and understand the terms 32 | and conditions of the license they choose before applying it. 33 | Licensors should also secure all rights necessary before 34 | applying our licenses so that the public can reuse the 35 | material as expected. Licensors should clearly mark any 36 | material not subject to the license. This includes other CC- 37 | licensed material, or material used under an exception or 38 | limitation to copyright. More considerations for licensors: 39 | wiki.creativecommons.org/Considerations_for_licensors 40 | 41 | Considerations for the public: By using one of our public 42 | licenses, a licensor grants the public permission to use the 43 | licensed material under specified terms and conditions. If 44 | the licensor's permission is not necessary for any reason--for 45 | example, because of any applicable exception or limitation to 46 | copyright--then that use is not regulated by the license. Our 47 | licenses grant only permissions under copyright and certain 48 | other rights that a licensor has authority to grant. Use of 49 | the licensed material may still be restricted for other 50 | reasons, including because others have copyright or other 51 | rights in the material. A licensor may make special requests, 52 | such as asking that all changes be marked or described. 53 | Although not required by our licenses, you are encouraged to 54 | respect those requests where reasonable. More_considerations 55 | for the public: 56 | wiki.creativecommons.org/Considerations_for_licensees 57 | 58 | ======================================================================= 59 | 60 | Creative Commons Attribution-NonCommercial 4.0 International Public 61 | License 62 | 63 | By exercising the Licensed Rights (defined below), You accept and agree 64 | to be bound by the terms and conditions of this Creative Commons 65 | Attribution-NonCommercial 4.0 International Public License ("Public 66 | License"). To the extent this Public License may be interpreted as a 67 | contract, You are granted the Licensed Rights in consideration of Your 68 | acceptance of these terms and conditions, and the Licensor grants You 69 | such rights in consideration of benefits the Licensor receives from 70 | making the Licensed Material available under these terms and 71 | conditions. 72 | 73 | 74 | Section 1 -- Definitions. 75 | 76 | a. Adapted Material means material subject to Copyright and Similar 77 | Rights that is derived from or based upon the Licensed Material 78 | and in which the Licensed Material is translated, altered, 79 | arranged, transformed, or otherwise modified in a manner requiring 80 | permission under the Copyright and Similar Rights held by the 81 | Licensor. For purposes of this Public License, where the Licensed 82 | Material is a musical work, performance, or sound recording, 83 | Adapted Material is always produced where the Licensed Material is 84 | synched in timed relation with a moving image. 85 | 86 | b. Adapter's License means the license You apply to Your Copyright 87 | and Similar Rights in Your contributions to Adapted Material in 88 | accordance with the terms and conditions of this Public License. 89 | 90 | c. Copyright and Similar Rights means copyright and/or similar rights 91 | closely related to copyright including, without limitation, 92 | performance, broadcast, sound recording, and Sui Generis Database 93 | Rights, without regard to how the rights are labeled or 94 | categorized. For purposes of this Public License, the rights 95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 96 | Rights. 97 | d. Effective Technological Measures means those measures that, in the 98 | absence of proper authority, may not be circumvented under laws 99 | fulfilling obligations under Article 11 of the WIPO Copyright 100 | Treaty adopted on December 20, 1996, and/or similar international 101 | agreements. 102 | 103 | e. Exceptions and Limitations means fair use, fair dealing, and/or 104 | any other exception or limitation to Copyright and Similar Rights 105 | that applies to Your use of the Licensed Material. 106 | 107 | f. Licensed Material means the artistic or literary work, database, 108 | or other material to which the Licensor applied this Public 109 | License. 110 | 111 | g. Licensed Rights means the rights granted to You subject to the 112 | terms and conditions of this Public License, which are limited to 113 | all Copyright and Similar Rights that apply to Your use of the 114 | Licensed Material and that the Licensor has authority to license. 115 | 116 | h. Licensor means the individual(s) or entity(ies) granting rights 117 | under this Public License. 118 | 119 | i. NonCommercial means not primarily intended for or directed towards 120 | commercial advantage or monetary compensation. For purposes of 121 | this Public License, the exchange of the Licensed Material for 122 | other material subject to Copyright and Similar Rights by digital 123 | file-sharing or similar means is NonCommercial provided there is 124 | no payment of monetary compensation in connection with the 125 | exchange. 126 | 127 | j. Share means to provide material to the public by any means or 128 | process that requires permission under the Licensed Rights, such 129 | as reproduction, public display, public performance, distribution, 130 | dissemination, communication, or importation, and to make material 131 | available to the public including in ways that members of the 132 | public may access the material from a place and at a time 133 | individually chosen by them. 134 | 135 | k. Sui Generis Database Rights means rights other than copyright 136 | resulting from Directive 96/9/EC of the European Parliament and of 137 | the Council of 11 March 1996 on the legal protection of databases, 138 | as amended and/or succeeded, as well as other essentially 139 | equivalent rights anywhere in the world. 140 | 141 | l. You means the individual or entity exercising the Licensed Rights 142 | under this Public License. Your has a corresponding meaning. 143 | 144 | 145 | Section 2 -- Scope. 146 | 147 | a. License grant. 148 | 149 | 1. Subject to the terms and conditions of this Public License, 150 | the Licensor hereby grants You a worldwide, royalty-free, 151 | non-sublicensable, non-exclusive, irrevocable license to 152 | exercise the Licensed Rights in the Licensed Material to: 153 | 154 | a. reproduce and Share the Licensed Material, in whole or 155 | in part, for NonCommercial purposes only; and 156 | 157 | b. produce, reproduce, and Share Adapted Material for 158 | NonCommercial purposes only. 159 | 160 | 2. Exceptions and Limitations. For the avoidance of doubt, where 161 | Exceptions and Limitations apply to Your use, this Public 162 | License does not apply, and You do not need to comply with 163 | its terms and conditions. 164 | 165 | 3. Term. The term of this Public License is specified in Section 166 | 6(a). 167 | 168 | 4. Media and formats; technical modifications allowed. The 169 | Licensor authorizes You to exercise the Licensed Rights in 170 | all media and formats whether now known or hereafter created, 171 | and to make technical modifications necessary to do so. The 172 | Licensor waives and/or agrees not to assert any right or 173 | authority to forbid You from making technical modifications 174 | necessary to exercise the Licensed Rights, including 175 | technical modifications necessary to circumvent Effective 176 | Technological Measures. For purposes of this Public License, 177 | simply making modifications authorized by this Section 2(a) 178 | (4) never produces Adapted Material. 179 | 180 | 5. Downstream recipients. 181 | 182 | a. Offer from the Licensor -- Licensed Material. Every 183 | recipient of the Licensed Material automatically 184 | receives an offer from the Licensor to exercise the 185 | Licensed Rights under the terms and conditions of this 186 | Public License. 187 | 188 | b. No downstream restrictions. You may not offer or impose 189 | any additional or different terms or conditions on, or 190 | apply any Effective Technological Measures to, the 191 | Licensed Material if doing so restricts exercise of the 192 | Licensed Rights by any recipient of the Licensed 193 | Material. 194 | 195 | 6. No endorsement. Nothing in this Public License constitutes or 196 | may be construed as permission to assert or imply that You 197 | are, or that Your use of the Licensed Material is, connected 198 | with, or sponsored, endorsed, or granted official status by, 199 | the Licensor or others designated to receive attribution as 200 | provided in Section 3(a)(1)(A)(i). 201 | 202 | b. Other rights. 203 | 204 | 1. Moral rights, such as the right of integrity, are not 205 | licensed under this Public License, nor are publicity, 206 | privacy, and/or other similar personality rights; however, to 207 | the extent possible, the Licensor waives and/or agrees not to 208 | assert any such rights held by the Licensor to the limited 209 | extent necessary to allow You to exercise the Licensed 210 | Rights, but not otherwise. 211 | 212 | 2. Patent and trademark rights are not licensed under this 213 | Public License. 214 | 215 | 3. To the extent possible, the Licensor waives any right to 216 | collect royalties from You for the exercise of the Licensed 217 | Rights, whether directly or through a collecting society 218 | under any voluntary or waivable statutory or compulsory 219 | licensing scheme. In all other cases the Licensor expressly 220 | reserves any right to collect such royalties, including when 221 | the Licensed Material is used other than for NonCommercial 222 | purposes. 223 | 224 | 225 | Section 3 -- License Conditions. 226 | 227 | Your exercise of the Licensed Rights is expressly made subject to the 228 | following conditions. 229 | 230 | a. Attribution. 231 | 232 | 1. If You Share the Licensed Material (including in modified 233 | form), You must: 234 | 235 | a. retain the following if it is supplied by the Licensor 236 | with the Licensed Material: 237 | 238 | i. identification of the creator(s) of the Licensed 239 | Material and any others designated to receive 240 | attribution, in any reasonable manner requested by 241 | the Licensor (including by pseudonym if 242 | designated); 243 | 244 | ii. a copyright notice; 245 | 246 | iii. a notice that refers to this Public License; 247 | 248 | iv. a notice that refers to the disclaimer of 249 | warranties; 250 | 251 | v. a URI or hyperlink to the Licensed Material to the 252 | extent reasonably practicable; 253 | 254 | b. indicate if You modified the Licensed Material and 255 | retain an indication of any previous modifications; and 256 | 257 | c. indicate the Licensed Material is licensed under this 258 | Public License, and include the text of, or the URI or 259 | hyperlink to, this Public License. 260 | 261 | 2. You may satisfy the conditions in Section 3(a)(1) in any 262 | reasonable manner based on the medium, means, and context in 263 | which You Share the Licensed Material. For example, it may be 264 | reasonable to satisfy the conditions by providing a URI or 265 | hyperlink to a resource that includes the required 266 | information. 267 | 268 | 3. If requested by the Licensor, You must remove any of the 269 | information required by Section 3(a)(1)(A) to the extent 270 | reasonably practicable. 271 | 272 | 4. If You Share Adapted Material You produce, the Adapter's 273 | License You apply must not prevent recipients of the Adapted 274 | Material from complying with this Public License. 275 | 276 | 277 | Section 4 -- Sui Generis Database Rights. 278 | 279 | Where the Licensed Rights include Sui Generis Database Rights that 280 | apply to Your use of the Licensed Material: 281 | 282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 283 | to extract, reuse, reproduce, and Share all or a substantial 284 | portion of the contents of the database for NonCommercial purposes 285 | only; 286 | 287 | b. if You include all or a substantial portion of the database 288 | contents in a database in which You have Sui Generis Database 289 | Rights, then the database in which You have Sui Generis Database 290 | Rights (but not its individual contents) is Adapted Material; and 291 | 292 | c. You must comply with the conditions in Section 3(a) if You Share 293 | all or a substantial portion of the contents of the database. 294 | 295 | For the avoidance of doubt, this Section 4 supplements and does not 296 | replace Your obligations under this Public License where the Licensed 297 | Rights include other Copyright and Similar Rights. 298 | 299 | 300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 301 | 302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 312 | 313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 322 | 323 | c. The disclaimer of warranties and limitation of liability provided 324 | above shall be interpreted in a manner that, to the extent 325 | possible, most closely approximates an absolute disclaimer and 326 | waiver of all liability. 327 | 328 | 329 | Section 6 -- Term and Termination. 330 | 331 | a. This Public License applies for the term of the Copyright and 332 | Similar Rights licensed here. However, if You fail to comply with 333 | this Public License, then Your rights under this Public License 334 | terminate automatically. 335 | 336 | b. Where Your right to use the Licensed Material has terminated under 337 | Section 6(a), it reinstates: 338 | 339 | 1. automatically as of the date the violation is cured, provided 340 | it is cured within 30 days of Your discovery of the 341 | violation; or 342 | 343 | 2. upon express reinstatement by the Licensor. 344 | 345 | For the avoidance of doubt, this Section 6(b) does not affect any 346 | right the Licensor may have to seek remedies for Your violations 347 | of this Public License. 348 | 349 | c. For the avoidance of doubt, the Licensor may also offer the 350 | Licensed Material under separate terms or conditions or stop 351 | distributing the Licensed Material at any time; however, doing so 352 | will not terminate this Public License. 353 | 354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 355 | License. 356 | 357 | 358 | Section 7 -- Other Terms and Conditions. 359 | 360 | a. The Licensor shall not be bound by any additional or different 361 | terms or conditions communicated by You unless expressly agreed. 362 | 363 | b. Any arrangements, understandings, or agreements regarding the 364 | Licensed Material not stated herein are separate from and 365 | independent of the terms and conditions of this Public License. 366 | 367 | 368 | Section 8 -- Interpretation. 369 | 370 | a. For the avoidance of doubt, this Public License does not, and 371 | shall not be interpreted to, reduce, limit, restrict, or impose 372 | conditions on any use of the Licensed Material that could lawfully 373 | be made without permission under this Public License. 374 | 375 | b. To the extent possible, if any provision of this Public License is 376 | deemed unenforceable, it shall be automatically reformed to the 377 | minimum extent necessary to make it enforceable. If the provision 378 | cannot be reformed, it shall be severed from this Public License 379 | without affecting the enforceability of the remaining terms and 380 | conditions. 381 | 382 | c. No term or condition of this Public License will be waived and no 383 | failure to comply consented to unless expressly agreed to by the 384 | Licensor. 385 | 386 | d. Nothing in this Public License constitutes or may be interpreted 387 | as a limitation upon, or waiver of, any privileges and immunities 388 | that apply to the Licensor or You, including from the legal 389 | processes of any jurisdiction or authority. 390 | 391 | ======================================================================= 392 | 393 | Creative Commons is not a party to its public 394 | licenses. Notwithstanding, Creative Commons may elect to apply one of 395 | its public licenses to material it publishes and in those instances 396 | will be considered the "Licensor." The text of the Creative Commons 397 | public licenses is dedicated to the public domain under the CC0 Public 398 | Domain Dedication. Except for the limited purpose of indicating that 399 | material is shared under a Creative Commons public license or as 400 | otherwise permitted by the Creative Commons policies published at 401 | creativecommons.org/policies, Creative Commons does not authorize the 402 | use of the trademark "Creative Commons" or any other trademark or logo 403 | of Creative Commons without its prior written consent including, 404 | without limitation, in connection with any unauthorized modifications 405 | to any of its public licenses or any other arrangements, 406 | understandings, or agreements concerning use of licensed material. For 407 | the avoidance of doubt, this paragraph does not form part of the 408 | public licenses. 409 | 410 | Creative Commons may be contacted at creativecommons.org. 411 | --------------------------------------------------------------------------------