├── .gitignore ├── Dockerfile ├── LICENSE.txt ├── README.md ├── aydao_flesh_digressions.py ├── blend_models.py ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib ├── __init__.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── custom_ops.py │ ├── network.py │ ├── ops │ │ ├── __init__.py │ │ ├── fused_bias_act.cu │ │ ├── fused_bias_act.py │ │ ├── upfirdn_2d.cu │ │ └── upfirdn_2d.py │ ├── optimizer.py │ └── tfutil.py └── util.py ├── docs ├── license.html ├── stylegan2-ada-teaser-1024x252.png ├── stylegan2-ada-training-curves.png └── train-help.txt ├── ffhq_dataset ├── __init__.py ├── face_alignment.py └── landmarks_detector.py ├── generate.py ├── grid_vid.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── linear_separability.py ├── metric_base.py ├── metric_defaults.py ├── perceptual_path_length.py └── precision_recall.py ├── notebooks └── closed_form_wip.ipynb ├── projector.py ├── style_mixing.py ├── train.py ├── training ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py ├── misc.py ├── networks.py └── training_loop.py └── utils ├── align_faces.py └── tffreeze.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM tensorflow/tensorflow:1.14.0-gpu-py3 10 | 11 | RUN pip install scipy==1.3.3 12 | RUN pip install requests==2.22.0 13 | RUN pip install Pillow==6.2.1 14 | RUN pip install h5py==2.9.0 15 | RUN pip install imageio==2.9.0 16 | RUN pip install imageio-ffmpeg==0.4.2 17 | RUN pip install tqdm==4.49.0 18 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /aydao_flesh_digressions.py: -------------------------------------------------------------------------------- 1 | # 2 | # ~~ Flesh Digressions ~~ 3 | # Or, Circular Interpolation of the StyleGAN Synthesis Network's Constant Layer 4 | # ~~~ aydao ~~~~ 2020 ~~~ 5 | # 6 | # Based on halcy's circular interpolation script https://pastebin.com/RTtV2UY7 7 | # 8 | import warnings 9 | warnings.filterwarnings('ignore', category=FutureWarning) 10 | warnings.filterwarnings('ignore', category=DeprecationWarning) 11 | import tensorflow as tf 12 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 13 | import dnnlib 14 | import dnnlib.tflib as tflib 15 | import math 16 | import moviepy.editor 17 | from numpy import linalg 18 | import numpy as np 19 | import pickle 20 | import argparse 21 | from datetime import datetime 22 | 23 | def circular_interpolation(radius, latents_persistent, latents_interpolate): 24 | 25 | latents_a, latents_b, latents_c = latents_persistent 26 | 27 | latents_axis_x = (latents_a - latents_b).flatten() / linalg.norm(latents_a - latents_b) 28 | latents_axis_y = (latents_a - latents_c).flatten() / linalg.norm(latents_a - latents_c) 29 | 30 | latents_x = math.sin(math.pi * 2.0 * latents_interpolate) * radius 31 | latents_y = math.cos(math.pi * 2.0 * latents_interpolate) * radius 32 | 33 | latents = latents_a + latents_x * latents_axis_x + latents_y * latents_axis_y 34 | return latents 35 | 36 | def generate_from_generator_adaptive(psi,radius_large,radius_small,step1,step2,video_length,seed,Gs): 37 | # psi = args.psi # 0.7 38 | # radius_large = args.radius_large # 600.0 39 | # radius_small = args.radius_small # 40.0 40 | current_position_increment = step1 # 0.005 41 | current_position_style_increment = step2 # 0.0025 42 | # video_length = args.video_length # 1.0 43 | output_format = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 44 | 45 | # latents for the circular interpolation in latent space 46 | if seed: 47 | np.random.RandomState(seed) 48 | rnd = np.random 49 | latents_a = rnd.randn(1, Gs.input_shape[1]) 50 | latents_b = rnd.randn(1, Gs.input_shape[1]) 51 | latents_c = rnd.randn(1, Gs.input_shape[1]) 52 | latents_persistent_small = (latents_a, latents_b, latents_c) 53 | 54 | # latents for the circular interpolation of the unrolled constant layer 55 | latent_size = 512 # default StyleGAN latent size 56 | constant_layer_size = 4 # default StyleGAN constant layer size is 4x4 57 | constant_layer_total = latent_size * constant_layer_size * constant_layer_size # 8192 58 | latents_aa = rnd.randn(1, constant_layer_total) 59 | latents_bb = rnd.randn(1, constant_layer_total) 60 | latents_cc = rnd.randn(1, constant_layer_total) 61 | latents_persistent_large = (latents_aa, latents_bb, latents_cc) 62 | 63 | # initialize the circular interpolation 64 | current_position = 0.0 65 | current_position_style = 0.0 66 | current_latent = circular_interpolation(radius_small, latents_persistent_small, current_position) 67 | current_image = Gs.run(current_latent, None, truncation_psi=psi, randomize_noise=False, output_transform=output_format)[0] 68 | output_frames = [] 69 | 70 | # Create the frames while interpolating along the circle, in both the latent space and the constant layer 71 | while(current_position_style < video_length): 72 | 73 | current_position += current_position_increment 74 | current_position_style += current_position_style_increment 75 | 76 | # interpolate the weights of the constant layer 77 | w = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == 'G_synthesis_1/4x4/Const/const:0'][0] 78 | v1 = tf.identity(tflib.run(['G_synthesis_1/4x4/Const/const:0'])[0]) 79 | v2 = tf.reshape(v1, [1, constant_layer_total]) 80 | v2 += circular_interpolation(radius_large, latents_persistent_large, current_position + np.pi) 81 | v2 = tf.reshape(v2, [1, latent_size, constant_layer_size, constant_layer_size]) 82 | tf.get_default_session().run(tf.assign(w, v2)) 83 | 84 | # interpolate along the latent space 85 | current_latent = circular_interpolation(radius_small, latents_persistent_small, current_position_style) 86 | current_image = images = Gs.run(current_latent, None, truncation_psi=psi, randomize_noise=False, output_transform=output_format)[0] 87 | output_frames.append(current_image) 88 | 89 | tf.get_default_session().run(tf.assign(w, v1)) 90 | 91 | # stops at 1.0 (or whatever value to which video_length is set) 92 | print('Stopping at',video_length,'currently at',current_position_style, flush=True) 93 | 94 | return output_frames 95 | 96 | def main(pkl,psi,radius_large,radius_small,step1,step2,seed,video_length=1.0): 97 | 98 | tflib.init_tf() 99 | print('Loading networks from "%s"...' % pkl) 100 | with dnnlib.util.open_url(pkl) as fp: 101 | _G, _D, Gs = pickle.load(fp) 102 | 103 | frames = generate_from_generator_adaptive(psi,radius_large,radius_small,step1,step2,video_length,seed, Gs) 104 | frames = moviepy.editor.ImageSequenceClip(frames, fps=30) 105 | 106 | # Generate video at the current date and timestamp 107 | timestamp = datetime.now().strftime("%d-%m-%Y-%I-%M-%S-%p") 108 | mp4_file = './circular-'+timestamp+'.mp4' 109 | mp4_codec = 'libx264' 110 | mp4_bitrate = '15M' 111 | mp4_fps = 24 # 20 112 | 113 | frames.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate) 114 | 115 | sess = tf.get_default_session() 116 | sess.close() 117 | 118 | if __name__ == "__main__": 119 | 120 | parser = argparse.ArgumentParser( 121 | description='Creates a video of a circular interpolation of the constant layer for an input StyleGAN model.', 122 | formatter_class=argparse.RawDescriptionHelpFormatter 123 | ) 124 | 125 | parser.add_argument('--pkl', help='A .pkl of a StyleGAN network model', required=True) 126 | parser.add_argument('--psi', help='The truncation psi used in the generator', default=0.7, type=float) 127 | parser.add_argument('--radius_large', help='The radius for the constant layer interpolation', default=300.0, type=float) 128 | parser.add_argument('--radius_small', help='The radius for the latent space interpolation', default=40.0, type=float) 129 | parser.add_argument('--step1', help='The value of the step/increment for the constant layer interpolation', default=0.005, type=float) 130 | parser.add_argument('--step2', help='The value of the step/increment for the latent space interpolation', default=0.0025, type=float) 131 | parser.add_argument('--seed', help='Seed value for random', default=None, type=int) 132 | parser.add_argument('--video_length', help='The length of the video in terms of circular interpolation (recommended to keep at 1.0)', default=1.0, type=float) 133 | 134 | args = parser.parse_args() 135 | 136 | main(args.pkl, args.psi, args.radius_large, args.radius_small, args.step1, args.step2, args.seed, args.video_length) 137 | -------------------------------------------------------------------------------- /blend_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import sys, getopt, os 3 | 4 | import numpy as np 5 | import dnnlib 6 | import dnnlib.tflib as tflib 7 | from dnnlib.tflib import tfutil 8 | from dnnlib.tflib.autosummary import autosummary 9 | import math 10 | import numpy as np 11 | 12 | from training import dataset 13 | from training import misc 14 | import pickle 15 | 16 | from pathlib import Path 17 | import typer 18 | from typing import Optional 19 | 20 | def extract_conv_names(model): 21 | # layers are G_synthesis/{res}x{res}/... 22 | # make a list of (name, resolution, level, position) 23 | # Currently assuming square(?) 24 | 25 | model_names = list(model.trainables.keys()) 26 | conv_names = [] 27 | 28 | resolutions = [4*2**x for x in range(9)] 29 | 30 | level_names = [["Conv0_up", "Const"], 31 | ["Conv1", "ToRGB"]] 32 | 33 | position = 0 34 | # option not to split levels 35 | for res in resolutions: 36 | root_name = f"G_synthesis/{res}x{res}/" 37 | for level, level_suffixes in enumerate(level_names): 38 | for suffix in level_suffixes: 39 | search_name = root_name + suffix 40 | matched_names = [x for x in model_names if x.startswith(search_name)] 41 | to_add = [(name, f"{res}x{res}", level, position) for name in matched_names] 42 | conv_names.extend(to_add) 43 | position += 1 44 | 45 | return conv_names 46 | 47 | 48 | def blend_models(model_1, model_2, resolution, level, blend_width=None, verbose=False): 49 | 50 | # y is the blending amount which y = 0 means all model 1, y = 1 means all model_2 51 | 52 | # TODO add small x offset for smoother blend animations 53 | resolution = f"{resolution}x{resolution}" 54 | 55 | model_1_names = extract_conv_names(model_1) 56 | model_2_names = extract_conv_names(model_2) 57 | 58 | assert all((x == y for x, y in zip(model_1_names, model_2_names))) 59 | 60 | model_out = model_1.clone() 61 | 62 | short_names = [(x[1:3]) for x in model_1_names] 63 | full_names = [(x[0]) for x in model_1_names] 64 | mid_point_idx = short_names.index((resolution, level)) 65 | mid_point_pos = model_1_names[mid_point_idx][3] 66 | 67 | ys = [] 68 | for name, resolution, level, position in model_1_names: 69 | # low to high (res) 70 | x = position - mid_point_pos 71 | if blend_width: 72 | exponent = -x/blend_width 73 | y = 1 / (1 + math.exp(exponent)) 74 | else: 75 | y = 1 if x > 1 else 0 76 | 77 | ys.append(y) 78 | if verbose: 79 | print(f"Blending {name} by {y}") 80 | 81 | tfutil.set_vars( 82 | tfutil.run( 83 | {model_out.vars[name]: (model_2.vars[name] * y + model_1.vars[name] * (1-y)) 84 | for name, y 85 | in zip(full_names, ys)} 86 | ) 87 | ) 88 | 89 | return model_out 90 | 91 | def main(low_res_pkl: Path, # Pickle file from which to take low res layers 92 | high_res_pkl: Path, # Pickle file from which to take high res layers 93 | resolution: int, # Resolution level at which to switch between models 94 | level: int = 0, # Switch at Conv block 0 or 1? 95 | blend_width: Optional[float] = None, # None = hard switch, float = smooth switch (logistic) with given width 96 | output_grid: Optional[Path] = "blended.jpg", # Path of image file to save example grid (None = don't save) 97 | seed: int = 0, # seed for random grid 98 | output_pkl: Optional[Path] = None, # Output path of pickle (None = don't save) 99 | verbose: bool = False, # Print out the exact blending fraction 100 | ): 101 | 102 | grid_size = (3, 3) 103 | 104 | tflib.init_tf() 105 | 106 | with tf.Session() as sess, tf.device('/gpu:0'): 107 | low_res_G, low_res_D, low_res_Gs = misc.load_pkl(low_res_pkl) 108 | high_res_G, high_res_D, high_res_Gs = misc.load_pkl(high_res_pkl) 109 | 110 | out = blend_models(low_res_Gs, high_res_Gs, resolution, level, blend_width=blend_width, verbose=verbose) 111 | 112 | if output_grid: 113 | rnd = np.random.RandomState(seed) 114 | grid_latents = rnd.randn(np.prod(grid_size), *out.input_shape[1:]) 115 | grid_fakes = out.run(grid_latents, None, is_validation=True, minibatch_size=1) 116 | misc.save_image_grid(grid_fakes, output_grid, drange= [-1,1], grid_size=grid_size) 117 | 118 | # TODO modify all the networks 119 | if output_pkl: 120 | misc.save_pkl((low_res_G, low_res_D, out), output_pkl) 121 | 122 | 123 | if __name__ == '__main__': 124 | typer.run(main) 125 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import argparse 13 | import json 14 | import pickle 15 | import dnnlib 16 | import dnnlib.tflib as tflib 17 | 18 | from metrics import metric_defaults 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | class UserError(Exception): 23 | pass 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def calc_metrics(network_pkl, metric_names, metricdata, mirror, gpus): 28 | tflib.init_tf() 29 | 30 | # Initialize metrics. 31 | metrics = [] 32 | for name in metric_names: 33 | if name not in metric_defaults.metric_defaults: 34 | raise UserError('\n'.join(['--metrics can only contain the following values:', 'none'] + list(metric_defaults.metric_defaults.keys()))) 35 | metrics.append(dnnlib.util.construct_class_by_name(**metric_defaults.metric_defaults[name])) 36 | 37 | # Load network. 38 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 39 | raise UserError('--network must point to a file or URL') 40 | print(f'Loading network from "{network_pkl}"...') 41 | with dnnlib.util.open_url(network_pkl) as f: 42 | _G, _D, Gs = pickle.load(f) 43 | Gs.print_layers() 44 | 45 | # Look up training options. 46 | run_dir = None 47 | training_options = None 48 | if os.path.isfile(network_pkl): 49 | potential_run_dir = os.path.dirname(network_pkl) 50 | potential_json_file = os.path.join(potential_run_dir, 'training_options.json') 51 | if os.path.isfile(potential_json_file): 52 | print(f'Looking up training options from "{potential_json_file}"...') 53 | run_dir = potential_run_dir 54 | with open(potential_json_file, 'rt') as f: 55 | training_options = json.load(f, object_pairs_hook=dnnlib.EasyDict) 56 | if training_options is None: 57 | print('Could not look up training options; will rely on --metricdata and --mirror') 58 | 59 | # Choose dataset options. 60 | dataset_options = dnnlib.EasyDict() 61 | if training_options is not None: 62 | dataset_options.update(training_options.metric_dataset_args) 63 | dataset_options.resolution = Gs.output_shapes[0][-1] 64 | dataset_options.max_label_size = Gs.input_shapes[1][-1] 65 | if metricdata is not None: 66 | if not os.path.isdir(metricdata): 67 | raise UserError('--metricdata must point to a directory containing *.tfrecords') 68 | dataset_options.path = metricdata 69 | if mirror is not None: 70 | dataset_options.mirror_augment = mirror 71 | if 'path' not in dataset_options: 72 | raise UserError('--metricdata must be specified explicitly') 73 | 74 | # Print dataset options. 75 | print() 76 | print('Dataset options:') 77 | print(json.dumps(dataset_options, indent=2)) 78 | 79 | # Evaluate metrics. 80 | for metric in metrics: 81 | print() 82 | print(f'Evaluating {metric.name}...') 83 | metric.configure(dataset_args=dataset_options, run_dir=run_dir) 84 | metric.run(network_pkl=network_pkl, num_gpus=gpus) 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | def _str_to_bool(v): 89 | if isinstance(v, bool): 90 | return v 91 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 92 | return True 93 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 94 | return False 95 | raise argparse.ArgumentTypeError('Boolean value expected.') 96 | 97 | def _parse_comma_sep(s): 98 | if s is None or s.lower() == 'none' or s == '': 99 | return [] 100 | return s.split(',') 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | _cmdline_help_epilog = '''examples: 105 | 106 | # Previous training run: look up options automatically, save result to text file. 107 | python %(prog)s --metrics=pr50k3_full \\ 108 | --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl 109 | 110 | # Pretrained network pickle: specify dataset explicitly, print result to stdout. 111 | python %(prog)s --metrics=fid50k_full --metricdata=~/datasets/ffhq --mirror=1 \\ 112 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl 113 | 114 | available metrics: 115 | 116 | ADA paper: 117 | fid50k_full Frechet inception distance against the full dataset. 118 | kid50k_full Kernel inception distance against the full dataset. 119 | pr50k3_full Precision and recall againt the full dataset. 120 | is50k Inception score for CIFAR-10. 121 | 122 | Legacy: StyleGAN2 123 | fid50k Frechet inception distance against 50k real images. 124 | kid50k Kernel inception distance against 50k real images. 125 | pr50k3 Precision and recall against 50k real images. 126 | ppl2_wend Perceptual path length in W at path endpoints against full image. 127 | 128 | Legacy: StyleGAN 129 | ppl_zfull Perceptual path length in Z for full paths against cropped image. 130 | ppl_wfull Perceptual path length in W for full paths against cropped image. 131 | ppl_zend Perceptual path length in Z at path endpoints against cropped image. 132 | ppl_wend Perceptual path length in W at path endpoints against cropped image. 133 | ls Linear separability with respect to CelebA attributes. 134 | ''' 135 | 136 | #---------------------------------------------------------------------------- 137 | 138 | def main(): 139 | parser = argparse.ArgumentParser( 140 | description='Calculate quality metrics for previous training run or pretrained network pickle.', 141 | epilog=_cmdline_help_epilog, 142 | formatter_class=argparse.RawDescriptionHelpFormatter 143 | ) 144 | 145 | parser.add_argument('--network', help='Network pickle filename or URL', dest='network_pkl', metavar='PATH') 146 | parser.add_argument('--metrics', help='Comma-separated list or "none" (default: %(default)s)', dest='metric_names', type=_parse_comma_sep, default='fid50k_full', metavar='LIST') 147 | parser.add_argument('--metricdata', help='Dataset to evaluate metrics against (default: look up from training options)', metavar='PATH') 148 | parser.add_argument('--mirror', help='Whether the dataset was augmented with x-flips during training (default: look up from training options)', type=_str_to_bool, metavar='BOOL') 149 | parser.add_argument('--gpus', help='Number of GPUs to use (default: %(default)s)', type=int, default=1, metavar='INT') 150 | 151 | args = parser.parse_args() 152 | try: 153 | calc_metrics(**vars(args)) 154 | except UserError as err: 155 | print(f'Error: {err}') 156 | exit(1) 157 | 158 | #---------------------------------------------------------------------------- 159 | 160 | if __name__ == "__main__": 161 | main() 162 | 163 | #---------------------------------------------------------------------------- 164 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from . import autosummary 10 | from . import network 11 | from . import optimizer 12 | from . import tfutil 13 | from . import custom_ops 14 | 15 | from .tfutil import * 16 | from .network import Network 17 | 18 | from .optimizer import Optimizer 19 | 20 | from .custom_ops import get_plugin 21 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Helper for adding automatically tracked values to Tensorboard. 10 | 11 | Autosummary creates an identity op that internally keeps track of the input 12 | values and automatically shows up in TensorBoard. The reported value 13 | represents an average over input components. The average is accumulated 14 | constantly over time and flushed when save_summaries() is called. 15 | 16 | Notes: 17 | - The output tensor must be used as an input for something else in the 18 | graph. Otherwise, the autosummary op will not get executed, and the average 19 | value will not get accumulated. 20 | - It is perfectly fine to include autosummaries with the same name in 21 | several places throughout the graph, even if they are executed concurrently. 22 | - It is ok to also pass in a python scalar or numpy array. In this case, it 23 | is added to the average immediately. 24 | """ 25 | 26 | from collections import OrderedDict 27 | import numpy as np 28 | import tensorflow as tf 29 | from tensorboard import summary as summary_lib 30 | from tensorboard.plugins.custom_scalar import layout_pb2 31 | 32 | from . import tfutil 33 | from .tfutil import TfExpression 34 | from .tfutil import TfExpressionEx 35 | 36 | # Enable "Custom scalars" tab in TensorBoard for advanced formatting. 37 | # Disabled by default to reduce tfevents file size. 38 | enable_custom_scalars = False 39 | 40 | _dtype = tf.float64 41 | _vars = OrderedDict() # name => [var, ...] 42 | _immediate = OrderedDict() # name => update_op, update_value 43 | _finalized = False 44 | _merge_op = None 45 | 46 | 47 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 48 | """Internal helper for creating autosummary accumulators.""" 49 | assert not _finalized 50 | name_id = name.replace("/", "_") 51 | v = tf.cast(value_expr, _dtype) 52 | 53 | if v.shape.is_fully_defined(): 54 | size = np.prod(v.shape.as_list()) 55 | size_expr = tf.constant(size, dtype=_dtype) 56 | else: 57 | size = None 58 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 59 | 60 | if size == 1: 61 | if v.shape.ndims != 0: 62 | v = tf.reshape(v, []) 63 | v = [size_expr, v, tf.square(v)] 64 | else: 65 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 66 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 67 | 68 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 69 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 70 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 71 | 72 | if name in _vars: 73 | _vars[name].append(var) 74 | else: 75 | _vars[name] = [var] 76 | return update_op 77 | 78 | 79 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: 80 | """Create a new autosummary. 81 | 82 | Args: 83 | name: Name to use in TensorBoard 84 | value: TensorFlow expression or python value to track 85 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 86 | 87 | Example use of the passthru mechanism: 88 | 89 | n = autosummary('l2loss', loss, passthru=n) 90 | 91 | This is a shorthand for the following code: 92 | 93 | with tf.control_dependencies([autosummary('l2loss', loss)]): 94 | n = tf.identity(n) 95 | """ 96 | tfutil.assert_tf_initialized() 97 | name_id = name.replace("/", "_") 98 | 99 | if tfutil.is_tf_expression(value): 100 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 101 | condition = tf.convert_to_tensor(condition, name='condition') 102 | update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) 103 | with tf.control_dependencies([update_op]): 104 | return tf.identity(value if passthru is None else passthru) 105 | 106 | else: # python scalar or numpy array 107 | assert not tfutil.is_tf_expression(passthru) 108 | assert not tfutil.is_tf_expression(condition) 109 | if condition: 110 | if name not in _immediate: 111 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 112 | update_value = tf.placeholder(_dtype) 113 | update_op = _create_var(name, update_value) 114 | _immediate[name] = update_op, update_value 115 | update_op, update_value = _immediate[name] 116 | tfutil.run(update_op, {update_value: value}) 117 | return value if passthru is None else passthru 118 | 119 | 120 | def finalize_autosummaries() -> None: 121 | """Create the necessary ops to include autosummaries in TensorBoard report. 122 | Note: This should be done only once per graph. 123 | """ 124 | global _finalized 125 | tfutil.assert_tf_initialized() 126 | 127 | if _finalized: 128 | return None 129 | 130 | _finalized = True 131 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 132 | 133 | # Create summary ops. 134 | with tf.device(None), tf.control_dependencies(None): 135 | for name, vars_list in _vars.items(): 136 | name_id = name.replace("/", "_") 137 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 138 | moments = tf.add_n(vars_list) 139 | moments /= moments[0] 140 | with tf.control_dependencies([moments]): # read before resetting 141 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 142 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 143 | mean = moments[1] 144 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 145 | tf.summary.scalar(name, mean) 146 | if enable_custom_scalars: 147 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 148 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 149 | 150 | # Setup layout for custom scalars. 151 | layout = None 152 | if enable_custom_scalars: 153 | cat_dict = OrderedDict() 154 | for series_name in sorted(_vars.keys()): 155 | p = series_name.split("/") 156 | cat = p[0] if len(p) >= 2 else "" 157 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 158 | if cat not in cat_dict: 159 | cat_dict[cat] = OrderedDict() 160 | if chart not in cat_dict[cat]: 161 | cat_dict[cat][chart] = [] 162 | cat_dict[cat][chart].append(series_name) 163 | categories = [] 164 | for cat_name, chart_dict in cat_dict.items(): 165 | charts = [] 166 | for chart_name, series_names in chart_dict.items(): 167 | series = [] 168 | for series_name in series_names: 169 | series.append(layout_pb2.MarginChartContent.Series( 170 | value=series_name, 171 | lower="xCustomScalars/" + series_name + "/margin_lo", 172 | upper="xCustomScalars/" + series_name + "/margin_hi")) 173 | margin = layout_pb2.MarginChartContent(series=series) 174 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 175 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 176 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 177 | return layout 178 | 179 | def save_summaries(file_writer, global_step=None): 180 | """Call FileWriter.add_summary() with all summaries in the default graph, 181 | automatically finalizing and merging them on the first call. 182 | """ 183 | global _merge_op 184 | tfutil.assert_tf_initialized() 185 | 186 | if _merge_op is None: 187 | layout = finalize_autosummaries() 188 | if layout is not None: 189 | file_writer.add_summary(layout) 190 | with tf.device(None), tf.control_dependencies(None): 191 | _merge_op = tf.summary.merge_all() 192 | 193 | file_writer.add_summary(_merge_op.eval(), global_step) 194 | -------------------------------------------------------------------------------- /dnnlib/tflib/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """TensorFlow custom ops builder. 10 | """ 11 | 12 | import glob 13 | import os 14 | import re 15 | import uuid 16 | import hashlib 17 | import tempfile 18 | import shutil 19 | import tensorflow as tf 20 | from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module 21 | 22 | from .. import util 23 | 24 | #---------------------------------------------------------------------------- 25 | # Global options. 26 | 27 | cuda_cache_path = None 28 | cuda_cache_version_tag = 'v1' 29 | do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change. 30 | verbose = True # Print status messages to stdout. 31 | 32 | #---------------------------------------------------------------------------- 33 | # Internal helper funcs. 34 | 35 | def _find_compiler_bindir(): 36 | hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) 37 | if hostx64_paths != []: 38 | return hostx64_paths[0] 39 | hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) 40 | if hostx64_paths != []: 41 | return hostx64_paths[0] 42 | hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) 43 | if hostx64_paths != []: 44 | return hostx64_paths[0] 45 | vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' 46 | if os.path.isdir(vc_bin_dir): 47 | return vc_bin_dir 48 | return None 49 | 50 | def _get_compute_cap(device): 51 | caps_str = device.physical_device_desc 52 | m = re.search('compute capability: (\\d+).(\\d+)', caps_str) 53 | major = m.group(1) 54 | minor = m.group(2) 55 | return (major, minor) 56 | 57 | def _get_cuda_gpu_arch_string(): 58 | gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] 59 | if len(gpus) == 0: 60 | raise RuntimeError('No GPU devices found') 61 | (major, minor) = _get_compute_cap(gpus[0]) 62 | return 'sm_%s%s' % (major, minor) 63 | 64 | def _run_cmd(cmd): 65 | with os.popen(cmd) as pipe: 66 | output = pipe.read() 67 | status = pipe.close() 68 | if status is not None: 69 | raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) 70 | 71 | def _prepare_nvcc_cli(opts): 72 | cmd = 'nvcc ' + opts.strip() 73 | cmd += ' --disable-warnings' 74 | cmd += ' --include-path "%s"' % tf.sysconfig.get_include() 75 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') 76 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') 77 | cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') 78 | 79 | compiler_bindir = _find_compiler_bindir() 80 | if compiler_bindir is None: 81 | # Require that _find_compiler_bindir succeeds on Windows. Allow 82 | # nvcc to use whatever is the default on Linux. 83 | if os.name == 'nt': 84 | raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) 85 | else: 86 | cmd += ' --compiler-bindir "%s"' % compiler_bindir 87 | cmd += ' 2>&1' 88 | return cmd 89 | 90 | #---------------------------------------------------------------------------- 91 | # Main entry point. 92 | 93 | _plugin_cache = dict() 94 | 95 | def get_plugin(cuda_file, extra_nvcc_options=[]): 96 | cuda_file_base = os.path.basename(cuda_file) 97 | cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) 98 | 99 | # Already in cache? 100 | if cuda_file in _plugin_cache: 101 | return _plugin_cache[cuda_file] 102 | 103 | # Setup plugin. 104 | if verbose: 105 | print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) 106 | try: 107 | # Hash CUDA source. 108 | md5 = hashlib.md5() 109 | with open(cuda_file, 'rb') as f: 110 | md5.update(f.read()) 111 | md5.update(b'\n') 112 | 113 | # Hash headers included by the CUDA code by running it through the preprocessor. 114 | if not do_not_hash_included_headers: 115 | if verbose: 116 | print('Preprocessing... ', end='', flush=True) 117 | with tempfile.TemporaryDirectory() as tmp_dir: 118 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) 119 | _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) 120 | with open(tmp_file, 'rb') as f: 121 | bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros 122 | good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') 123 | for ln in f: 124 | if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas 125 | ln = ln.replace(bad_file_str, good_file_str) 126 | md5.update(ln) 127 | md5.update(b'\n') 128 | 129 | # Select compiler options. 130 | compile_opts = '' 131 | if os.name == 'nt': 132 | compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') 133 | elif os.name == 'posix': 134 | compile_opts += f' --compiler-options \'-fPIC\'' 135 | compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\'' 136 | compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\'' 137 | else: 138 | assert False # not Windows or Linux, w00t? 139 | compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}' 140 | compile_opts += ' --use_fast_math' 141 | for opt in extra_nvcc_options: 142 | compile_opts += ' ' + opt 143 | nvcc_cmd = _prepare_nvcc_cli(compile_opts) 144 | 145 | # Hash build configuration. 146 | md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') 147 | md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') 148 | md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') 149 | 150 | # Compile if not already compiled. 151 | cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path 152 | bin_file_ext = '.dll' if os.name == 'nt' else '.so' 153 | bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) 154 | if not os.path.isfile(bin_file): 155 | if verbose: 156 | print('Compiling... ', end='', flush=True) 157 | with tempfile.TemporaryDirectory() as tmp_dir: 158 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) 159 | _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) 160 | os.makedirs(cache_dir, exist_ok=True) 161 | intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) 162 | shutil.copyfile(tmp_file, intermediate_file) 163 | os.rename(intermediate_file, bin_file) # atomic 164 | 165 | # Load. 166 | if verbose: 167 | print('Loading... ', end='', flush=True) 168 | plugin = tf.load_op_library(bin_file) 169 | 170 | # Add to cache. 171 | _plugin_cache[cuda_file] = plugin 172 | if verbose: 173 | print('Done.', flush=True) 174 | return plugin 175 | 176 | except: 177 | if verbose: 178 | print('Failed!', flush=True) 179 | raise 180 | 181 | #---------------------------------------------------------------------------- 182 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/fused_bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #define EIGEN_USE_GPU 10 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 11 | #include "tensorflow/core/framework/op.h" 12 | #include "tensorflow/core/framework/op_kernel.h" 13 | #include "tensorflow/core/framework/shape_inference.h" 14 | #include 15 | 16 | using namespace tensorflow; 17 | using namespace tensorflow::shape_inference; 18 | 19 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 20 | 21 | //------------------------------------------------------------------------ 22 | // CUDA kernel. 23 | 24 | template 25 | struct FusedBiasActKernelParams 26 | { 27 | const T* x; // [sizeX] 28 | const T* b; // [sizeB] or NULL 29 | const T* xref; // [sizeX] or NULL 30 | const T* yref; // [sizeX] or NULL 31 | T* y; // [sizeX] 32 | 33 | int grad; 34 | int axis; 35 | int act; 36 | float alpha; 37 | float gain; 38 | float clamp; 39 | 40 | int sizeX; 41 | int sizeB; 42 | int stepB; 43 | int loopX; 44 | }; 45 | 46 | template 47 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) 48 | { 49 | const float expRange = 80.0f; 50 | const float halfExpRange = 40.0f; 51 | const float seluScale = 1.0507009873554804934193349852946f; 52 | const float seluAlpha = 1.6732632423543772848170429916717f; 53 | 54 | // Loop over elements. 55 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 56 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 57 | { 58 | // Load and apply bias. 59 | float x = (float)p.x[xi]; 60 | if (p.b) 61 | x += (float)p.b[(xi / p.stepB) % p.sizeB]; 62 | float xref = (p.xref) ? (float)p.xref[xi] : 0.0f; 63 | float yref = (p.yref) ? (float)p.yref[xi] : 0.0f; 64 | float yy = (p.gain != 0.0f) ? yref / p.gain : 0.0f; 65 | 66 | // Evaluate activation func. 67 | float y; 68 | switch (p.act * 10 + p.grad) 69 | { 70 | // linear 71 | default: 72 | case 10: y = x; break; 73 | case 11: y = x; break; 74 | case 12: y = 0.0f; break; 75 | 76 | // relu 77 | case 20: y = (x > 0.0f) ? x : 0.0f; break; 78 | case 21: y = (yy > 0.0f) ? x : 0.0f; break; 79 | case 22: y = 0.0f; break; 80 | 81 | // lrelu 82 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break; 83 | case 31: y = (yy > 0.0f) ? x : x * p.alpha; break; 84 | case 32: y = 0.0f; break; 85 | 86 | // tanh 87 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; 88 | case 41: y = x * (1.0f - yy * yy); break; 89 | case 42: y = x * (1.0f - yy * yy) * (-2.0f * yy); break; 90 | 91 | // sigmoid 92 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; 93 | case 51: y = x * yy * (1.0f - yy); break; 94 | case 52: y = x * yy * (1.0f - yy) * (1.0f - 2.0f * yy); break; 95 | 96 | // elu 97 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; 98 | case 61: y = (yy >= 0.0f) ? x : x * (yy + 1.0f); break; 99 | case 62: y = (yy >= 0.0f) ? 0.0f : x * (yy + 1.0f); break; 100 | 101 | // selu 102 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; 103 | case 71: y = (yy >= 0.0f) ? x * seluScale : x * (yy + seluScale * seluAlpha); break; 104 | case 72: y = (yy >= 0.0f) ? 0.0f : x * (yy + seluScale * seluAlpha); break; 105 | 106 | // softplus 107 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; 108 | case 81: y = x * (1.0f - expf(-yy)); break; 109 | case 82: { float c = expf(-yy); y = x * c * (1.0f - c); } break; 110 | 111 | // swish 112 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; 113 | case 91: 114 | case 92: 115 | { 116 | float c = expf(xref); 117 | float d = c + 1.0f; 118 | if (p.grad == 1) 119 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 120 | else 121 | y = (xref > halfExpRange) ? 0.0f : x * c * (xref * (2.0f - d) + 2.0f * d) / (d * d * d); 122 | yref = (xref < -expRange) ? 0.0f : xref / (expf(-xref) + 1.0f) * p.gain; 123 | } 124 | break; 125 | } 126 | 127 | // Apply gain. 128 | y *= p.gain; 129 | 130 | // Clamp. 131 | if (p.clamp >= 0.0f) 132 | { 133 | if (p.grad == 0) 134 | y = (fabsf(y) < p.clamp) ? y : (y >= 0.0f) ? p.clamp : -p.clamp; 135 | else 136 | y = (fabsf(yref) < p.clamp) ? y : 0.0f; 137 | } 138 | 139 | // Store. 140 | p.y[xi] = (T)y; 141 | } 142 | } 143 | 144 | //------------------------------------------------------------------------ 145 | // TensorFlow op. 146 | 147 | template 148 | struct FusedBiasActOp : public OpKernel 149 | { 150 | FusedBiasActKernelParams m_attribs; 151 | 152 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) 153 | { 154 | memset(&m_attribs, 0, sizeof(m_attribs)); 155 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); 156 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); 157 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); 158 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); 159 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); 160 | OP_REQUIRES_OK(ctx, ctx->GetAttr("clamp", &m_attribs.clamp)); 161 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); 162 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); 163 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); 164 | } 165 | 166 | void Compute(OpKernelContext* ctx) 167 | { 168 | FusedBiasActKernelParams p = m_attribs; 169 | cudaStream_t stream = ctx->eigen_device().stream(); 170 | 171 | const Tensor& x = ctx->input(0); // [...] 172 | const Tensor& b = ctx->input(1); // [sizeB] or [0] 173 | const Tensor& xref = ctx->input(2); // x.shape or [0] 174 | const Tensor& yref = ctx->input(3); // x.shape or [0] 175 | p.x = x.flat().data(); 176 | p.b = (b.NumElements()) ? b.flat().data() : NULL; 177 | p.xref = (xref.NumElements()) ? xref.flat().data() : NULL; 178 | p.yref = (yref.NumElements()) ? yref.flat().data() : NULL; 179 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); 180 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); 181 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); 182 | OP_REQUIRES(ctx, xref.NumElements() == 0 || xref.NumElements() == x.NumElements(), errors::InvalidArgument("xref has wrong number of elements")); 183 | OP_REQUIRES(ctx, yref.NumElements() == 0 || yref.NumElements() == x.NumElements(), errors::InvalidArgument("yref has wrong number of elements")); 184 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); 185 | 186 | p.sizeX = (int)x.NumElements(); 187 | p.sizeB = (int)b.NumElements(); 188 | p.stepB = 1; 189 | for (int i = m_attribs.axis + 1; i < x.dims(); i++) 190 | p.stepB *= (int)x.dim_size(i); 191 | 192 | Tensor* y = NULL; // x.shape 193 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); 194 | p.y = y->flat().data(); 195 | 196 | p.loopX = 4; 197 | int blockSize = 4 * 32; 198 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 199 | void* args[] = {&p}; 200 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); 201 | } 202 | }; 203 | 204 | REGISTER_OP("FusedBiasAct") 205 | .Input ("x: T") 206 | .Input ("b: T") 207 | .Input ("xref: T") 208 | .Input ("yref: T") 209 | .Output ("y: T") 210 | .Attr ("T: {float, half}") 211 | .Attr ("grad: int = 0") 212 | .Attr ("axis: int = 1") 213 | .Attr ("act: int = 0") 214 | .Attr ("alpha: float = 0.0") 215 | .Attr ("gain: float = 1.0") 216 | .Attr ("clamp: float = -1.0"); 217 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 218 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 219 | 220 | //------------------------------------------------------------------------ 221 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/fused_bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom TensorFlow ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import tensorflow as tf 14 | from .. import custom_ops 15 | from ...util import EasyDict 16 | 17 | def _get_plugin(): 18 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | activation_funcs = { 23 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), 24 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), 25 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), 26 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), 27 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), 28 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), 29 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), 30 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), 31 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), 32 | } 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 37 | r"""Fused bias and activation function. 38 | 39 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 40 | and scales the result by `gain`. Each of the steps is optional. In most cases, 41 | the fused op is considerably more efficient than performing the same calculation 42 | using standard TensorFlow ops. It supports first and second order gradients, 43 | but not third order gradients. 44 | 45 | Args: 46 | x: Input activation tensor. Can have any shape, but if `b` is defined, the 47 | dimension corresponding to `axis`, as well as the rank, must be known. 48 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 49 | as `x`. The shape must be known, and it must match the dimension of `x` 50 | corresponding to `axis`. 51 | axis: The dimension in `x` corresponding to the elements of `b`. 52 | The value of `axis` is ignored if `b` is not specified. 53 | act: Name of the activation function to evaluate, or `"linear"` to disable. 54 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 55 | See `activation_funcs` for a full list. `None` is not allowed. 56 | alpha: Shape parameter for the activation function, or `None` to use the default. 57 | gain: Scaling factor for the output tensor, or `None` to use default. 58 | See `activation_funcs` for the default scaling of each activation function. 59 | If unsure, consider specifying `1.0`. 60 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 61 | the clamping (default). 62 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 63 | 64 | Returns: 65 | Tensor of the same shape and datatype as `x`. 66 | """ 67 | 68 | impl_dict = { 69 | 'ref': _fused_bias_act_ref, 70 | 'cuda': _fused_bias_act_cuda, 71 | } 72 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp) 73 | 74 | #---------------------------------------------------------------------------- 75 | 76 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain, clamp): 77 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" 78 | 79 | # Validate arguments. 80 | x = tf.convert_to_tensor(x) 81 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) 82 | act_spec = activation_funcs[act] 83 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 84 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 85 | if alpha is None: 86 | alpha = act_spec.def_alpha 87 | if gain is None: 88 | gain = act_spec.def_gain 89 | 90 | # Add bias. 91 | if b.shape[0] != 0: 92 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) 93 | 94 | # Evaluate activation function. 95 | x = act_spec.func(x, alpha=alpha) 96 | 97 | # Scale by gain. 98 | if gain != 1: 99 | x *= gain 100 | 101 | # Clamp. 102 | if clamp is not None: 103 | clamp = np.asarray(clamp, dtype=x.dtype.name) 104 | assert clamp.shape == () and clamp >= 0 105 | x = tf.clip_by_value(x, -clamp, clamp) 106 | return x 107 | 108 | #---------------------------------------------------------------------------- 109 | 110 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain, clamp): 111 | """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" 112 | 113 | # Validate arguments. 114 | x = tf.convert_to_tensor(x) 115 | empty_tensor = tf.constant([], dtype=x.dtype) 116 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor 117 | act_spec = activation_funcs[act] 118 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 119 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 120 | if alpha is None: 121 | alpha = act_spec.def_alpha 122 | if gain is None: 123 | gain = act_spec.def_gain 124 | 125 | # Special cases. 126 | if act == 'linear' and b is None and gain == 1.0: 127 | return x 128 | if act_spec.cuda_idx is None: 129 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp) 130 | 131 | # CUDA op. 132 | cuda_op = _get_plugin().fused_bias_act 133 | cuda_kwargs = dict(axis=int(axis), act=int(act_spec.cuda_idx), gain=float(gain)) 134 | if alpha is not None: 135 | cuda_kwargs['alpha'] = float(alpha) 136 | if clamp is not None: 137 | clamp = np.asarray(clamp, dtype=x.dtype.name) 138 | assert clamp.shape == () and clamp >= 0 139 | cuda_kwargs['clamp'] = float(clamp.astype(np.float32)) 140 | def ref(tensor, name): 141 | return tensor if act_spec.ref == name else empty_tensor 142 | 143 | # Forward pass: y = func(x, b). 144 | def func_y(x, b): 145 | y = cuda_op(x=x, b=b, xref=empty_tensor, yref=empty_tensor, grad=0, **cuda_kwargs) 146 | y.set_shape(x.shape) 147 | return y 148 | 149 | # Backward pass: dx, db = grad(dy, x, y) 150 | def grad_dx(dy, x, y): 151 | dx = cuda_op(x=dy, b=empty_tensor, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs) 152 | dx.set_shape(x.shape) 153 | return dx 154 | def grad_db(dx): 155 | if b.shape[0] == 0: 156 | return empty_tensor 157 | db = dx 158 | if axis < x.shape.rank - 1: 159 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) 160 | if axis > 0: 161 | db = tf.reduce_sum(db, list(range(axis))) 162 | db.set_shape(b.shape) 163 | return db 164 | 165 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) 166 | def grad2_d_dy(d_dx, d_db, x, y): 167 | d_dy = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs) 168 | d_dy.set_shape(x.shape) 169 | return d_dy 170 | def grad2_d_x(d_dx, d_db, x, y): 171 | d_x = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=2, **cuda_kwargs) 172 | d_x.set_shape(x.shape) 173 | return d_x 174 | 175 | # Fast version for piecewise-linear activation funcs. 176 | @tf.custom_gradient 177 | def func_zero_2nd_grad(x, b): 178 | y = func_y(x, b) 179 | @tf.custom_gradient 180 | def grad(dy): 181 | dx = grad_dx(dy, x, y) 182 | db = grad_db(dx) 183 | def grad2(d_dx, d_db): 184 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 185 | return d_dy 186 | return (dx, db), grad2 187 | return y, grad 188 | 189 | # Slow version for general activation funcs. 190 | @tf.custom_gradient 191 | def func_nonzero_2nd_grad(x, b): 192 | y = func_y(x, b) 193 | def grad_wrap(dy): 194 | @tf.custom_gradient 195 | def grad_impl(dy, x): 196 | dx = grad_dx(dy, x, y) 197 | db = grad_db(dx) 198 | def grad2(d_dx, d_db): 199 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 200 | d_x = grad2_d_x(d_dx, d_db, x, y) 201 | return d_dy, d_x 202 | return (dx, db), grad2 203 | return grad_impl(dy, x) 204 | return y, grad_wrap 205 | 206 | # Which version to use? 207 | if act_spec.zero_2nd_grad: 208 | return func_zero_2nd_grad(x, b) 209 | return func_nonzero_2nd_grad(x, b) 210 | 211 | #---------------------------------------------------------------------------- 212 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Miscellaneous helper utils for Tensorflow.""" 10 | 11 | import os 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | # Silence deprecation warnings from TensorFlow 1.13 onwards 16 | import logging 17 | logging.getLogger('tensorflow').setLevel(logging.ERROR) 18 | import tensorflow.contrib # requires TensorFlow 1.x! 19 | tf.contrib = tensorflow.contrib 20 | 21 | from typing import Any, Iterable, List, Union 22 | 23 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 24 | """A type that represents a valid Tensorflow expression.""" 25 | 26 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 27 | """A type that can be converted to a valid Tensorflow expression.""" 28 | 29 | 30 | def run(*args, **kwargs) -> Any: 31 | """Run the specified ops in the default session.""" 32 | assert_tf_initialized() 33 | return tf.get_default_session().run(*args, **kwargs) 34 | 35 | 36 | def is_tf_expression(x: Any) -> bool: 37 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 38 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 39 | 40 | 41 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 42 | """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" 43 | return [dim.value for dim in shape] 44 | 45 | 46 | def flatten(x: TfExpressionEx) -> TfExpression: 47 | """Shortcut function for flattening a tensor.""" 48 | with tf.name_scope("Flatten"): 49 | return tf.reshape(x, [-1]) 50 | 51 | 52 | def log2(x: TfExpressionEx) -> TfExpression: 53 | """Logarithm in base 2.""" 54 | with tf.name_scope("Log2"): 55 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 56 | 57 | 58 | def exp2(x: TfExpressionEx) -> TfExpression: 59 | """Exponent in base 2.""" 60 | with tf.name_scope("Exp2"): 61 | return tf.exp(x * np.float32(np.log(2.0))) 62 | 63 | 64 | def erfinv(y: TfExpressionEx) -> TfExpression: 65 | """Inverse of the error function.""" 66 | # pylint: disable=no-name-in-module 67 | from tensorflow.python.ops.distributions import special_math 68 | return special_math.erfinv(y) 69 | 70 | 71 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 72 | """Linear interpolation.""" 73 | with tf.name_scope("Lerp"): 74 | return a + (b - a) * t 75 | 76 | 77 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 78 | """Linear interpolation with clip.""" 79 | with tf.name_scope("LerpClip"): 80 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 81 | 82 | 83 | def absolute_name_scope(scope: str) -> tf.name_scope: 84 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 85 | return tf.name_scope(scope + "/") 86 | 87 | 88 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 89 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 90 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 91 | 92 | 93 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 94 | # Defaults. 95 | cfg = dict() 96 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 97 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 98 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 99 | cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares. 100 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 101 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 102 | 103 | # Remove defaults for environment variables that are already set. 104 | for key in list(cfg): 105 | fields = key.split(".") 106 | if fields[0] == "env": 107 | assert len(fields) == 2 108 | if fields[1] in os.environ: 109 | del cfg[key] 110 | 111 | # User overrides. 112 | if config_dict is not None: 113 | cfg.update(config_dict) 114 | return cfg 115 | 116 | 117 | def init_tf(config_dict: dict = None) -> None: 118 | """Initialize TensorFlow session using good default settings.""" 119 | # Skip if already initialized. 120 | if tf.get_default_session() is not None: 121 | return 122 | 123 | # Setup config dict and random seeds. 124 | cfg = _sanitize_tf_config(config_dict) 125 | np_random_seed = cfg["rnd.np_random_seed"] 126 | if np_random_seed is not None: 127 | np.random.seed(np_random_seed) 128 | tf_random_seed = cfg["rnd.tf_random_seed"] 129 | if tf_random_seed == "auto": 130 | tf_random_seed = np.random.randint(1 << 31) 131 | if tf_random_seed is not None: 132 | tf.set_random_seed(tf_random_seed) 133 | 134 | # Setup environment variables. 135 | for key, value in cfg.items(): 136 | fields = key.split(".") 137 | if fields[0] == "env": 138 | assert len(fields) == 2 139 | os.environ[fields[1]] = str(value) 140 | 141 | # Create default TensorFlow session. 142 | create_session(cfg, force_as_default=True) 143 | 144 | 145 | def assert_tf_initialized(): 146 | """Check that TensorFlow session has been initialized.""" 147 | if tf.get_default_session() is None: 148 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 149 | 150 | 151 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 152 | """Create tf.Session based on config dict.""" 153 | # Setup TensorFlow config proto. 154 | cfg = _sanitize_tf_config(config_dict) 155 | config_proto = tf.ConfigProto() 156 | for key, value in cfg.items(): 157 | fields = key.split(".") 158 | if fields[0] not in ["rnd", "env"]: 159 | obj = config_proto 160 | for field in fields[:-1]: 161 | obj = getattr(obj, field) 162 | setattr(obj, fields[-1], value) 163 | 164 | # Create session. 165 | session = tf.Session(config=config_proto) 166 | if force_as_default: 167 | # pylint: disable=protected-access 168 | session._default_session = session.as_default() 169 | session._default_session.enforce_nesting = False 170 | session._default_session.__enter__() 171 | return session 172 | 173 | 174 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 175 | """Initialize all tf.Variables that have not already been initialized. 176 | 177 | Equivalent to the following, but more efficient and does not bloat the tf graph: 178 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 179 | """ 180 | assert_tf_initialized() 181 | if target_vars is None: 182 | target_vars = tf.global_variables() 183 | 184 | test_vars = [] 185 | test_ops = [] 186 | 187 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 188 | for var in target_vars: 189 | assert is_tf_expression(var) 190 | 191 | try: 192 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 193 | except KeyError: 194 | # Op does not exist => variable may be uninitialized. 195 | test_vars.append(var) 196 | 197 | with absolute_name_scope(var.name.split(":")[0]): 198 | test_ops.append(tf.is_variable_initialized(var)) 199 | 200 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 201 | run([var.initializer for var in init_vars]) 202 | 203 | 204 | def set_vars(var_to_value_dict: dict) -> None: 205 | """Set the values of given tf.Variables. 206 | 207 | Equivalent to the following, but more efficient and does not bloat the tf graph: 208 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 209 | """ 210 | assert_tf_initialized() 211 | ops = [] 212 | feed_dict = {} 213 | 214 | for var, value in var_to_value_dict.items(): 215 | assert is_tf_expression(var) 216 | 217 | try: 218 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 219 | except KeyError: 220 | with absolute_name_scope(var.name.split(":")[0]): 221 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 222 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 223 | 224 | ops.append(setter) 225 | feed_dict[setter.op.inputs[1]] = value 226 | 227 | run(ops, feed_dict) 228 | 229 | 230 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 231 | """Create tf.Variable with large initial value without bloating the tf graph.""" 232 | assert_tf_initialized() 233 | assert isinstance(initial_value, np.ndarray) 234 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 235 | var = tf.Variable(zeros, *args, **kwargs) 236 | set_vars({var: initial_value}) 237 | return var 238 | 239 | 240 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 241 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 242 | Can be used as an input transformation for Network.run(). 243 | """ 244 | images = tf.cast(images, tf.float32) 245 | if nhwc_to_nchw: 246 | images = tf.transpose(images, [0, 3, 1, 2]) 247 | return images * ((drange[1] - drange[0]) / 255) + drange[0] 248 | 249 | 250 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 251 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 252 | Can be used as an output transformation for Network.run(). 253 | """ 254 | images = tf.cast(images, tf.float32) 255 | if shrink > 1: 256 | ksize = [1, 1, shrink, shrink] 257 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 258 | if nchw_to_nhwc: 259 | images = tf.transpose(images, [0, 2, 3, 1]) 260 | scale = 255 / (drange[1] - drange[0]) 261 | images = images * scale + (0.5 - drange[0] * scale) 262 | return tf.saturate_cast(images, tf.uint8) 263 | -------------------------------------------------------------------------------- /docs/license.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Nvidia Source Code License-NC 7 | 8 | 56 | 57 | 58 | 59 |

NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)

60 | 61 |
62 | 63 |

1. Definitions

64 | 65 |

“Licensor” means any person or entity that distributes its Work.

66 | 67 |

“Software” means the original work of authorship made available under 68 | this License.

69 | 70 |

“Work” means the Software and any additions to or derivative works of 71 | the Software that are made available under this License.

72 | 73 |

The terms “reproduce,” “reproduction,” “derivative works,” and 74 | “distribution” have the meaning as provided under U.S. copyright law; 75 | provided, however, that for the purposes of this License, derivative 76 | works shall not include works that remain separable from, or merely 77 | link (or bind by name) to the interfaces of, the Work.

78 | 79 |

Works, including the Software, are “made available” under this License 80 | by including in or with the Work either (a) a copyright notice 81 | referencing the applicability of this License to the Work, or (b) a 82 | copy of this License.

83 | 84 |

2. License Grants

85 | 86 |

2.1 Copyright Grant. Subject to the terms and conditions of this 87 | License, each Licensor grants to you a perpetual, worldwide, 88 | non-exclusive, royalty-free, copyright license to reproduce, 89 | prepare derivative works of, publicly display, publicly perform, 90 | sublicense and distribute its Work and any resulting derivative 91 | works in any form.

92 | 93 |

3. Limitations

94 | 95 |

