├── LICENSE.txt ├── Learn_direction_in_latent_space.ipynb ├── Play_with_latent_directions.ipynb ├── README.md ├── align_images.py ├── config.py ├── dataset_tool.py ├── dnnlib ├── __init__.py ├── submission │ ├── __init__.py │ ├── _internal │ │ └── run.py │ ├── run_context.py │ └── submit.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── network.py │ ├── optimizer.py │ └── tfutil.py └── util.py ├── encode_images.py ├── encoder ├── __init__.py ├── generator_model.py └── perceptual_model.py ├── ffhq_dataset ├── __init__.py ├── face_alignment.py ├── landmarks_detector.py ├── latent_directions │ ├── age.npy │ ├── gender.npy │ └── smile.npy └── latent_representations │ ├── donald_trump_01.npy │ └── hillary_clinton_01.npy ├── generate_figures.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── linear_separability.py ├── metric_base.py └── perceptual_path_length.py ├── pretrained_example.py ├── requirements.txt ├── run_metrics.py ├── teaser.png ├── train.py └── training ├── __init__.py ├── dataset.py ├── loss.py ├── misc.py ├── networks_progan.py ├── networks_stylegan.py └── training_loop.py /align_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import bz2 4 | from keras.utils import get_file 5 | from ffhq_dataset.face_alignment import image_align 6 | from ffhq_dataset.landmarks_detector import LandmarksDetector 7 | 8 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' 9 | 10 | 11 | def unpack_bz2(src_path): 12 | data = bz2.BZ2File(src_path).read() 13 | dst_path = src_path[:-4] 14 | with open(dst_path, 'wb') as fp: 15 | fp.write(data) 16 | return dst_path 17 | 18 | 19 | if __name__ == "__main__": 20 | """ 21 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step 22 | python align_images.py /raw_images /aligned_images 23 | """ 24 | 25 | landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2', 26 | LANDMARKS_MODEL_URL, cache_subdir='temp')) 27 | RAW_IMAGES_DIR = sys.argv[1] 28 | ALIGNED_IMAGES_DIR = sys.argv[2] 29 | 30 | landmarks_detector = LandmarksDetector(landmarks_model_path) 31 | for img_name in os.listdir(RAW_IMAGES_DIR): 32 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name) 33 | for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1): 34 | face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i) 35 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) 36 | 37 | image_align(raw_img_path, aligned_face_path, face_landmarks) 38 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Global configuration.""" 9 | 10 | #---------------------------------------------------------------------------- 11 | # Paths. 12 | 13 | result_dir = 'results' 14 | data_dir = 'datasets' 15 | cache_dir = 'cache' 16 | run_dir_ignore = ['results', 'datasets', 'cache'] 17 | 18 | #---------------------------------------------------------------------------- 19 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Submit a function to be run either locally or in a computing cluster.""" 9 | 10 | import copy 11 | import io 12 | import os 13 | import pathlib 14 | import pickle 15 | import platform 16 | import pprint 17 | import re 18 | import shutil 19 | import time 20 | import traceback 21 | 22 | import zipfile 23 | 24 | from enum import Enum 25 | 26 | from .. import util 27 | from ..util import EasyDict 28 | 29 | 30 | class SubmitTarget(Enum): 31 | """The target where the function should be run. 32 | 33 | LOCAL: Run it locally. 34 | """ 35 | LOCAL = 1 36 | 37 | 38 | class PathType(Enum): 39 | """Determines in which format should a path be formatted. 40 | 41 | WINDOWS: Format with Windows style. 42 | LINUX: Format with Linux/Posix style. 43 | AUTO: Use current OS type to select either WINDOWS or LINUX. 44 | """ 45 | WINDOWS = 1 46 | LINUX = 2 47 | AUTO = 3 48 | 49 | 50 | _user_name_override = None 51 | 52 | 53 | class SubmitConfig(util.EasyDict): 54 | """Strongly typed config dict needed to submit runs. 55 | 56 | Attributes: 57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 58 | run_desc: Description of the run. Will be used in the run dir and task name. 59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 61 | submit_target: Submit target enum value. Used to select where the run is actually launched. 62 | num_gpus: Number of GPUs used/requested for the run. 63 | print_info: Whether to print debug information when submitting. 64 | ask_confirmation: Whether to ask a confirmation before submitting. 65 | run_id: Automatically populated value during submit. 66 | run_name: Automatically populated value during submit. 67 | run_dir: Automatically populated value during submit. 68 | run_func_name: Automatically populated value during submit. 69 | run_func_kwargs: Automatically populated value during submit. 70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 71 | task_name: Automatically populated value during submit. 72 | host_name: Automatically populated value during submit. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | 78 | # run (set these) 79 | self.run_dir_root = "" # should always be passed through get_path_from_template 80 | self.run_desc = "" 81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] 82 | self.run_dir_extra_files = None 83 | 84 | # submit (set these) 85 | self.submit_target = SubmitTarget.LOCAL 86 | self.num_gpus = 1 87 | self.print_info = False 88 | self.ask_confirmation = False 89 | 90 | # (automatically populated) 91 | self.run_id = None 92 | self.run_name = None 93 | self.run_dir = None 94 | self.run_func_name = None 95 | self.run_func_kwargs = None 96 | self.user_name = None 97 | self.task_name = None 98 | self.host_name = "localhost" 99 | 100 | 101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 102 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 103 | # automatically select path type depending on running OS 104 | if path_type == PathType.AUTO: 105 | if platform.system() == "Windows": 106 | path_type = PathType.WINDOWS 107 | elif platform.system() == "Linux": 108 | path_type = PathType.LINUX 109 | else: 110 | raise RuntimeError("Unknown platform") 111 | 112 | path_template = path_template.replace("", get_user_name()) 113 | 114 | # return correctly formatted path 115 | if path_type == PathType.WINDOWS: 116 | return str(pathlib.PureWindowsPath(path_template)) 117 | elif path_type == PathType.LINUX: 118 | return str(pathlib.PurePosixPath(path_template)) 119 | else: 120 | raise RuntimeError("Unknown platform") 121 | 122 | 123 | def get_template_from_path(path: str) -> str: 124 | """Convert a normal path back to its template representation.""" 125 | # replace all path parts with the template tags 126 | path = path.replace("\\", "/") 127 | return path 128 | 129 | 130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 131 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 132 | path_template = get_template_from_path(path) 133 | path = get_path_from_template(path_template, path_type) 134 | return path 135 | 136 | 137 | def set_user_name_override(name: str) -> None: 138 | """Set the global username override value.""" 139 | global _user_name_override 140 | _user_name_override = name 141 | 142 | 143 | def get_user_name(): 144 | """Get the current user name.""" 145 | if _user_name_override is not None: 146 | return _user_name_override 147 | elif platform.system() == "Windows": 148 | return os.getlogin() 149 | elif platform.system() == "Linux": 150 | try: 151 | import pwd # pylint: disable=import-error 152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member 153 | except: 154 | return "unknown" 155 | else: 156 | raise RuntimeError("Unknown platform") 157 | 158 | 159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 160 | """Create a new run dir with increasing ID number at the start.""" 161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 162 | 163 | if not os.path.exists(run_dir_root): 164 | print("Creating the run dir root: {}".format(run_dir_root)) 165 | os.makedirs(run_dir_root) 166 | 167 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 169 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 170 | 171 | if os.path.exists(run_dir): 172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 173 | 174 | print("Creating the run dir: {}".format(run_dir)) 175 | os.makedirs(run_dir) 176 | 177 | return run_dir 178 | 179 | 180 | def _get_next_run_id_local(run_dir_root: str) -> int: 181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 183 | r = re.compile("^\\d+") # match one or more digits at the start of the string 184 | run_id = 0 185 | 186 | for dir_name in dir_names: 187 | m = r.match(dir_name) 188 | 189 | if m is not None: 190 | i = int(m.group()) 191 | run_id = max(run_id, i + 1) 192 | 193 | return run_id 194 | 195 | 196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: 197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 198 | print("Copying files to the run dir") 199 | files = [] 200 | 201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 202 | assert '.' in submit_config.run_func_name 203 | for _idx in range(submit_config.run_func_name.count('.') - 1): 204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 206 | 207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 209 | 210 | if submit_config.run_dir_extra_files is not None: 211 | files += submit_config.run_dir_extra_files 212 | 213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] 215 | 216 | util.copy_files_and_create_dirs(files) 217 | 218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 219 | 220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 222 | 223 | 224 | def run_wrapper(submit_config: SubmitConfig) -> None: 225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 227 | 228 | checker = None 229 | 230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 231 | if is_local: 232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 234 | logger = util.Logger(file_name=None, should_flush=True) 235 | 236 | import dnnlib 237 | dnnlib.submit_config = submit_config 238 | 239 | try: 240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 241 | start_time = time.time() 242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) 243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 244 | except: 245 | if is_local: 246 | raise 247 | else: 248 | traceback.print_exc() 249 | 250 | log_src = os.path.join(submit_config.run_dir, "log.txt") 251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 252 | shutil.copyfile(log_src, log_dst) 253 | finally: 254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 255 | 256 | dnnlib.submit_config = None 257 | logger.close() 258 | 259 | if checker is not None: 260 | checker.stop() 261 | 262 | 263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 265 | submit_config = copy.copy(submit_config) 266 | 267 | if submit_config.user_name is None: 268 | submit_config.user_name = get_user_name() 269 | 270 | submit_config.run_func_name = run_func_name 271 | submit_config.run_func_kwargs = run_func_kwargs 272 | 273 | assert submit_config.submit_target == SubmitTarget.LOCAL 274 | if submit_config.submit_target in {SubmitTarget.LOCAL}: 275 | run_dir = _create_run_dir_local(submit_config) 276 | 277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 278 | submit_config.run_dir = run_dir 279 | _populate_run_dir(run_dir, submit_config) 280 | 281 | if submit_config.print_info: 282 | print("\nSubmit config:\n") 283 | pprint.pprint(submit_config, indent=4, width=200, compact=False) 284 | print() 285 | 286 | if submit_config.ask_confirmation: 287 | if not util.ask_yes_no("Continue submitting the job?"): 288 | return 289 | 290 | run_wrapper(submit_config) 291 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | return 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | if uint8_cast: 241 | images = tf.saturate_cast(images, tf.uint8) 242 | return images 243 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import uuid 27 | 28 | from distutils.util import strtobool 29 | from typing import Any, List, Tuple, Union 30 | 31 | 32 | # Util classes 33 | # ------------------------------------------------------------------------------------------ 34 | 35 | 36 | class EasyDict(dict): 37 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 38 | 39 | def __getattr__(self, name: str) -> Any: 40 | try: 41 | return self[name] 42 | except KeyError: 43 | raise AttributeError(name) 44 | 45 | def __setattr__(self, name: str, value: Any) -> None: 46 | self[name] = value 47 | 48 | def __delattr__(self, name: str) -> None: 49 | del self[name] 50 | 51 | 52 | class Logger(object): 53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 54 | 55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 56 | self.file = None 57 | 58 | if file_name is not None: 59 | self.file = open(file_name, file_mode) 60 | 61 | self.should_flush = should_flush 62 | self.stdout = sys.stdout 63 | self.stderr = sys.stderr 64 | 65 | sys.stdout = self 66 | sys.stderr = self 67 | 68 | def __enter__(self) -> "Logger": 69 | return self 70 | 71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 72 | self.close() 73 | 74 | def write(self, text: str) -> None: 75 | """Write text to stdout (and a file) and optionally flush.""" 76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 77 | return 78 | 79 | if self.file is not None: 80 | self.file.write(text) 81 | 82 | self.stdout.write(text) 83 | 84 | if self.should_flush: 85 | self.flush() 86 | 87 | def flush(self) -> None: 88 | """Flush written text to both stdout and a file, if open.""" 89 | if self.file is not None: 90 | self.file.flush() 91 | 92 | self.stdout.flush() 93 | 94 | def close(self) -> None: 95 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 96 | self.flush() 97 | 98 | # if using multiple loggers, prevent closing in wrong order 99 | if sys.stdout is self: 100 | sys.stdout = self.stdout 101 | if sys.stderr is self: 102 | sys.stderr = self.stderr 103 | 104 | if self.file is not None: 105 | self.file.close() 106 | 107 | 108 | # Small util functions 109 | # ------------------------------------------------------------------------------------------ 110 | 111 | 112 | def format_time(seconds: Union[int, float]) -> str: 113 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 114 | s = int(np.rint(seconds)) 115 | 116 | if s < 60: 117 | return "{0}s".format(s) 118 | elif s < 60 * 60: 119 | return "{0}m {1:02}s".format(s // 60, s % 60) 120 | elif s < 24 * 60 * 60: 121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 122 | else: 123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 124 | 125 | 126 | def ask_yes_no(question: str) -> bool: 127 | """Ask the user the question until the user inputs a valid answer.""" 128 | while True: 129 | try: 130 | print("{0} [y/n]".format(question)) 131 | return strtobool(input().lower()) 132 | except ValueError: 133 | pass 134 | 135 | 136 | def tuple_product(t: Tuple) -> Any: 137 | """Calculate the product of the tuple elements.""" 138 | result = 1 139 | 140 | for v in t: 141 | result *= v 142 | 143 | return result 144 | 145 | 146 | _str_to_ctype = { 147 | "uint8": ctypes.c_ubyte, 148 | "uint16": ctypes.c_uint16, 149 | "uint32": ctypes.c_uint32, 150 | "uint64": ctypes.c_uint64, 151 | "int8": ctypes.c_byte, 152 | "int16": ctypes.c_int16, 153 | "int32": ctypes.c_int32, 154 | "int64": ctypes.c_int64, 155 | "float32": ctypes.c_float, 156 | "float64": ctypes.c_double 157 | } 158 | 159 | 160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 162 | type_str = None 163 | 164 | if isinstance(type_obj, str): 165 | type_str = type_obj 166 | elif hasattr(type_obj, "__name__"): 167 | type_str = type_obj.__name__ 168 | elif hasattr(type_obj, "name"): 169 | type_str = type_obj.name 170 | else: 171 | raise RuntimeError("Cannot infer type name from input") 172 | 173 | assert type_str in _str_to_ctype.keys() 174 | 175 | my_dtype = np.dtype(type_str) 176 | my_ctype = _str_to_ctype[type_str] 177 | 178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 179 | 180 | return my_dtype, my_ctype 181 | 182 | 183 | def is_pickleable(obj: Any) -> bool: 184 | try: 185 | with io.BytesIO() as stream: 186 | pickle.dump(obj, stream) 187 | return True 188 | except: 189 | return False 190 | 191 | 192 | # Functionality to import modules/objects by name, and call functions by name 193 | # ------------------------------------------------------------------------------------------ 194 | 195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 196 | """Searches for the underlying module behind the name to some python object. 197 | Returns the module and the object name (original name with module part removed).""" 198 | 199 | # allow convenience shorthands, substitute them by full names 200 | obj_name = re.sub("^np.", "numpy.", obj_name) 201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 202 | 203 | # list alternatives for (module_name, local_obj_name) 204 | parts = obj_name.split(".") 205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 206 | 207 | # try each alternative in turn 208 | for module_name, local_obj_name in name_pairs: 209 | try: 210 | module = importlib.import_module(module_name) # may raise ImportError 211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 212 | return module, local_obj_name 213 | except: 214 | pass 215 | 216 | # maybe some of the modules themselves contain errors? 217 | for module_name, _local_obj_name in name_pairs: 218 | try: 219 | importlib.import_module(module_name) # may raise ImportError 220 | except ImportError: 221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 222 | raise 223 | 224 | # maybe the requested attribute is missing? 225 | for module_name, local_obj_name in name_pairs: 226 | try: 227 | module = importlib.import_module(module_name) # may raise ImportError 228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 229 | except ImportError: 230 | pass 231 | 232 | # we are out of luck, but we have no idea why 233 | raise ImportError(obj_name) 234 | 235 | 236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 237 | """Traverses the object name and returns the last (rightmost) python object.""" 238 | if obj_name == '': 239 | return module 240 | obj = module 241 | for part in obj_name.split("."): 242 | obj = getattr(obj, part) 243 | return obj 244 | 245 | 246 | def get_obj_by_name(name: str) -> Any: 247 | """Finds the python object with the given name.""" 248 | module, obj_name = get_module_from_obj_name(name) 249 | return get_obj_from_module(module, obj_name) 250 | 251 | 252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 253 | """Finds the python object with the given name and calls it as a function.""" 254 | assert func_name is not None 255 | func_obj = get_obj_by_name(func_name) 256 | assert callable(func_obj) 257 | return func_obj(*args, **kwargs) 258 | 259 | 260 | def get_module_dir_by_obj_name(obj_name: str) -> str: 261 | """Get the directory path of the module containing the given object name.""" 262 | module, _ = get_module_from_obj_name(obj_name) 263 | return os.path.dirname(inspect.getfile(module)) 264 | 265 | 266 | def is_top_level_function(obj: Any) -> bool: 267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 269 | 270 | 271 | def get_top_level_function_name(obj: Any) -> str: 272 | """Return the fully-qualified name of a top-level function.""" 273 | assert is_top_level_function(obj) 274 | return obj.__module__ + "." + obj.__name__ 275 | 276 | 277 | # File system helpers 278 | # ------------------------------------------------------------------------------------------ 279 | 280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 281 | """List all files recursively in a given directory while ignoring given file and directory names. 282 | Returns list of tuples containing both absolute and relative paths.""" 283 | assert os.path.isdir(dir_path) 284 | base_name = os.path.basename(os.path.normpath(dir_path)) 285 | 286 | if ignores is None: 287 | ignores = [] 288 | 289 | result = [] 290 | 291 | for root, dirs, files in os.walk(dir_path, topdown=True): 292 | for ignore_ in ignores: 293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 294 | 295 | # dirs need to be edited in-place 296 | for d in dirs_to_remove: 297 | dirs.remove(d) 298 | 299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 300 | 301 | absolute_paths = [os.path.join(root, f) for f in files] 302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 303 | 304 | if add_base_to_relative: 305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 306 | 307 | assert len(absolute_paths) == len(relative_paths) 308 | result += zip(absolute_paths, relative_paths) 309 | 310 | return result 311 | 312 | 313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 314 | """Takes in a list of tuples of (src, dst) paths and copies files. 315 | Will create all necessary directories.""" 316 | for file in files: 317 | target_dir_name = os.path.dirname(file[1]) 318 | 319 | # will create all intermediate-level directories 320 | if not os.path.exists(target_dir_name): 321 | os.makedirs(target_dir_name) 322 | 323 | shutil.copyfile(file[0], file[1]) 324 | 325 | 326 | # URL helpers 327 | # ------------------------------------------------------------------------------------------ 328 | 329 | def is_url(obj: Any) -> bool: 330 | """Determine whether the given object is a valid URL string.""" 331 | if not isinstance(obj, str) or not "://" in obj: 332 | return False 333 | try: 334 | res = requests.compat.urlparse(obj) 335 | if not res.scheme or not res.netloc or not "." in res.netloc: 336 | return False 337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 338 | if not res.scheme or not res.netloc or not "." in res.netloc: 339 | return False 340 | except: 341 | return False 342 | return True 343 | 344 | 345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: 346 | """Download the given URL and return a binary-mode file object to access the data.""" 347 | assert is_url(url) 348 | assert num_attempts >= 1 349 | 350 | # Lookup from cache. 351 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 352 | if cache_dir is not None: 353 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 354 | if len(cache_files) == 1: 355 | return open(cache_files[0], "rb") 356 | 357 | # Download. 358 | url_name = None 359 | url_data = None 360 | with requests.Session() as session: 361 | if verbose: 362 | print("Downloading %s ..." % url, end="", flush=True) 363 | for attempts_left in reversed(range(num_attempts)): 364 | try: 365 | with session.get(url) as res: 366 | res.raise_for_status() 367 | if len(res.content) == 0: 368 | raise IOError("No data received") 369 | 370 | if len(res.content) < 8192: 371 | content_str = res.content.decode("utf-8") 372 | if "download_warning" in res.headers.get("Set-Cookie", ""): 373 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 374 | if len(links) == 1: 375 | url = requests.compat.urljoin(url, links[0]) 376 | raise IOError("Google Drive virus checker nag") 377 | if "Google Drive - Quota exceeded" in content_str: 378 | raise IOError("Google Drive quota exceeded") 379 | 380 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 381 | url_name = match[1] if match else url 382 | url_data = res.content 383 | if verbose: 384 | print(" done") 385 | break 386 | except: 387 | if not attempts_left: 388 | if verbose: 389 | print(" failed") 390 | raise 391 | if verbose: 392 | print(".", end="", flush=True) 393 | 394 | # Save to cache. 395 | if cache_dir is not None: 396 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 397 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 398 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 399 | os.makedirs(cache_dir, exist_ok=True) 400 | with open(temp_file, "wb") as f: 401 | f.write(url_data) 402 | os.replace(temp_file, cache_file) # atomic 403 | 404 | # Return data as file object. 405 | return io.BytesIO(url_data) 406 | -------------------------------------------------------------------------------- /encode_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | from tqdm import tqdm 5 | import PIL.Image 6 | import numpy as np 7 | import dnnlib 8 | import dnnlib.tflib as tflib 9 | import config 10 | from encoder.generator_model import Generator 11 | from encoder.perceptual_model import PerceptualModel 12 | 13 | URL_FFHQ = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 14 | 15 | 16 | def split_to_batches(l, n): 17 | for i in range(0, len(l), n): 18 | yield l[i:i + n] 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual loss') 23 | parser.add_argument('src_dir', help='Directory with images for encoding') 24 | parser.add_argument('generated_images_dir', help='Directory for storing generated images') 25 | parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations') 26 | 27 | # for now it's unclear if larger batch leads to better performance/quality 28 | parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int) 29 | 30 | # Perceptual model params 31 | parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int) 32 | parser.add_argument('--lr', default=1., help='Learning rate for perceptual model', type=float) 33 | parser.add_argument('--iterations', default=1000, help='Number of optimization steps for each batch', type=int) 34 | 35 | # Generator params 36 | parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=bool) 37 | args, other_args = parser.parse_known_args() 38 | 39 | ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)] 40 | ref_images = list(filter(os.path.isfile, ref_images)) 41 | 42 | if len(ref_images) == 0: 43 | raise Exception('%s is empty' % args.src_dir) 44 | 45 | os.makedirs(args.generated_images_dir, exist_ok=True) 46 | os.makedirs(args.dlatent_dir, exist_ok=True) 47 | 48 | # Initialize generator and perceptual model 49 | tflib.init_tf() 50 | with dnnlib.util.open_url(URL_FFHQ, cache_dir=config.cache_dir) as f: 51 | generator_network, discriminator_network, Gs_network = pickle.load(f) 52 | 53 | generator = Generator(Gs_network, args.batch_size, randomize_noise=args.randomize_noise) 54 | perceptual_model = PerceptualModel(args.image_size, layer=9, batch_size=args.batch_size) 55 | perceptual_model.build_perceptual_model(generator.generated_image) 56 | 57 | # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space 58 | for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images)//args.batch_size): 59 | names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch] 60 | 61 | perceptual_model.set_reference_images(images_batch) 62 | op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, learning_rate=args.lr) 63 | pbar = tqdm(op, leave=False, total=args.iterations) 64 | for loss in pbar: 65 | pbar.set_description(' '.join(names)+' Loss: %.2f' % loss) 66 | print(' '.join(names), ' loss:', loss) 67 | 68 | # Generate images from found dlatents and save them 69 | generated_images = generator.generate_images() 70 | generated_dlatents = generator.get_dlatents() 71 | for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, names): 72 | img = PIL.Image.fromarray(img_array, 'RGB') 73 | img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG') 74 | np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent) 75 | 76 | generator.reset_dlatents() 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/encoder/__init__.py -------------------------------------------------------------------------------- /encoder/generator_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import dnnlib.tflib as tflib 4 | from functools import partial 5 | 6 | 7 | def create_stub(name, batch_size): 8 | return tf.constant(0, dtype='float32', shape=(batch_size, 0)) 9 | 10 | 11 | def create_variable_for_generator(name, batch_size): 12 | return tf.get_variable('learnable_dlatents', 13 | shape=(batch_size, 18, 512), 14 | dtype='float32', 15 | initializer=tf.initializers.random_normal()) 16 | 17 | 18 | class Generator: 19 | def __init__(self, model, batch_size, randomize_noise=False): 20 | self.batch_size = batch_size 21 | 22 | self.initial_dlatents = np.zeros((self.batch_size, 18, 512)) 23 | model.components.synthesis.run(self.initial_dlatents, 24 | randomize_noise=randomize_noise, minibatch_size=self.batch_size, 25 | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size), 26 | partial(create_stub, batch_size=batch_size)], 27 | structure='fixed') 28 | 29 | self.sess = tf.get_default_session() 30 | self.graph = tf.get_default_graph() 31 | 32 | self.dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name) 33 | self.set_dlatents(self.initial_dlatents) 34 | 35 | self.generator_output = self.graph.get_tensor_by_name('G_synthesis_1/_Run/concat:0') 36 | self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False) 37 | self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8) 38 | 39 | def reset_dlatents(self): 40 | self.set_dlatents(self.initial_dlatents) 41 | 42 | def set_dlatents(self, dlatents): 43 | assert (dlatents.shape == (self.batch_size, 18, 512)) 44 | self.sess.run(tf.assign(self.dlatent_variable, dlatents)) 45 | 46 | def get_dlatents(self): 47 | return self.sess.run(self.dlatent_variable) 48 | 49 | def generate_images(self, dlatents=None): 50 | if dlatents: 51 | self.set_dlatents(dlatents) 52 | return self.sess.run(self.generated_image_uint8) 53 | -------------------------------------------------------------------------------- /encoder/perceptual_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from keras.models import Model 4 | from keras.applications.vgg16 import VGG16, preprocess_input 5 | from keras.preprocessing import image 6 | import keras.backend as K 7 | 8 | 9 | def load_images(images_list, img_size): 10 | loaded_images = list() 11 | for img_path in images_list: 12 | img = image.load_img(img_path, target_size=(img_size, img_size)) 13 | img = np.expand_dims(img, 0) 14 | loaded_images.append(img) 15 | loaded_images = np.vstack(loaded_images) 16 | preprocessed_images = preprocess_input(loaded_images) 17 | return preprocessed_images 18 | 19 | 20 | class PerceptualModel: 21 | def __init__(self, img_size, layer=9, batch_size=1, sess=None): 22 | self.sess = tf.get_default_session() if sess is None else sess 23 | K.set_session(self.sess) 24 | self.img_size = img_size 25 | self.layer = layer 26 | self.batch_size = batch_size 27 | 28 | self.perceptual_model = None 29 | self.ref_img_features = None 30 | self.features_weight = None 31 | self.loss = None 32 | 33 | def build_perceptual_model(self, generated_image_tensor): 34 | vgg16 = VGG16(include_top=False, input_shape=(self.img_size, self.img_size, 3)) 35 | self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output) 36 | generated_image = preprocess_input(tf.image.resize_images(generated_image_tensor, 37 | (self.img_size, self.img_size), method=1)) 38 | generated_img_features = self.perceptual_model(generated_image) 39 | 40 | self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape, 41 | dtype='float32', initializer=tf.initializers.zeros()) 42 | self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape, 43 | dtype='float32', initializer=tf.initializers.zeros()) 44 | self.sess.run([self.features_weight.initializer, self.features_weight.initializer]) 45 | 46 | self.loss = tf.losses.mean_squared_error(self.features_weight * self.ref_img_features, 47 | self.features_weight * generated_img_features) / 82890.0 48 | 49 | def set_reference_images(self, images_list): 50 | assert(len(images_list) != 0 and len(images_list) <= self.batch_size) 51 | loaded_image = load_images(images_list, self.img_size) 52 | image_features = self.perceptual_model.predict_on_batch(loaded_image) 53 | 54 | # in case if number of images less than actual batch size 55 | # can be optimized further 56 | weight_mask = np.ones(self.features_weight.shape) 57 | if len(images_list) != self.batch_size: 58 | features_space = list(self.features_weight.shape[1:]) 59 | existing_features_shape = [len(images_list)] + features_space 60 | empty_features_shape = [self.batch_size - len(images_list)] + features_space 61 | 62 | existing_examples = np.ones(shape=existing_features_shape) 63 | empty_examples = np.zeros(shape=empty_features_shape) 64 | weight_mask = np.vstack([existing_examples, empty_examples]) 65 | 66 | image_features = np.vstack([image_features, np.zeros(empty_features_shape)]) 67 | 68 | self.sess.run(tf.assign(self.features_weight, weight_mask)) 69 | self.sess.run(tf.assign(self.ref_img_features, image_features)) 70 | 71 | def optimize(self, vars_to_optimize, iterations=500, learning_rate=1.): 72 | vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize] 73 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 74 | min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize]) 75 | for _ in range(iterations): 76 | _, loss = self.sess.run([min_op, self.loss]) 77 | yield loss 78 | 79 | -------------------------------------------------------------------------------- /ffhq_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/ffhq_dataset/__init__.py -------------------------------------------------------------------------------- /ffhq_dataset/face_alignment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | import os 4 | import PIL.Image 5 | 6 | 7 | def image_align(src_file, dst_file, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True): 8 | # Align function from FFHQ dataset pre-processing step 9 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 10 | 11 | lm = np.array(face_landmarks) 12 | lm_chin = lm[0 : 17] # left-right 13 | lm_eyebrow_left = lm[17 : 22] # left-right 14 | lm_eyebrow_right = lm[22 : 27] # left-right 15 | lm_nose = lm[27 : 31] # top-down 16 | lm_nostrils = lm[31 : 36] # top-down 17 | lm_eye_left = lm[36 : 42] # left-clockwise 18 | lm_eye_right = lm[42 : 48] # left-clockwise 19 | lm_mouth_outer = lm[48 : 60] # left-clockwise 20 | lm_mouth_inner = lm[60 : 68] # left-clockwise 21 | 22 | # Calculate auxiliary vectors. 23 | eye_left = np.mean(lm_eye_left, axis=0) 24 | eye_right = np.mean(lm_eye_right, axis=0) 25 | eye_avg = (eye_left + eye_right) * 0.5 26 | eye_to_eye = eye_right - eye_left 27 | mouth_left = lm_mouth_outer[0] 28 | mouth_right = lm_mouth_outer[6] 29 | mouth_avg = (mouth_left + mouth_right) * 0.5 30 | eye_to_mouth = mouth_avg - eye_avg 31 | 32 | # Choose oriented crop rectangle. 33 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 34 | x /= np.hypot(*x) 35 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 36 | y = np.flipud(x) * [-1, 1] 37 | c = eye_avg + eye_to_mouth * 0.1 38 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 39 | qsize = np.hypot(*x) * 2 40 | 41 | # Load in-the-wild image. 42 | if not os.path.isfile(src_file): 43 | print('\nCannot find source image. Please run "--wilds" before "--align".') 44 | return 45 | img = PIL.Image.open(src_file) 46 | 47 | # Shrink. 48 | shrink = int(np.floor(qsize / output_size * 0.5)) 49 | if shrink > 1: 50 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 51 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 52 | quad /= shrink 53 | qsize /= shrink 54 | 55 | # Crop. 56 | border = max(int(np.rint(qsize * 0.1)), 3) 57 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 58 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 59 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 60 | img = img.crop(crop) 61 | quad -= crop[0:2] 62 | 63 | # Pad. 64 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 65 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 66 | if enable_padding and max(pad) > border - 4: 67 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 68 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 69 | h, w, _ = img.shape 70 | y, x, _ = np.ogrid[:h, :w, :1] 71 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 72 | blur = qsize * 0.02 73 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 74 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 75 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 76 | quad += pad[:2] 77 | 78 | # Transform. 79 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 80 | if output_size < transform_size: 81 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 82 | 83 | # Save aligned image. 84 | img.save(dst_file, 'PNG') 85 | -------------------------------------------------------------------------------- /ffhq_dataset/landmarks_detector.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | 3 | 4 | class LandmarksDetector: 5 | def __init__(self, predictor_model_path): 6 | """ 7 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file 8 | """ 9 | self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used 10 | self.shape_predictor = dlib.shape_predictor(predictor_model_path) 11 | 12 | def get_landmarks(self, image): 13 | img = dlib.load_rgb_image(image) 14 | dets = self.detector(img, 1) 15 | 16 | for detection in dets: 17 | face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()] 18 | yield face_landmarks 19 | -------------------------------------------------------------------------------- /ffhq_dataset/latent_directions/age.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/ffhq_dataset/latent_directions/age.npy -------------------------------------------------------------------------------- /ffhq_dataset/latent_directions/gender.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/ffhq_dataset/latent_directions/gender.npy -------------------------------------------------------------------------------- /ffhq_dataset/latent_directions/smile.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/ffhq_dataset/latent_directions/smile.npy -------------------------------------------------------------------------------- /ffhq_dataset/latent_representations/donald_trump_01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/ffhq_dataset/latent_representations/donald_trump_01.npy -------------------------------------------------------------------------------- /ffhq_dataset/latent_representations/hillary_clinton_01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/ffhq_dataset/latent_representations/hillary_clinton_01.npy -------------------------------------------------------------------------------- /generate_figures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators.""" 9 | 10 | import os 11 | import pickle 12 | import numpy as np 13 | import PIL.Image 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | import config 17 | 18 | #---------------------------------------------------------------------------- 19 | # Helpers for loading and using pre-trained generators. 20 | 21 | url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 22 | url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl 23 | url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl 24 | url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl 25 | url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl 26 | 27 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 28 | 29 | _Gs_cache = dict() 30 | 31 | def load_Gs(url): 32 | if url not in _Gs_cache: 33 | with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 34 | _G, _D, Gs = pickle.load(f) 35 | _Gs_cache[url] = Gs 36 | return _Gs_cache[url] 37 | 38 | #---------------------------------------------------------------------------- 39 | # Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images. 40 | 41 | def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed): 42 | print(png) 43 | latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1]) 44 | images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb] 45 | 46 | canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white') 47 | image_iter = iter(list(images)) 48 | for col, lod in enumerate(lods): 49 | for row in range(rows * 2**lod): 50 | image = PIL.Image.fromarray(next(image_iter), 'RGB') 51 | image = image.crop((cx, cy, cx + cw, cy + ch)) 52 | image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS) 53 | canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod)) 54 | canvas.save(png) 55 | 56 | #---------------------------------------------------------------------------- 57 | # Figure 3: Style mixing. 58 | 59 | def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges): 60 | print(png) 61 | src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) 62 | dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds) 63 | src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] 64 | dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component] 65 | src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs) 66 | dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs) 67 | 68 | canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white') 69 | for col, src_image in enumerate(list(src_images)): 70 | canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0)) 71 | for row, dst_image in enumerate(list(dst_images)): 72 | canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h)) 73 | row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds)) 74 | row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]] 75 | row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) 76 | for col, image in enumerate(list(row_images)): 77 | canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h)) 78 | canvas.save(png) 79 | 80 | #---------------------------------------------------------------------------- 81 | # Figure 4: Noise detail. 82 | 83 | def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds): 84 | print(png) 85 | canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white') 86 | for row, seed in enumerate(seeds): 87 | latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples) 88 | images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs) 89 | canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h)) 90 | for i in range(4): 91 | crop = PIL.Image.fromarray(images[i + 1], 'RGB') 92 | crop = crop.crop((650, 180, 906, 436)) 93 | crop = crop.resize((w//2, h//2), PIL.Image.NEAREST) 94 | canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2)) 95 | diff = np.std(np.mean(images, axis=3), axis=0) * 4 96 | diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8) 97 | canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h)) 98 | canvas.save(png) 99 | 100 | #---------------------------------------------------------------------------- 101 | # Figure 5: Noise components. 102 | 103 | def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips): 104 | print(png) 105 | Gsc = Gs.clone() 106 | noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')] 107 | noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...] 108 | latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) 109 | all_images = [] 110 | for noise_range in noise_ranges: 111 | tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)}) 112 | range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs) 113 | range_images[flips, :, :] = range_images[flips, :, ::-1] 114 | all_images.append(list(range_images)) 115 | 116 | canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white') 117 | for col, col_images in enumerate(zip(*all_images)): 118 | canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0)) 119 | canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0)) 120 | canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h)) 121 | canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h)) 122 | canvas.save(png) 123 | 124 | #---------------------------------------------------------------------------- 125 | # Figure 8: Truncation trick. 126 | 127 | def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis): 128 | print(png) 129 | latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) 130 | dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component] 131 | dlatent_avg = Gs.get_var('dlatent_avg') # [component] 132 | 133 | canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white') 134 | for row, dlatent in enumerate(list(dlatents)): 135 | row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg 136 | row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) 137 | for col, image in enumerate(list(row_images)): 138 | canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h)) 139 | canvas.save(png) 140 | 141 | #---------------------------------------------------------------------------- 142 | # Main program. 143 | 144 | def main(): 145 | tflib.init_tf() 146 | os.makedirs(config.result_dir, exist_ok=True) 147 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5) 148 | draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)]) 149 | draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012]) 150 | draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1]) 151 | draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1]) 152 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0) 153 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2) 154 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1) 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | if __name__ == "__main__": 159 | main() 160 | 161 | #---------------------------------------------------------------------------- 162 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import numpy as np 12 | import scipy 13 | import tensorflow as tf 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class FID(metric_base.MetricBase): 22 | def __init__(self, num_images, minibatch_per_gpu, **kwargs): 23 | super().__init__(**kwargs) 24 | self.num_images = num_images 25 | self.minibatch_per_gpu = minibatch_per_gpu 26 | 27 | def _evaluate(self, Gs, num_gpus): 28 | minibatch_size = num_gpus * self.minibatch_per_gpu 29 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl 30 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 31 | 32 | # Calculate statistics for reals. 33 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images) 34 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 35 | if os.path.isfile(cache_file): 36 | mu_real, sigma_real = misc.load_pkl(cache_file) 37 | else: 38 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): 39 | begin = idx * minibatch_size 40 | end = min(begin + minibatch_size, self.num_images) 41 | activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) 42 | if end == self.num_images: 43 | break 44 | mu_real = np.mean(activations, axis=0) 45 | sigma_real = np.cov(activations, rowvar=False) 46 | misc.save_pkl((mu_real, sigma_real), cache_file) 47 | 48 | # Construct TensorFlow graph. 49 | result_expr = [] 50 | for gpu_idx in range(num_gpus): 51 | with tf.device('/gpu:%d' % gpu_idx): 52 | Gs_clone = Gs.clone() 53 | inception_clone = inception.clone() 54 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 55 | images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) 56 | images = tflib.convert_images_to_uint8(images) 57 | result_expr.append(inception_clone.get_output_for(images)) 58 | 59 | # Calculate statistics for fakes. 60 | for begin in range(0, self.num_images, minibatch_size): 61 | end = min(begin + minibatch_size, self.num_images) 62 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 63 | mu_fake = np.mean(activations, axis=0) 64 | sigma_fake = np.cov(activations, rowvar=False) 65 | 66 | # Calculate FID. 67 | m = np.square(mu_fake - mu_real).sum() 68 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member 69 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 70 | self._report_result(np.real(dist)) 71 | 72 | #---------------------------------------------------------------------------- 73 | -------------------------------------------------------------------------------- /metrics/linear_separability.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Linear Separability (LS).""" 9 | 10 | from collections import defaultdict 11 | import numpy as np 12 | import sklearn.svm 13 | import tensorflow as tf 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | classifier_urls = [ 22 | 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl 23 | 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl 24 | 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl 25 | 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl 26 | 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl 27 | 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl 28 | 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl 29 | 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl 30 | 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl 31 | 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl 32 | 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl 33 | 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl 34 | 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl 35 | 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl 36 | 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl 37 | 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl 38 | 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl 39 | 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl 40 | 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl 41 | 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl 42 | 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl 43 | 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl 44 | 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl 45 | 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl 46 | 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl 47 | 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl 48 | 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl 49 | 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl 50 | 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl 51 | 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl 52 | 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl 53 | 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl 54 | 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl 55 | 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl 56 | 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl 57 | 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl 58 | 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl 59 | 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl 60 | 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl 61 | 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl 62 | ] 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | def prob_normalize(p): 67 | p = np.asarray(p).astype(np.float32) 68 | assert len(p.shape) == 2 69 | return p / np.sum(p) 70 | 71 | def mutual_information(p): 72 | p = prob_normalize(p) 73 | px = np.sum(p, axis=1) 74 | py = np.sum(p, axis=0) 75 | result = 0.0 76 | for x in range(p.shape[0]): 77 | p_x = px[x] 78 | for y in range(p.shape[1]): 79 | p_xy = p[x][y] 80 | p_y = py[y] 81 | if p_xy > 0.0: 82 | result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output 83 | return result 84 | 85 | def entropy(p): 86 | p = prob_normalize(p) 87 | result = 0.0 88 | for x in range(p.shape[0]): 89 | for y in range(p.shape[1]): 90 | p_xy = p[x][y] 91 | if p_xy > 0.0: 92 | result -= p_xy * np.log2(p_xy) 93 | return result 94 | 95 | def conditional_entropy(p): 96 | # H(Y|X) where X corresponds to axis 0, Y to axis 1 97 | # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? 98 | p = prob_normalize(p) 99 | y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) 100 | return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | class LS(metric_base.MetricBase): 105 | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): 106 | assert num_keep <= num_samples 107 | super().__init__(**kwargs) 108 | self.num_samples = num_samples 109 | self.num_keep = num_keep 110 | self.attrib_indices = attrib_indices 111 | self.minibatch_per_gpu = minibatch_per_gpu 112 | 113 | def _evaluate(self, Gs, num_gpus): 114 | minibatch_size = num_gpus * self.minibatch_per_gpu 115 | 116 | # Construct TensorFlow graph for each GPU. 117 | result_expr = [] 118 | for gpu_idx in range(num_gpus): 119 | with tf.device('/gpu:%d' % gpu_idx): 120 | Gs_clone = Gs.clone() 121 | 122 | # Generate images. 123 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 124 | dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True) 125 | images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True) 126 | 127 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 128 | if images.shape[2] > 256: 129 | factor = images.shape[2] // 256 130 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 131 | images = tf.reduce_mean(images, axis=[3, 5]) 132 | 133 | # Run classifier for each attribute. 134 | result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) 135 | for attrib_idx in self.attrib_indices: 136 | classifier = misc.load_pkl(classifier_urls[attrib_idx]) 137 | logits = classifier.get_output_for(images, None) 138 | predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) 139 | result_dict[attrib_idx] = predictions 140 | result_expr.append(result_dict) 141 | 142 | # Sampling loop. 143 | results = [] 144 | for _ in range(0, self.num_samples, minibatch_size): 145 | results += tflib.run(result_expr) 146 | results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} 147 | 148 | # Calculate conditional entropy for each attribute. 149 | conditional_entropies = defaultdict(list) 150 | for attrib_idx in self.attrib_indices: 151 | # Prune the least confident samples. 152 | pruned_indices = list(range(self.num_samples)) 153 | pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) 154 | pruned_indices = pruned_indices[:self.num_keep] 155 | 156 | # Fit SVM to the remaining samples. 157 | svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) 158 | for space in ['latents', 'dlatents']: 159 | svm_inputs = results[space][pruned_indices] 160 | try: 161 | svm = sklearn.svm.LinearSVC() 162 | svm.fit(svm_inputs, svm_targets) 163 | svm.score(svm_inputs, svm_targets) 164 | svm_outputs = svm.predict(svm_inputs) 165 | except: 166 | svm_outputs = svm_targets # assume perfect prediction 167 | 168 | # Calculate conditional entropy. 169 | p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] 170 | conditional_entropies[space].append(conditional_entropy(p)) 171 | 172 | # Calculate separability scores. 173 | scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} 174 | self._report_result(scores['latents'], suffix='_z') 175 | self._report_result(scores['dlatents'], suffix='_w') 176 | 177 | #---------------------------------------------------------------------------- 178 | -------------------------------------------------------------------------------- /metrics/metric_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Common definitions for GAN metrics.""" 9 | 10 | import os 11 | import time 12 | import hashlib 13 | import numpy as np 14 | import tensorflow as tf 15 | import dnnlib 16 | import dnnlib.tflib as tflib 17 | 18 | import config 19 | from training import misc 20 | from training import dataset 21 | 22 | #---------------------------------------------------------------------------- 23 | # Standard metrics. 24 | 25 | fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) 26 | ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) 27 | ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) 28 | ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) 29 | ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) 30 | ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) 31 | dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging 32 | 33 | #---------------------------------------------------------------------------- 34 | # Base class for metrics. 35 | 36 | class MetricBase: 37 | def __init__(self, name): 38 | self.name = name 39 | self._network_pkl = None 40 | self._dataset_args = None 41 | self._mirror_augment = None 42 | self._results = [] 43 | self._eval_time = None 44 | 45 | def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): 46 | self._network_pkl = network_pkl 47 | self._dataset_args = dataset_args 48 | self._mirror_augment = mirror_augment 49 | self._results = [] 50 | 51 | if (dataset_args is None or mirror_augment is None) and run_dir is not None: 52 | run_config = misc.parse_config_for_previous_run(run_dir) 53 | self._dataset_args = dict(run_config['dataset']) 54 | self._dataset_args['shuffle_mb'] = 0 55 | self._mirror_augment = run_config['train'].get('mirror_augment', False) 56 | 57 | time_begin = time.time() 58 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager 59 | _G, _D, Gs = misc.load_pkl(self._network_pkl) 60 | self._evaluate(Gs, num_gpus=num_gpus) 61 | self._eval_time = time.time() - time_begin 62 | 63 | if log_results: 64 | result_str = self.get_result_str() 65 | if run_dir is not None: 66 | log = os.path.join(run_dir, 'metric-%s.txt' % self.name) 67 | with dnnlib.util.Logger(log, 'a'): 68 | print(result_str) 69 | else: 70 | print(result_str) 71 | 72 | def get_result_str(self): 73 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] 74 | if len(network_name) > 29: 75 | network_name = '...' + network_name[-26:] 76 | result_str = '%-30s' % network_name 77 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) 78 | for res in self._results: 79 | result_str += ' ' + self.name + res.suffix + ' ' 80 | result_str += res.fmt % res.value 81 | return result_str 82 | 83 | def update_autosummaries(self): 84 | for res in self._results: 85 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) 86 | 87 | def _evaluate(self, Gs, num_gpus): 88 | raise NotImplementedError # to be overridden by subclasses 89 | 90 | def _report_result(self, value, suffix='', fmt='%-10.4f'): 91 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] 92 | 93 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs): 94 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) 95 | all_args.update(self._dataset_args) 96 | all_args.update(kwargs) 97 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) 98 | dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] 99 | return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) 100 | 101 | def _iterate_reals(self, minibatch_size): 102 | dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) 103 | while True: 104 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 105 | if self._mirror_augment: 106 | images = misc.apply_mirror_augment(images) 107 | yield images 108 | 109 | def _iterate_fakes(self, Gs, minibatch_size, num_gpus): 110 | while True: 111 | latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) 112 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 113 | images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) 114 | yield images 115 | 116 | #---------------------------------------------------------------------------- 117 | # Group of multiple metrics. 118 | 119 | class MetricGroup: 120 | def __init__(self, metric_kwarg_list): 121 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] 122 | 123 | def run(self, *args, **kwargs): 124 | for metric in self.metrics: 125 | metric.run(*args, **kwargs) 126 | 127 | def get_result_str(self): 128 | return ' '.join(metric.get_result_str() for metric in self.metrics) 129 | 130 | def update_autosummaries(self): 131 | for metric in self.metrics: 132 | metric.update_autosummaries() 133 | 134 | #---------------------------------------------------------------------------- 135 | # Dummy metric for debugging purposes. 136 | 137 | class DummyMetric(MetricBase): 138 | def _evaluate(self, Gs, num_gpus): 139 | _ = Gs, num_gpus 140 | self._report_result(0.0) 141 | 142 | #---------------------------------------------------------------------------- 143 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Perceptual Path Length (PPL).""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import dnnlib.tflib as tflib 13 | 14 | from metrics import metric_base 15 | from training import misc 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | # Normalize batch of vectors. 20 | def normalize(v): 21 | return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) 22 | 23 | # Spherical interpolation of a batch of vectors. 24 | def slerp(a, b, t): 25 | a = normalize(a) 26 | b = normalize(b) 27 | d = tf.reduce_sum(a * b, axis=-1, keepdims=True) 28 | p = t * tf.math.acos(d) 29 | c = normalize(b - d * a) 30 | d = a * tf.math.cos(p) + c * tf.math.sin(p) 31 | return normalize(d) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPL(metric_base.MetricBase): 36 | def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__(**kwargs) 40 | self.num_samples = num_samples 41 | self.epsilon = epsilon 42 | self.space = space 43 | self.sampling = sampling 44 | self.minibatch_per_gpu = minibatch_per_gpu 45 | 46 | def _evaluate(self, Gs, num_gpus): 47 | minibatch_size = num_gpus * self.minibatch_per_gpu 48 | 49 | # Construct TensorFlow graph. 50 | distance_expr = [] 51 | for gpu_idx in range(num_gpus): 52 | with tf.device('/gpu:%d' % gpu_idx): 53 | Gs_clone = Gs.clone() 54 | noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] 55 | 56 | # Generate random latents and interpolation t-values. 57 | lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) 58 | lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) 59 | 60 | # Interpolate in W or Z. 61 | if self.space == 'w': 62 | dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True) 63 | dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] 64 | dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) 65 | dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) 66 | dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) 67 | else: # space == 'z' 68 | lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] 69 | lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) 70 | lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) 71 | lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) 72 | dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True) 73 | 74 | # Synthesize images. 75 | with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch 76 | images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False) 77 | 78 | # Crop only the face region. 79 | c = int(images.shape[2] // 8) 80 | images = images[:, :, c*3 : c*7, c*2 : c*6] 81 | 82 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 83 | if images.shape[2] > 256: 84 | factor = images.shape[2] // 256 85 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 86 | images = tf.reduce_mean(images, axis=[3,5]) 87 | 88 | # Scale dynamic range from [-1,1] to [0,255] for VGG. 89 | images = (images + 1) * (255 / 2) 90 | 91 | # Evaluate perceptual distance. 92 | img_e0, img_e1 = images[0::2], images[1::2] 93 | distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl 94 | distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) 95 | 96 | # Sampling loop. 97 | all_distances = [] 98 | for _ in range(0, self.num_samples, minibatch_size): 99 | all_distances += tflib.run(distance_expr) 100 | all_distances = np.concatenate(all_distances, axis=0) 101 | 102 | # Reject outliers. 103 | lo = np.percentile(all_distances, 1, interpolation='lower') 104 | hi = np.percentile(all_distances, 99, interpolation='higher') 105 | filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) 106 | self._report_result(np.mean(filtered_distances)) 107 | 108 | #---------------------------------------------------------------------------- 109 | -------------------------------------------------------------------------------- /pretrained_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Minimal script for generating an image using pre-trained StyleGAN generator.""" 9 | 10 | import os 11 | import pickle 12 | import numpy as np 13 | import PIL.Image 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | import config 17 | 18 | def main(): 19 | # Initialize TensorFlow. 20 | tflib.init_tf() 21 | 22 | # Load pre-trained network. 23 | url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 24 | with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 25 | _G, _D, Gs = pickle.load(f) 26 | # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. 27 | # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. 28 | # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. 29 | 30 | # Print network details. 31 | Gs.print_layers() 32 | 33 | # Pick latent vector. 34 | rnd = np.random.RandomState(5) 35 | latents = rnd.randn(1, Gs.input_shape[1]) 36 | 37 | # Generate image. 38 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 39 | images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) 40 | 41 | # Save image. 42 | os.makedirs(config.result_dir, exist_ok=True) 43 | png_filename = os.path.join(config.result_dir, 'example.png') 44 | PIL.Image.fromarray(images[0], 'RGB').save(png_filename) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/requirements.txt -------------------------------------------------------------------------------- /run_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Main entry point for training StyleGAN and ProGAN networks.""" 9 | 10 | import dnnlib 11 | from dnnlib import EasyDict 12 | import dnnlib.tflib as tflib 13 | 14 | import config 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment): 21 | ctx = dnnlib.RunContext(submit_config) 22 | tflib.init_tf() 23 | print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl)) 24 | metric = dnnlib.util.call_func_by_name(**metric_args) 25 | print() 26 | metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus) 27 | print() 28 | ctx.close() 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | def run_snapshot(submit_config, metric_args, run_id, snapshot): 33 | ctx = dnnlib.RunContext(submit_config) 34 | tflib.init_tf() 35 | print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot)) 36 | run_dir = misc.locate_run_dir(run_id) 37 | network_pkl = misc.locate_network_pkl(run_dir, snapshot) 38 | metric = dnnlib.util.call_func_by_name(**metric_args) 39 | print() 40 | metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) 41 | print() 42 | ctx.close() 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def run_all_snapshots(submit_config, metric_args, run_id): 47 | ctx = dnnlib.RunContext(submit_config) 48 | tflib.init_tf() 49 | print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id)) 50 | run_dir = misc.locate_run_dir(run_id) 51 | network_pkls = misc.list_network_pkls(run_dir) 52 | metric = dnnlib.util.call_func_by_name(**metric_args) 53 | print() 54 | for idx, network_pkl in enumerate(network_pkls): 55 | ctx.update('', idx, len(network_pkls)) 56 | metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) 57 | print() 58 | ctx.close() 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | def main(): 63 | submit_config = dnnlib.SubmitConfig() 64 | 65 | # Which metrics to evaluate? 66 | metrics = [] 67 | metrics += [metric_base.fid50k] 68 | #metrics += [metric_base.ppl_zfull] 69 | #metrics += [metric_base.ppl_wfull] 70 | #metrics += [metric_base.ppl_zend] 71 | #metrics += [metric_base.ppl_wend] 72 | #metrics += [metric_base.ls] 73 | #metrics += [metric_base.dummy] 74 | 75 | # Which networks to evaluate them on? 76 | tasks = [] 77 | tasks += [EasyDict(run_func_name='run_metrics.run_pickle', network_pkl='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', dataset_args=EasyDict(tfrecord_dir='ffhq', shuffle_mb=0), mirror_augment=True)] # karras2019stylegan-ffhq-1024x1024.pkl 78 | #tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)] 79 | #tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)] 80 | 81 | # How many GPUs to use? 82 | submit_config.num_gpus = 1 83 | #submit_config.num_gpus = 2 84 | #submit_config.num_gpus = 4 85 | #submit_config.num_gpus = 8 86 | 87 | # Execute. 88 | submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 89 | submit_config.run_dir_ignore += config.run_dir_ignore 90 | for task in tasks: 91 | for metric in metrics: 92 | submit_config.run_desc = '%s-%s' % (task.run_func_name, metric.name) 93 | if task.run_func_name.endswith('run_snapshot'): 94 | submit_config.run_desc += '-%s-%s' % (task.run_id, task.snapshot) 95 | if task.run_func_name.endswith('run_all_snapshots'): 96 | submit_config.run_desc += '-%s' % task.run_id 97 | submit_config.run_desc += '-%dgpu' % submit_config.num_gpus 98 | dnnlib.submit_run(submit_config, metric_args=metric, **task) 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | if __name__ == "__main__": 103 | main() 104 | 105 | #---------------------------------------------------------------------------- 106 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Puzer/stylegan-encoder/1e7e47f9bbb0ca391cdc250af5ad2468250a803c/teaser.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Main entry point for training StyleGAN and ProGAN networks.""" 9 | 10 | import copy 11 | import dnnlib 12 | from dnnlib import EasyDict 13 | 14 | import config 15 | from metrics import metric_base 16 | 17 | #---------------------------------------------------------------------------- 18 | # Official training configs for StyleGAN, targeted mainly for FFHQ. 19 | 20 | if 1: 21 | desc = 'sgan' # Description string included in result subdir name. 22 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 23 | G = EasyDict(func_name='training.networks_stylegan.G_style') # Options for generator network. 24 | D = EasyDict(func_name='training.networks_stylegan.D_basic') # Options for discriminator network. 25 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 26 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 27 | G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating') # Options for generator loss. 28 | D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss. 29 | dataset = EasyDict() # Options for load_dataset(). 30 | sched = EasyDict() # Options for TrainingSchedule. 31 | grid = EasyDict(size='4k', layout='random') # Options for setup_snapshot_image_grid(). 32 | metrics = [metric_base.fid50k] # Options for MetricGroup. 33 | submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 34 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). 35 | 36 | # Dataset. 37 | desc += '-ffhq'; dataset = EasyDict(tfrecord_dir='ffhq'); train.mirror_augment = True 38 | #desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True 39 | #desc += '-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-full'); train.mirror_augment = False 40 | #desc += '-car'; dataset = EasyDict(tfrecord_dir='lsun-car-512x384'); train.mirror_augment = False 41 | #desc += '-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-full'); train.mirror_augment = False 42 | 43 | # Number of GPUs. 44 | #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4} 45 | #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8} 46 | #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16} 47 | desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32} 48 | 49 | # Default options. 50 | train.total_kimg = 25000 51 | sched.lod_initial_resolution = 8 52 | sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 53 | sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) 54 | 55 | # WGAN-GP loss for CelebA-HQ. 56 | #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) 57 | 58 | # Table 1. 59 | #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 60 | #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 61 | #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False 62 | #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0 63 | #desc += '-mixing-regularization' # default 64 | 65 | # Table 2. 66 | #desc += '-mix0'; G.style_mixing_prob = 0.0 67 | #desc += '-mix50'; G.style_mixing_prob = 0.5 68 | #desc += '-mix90'; G.style_mixing_prob = 0.9 # default 69 | #desc += '-mix100'; G.style_mixing_prob = 1.0 70 | 71 | # Table 4. 72 | #desc += '-traditional-0'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 73 | #desc += '-traditional-8'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 8; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False 74 | #desc += '-stylebased-0'; G.mapping_layers = 0 75 | #desc += '-stylebased-1'; G.mapping_layers = 1 76 | #desc += '-stylebased-2'; G.mapping_layers = 2 77 | #desc += '-stylebased-8'; G.mapping_layers = 8 # default 78 | 79 | #---------------------------------------------------------------------------- 80 | # Official training configs for Progressive GAN, targeted mainly for CelebA-HQ. 81 | 82 | if 0: 83 | desc = 'pgan' # Description string included in result subdir name. 84 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 85 | G = EasyDict(func_name='training.networks_progan.G_paper') # Options for generator network. 86 | D = EasyDict(func_name='training.networks_progan.D_paper') # Options for discriminator network. 87 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 88 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 89 | G_loss = EasyDict(func_name='training.loss.G_wgan') # Options for generator loss. 90 | D_loss = EasyDict(func_name='training.loss.D_wgan_gp') # Options for discriminator loss. 91 | dataset = EasyDict() # Options for load_dataset(). 92 | sched = EasyDict() # Options for TrainingSchedule. 93 | grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid(). 94 | metrics = [metric_base.fid50k] # Options for MetricGroup. 95 | submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 96 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). 97 | 98 | # Dataset (choose one). 99 | desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True 100 | #desc += '-celeba'; dataset = EasyDict(tfrecord_dir='celeba'); train.mirror_augment = True 101 | #desc += '-cifar10'; dataset = EasyDict(tfrecord_dir='cifar10') 102 | #desc += '-cifar100'; dataset = EasyDict(tfrecord_dir='cifar100') 103 | #desc += '-svhn'; dataset = EasyDict(tfrecord_dir='svhn') 104 | #desc += '-mnist'; dataset = EasyDict(tfrecord_dir='mnist') 105 | #desc += '-mnistrgb'; dataset = EasyDict(tfrecord_dir='mnistrgb') 106 | #desc += '-syn1024rgb'; dataset = EasyDict(class_name='training.dataset.SyntheticDataset', resolution=1024, num_channels=3) 107 | #desc += '-lsun-airplane'; dataset = EasyDict(tfrecord_dir='lsun-airplane-100k'); train.mirror_augment = True 108 | #desc += '-lsun-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-100k'); train.mirror_augment = True 109 | #desc += '-lsun-bicycle'; dataset = EasyDict(tfrecord_dir='lsun-bicycle-100k'); train.mirror_augment = True 110 | #desc += '-lsun-bird'; dataset = EasyDict(tfrecord_dir='lsun-bird-100k'); train.mirror_augment = True 111 | #desc += '-lsun-boat'; dataset = EasyDict(tfrecord_dir='lsun-boat-100k'); train.mirror_augment = True 112 | #desc += '-lsun-bottle'; dataset = EasyDict(tfrecord_dir='lsun-bottle-100k'); train.mirror_augment = True 113 | #desc += '-lsun-bridge'; dataset = EasyDict(tfrecord_dir='lsun-bridge-100k'); train.mirror_augment = True 114 | #desc += '-lsun-bus'; dataset = EasyDict(tfrecord_dir='lsun-bus-100k'); train.mirror_augment = True 115 | #desc += '-lsun-car'; dataset = EasyDict(tfrecord_dir='lsun-car-100k'); train.mirror_augment = True 116 | #desc += '-lsun-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-100k'); train.mirror_augment = True 117 | #desc += '-lsun-chair'; dataset = EasyDict(tfrecord_dir='lsun-chair-100k'); train.mirror_augment = True 118 | #desc += '-lsun-churchoutdoor'; dataset = EasyDict(tfrecord_dir='lsun-churchoutdoor-100k'); train.mirror_augment = True 119 | #desc += '-lsun-classroom'; dataset = EasyDict(tfrecord_dir='lsun-classroom-100k'); train.mirror_augment = True 120 | #desc += '-lsun-conferenceroom'; dataset = EasyDict(tfrecord_dir='lsun-conferenceroom-100k'); train.mirror_augment = True 121 | #desc += '-lsun-cow'; dataset = EasyDict(tfrecord_dir='lsun-cow-100k'); train.mirror_augment = True 122 | #desc += '-lsun-diningroom'; dataset = EasyDict(tfrecord_dir='lsun-diningroom-100k'); train.mirror_augment = True 123 | #desc += '-lsun-diningtable'; dataset = EasyDict(tfrecord_dir='lsun-diningtable-100k'); train.mirror_augment = True 124 | #desc += '-lsun-dog'; dataset = EasyDict(tfrecord_dir='lsun-dog-100k'); train.mirror_augment = True 125 | #desc += '-lsun-horse'; dataset = EasyDict(tfrecord_dir='lsun-horse-100k'); train.mirror_augment = True 126 | #desc += '-lsun-kitchen'; dataset = EasyDict(tfrecord_dir='lsun-kitchen-100k'); train.mirror_augment = True 127 | #desc += '-lsun-livingroom'; dataset = EasyDict(tfrecord_dir='lsun-livingroom-100k'); train.mirror_augment = True 128 | #desc += '-lsun-motorbike'; dataset = EasyDict(tfrecord_dir='lsun-motorbike-100k'); train.mirror_augment = True 129 | #desc += '-lsun-person'; dataset = EasyDict(tfrecord_dir='lsun-person-100k'); train.mirror_augment = True 130 | #desc += '-lsun-pottedplant'; dataset = EasyDict(tfrecord_dir='lsun-pottedplant-100k'); train.mirror_augment = True 131 | #desc += '-lsun-restaurant'; dataset = EasyDict(tfrecord_dir='lsun-restaurant-100k'); train.mirror_augment = True 132 | #desc += '-lsun-sheep'; dataset = EasyDict(tfrecord_dir='lsun-sheep-100k'); train.mirror_augment = True 133 | #desc += '-lsun-sofa'; dataset = EasyDict(tfrecord_dir='lsun-sofa-100k'); train.mirror_augment = True 134 | #desc += '-lsun-tower'; dataset = EasyDict(tfrecord_dir='lsun-tower-100k'); train.mirror_augment = True 135 | #desc += '-lsun-train'; dataset = EasyDict(tfrecord_dir='lsun-train-100k'); train.mirror_augment = True 136 | #desc += '-lsun-tvmonitor'; dataset = EasyDict(tfrecord_dir='lsun-tvmonitor-100k'); train.mirror_augment = True 137 | 138 | # Conditioning & snapshot options. 139 | #desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label 140 | #desc += '-cond1'; dataset.max_label_size = 1 # conditioned on first component of the label 141 | #desc += '-g4k'; grid.size = '4k' 142 | #desc += '-grpc'; grid.layout = 'row_per_class' 143 | 144 | # Config presets (choose one). 145 | #desc += '-preset-v1-1gpu'; submit_config.num_gpus = 1; D.mbstd_group_size = 16; sched.minibatch_base = 16; sched.minibatch_dict = {256: 14, 512: 6, 1024: 3}; sched.lod_training_kimg = 800; sched.lod_transition_kimg = 800; train.total_kimg = 19000 146 | desc += '-preset-v2-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 147 | #desc += '-preset-v2-2gpus'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}; sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 148 | #desc += '-preset-v2-4gpus'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}; sched.G_lrate_dict = {256: 0.0015, 512: 0.002, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 149 | #desc += '-preset-v2-8gpus'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}; sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 150 | 151 | # Numerical precision (choose one). 152 | desc += '-fp32'; sched.max_minibatch_per_gpu = {256: 16, 512: 8, 1024: 4} 153 | #desc += '-fp16'; G.dtype = 'float16'; D.dtype = 'float16'; G.pixelnorm_epsilon=1e-4; G_opt.use_loss_scaling = True; D_opt.use_loss_scaling = True; sched.max_minibatch_per_gpu = {512: 16, 1024: 8} 154 | 155 | # Disable individual features. 156 | #desc += '-nogrowing'; sched.lod_initial_resolution = 1024; sched.lod_training_kimg = 0; sched.lod_transition_kimg = 0; train.total_kimg = 10000 157 | #desc += '-nopixelnorm'; G.use_pixelnorm = False 158 | #desc += '-nowscale'; G.use_wscale = False; D.use_wscale = False 159 | #desc += '-noleakyrelu'; G.use_leakyrelu = False 160 | #desc += '-nosmoothing'; train.G_smoothing_kimg = 0.0 161 | #desc += '-norepeat'; train.minibatch_repeats = 1 162 | #desc += '-noreset'; train.reset_opt_for_new_lod = False 163 | 164 | # Special modes. 165 | #desc += '-BENCHMARK'; sched.lod_initial_resolution = 4; sched.lod_training_kimg = 3; sched.lod_transition_kimg = 3; train.total_kimg = (8*2+1)*3; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000 166 | #desc += '-BENCHMARK0'; sched.lod_initial_resolution = 1024; train.total_kimg = 10; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000 167 | #desc += '-VERBOSE'; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1; train.network_snapshot_ticks = 100 168 | #desc += '-GRAPH'; train.save_tf_graph = True 169 | #desc += '-HIST'; train.save_weight_histograms = True 170 | 171 | #---------------------------------------------------------------------------- 172 | # Main entry point for training. 173 | # Calls the function indicated by 'train' using the selected options. 174 | 175 | def main(): 176 | kwargs = EasyDict(train) 177 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) 178 | kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) 179 | kwargs.submit_config = copy.deepcopy(submit_config) 180 | kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 181 | kwargs.submit_config.run_dir_ignore += config.run_dir_ignore 182 | kwargs.submit_config.run_desc = desc 183 | dnnlib.submit_run(**kwargs) 184 | 185 | #---------------------------------------------------------------------------- 186 | 187 | if __name__ == "__main__": 188 | main() 189 | 190 | #---------------------------------------------------------------------------- 191 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Multi-resolution input data pipeline.""" 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import tensorflow as tf 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | 17 | #---------------------------------------------------------------------------- 18 | # Parse individual image from a tfrecords file. 19 | 20 | def parse_tfrecord_tf(record): 21 | features = tf.parse_single_example(record, features={ 22 | 'shape': tf.FixedLenFeature([3], tf.int64), 23 | 'data': tf.FixedLenFeature([], tf.string)}) 24 | data = tf.decode_raw(features['data'], tf.uint8) 25 | return tf.reshape(data, features['shape']) 26 | 27 | def parse_tfrecord_np(record): 28 | ex = tf.train.Example() 29 | ex.ParseFromString(record) 30 | shape = ex.features.feature['shape'].int64_list.value # temporary pylint workaround # pylint: disable=no-member 31 | data = ex.features.feature['data'].bytes_list.value[0] # temporary pylint workaround # pylint: disable=no-member 32 | return np.fromstring(data, np.uint8).reshape(shape) 33 | 34 | #---------------------------------------------------------------------------- 35 | # Dataset class that loads data from tfrecords files. 36 | 37 | class TFRecordDataset: 38 | def __init__(self, 39 | tfrecord_dir, # Directory containing a collection of tfrecords files. 40 | resolution = None, # Dataset resolution, None = autodetect. 41 | label_file = None, # Relative path of the labels file, None = autodetect. 42 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. 43 | repeat = True, # Repeat dataset indefinitely. 44 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 45 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. 46 | buffer_mb = 256, # Read buffer size (megabytes). 47 | num_threads = 2): # Number of concurrent threads. 48 | 49 | self.tfrecord_dir = tfrecord_dir 50 | self.resolution = None 51 | self.resolution_log2 = None 52 | self.shape = [] # [channel, height, width] 53 | self.dtype = 'uint8' 54 | self.dynamic_range = [0, 255] 55 | self.label_file = label_file 56 | self.label_size = None # [component] 57 | self.label_dtype = None 58 | self._np_labels = None 59 | self._tf_minibatch_in = None 60 | self._tf_labels_var = None 61 | self._tf_labels_dataset = None 62 | self._tf_datasets = dict() 63 | self._tf_iterator = None 64 | self._tf_init_ops = dict() 65 | self._tf_minibatch_np = None 66 | self._cur_minibatch = -1 67 | self._cur_lod = -1 68 | 69 | # List tfrecords files and inspect their shapes. 70 | assert os.path.isdir(self.tfrecord_dir) 71 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) 72 | assert len(tfr_files) >= 1 73 | tfr_shapes = [] 74 | for tfr_file in tfr_files: 75 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 76 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): 77 | tfr_shapes.append(parse_tfrecord_np(record).shape) 78 | break 79 | 80 | # Autodetect label filename. 81 | if self.label_file is None: 82 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) 83 | if len(guess): 84 | self.label_file = guess[0] 85 | elif not os.path.isfile(self.label_file): 86 | guess = os.path.join(self.tfrecord_dir, self.label_file) 87 | if os.path.isfile(guess): 88 | self.label_file = guess 89 | 90 | # Determine shape and resolution. 91 | max_shape = max(tfr_shapes, key=np.prod) 92 | self.resolution = resolution if resolution is not None else max_shape[1] 93 | self.resolution_log2 = int(np.log2(self.resolution)) 94 | self.shape = [max_shape[0], self.resolution, self.resolution] 95 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] 96 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes) 97 | assert all(shape[1] == shape[2] for shape in tfr_shapes) 98 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) 99 | assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) 100 | 101 | # Load labels. 102 | assert max_label_size == 'full' or max_label_size >= 0 103 | self._np_labels = np.zeros([1<<20, 0], dtype=np.float32) 104 | if self.label_file is not None and max_label_size != 0: 105 | self._np_labels = np.load(self.label_file) 106 | assert self._np_labels.ndim == 2 107 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: 108 | self._np_labels = self._np_labels[:, :max_label_size] 109 | self.label_size = self._np_labels.shape[1] 110 | self.label_dtype = self._np_labels.dtype.name 111 | 112 | # Build TF expressions. 113 | with tf.name_scope('Dataset'), tf.device('/cpu:0'): 114 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) 115 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') 116 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) 117 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): 118 | if tfr_lod < 0: 119 | continue 120 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) 121 | dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads) 122 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) 123 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize 124 | if shuffle_mb > 0: 125 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) 126 | if repeat: 127 | dset = dset.repeat() 128 | if prefetch_mb > 0: 129 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) 130 | dset = dset.batch(self._tf_minibatch_in) 131 | self._tf_datasets[tfr_lod] = dset 132 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) 133 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} 134 | 135 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). 136 | def configure(self, minibatch_size, lod=0): 137 | lod = int(np.floor(lod)) 138 | assert minibatch_size >= 1 and lod in self._tf_datasets 139 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod: 140 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 141 | self._cur_minibatch = minibatch_size 142 | self._cur_lod = lod 143 | 144 | # Get next minibatch as TensorFlow expressions. 145 | def get_minibatch_tf(self): # => images, labels 146 | return self._tf_iterator.get_next() 147 | 148 | # Get next minibatch as NumPy arrays. 149 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 150 | self.configure(minibatch_size, lod) 151 | if self._tf_minibatch_np is None: 152 | self._tf_minibatch_np = self.get_minibatch_tf() 153 | return tflib.run(self._tf_minibatch_np) 154 | 155 | # Get random labels as TensorFlow expression. 156 | def get_random_labels_tf(self, minibatch_size): # => labels 157 | if self.label_size > 0: 158 | with tf.device('/cpu:0'): 159 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) 160 | return tf.zeros([minibatch_size, 0], self.label_dtype) 161 | 162 | # Get random labels as NumPy array. 163 | def get_random_labels_np(self, minibatch_size): # => labels 164 | if self.label_size > 0: 165 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] 166 | return np.zeros([minibatch_size, 0], self.label_dtype) 167 | 168 | #---------------------------------------------------------------------------- 169 | # Base class for datasets that are generated on the fly. 170 | 171 | class SyntheticDataset: 172 | def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'): 173 | self.resolution = resolution 174 | self.resolution_log2 = int(np.log2(resolution)) 175 | self.shape = [num_channels, resolution, resolution] 176 | self.dtype = dtype 177 | self.dynamic_range = dynamic_range 178 | self.label_size = label_size 179 | self.label_dtype = label_dtype 180 | self._tf_minibatch_var = None 181 | self._tf_lod_var = None 182 | self._tf_minibatch_np = None 183 | self._tf_labels_np = None 184 | 185 | assert self.resolution == 2 ** self.resolution_log2 186 | with tf.name_scope('Dataset'): 187 | self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var') 188 | self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var') 189 | 190 | def configure(self, minibatch_size, lod=0): 191 | lod = int(np.floor(lod)) 192 | assert minibatch_size >= 1 and 0 <= lod <= self.resolution_log2 193 | tflib.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod}) 194 | 195 | def get_minibatch_tf(self): # => images, labels 196 | with tf.name_scope('SyntheticDataset'): 197 | shrink = tf.cast(2.0 ** tf.cast(self._tf_lod_var, tf.float32), tf.int32) 198 | shape = [self.shape[0], self.shape[1] // shrink, self.shape[2] // shrink] 199 | images = self._generate_images(self._tf_minibatch_var, self._tf_lod_var, shape) 200 | labels = self._generate_labels(self._tf_minibatch_var) 201 | return images, labels 202 | 203 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 204 | self.configure(minibatch_size, lod) 205 | if self._tf_minibatch_np is None: 206 | self._tf_minibatch_np = self.get_minibatch_tf() 207 | return tflib.run(self._tf_minibatch_np) 208 | 209 | def get_random_labels_tf(self, minibatch_size): # => labels 210 | with tf.name_scope('SyntheticDataset'): 211 | return self._generate_labels(minibatch_size) 212 | 213 | def get_random_labels_np(self, minibatch_size): # => labels 214 | self.configure(minibatch_size) 215 | if self._tf_labels_np is None: 216 | self._tf_labels_np = self.get_random_labels_tf(minibatch_size) 217 | return tflib.run(self._tf_labels_np) 218 | 219 | def _generate_images(self, minibatch, lod, shape): # to be overridden by subclasses # pylint: disable=unused-argument 220 | return tf.zeros([minibatch] + shape, self.dtype) 221 | 222 | def _generate_labels(self, minibatch): # to be overridden by subclasses 223 | return tf.zeros([minibatch, self.label_size], self.label_dtype) 224 | 225 | #---------------------------------------------------------------------------- 226 | # Helper func for constructing a dataset object using the given options. 227 | 228 | def load_dataset(class_name='training.dataset.TFRecordDataset', data_dir=None, verbose=False, **kwargs): 229 | adjusted_kwargs = dict(kwargs) 230 | if 'tfrecord_dir' in adjusted_kwargs and data_dir is not None: 231 | adjusted_kwargs['tfrecord_dir'] = os.path.join(data_dir, adjusted_kwargs['tfrecord_dir']) 232 | if verbose: 233 | print('Streaming data using %s...' % class_name) 234 | dataset = dnnlib.util.get_obj_by_name(class_name)(**adjusted_kwargs) 235 | if verbose: 236 | print('Dataset shape =', np.int32(dataset.shape).tolist()) 237 | print('Dynamic range =', dataset.dynamic_range) 238 | print('Label size =', dataset.label_size) 239 | return dataset 240 | 241 | #---------------------------------------------------------------------------- 242 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Loss functions.""" 9 | 10 | import tensorflow as tf 11 | import dnnlib.tflib as tflib 12 | from dnnlib.tflib.autosummary import autosummary 13 | 14 | #---------------------------------------------------------------------------- 15 | # Convenience func that casts all of its arguments to tf.float32. 16 | 17 | def fp32(*values): 18 | if len(values) == 1 and isinstance(values[0], tuple): 19 | values = values[0] 20 | values = tuple(tf.cast(v, tf.float32) for v in values) 21 | return values if len(values) >= 2 else values[0] 22 | 23 | #---------------------------------------------------------------------------- 24 | # WGAN & WGAN-GP loss functions. 25 | 26 | def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 27 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 28 | labels = training_set.get_random_labels_tf(minibatch_size) 29 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 30 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 31 | loss = -fake_scores_out 32 | return loss 33 | 34 | def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 35 | wgan_epsilon = 0.001): # Weight for the epsilon term, \epsilon_{drift}. 36 | 37 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 38 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 39 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 40 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 41 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 42 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 43 | loss = fake_scores_out - real_scores_out 44 | 45 | with tf.name_scope('EpsilonPenalty'): 46 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) 47 | loss += epsilon_penalty * wgan_epsilon 48 | return loss 49 | 50 | def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 51 | wgan_lambda = 10.0, # Weight for the gradient penalty term. 52 | wgan_epsilon = 0.001, # Weight for the epsilon term, \epsilon_{drift}. 53 | wgan_target = 1.0): # Target value for gradient magnitudes. 54 | 55 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 56 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 57 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 58 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 59 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 60 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 61 | loss = fake_scores_out - real_scores_out 62 | 63 | with tf.name_scope('GradientPenalty'): 64 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) 65 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) 66 | mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) 67 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) 68 | mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) 69 | mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) 70 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) 71 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) 72 | gradient_penalty = tf.square(mixed_norms - wgan_target) 73 | loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) 74 | 75 | with tf.name_scope('EpsilonPenalty'): 76 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) 77 | loss += epsilon_penalty * wgan_epsilon 78 | return loss 79 | 80 | #---------------------------------------------------------------------------- 81 | # Hinge loss functions. (Use G_wgan with these) 82 | 83 | def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument 84 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 85 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 86 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 87 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 88 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 89 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 90 | loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) 91 | return loss 92 | 93 | def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 94 | wgan_lambda = 10.0, # Weight for the gradient penalty term. 95 | wgan_target = 1.0): # Target value for gradient magnitudes. 96 | 97 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 98 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 99 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 100 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 101 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 102 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 103 | loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) 104 | 105 | with tf.name_scope('GradientPenalty'): 106 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) 107 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) 108 | mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) 109 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) 110 | mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) 111 | mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) 112 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) 113 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) 114 | gradient_penalty = tf.square(mixed_norms - wgan_target) 115 | loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) 116 | return loss 117 | 118 | 119 | #---------------------------------------------------------------------------- 120 | # Loss functions advocated by the paper 121 | # "Which Training Methods for GANs do actually Converge?" 122 | 123 | def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 124 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 125 | labels = training_set.get_random_labels_tf(minibatch_size) 126 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 127 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 128 | loss = -tf.nn.softplus(fake_scores_out) # log(1 - logistic(fake_scores_out)) 129 | return loss 130 | 131 | def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 132 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 133 | labels = training_set.get_random_labels_tf(minibatch_size) 134 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 135 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 136 | loss = tf.nn.softplus(-fake_scores_out) # -log(logistic(fake_scores_out)) 137 | return loss 138 | 139 | def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument 140 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 141 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 142 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 143 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 144 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 145 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 146 | loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) 147 | loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type 148 | return loss 149 | 150 | def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument 151 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 152 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 153 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 154 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 155 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 156 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 157 | loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) 158 | loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type 159 | 160 | if r1_gamma != 0.0: 161 | with tf.name_scope('R1Penalty'): 162 | real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out)) 163 | real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0])) 164 | r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3]) 165 | r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) 166 | loss += r1_penalty * (r1_gamma * 0.5) 167 | 168 | if r2_gamma != 0.0: 169 | with tf.name_scope('R2Penalty'): 170 | fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out)) 171 | fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0])) 172 | r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3]) 173 | r2_penalty = autosummary('Loss/r2_penalty', r2_penalty) 174 | loss += r2_penalty * (r2_gamma * 0.5) 175 | return loss 176 | 177 | #---------------------------------------------------------------------------- 178 | -------------------------------------------------------------------------------- /training/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility functions.""" 9 | 10 | import os 11 | import glob 12 | import pickle 13 | import re 14 | import numpy as np 15 | from collections import defaultdict 16 | import PIL.Image 17 | import dnnlib 18 | 19 | import config 20 | from training import dataset 21 | 22 | #---------------------------------------------------------------------------- 23 | # Convenience wrappers for pickle that are able to load data produced by 24 | # older versions of the code, and from external URLs. 25 | 26 | def open_file_or_url(file_or_url): 27 | if dnnlib.util.is_url(file_or_url): 28 | return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir) 29 | return open(file_or_url, 'rb') 30 | 31 | def load_pkl(file_or_url): 32 | with open_file_or_url(file_or_url) as file: 33 | return pickle.load(file, encoding='latin1') 34 | 35 | def save_pkl(obj, filename): 36 | with open(filename, 'wb') as file: 37 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 38 | 39 | #---------------------------------------------------------------------------- 40 | # Image utils. 41 | 42 | def adjust_dynamic_range(data, drange_in, drange_out): 43 | if drange_in != drange_out: 44 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 45 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 46 | data = data * scale + bias 47 | return data 48 | 49 | def create_image_grid(images, grid_size=None): 50 | assert images.ndim == 3 or images.ndim == 4 51 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 52 | 53 | if grid_size is not None: 54 | grid_w, grid_h = tuple(grid_size) 55 | else: 56 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 57 | grid_h = max((num - 1) // grid_w + 1, 1) 58 | 59 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 60 | for idx in range(num): 61 | x = (idx % grid_w) * img_w 62 | y = (idx // grid_w) * img_h 63 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 64 | return grid 65 | 66 | def convert_to_pil_image(image, drange=[0,1]): 67 | assert image.ndim == 2 or image.ndim == 3 68 | if image.ndim == 3: 69 | if image.shape[0] == 1: 70 | image = image[0] # grayscale CHW => HW 71 | else: 72 | image = image.transpose(1, 2, 0) # CHW -> HWC 73 | 74 | image = adjust_dynamic_range(image, drange, [0,255]) 75 | image = np.rint(image).clip(0, 255).astype(np.uint8) 76 | fmt = 'RGB' if image.ndim == 3 else 'L' 77 | return PIL.Image.fromarray(image, fmt) 78 | 79 | def save_image(image, filename, drange=[0,1], quality=95): 80 | img = convert_to_pil_image(image, drange) 81 | if '.jpg' in filename: 82 | img.save(filename,"JPEG", quality=quality, optimize=True) 83 | else: 84 | img.save(filename) 85 | 86 | def save_image_grid(images, filename, drange=[0,1], grid_size=None): 87 | convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) 88 | 89 | #---------------------------------------------------------------------------- 90 | # Locating results. 91 | 92 | def locate_run_dir(run_id_or_run_dir): 93 | if isinstance(run_id_or_run_dir, str): 94 | if os.path.isdir(run_id_or_run_dir): 95 | return run_id_or_run_dir 96 | converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir) 97 | if os.path.isdir(converted): 98 | return converted 99 | 100 | run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir)) 101 | for search_dir in ['']: 102 | full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir)) 103 | run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir)) 104 | if os.path.isdir(run_dir): 105 | return run_dir 106 | run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*'))) 107 | run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))] 108 | run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)] 109 | if len(run_dirs) == 1: 110 | return run_dirs[0] 111 | raise IOError('Cannot locate result subdir for run', run_id_or_run_dir) 112 | 113 | def list_network_pkls(run_id_or_run_dir, include_final=True): 114 | run_dir = locate_run_dir(run_id_or_run_dir) 115 | pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl'))) 116 | if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': 117 | if include_final: 118 | pkls.append(pkls[0]) 119 | del pkls[0] 120 | return pkls 121 | 122 | def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): 123 | for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]: 124 | if isinstance(candidate, str): 125 | if os.path.isfile(candidate): 126 | return candidate 127 | converted = dnnlib.submission.submit.convert_path(candidate) 128 | if os.path.isfile(converted): 129 | return converted 130 | 131 | pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl) 132 | if len(pkls) >= 1 and snapshot_or_network_pkl is None: 133 | return pkls[-1] 134 | 135 | for pkl in pkls: 136 | try: 137 | name = os.path.splitext(os.path.basename(pkl))[0] 138 | number = int(name.split('-')[-1]) 139 | if number == snapshot_or_network_pkl: 140 | return pkl 141 | except ValueError: pass 142 | except IndexError: pass 143 | raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl) 144 | 145 | def get_id_string_for_network_pkl(network_pkl): 146 | p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') 147 | return '-'.join(p[max(len(p) - 2, 0):]) 148 | 149 | #---------------------------------------------------------------------------- 150 | # Loading data from previous training runs. 151 | 152 | def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): 153 | return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl)) 154 | 155 | def parse_config_for_previous_run(run_id): 156 | run_dir = locate_run_dir(run_id) 157 | 158 | # Parse config.txt. 159 | cfg = defaultdict(dict) 160 | with open(os.path.join(run_dir, 'config.txt'), 'rt') as f: 161 | for line in f: 162 | line = re.sub(r"^{?\s*'(\w+)':\s*{(.*)(},|}})$", r"\1 = {\2}", line.strip()) 163 | if line.startswith('dataset =') or line.startswith('train ='): 164 | exec(line, cfg, cfg) # pylint: disable=exec-used 165 | 166 | # Handle legacy options. 167 | if 'file_pattern' in cfg['dataset']: 168 | cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '') 169 | if 'mirror_augment' in cfg['dataset']: 170 | cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment') 171 | if 'max_labels' in cfg['dataset']: 172 | v = cfg['dataset'].pop('max_labels') 173 | if v is None: v = 0 174 | if v == 'all': v = 'full' 175 | cfg['dataset']['max_label_size'] = v 176 | if 'max_images' in cfg['dataset']: 177 | cfg['dataset'].pop('max_images') 178 | return cfg 179 | 180 | def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment 181 | cfg = parse_config_for_previous_run(run_id) 182 | cfg['dataset'].update(kwargs) 183 | dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset']) 184 | mirror_augment = cfg['train'].get('mirror_augment', False) 185 | return dataset_obj, mirror_augment 186 | 187 | def apply_mirror_augment(minibatch): 188 | mask = np.random.rand(minibatch.shape[0]) < 0.5 189 | minibatch = np.array(minibatch) 190 | minibatch[mask] = minibatch[mask, :, :, ::-1] 191 | return minibatch 192 | 193 | #---------------------------------------------------------------------------- 194 | # Size and contents of the image snapshot grids that are exported 195 | # periodically during training. 196 | 197 | def setup_snapshot_image_grid(G, training_set, 198 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 199 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 200 | 201 | # Select size. 202 | gw = 1; gh = 1 203 | if size == '1080p': 204 | gw = np.clip(1920 // G.output_shape[3], 3, 32) 205 | gh = np.clip(1080 // G.output_shape[2], 2, 32) 206 | if size == '4k': 207 | gw = np.clip(3840 // G.output_shape[3], 7, 32) 208 | gh = np.clip(2160 // G.output_shape[2], 4, 32) 209 | 210 | # Initialize data arrays. 211 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 212 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 213 | latents = np.random.randn(gw * gh, *G.input_shape[1:]) 214 | 215 | # Random layout. 216 | if layout == 'random': 217 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) 218 | 219 | # Class-conditional layouts. 220 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 221 | if layout in class_layouts: 222 | bw, bh = class_layouts[layout] 223 | nw = (gw - 1) // bw + 1 224 | nh = (gh - 1) // bh + 1 225 | blocks = [[] for _i in range(nw * nh)] 226 | for _iter in range(1000000): 227 | real, label = training_set.get_minibatch_np(1) 228 | idx = np.argmax(label[0]) 229 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 230 | idx += training_set.label_size 231 | if idx < len(blocks): 232 | blocks[idx].append((real, label)) 233 | if all(len(block) >= bw * bh for block in blocks): 234 | break 235 | for i, block in enumerate(blocks): 236 | for j, (real, label) in enumerate(block): 237 | x = (i % nw) * bw + j % bw 238 | y = (i // nw) * bh + j // bw 239 | if x < gw and y < gh: 240 | reals[x + y * gw] = real[0] 241 | labels[x + y * gw] = label[0] 242 | 243 | return (gw, gh), reals, labels, latents 244 | 245 | #---------------------------------------------------------------------------- 246 | -------------------------------------------------------------------------------- /training/networks_progan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Network architectures used in the ProGAN paper.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | # NOTE: Do not import any application-specific modules here! 14 | # Specify all network parameters as kwargs. 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def lerp(a, b, t): return a + (b - a) * t 19 | def lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 20 | def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda) 21 | 22 | #---------------------------------------------------------------------------- 23 | # Get/create weight tensor for a convolutional or fully-connected layer. 24 | 25 | def get_weight(shape, gain=np.sqrt(2), use_wscale=False): 26 | fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] 27 | std = gain / np.sqrt(fan_in) # He init 28 | if use_wscale: 29 | wscale = tf.constant(np.float32(std), name='wscale') 30 | w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale 31 | else: 32 | w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) 33 | return w 34 | 35 | #---------------------------------------------------------------------------- 36 | # Fully-connected layer. 37 | 38 | def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False): 39 | if len(x.shape) > 2: 40 | x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) 41 | w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 42 | w = tf.cast(w, x.dtype) 43 | return tf.matmul(x, w) 44 | 45 | #---------------------------------------------------------------------------- 46 | # Convolutional layer. 47 | 48 | def conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 49 | assert kernel >= 1 and kernel % 2 == 1 50 | w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 51 | w = tf.cast(w, x.dtype) 52 | return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW') 53 | 54 | #---------------------------------------------------------------------------- 55 | # Apply bias to the given activation tensor. 56 | 57 | def apply_bias(x): 58 | b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) 59 | b = tf.cast(b, x.dtype) 60 | if len(x.shape) == 2: 61 | return x + b 62 | return x + tf.reshape(b, [1, -1, 1, 1]) 63 | 64 | #---------------------------------------------------------------------------- 65 | # Leaky ReLU activation. Same as tf.nn.leaky_relu, but supports FP16. 66 | 67 | def leaky_relu(x, alpha=0.2): 68 | with tf.name_scope('LeakyRelu'): 69 | alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') 70 | return tf.maximum(x * alpha, x) 71 | 72 | #---------------------------------------------------------------------------- 73 | # Nearest-neighbor upscaling layer. 74 | 75 | def upscale2d(x, factor=2): 76 | assert isinstance(factor, int) and factor >= 1 77 | if factor == 1: return x 78 | with tf.variable_scope('Upscale2D'): 79 | s = x.shape 80 | x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) 81 | x = tf.tile(x, [1, 1, 1, factor, 1, factor]) 82 | x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) 83 | return x 84 | 85 | #---------------------------------------------------------------------------- 86 | # Fused upscale2d + conv2d. 87 | # Faster and uses less memory than performing the operations separately. 88 | 89 | def upscale2d_conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 90 | assert kernel >= 1 and kernel % 2 == 1 91 | w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 92 | w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in] 93 | w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') 94 | w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) 95 | w = tf.cast(w, x.dtype) 96 | os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2] 97 | return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW') 98 | 99 | #---------------------------------------------------------------------------- 100 | # Box filter downscaling layer. 101 | 102 | def downscale2d(x, factor=2): 103 | assert isinstance(factor, int) and factor >= 1 104 | if factor == 1: return x 105 | with tf.variable_scope('Downscale2D'): 106 | ksize = [1, 1, factor, factor] 107 | return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') # NOTE: requires tf_config['graph_options.place_pruned_graph'] = True 108 | 109 | #---------------------------------------------------------------------------- 110 | # Fused conv2d + downscale2d. 111 | # Faster and uses less memory than performing the operations separately. 112 | 113 | def conv2d_downscale2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): 114 | assert kernel >= 1 and kernel % 2 == 1 115 | w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) 116 | w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') 117 | w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25 118 | w = tf.cast(w, x.dtype) 119 | return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW') 120 | 121 | #---------------------------------------------------------------------------- 122 | # Pixelwise feature vector normalization. 123 | 124 | def pixel_norm(x, epsilon=1e-8): 125 | with tf.variable_scope('PixelNorm'): 126 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) 127 | 128 | #---------------------------------------------------------------------------- 129 | # Minibatch standard deviation. 130 | 131 | def minibatch_stddev_layer(x, group_size=4, num_new_features=1): 132 | with tf.variable_scope('MinibatchStddev'): 133 | group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. 134 | s = x.shape # [NCHW] Input shape. 135 | y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. 136 | y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32. 137 | y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group. 138 | y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group. 139 | y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. 140 | y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels. 141 | y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups 142 | y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type. 143 | y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels. 144 | return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. 145 | 146 | #---------------------------------------------------------------------------- 147 | # Networks used in the ProgressiveGAN paper. 148 | 149 | def G_paper( 150 | latents_in, # First input: Latent vectors [minibatch, latent_size]. 151 | labels_in, # Second input: Labels [minibatch, label_size]. 152 | num_channels = 1, # Number of output color channels. Overridden based on dataset. 153 | resolution = 32, # Output resolution. Overridden based on dataset. 154 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 155 | fmap_base = 8192, # Overall multiplier for the number of feature maps. 156 | fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. 157 | fmap_max = 512, # Maximum number of feature maps in any layer. 158 | latent_size = None, # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max). 159 | normalize_latents = True, # Normalize latent vectors before feeding them to the network? 160 | use_wscale = True, # Enable equalized learning rate? 161 | use_pixelnorm = True, # Enable pixelwise feature vector normalization? 162 | pixelnorm_epsilon = 1e-8, # Constant epsilon for pixelwise feature vector normalization. 163 | use_leakyrelu = True, # True = leaky ReLU, False = ReLU. 164 | dtype = 'float32', # Data type to use for activations and outputs. 165 | fused_scale = True, # True = use fused upscale2d + conv2d, False = separate upscale2d layers. 166 | structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically. 167 | is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. 168 | **_kwargs): # Ignore unrecognized keyword args. 169 | 170 | resolution_log2 = int(np.log2(resolution)) 171 | assert resolution == 2**resolution_log2 and resolution >= 4 172 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 173 | def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x 174 | if latent_size is None: latent_size = nf(0) 175 | if structure is None: structure = 'linear' if is_template_graph else 'recursive' 176 | act = leaky_relu if use_leakyrelu else tf.nn.relu 177 | 178 | latents_in.set_shape([None, latent_size]) 179 | labels_in.set_shape([None, label_size]) 180 | combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype) 181 | lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) 182 | images_out = None 183 | 184 | # Building blocks. 185 | def block(x, res): # res = 2..resolution_log2 186 | with tf.variable_scope('%dx%d' % (2**res, 2**res)): 187 | if res == 2: # 4x4 188 | if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon) 189 | with tf.variable_scope('Dense'): 190 | x = dense(x, fmaps=nf(res-1)*16, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation 191 | x = tf.reshape(x, [-1, nf(res-1), 4, 4]) 192 | x = PN(act(apply_bias(x))) 193 | with tf.variable_scope('Conv'): 194 | x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 195 | else: # 8x8 and up 196 | if fused_scale: 197 | with tf.variable_scope('Conv0_up'): 198 | x = PN(act(apply_bias(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 199 | else: 200 | x = upscale2d(x) 201 | with tf.variable_scope('Conv0'): 202 | x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 203 | with tf.variable_scope('Conv1'): 204 | x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) 205 | return x 206 | def torgb(x, res): # res = 2..resolution_log2 207 | lod = resolution_log2 - res 208 | with tf.variable_scope('ToRGB_lod%d' % lod): 209 | return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale)) 210 | 211 | # Linear structure: simple but inefficient. 212 | if structure == 'linear': 213 | x = block(combo_in, 2) 214 | images_out = torgb(x, 2) 215 | for res in range(3, resolution_log2 + 1): 216 | lod = resolution_log2 - res 217 | x = block(x, res) 218 | img = torgb(x, res) 219 | images_out = upscale2d(images_out) 220 | with tf.variable_scope('Grow_lod%d' % lod): 221 | images_out = lerp_clip(img, images_out, lod_in - lod) 222 | 223 | # Recursive structure: complex but efficient. 224 | if structure == 'recursive': 225 | def grow(x, res, lod): 226 | y = block(x, res) 227 | img = lambda: upscale2d(torgb(y, res), 2**lod) 228 | if res > 2: img = cset(img, (lod_in > lod), lambda: upscale2d(lerp(torgb(y, res), upscale2d(torgb(x, res - 1)), lod_in - lod), 2**lod)) 229 | if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1)) 230 | return img() 231 | images_out = grow(combo_in, 2, resolution_log2 - 2) 232 | 233 | assert images_out.dtype == tf.as_dtype(dtype) 234 | images_out = tf.identity(images_out, name='images_out') 235 | return images_out 236 | 237 | 238 | def D_paper( 239 | images_in, # First input: Images [minibatch, channel, height, width]. 240 | labels_in, # Second input: Labels [minibatch, label_size]. 241 | num_channels = 1, # Number of input color channels. Overridden based on dataset. 242 | resolution = 32, # Input resolution. Overridden based on dataset. 243 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 244 | fmap_base = 8192, # Overall multiplier for the number of feature maps. 245 | fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. 246 | fmap_max = 512, # Maximum number of feature maps in any layer. 247 | use_wscale = True, # Enable equalized learning rate? 248 | mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. 249 | dtype = 'float32', # Data type to use for activations and outputs. 250 | fused_scale = True, # True = use fused conv2d + downscale2d, False = separate downscale2d layers. 251 | structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically 252 | is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. 253 | **_kwargs): # Ignore unrecognized keyword args. 254 | 255 | resolution_log2 = int(np.log2(resolution)) 256 | assert resolution == 2**resolution_log2 and resolution >= 4 257 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 258 | if structure is None: structure = 'linear' if is_template_graph else 'recursive' 259 | act = leaky_relu 260 | 261 | images_in.set_shape([None, num_channels, resolution, resolution]) 262 | labels_in.set_shape([None, label_size]) 263 | images_in = tf.cast(images_in, dtype) 264 | labels_in = tf.cast(labels_in, dtype) 265 | lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) 266 | scores_out = None 267 | 268 | # Building blocks. 269 | def fromrgb(x, res): # res = 2..resolution_log2 270 | with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): 271 | return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale))) 272 | def block(x, res): # res = 2..resolution_log2 273 | with tf.variable_scope('%dx%d' % (2**res, 2**res)): 274 | if res >= 3: # 8x8 and up 275 | with tf.variable_scope('Conv0'): 276 | x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) 277 | if fused_scale: 278 | with tf.variable_scope('Conv1_down'): 279 | x = act(apply_bias(conv2d_downscale2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) 280 | else: 281 | with tf.variable_scope('Conv1'): 282 | x = act(apply_bias(conv2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) 283 | x = downscale2d(x) 284 | else: # 4x4 285 | if mbstd_group_size > 1: 286 | x = minibatch_stddev_layer(x, mbstd_group_size) 287 | with tf.variable_scope('Conv'): 288 | x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) 289 | with tf.variable_scope('Dense0'): 290 | x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale))) 291 | with tf.variable_scope('Dense1'): 292 | x = apply_bias(dense(x, fmaps=1, gain=1, use_wscale=use_wscale)) 293 | return x 294 | 295 | # Linear structure: simple but inefficient. 296 | if structure == 'linear': 297 | img = images_in 298 | x = fromrgb(img, resolution_log2) 299 | for res in range(resolution_log2, 2, -1): 300 | lod = resolution_log2 - res 301 | x = block(x, res) 302 | img = downscale2d(img) 303 | y = fromrgb(img, res - 1) 304 | with tf.variable_scope('Grow_lod%d' % lod): 305 | x = lerp_clip(x, y, lod_in - lod) 306 | scores_out = block(x, 2) 307 | 308 | # Recursive structure: complex but efficient. 309 | if structure == 'recursive': 310 | def grow(res, lod): 311 | x = lambda: fromrgb(downscale2d(images_in, 2**lod), res) 312 | if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) 313 | x = block(x(), res); y = lambda: x 314 | if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod)) 315 | return y() 316 | scores_out = grow(2, resolution_log2 - 2) 317 | 318 | assert scores_out.dtype == tf.as_dtype(dtype) 319 | scores_out = tf.identity(scores_out, name='scores_out') 320 | return scores_out 321 | 322 | #---------------------------------------------------------------------------- 323 | -------------------------------------------------------------------------------- /training/training_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Main training script.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | import dnnlib 14 | import dnnlib.tflib as tflib 15 | from dnnlib.tflib.autosummary import autosummary 16 | 17 | import config 18 | import train 19 | from training import dataset 20 | from training import misc 21 | from metrics import metric_base 22 | 23 | #---------------------------------------------------------------------------- 24 | # Just-in-time processing of training images before feeding them to the networks. 25 | 26 | def process_reals(x, lod, mirror_augment, drange_data, drange_net): 27 | with tf.name_scope('ProcessReals'): 28 | with tf.name_scope('DynamicRange'): 29 | x = tf.cast(x, tf.float32) 30 | x = misc.adjust_dynamic_range(x, drange_data, drange_net) 31 | if mirror_augment: 32 | with tf.name_scope('MirrorAugment'): 33 | s = tf.shape(x) 34 | mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) 35 | mask = tf.tile(mask, [1, s[1], s[2], s[3]]) 36 | x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) 37 | with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. 38 | s = tf.shape(x) 39 | y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) 40 | y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) 41 | y = tf.tile(y, [1, 1, 1, 2, 1, 2]) 42 | y = tf.reshape(y, [-1, s[1], s[2], s[3]]) 43 | x = tflib.lerp(x, y, lod - tf.floor(lod)) 44 | with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks. 45 | s = tf.shape(x) 46 | factor = tf.cast(2 ** tf.floor(lod), tf.int32) 47 | x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) 48 | x = tf.tile(x, [1, 1, 1, factor, 1, factor]) 49 | x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) 50 | return x 51 | 52 | #---------------------------------------------------------------------------- 53 | # Evaluate time-varying training parameters. 54 | 55 | def training_schedule( 56 | cur_nimg, 57 | training_set, 58 | num_gpus, 59 | lod_initial_resolution = 4, # Image resolution used at the beginning. 60 | lod_training_kimg = 600, # Thousands of real images to show before doubling the resolution. 61 | lod_transition_kimg = 600, # Thousands of real images to show when fading in new layers. 62 | minibatch_base = 16, # Maximum minibatch size, divided evenly among GPUs. 63 | minibatch_dict = {}, # Resolution-specific overrides. 64 | max_minibatch_per_gpu = {}, # Resolution-specific maximum minibatch size per GPU. 65 | G_lrate_base = 0.001, # Learning rate for the generator. 66 | G_lrate_dict = {}, # Resolution-specific overrides. 67 | D_lrate_base = 0.001, # Learning rate for the discriminator. 68 | D_lrate_dict = {}, # Resolution-specific overrides. 69 | lrate_rampup_kimg = 0, # Duration of learning rate ramp-up. 70 | tick_kimg_base = 160, # Default interval of progress snapshots. 71 | tick_kimg_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20}): # Resolution-specific overrides. 72 | 73 | # Initialize result dict. 74 | s = dnnlib.EasyDict() 75 | s.kimg = cur_nimg / 1000.0 76 | 77 | # Training phase. 78 | phase_dur = lod_training_kimg + lod_transition_kimg 79 | phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 80 | phase_kimg = s.kimg - phase_idx * phase_dur 81 | 82 | # Level-of-detail and resolution. 83 | s.lod = training_set.resolution_log2 84 | s.lod -= np.floor(np.log2(lod_initial_resolution)) 85 | s.lod -= phase_idx 86 | if lod_transition_kimg > 0: 87 | s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg 88 | s.lod = max(s.lod, 0.0) 89 | s.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(s.lod))) 90 | 91 | # Minibatch size. 92 | s.minibatch = minibatch_dict.get(s.resolution, minibatch_base) 93 | s.minibatch -= s.minibatch % num_gpus 94 | if s.resolution in max_minibatch_per_gpu: 95 | s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus) 96 | 97 | # Learning rate. 98 | s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) 99 | s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) 100 | if lrate_rampup_kimg > 0: 101 | rampup = min(s.kimg / lrate_rampup_kimg, 1.0) 102 | s.G_lrate *= rampup 103 | s.D_lrate *= rampup 104 | 105 | # Other parameters. 106 | s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) 107 | return s 108 | 109 | #---------------------------------------------------------------------------- 110 | # Main training script. 111 | 112 | def training_loop( 113 | submit_config, 114 | G_args = {}, # Options for generator network. 115 | D_args = {}, # Options for discriminator network. 116 | G_opt_args = {}, # Options for generator optimizer. 117 | D_opt_args = {}, # Options for discriminator optimizer. 118 | G_loss_args = {}, # Options for generator loss. 119 | D_loss_args = {}, # Options for discriminator loss. 120 | dataset_args = {}, # Options for dataset.load_dataset(). 121 | sched_args = {}, # Options for train.TrainingSchedule. 122 | grid_args = {}, # Options for train.setup_snapshot_image_grid(). 123 | metric_arg_list = [], # Options for MetricGroup. 124 | tf_config = {}, # Options for tflib.init_tf(). 125 | G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. 126 | D_repeats = 1, # How many times the discriminator is trained per G iteration. 127 | minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. 128 | reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? 129 | total_kimg = 15000, # Total length of the training, measured in thousands of real images. 130 | mirror_augment = False, # Enable mirror augment? 131 | drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. 132 | image_snapshot_ticks = 1, # How often to export image snapshots? 133 | network_snapshot_ticks = 10, # How often to export network snapshots? 134 | save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? 135 | save_weight_histograms = False, # Include weight histograms in the tfevents file? 136 | resume_run_id = None, # Run ID or network pkl to resume training from, None = start from scratch. 137 | resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. 138 | resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. 139 | resume_time = 0.0): # Assumed wallclock time at the beginning. Affects reporting. 140 | 141 | # Initialize dnnlib and TensorFlow. 142 | ctx = dnnlib.RunContext(submit_config, train) 143 | tflib.init_tf(tf_config) 144 | 145 | # Load training set. 146 | training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) 147 | 148 | # Construct networks. 149 | with tf.device('/gpu:0'): 150 | if resume_run_id is not None: 151 | network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) 152 | print('Loading networks from "%s"...' % network_pkl) 153 | G, D, Gs = misc.load_pkl(network_pkl) 154 | else: 155 | print('Constructing networks...') 156 | G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) 157 | D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) 158 | Gs = G.clone('Gs') 159 | G.print_layers(); D.print_layers() 160 | 161 | print('Building TensorFlow graph...') 162 | with tf.name_scope('Inputs'), tf.device('/cpu:0'): 163 | lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) 164 | lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) 165 | minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) 166 | minibatch_split = minibatch_in // submit_config.num_gpus 167 | Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 168 | 169 | G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) 170 | D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) 171 | for gpu in range(submit_config.num_gpus): 172 | with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): 173 | G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') 174 | D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') 175 | lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] 176 | reals, labels = training_set.get_minibatch_tf() 177 | reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) 178 | with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): 179 | G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) 180 | with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): 181 | D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) 182 | G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) 183 | D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) 184 | G_train_op = G_opt.apply_updates() 185 | D_train_op = D_opt.apply_updates() 186 | 187 | Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) 188 | with tf.device('/gpu:0'): 189 | try: 190 | peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() 191 | except tf.errors.NotFoundError: 192 | peak_gpu_mem_op = tf.constant(0) 193 | 194 | print('Setting up snapshot image grid...') 195 | grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) 196 | sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) 197 | grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) 198 | 199 | print('Setting up run dir...') 200 | misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) 201 | misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) 202 | summary_log = tf.summary.FileWriter(submit_config.run_dir) 203 | if save_tf_graph: 204 | summary_log.add_graph(tf.get_default_graph()) 205 | if save_weight_histograms: 206 | G.setup_weight_histograms(); D.setup_weight_histograms() 207 | metrics = metric_base.MetricGroup(metric_arg_list) 208 | 209 | print('Training...\n') 210 | ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) 211 | maintenance_time = ctx.get_last_update_interval() 212 | cur_nimg = int(resume_kimg * 1000) 213 | cur_tick = 0 214 | tick_start_nimg = cur_nimg 215 | prev_lod = -1.0 216 | while cur_nimg < total_kimg * 1000: 217 | if ctx.should_stop(): break 218 | 219 | # Choose training parameters and configure training ops. 220 | sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) 221 | training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) 222 | if reset_opt_for_new_lod: 223 | if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): 224 | G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() 225 | prev_lod = sched.lod 226 | 227 | # Run training ops. 228 | for _mb_repeat in range(minibatch_repeats): 229 | for _D_repeat in range(D_repeats): 230 | tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) 231 | cur_nimg += sched.minibatch 232 | tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) 233 | 234 | # Perform maintenance tasks once per tick. 235 | done = (cur_nimg >= total_kimg * 1000) 236 | if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: 237 | cur_tick += 1 238 | tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 239 | tick_start_nimg = cur_nimg 240 | tick_time = ctx.get_time_since_last_update() 241 | total_time = ctx.get_time_since_start() + resume_time 242 | 243 | # Report progress. 244 | print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( 245 | autosummary('Progress/tick', cur_tick), 246 | autosummary('Progress/kimg', cur_nimg / 1000.0), 247 | autosummary('Progress/lod', sched.lod), 248 | autosummary('Progress/minibatch', sched.minibatch), 249 | dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), 250 | autosummary('Timing/sec_per_tick', tick_time), 251 | autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), 252 | autosummary('Timing/maintenance_sec', maintenance_time), 253 | autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) 254 | autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) 255 | autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) 256 | 257 | # Save snapshots. 258 | if cur_tick % image_snapshot_ticks == 0 or done: 259 | grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) 260 | misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) 261 | if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: 262 | pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) 263 | misc.save_pkl((G, D, Gs), pkl) 264 | metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) 265 | 266 | # Update summaries and RunContext. 267 | metrics.update_autosummaries() 268 | tflib.autosummary.save_summaries(summary_log, cur_nimg) 269 | ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) 270 | maintenance_time = ctx.get_last_update_interval() - tick_time 271 | 272 | # Write final results. 273 | misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) 274 | summary_log.close() 275 | 276 | ctx.close() 277 | 278 | #---------------------------------------------------------------------------- 279 | --------------------------------------------------------------------------------