3.1 Redistribution. You may reproduce or distribute the Work only 96 | if (a) you do so under this License, (b) you include a complete 97 | copy of this License with your distribution, and (c) you retain 98 | without modification any copyright, patent, trademark, or 99 | attribution notices that are present in the Work.

100 | 101 |

3.2 Derivative Works. You may specify that additional or different 102 | terms apply to the use, reproduction, and distribution of your 103 | derivative works of the Work (“Your Terms”) only if (a) Your Terms 104 | provide that the use limitation in Section 3.3 applies to your 105 | derivative works, and (b) you identify the specific derivative 106 | works that are subject to Your Terms. Notwithstanding Your Terms, 107 | this License (including the redistribution requirements in Section 108 | 3.1) will continue to apply to the Work itself.

109 | 110 |

3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for 111 | use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work 112 | and any derivative works commercially. As used herein, “non-commercially” means for research or 113 | evaluation purposes only. 114 | 115 |

3.4 Patent Claims. If you bring or threaten to bring a patent claim 116 | against any Licensor (including any claim, cross-claim or 117 | counterclaim in a lawsuit) to enforce any patents that you allege 118 | are infringed by any Work, then your rights under this License from 119 | such Licensor (including the grant in Section 2.1) will terminate immediately. 120 | 121 |

3.5 Trademarks. This License does not grant any rights to use any 122 | Licensor’s or its affiliates’ names, logos, or trademarks, except 123 | as necessary to reproduce the notices described in this License.

124 | 125 |

3.6 Termination. If you violate any term of this License, then your 126 | rights under this License (including the grant in Section 2.1) 127 | will terminate immediately.

128 | 129 |

4. Disclaimer of Warranty.

130 | 131 |

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY 132 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 133 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 134 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 135 | THIS LICENSE.

136 | 137 |

5. Limitation of Liability.

138 | 139 |

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 140 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 141 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 142 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 143 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 144 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 145 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 146 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 147 | THE POSSIBILITY OF SUCH DAMAGES.

148 | 149 |
150 |
151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /docs/stylegan2-ada-teaser-1024x252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvschultz/stylegan2-ada/8f4ab24f494483542d31bf10f4fdb0005dc62739/docs/stylegan2-ada-teaser-1024x252.png -------------------------------------------------------------------------------- /docs/stylegan2-ada-training-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvschultz/stylegan2-ada/8f4ab24f494483542d31bf10f4fdb0005dc62739/docs/stylegan2-ada-training-curves.png -------------------------------------------------------------------------------- /docs/train-help.txt: -------------------------------------------------------------------------------- 1 | usage: train.py [-h] --outdir DIR [--gpus INT] [--snap INT] [--seed INT] [-n] 2 | --data PATH [--res INT] [--mirror BOOL] [--metrics LIST] 3 | [--metricdata PATH] 4 | [--cfg {auto,stylegan2,paper256,paper512,paper1024,cifar,cifarbaseline}] 5 | [--gamma FLOAT] [--kimg INT] [--aug {noaug,ada,fixed,adarv}] 6 | [--p FLOAT] [--target TARGET] 7 | [--augpipe {blit,geom,color,filter,noise,cutout,bg,bgc,bgcf,bgcfn,bgcfnc}] 8 | [--cmethod {nocmethod,bcr,zcr,pagan,wgangp,auxrot,spectralnorm,shallowmap,adropout}] 9 | [--dcap FLOAT] [--resume RESUME] [--freezed INT] 10 | 11 | Train a GAN using the techniques described in the paper 12 | "Training Generative Adversarial Networks with Limited Data". 13 | 14 | optional arguments: 15 | -h, --help show this help message and exit 16 | 17 | general options: 18 | --outdir DIR Where to save the results (required) 19 | --gpus INT Number of GPUs to use (default: 1 gpu) 20 | --snap INT Snapshot interval (default: 50 ticks) 21 | --seed INT Random seed (default: 1000) 22 | -n, --dry-run Print training options and exit 23 | 24 | training dataset: 25 | --data PATH Training dataset path (required) 26 | --res INT Dataset resolution (default: highest available) 27 | --mirror BOOL Augment dataset with x-flips (default: false) 28 | 29 | metrics: 30 | --metrics LIST Comma-separated list or "none" (default: fid50k_full) 31 | --metricdata PATH Dataset to evaluate metrics against (optional) 32 | 33 | base config: 34 | --cfg {auto,stylegan2,paper256,paper512,paper1024,cifar,cifarbaseline} 35 | Base config (default: auto) 36 | --gamma FLOAT Override R1 gamma 37 | --kimg INT Override training duration 38 | 39 | discriminator augmentation: 40 | --aug {noaug,ada,fixed,adarv} 41 | Augmentation mode (default: ada) 42 | --p FLOAT Specify augmentation probability for --aug=fixed 43 | --target TARGET Override ADA target for --aug=ada and --aug=adarv 44 | --augpipe {blit,geom,color,filter,noise,cutout,bg,bgc,bgcf,bgcfn,bgcfnc} 45 | Augmentation pipeline (default: bgc) 46 | 47 | comparison methods: 48 | --cmethod {nocmethod,bcr,zcr,pagan,wgangp,auxrot,spectralnorm,shallowmap,adropout} 49 | Comparison method (default: nocmethod) 50 | --dcap FLOAT Multiplier for discriminator capacity 51 | 52 | transfer learning: 53 | --resume RESUME Resume from network pickle (default: noresume) 54 | --freezed INT Freeze-D (default: 0 discriminator layers) 55 | 56 | examples: 57 | 58 | # Train custom dataset using 1 GPU. 59 | python train.py --outdir=~/training-runs --gpus=1 --data=~/datasets/custom 60 | 61 | # Train class-conditional CIFAR-10 using 2 GPUs. 62 | python train.py --outdir=~/training-runs --gpus=2 --data=~/datasets/cifar10c \ 63 | --cfg=cifar 64 | 65 | # Transfer learn MetFaces from FFHQ using 4 GPUs. 66 | python train.py --outdir=~/training-runs --gpus=4 --data=~/datasets/metfaces \ 67 | --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 68 | 69 | # Reproduce original StyleGAN2 config F. 70 | python train.py --outdir=~/training-runs --gpus=8 --data=~/datasets/ffhq \ 71 | --cfg=stylegan2 --res=1024 --mirror=1 --aug=noaug 72 | 73 | available base configs (--cfg): 74 | auto Automatically select reasonable defaults based on resolution 75 | and GPU count. Good starting point for new datasets. 76 | stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. 77 | paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. 78 | paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. 79 | paper1024 Reproduce results for MetFaces at 1024x1024. 80 | cifar Reproduce results for CIFAR-10 (tuned configuration). 81 | cifarbaseline Reproduce results for CIFAR-10 (baseline configuration). 82 | 83 | transfer learning source networks (--resume): 84 | ffhq256 FFHQ trained at 256x256 resolution. 85 | ffhq512 FFHQ trained at 512x512 resolution. 86 | ffhq1024 FFHQ trained at 1024x1024 resolution. 87 | celebahq256 CelebA-HQ trained at 256x256 resolution. 88 | lsundog256 LSUN Dog trained at 256x256 resolution. 89 | Custom network pickle. 90 | -------------------------------------------------------------------------------- /ffhq_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvschultz/stylegan2-ada/8f4ab24f494483542d31bf10f4fdb0005dc62739/ffhq_dataset/__init__.py -------------------------------------------------------------------------------- /ffhq_dataset/face_alignment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | import os 4 | import PIL.Image 5 | 6 | 7 | def image_align(src_file, dst_file, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True, x_scale=1, y_scale=1, em_scale=0.1, alpha=False): 8 | # Align function from FFHQ dataset pre-processing step 9 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 10 | 11 | lm = np.array(face_landmarks) 12 | lm_chin = lm[0 : 17] # left-right 13 | lm_eyebrow_left = lm[17 : 22] # left-right 14 | lm_eyebrow_right = lm[22 : 27] # left-right 15 | lm_nose = lm[27 : 31] # top-down 16 | lm_nostrils = lm[31 : 36] # top-down 17 | lm_eye_left = lm[36 : 42] # left-clockwise 18 | lm_eye_right = lm[42 : 48] # left-clockwise 19 | lm_mouth_outer = lm[48 : 60] # left-clockwise 20 | lm_mouth_inner = lm[60 : 68] # left-clockwise 21 | 22 | # Calculate auxiliary vectors. 23 | eye_left = np.mean(lm_eye_left, axis=0) 24 | eye_right = np.mean(lm_eye_right, axis=0) 25 | eye_avg = (eye_left + eye_right) * 0.5 26 | eye_to_eye = eye_right - eye_left 27 | mouth_left = lm_mouth_outer[0] 28 | mouth_right = lm_mouth_outer[6] 29 | mouth_avg = (mouth_left + mouth_right) * 0.5 30 | eye_to_mouth = mouth_avg - eye_avg 31 | 32 | # Choose oriented crop rectangle. 33 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 34 | x /= np.hypot(*x) 35 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 36 | x *= x_scale 37 | y = np.flipud(x) * [-y_scale, y_scale] 38 | c = eye_avg + eye_to_mouth * em_scale 39 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 40 | qsize = np.hypot(*x) * 2 41 | 42 | # Load in-the-wild image. 43 | if not os.path.isfile(src_file): 44 | print('\nCannot find source image. Please run "--wilds" before "--align".') 45 | return 46 | img = PIL.Image.open(src_file).convert('RGBA').convert('RGB') 47 | 48 | # Shrink. 49 | shrink = int(np.floor(qsize / output_size * 0.5)) 50 | if shrink > 1: 51 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 52 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 53 | quad /= shrink 54 | qsize /= shrink 55 | 56 | # Crop. 57 | border = max(int(np.rint(qsize * 0.1)), 3) 58 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 59 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 60 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 61 | img = img.crop(crop) 62 | quad -= crop[0:2] 63 | 64 | # Pad. 65 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 66 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 67 | if enable_padding and max(pad) > border - 4: 68 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 69 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 70 | h, w, _ = img.shape 71 | y, x, _ = np.ogrid[:h, :w, :1] 72 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 73 | blur = qsize * 0.02 74 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 75 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 76 | img = np.uint8(np.clip(np.rint(img), 0, 255)) 77 | if alpha: 78 | mask = 1-np.clip(3.0 * mask, 0.0, 1.0) 79 | mask = np.uint8(np.clip(np.rint(mask*255), 0, 255)) 80 | img = np.concatenate((img, mask), axis=2) 81 | img = PIL.Image.fromarray(img, 'RGBA') 82 | else: 83 | img = PIL.Image.fromarray(img, 'RGB') 84 | quad += pad[:2] 85 | 86 | # Transform. 87 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 88 | if output_size < transform_size: 89 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 90 | 91 | # Save aligned image. 92 | img.save(dst_file, 'PNG') 93 | -------------------------------------------------------------------------------- /ffhq_dataset/landmarks_detector.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | 3 | 4 | class LandmarksDetector: 5 | def __init__(self, predictor_model_path): 6 | """ 7 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file 8 | """ 9 | self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used 10 | self.shape_predictor = dlib.shape_predictor(predictor_model_path) 11 | 12 | def get_landmarks(self, image): 13 | img = dlib.load_rgb_image(image) 14 | dets = self.detector(img, 1) 15 | 16 | for detection in dets: 17 | try: 18 | face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()] 19 | yield face_landmarks 20 | except: 21 | print("Exception in get_landmarks()!") 22 | -------------------------------------------------------------------------------- /grid_vid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: lzhbrian (https://lzhbrian.me) 3 | Date: 2020.1.20 4 | Note: mainly modified from: https://github.com/tkarras/progressive_growing_of_gans/blob/master/util_scripts.py#L50 5 | """ 6 | 7 | import numpy as np 8 | from PIL import Image 9 | import os 10 | import scipy 11 | import pickle 12 | import moviepy 13 | import dnnlib 14 | import dnnlib.tflib as tflib 15 | from tqdm import tqdm 16 | 17 | from pathlib import Path 18 | import typer 19 | 20 | 21 | 22 | def load_net(fpath): 23 | tflib.init_tf() 24 | with open(fpath, 'rb') as stream: 25 | _G, _D, Gs = pickle.load(stream, encoding='latin1') 26 | 27 | return Gs 28 | 29 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 30 | 31 | def create_image_grid(images, grid_size=None): 32 | assert images.ndim == 3 or images.ndim == 4 33 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 34 | 35 | if grid_size is not None: 36 | grid_w, grid_h = tuple(grid_size) 37 | else: 38 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 39 | grid_h = max((num - 1) // grid_w + 1, 1) 40 | 41 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 42 | for idx in range(num): 43 | x = (idx % grid_w) * img_w 44 | y = (idx // grid_w) * img_h 45 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 46 | return grid 47 | 48 | # grid_size=[4,4], mp4_fps=25, duration_sec=10.0, smoothing_sec=2.0, truncation_psi=0.7) 49 | from typing import Tuple 50 | 51 | def generate_interpolation_video(net: Path, 52 | mp4: Path = Path("output.mp4"), 53 | truncation_psi:float =0.5, 54 | grid_size: Tuple[int, int]=(1,1), 55 | duration_sec:float =60.0, 56 | smoothing_sec:float =1.0, 57 | mp4_fps:int=30, 58 | mp4_codec='libx264', 59 | random_seed:int = 1000, 60 | minibatch_size:int = 8, 61 | output_width: int = typer.Option(None)): 62 | 63 | Gs = load_net(net) 64 | num_frames = int(np.rint(duration_sec * mp4_fps)) 65 | random_state = np.random.RandomState(random_seed) 66 | 67 | print('Generating latent vectors...') 68 | shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component] 69 | all_latents = random_state.randn(*shape).astype(np.float32) 70 | all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap') 71 | all_latents /= np.sqrt(np.mean(np.square(all_latents))) 72 | 73 | # Frame generation func for moviepy. 74 | def make_frame(t): 75 | frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) 76 | latents = all_latents[frame_idx] 77 | labels = np.zeros([latents.shape[0], 0], np.float32) 78 | images = Gs.run(latents, None, truncation_psi=truncation_psi, randomize_noise=False, output_transform=fmt, minibatch_size=minibatch_size) 79 | 80 | images = images.transpose(0, 3, 1, 2) #NHWC -> NCHW 81 | grid = create_image_grid(images, grid_size).transpose(1, 2, 0) # HWC 82 | if grid.shape[2] == 1: 83 | grid = grid.repeat(3, 2) # grayscale => RGB 84 | return grid 85 | 86 | # Generate video. 87 | import moviepy.editor # pip install moviepy 88 | c = moviepy.editor.VideoClip(make_frame, duration=duration_sec) 89 | if output_width: 90 | c = c.resize(width=output_width) 91 | c.write_videofile(str(mp4), fps=mp4_fps, codec=mp4_codec) 92 | return c 93 | 94 | if __name__ == "__main__": 95 | typer.run(generate_interpolation_video) -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash equilibrium".""" 11 | 12 | import os 13 | import pickle 14 | import numpy as np 15 | import scipy 16 | import tensorflow as tf 17 | import dnnlib 18 | import dnnlib.tflib as tflib 19 | 20 | from metrics import metric_base 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | class FID(metric_base.MetricBase): 25 | def __init__(self, max_reals, num_fakes, minibatch_per_gpu, use_cached_real_stats=True, **kwargs): 26 | super().__init__(**kwargs) 27 | self.max_reals = max_reals 28 | self.num_fakes = num_fakes 29 | self.minibatch_per_gpu = minibatch_per_gpu 30 | self.use_cached_real_stats = use_cached_real_stats 31 | 32 | def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ 33 | minibatch_size = num_gpus * self.minibatch_per_gpu 34 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_features.pkl') as f: # identical to http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 35 | feature_net = pickle.load(f) 36 | 37 | # Calculate statistics for reals. 38 | cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals) 39 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 40 | if self.use_cached_real_stats and os.path.isfile(cache_file): 41 | with open(cache_file, 'rb') as f: 42 | mu_real, sigma_real = pickle.load(f) 43 | else: 44 | nfeat = feature_net.output_shape[1] 45 | mu_real = np.zeros(nfeat) 46 | sigma_real = np.zeros([nfeat, nfeat]) 47 | num_real = 0 48 | for images, _labels, num in self._iterate_reals(minibatch_size): 49 | if self.max_reals is not None: 50 | num = min(num, self.max_reals - num_real) 51 | if images.shape[1] == 1: 52 | images = np.tile(images, [1, 3, 1, 1]) 53 | for feat in list(feature_net.run(images, num_gpus=num_gpus, assume_frozen=True))[:num]: 54 | mu_real += feat 55 | sigma_real += np.outer(feat, feat) 56 | num_real += 1 57 | if self.max_reals is not None and num_real >= self.max_reals: 58 | break 59 | mu_real /= num_real 60 | sigma_real /= num_real 61 | sigma_real -= np.outer(mu_real, mu_real) 62 | with open(cache_file, 'wb') as f: 63 | pickle.dump((mu_real, sigma_real), f) 64 | 65 | # Construct TensorFlow graph. 66 | result_expr = [] 67 | for gpu_idx in range(num_gpus): 68 | with tf.device('/gpu:%d' % gpu_idx): 69 | Gs_clone = Gs.clone() 70 | feature_net_clone = feature_net.clone() 71 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 72 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 73 | images = Gs_clone.get_output_for(latents, labels, **G_kwargs) 74 | if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1]) 75 | images = tflib.convert_images_to_uint8(images) 76 | result_expr.append(feature_net_clone.get_output_for(images)) 77 | 78 | # Calculate statistics for fakes. 79 | feat_fake = [] 80 | for begin in range(0, self.num_fakes, minibatch_size): 81 | self._report_progress(begin, self.num_fakes) 82 | feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0)) 83 | feat_fake = np.stack(feat_fake[:self.num_fakes]) 84 | mu_fake = np.mean(feat_fake, axis=0) 85 | sigma_fake = np.cov(feat_fake, rowvar=False) 86 | 87 | # Calculate FID. 88 | m = np.square(mu_fake - mu_real).sum() 89 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member 90 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 91 | self._report_result(np.real(dist)) 92 | 93 | #---------------------------------------------------------------------------- 94 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper 10 | "Improved techniques for training GANs".""" 11 | 12 | import pickle 13 | import numpy as np 14 | import tensorflow as tf 15 | import dnnlib 16 | import dnnlib.tflib as tflib 17 | 18 | from metrics import metric_base 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | class IS(metric_base.MetricBase): 23 | def __init__(self, num_images, num_splits, minibatch_per_gpu, **kwargs): 24 | super().__init__(**kwargs) 25 | self.num_images = num_images 26 | self.num_splits = num_splits 27 | self.minibatch_per_gpu = minibatch_per_gpu 28 | 29 | def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ 30 | minibatch_size = num_gpus * self.minibatch_per_gpu 31 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_softmax.pkl') as f: 32 | inception = pickle.load(f) 33 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 34 | 35 | # Construct TensorFlow graph. 36 | result_expr = [] 37 | for gpu_idx in range(num_gpus): 38 | with tf.device(f'/gpu:{gpu_idx}'): 39 | Gs_clone = Gs.clone() 40 | inception_clone = inception.clone() 41 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 42 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 43 | images = Gs_clone.get_output_for(latents, labels, **G_kwargs) 44 | if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1]) 45 | images = tflib.convert_images_to_uint8(images) 46 | result_expr.append(inception_clone.get_output_for(images)) 47 | 48 | # Calculate activations for fakes. 49 | for begin in range(0, self.num_images, minibatch_size): 50 | self._report_progress(begin, self.num_images) 51 | end = min(begin + minibatch_size, self.num_images) 52 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 53 | 54 | # Calculate IS. 55 | scores = [] 56 | for i in range(self.num_splits): 57 | part = activations[i * self.num_images // self.num_splits : (i + 1) * self.num_images // self.num_splits] 58 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 59 | kl = np.mean(np.sum(kl, 1)) 60 | scores.append(np.exp(kl)) 61 | self._report_result(np.mean(scores), suffix='_mean') 62 | self._report_result(np.std(scores), suffix='_std') 63 | 64 | #---------------------------------------------------------------------------- 65 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper 10 | "Demystifying MMD GANs".""" 11 | 12 | import os 13 | import pickle 14 | import numpy as np 15 | import tensorflow as tf 16 | import dnnlib 17 | import dnnlib.tflib as tflib 18 | 19 | from metrics import metric_base 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def compute_kid(feat_real, feat_fake, num_subsets=100, max_subset_size=1000): 24 | n = feat_real.shape[1] 25 | m = min(min(feat_real.shape[0], feat_fake.shape[0]), max_subset_size) 26 | t = 0 27 | for _subset_idx in range(num_subsets): 28 | x = feat_fake[np.random.choice(feat_fake.shape[0], m, replace=False)] 29 | y = feat_real[np.random.choice(feat_real.shape[0], m, replace=False)] 30 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 31 | b = (x @ y.T / n + 1) ** 3 32 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 33 | return t / num_subsets / m 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | class KID(metric_base.MetricBase): 38 | def __init__(self, max_reals, num_fakes, minibatch_per_gpu, use_cached_real_stats=True, **kwargs): 39 | super().__init__(**kwargs) 40 | self.max_reals = max_reals 41 | self.num_fakes = num_fakes 42 | self.minibatch_per_gpu = minibatch_per_gpu 43 | self.use_cached_real_stats = use_cached_real_stats 44 | 45 | def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ 46 | minibatch_size = num_gpus * self.minibatch_per_gpu 47 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/inception_v3_features.pkl') as f: # identical to http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 48 | feature_net = pickle.load(f) 49 | 50 | # Calculate statistics for reals. 51 | cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals) 52 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 53 | if self.use_cached_real_stats and os.path.isfile(cache_file): 54 | with open(cache_file, 'rb') as f: 55 | feat_real = pickle.load(f) 56 | else: 57 | feat_real = [] 58 | for images, _labels, num in self._iterate_reals(minibatch_size): 59 | if self.max_reals is not None: 60 | num = min(num, self.max_reals - len(feat_real)) 61 | if images.shape[1] == 1: 62 | images = np.tile(images, [1, 3, 1, 1]) 63 | feat_real += list(feature_net.run(images, num_gpus=num_gpus, assume_frozen=True))[:num] 64 | if self.max_reals is not None and len(feat_real) >= self.max_reals: 65 | break 66 | feat_real = np.stack(feat_real) 67 | with open(cache_file, 'wb') as f: 68 | pickle.dump(feat_real, f) 69 | 70 | # Construct TensorFlow graph. 71 | result_expr = [] 72 | for gpu_idx in range(num_gpus): 73 | with tf.device('/gpu:%d' % gpu_idx): 74 | Gs_clone = Gs.clone() 75 | feature_net_clone = feature_net.clone() 76 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 77 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 78 | images = Gs_clone.get_output_for(latents, labels, **G_kwargs) 79 | if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1]) 80 | images = tflib.convert_images_to_uint8(images) 81 | result_expr.append(feature_net_clone.get_output_for(images)) 82 | 83 | # Calculate statistics for fakes. 84 | feat_fake = [] 85 | for begin in range(0, self.num_fakes, minibatch_size): 86 | self._report_progress(begin, self.num_fakes) 87 | feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0)) 88 | feat_fake = np.stack(feat_fake[:self.num_fakes]) 89 | 90 | # Calculate KID. 91 | kid = compute_kid(feat_real, feat_fake) 92 | self._report_result(np.real(kid), fmt='%-12.8f') 93 | 94 | #---------------------------------------------------------------------------- 95 | -------------------------------------------------------------------------------- /metrics/linear_separability.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Linear Separability (LS) from the paper 10 | "A Style-Based Generator Architecture for Generative Adversarial Networks".""" 11 | 12 | import pickle 13 | from collections import defaultdict 14 | import numpy as np 15 | import sklearn.svm 16 | import tensorflow as tf 17 | import dnnlib 18 | import dnnlib.tflib as tflib 19 | 20 | from metrics import metric_base 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | classifier_urls = [ 25 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-00-male.pkl', 26 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-01-smiling.pkl', 27 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-02-attractive.pkl', 28 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-03-wavy-hair.pkl', 29 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-04-young.pkl', 30 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-05-5-o-clock-shadow.pkl', 31 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-06-arched-eyebrows.pkl', 32 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-07-bags-under-eyes.pkl', 33 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-08-bald.pkl', 34 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-09-bangs.pkl', 35 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-10-big-lips.pkl', 36 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-11-big-nose.pkl', 37 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-12-black-hair.pkl', 38 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-13-blond-hair.pkl', 39 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-14-blurry.pkl', 40 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-15-brown-hair.pkl', 41 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-16-bushy-eyebrows.pkl', 42 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-17-chubby.pkl', 43 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-18-double-chin.pkl', 44 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-19-eyeglasses.pkl', 45 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-20-goatee.pkl', 46 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-21-gray-hair.pkl', 47 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-22-heavy-makeup.pkl', 48 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-23-high-cheekbones.pkl', 49 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-24-mouth-slightly-open.pkl', 50 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-25-mustache.pkl', 51 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-26-narrow-eyes.pkl', 52 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-27-no-beard.pkl', 53 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-28-oval-face.pkl', 54 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-29-pale-skin.pkl', 55 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-30-pointy-nose.pkl', 56 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-31-receding-hairline.pkl', 57 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-32-rosy-cheeks.pkl', 58 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-33-sideburns.pkl', 59 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-34-straight-hair.pkl', 60 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-35-wearing-earrings.pkl', 61 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-36-wearing-hat.pkl', 62 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-37-wearing-lipstick.pkl', 63 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-38-wearing-necklace.pkl', 64 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/celebahq-classifier-39-wearing-necktie.pkl', 65 | ] 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | def prob_normalize(p): 70 | p = np.asarray(p).astype(np.float32) 71 | assert len(p.shape) == 2 72 | return p / np.sum(p) 73 | 74 | def mutual_information(p): 75 | p = prob_normalize(p) 76 | px = np.sum(p, axis=1) 77 | py = np.sum(p, axis=0) 78 | result = 0.0 79 | for x in range(p.shape[0]): 80 | p_x = px[x] 81 | for y in range(p.shape[1]): 82 | p_xy = p[x][y] 83 | p_y = py[y] 84 | if p_xy > 0.0: 85 | result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output 86 | return result 87 | 88 | def entropy(p): 89 | p = prob_normalize(p) 90 | result = 0.0 91 | for x in range(p.shape[0]): 92 | for y in range(p.shape[1]): 93 | p_xy = p[x][y] 94 | if p_xy > 0.0: 95 | result -= p_xy * np.log2(p_xy) 96 | return result 97 | 98 | def conditional_entropy(p): 99 | # H(Y|X) where X corresponds to axis 0, Y to axis 1 100 | # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? 101 | p = prob_normalize(p) 102 | y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) 103 | return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. 104 | 105 | #---------------------------------------------------------------------------- 106 | 107 | class LS(metric_base.MetricBase): 108 | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): 109 | assert num_keep <= num_samples 110 | super().__init__(**kwargs) 111 | self.num_samples = num_samples 112 | self.num_keep = num_keep 113 | self.attrib_indices = attrib_indices 114 | self.minibatch_per_gpu = minibatch_per_gpu 115 | 116 | def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ 117 | minibatch_size = num_gpus * self.minibatch_per_gpu 118 | 119 | # Construct TensorFlow graph for each GPU. 120 | result_expr = [] 121 | for gpu_idx in range(num_gpus): 122 | with tf.device(f'/gpu:{gpu_idx}'): 123 | Gs_clone = Gs.clone() 124 | 125 | # Generate images. 126 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 127 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 128 | dlatents = Gs_clone.components.mapping.get_output_for(latents, labels, **G_kwargs) 129 | images = Gs_clone.get_output_for(latents, None, **G_kwargs) 130 | if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1]) 131 | 132 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 133 | if images.shape[2] > 256: 134 | factor = images.shape[2] // 256 135 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 136 | images = tf.reduce_mean(images, axis=[3, 5]) 137 | 138 | # Run classifier for each attribute. 139 | result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) 140 | for attrib_idx in self.attrib_indices: 141 | with dnnlib.util.open_url(classifier_urls[attrib_idx]) as f: 142 | classifier = pickle.load(f) 143 | logits = classifier.get_output_for(images, None) 144 | predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) 145 | result_dict[attrib_idx] = predictions 146 | result_expr.append(result_dict) 147 | 148 | # Sampling loop. 149 | results = [] 150 | for begin in range(0, self.num_samples, minibatch_size): 151 | self._report_progress(begin, self.num_samples) 152 | results += tflib.run(result_expr) 153 | results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} 154 | 155 | # Calculate conditional entropy for each attribute. 156 | conditional_entropies = defaultdict(list) 157 | for attrib_idx in self.attrib_indices: 158 | # Prune the least confident samples. 159 | pruned_indices = list(range(self.num_samples)) 160 | pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) 161 | pruned_indices = pruned_indices[:self.num_keep] 162 | 163 | # Fit SVM to the remaining samples. 164 | svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) 165 | for space in ['latents', 'dlatents']: 166 | svm_inputs = results[space][pruned_indices] 167 | try: 168 | svm = sklearn.svm.LinearSVC() 169 | svm.fit(svm_inputs, svm_targets) 170 | svm.score(svm_inputs, svm_targets) 171 | svm_outputs = svm.predict(svm_inputs) 172 | except: 173 | svm_outputs = svm_targets # assume perfect prediction 174 | 175 | # Calculate conditional entropy. 176 | p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] 177 | conditional_entropies[space].append(conditional_entropy(p)) 178 | 179 | # Calculate separability scores. 180 | scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} 181 | self._report_result(scores['latents'], suffix='_z') 182 | self._report_result(scores['dlatents'], suffix='_w') 183 | 184 | #---------------------------------------------------------------------------- 185 | -------------------------------------------------------------------------------- /metrics/metric_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Common definitions for quality metrics.""" 10 | 11 | import os 12 | import time 13 | import hashlib 14 | import pickle 15 | import numpy as np 16 | import tensorflow as tf 17 | import dnnlib 18 | import dnnlib.tflib as tflib 19 | 20 | from training import dataset 21 | 22 | #---------------------------------------------------------------------------- 23 | # Base class for metrics. 24 | 25 | class MetricBase: 26 | def __init__(self, name, force_dataset_args={}, force_G_kwargs={}): 27 | # Constructor args. 28 | self.name = name 29 | self.force_dataset_args = force_dataset_args 30 | self.force_G_kwargs = force_G_kwargs 31 | 32 | # Configuration. 33 | self._dataset_args = dnnlib.EasyDict() 34 | self._run_dir = None 35 | self._progress_fn = None 36 | 37 | # Internal state. 38 | self._results = [] 39 | self._network_name = '' 40 | self._eval_time = 0 41 | self._dataset = None 42 | 43 | def configure(self, dataset_args={}, run_dir=None, progress_fn=None): 44 | self._dataset_args = dnnlib.EasyDict(dataset_args) 45 | self._dataset_args.update(self.force_dataset_args) 46 | self._run_dir = run_dir 47 | self._progress_fn = progress_fn 48 | 49 | def run(self, network_pkl, num_gpus=1, G_kwargs=dict(is_validation=True)): 50 | self._results = [] 51 | self._network_name = os.path.splitext(os.path.basename(network_pkl))[0] 52 | self._eval_time = 0 53 | self._dataset = None 54 | 55 | with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager 56 | self._report_progress(0, 1) 57 | time_begin = time.time() 58 | with dnnlib.util.open_url(network_pkl) as f: 59 | G, D, Gs = pickle.load(f) 60 | 61 | G_kwargs = dnnlib.EasyDict(G_kwargs) 62 | G_kwargs.update(self.force_G_kwargs) 63 | self._evaluate(G=G, D=D, Gs=Gs, G_kwargs=G_kwargs, num_gpus=num_gpus) 64 | 65 | self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init 66 | self._report_progress(1, 1) 67 | if self._dataset is not None: 68 | self._dataset.close() 69 | self._dataset = None 70 | 71 | result_str = self.get_result_str() 72 | print(result_str) 73 | if self._run_dir is not None and os.path.isdir(self._run_dir): 74 | with open(os.path.join(self._run_dir, f'metric-{self.name}.txt'), 'at') as f: 75 | f.write(result_str + '\n') 76 | 77 | def get_result_str(self): 78 | title = self._network_name 79 | if len(title) > 29: 80 | title = '...' + title[-26:] 81 | result_str = f'{title:<30s} time {dnnlib.util.format_time(self._eval_time):<12s}' 82 | for res in self._results: 83 | result_str += f' {self.name}{res.suffix} {res.fmt % res.value}' 84 | return result_str.strip() 85 | 86 | def update_autosummaries(self): 87 | for res in self._results: 88 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) 89 | 90 | def _evaluate(self, **_kwargs): 91 | raise NotImplementedError # to be overridden by subclasses 92 | 93 | def _report_result(self, value, suffix='', fmt='%-10.4f'): 94 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] 95 | 96 | def _report_progress(self, cur, total): 97 | if self._progress_fn is not None: 98 | self._progress_fn(cur, total) 99 | 100 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs): 101 | all_args = dnnlib.EasyDict(metric_name=self.name) 102 | all_args.update(self._dataset_args) 103 | all_args.update(kwargs) 104 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) 105 | dataset_name = os.path.splitext(os.path.basename(self._dataset_args.path))[0] 106 | return dnnlib.make_cache_dir_path('metrics', f'{md5.hexdigest()}-{self.name}-{dataset_name}.{extension}') 107 | 108 | def _get_dataset_obj(self): 109 | if self._dataset is None: 110 | self._dataset = dataset.load_dataset(**self._dataset_args) 111 | return self._dataset 112 | 113 | def _iterate_reals(self, minibatch_size): 114 | print(f'Calculating real image statistics for {self.name}...') 115 | dataset_obj = self._get_dataset_obj() 116 | while True: 117 | images = [] 118 | labels = [] 119 | for _ in range(minibatch_size): 120 | image, label = dataset_obj.get_minibatch_np(1) 121 | if image is None: 122 | break 123 | images.append(image) 124 | labels.append(label) 125 | num = len(images) 126 | if num == 0: 127 | break 128 | images = np.concatenate(images + [images[-1]] * (minibatch_size - num), axis=0) 129 | labels = np.concatenate(labels + [labels[-1]] * (minibatch_size - num), axis=0) 130 | yield images, labels, num 131 | if num < minibatch_size: 132 | break 133 | 134 | def _get_random_labels_tf(self, minibatch_size): 135 | return self._get_dataset_obj().get_random_labels_tf(minibatch_size) 136 | 137 | #---------------------------------------------------------------------------- 138 | -------------------------------------------------------------------------------- /metrics/metric_defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Default metric definitions.""" 10 | 11 | from dnnlib import EasyDict 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | metric_defaults = EasyDict([(args.name, args) for args in [ 16 | # ADA paper. 17 | EasyDict(name='fid50k_full', class_name='metrics.frechet_inception_distance.FID', max_reals=None, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), 18 | EasyDict(name='kid50k_full', class_name='metrics.kernel_inception_distance.KID', max_reals=1000000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), 19 | EasyDict(name='pr50k3_full', class_name='metrics.precision_recall.PR', max_reals=200000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None, repeat=False, mirror_augment=False)), 20 | EasyDict(name='is50k', class_name='metrics.inception_score.IS', num_images=50000, num_splits=10, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), 21 | 22 | # Legacy: StyleGAN2. 23 | EasyDict(name='fid50k', class_name='metrics.frechet_inception_distance.FID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), 24 | EasyDict(name='kid50k', class_name='metrics.kernel_inception_distance.KID', max_reals=50000, num_fakes=50000, minibatch_per_gpu=8, force_dataset_args=dict(shuffle=False, max_images=None)), 25 | EasyDict(name='pr50k3', class_name='metrics.precision_recall.PR', max_reals=50000, num_fakes=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000, force_dataset_args=dict(shuffle=False, max_images=None)), 26 | EasyDict(name='ppl2_wend', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)), 27 | 28 | # Legacy: StyleGAN. 29 | EasyDict(name='ppl_zfull', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)), 30 | EasyDict(name='ppl_wfull', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)), 31 | EasyDict(name='ppl_zend', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)), 32 | EasyDict(name='ppl_wend', class_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, minibatch_per_gpu=2, force_dataset_args=dict(shuffle=False, max_images=None), force_G_kwargs=dict(dtype='float32', mapping_dtype='float32', num_fp16_res=0)), 33 | EasyDict(name='ls', class_name='metrics.linear_separability.LS', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4, force_dataset_args=dict(shuffle=False, max_images=None)), 34 | ]]) 35 | 36 | #---------------------------------------------------------------------------- 37 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper 10 | "A Style-Based Generator Architecture for Generative Adversarial Networks".""" 11 | 12 | import pickle 13 | import numpy as np 14 | import tensorflow as tf 15 | import dnnlib 16 | import dnnlib.tflib as tflib 17 | 18 | from metrics import metric_base 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | # Normalize batch of vectors. 23 | def normalize(v): 24 | return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) 25 | 26 | # Spherical interpolation of a batch of vectors. 27 | def slerp(a, b, t): 28 | a = normalize(a) 29 | b = normalize(b) 30 | d = tf.reduce_sum(a * b, axis=-1, keepdims=True) 31 | p = t * tf.math.acos(d) 32 | c = normalize(b - d * a) 33 | d = a * tf.math.cos(p) + c * tf.math.sin(p) 34 | return normalize(d) 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class PPL(metric_base.MetricBase): 39 | def __init__(self, num_samples, epsilon, space, sampling, crop, minibatch_per_gpu, **kwargs): 40 | assert space in ['z', 'w'] 41 | assert sampling in ['full', 'end'] 42 | super().__init__(**kwargs) 43 | self.num_samples = num_samples 44 | self.epsilon = epsilon 45 | self.space = space 46 | self.sampling = sampling 47 | self.crop = crop 48 | self.minibatch_per_gpu = minibatch_per_gpu 49 | 50 | def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ 51 | minibatch_size = num_gpus * self.minibatch_per_gpu 52 | 53 | # Construct TensorFlow graph. 54 | distance_expr = [] 55 | for gpu_idx in range(num_gpus): 56 | with tf.device(f'/gpu:{gpu_idx}'): 57 | Gs_clone = Gs.clone() 58 | noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] 59 | 60 | # Generate random latents and interpolation t-values. 61 | lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) 62 | lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) 63 | labels = tf.reshape(tf.tile(self._get_random_labels_tf(self.minibatch_per_gpu), [1, 2]), [self.minibatch_per_gpu * 2, -1]) 64 | 65 | # Interpolate in W or Z. 66 | if self.space == 'w': 67 | dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, labels, **G_kwargs) 68 | dlat_t01 = tf.cast(dlat_t01, tf.float32) 69 | dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] 70 | dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) 71 | dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) 72 | dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) 73 | else: # space == 'z' 74 | lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] 75 | lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) 76 | lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) 77 | lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) 78 | dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, labels, **G_kwargs) 79 | 80 | # Synthesize images. 81 | with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch 82 | images = Gs_clone.components.synthesis.get_output_for(dlat_e01, randomize_noise=False, **G_kwargs) 83 | images = tf.cast(images, tf.float32) 84 | 85 | # Crop only the face region. 86 | if self.crop: 87 | c = int(images.shape[2] // 8) 88 | images = images[:, :, c*3 : c*7, c*2 : c*6] 89 | 90 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 91 | factor = images.shape[2] // 256 92 | if factor > 1: 93 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 94 | images = tf.reduce_mean(images, axis=[3,5]) 95 | 96 | # Scale dynamic range from [-1,1] to [0,255] for VGG. 97 | images = (images + 1) * (255 / 2) 98 | if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1]) 99 | 100 | # Evaluate perceptual distance. 101 | img_e0, img_e1 = images[0::2], images[1::2] 102 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl') as f: 103 | distance_measure = pickle.load(f) 104 | distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) 105 | 106 | # Sampling loop. 107 | all_distances = [] 108 | for begin in range(0, self.num_samples, minibatch_size): 109 | self._report_progress(begin, self.num_samples) 110 | all_distances += tflib.run(distance_expr) 111 | all_distances = np.concatenate(all_distances, axis=0) 112 | 113 | # Reject outliers. 114 | lo = np.percentile(all_distances, 1, interpolation='lower') 115 | hi = np.percentile(all_distances, 99, interpolation='higher') 116 | filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) 117 | self._report_result(np.mean(filtered_distances)) 118 | 119 | #---------------------------------------------------------------------------- 120 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper 10 | "Improved Precision and Recall Metric for Assessing Generative Models".""" 11 | 12 | import os 13 | import pickle 14 | import numpy as np 15 | import tensorflow as tf 16 | import dnnlib 17 | import dnnlib.tflib as tflib 18 | 19 | from metrics import metric_base 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def batch_pairwise_distances(U, V): 24 | """ Compute pairwise distances between two batches of feature vectors.""" 25 | with tf.variable_scope('pairwise_dist_block'): 26 | # Squared norms of each row in U and V. 27 | norm_u = tf.reduce_sum(tf.square(U), 1) 28 | norm_v = tf.reduce_sum(tf.square(V), 1) 29 | 30 | # norm_u as a row and norm_v as a column vectors. 31 | norm_u = tf.reshape(norm_u, [-1, 1]) 32 | norm_v = tf.reshape(norm_v, [1, -1]) 33 | 34 | # Pairwise squared Euclidean distances. 35 | D = tf.maximum(norm_u - 2*tf.matmul(U, V, False, True) + norm_v, 0.0) 36 | 37 | return D 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | class DistanceBlock(): 42 | """Distance block.""" 43 | def __init__(self, num_features, num_gpus): 44 | self.num_features = num_features 45 | self.num_gpus = num_gpus 46 | 47 | # Initialize TF graph to calculate pairwise distances. 48 | with tf.device('/cpu:0'): 49 | self._features_batch1 = tf.placeholder(tf.float16, shape=[None, self.num_features]) 50 | self._features_batch2 = tf.placeholder(tf.float16, shape=[None, self.num_features]) 51 | features_split2 = tf.split(self._features_batch2, self.num_gpus, axis=0) 52 | distances_split = [] 53 | for gpu_idx in range(self.num_gpus): 54 | with tf.device(f'/gpu:{gpu_idx}'): 55 | distances_split.append(batch_pairwise_distances(self._features_batch1, features_split2[gpu_idx])) 56 | self._distance_block = tf.concat(distances_split, axis=1) 57 | 58 | def pairwise_distances(self, U, V): 59 | """Evaluate pairwise distances between two batches of feature vectors.""" 60 | return self._distance_block.eval(feed_dict={self._features_batch1: U, self._features_batch2: V}) 61 | 62 | #---------------------------------------------------------------------------- 63 | 64 | class ManifoldEstimator(): 65 | """Finds an estimate for the manifold of given feature vectors.""" 66 | def __init__(self, distance_block, features, row_batch_size, col_batch_size, nhood_sizes, clamp_to_percentile=None): 67 | """Find an estimate of the manifold of given feature vectors.""" 68 | num_images = features.shape[0] 69 | self.nhood_sizes = nhood_sizes 70 | self.num_nhoods = len(nhood_sizes) 71 | self.row_batch_size = row_batch_size 72 | self.col_batch_size = col_batch_size 73 | self._ref_features = features 74 | self._distance_block = distance_block 75 | 76 | # Estimate manifold of features by calculating distances to kth nearest neighbor of each sample. 77 | self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float16) 78 | distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float16) 79 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 80 | 81 | for begin1 in range(0, num_images, row_batch_size): 82 | end1 = min(begin1 + row_batch_size, num_images) 83 | row_batch = features[begin1:end1] 84 | 85 | for begin2 in range(0, num_images, col_batch_size): 86 | end2 = min(begin2 + col_batch_size, num_images) 87 | col_batch = features[begin2:end2] 88 | 89 | # Compute distances between batches. 90 | distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch, col_batch) 91 | 92 | # Find the kth nearest neighbor from the current batch. 93 | self.D[begin1:end1, :] = np.partition(distance_batch[0:end1-begin1, :], seq, axis=1)[:, self.nhood_sizes] 94 | 95 | if clamp_to_percentile is not None: 96 | max_distances = np.percentile(self.D, clamp_to_percentile, axis=0) 97 | self.D[self.D > max_distances] = 0 #max_distances # 0 98 | 99 | def evaluate(self, eval_features, return_realism=False, return_neighbors=False): 100 | """Evaluate if new feature vectors are in the estimated manifold.""" 101 | num_eval_images = eval_features.shape[0] 102 | num_ref_images = self.D.shape[0] 103 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float16) 104 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 105 | #max_realism_score = np.zeros([num_eval_images,], dtype=np.float32) 106 | realism_score = np.zeros([num_eval_images,], dtype=np.float32) 107 | nearest_indices = np.zeros([num_eval_images,], dtype=np.int32) 108 | 109 | for begin1 in range(0, num_eval_images, self.row_batch_size): 110 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 111 | feature_batch = eval_features[begin1:end1] 112 | 113 | for begin2 in range(0, num_ref_images, self.col_batch_size): 114 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 115 | ref_batch = self._ref_features[begin2:end2] 116 | 117 | distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch, ref_batch) 118 | 119 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 120 | # If a feature vector is inside a hypersphere of some reference sample, then the new sample lies on the estimated manifold. 121 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 122 | samples_in_manifold = distance_batch[0:end1-begin1, :, None] <= self.D 123 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) 124 | 125 | #max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1) 126 | #nearest_indices[begin1:end1] = np.argmax(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1) 127 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1-begin1, :], axis=1) 128 | realism_score[begin1:end1] = self.D[nearest_indices[begin1:end1], 0] / np.min(distance_batch[0:end1-begin1, :], axis=1) 129 | 130 | if return_realism and return_neighbors: 131 | return batch_predictions, realism_score, nearest_indices 132 | elif return_realism: 133 | return batch_predictions, realism_score 134 | elif return_neighbors: 135 | return batch_predictions, nearest_indices 136 | 137 | return batch_predictions 138 | 139 | #---------------------------------------------------------------------------- 140 | 141 | def knn_precision_recall_features(ref_features, eval_features, feature_net, nhood_sizes, 142 | row_batch_size, col_batch_size, num_gpus): 143 | """Calculates k-NN precision and recall for two sets of feature vectors.""" 144 | state = dnnlib.EasyDict() 145 | #num_images = ref_features.shape[0] 146 | num_features = feature_net.output_shape[1] 147 | state.ref_features = ref_features 148 | state.eval_features = eval_features 149 | 150 | # Initialize DistanceBlock and ManifoldEstimators. 151 | distance_block = DistanceBlock(num_features, num_gpus) 152 | state.ref_manifold = ManifoldEstimator(distance_block, state.ref_features, row_batch_size, col_batch_size, nhood_sizes) 153 | state.eval_manifold = ManifoldEstimator(distance_block, state.eval_features, row_batch_size, col_batch_size, nhood_sizes) 154 | 155 | # Evaluate precision and recall using k-nearest neighbors. 156 | #print(f'Evaluating k-NN precision and recall with {num_images} samples...') 157 | #start = time.time() 158 | 159 | # Precision: How many points from eval_features are in ref_features manifold. 160 | state.precision, state.realism_scores, state.nearest_neighbors = state.ref_manifold.evaluate(state.eval_features, return_realism=True, return_neighbors=True) 161 | state.knn_precision = state.precision.mean(axis=0) 162 | 163 | # Recall: How many points from ref_features are in eval_features manifold. 164 | state.recall = state.eval_manifold.evaluate(state.ref_features) 165 | state.knn_recall = state.recall.mean(axis=0) 166 | 167 | #elapsed_time = time.time() - start 168 | #print(f'Done evaluation in: {elapsed_time:g}s') 169 | 170 | return state 171 | 172 | #---------------------------------------------------------------------------- 173 | 174 | class PR(metric_base.MetricBase): 175 | def __init__(self, max_reals, num_fakes, nhood_size, minibatch_per_gpu, row_batch_size, col_batch_size, **kwargs): 176 | super().__init__(**kwargs) 177 | self.max_reals = max_reals 178 | self.num_fakes = num_fakes 179 | self.nhood_size = nhood_size 180 | self.minibatch_per_gpu = minibatch_per_gpu 181 | self.row_batch_size = row_batch_size 182 | self.col_batch_size = col_batch_size 183 | 184 | def _evaluate(self, Gs, G_kwargs, num_gpus, **_kwargs): # pylint: disable=arguments-differ 185 | minibatch_size = num_gpus * self.minibatch_per_gpu 186 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16.pkl') as f: 187 | feature_net = pickle.load(f) 188 | 189 | # Calculate features for reals. 190 | cache_file = self._get_cache_file_for_reals(max_reals=self.max_reals) 191 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 192 | if os.path.isfile(cache_file): 193 | with open(cache_file, 'rb') as f: 194 | feat_real = pickle.load(f) 195 | else: 196 | feat_real = [] 197 | for images, _labels, num in self._iterate_reals(minibatch_size): 198 | if images.shape[1] == 1: images = np.tile(images, [1, 3, 1, 1]) 199 | feat_real += list(feature_net.run(images, num_gpus=num_gpus, assume_frozen=True))[:num] 200 | if self.max_reals is not None and len(feat_real) >= self.max_reals: 201 | break 202 | if self.max_reals is not None and len(feat_real) > self.max_reals: 203 | feat_real = feat_real[:self.max_reals] 204 | feat_real = np.stack(feat_real) 205 | with open(cache_file, 'wb') as f: 206 | pickle.dump(feat_real, f) 207 | 208 | # Construct TensorFlow graph. 209 | result_expr = [] 210 | for gpu_idx in range(num_gpus): 211 | with tf.device(f'/gpu:{gpu_idx}'): 212 | Gs_clone = Gs.clone() 213 | feature_net_clone = feature_net.clone() 214 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 215 | labels = self._get_random_labels_tf(self.minibatch_per_gpu) 216 | images = Gs_clone.get_output_for(latents, labels, **G_kwargs) 217 | if images.shape[1] == 1: images = tf.tile(images, [1, 3, 1, 1]) 218 | images = tflib.convert_images_to_uint8(images) 219 | result_expr.append(feature_net_clone.get_output_for(images)) 220 | 221 | # Calculate features for fakes. 222 | feat_fake = [] 223 | for begin in range(0, self.num_fakes, minibatch_size): 224 | self._report_progress(begin, self.num_fakes) 225 | feat_fake += list(np.concatenate(tflib.run(result_expr), axis=0)) 226 | feat_fake = np.stack(feat_fake[:self.num_fakes]) 227 | 228 | # Calculate precision and recall. 229 | state = knn_precision_recall_features(ref_features=feat_real, eval_features=feat_fake, feature_net=feature_net, 230 | nhood_sizes=[self.nhood_size], row_batch_size=self.row_batch_size, col_batch_size=self.row_batch_size, num_gpus=num_gpus) 231 | self._report_result(state.knn_precision[0], suffix='_precision') 232 | self._report_result(state.knn_recall[0], suffix='_recall') 233 | 234 | #---------------------------------------------------------------------------- 235 | -------------------------------------------------------------------------------- /projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Project given image to the latent space of pretrained network pickle.""" 10 | 11 | import argparse 12 | import os 13 | import pickle 14 | import imageio 15 | 16 | import numpy as np 17 | import PIL.Image 18 | import tensorflow as tf 19 | import tqdm 20 | 21 | import dnnlib 22 | import dnnlib.tflib as tflib 23 | 24 | class Projector: 25 | def __init__(self): 26 | self.num_steps = 1000 27 | self.dlatent_avg_samples = 10000 28 | self.initial_learning_rate = 0.1 29 | self.initial_noise_factor = 0.05 30 | self.lr_rampdown_length = 0.25 31 | self.lr_rampup_length = 0.05 32 | self.noise_ramp_length = 0.75 33 | self.regularize_noise_weight = 1e5 34 | self.verbose = True 35 | 36 | self._Gs = None 37 | self._minibatch_size = None 38 | self._dlatent_avg = None 39 | self._dlatent_std = None 40 | self._noise_vars = None 41 | self._noise_init_op = None 42 | self._noise_normalize_op = None 43 | self._dlatents_var = None 44 | self._dlatent_noise_in = None 45 | self._dlatents_expr = None 46 | self._images_float_expr = None 47 | self._images_uint8_expr = None 48 | self._target_images_var = None 49 | self._lpips = None 50 | self._dist = None 51 | self._loss = None 52 | self._reg_sizes = None 53 | self._lrate_in = None 54 | self._opt = None 55 | self._opt_step = None 56 | self._cur_step = None 57 | 58 | def _info(self, *args): 59 | if self.verbose: 60 | print('Projector:', *args) 61 | 62 | def set_network(self, Gs, dtype='float16'): 63 | if Gs is None: 64 | self._Gs = None 65 | return 66 | self._Gs = Gs.clone(randomize_noise=False, dtype=dtype, num_fp16_res=0, fused_modconv=True) 67 | 68 | # Compute dlatent stats. 69 | self._info(f'Computing W midpoint and stddev using {self.dlatent_avg_samples} samples...') 70 | latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:]) 71 | dlatent_samples = self._Gs.components.mapping.run(latent_samples, None) # [N, L, C] 72 | dlatent_samples = dlatent_samples[:, :1, :].astype(np.float32) # [N, 1, C] 73 | self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 1, C] 74 | self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5 75 | self._info(f'std = {self._dlatent_std:g}') 76 | 77 | # Setup noise inputs. 78 | self._info('Setting up noise inputs...') 79 | self._noise_vars = [] 80 | noise_init_ops = [] 81 | noise_normalize_ops = [] 82 | while True: 83 | n = f'G_synthesis/noise{len(self._noise_vars)}' 84 | if not n in self._Gs.vars: 85 | break 86 | v = self._Gs.vars[n] 87 | self._noise_vars.append(v) 88 | noise_init_ops.append(tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32))) 89 | noise_mean = tf.reduce_mean(v) 90 | noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5 91 | noise_normalize_ops.append(tf.assign(v, (v - noise_mean) / noise_std)) 92 | self._noise_init_op = tf.group(*noise_init_ops) 93 | self._noise_normalize_op = tf.group(*noise_normalize_ops) 94 | 95 | # Build image output graph. 96 | self._info('Building image output graph...') 97 | self._minibatch_size = 1 98 | self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var') 99 | self._dlatent_noise_in = tf.placeholder(tf.float32, [], name='noise_in') 100 | dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._dlatent_noise_in 101 | self._dlatents_expr = tf.tile(self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1]) 102 | self._images_float_expr = tf.cast(self._Gs.components.synthesis.get_output_for(self._dlatents_expr), tf.float32) 103 | self._images_uint8_expr = tflib.convert_images_to_uint8(self._images_float_expr, nchw_to_nhwc=True) 104 | 105 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 106 | proc_images_expr = (self._images_float_expr + 1) * (255 / 2) 107 | sh = proc_images_expr.shape.as_list() 108 | if sh[2] > 256: 109 | factor = sh[2] // 256 110 | proc_images_expr = tf.reduce_mean(tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5]) 111 | 112 | # Build loss graph. 113 | self._info('Building loss graph...') 114 | self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var') 115 | if self._lpips is None: 116 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl') as f: 117 | self._lpips = pickle.load(f) 118 | self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var) 119 | self._loss = tf.reduce_sum(self._dist) 120 | 121 | # Build noise regularization graph. 122 | self._info('Building noise regularization graph...') 123 | reg_loss = 0.0 124 | for v in self._noise_vars: 125 | sz = v.shape[2] 126 | while True: 127 | reg_loss += tf.reduce_mean(v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(v * tf.roll(v, shift=1, axis=2))**2 128 | if sz <= 8: 129 | break # Small enough already 130 | v = tf.reshape(v, [1, 1, sz//2, 2, sz//2, 2]) # Downscale 131 | v = tf.reduce_mean(v, axis=[3, 5]) 132 | sz = sz // 2 133 | self._loss += reg_loss * self.regularize_noise_weight 134 | 135 | # Setup optimizer. 136 | self._info('Setting up optimizer...') 137 | self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in') 138 | self._opt = tflib.Optimizer(learning_rate=self._lrate_in) 139 | self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars) 140 | self._opt_step = self._opt.apply_updates() 141 | 142 | def start(self, target_images): 143 | assert self._Gs is not None 144 | 145 | # Prepare target images. 146 | self._info('Preparing target images...') 147 | target_images = np.asarray(target_images, dtype='float32') 148 | target_images = (target_images + 1) * (255 / 2) 149 | sh = target_images.shape 150 | assert sh[0] == self._minibatch_size 151 | if sh[2] > self._target_images_var.shape[2]: 152 | factor = sh[2] // self._target_images_var.shape[2] 153 | target_images = np.reshape(target_images, [-1, sh[1], sh[2] // factor, factor, sh[3] // factor, factor]).mean((3, 5)) 154 | 155 | # Initialize optimization state. 156 | self._info('Initializing optimization state...') 157 | dlatents = np.tile(self._dlatent_avg, [self._minibatch_size, 1, 1]) 158 | tflib.set_vars({self._target_images_var: target_images, self._dlatents_var: dlatents}) 159 | tflib.run(self._noise_init_op) 160 | self._opt.reset_optimizer_state() 161 | self._cur_step = 0 162 | 163 | def step(self): 164 | assert self._cur_step is not None 165 | if self._cur_step >= self.num_steps: 166 | return 0, 0 167 | 168 | # Choose hyperparameters. 169 | t = self._cur_step / self.num_steps 170 | dlatent_noise = self._dlatent_std * self.initial_noise_factor * max(0.0, 1.0 - t / self.noise_ramp_length) ** 2 171 | lr_ramp = min(1.0, (1.0 - t) / self.lr_rampdown_length) 172 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 173 | lr_ramp = lr_ramp * min(1.0, t / self.lr_rampup_length) 174 | learning_rate = self.initial_learning_rate * lr_ramp 175 | 176 | # Execute optimization step. 177 | feed_dict = {self._dlatent_noise_in: dlatent_noise, self._lrate_in: learning_rate} 178 | _, dist_value, loss_value = tflib.run([self._opt_step, self._dist, self._loss], feed_dict) 179 | tflib.run(self._noise_normalize_op) 180 | self._cur_step += 1 181 | return dist_value, loss_value 182 | 183 | @property 184 | def cur_step(self): 185 | return self._cur_step 186 | 187 | @property 188 | def dlatents(self): 189 | return tflib.run(self._dlatents_expr, {self._dlatent_noise_in: 0}) 190 | 191 | @property 192 | def noises(self): 193 | return tflib.run(self._noise_vars) 194 | 195 | @property 196 | def images_float(self): 197 | return tflib.run(self._images_float_expr, {self._dlatent_noise_in: 0}) 198 | 199 | @property 200 | def images_uint8(self): 201 | return tflib.run(self._images_uint8_expr, {self._dlatent_noise_in: 0}) 202 | 203 | #---------------------------------------------------------------------------- 204 | 205 | def project(network_pkl: str, target_fname: str, outdir: str, save_video: bool, seed: int): 206 | # Load networks. 207 | tflib.init_tf({'rnd.np_random_seed': seed}) 208 | print('Loading networks from "%s"...' % network_pkl) 209 | with dnnlib.util.open_url(network_pkl) as fp: 210 | _G, _D, Gs = pickle.load(fp) 211 | 212 | # Load target image. 213 | target_pil = PIL.Image.open(target_fname) 214 | w, h = target_pil.size 215 | s = min(w, h) 216 | target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) 217 | target_pil= target_pil.convert('RGB') 218 | target_pil = target_pil.resize((Gs.output_shape[3], Gs.output_shape[2]), PIL.Image.ANTIALIAS) 219 | target_uint8 = np.array(target_pil, dtype=np.uint8) 220 | target_float = target_uint8.astype(np.float32).transpose([2, 0, 1]) * (2 / 255) - 1 221 | 222 | # Initialize projector. 223 | proj = Projector() 224 | proj.set_network(Gs) 225 | proj.start([target_float]) 226 | 227 | # Setup output directory. 228 | os.makedirs(outdir, exist_ok=True) 229 | target_pil.save(f'{outdir}/target.png') 230 | writer = None 231 | if save_video: 232 | writer = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=60, codec='libx264', bitrate='16M') 233 | 234 | # Run projector. 235 | with tqdm.trange(proj.num_steps) as t: 236 | for step in t: 237 | assert step == proj.cur_step 238 | if writer is not None: 239 | writer.append_data(np.concatenate([target_uint8, proj.images_uint8[0]], axis=1)) 240 | dist, loss = proj.step() 241 | t.set_postfix(dist=f'{dist[0]:.4f}', loss=f'{loss:.2f}') 242 | 243 | # Save results. 244 | PIL.Image.fromarray(proj.images_uint8[0], 'RGB').save(f'{outdir}/proj.png') 245 | np.savez(f'{outdir}/dlatents.npz', dlatents=proj.dlatents) 246 | if writer is not None: 247 | writer.close() 248 | 249 | #---------------------------------------------------------------------------- 250 | 251 | def _str_to_bool(v): 252 | if isinstance(v, bool): 253 | return v 254 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 255 | return True 256 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 257 | return False 258 | raise argparse.ArgumentTypeError('Boolean value expected.') 259 | 260 | #---------------------------------------------------------------------------- 261 | 262 | _examples = '''examples: 263 | 264 | python %(prog)s --outdir=out --target=targetimg.png \\ 265 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl 266 | ''' 267 | 268 | #---------------------------------------------------------------------------- 269 | 270 | def main(): 271 | parser = argparse.ArgumentParser( 272 | description='Project given image to the latent space of pretrained network pickle.', 273 | epilog=_examples, 274 | formatter_class=argparse.RawDescriptionHelpFormatter 275 | ) 276 | 277 | parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 278 | parser.add_argument('--target', help='Target image file to project to', dest='target_fname', required=True) 279 | parser.add_argument('--save-video', help='Save an mp4 video of optimization progress (default: true)', type=_str_to_bool, default=True) 280 | parser.add_argument('--seed', help='Random seed', type=int, default=303) 281 | parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR') 282 | project(**vars(parser.parse_args())) 283 | 284 | #---------------------------------------------------------------------------- 285 | 286 | if __name__ == "__main__": 287 | main() 288 | 289 | #---------------------------------------------------------------------------- 290 | -------------------------------------------------------------------------------- /style_mixing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate style mixing image matrix using pretrained network pickle.""" 10 | 11 | import argparse 12 | import os 13 | import pickle 14 | import re 15 | 16 | import numpy as np 17 | import PIL.Image 18 | 19 | import dnnlib 20 | import dnnlib.tflib as tflib 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, outdir, minibatch_size=4): 25 | tflib.init_tf() 26 | print('Loading networks from "%s"...' % network_pkl) 27 | with dnnlib.util.open_url(network_pkl) as fp: 28 | _G, _D, Gs = pickle.load(fp) 29 | 30 | w_avg = Gs.get_var('dlatent_avg') # [component] 31 | Gs_syn_kwargs = { 32 | 'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), 33 | 'randomize_noise': False, 34 | 'minibatch_size': minibatch_size 35 | } 36 | 37 | print('Generating W vectors...') 38 | all_seeds = list(set(row_seeds + col_seeds)) 39 | all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component] 40 | all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component] 41 | all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component] 42 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component] 43 | 44 | print('Generating images...') 45 | all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel] 46 | image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} 47 | 48 | print('Generating style-mixed images...') 49 | for row_seed in row_seeds: 50 | for col_seed in col_seeds: 51 | w = w_dict[row_seed].copy() 52 | w[col_styles] = w_dict[col_seed][col_styles] 53 | image = Gs.components.synthesis.run(w[np.newaxis], **Gs_syn_kwargs)[0] 54 | image_dict[(row_seed, col_seed)] = image 55 | 56 | print('Saving images...') 57 | os.makedirs(outdir, exist_ok=True) 58 | for (row_seed, col_seed), image in image_dict.items(): 59 | PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png') 60 | 61 | print('Saving image grid...') 62 | _N, _C, H, W = Gs.output_shape 63 | canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black') 64 | for row_idx, row_seed in enumerate([None] + row_seeds): 65 | for col_idx, col_seed in enumerate([None] + col_seeds): 66 | if row_seed is None and col_seed is None: 67 | continue 68 | key = (row_seed, col_seed) 69 | if row_seed is None: 70 | key = (col_seed, col_seed) 71 | if col_seed is None: 72 | key = (row_seed, row_seed) 73 | canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx)) 74 | canvas.save(f'{outdir}/grid.png') 75 | 76 | #---------------------------------------------------------------------------- 77 | 78 | def _parse_num_range(s): 79 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 80 | 81 | range_re = re.compile(r'^(\d+)-(\d+)$') 82 | m = range_re.match(s) 83 | if m: 84 | return list(range(int(m.group(1)), int(m.group(2))+1)) 85 | vals = s.split(',') 86 | return [int(x) for x in vals] 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | _examples = '''examples: 91 | 92 | python %(prog)s --outdir=out --trunc=1 --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\ 93 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl 94 | ''' 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def main(): 99 | parser = argparse.ArgumentParser( 100 | description='Generate style mixing image matrix using pretrained network pickle.', 101 | epilog=_examples, 102 | formatter_class=argparse.RawDescriptionHelpFormatter 103 | ) 104 | 105 | parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True) 106 | parser.add_argument('--rows', dest='row_seeds', type=_parse_num_range, help='Random seeds to use for image rows', required=True) 107 | parser.add_argument('--cols', dest='col_seeds', type=_parse_num_range, help='Random seeds to use for image columns', required=True) 108 | parser.add_argument('--styles', dest='col_styles', type=_parse_num_range, help='Style layer range (default: %(default)s)', default='0-6') 109 | parser.add_argument('--trunc', dest='truncation_psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5) 110 | parser.add_argument('--outdir', help='Where to save the output images', required=True, metavar='DIR') 111 | 112 | args = parser.parse_args() 113 | style_mixing_example(**vars(args)) 114 | 115 | #---------------------------------------------------------------------------- 116 | 117 | if __name__ == "__main__": 118 | main() 119 | 120 | #---------------------------------------------------------------------------- 121 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Streaming images and labels from dataset created with dataset_tool.py.""" 10 | 11 | import os 12 | import glob 13 | import numpy as np 14 | import tensorflow as tf 15 | import dnnlib.tflib as tflib 16 | 17 | #---------------------------------------------------------------------------- 18 | # Dataset class that loads images from tfrecords files. 19 | 20 | class TFRecordDataset: 21 | def __init__(self, 22 | tfrecord_dir, # Directory containing a collection of tfrecords files. 23 | resolution = None, # Dataset resolution, None = autodetect. 24 | label_file = None, # Relative path of the labels file, None = autodetect. 25 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. 26 | max_images = None, # Maximum number of images to use, None = use all images. 27 | max_validation = 10000, # Maximum size of the validation set, None = use all available images. 28 | mirror_augment = False, # Apply mirror augment? 29 | mirror_augment_v= False, # Apply mirror augment vertically? 30 | repeat = True, # Repeat dataset indefinitely? 31 | shuffle = True, # Shuffle images? 32 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 33 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. 34 | buffer_mb = 256, # Read buffer size (megabytes). 35 | num_threads = 2, # Number of concurrent threads. 36 | _is_validation = False, 37 | use_raw=False, 38 | ): 39 | self.tfrecord_dir = tfrecord_dir 40 | self.resolution = None 41 | self.resolution_log2 = None 42 | self.shape = [] # [channels, height, width] 43 | self.dtype = 'uint8' 44 | self.label_file = label_file 45 | self.label_size = None # components 46 | self.label_dtype = None 47 | self.has_validation_set = None 48 | self.mirror_augment = mirror_augment 49 | self.mirror_augment_v = mirror_augment_v 50 | self.repeat = repeat 51 | self.shuffle = shuffle 52 | self._max_validation = max_validation 53 | self._np_labels = None 54 | self._tf_minibatch_in = None 55 | self._tf_labels_var = None 56 | self._tf_labels_dataset = None 57 | self._tf_datasets = dict() 58 | self._tf_iterator = None 59 | self._tf_init_ops = dict() 60 | self._tf_minibatch_np = None 61 | self._cur_minibatch = -1 62 | self._cur_lod = -1 63 | 64 | # List files in the dataset directory. 65 | assert os.path.isdir(self.tfrecord_dir) 66 | all_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*'))) 67 | self.has_validation_set = (self._max_validation > 0) and any(os.path.basename(f).startswith('validation-') for f in all_files) 68 | all_files = [f for f in all_files if os.path.basename(f).startswith('validation-') == _is_validation] 69 | 70 | # Inspect tfrecords files. 71 | tfr_files = [f for f in all_files if f.endswith('.tfrecords')] 72 | assert len(tfr_files) >= 1 73 | tfr_shapes = [] 74 | for tfr_file in tfr_files: 75 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 76 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): 77 | if use_raw: 78 | tfr_shapes.append(self.parse_tfrecord_np_raw(record)) 79 | else: 80 | tfr_shapes.append(self.parse_tfrecord_np(record).shape) 81 | break 82 | 83 | # Autodetect label filename. 84 | if self.label_file is None: 85 | guess = [f for f in all_files if f.endswith('.labels')] 86 | if len(guess): 87 | self.label_file = guess[0] 88 | elif not os.path.isfile(self.label_file): 89 | guess = os.path.join(self.tfrecord_dir, self.label_file) 90 | if os.path.isfile(guess): 91 | self.label_file = guess 92 | 93 | # Determine shape and resolution. 94 | max_shape = max(tfr_shapes, key=np.prod) 95 | self.resolution = resolution if resolution is not None else max_shape[1] 96 | self.resolution_log2 = int(np.log2(self.resolution)) 97 | self.shape = [max_shape[0], self.resolution, self.resolution] 98 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] 99 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes) 100 | assert all(shape[1] == shape[2] for shape in tfr_shapes) 101 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) 102 | # Breaks raw functions 103 | # assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) 104 | 105 | # Load labels. 106 | assert max_label_size == 'full' or max_label_size >= 0 107 | self._np_labels = np.zeros([1<<30, 0], dtype=np.float32) 108 | if self.label_file is not None and max_label_size != 0: 109 | self._np_labels = np.load(self.label_file) 110 | assert self._np_labels.ndim == 2 111 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: 112 | self._np_labels = self._np_labels[:, :max_label_size] 113 | if max_images is not None and self._np_labels.shape[0] > max_images: 114 | self._np_labels = self._np_labels[:max_images] 115 | self.label_size = self._np_labels.shape[1] 116 | self.label_dtype = self._np_labels.dtype.name 117 | 118 | # Build TF expressions. 119 | with tf.name_scope('Dataset'), tf.device('/cpu:0'), tf.control_dependencies(None): 120 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) 121 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') 122 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) 123 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): 124 | if tfr_lod < 0: 125 | continue 126 | 127 | tfr_file = tfr_files[-1] # should be the highest resolution tf_record file 128 | tfr_shape = tfr_shapes[-1] # again the highest resolution shape 129 | dset = tf.data.TFRecordDataset(tfr_file, compression_type="", buffer_size=buffer_mb<< 20) 130 | if max_images is not None: 131 | dset = dset.take(max_images) 132 | if use_raw: 133 | dset = dset.map(self.parse_tfrecord_tf_raw, num_parallel_calls=num_threads) 134 | else: 135 | dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads) 136 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) 137 | 138 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize 139 | if self.shuffle and shuffle_mb > 0: 140 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) 141 | if self.repeat: 142 | dset = dset.repeat() 143 | if prefetch_mb > 0: 144 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) 145 | dset = dset.batch(self._tf_minibatch_in) 146 | self._tf_datasets[tfr_lod] = dset 147 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) 148 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} 149 | 150 | def close(self): 151 | pass 152 | 153 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). 154 | def configure(self, minibatch_size, lod=0): 155 | lod = int(np.floor(lod)) 156 | assert minibatch_size >= 1 and lod in self._tf_datasets 157 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod: 158 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 159 | self._cur_minibatch = minibatch_size 160 | # breaks raw loading? 161 | #self._cur_lod = lod 162 | 163 | # Get next minibatch as TensorFlow expressions. 164 | def get_minibatch_tf(self): 165 | images, labels = self._tf_iterator.get_next() 166 | if self.mirror_augment or self.mirror_augment_v: 167 | images = tf.cast(images, tf.float32) 168 | if self.mirror_augment: 169 | images = tf.where(tf.random_uniform([tf.shape(images)[0]]) < 0.5, images, tf.reverse(images, [3])) 170 | if self.mirror_augment_v: 171 | images = tf.where(tf.random_uniform([tf.shape(images)[0]]) < 0.5, images, tf.reverse(images, [2])) 172 | images = tf.cast(images, self.dtype) 173 | return images, labels 174 | 175 | # Get next minibatch as NumPy arrays. 176 | def get_minibatch_np(self, minibatch_size, lod=0): # => (images, labels) or (None, None) 177 | self.configure(minibatch_size, lod) 178 | if self._tf_minibatch_np is None: 179 | with tf.name_scope('Dataset'): 180 | self._tf_minibatch_np = self.get_minibatch_tf() 181 | try: 182 | return tflib.run(self._tf_minibatch_np) 183 | except tf.errors.OutOfRangeError: 184 | return None, None 185 | 186 | # Get random labels as TensorFlow expression. 187 | def get_random_labels_tf(self, minibatch_size): # => labels 188 | with tf.name_scope('Dataset'): 189 | if self.label_size > 0: 190 | with tf.device('/cpu:0'): 191 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) 192 | return tf.zeros([minibatch_size, 0], self.label_dtype) 193 | 194 | # Get random labels as NumPy array. 195 | def get_random_labels_np(self, minibatch_size): # => labels 196 | if self.label_size > 0: 197 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] 198 | return np.zeros([minibatch_size, 0], self.label_dtype) 199 | 200 | # Load validation set as NumPy array. 201 | def load_validation_set_np(self): 202 | images = [] 203 | labels = [] 204 | if self.has_validation_set: 205 | validation_set = TFRecordDataset( 206 | tfrecord_dir=self.tfrecord_dir, resolution=self.shape[2], max_label_size=self.label_size, 207 | max_images=self._max_validation, repeat=False, shuffle=False, prefetch_mb=0, _is_validation=True) 208 | validation_set.configure(1) 209 | while True: 210 | image, label = validation_set.get_minibatch_np(1) 211 | if image is None: 212 | break 213 | images.append(image) 214 | labels.append(label) 215 | images = np.concatenate(images, axis=0) if len(images) else np.zeros([0] + self.shape, dtype=self.dtype) 216 | labels = np.concatenate(labels, axis=0) if len(labels) else np.zeros([0, self.label_size], self.label_dtype) 217 | assert list(images.shape[1:]) == self.shape 218 | assert labels.shape[1] == self.label_size 219 | assert images.shape[0] <= self._max_validation 220 | return images, labels 221 | 222 | # Parse individual image from a tfrecords file into TensorFlow expression. 223 | @staticmethod 224 | def parse_tfrecord_tf(record): 225 | features = tf.parse_single_example(record, features={ 226 | 'shape': tf.FixedLenFeature([3], tf.int64), 227 | 'data': tf.FixedLenFeature([], tf.string)}) 228 | data = tf.decode_raw(features['data'], tf.uint8) 229 | return tf.reshape(data, features['shape']) 230 | 231 | @staticmethod 232 | def parse_tfrecord_tf_raw(record): 233 | features = tf.parse_single_example(record, 234 | features={ 235 | "shape": tf.FixedLenFeature([3], tf.int64), 236 | "img": tf.FixedLenFeature([], tf.string), 237 | } 238 | ) 239 | image = tf.image.decode_image(features['img']) 240 | return tf.transpose(image, [2,0,1]) 241 | 242 | # Parse individual image from a tfrecords file into NumPy array. 243 | @staticmethod 244 | def parse_tfrecord_np(record): 245 | ex = tf.train.Example() 246 | ex.ParseFromString(record) 247 | shape = ex.features.feature['shape'].int64_list.value # pylint: disable=no-member 248 | data = ex.features.feature['data'].bytes_list.value[0] # pylint: disable=no-member 249 | return np.fromstring(data, np.uint8).reshape(shape) 250 | 251 | @staticmethod 252 | def parse_tfrecord_np_raw(record): 253 | ex = tf.train.Example() 254 | ex.ParseFromString(record) 255 | shape = ex.features.feature[ 256 | "shape" 257 | ].int64_list.value # temporary pylint workaround # pylint: disable=no-member 258 | img = ex.features.feature["img"].bytes_list.value[ 259 | 0 260 | ] # temporary pylint workaround # pylint: disable=no-member 261 | return shape 262 | 263 | #---------------------------------------------------------------------------- 264 | # Construct a dataset object using the given options. 265 | 266 | def load_dataset(path=None, use_raw=False, resolution=None, max_images=None, max_label_size=0, mirror_augment=False, mirror_augment_v=False, repeat=True, shuffle=True, seed=None): 267 | _ = seed 268 | assert os.path.isdir(path) 269 | return TFRecordDataset( 270 | tfrecord_dir=path, use_raw=use_raw, 271 | resolution=resolution, max_images=max_images, max_label_size=max_label_size, 272 | mirror_augment=mirror_augment, mirror_augment_v=mirror_augment_v, repeat=repeat, shuffle=shuffle) 273 | 274 | #---------------------------------------------------------------------------- 275 | -------------------------------------------------------------------------------- /training/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Miscellaneous utility functions.""" 8 | 9 | import os 10 | import pickle 11 | import numpy as np 12 | import PIL.Image 13 | import PIL.ImageFont 14 | import dnnlib 15 | import glob 16 | import re 17 | #---------------------------------------------------------------------------- 18 | # Convenience wrappers for pickle that are able to load data produced by 19 | # older versions of the code, and from external URLs. 20 | 21 | def open_file_or_url(file_or_url): 22 | if dnnlib.util.is_url(file_or_url): 23 | return dnnlib.util.open_url(file_or_url, cache_dir='.stylegan2-cache') 24 | return open(file_or_url, 'rb') 25 | 26 | def load_pkl(file_or_url): 27 | with open_file_or_url(file_or_url) as file: 28 | return pickle.load(file, encoding='latin1') 29 | 30 | def locate_latest_pkl(result_dir): 31 | allpickles = sorted(glob.glob(os.path.join(result_dir, '0*', 'network-*.pkl'))) 32 | if len(allpickles) == 0: 33 | return None, 0.0 34 | latest_pickle = allpickles[-1] 35 | resume_run_id = os.path.basename(os.path.dirname(latest_pickle)) 36 | RE_KIMG = re.compile('network-snapshot-(\d+).pkl') 37 | kimg = int(RE_KIMG.match(os.path.basename(latest_pickle)).group(1)) 38 | return (latest_pickle, float(kimg)) 39 | 40 | def save_pkl(obj, filename): 41 | with open(filename, 'wb') as file: 42 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 43 | 44 | #---------------------------------------------------------------------------- 45 | # Image utils. 46 | 47 | def adjust_dynamic_range(data, drange_in, drange_out): 48 | if drange_in != drange_out: 49 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 50 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 51 | data = data * scale + bias 52 | return data 53 | 54 | def create_image_grid(images, grid_size=None): 55 | assert images.ndim == 3 or images.ndim == 4 56 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 57 | 58 | if grid_size is not None: 59 | grid_w, grid_h = tuple(grid_size) 60 | else: 61 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 62 | grid_h = max((num - 1) // grid_w + 1, 1) 63 | 64 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 65 | for idx in range(num): 66 | x = (idx % grid_w) * img_w 67 | y = (idx // grid_w) * img_h 68 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 69 | return grid 70 | 71 | def convert_to_pil_image(image, drange=[0,1]): 72 | assert image.ndim == 2 or image.ndim == 3 73 | if image.ndim == 3: 74 | if image.shape[0] == 1: 75 | image = image[0] # grayscale CHW => HW 76 | else: 77 | image = image.transpose(1, 2, 0) # CHW -> HWC 78 | 79 | image = adjust_dynamic_range(image, drange, [0,255]) 80 | image = np.rint(image).clip(0, 255).astype(np.uint8) 81 | fmt = 'RGB' if image.ndim == 3 else 'L' 82 | return PIL.Image.fromarray(image, fmt) 83 | 84 | def save_image_grid(images, filename, drange=[0,1], grid_size=None): 85 | convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) 86 | 87 | def apply_mirror_augment(minibatch): 88 | mask = np.random.rand(minibatch.shape[0]) < 0.5 89 | minibatch = np.array(minibatch) 90 | minibatch[mask] = minibatch[mask, :, :, ::-1] 91 | return minibatch 92 | 93 | def apply_mirror_augment_v(minibatch): 94 | mask = np.random.rand(minibatch.shape[0]) < 0.5 95 | minibatch = np.array(minibatch) 96 | minibatch[mask] = minibatch[mask, :, ::-1, :] 97 | return minibatch 98 | 99 | #---------------------------------------------------------------------------- 100 | # Loading data from previous training runs. 101 | 102 | def parse_config_for_previous_run(run_dir): 103 | with open(os.path.join(run_dir, 'submit_config.pkl'), 'rb') as f: 104 | data = pickle.load(f) 105 | data = data.get('run_func_kwargs', {}) 106 | return dict(train=data, dataset=data.get('dataset_args', {})) 107 | 108 | #---------------------------------------------------------------------------- 109 | # Size and contents of the image snapshot grids that are exported 110 | # periodically during training. 111 | 112 | def setup_snapshot_image_grid(training_set, 113 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 114 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 115 | 116 | # Select size. 117 | gw = 1; gh = 1 118 | if size == '1080p': 119 | gw = np.clip(1920 // training_set.shape[2], 3, 32) 120 | gh = np.clip(1080 // training_set.shape[1], 2, 32) 121 | if size == '4k': 122 | gw = np.clip(3840 // training_set.shape[2], 7, 32) 123 | gh = np.clip(2160 // training_set.shape[1], 4, 32) 124 | if size == '8k': 125 | gw = np.clip(7680 // training_set.shape[2], 7, 32) 126 | gh = np.clip(4320 // training_set.shape[1], 4, 32) 127 | 128 | # Initialize data arrays. 129 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 130 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 131 | 132 | # Random layout. 133 | if layout == 'random': 134 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) 135 | 136 | # Class-conditional layouts. 137 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 138 | if layout in class_layouts: 139 | bw, bh = class_layouts[layout] 140 | nw = (gw - 1) // bw + 1 141 | nh = (gh - 1) // bh + 1 142 | blocks = [[] for _i in range(nw * nh)] 143 | for _iter in range(1000000): 144 | real, label = training_set.get_minibatch_np(1) 145 | idx = np.argmax(label[0]) 146 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 147 | idx += training_set.label_size 148 | if idx < len(blocks): 149 | blocks[idx].append((real, label)) 150 | if all(len(block) >= bw * bh for block in blocks): 151 | break 152 | for i, block in enumerate(blocks): 153 | for j, (real, label) in enumerate(block): 154 | x = (i % nw) * bw + j % bw 155 | y = (i // nw) * bh + j // bw 156 | if x < gw and y < gh: 157 | reals[x + y * gw] = real[0] 158 | labels[x + y * gw] = label[0] 159 | 160 | return (gw, gh), reals, labels 161 | 162 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /training/training_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main training loop.""" 10 | 11 | import os 12 | import pickle 13 | import time 14 | import PIL.Image 15 | import numpy as np 16 | import tensorflow as tf 17 | import dnnlib 18 | import dnnlib.tflib as tflib 19 | from dnnlib.tflib.autosummary import autosummary 20 | 21 | from training import dataset 22 | 23 | #---------------------------------------------------------------------------- 24 | # Select size and contents of the image snapshot grids that are exported 25 | # periodically during training. 26 | 27 | def setup_snapshot_image_grid(training_set): 28 | gw = np.clip(7680 // training_set.shape[2], 7, 32) 29 | gh = np.clip(4320 // training_set.shape[1], 4, 32) 30 | 31 | # Unconditional. 32 | if training_set.label_size == 0: 33 | reals, labels = training_set.get_minibatch_np(gw * gh) 34 | return (gw, gh), reals, labels 35 | 36 | # Row per class. 37 | cw, ch = (gw, 1) 38 | nw = (gw - 1) // cw + 1 39 | nh = (gh - 1) // ch + 1 40 | 41 | # Collect images. 42 | blocks = [[] for _i in range(nw * nh)] 43 | for _iter in range(1000000): 44 | real, label = training_set.get_minibatch_np(1) 45 | idx = np.argmax(label[0]) 46 | while idx < len(blocks) and len(blocks[idx]) >= cw * ch: 47 | idx += training_set.label_size 48 | if idx < len(blocks): 49 | blocks[idx].append((real, label)) 50 | if all(len(block) >= cw * ch for block in blocks): 51 | break 52 | 53 | # Layout grid. 54 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 55 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 56 | for i, block in enumerate(blocks): 57 | for j, (real, label) in enumerate(block): 58 | x = (i % nw) * cw + j % cw 59 | y = (i // nw) * ch + j // cw 60 | if x < gw and y < gh: 61 | reals[x + y * gw] = real[0] 62 | labels[x + y * gw] = label[0] 63 | return (gw, gh), reals, labels 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def save_image_grid(images, filename, drange, grid_size): 68 | lo, hi = drange 69 | gw, gh = grid_size 70 | images = np.asarray(images, dtype=np.float32) 71 | images = (images - lo) * (255 / (hi - lo)) 72 | images = np.rint(images).clip(0, 255).astype(np.uint8) 73 | _N, C, H, W = images.shape 74 | images = images.reshape(gh, gw, C, H, W) 75 | images = images.transpose(0, 3, 1, 4, 2) 76 | images = images.reshape(gh * H, gw * W, C) 77 | PIL.Image.fromarray(images, {3: 'RGB', 1: 'L'}[C]).save(filename) 78 | 79 | #---------------------------------------------------------------------------- 80 | # Main training script. 81 | 82 | def training_loop( 83 | run_dir = '.', # Output directory. 84 | G_args = {}, # Options for generator network. 85 | D_args = {}, # Options for discriminator network. 86 | G_opt_args = {}, # Options for generator optimizer. 87 | D_opt_args = {}, # Options for discriminator optimizer. 88 | loss_args = {}, # Options for loss function. 89 | train_dataset_args = {}, # Options for dataset to train with. 90 | metric_dataset_args = {}, # Options for dataset to evaluate metrics against. 91 | augment_args = {}, # Options for adaptive augmentations. 92 | metric_arg_list = [], # Metrics to evaluate during training. 93 | num_gpus = 1, # Number of GPUs to use. 94 | minibatch_size = 32, # Global minibatch size. 95 | minibatch_gpu = 4, # Number of samples processed at a time by one GPU. 96 | G_smoothing_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. 97 | G_smoothing_rampup = None, # EMA ramp-up coefficient. 98 | minibatch_repeats = 4, # Number of minibatches to run in the inner loop. 99 | lazy_regularization = True, # Perform regularization as a separate training step? 100 | G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. 101 | D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. 102 | nimg = 0, # current image count 103 | total_kimg = 25000, # Total length of the training, measured in thousands of real images. 104 | kimg_per_tick = 4, # Progress snapshot interval. 105 | image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. 106 | network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. 107 | resume_pkl = None, # Network pickle to resume training from. 108 | abort_fn = None, # Callback function for determining whether to abort training. 109 | progress_fn = None, # Callback function for updating training progress. 110 | ): 111 | assert minibatch_size % (num_gpus * minibatch_gpu) == 0 112 | start_time = time.time() 113 | 114 | print('Loading training set...') 115 | training_set = dataset.load_dataset(**train_dataset_args) 116 | print('Image shape:', np.int32(training_set.shape).tolist()) 117 | print('Label shape:', [training_set.label_size]) 118 | print() 119 | 120 | print('Constructing networks...') 121 | with tf.device('/gpu:0'): 122 | G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) 123 | D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) 124 | Gs = G.clone('Gs') 125 | if resume_pkl is not None: 126 | print(f'Resuming from "{resume_pkl}"') 127 | with dnnlib.util.open_url(resume_pkl) as f: 128 | rG, rD, rGs = pickle.load(f) 129 | G.copy_vars_from(rG) 130 | D.copy_vars_from(rD) 131 | Gs.copy_vars_from(rGs) 132 | G.print_layers() 133 | D.print_layers() 134 | 135 | print('Exporting sample images...') 136 | grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(training_set) 137 | save_image_grid(grid_reals, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size) 138 | grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) 139 | grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) 140 | save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size) 141 | 142 | print(f'Replicating networks across {num_gpus} GPUs...') 143 | G_gpus = [G] 144 | D_gpus = [D] 145 | for gpu in range(1, num_gpus): 146 | with tf.device(f'/gpu:{gpu}'): 147 | G_gpus.append(G.clone(f'{G.name}_gpu{gpu}')) 148 | D_gpus.append(D.clone(f'{D.name}_gpu{gpu}')) 149 | 150 | print('Initializing augmentations...') 151 | aug = None 152 | if augment_args.get('class_name', None) is not None: 153 | aug = dnnlib.util.construct_class_by_name(**augment_args) 154 | aug.init_validation_set(D_gpus=D_gpus, training_set=training_set) 155 | 156 | print('Setting up optimizers...') 157 | G_opt_args = dict(G_opt_args) 158 | D_opt_args = dict(D_opt_args) 159 | for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: 160 | args['minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu 161 | if lazy_regularization: 162 | mb_ratio = reg_interval / (reg_interval + 1) 163 | args['learning_rate'] *= mb_ratio 164 | if 'beta1' in args: args['beta1'] **= mb_ratio 165 | if 'beta2' in args: args['beta2'] **= mb_ratio 166 | G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) 167 | D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) 168 | G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) 169 | D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) 170 | 171 | print('Constructing training graph...') 172 | data_fetch_ops = [] 173 | training_set.configure(minibatch_gpu) 174 | for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)): 175 | with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'): 176 | 177 | # Fetch training data via temporary variables. 178 | with tf.name_scope('DataFetch'): 179 | real_images_var = tf.Variable(name='images', trainable=False, initial_value=tf.zeros([minibatch_gpu] + training_set.shape)) 180 | real_labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([minibatch_gpu, training_set.label_size])) 181 | real_images_write, real_labels_write = training_set.get_minibatch_tf() 182 | real_images_write = tflib.convert_images_from_uint8(real_images_write) 183 | data_fetch_ops += [tf.assign(real_images_var, real_images_write)] 184 | data_fetch_ops += [tf.assign(real_labels_var, real_labels_write)] 185 | 186 | # Evaluate loss function and register gradients. 187 | fake_labels = training_set.get_random_labels_tf(minibatch_gpu) 188 | terms = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, aug=aug, fake_labels=fake_labels, real_images=real_images_var, real_labels=real_labels_var, **loss_args) 189 | if lazy_regularization: 190 | if terms.G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(terms.G_reg * G_reg_interval), G_gpu.trainables) 191 | if terms.D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(terms.D_reg * D_reg_interval), D_gpu.trainables) 192 | else: 193 | if terms.G_reg is not None: terms.G_loss += terms.G_reg 194 | if terms.D_reg is not None: terms.D_loss += terms.D_reg 195 | G_opt.register_gradients(tf.reduce_mean(terms.G_loss), G_gpu.trainables) 196 | D_opt.register_gradients(tf.reduce_mean(terms.D_loss), D_gpu.trainables) 197 | 198 | print('Finalizing training ops...') 199 | data_fetch_op = tf.group(*data_fetch_ops) 200 | G_train_op = G_opt.apply_updates() 201 | D_train_op = D_opt.apply_updates() 202 | G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) 203 | D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) 204 | Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) 205 | Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in) 206 | Gs_epochs = tf.placeholder(tf.float32, name='Gs_epochs', shape=[]) 207 | Gs_epochs_op = Gs.update_epochs(Gs_epochs) 208 | tflib.init_uninitialized_vars() 209 | with tf.device('/gpu:0'): 210 | peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() 211 | 212 | print('Initializing metrics...') 213 | summary_log = tf.summary.FileWriter(run_dir) 214 | metrics = [] 215 | for args in metric_arg_list: 216 | metric = dnnlib.util.construct_class_by_name(**args) 217 | metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir) 218 | metrics.append(metric) 219 | 220 | print(f'Training for {total_kimg} kimg...') 221 | print() 222 | if progress_fn is not None: 223 | progress_fn(0, total_kimg) 224 | tick_start_time = time.time() 225 | maintenance_time = tick_start_time - start_time 226 | cur_nimg = nimg 227 | cur_tick = -1 228 | tick_start_nimg = cur_nimg 229 | running_mb_counter = 0 230 | 231 | done = False 232 | while not done: 233 | 234 | # Compute EMA decay parameter. 235 | Gs_nimg = G_smoothing_kimg * 1000.0 236 | if G_smoothing_rampup is not None: 237 | Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup) 238 | Gs_beta = 0.5 ** (minibatch_size / max(Gs_nimg, 1e-8)) 239 | 240 | epochs = float(100 * cur_nimg / (total_kimg * 1000)) # 100 total top k "epochs" in total_kimg 241 | 242 | # Run training ops. 243 | for _repeat_idx in range(minibatch_repeats): 244 | rounds = range(0, minibatch_size, minibatch_gpu * num_gpus) 245 | run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) 246 | run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) 247 | cur_nimg += minibatch_size 248 | running_mb_counter += 1 249 | 250 | # Fast path without gradient accumulation. 251 | if len(rounds) == 1: 252 | tflib.run([G_train_op, data_fetch_op]) 253 | if run_G_reg: 254 | tflib.run(G_reg_op) 255 | tflib.run([D_train_op, Gs_update_op, Gs_epochs_op], {Gs_beta_in: Gs_beta, Gs_epochs: epochs}) 256 | if run_D_reg: 257 | tflib.run(D_reg_op) 258 | 259 | # Slow path with gradient accumulation. 260 | else: 261 | for _round in rounds: 262 | tflib.run(G_train_op) 263 | if run_G_reg: 264 | tflib.run(G_reg_op) 265 | tflib.run([Gs_update_op, Gs_epochs_op], {Gs_beta_in: Gs_beta, Gs_epochs: epochs}) 266 | for _round in rounds: 267 | tflib.run(data_fetch_op) 268 | tflib.run(D_train_op) 269 | if run_D_reg: 270 | tflib.run(D_reg_op) 271 | 272 | # Run validation. 273 | if aug is not None: 274 | aug.run_validation(minibatch_size=minibatch_size) 275 | 276 | # Tune augmentation parameters. 277 | if aug is not None: 278 | aug.tune(minibatch_size * minibatch_repeats) 279 | 280 | # Perform maintenance tasks once per tick. 281 | done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None and abort_fn()) 282 | if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000: 283 | cur_tick += 1 284 | tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 285 | tick_start_nimg = cur_nimg 286 | tick_end_time = time.time() 287 | total_time = tick_end_time - start_time 288 | tick_time = tick_end_time - tick_start_time 289 | 290 | # Report progress. 291 | print(' '.join([ 292 | f"tick {autosummary('Progress/tick', cur_tick):<5d}", 293 | f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}", 294 | f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}", 295 | f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}", 296 | f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}", 297 | f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}", 298 | f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}", 299 | f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}", 300 | ])) 301 | autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) 302 | autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) 303 | if progress_fn is not None: 304 | progress_fn(cur_nimg // 1000, total_kimg) 305 | 306 | # Save snapshots. 307 | if image_snapshot_ticks is not None and (done or cur_tick % image_snapshot_ticks == 0): 308 | grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) 309 | save_image_grid(grid_fakes, os.path.join(run_dir, f'fakes{cur_nimg // 1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size) 310 | if network_snapshot_ticks is not None and (done or cur_tick % network_snapshot_ticks == 0): 311 | pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl') 312 | with open(pkl, 'wb') as f: 313 | pickle.dump((G, D, Gs), f) 314 | if len(metrics): 315 | print('Evaluating metrics...') 316 | for metric in metrics: 317 | metric.run(pkl, num_gpus=num_gpus) 318 | 319 | # Update summaries. 320 | for metric in metrics: 321 | metric.update_autosummaries() 322 | tflib.autosummary.save_summaries(summary_log, cur_nimg) 323 | tick_start_time = time.time() 324 | maintenance_time = tick_start_time - tick_end_time 325 | 326 | print() 327 | print('Exiting...') 328 | summary_log.close() 329 | training_set.close() 330 | 331 | #---------------------------------------------------------------------------- 332 | -------------------------------------------------------------------------------- /utils/align_faces.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import bz2 4 | from keras.utils import get_file 5 | from ffhq_dataset.face_alignment import image_align 6 | from ffhq_dataset.landmarks_detector import LandmarksDetector 7 | 8 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' 9 | 10 | 11 | def unpack_bz2(src_path): 12 | data = bz2.BZ2File(src_path).read() 13 | dst_path = src_path[:-4] 14 | with open(dst_path, 'wb') as fp: 15 | fp.write(data) 16 | return dst_path 17 | 18 | 19 | if __name__ == "__main__": 20 | """ 21 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step 22 | python align_images.py /raw_images /aligned_images 23 | """ 24 | 25 | landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2', 26 | LANDMARKS_MODEL_URL, cache_subdir='temp')) 27 | RAW_IMAGES_DIR = sys.argv[1] 28 | ALIGNED_IMAGES_DIR = sys.argv[2] 29 | 30 | landmarks_detector = LandmarksDetector(landmarks_model_path) 31 | for img_name in [x for x in os.listdir(RAW_IMAGES_DIR) if x[0] not in '._']: 32 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name) 33 | for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1): 34 | face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i) 35 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) 36 | os.makedirs(ALIGNED_IMAGES_DIR, exist_ok=True) 37 | image_align(raw_img_path, aligned_face_path, face_landmarks) 38 | -------------------------------------------------------------------------------- /utils/tffreeze.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | import tensorflow as tf 4 | 5 | # The original freeze_graph function 6 | # from tensorflow.python.tools.freeze_graph import freeze_graph 7 | 8 | dir = os.path.dirname(os.path.realpath(__file__)) 9 | 10 | def freeze_graph(model_dir, output_node_names): 11 | """Extract the sub graph defined by the output nodes and convert 12 | all its variables into constant 13 | 14 | Args: 15 | model_dir: the root folder containing the checkpoint state file 16 | output_node_names: a string, containing all the output node's names, 17 | comma separated 18 | """ 19 | if not tf.gfile.Exists(model_dir): 20 | raise AssertionError( 21 | "Export directory doesn't exists. Please specify an export " 22 | "directory: %s" % model_dir) 23 | 24 | if not output_node_names: 25 | print("You need to supply the name of a node to --output_node_names.") 26 | return -1 27 | 28 | # We retrieve our checkpoint fullpath 29 | checkpoint = tf.train.get_checkpoint_state(model_dir) 30 | input_checkpoint = checkpoint.model_checkpoint_path 31 | 32 | # We precise the file fullname of our freezed graph 33 | absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1]) 34 | output_graph = absolute_model_dir + "/frozen_model.pb" 35 | 36 | # We clear devices to allow TensorFlow to control on which device it will load operations 37 | clear_devices = True 38 | 39 | # We start a session using a temporary fresh Graph 40 | with tf.Session(graph=tf.Graph()) as sess: 41 | # We import the meta graph in the current default Graph 42 | saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) 43 | 44 | # We restore the weights 45 | saver.restore(sess, input_checkpoint) 46 | 47 | # We use a built-in TF helper to export variables to constants 48 | output_graph_def = tf.graph_util.convert_variables_to_constants( 49 | sess, # The session is used to retrieve the weights 50 | tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 51 | output_node_names.split(",") # The output node names are used to select the usefull nodes 52 | ) 53 | 54 | # Finally we serialize and dump the output graph to the filesystem 55 | with tf.gfile.GFile(output_graph, "wb") as f: 56 | f.write(output_graph_def.SerializeToString()) 57 | print("%d ops in the final graph." % len(output_graph_def.node)) 58 | 59 | return output_graph_def 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--model_dir", type=str, default="", help="Model folder to export") 64 | parser.add_argument("--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.") 65 | args = parser.parse_args() 66 | 67 | freeze_graph(args.model_dir, args.output_node_names) 68 | --------------------------------------------------------------------------------