├── utils ├── __init__.py ├── README.md └── logger.py ├── docs ├── assets │ ├── igan.jpg │ ├── higan.jpg │ ├── teaser.jpg │ ├── genforce.png │ ├── mganprior.jpg │ ├── pix2pix.jpg │ ├── image2stylegan.jpg │ ├── interfacegan.jpg │ ├── teaser_video.gif │ ├── teaser_diffusion.gif │ ├── font.css │ └── style.css └── index.html ├── examples ├── 000001.png ├── 000002.png ├── 000003.png ├── 000004.png ├── 000005.png ├── 000006.png ├── 000007.png ├── 000008.png ├── 000009.png ├── 000010.png ├── 000011.png ├── 000012.png ├── 000013.png ├── 000014.png ├── 000015.png ├── 000016.png ├── 000017.png ├── target.list ├── context.list └── test.list ├── .gitignore ├── boundaries ├── stylegan_ffhq256 │ ├── age.npy │ ├── pose.npy │ ├── gender.npy │ ├── expression.npy │ └── eyeglasses.npy ├── stylegan_bedroom256 │ ├── cloth.npy │ ├── scary.npy │ ├── wood.npy │ ├── soothing.npy │ ├── cluttered_space.npy │ └── indoor_lighting.npy └── stylegan_tower256 │ ├── clouds.npy │ ├── sunny.npy │ └── vegetation.npy ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── perceptual_path_length.py ├── metric_base.py └── linear_separability.py ├── training ├── __init__.py ├── loss_encoder.py ├── misc.py ├── loss.py ├── training_loop_encoder.py └── dataset.py ├── dnnlib ├── submission │ ├── __init__.py │ ├── _internal │ │ └── run.py │ ├── run_context.py │ └── submit.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── tfutil.py │ └── optimizer.py └── __init__.py ├── config.py ├── LICENSE.txt ├── perceptual_model.py ├── pretrained_example.py ├── train_encoder.py ├── run_metrics.py ├── README.md ├── interpolate.py ├── mix_style.py ├── manipulate.py ├── invert.py ├── diffuse.py └── generate_figures.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/assets/igan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/igan.jpg -------------------------------------------------------------------------------- /examples/000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000001.png -------------------------------------------------------------------------------- /examples/000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000002.png -------------------------------------------------------------------------------- /examples/000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000003.png -------------------------------------------------------------------------------- /examples/000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000004.png -------------------------------------------------------------------------------- /examples/000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000005.png -------------------------------------------------------------------------------- /examples/000006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000006.png -------------------------------------------------------------------------------- /examples/000007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000007.png -------------------------------------------------------------------------------- /examples/000008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000008.png -------------------------------------------------------------------------------- /examples/000009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000009.png -------------------------------------------------------------------------------- /examples/000010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000010.png -------------------------------------------------------------------------------- /examples/000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000011.png -------------------------------------------------------------------------------- /examples/000012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000012.png -------------------------------------------------------------------------------- /examples/000013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000013.png -------------------------------------------------------------------------------- /examples/000014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000014.png -------------------------------------------------------------------------------- /examples/000015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000015.png -------------------------------------------------------------------------------- /examples/000016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000016.png -------------------------------------------------------------------------------- /examples/000017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/examples/000017.png -------------------------------------------------------------------------------- /docs/assets/higan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/higan.jpg -------------------------------------------------------------------------------- /docs/assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/teaser.jpg -------------------------------------------------------------------------------- /examples/target.list: -------------------------------------------------------------------------------- 1 | examples/000001.png 2 | examples/000005.png 3 | examples/000006.png 4 | -------------------------------------------------------------------------------- /docs/assets/genforce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/genforce.png -------------------------------------------------------------------------------- /docs/assets/mganprior.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/mganprior.jpg -------------------------------------------------------------------------------- /docs/assets/pix2pix.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/pix2pix.jpg -------------------------------------------------------------------------------- /docs/assets/image2stylegan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/image2stylegan.jpg -------------------------------------------------------------------------------- /docs/assets/interfacegan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/interfacegan.jpg -------------------------------------------------------------------------------- /docs/assets/teaser_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/teaser_video.gif -------------------------------------------------------------------------------- /docs/assets/teaser_diffusion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/docs/assets/teaser_diffusion.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | 4 | *.jpg 5 | *.png 6 | *.jpeg 7 | *.npy 8 | *.log 9 | /results/ 10 | -------------------------------------------------------------------------------- /boundaries/stylegan_ffhq256/age.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_ffhq256/age.npy -------------------------------------------------------------------------------- /boundaries/stylegan_ffhq256/pose.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_ffhq256/pose.npy -------------------------------------------------------------------------------- /boundaries/stylegan_bedroom256/cloth.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_bedroom256/cloth.npy -------------------------------------------------------------------------------- /boundaries/stylegan_bedroom256/scary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_bedroom256/scary.npy -------------------------------------------------------------------------------- /boundaries/stylegan_bedroom256/wood.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_bedroom256/wood.npy -------------------------------------------------------------------------------- /boundaries/stylegan_ffhq256/gender.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_ffhq256/gender.npy -------------------------------------------------------------------------------- /boundaries/stylegan_tower256/clouds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_tower256/clouds.npy -------------------------------------------------------------------------------- /boundaries/stylegan_tower256/sunny.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_tower256/sunny.npy -------------------------------------------------------------------------------- /boundaries/stylegan_ffhq256/expression.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_ffhq256/expression.npy -------------------------------------------------------------------------------- /boundaries/stylegan_ffhq256/eyeglasses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_ffhq256/eyeglasses.npy -------------------------------------------------------------------------------- /boundaries/stylegan_bedroom256/soothing.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_bedroom256/soothing.npy -------------------------------------------------------------------------------- /boundaries/stylegan_tower256/vegetation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_tower256/vegetation.npy -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Utility Functions 2 | 3 | Scripts under this folder are borrowed from [HiGAN](https://github.com/genforce/higan). 4 | -------------------------------------------------------------------------------- /boundaries/stylegan_bedroom256/cluttered_space.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_bedroom256/cluttered_space.npy -------------------------------------------------------------------------------- /boundaries/stylegan_bedroom256/indoor_lighting.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/genforce/idinvert/HEAD/boundaries/stylegan_bedroom256/indoor_lighting.npy -------------------------------------------------------------------------------- /examples/context.list: -------------------------------------------------------------------------------- 1 | examples/000001.png 2 | examples/000002.png 3 | examples/000003.png 4 | examples/000004.png 5 | examples/000005.png 6 | examples/000006.png 7 | examples/000007.png 8 | examples/000008.png 9 | examples/000009.png 10 | examples/000010.png 11 | examples/000011.png 12 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /examples/test.list: -------------------------------------------------------------------------------- 1 | examples/000001.png 2 | examples/000002.png 3 | examples/000003.png 4 | examples/000004.png 5 | examples/000005.png 6 | examples/000006.png 7 | examples/000007.png 8 | examples/000008.png 9 | examples/000009.png 10 | examples/000010.png 11 | examples/000011.png 12 | examples/000012.png 13 | examples/000013.png 14 | examples/000014.png 15 | examples/000015.png 16 | examples/000016.png 17 | examples/000017.png 18 | -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Global configuration.""" 9 | 10 | #---------------------------------------------------------------------------- 11 | # Paths. 12 | 13 | result_dir = 'results' 14 | data_dir = 'datasets' 15 | cache_dir = 'cache' 16 | run_dir_ignore = ['results', 'datasets', 'cache'] 17 | 18 | #---------------------------------------------------------------------------- 19 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Yujun Shen 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /perceptual_model.py: -------------------------------------------------------------------------------- 1 | """Perceptual module for encoder training.""" 2 | 3 | from keras.models import Model 4 | from keras.layers import Flatten, Concatenate 5 | from keras.applications.vgg16 import VGG16, preprocess_input 6 | 7 | 8 | class PerceptualModel(Model): 9 | """Defines the VGG16 model for perceptual loss.""" 10 | 11 | def __init__(self, img_size, multi_layers=False): 12 | """Initializes with image size. 13 | 14 | Args: 15 | img_size: The image size prepared to feed to VGG16, default=256. 16 | multi_layers: Whether to use the multiple layers output of VGG16 or not. 17 | """ 18 | super().__init__() 19 | 20 | vgg = VGG16(include_top=False, input_shape=(img_size[0], img_size[1], 3)) 21 | if multi_layers: 22 | layer_ids = [2, 5, 9, 13, 17] 23 | layer_outputs = [ 24 | Flatten()(vgg.layers[layer_id].output) for layer_id in layer_ids] 25 | features = Concatenate(axis=-1)(layer_outputs) 26 | else: 27 | layer_ids = [13] # 13 -> conv4_3 28 | features = [ 29 | Flatten()(vgg.layers[layer_id].output) for layer_id in layer_ids] 30 | 31 | self._model = Model(inputs=vgg.input, outputs=features) 32 | 33 | def call(self, inputs, mask=None): 34 | return self._model(preprocess_input(inputs)) 35 | 36 | def compute_output_shape(self, input_shape): 37 | return self._model.compute_output_shape(input_shape) 38 | -------------------------------------------------------------------------------- /docs/assets/font.css: -------------------------------------------------------------------------------- 1 | /* Homepage Font */ 2 | 3 | /* latin-ext */ 4 | @font-face { 5 | font-family: 'Lato'; 6 | font-style: normal; 7 | font-weight: 400; 8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2'); 9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 10 | } 11 | 12 | /* latin */ 13 | @font-face { 14 | font-family: 'Lato'; 15 | font-style: normal; 16 | font-weight: 400; 17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2'); 18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 19 | } 20 | 21 | /* latin-ext */ 22 | @font-face { 23 | font-family: 'Lato'; 24 | font-style: normal; 25 | font-weight: 700; 26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2'); 27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 28 | } 29 | 30 | /* latin */ 31 | @font-face { 32 | font-family: 'Lato'; 33 | font-style: normal; 34 | font-weight: 700; 35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2'); 36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 37 | } 38 | -------------------------------------------------------------------------------- /dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /pretrained_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Minimal script for generating an image using pre-trained StyleGAN generator.""" 9 | 10 | import os 11 | import pickle 12 | import numpy as np 13 | import PIL.Image 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | import config 17 | 18 | def main(): 19 | # Initialize TensorFlow. 20 | tflib.init_tf() 21 | 22 | # Load pre-trained network. 23 | url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 24 | with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 25 | _G, _D, Gs = pickle.load(f) 26 | # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. 27 | # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. 28 | # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. 29 | 30 | # Print network details. 31 | Gs.print_layers() 32 | 33 | # Pick latent vector. 34 | rnd = np.random.RandomState(5) 35 | latents = rnd.randn(1, Gs.input_shape[1]) 36 | 37 | # Generate image. 38 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 39 | images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) 40 | 41 | # Save image. 42 | os.makedirs(config.result_dir, exist_ok=True) 43 | png_filename = os.path.join(config.result_dir, 'example.png') 44 | PIL.Image.fromarray(images[0], 'RGB').save(png_filename) 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # python 3.7 2 | """Utility functions for logging.""" 3 | 4 | import os 5 | import sys 6 | import logging 7 | 8 | __all__ = ['setup_logger'] 9 | 10 | DEFAULT_WORK_DIR = 'results' 11 | 12 | def setup_logger(work_dir=None, logfile_name='log.txt', logger_name='logger'): 13 | """Sets up logger from target work directory. 14 | 15 | The function will sets up a logger with `DEBUG` log level. Two handlers will 16 | be added to the logger automatically. One is the `sys.stdout` stream, with 17 | `INFO` log level, which will print improtant messages on the screen. The other 18 | is used to save all messages to file `$WORK_DIR/$LOGFILE_NAME`. Messages will 19 | be added time stamp and log level before logged. 20 | 21 | NOTE: If `logfile_name` is empty, the file stream will be skipped. Also, 22 | `DEFAULT_WORK_DIR` will be used as default work directory. 23 | 24 | Args: 25 | work_dir: The work directory. All intermediate files will be saved here. 26 | (default: None) 27 | logfile_name: Name of the file to save log message. (default: `log.txt`) 28 | logger_name: Unique name for the logger. (default: `logger`) 29 | 30 | Returns: 31 | A `logging.Logger` object. 32 | 33 | Raises: 34 | SystemExit: If the work directory has already existed, of the logger with 35 | specified name `logger_name` has already existed. 36 | """ 37 | 38 | logger = logging.getLogger(logger_name) 39 | if logger.hasHandlers(): # Already existed 40 | raise SystemExit(f'Logger name `{logger_name}` has already been set up!\n' 41 | f'Please use another name, or otherwise the messages ' 42 | f'may be mixed between these two loggers.') 43 | 44 | logger.setLevel(logging.DEBUG) 45 | formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s") 46 | 47 | # Print log message with `INFO` level or above onto the screen. 48 | sh = logging.StreamHandler(stream=sys.stdout) 49 | sh.setLevel(logging.INFO) 50 | sh.setFormatter(formatter) 51 | logger.addHandler(sh) 52 | 53 | if not logfile_name: 54 | return logger 55 | 56 | work_dir = work_dir or DEFAULT_WORK_DIR 57 | logfile_name = os.path.join(work_dir, logfile_name) 58 | if os.path.isfile(logfile_name): 59 | print(f'Log file `{logfile_name}` has already existed!') 60 | while True: 61 | decision = input(f'Would you like to overwrite it (Y/N): ') 62 | decision = decision.strip().lower() 63 | if decision == 'n': 64 | raise SystemExit(f'Please specify another one.') 65 | if decision == 'y': 66 | logger.warning(f'Overwriting log file `{logfile_name}`!') 67 | break 68 | 69 | os.makedirs(work_dir, exist_ok=True) 70 | 71 | # Save log message with all levels in log file. 72 | fh = logging.FileHandler(logfile_name) 73 | fh.setLevel(logging.DEBUG) 74 | fh.setFormatter(formatter) 75 | logger.addHandler(fh) 76 | 77 | return logger 78 | -------------------------------------------------------------------------------- /docs/assets/style.css: -------------------------------------------------------------------------------- 1 | /* Body */ 2 | body { 3 | background: #e3e5e8; 4 | color: #ffffff; 5 | font-family: 'Lato', Verdana, Helvetica, sans-serif; 6 | font-weight: 300; 7 | font-size: 14pt; 8 | } 9 | 10 | /* Hyperlinks */ 11 | a {text-decoration: none;} 12 | a:link {color: #1772d0;} 13 | a:visited {color: #1772d0;} 14 | a:active {color: red;} 15 | a:hover {color: #f09228;} 16 | 17 | /* Pre-formatted Text */ 18 | pre { 19 | margin: 5pt 0; 20 | border: 0; 21 | font-size: 12pt; 22 | background: #fcfcfc; 23 | } 24 | 25 | /* Project Page Style */ 26 | /* Section */ 27 | .section { 28 | width: 768pt; 29 | min-height: 100pt; 30 | margin: 15pt auto; 31 | padding: 20pt 30pt; 32 | border: 1pt hidden #000; 33 | text-align: justify; 34 | color: #000000; 35 | background: #ffffff; 36 | } 37 | 38 | /* Header (Title and Logo) */ 39 | .section .header { 40 | min-height: 80pt; 41 | margin-top: 30pt; 42 | } 43 | .section .header .logo { 44 | width: 80pt; 45 | margin-left: 10pt; 46 | float: left; 47 | } 48 | .section .header .logo img { 49 | width: 80pt; 50 | object-fit: cover; 51 | } 52 | .section .header .title { 53 | margin: 0 120pt; 54 | text-align: center; 55 | font-size: 22pt; 56 | } 57 | 58 | /* Author */ 59 | .section .author { 60 | margin: 5pt 0; 61 | text-align: center; 62 | font-size: 16pt; 63 | } 64 | 65 | /* Institution */ 66 | .section .institution { 67 | margin: 5pt 0; 68 | text-align: center; 69 | font-size: 16pt; 70 | } 71 | 72 | /* Hyperlink (such as Paper and Code) */ 73 | .section .link { 74 | margin: 5pt 0; 75 | text-align: center; 76 | font-size: 16pt; 77 | } 78 | 79 | /* Teaser */ 80 | .section .teaser { 81 | margin: 20pt 0; 82 | text-align: center; 83 | } 84 | .section .teaser img { 85 | width: 95%; 86 | } 87 | 88 | /* Section Title */ 89 | .section .title { 90 | text-align: center; 91 | font-size: 22pt; 92 | margin: 5pt 0 15pt 0; /* top right bottom left */ 93 | } 94 | 95 | /* Section Body */ 96 | .section .body { 97 | margin-bottom: 15pt; 98 | text-align: justify; 99 | font-size: 14pt; 100 | } 101 | 102 | /* BibTeX */ 103 | .section .bibtex { 104 | margin: 5pt 0; 105 | text-align: left; 106 | font-size: 22pt; 107 | } 108 | 109 | /* Related Work */ 110 | .section .ref { 111 | margin: 20pt 0 10pt 0; /* top right bottom left */ 112 | text-align: left; 113 | font-size: 18pt; 114 | font-weight: bold; 115 | } 116 | 117 | /* Citation */ 118 | .section .citation { 119 | min-height: 60pt; 120 | margin: 10pt 0; 121 | } 122 | .section .citation .image { 123 | width: 120pt; 124 | float: left; 125 | } 126 | .section .citation .image img { 127 | max-height: 60pt; 128 | width: 120pt; 129 | object-fit: cover; 130 | } 131 | .section .citation .comment{ 132 | margin-left: 130pt; 133 | text-align: left; 134 | font-size: 14pt; 135 | } 136 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import numpy as np 12 | import scipy 13 | import tensorflow as tf 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class FID(metric_base.MetricBase): 22 | def __init__(self, num_images, minibatch_per_gpu, **kwargs): 23 | super().__init__(**kwargs) 24 | self.num_images = num_images 25 | self.minibatch_per_gpu = minibatch_per_gpu 26 | 27 | def _evaluate(self, Gs, num_gpus): 28 | minibatch_size = num_gpus * self.minibatch_per_gpu 29 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl 30 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 31 | 32 | # Calculate statistics for reals. 33 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images) 34 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 35 | if os.path.isfile(cache_file): 36 | mu_real, sigma_real = misc.load_pkl(cache_file) 37 | else: 38 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): 39 | begin = idx * minibatch_size 40 | end = min(begin + minibatch_size, self.num_images) 41 | activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) 42 | if end == self.num_images: 43 | break 44 | mu_real = np.mean(activations, axis=0) 45 | sigma_real = np.cov(activations, rowvar=False) 46 | misc.save_pkl((mu_real, sigma_real), cache_file) 47 | 48 | # Construct TensorFlow graph. 49 | result_expr = [] 50 | for gpu_idx in range(num_gpus): 51 | with tf.device('/gpu:%d' % gpu_idx): 52 | Gs_clone = Gs.clone() 53 | inception_clone = inception.clone() 54 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 55 | images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) 56 | images = tflib.convert_images_to_uint8(images) 57 | result_expr.append(inception_clone.get_output_for(images)) 58 | 59 | # Calculate statistics for fakes. 60 | for begin in range(0, self.num_images, minibatch_size): 61 | end = min(begin + minibatch_size, self.num_images) 62 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 63 | mu_fake = np.mean(activations, axis=0) 64 | sigma_fake = np.cov(activations, rowvar=False) 65 | 66 | # Calculate FID. 67 | m = np.square(mu_fake - mu_real).sum() 68 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member 69 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 70 | self._report_result(np.real(dist)) 71 | 72 | #---------------------------------------------------------------------------- 73 | -------------------------------------------------------------------------------- /train_encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dnnlib 3 | from dnnlib import EasyDict 4 | import config 5 | import copy 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description='Training the in-domain encoder') 9 | parser.add_argument('training_data', type=str, 10 | help='path to training data (.tfrecords).') 11 | parser.add_argument('test_data', type=str, 12 | help='path to test data (.tfrecords).') 13 | parser.add_argument('decoder_pkl', default=str, 14 | help='path to the stylegan generator, which serves as a decoder here.') 15 | parser.add_argument('--num_gpus', type=int, default=8, 16 | help='Number of GPUs to use during training (defaults: 8)') 17 | parser.add_argument('--image_size', type=int, default=256, 18 | help='the image size in training dataset (defaults; 256)') 19 | parser.add_argument('--dataset_name', type=str, default='ffhq', 20 | help='the name of the training dataset (defaults; ffhq)') 21 | parser.add_argument('--mirror_augment', action='store_false', 22 | help='Mirror augment (default: True)') 23 | args = parser.parse_args() 24 | 25 | train = EasyDict(run_func_name='training.training_loop_encoder.training_loop') 26 | Encoder = EasyDict(func_name='training.networks_encoder.Encoder') 27 | E_opt = EasyDict(beta1=0.9, beta2=0.99, epsilon=1e-8) 28 | D_opt = EasyDict(beta1=0.9, beta2=0.99, epsilon=1e-8) 29 | E_loss = EasyDict(func_name='training.loss_encoder.E_loss', feature_scale=0.00005, D_scale=0.08, perceptual_img_size=256) 30 | D_loss = EasyDict(func_name='training.loss_encoder.D_logistic_simplegp', r1_gamma=10.0) 31 | lr = EasyDict(learning_rate=0.0001, decay_step=30000, decay_rate=0.8, stair=False) 32 | Data_dir = EasyDict(data_train=args.training_data, data_test=args.test_data) 33 | Decoder_pkl = EasyDict(decoder_pkl=args.decoder_pkl) 34 | tf_config = {'rnd.np_random_seed': 1000} 35 | submit_config = dnnlib.SubmitConfig() 36 | 37 | desc = 'stylegan-encoder' 38 | desc += '-%dgpu' % (args.num_gpus) 39 | desc += '-%dx%d' % (args.image_size, args.image_size) 40 | desc += '-%s' % (args.dataset_name) 41 | 42 | train.mirror_augment = args.mirror_augment 43 | minibatch_per_gpu_train = {128: 16, 256: 16, 512: 8, 1024: 4} 44 | minibatch_per_gpu_test = {128: 1, 256: 1, 512: 1, 1024: 1} 45 | total_kimgs = {128: 12000, 256: 14000, 512: 16000, 1024: 18000} 46 | 47 | assert args.image_size in minibatch_per_gpu_train, 'Invalid image size' 48 | batch_size = minibatch_per_gpu_train.get(args.image_size) * args.num_gpus 49 | batch_size_test = minibatch_per_gpu_test.get(args.image_size) * args.num_gpus 50 | train.max_iters = int(total_kimgs.get(args.image_size) * 1000 / batch_size) 51 | 52 | kwargs = EasyDict(train) 53 | kwargs.update(Encoder_args=Encoder, E_opt_args=E_opt, D_opt_args=D_opt, E_loss_args=E_loss, D_loss_args=D_loss, lr_args=lr) 54 | kwargs.update(dataset_args=Data_dir, decoder_pkl=Decoder_pkl, tf_config=tf_config) 55 | kwargs.lr_args.decay_step = train.max_iters // 4 56 | kwargs.submit_config = copy.deepcopy(submit_config) 57 | kwargs.submit_config.num_gpus = args.num_gpus 58 | kwargs.submit_config.image_size = args.image_size 59 | kwargs.submit_config.batch_size = batch_size 60 | kwargs.submit_config.batch_size_test = batch_size_test 61 | kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 62 | kwargs.submit_config.run_dir_ignore += config.run_dir_ignore 63 | kwargs.submit_config.run_desc = desc 64 | 65 | dnnlib.submit_run(**kwargs) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /training/loss_encoder.py: -------------------------------------------------------------------------------- 1 | """Loss functions for training encoder.""" 2 | import tensorflow as tf 3 | from dnnlib.tflib.autosummary import autosummary 4 | 5 | 6 | #---------------------------------------------------------------------------- 7 | # Convenience func that casts all of its arguments to tf.float32. 8 | 9 | def fp32(*values): 10 | if len(values) == 1 and isinstance(values[0], tuple): 11 | values = values[0] 12 | values = tuple(tf.cast(v, tf.float32) for v in values) 13 | return values if len(values) >= 2 else values[0] 14 | 15 | 16 | #---------------------------------------------------------------------------- 17 | # Encoder loss function . 18 | def E_loss(E, G, D, perceptual_model, reals, feature_scale=0.00005, D_scale=0.1, perceptual_img_size=256): 19 | num_layers, latent_dim = G.components.synthesis.input_shape[1:3] 20 | latent_w = E.get_output_for(reals, is_training=True) 21 | latent_wp = tf.reshape(latent_w, [reals.shape[0], num_layers, latent_dim]) 22 | fake_X = G.components.synthesis.get_output_for(latent_wp, randomize_noise=False) 23 | fake_scores_out = fp32(D.get_output_for(fake_X, None)) 24 | 25 | with tf.variable_scope('recon_loss'): 26 | vgg16_input_real = tf.transpose(reals, perm=[0, 2, 3, 1]) 27 | vgg16_input_real = tf.image.resize_images(vgg16_input_real, size=[perceptual_img_size, perceptual_img_size], method=1) 28 | vgg16_input_real = ((vgg16_input_real + 1) / 2) * 255 29 | vgg16_input_fake = tf.transpose(fake_X, perm=[0, 2, 3, 1]) 30 | vgg16_input_fake = tf.image.resize_images(vgg16_input_fake, size=[perceptual_img_size, perceptual_img_size], method=1) 31 | vgg16_input_fake = ((vgg16_input_fake + 1) / 2) * 255 32 | vgg16_feature_real = perceptual_model(vgg16_input_real) 33 | vgg16_feature_fake = perceptual_model(vgg16_input_fake) 34 | recon_loss_feats = feature_scale * tf.reduce_mean(tf.square(vgg16_feature_real - vgg16_feature_fake)) 35 | recon_loss_pixel = tf.reduce_mean(tf.square(fake_X - reals)) 36 | recon_loss_feats = autosummary('Loss/scores/loss_feats', recon_loss_feats) 37 | recon_loss_pixel = autosummary('Loss/scores/loss_pixel', recon_loss_pixel) 38 | recon_loss = recon_loss_feats + recon_loss_pixel 39 | recon_loss = autosummary('Loss/scores/recon_loss', recon_loss) 40 | 41 | with tf.variable_scope('adv_loss'): 42 | D_scale = autosummary('Loss/scores/d_scale', D_scale) 43 | adv_loss = D_scale * tf.reduce_mean(tf.nn.softplus(-fake_scores_out)) 44 | adv_loss = autosummary('Loss/scores/adv_loss', adv_loss) 45 | 46 | loss = recon_loss + adv_loss 47 | 48 | return loss, recon_loss, adv_loss 49 | 50 | #---------------------------------------------------------------------------- 51 | # Discriminator loss function. 52 | def D_logistic_simplegp(E, G, D, reals, r1_gamma=10.0): 53 | 54 | num_layers, latent_dim = G.components.synthesis.input_shape[1:3] 55 | latent_w = E.get_output_for(reals, is_training=True) 56 | latent_wp = tf.reshape(latent_w, [reals.shape[0], num_layers, latent_dim]) 57 | fake_X = G.components.synthesis.get_output_for(latent_wp, randomize_noise=False) 58 | real_scores_out = fp32(D.get_output_for(reals, None)) 59 | fake_scores_out = fp32(D.get_output_for(fake_X, None)) 60 | 61 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 62 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 63 | loss_fake = tf.reduce_mean(tf.nn.softplus(fake_scores_out)) 64 | loss_real = tf.reduce_mean(tf.nn.softplus(-real_scores_out)) 65 | 66 | with tf.name_scope('R1Penalty'): 67 | real_grads = fp32(tf.gradients(real_scores_out, [reals])[0]) 68 | r1_penalty = tf.reduce_mean(tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3])) 69 | r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) 70 | loss_gp = r1_penalty * (r1_gamma * 0.5) 71 | loss = loss_fake + loss_real + loss_gp 72 | return loss, loss_fake, loss_real, loss_gp 73 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /run_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Main entry point for training StyleGAN and ProGAN networks.""" 9 | 10 | import dnnlib 11 | from dnnlib import EasyDict 12 | import dnnlib.tflib as tflib 13 | 14 | import config 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment): 21 | ctx = dnnlib.RunContext(submit_config) 22 | tflib.init_tf() 23 | print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl)) 24 | metric = dnnlib.util.call_func_by_name(**metric_args) 25 | print() 26 | metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus) 27 | print() 28 | ctx.close() 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | def run_snapshot(submit_config, metric_args, run_id, snapshot): 33 | ctx = dnnlib.RunContext(submit_config) 34 | tflib.init_tf() 35 | print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot)) 36 | run_dir = misc.locate_run_dir(run_id) 37 | network_pkl = misc.locate_network_pkl(run_dir, snapshot) 38 | metric = dnnlib.util.call_func_by_name(**metric_args) 39 | print() 40 | metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) 41 | print() 42 | ctx.close() 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def run_all_snapshots(submit_config, metric_args, run_id): 47 | ctx = dnnlib.RunContext(submit_config) 48 | tflib.init_tf() 49 | print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id)) 50 | run_dir = misc.locate_run_dir(run_id) 51 | network_pkls = misc.list_network_pkls(run_dir) 52 | metric = dnnlib.util.call_func_by_name(**metric_args) 53 | print() 54 | for idx, network_pkl in enumerate(network_pkls): 55 | ctx.update('', idx, len(network_pkls)) 56 | metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) 57 | print() 58 | ctx.close() 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | def main(): 63 | submit_config = dnnlib.SubmitConfig() 64 | 65 | # Which metrics to evaluate? 66 | metrics = [] 67 | metrics += [metric_base.fid50k] 68 | #metrics += [metric_base.ppl_zfull] 69 | #metrics += [metric_base.ppl_wfull] 70 | #metrics += [metric_base.ppl_zend] 71 | #metrics += [metric_base.ppl_wend] 72 | #metrics += [metric_base.ls] 73 | #metrics += [metric_base.dummy] 74 | 75 | # Which networks to evaluate them on? 76 | tasks = [] 77 | tasks += [EasyDict(run_func_name='run_metrics.run_pickle', network_pkl='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', dataset_args=EasyDict(tfrecord_dir='ffhq', shuffle_mb=0), mirror_augment=True)] # karras2019stylegan-ffhq-1024x1024.pkl 78 | #tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)] 79 | #tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)] 80 | 81 | # How many GPUs to use? 82 | submit_config.num_gpus = 1 83 | #submit_config.num_gpus = 2 84 | #submit_config.num_gpus = 4 85 | #submit_config.num_gpus = 8 86 | 87 | # Execute. 88 | submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) 89 | submit_config.run_dir_ignore += config.run_dir_ignore 90 | for task in tasks: 91 | for metric in metrics: 92 | submit_config.run_desc = '%s-%s' % (task.run_func_name, metric.name) 93 | if task.run_func_name.endswith('run_snapshot'): 94 | submit_config.run_desc += '-%s-%s' % (task.run_id, task.snapshot) 95 | if task.run_func_name.endswith('run_all_snapshots'): 96 | submit_config.run_desc += '-%s' % task.run_id 97 | submit_config.run_desc += '-%dgpu' % submit_config.num_gpus 98 | dnnlib.submit_run(submit_config, metric_args=metric, **task) 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | if __name__ == "__main__": 103 | main() 104 | 105 | #---------------------------------------------------------------------------- 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # In-Domain GAN Inversion for Real Image Editing 2 | 3 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg?style=plastic) 4 | ![TensorFlow 1.12.2](https://img.shields.io/badge/tensorflow-1.12.2-green.svg?style=plastic) 5 | ![Keras 2.2.4](https://img.shields.io/badge/keras-2.2.4-green.svg?style=plastic) 6 | 7 | ![image](./docs/assets/teaser.jpg) 8 | 9 | **Figure:** *Real image editing using the proposed In-Domain GAN inversion with a fixed GAN generator.* 10 | 11 | > **In-Domain GAN Inversion for Real Image Editing**
12 | > Jiapeng Zhu*, Yujun Shen*, Deli Zhao, Bolei Zhou
13 | > *European Conference on Computer Vision (ECCV) 2020* 14 | 15 | In the repository, we propose an in-domain GAN inversion method, which not only faithfully reconstructs the input image but also ensures the inverted code to be **semantically meaningful** for editing. Basically, the in-domain GAN inversion contains two steps: 16 | 17 | 1. Training **domain-guided** encoder. 18 | 2. Performing **domain-regularized** optimization. 19 | 20 | **NEWS: Please also find [this repo](https://github.com/genforce/idinvert_pytorch), which is friendly to PyTorch users!** 21 | 22 | [[Paper](https://arxiv.org/pdf/2004.00049.pdf)] 23 | [[Project Page](https://genforce.github.io/idinvert/)] 24 | [[Demo](https://www.youtube.com/watch?v=3v6NHrhuyFY)] 25 | [[Colab](https://colab.research.google.com/github/genforce/idinvert_pytorch/blob/master/docs/Idinvert.ipynb)] 26 | 27 | ## Testing 28 | 29 | ### Pre-trained Models 30 | 31 | Please download the pre-trained models from the following links. For each model, it contains the GAN generator and discriminator, as well as the proposed **domain-guided encoder**. 32 | 33 | | Path | Description 34 | | :--- | :---------- 35 | |[face_256x256](https://drive.google.com/file/d/1azAzSZg6VfNydjWr4qfl8Z4LfxktTPqM/view?usp=sharing) | In-domain GAN trained with [FFHQ](https://github.com/NVlabs/ffhq-dataset) dataset. 36 | |[tower_256x256](https://drive.google.com/file/d/1USfaSLor5d71IRoC8CWTbKJagS0-MJEv/view?usp=sharing) | In-domain GAN trained with [LSUN Tower](https://github.com/fyu/lsun) dataset. 37 | |[bedroom_256x256](https://drive.google.com/file/d/1nRa4WAE1qF_j1CtH32hxjREK0o-rpucD/view?usp=sharing) | In-domain GAN trained with [LSUN Bedroom](https://github.com/fyu/lsun) dataset. 38 | 39 | ### Inversion 40 | 41 | ```bash 42 | MODEL_PATH='styleganinv_face_256.pkl' 43 | IMAGE_LIST='examples/test.list' 44 | python invert.py $MODEL_PATH $IMAGE_LIST 45 | ``` 46 | 47 | NOTE: We find that 100 iterations are good enough for inverting an image, which takes about 8s (on P40). But users can always use more iterations (much slower) for a more precise reconstruction. 48 | 49 | ### Semantic Diffusion 50 | 51 | ```bash 52 | MODEL_PATH='styleganinv_face_256.pkl' 53 | TARGET_LIST='examples/target.list' 54 | CONTEXT_LIST='examples/context.list' 55 | python diffuse.py $MODEL_PATH $TARGET_LIST $CONTEXT_LIST 56 | ``` 57 | 58 | NOTE: The diffusion process is highly similar to image inversion. The main difference is that only the target patch is used to compute loss for **masked** optimization. 59 | 60 | ### Interpolation 61 | 62 | ```bash 63 | SRC_DIR='results/inversion/test' 64 | DST_DIR='results/inversion/test' 65 | python interpolate.py $MODEL_PATH $SRC_DIR $DST_DIR 66 | ``` 67 | 68 | ### Manipulation 69 | 70 | ```bash 71 | IMAGE_DIR='results/inversion/test' 72 | BOUNDARY='boundaries/expression.npy' 73 | python manipulate.py $MODEL_PATH $IMAGE_DIR $BOUNDARY 74 | ``` 75 | 76 | NOTE: Boundaries are obtained using [InterFaceGAN](https://github.com/genforce/interfacegan). 77 | 78 | ### Style Mixing 79 | 80 | ```bash 81 | STYLE_DIR='results/inversion/test' 82 | CONTENT_DIR='results/inversion/test' 83 | python mix_style.py $MODEL_PATH $STYLE_DIR $CONTENT_DIR 84 | ``` 85 | 86 | ## Training 87 | 88 | The GAN model used in this work is [StyleGAN](https://github.com/NVlabs/stylegan). Beyond the original repository, we make following changes: 89 | 90 | - Change repleated $w$ for all layers to different $w$s (Line 428-435 in file `training/networks_stylegan.py`). 91 | - Add the *domain-guided* encoder in file `training/networks_encoder.py`. 92 | - Add losses for training the *domain-guided* encoder in file `training/loss_encoder.py`. 93 | - Add schedule for training the *domain-guided* encoder in file `training/training_loop_encoder.py`. 94 | - Add a perceptual model (VGG16) for computing perceptual loss in file `perceptual_model.py`. 95 | - Add training script for the *domain-guided* encoder in file `train_encoder.py`. 96 | 97 | ### Step-1: Train your own generator 98 | 99 | ```bash 100 | python train.py 101 | ``` 102 | 103 | ### Step-2: Train your own encoder 104 | 105 | ```bash 106 | TRAINING_DATA=PATH_TO_TRAINING_DATA 107 | TESTING_DATA=PATH_TO_TESTING_DATA 108 | DECODER_PKL=PATH_TO_GENERATOR 109 | python train_encoder.py $TRAINING_DATA $TESTING_DATA $DECODER_PKL 110 | ``` 111 | 112 | Note that the file `dataset_tool.py`, which is borrowed from the [StyleGAN](https://github.com/NVlabs/stylegan) repo, is used to prepared a directory of data from all resolutions. The training of the encoder does not rely on the progressive strategy, therefore, the training data and the test data should be both specified as the `.tfrecords` file with the highest resolution. 113 | 114 | ## BibTeX 115 | 116 | ```bibtex 117 | @inproceedings{zhu2020indomain, 118 | title = {In-domain GAN Inversion for Real Image Editing}, 119 | author = {Zhu, Jiapeng and Shen, Yujun and Zhao, Deli and Zhou, Bolei}, 120 | booktitle = {Proceedings of European Conference on Computer Vision (ECCV)}, 121 | year = {2020} 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Perceptual Path Length (PPL).""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import dnnlib.tflib as tflib 13 | 14 | from metrics import metric_base 15 | from training import misc 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | # Normalize batch of vectors. 20 | def normalize(v): 21 | return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) 22 | 23 | # Spherical interpolation of a batch of vectors. 24 | def slerp(a, b, t): 25 | a = normalize(a) 26 | b = normalize(b) 27 | d = tf.reduce_sum(a * b, axis=-1, keepdims=True) 28 | p = t * tf.math.acos(d) 29 | c = normalize(b - d * a) 30 | d = a * tf.math.cos(p) + c * tf.math.sin(p) 31 | return normalize(d) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPL(metric_base.MetricBase): 36 | def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__(**kwargs) 40 | self.num_samples = num_samples 41 | self.epsilon = epsilon 42 | self.space = space 43 | self.sampling = sampling 44 | self.minibatch_per_gpu = minibatch_per_gpu 45 | 46 | def _evaluate(self, Gs, num_gpus): 47 | minibatch_size = num_gpus * self.minibatch_per_gpu 48 | 49 | # Construct TensorFlow graph. 50 | distance_expr = [] 51 | for gpu_idx in range(num_gpus): 52 | with tf.device('/gpu:%d' % gpu_idx): 53 | Gs_clone = Gs.clone() 54 | noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] 55 | 56 | # Generate random latents and interpolation t-values. 57 | lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) 58 | lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) 59 | 60 | # Interpolate in W or Z. 61 | if self.space == 'w': 62 | dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True) 63 | dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] 64 | dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) 65 | dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) 66 | dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) 67 | else: # space == 'z' 68 | lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] 69 | lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) 70 | lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) 71 | lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) 72 | dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True) 73 | 74 | # Synthesize images. 75 | with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch 76 | images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False) 77 | 78 | # Crop only the face region. 79 | c = int(images.shape[2] // 8) 80 | images = images[:, :, c*3 : c*7, c*2 : c*6] 81 | 82 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 83 | if images.shape[2] > 256: 84 | factor = images.shape[2] // 256 85 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 86 | images = tf.reduce_mean(images, axis=[3,5]) 87 | 88 | # Scale dynamic range from [-1,1] to [0,255] for VGG. 89 | images = (images + 1) * (255 / 2) 90 | 91 | # Evaluate perceptual distance. 92 | img_e0, img_e1 = images[0::2], images[1::2] 93 | distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl 94 | distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) 95 | 96 | # Sampling loop. 97 | all_distances = [] 98 | for _ in range(0, self.num_samples, minibatch_size): 99 | all_distances += tflib.run(distance_expr) 100 | all_distances = np.concatenate(all_distances, axis=0) 101 | 102 | # Reject outliers. 103 | lo = np.percentile(all_distances, 1, interpolation='lower') 104 | hi = np.percentile(all_distances, 99, interpolation='higher') 105 | filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) 106 | self._report_result(np.mean(filtered_distances)) 107 | 108 | #---------------------------------------------------------------------------- 109 | -------------------------------------------------------------------------------- /interpolate.py: -------------------------------------------------------------------------------- 1 | # python 3.6 2 | """Interpolates real images with In-domain GAN Inversion. 3 | 4 | The real images should be first inverted to latent codes with `invert.py`. After 5 | that, this script can be used for image interpolation. 6 | 7 | NOTE: This script will interpolate every image pair from source directory to 8 | target directory. 9 | """ 10 | 11 | import os 12 | import argparse 13 | import pickle 14 | from tqdm import tqdm 15 | import numpy as np 16 | import tensorflow as tf 17 | from dnnlib import tflib 18 | 19 | from utils.logger import setup_logger 20 | from utils.editor import interpolate 21 | from utils.visualizer import load_image 22 | from utils.visualizer import adjust_pixel_range 23 | from utils.visualizer import HtmlPageVisualizer 24 | 25 | 26 | def parse_args(): 27 | """Parses arguments.""" 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('model_path', type=str, 30 | help='Path to the pre-trained model.') 31 | parser.add_argument('src_dir', type=str, 32 | help='Source directory, which includes original images, ' 33 | 'inverted codes, and image list.') 34 | parser.add_argument('dst_dir', type=str, 35 | help='Target directory, which includes original images, ' 36 | 'inverted codes, and image list.') 37 | parser.add_argument('-o', '--output_dir', type=str, default='', 38 | help='Directory to save the results. If not specified, ' 39 | '`./results/interpolation` will be used by default.') 40 | parser.add_argument('--batch_size', type=int, default=32, 41 | help='Batch size. (default: 32)') 42 | parser.add_argument('--step', type=int, default=5, 43 | help='Number of steps for interpolation. (default: 5)') 44 | parser.add_argument('--viz_size', type=int, default=256, 45 | help='Image size for visualization. (default: 256)') 46 | parser.add_argument('--gpu_id', type=str, default='0', 47 | help='Which GPU(s) to use. (default: `0`)') 48 | return parser.parse_args() 49 | 50 | 51 | def main(): 52 | """Main function.""" 53 | args = parse_args() 54 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 55 | src_dir = args.src_dir 56 | src_dir_name = os.path.basename(src_dir.rstrip('/')) 57 | assert os.path.exists(src_dir) 58 | assert os.path.exists(f'{src_dir}/image_list.txt') 59 | assert os.path.exists(f'{src_dir}/inverted_codes.npy') 60 | dst_dir = args.dst_dir 61 | dst_dir_name = os.path.basename(dst_dir.rstrip('/')) 62 | assert os.path.exists(dst_dir) 63 | assert os.path.exists(f'{dst_dir}/image_list.txt') 64 | assert os.path.exists(f'{dst_dir}/inverted_codes.npy') 65 | output_dir = args.output_dir or 'results/interpolation' 66 | job_name = f'{src_dir_name}_TO_{dst_dir_name}' 67 | logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger') 68 | 69 | # Load model. 70 | logger.info(f'Loading generator.') 71 | tflib.init_tf({'rnd.np_random_seed': 1000}) 72 | with open(args.model_path, 'rb') as f: 73 | _, _, _, Gs = pickle.load(f) 74 | 75 | # Build graph. 76 | logger.info(f'Building graph.') 77 | sess = tf.get_default_session() 78 | num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3] 79 | wp = tf.placeholder( 80 | tf.float32, [args.batch_size, num_layers, latent_dim], name='latent_code') 81 | x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False) 82 | 83 | # Load image and codes. 84 | logger.info(f'Loading images and corresponding inverted latent codes.') 85 | src_list = [] 86 | with open(f'{src_dir}/image_list.txt', 'r') as f: 87 | for line in f: 88 | name = os.path.splitext(os.path.basename(line.strip()))[0] 89 | assert os.path.exists(f'{src_dir}/{name}_ori.png') 90 | src_list.append(name) 91 | src_codes = np.load(f'{src_dir}/inverted_codes.npy') 92 | assert src_codes.shape[0] == len(src_list) 93 | num_src = src_codes.shape[0] 94 | dst_list = [] 95 | with open(f'{dst_dir}/image_list.txt', 'r') as f: 96 | for line in f: 97 | name = os.path.splitext(os.path.basename(line.strip()))[0] 98 | assert os.path.exists(f'{dst_dir}/{name}_ori.png') 99 | dst_list.append(name) 100 | dst_codes = np.load(f'{dst_dir}/inverted_codes.npy') 101 | assert dst_codes.shape[0] == len(dst_list) 102 | num_dst = dst_codes.shape[0] 103 | 104 | # Interpolate images. 105 | logger.info(f'Start interpolation.') 106 | step = args.step + 2 107 | viz_size = None if args.viz_size == 0 else args.viz_size 108 | visualizer = HtmlPageVisualizer( 109 | num_rows=num_src * num_dst, num_cols=step + 2, viz_size=viz_size) 110 | visualizer.set_headers( 111 | ['Source', 'Source Inversion'] + 112 | [f'Step {i:02d}' for i in range(1, step - 1)] + 113 | ['Target Inversion', 'Target'] 114 | ) 115 | 116 | inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32) 117 | for src_idx in tqdm(range(num_src), leave=False): 118 | src_code = src_codes[src_idx:src_idx + 1] 119 | src_path = f'{src_dir}/{src_list[src_idx]}_ori.png' 120 | codes = interpolate(src_codes=np.repeat(src_code, num_dst, axis=0), 121 | dst_codes=dst_codes, 122 | step=step) 123 | for dst_idx in tqdm(range(num_dst), leave=False): 124 | dst_path = f'{dst_dir}/{dst_list[dst_idx]}_ori.png' 125 | output_images = [] 126 | for idx in range(0, step, args.batch_size): 127 | batch = codes[dst_idx, idx:idx + args.batch_size] 128 | inputs[0:len(batch)] = batch 129 | images = sess.run(x, feed_dict={wp: inputs}) 130 | output_images.append(images[0:len(batch)]) 131 | output_images = adjust_pixel_range(np.concatenate(output_images, axis=0)) 132 | 133 | row_idx = src_idx * num_dst + dst_idx 134 | visualizer.set_cell(row_idx, 0, image=load_image(src_path)) 135 | visualizer.set_cell(row_idx, step + 1, image=load_image(dst_path)) 136 | for s, output_image in enumerate(output_images): 137 | visualizer.set_cell(row_idx, s + 1, image=output_image) 138 | 139 | # Save results. 140 | visualizer.save(f'{output_dir}/{job_name}.html') 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /mix_style.py: -------------------------------------------------------------------------------- 1 | # python 3.6 2 | """Mixes styles with In-domain GAN Inversion. 3 | 4 | The real images should be first inverted to latent codes with `invert.py`. After 5 | that, this script can be used for style mixing. 6 | 7 | NOTE: This script will mix every `style-content` image pair from style 8 | directory to content directory. 9 | """ 10 | 11 | import os 12 | import argparse 13 | import pickle 14 | from tqdm import tqdm 15 | import numpy as np 16 | import tensorflow as tf 17 | from dnnlib import tflib 18 | 19 | from utils.logger import setup_logger 20 | from utils.editor import mix_style 21 | from utils.visualizer import load_image 22 | from utils.visualizer import adjust_pixel_range 23 | from utils.visualizer import HtmlPageVisualizer 24 | 25 | 26 | def parse_args(): 27 | """Parses arguments.""" 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('model_path', type=str, 30 | help='Path to the pre-trained model.') 31 | parser.add_argument('style_dir', type=str, 32 | help='Style directory, which includes original images, ' 33 | 'inverted codes, and image list.') 34 | parser.add_argument('content_dir', type=str, 35 | help='Content directory, which includes original images, ' 36 | 'inverted codes, and image list.') 37 | parser.add_argument('-o', '--output_dir', type=str, default='', 38 | help='Directory to save the results. If not specified, ' 39 | '`./results/style_mixing` will be used by default.') 40 | parser.add_argument('--batch_size', type=int, default=32, 41 | help='Batch size. (default: 32)') 42 | parser.add_argument('--mix_layer_start_idx', type=int, default=10, 43 | help='0-based layer index. Style mixing is performed ' 44 | 'from this layer to the last layer. (default: 10)') 45 | parser.add_argument('--viz_size', type=int, default=256, 46 | help='Image size for visualization. (default: 256)') 47 | parser.add_argument('--gpu_id', type=str, default='0', 48 | help='Which GPU(s) to use. (default: `0`)') 49 | return parser.parse_args() 50 | 51 | 52 | def main(): 53 | """Main function.""" 54 | args = parse_args() 55 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 56 | style_dir = args.style_dir 57 | style_dir_name = os.path.basename(style_dir.rstrip('/')) 58 | assert os.path.exists(style_dir) 59 | assert os.path.exists(f'{style_dir}/image_list.txt') 60 | assert os.path.exists(f'{style_dir}/inverted_codes.npy') 61 | content_dir = args.content_dir 62 | content_dir_name = os.path.basename(content_dir.rstrip('/')) 63 | assert os.path.exists(content_dir) 64 | assert os.path.exists(f'{content_dir}/image_list.txt') 65 | assert os.path.exists(f'{content_dir}/inverted_codes.npy') 66 | output_dir = args.output_dir or 'results/style_mixing' 67 | job_name = f'{style_dir_name}_STYLIZE_{content_dir_name}' 68 | logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger') 69 | 70 | # Load model. 71 | logger.info(f'Loading generator.') 72 | tflib.init_tf({'rnd.np_random_seed': 1000}) 73 | with open(args.model_path, 'rb') as f: 74 | _, _, _, Gs = pickle.load(f) 75 | 76 | # Build graph. 77 | logger.info(f'Building graph.') 78 | sess = tf.get_default_session() 79 | num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3] 80 | wp = tf.placeholder( 81 | tf.float32, [args.batch_size, num_layers, latent_dim], name='latent_code') 82 | x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False) 83 | mix_layers = list(range(args.mix_layer_start_idx, num_layers)) 84 | 85 | # Load image and codes. 86 | logger.info(f'Loading images and corresponding inverted latent codes.') 87 | style_list = [] 88 | with open(f'{style_dir}/image_list.txt', 'r') as f: 89 | for line in f: 90 | name = os.path.splitext(os.path.basename(line.strip()))[0] 91 | assert os.path.exists(f'{style_dir}/{name}_ori.png') 92 | style_list.append(name) 93 | logger.info(f'Loading inverted latent codes.') 94 | style_codes = np.load(f'{style_dir}/inverted_codes.npy') 95 | assert style_codes.shape[0] == len(style_list) 96 | num_styles = style_codes.shape[0] 97 | content_list = [] 98 | with open(f'{content_dir}/image_list.txt', 'r') as f: 99 | for line in f: 100 | name = os.path.splitext(os.path.basename(line.strip()))[0] 101 | assert os.path.exists(f'{content_dir}/{name}_ori.png') 102 | content_list.append(name) 103 | logger.info(f'Loading inverted latent codes.') 104 | content_codes = np.load(f'{content_dir}/inverted_codes.npy') 105 | assert content_codes.shape[0] == len(content_list) 106 | num_contents = content_codes.shape[0] 107 | 108 | # Mix styles. 109 | logger.info(f'Start style mixing.') 110 | viz_size = None if args.viz_size == 0 else args.viz_size 111 | visualizer = HtmlPageVisualizer( 112 | num_rows=num_styles + 1, num_cols=num_contents + 1, viz_size=viz_size) 113 | visualizer.set_headers( 114 | ['Style'] + 115 | [f'Content {i:03d}' for i in range(num_contents)] 116 | ) 117 | for style_idx, style_name in enumerate(style_list): 118 | style_image = load_image(f'{style_dir}/{style_name}_ori.png') 119 | visualizer.set_cell(style_idx + 1, 0, image=style_image) 120 | for content_idx, content_name in enumerate(content_list): 121 | content_image = load_image(f'{content_dir}/{content_name}_ori.png') 122 | visualizer.set_cell(0, content_idx + 1, image=content_image) 123 | 124 | codes = mix_style(style_codes=style_codes, 125 | content_codes=content_codes, 126 | num_layers=num_layers, 127 | mix_layers=mix_layers) 128 | inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32) 129 | for style_idx in tqdm(range(num_styles), leave=False): 130 | output_images = [] 131 | for idx in range(0, num_contents, args.batch_size): 132 | batch = codes[style_idx, idx:idx + args.batch_size] 133 | inputs[0:len(batch)] = batch 134 | images = sess.run(x, feed_dict={wp: inputs}) 135 | output_images.append(images[0:len(batch)]) 136 | output_images = adjust_pixel_range(np.concatenate(output_images, axis=0)) 137 | for content_idx, output_image in enumerate(output_images): 138 | visualizer.set_cell(style_idx + 1, content_idx + 1, image=output_image) 139 | 140 | # Save results. 141 | visualizer.save(f'{output_dir}/{job_name}.html') 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /metrics/metric_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Common definitions for GAN metrics.""" 9 | 10 | import os 11 | import time 12 | import hashlib 13 | import numpy as np 14 | import tensorflow as tf 15 | import dnnlib 16 | import dnnlib.tflib as tflib 17 | 18 | import config 19 | from training import misc 20 | from training import dataset 21 | 22 | #---------------------------------------------------------------------------- 23 | # Standard metrics. 24 | 25 | fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) 26 | ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) 27 | ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) 28 | ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) 29 | ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) 30 | ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) 31 | dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging 32 | 33 | #---------------------------------------------------------------------------- 34 | # Base class for metrics. 35 | 36 | class MetricBase: 37 | def __init__(self, name): 38 | self.name = name 39 | self._network_pkl = None 40 | self._dataset_args = None 41 | self._mirror_augment = None 42 | self._results = [] 43 | self._eval_time = None 44 | 45 | def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): 46 | self._network_pkl = network_pkl 47 | self._dataset_args = dataset_args 48 | self._mirror_augment = mirror_augment 49 | self._results = [] 50 | 51 | if (dataset_args is None or mirror_augment is None) and run_dir is not None: 52 | run_config = misc.parse_config_for_previous_run(run_dir) 53 | self._dataset_args = dict(run_config['dataset']) 54 | self._dataset_args['shuffle_mb'] = 0 55 | self._mirror_augment = run_config['train'].get('mirror_augment', False) 56 | 57 | time_begin = time.time() 58 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager 59 | _G, _D, Gs = misc.load_pkl(self._network_pkl) 60 | self._evaluate(Gs, num_gpus=num_gpus) 61 | self._eval_time = time.time() - time_begin 62 | 63 | if log_results: 64 | result_str = self.get_result_str() 65 | if run_dir is not None: 66 | log = os.path.join(run_dir, 'metric-%s.txt' % self.name) 67 | with dnnlib.util.Logger(log, 'a'): 68 | print(result_str) 69 | else: 70 | print(result_str) 71 | 72 | def get_result_str(self): 73 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] 74 | if len(network_name) > 29: 75 | network_name = '...' + network_name[-26:] 76 | result_str = '%-30s' % network_name 77 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) 78 | for res in self._results: 79 | result_str += ' ' + self.name + res.suffix + ' ' 80 | result_str += res.fmt % res.value 81 | return result_str 82 | 83 | def update_autosummaries(self): 84 | for res in self._results: 85 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) 86 | 87 | def _evaluate(self, Gs, num_gpus): 88 | raise NotImplementedError # to be overridden by subclasses 89 | 90 | def _report_result(self, value, suffix='', fmt='%-10.4f'): 91 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] 92 | 93 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs): 94 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) 95 | all_args.update(self._dataset_args) 96 | all_args.update(kwargs) 97 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) 98 | dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] 99 | return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) 100 | 101 | def _iterate_reals(self, minibatch_size): 102 | dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) 103 | while True: 104 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 105 | if self._mirror_augment: 106 | images = misc.apply_mirror_augment(images) 107 | yield images 108 | 109 | def _iterate_fakes(self, Gs, minibatch_size, num_gpus): 110 | while True: 111 | latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) 112 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 113 | images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) 114 | yield images 115 | 116 | #---------------------------------------------------------------------------- 117 | # Group of multiple metrics. 118 | 119 | class MetricGroup: 120 | def __init__(self, metric_kwarg_list): 121 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] 122 | 123 | def run(self, *args, **kwargs): 124 | for metric in self.metrics: 125 | metric.run(*args, **kwargs) 126 | 127 | def get_result_str(self): 128 | return ' '.join(metric.get_result_str() for metric in self.metrics) 129 | 130 | def update_autosummaries(self): 131 | for metric in self.metrics: 132 | metric.update_autosummaries() 133 | 134 | #---------------------------------------------------------------------------- 135 | # Dummy metric for debugging purposes. 136 | 137 | class DummyMetric(MetricBase): 138 | def _evaluate(self, Gs, num_gpus): 139 | _ = Gs, num_gpus 140 | self._report_result(0.0) 141 | 142 | #---------------------------------------------------------------------------- 143 | -------------------------------------------------------------------------------- /manipulate.py: -------------------------------------------------------------------------------- 1 | # python 3.6 2 | """Manipulates real images with In-domain GAN Inversion. 3 | 4 | The real images should be first inverted to latent codes with `invert.py`. After 5 | that, this script can be used for image manipulation with a given boundary. 6 | """ 7 | 8 | import os.path 9 | import argparse 10 | import pickle 11 | import numpy as np 12 | from tqdm import tqdm 13 | import tensorflow as tf 14 | from dnnlib import tflib 15 | 16 | from utils.logger import setup_logger 17 | from utils.editor import manipulate 18 | from utils.visualizer import load_image 19 | from utils.visualizer import adjust_pixel_range 20 | from utils.visualizer import HtmlPageVisualizer 21 | 22 | 23 | def parse_args(): 24 | """Parses arguments.""" 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('model_path', type=str, 27 | help='Name of the model used for synthesis.') 28 | parser.add_argument('image_dir', type=str, 29 | help='Image directory, which includes original images, ' 30 | 'inverted codes, and image list.') 31 | parser.add_argument('boundary_path', type=str, 32 | help='Path to the boundary for semantic manipulation.') 33 | parser.add_argument('-o', '--output_dir', type=str, default='', 34 | help='Directory to save the results. If not specified, ' 35 | '`./results/manipulation` will be used by default.') 36 | parser.add_argument('--batch_size', type=int, default=32, 37 | help='Batch size. (default: 32)') 38 | parser.add_argument('--step', type=int, default=7, 39 | help='Number of manipulation steps. (default: 7)') 40 | parser.add_argument('--start_distance', type=float, default=-3.0, 41 | help='Start distance for manipulation. (default: -3.0)') 42 | parser.add_argument('--end_distance', type=float, default=3.0, 43 | help='End distance for manipulation. (default: 3.0)') 44 | parser.add_argument('--manipulate_layers', type=str, default='', 45 | help='Indices of the layers to perform manipulation. ' 46 | 'If not specified, all layers will be manipulated. ' 47 | 'More than one layers should be separated by `,`. ' 48 | '(default: None)') 49 | parser.add_argument('--viz_size', type=int, default=256, 50 | help='Image size for visualization. (default: 256)') 51 | parser.add_argument('--gpu_id', type=str, default='0', 52 | help='Which GPU(s) to use. (default: `0`)') 53 | return parser.parse_args() 54 | 55 | 56 | def main(): 57 | """Main function.""" 58 | args = parse_args() 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 60 | image_dir = args.image_dir 61 | image_dir_name = os.path.basename(image_dir.rstrip('/')) 62 | assert os.path.exists(image_dir) 63 | assert os.path.exists(f'{image_dir}/image_list.txt') 64 | assert os.path.exists(f'{image_dir}/inverted_codes.npy') 65 | boundary_path = args.boundary_path 66 | assert os.path.exists(boundary_path) 67 | boundary_name = os.path.splitext(os.path.basename(boundary_path))[0] 68 | output_dir = args.output_dir or 'results/manipulation' 69 | job_name = f'{boundary_name.upper()}_{image_dir_name}' 70 | logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger') 71 | 72 | # Load model. 73 | logger.info(f'Loading generator.') 74 | tflib.init_tf({'rnd.np_random_seed': 1000}) 75 | with open(args.model_path, 'rb') as f: 76 | _, _, _, Gs = pickle.load(f) 77 | 78 | # Build graph. 79 | logger.info(f'Building graph.') 80 | sess = tf.get_default_session() 81 | num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3] 82 | wp = tf.placeholder( 83 | tf.float32, [args.batch_size, num_layers, latent_dim], name='latent_code') 84 | x = Gs.components.synthesis.get_output_for(wp, randomize_noise=False) 85 | 86 | # Load image, codes, and boundary. 87 | logger.info(f'Loading images and corresponding inverted latent codes.') 88 | image_list = [] 89 | with open(f'{image_dir}/image_list.txt', 'r') as f: 90 | for line in f: 91 | name = os.path.splitext(os.path.basename(line.strip()))[0] 92 | assert os.path.exists(f'{image_dir}/{name}_ori.png') 93 | assert os.path.exists(f'{image_dir}/{name}_inv.png') 94 | image_list.append(name) 95 | latent_codes = np.load(f'{image_dir}/inverted_codes.npy') 96 | assert latent_codes.shape[0] == len(image_list) 97 | num_images = latent_codes.shape[0] 98 | logger.info(f'Loading boundary.') 99 | boundary_file = np.load(boundary_path, allow_pickle=True)[()] 100 | if isinstance(boundary_file, dict): 101 | boundary = boundary_file['boundary'] 102 | manipulate_layers = boundary_file['meta_data']['manipulate_layers'] 103 | else: 104 | boundary = boundary_file 105 | manipulate_layers = args.manipulate_layers 106 | if manipulate_layers: 107 | logger.info(f' Manipulating on layers `{manipulate_layers}`.') 108 | else: 109 | logger.info(f' Manipulating on ALL layers.') 110 | 111 | # Manipulate images. 112 | logger.info(f'Start manipulation.') 113 | step = args.step 114 | viz_size = None if args.viz_size == 0 else args.viz_size 115 | visualizer = HtmlPageVisualizer( 116 | num_rows=num_images, num_cols=step + 3, viz_size=viz_size) 117 | visualizer.set_headers( 118 | ['Name', 'Origin', 'Inverted'] + 119 | [f'Step {i:02d}' for i in range(1, step + 1)] 120 | ) 121 | for img_idx, img_name in enumerate(image_list): 122 | ori_image = load_image(f'{image_dir}/{img_name}_ori.png') 123 | inv_image = load_image(f'{image_dir}/{img_name}_inv.png') 124 | visualizer.set_cell(img_idx, 0, text=img_name) 125 | visualizer.set_cell(img_idx, 1, image=ori_image) 126 | visualizer.set_cell(img_idx, 2, image=inv_image) 127 | 128 | codes = manipulate(latent_codes=latent_codes, 129 | boundary=boundary, 130 | start_distance=args.start_distance, 131 | end_distance=args.end_distance, 132 | step=step, 133 | layerwise_manipulation=True, 134 | num_layers=num_layers, 135 | manipulate_layers=manipulate_layers, 136 | is_code_layerwise=True, 137 | is_boundary_layerwise=True) 138 | inputs = np.zeros((args.batch_size, num_layers, latent_dim), np.float32) 139 | for img_idx in tqdm(range(num_images), leave=False): 140 | output_images = [] 141 | for idx in range(0, step, args.batch_size): 142 | batch = codes[img_idx, idx:idx + args.batch_size] 143 | inputs[0:len(batch)] = batch 144 | images = sess.run(x, feed_dict={wp: inputs}) 145 | output_images.append(images[0:len(batch)]) 146 | output_images = adjust_pixel_range(np.concatenate(output_images, axis=0)) 147 | for s, output_image in enumerate(output_images): 148 | visualizer.set_cell(img_idx, s + 3, image=output_image) 149 | 150 | # Save results. 151 | visualizer.save(f'{output_dir}/{job_name}.html') 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | IDInvert 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 |
25 | 28 |
29 | In-Domain GAN Inversion for Real Image Editing 30 |
31 |
32 | 33 |
34 | Jiapeng Zhu1*,  35 | Yujun Shen1*,  36 | Deli Zhao2,  37 | Bolei Zhou1 38 |
39 |
40 | 1The Chinese University of Hong Kong         41 | 2Xiaomi AI Lab 42 |
43 | 49 |
50 | 51 |
52 |
53 | 54 | 55 | 56 | 57 |
58 |
Overview
59 |
60 | In this work, we argue that the GAN inversion task is required 61 | not only to reconstruct the target image by pixel values, 62 | but also to keep the inverted code in the semantic domain of the original latent space of well-trained GANs. For this purpose, we propose In-Domain GAN inversion (IDInvert) by 63 | first training a novel domain-guided encoder which is able to produce in-domain latent code, 64 | and then performing domain-regularized optimization which involves the encoder as a regularizer to land the 65 | code inside the latent space when being finetuned. 66 | The in-domain codes produced by IDInvert enable high-quality real image editing with fixed GAN models. 67 |
68 |
69 | 70 | 71 | 72 | 73 |
74 |
Results
75 |
76 | Semantic diffusion results. 77 | 78 | 79 | 80 | 81 | 82 |
83 | 84 | Image editing results. 85 | 86 | 87 | 88 | 89 | 90 |
91 | 92 | See more results in the following demo video: 93 | 94 |
95 | 99 |
100 |
101 | This work is featured in Two Minute Papers Youtube channel as below: 102 | 103 |
104 | 108 |
109 | 110 |
111 |
112 | 113 | 114 | 115 | 116 |
117 |
BibTeX
118 |
119 | @inproceedings{zhu2020indomain,
120 |   title     = {In-domain GAN Inversion for Real Image Editing},
121 |   author    = {Zhu, Jiapeng and Shen, Yujun and Zhao, Deli and Zhou, Bolei},
122 |   booktitle = {Proceedings of European Conference on Computer Vision (ECCV)},
123 |   year      = {2020}
124 | }
125 | 
126 | 127 |
Related Work
128 |
129 |
130 |
131 | 132 | Y. Shen, J. Gu, X. Tang, B. Zhou. 133 | Interpreting Latent Space of GANs for Semantic Face Editing. 134 | CVPR 2020.
135 | Comment: 136 | Proposes a technique for semantic face editing in latent space. 137 |
138 |
139 |
140 |
141 | 149 |
150 |
151 |
152 |
153 | 154 | R. Abdal, Y. Qin, P. Wonka 155 | Image2StyleGAN: How to Embed Images Into the StyleGAN Latent Space? 156 | ICCV 2019.
157 | Comment: 158 | Explores how to Embed Images into the latent space. 159 |
160 |
161 |
162 |
163 |
164 | 165 | P. Isola, J.Y. Zhu, T. Zhou, A. A. Efros. 166 | Image-to-Image Translation with Conditional Adversarial Nets. 167 | CVPR 2017.
168 | Comment: 169 | Investigates image-to-image translation using conditional GANs. 170 |
171 |
172 |
173 |
174 |
175 | 176 | J.Y. Zhu, P. Krähenbühl, E. Shechtman, A. A. Efros. 177 | Generative Visual Manipulation on the Natural Image Manifold. 178 | ECCV 2016.
179 | Comment: 180 | Proposes a method for realistic photo manipulation and a system for interactive drawing using GANs. 181 |
182 |
183 |
184 |
185 |
186 | 187 | J. Gu, Y. Shen, B. Zhou. 188 | Image Processing Using Multi-Code GAN Prior. 189 | CVPR 2020.
190 | Comment: 191 | Employs multiple latent codes to invert a GAN model as prior for real image processing. 192 |
193 |
194 |
195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /invert.py: -------------------------------------------------------------------------------- 1 | # python 3.6 2 | """Inverts given images to latent codes with In-Domain GAN Inversion. 3 | 4 | Basically, for a particular image (real or synthesized), this script first 5 | employs the domain-guided encoder to produce a initial point in the latent 6 | space and then performs domain-regularized optimization to refine the latent 7 | code. 8 | """ 9 | 10 | import os 11 | import argparse 12 | import pickle 13 | from tqdm import tqdm 14 | import numpy as np 15 | import tensorflow as tf 16 | from dnnlib import tflib 17 | 18 | from perceptual_model import PerceptualModel 19 | from utils.logger import setup_logger 20 | from utils.visualizer import adjust_pixel_range 21 | from utils.visualizer import HtmlPageVisualizer 22 | from utils.visualizer import save_image, load_image, resize_image 23 | 24 | 25 | def parse_args(): 26 | """Parses arguments.""" 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('model_path', type=str, 29 | help='Path to the pre-trained model.') 30 | parser.add_argument('image_list', type=str, 31 | help='List of images to invert.') 32 | parser.add_argument('-o', '--output_dir', type=str, default='', 33 | help='Directory to save the results. If not specified, ' 34 | '`./results/inversion/${IMAGE_LIST}` ' 35 | 'will be used by default.') 36 | parser.add_argument('--batch_size', type=int, default=4, 37 | help='Batch size. (default: 4)') 38 | parser.add_argument('--learning_rate', type=float, default=0.01, 39 | help='Learning rate for optimization. (default: 0.01)') 40 | parser.add_argument('--num_iterations', type=int, default=100, 41 | help='Number of optimization iterations. (default: 100)') 42 | parser.add_argument('--num_results', type=int, default=5, 43 | help='Number of intermediate optimization results to ' 44 | 'save for each sample. (default: 5)') 45 | parser.add_argument('-R', '--random_init', action='store_true', 46 | help='Whether to use random initialization instead of ' 47 | 'the output from encoder. (default: False)') 48 | parser.add_argument('-E', '--domain_regularizer', action='store_false', 49 | help='Whether to use domain regularizer for ' 50 | 'optimization. (default: True)') 51 | parser.add_argument('--loss_weight_feat', type=float, default=5e-5, 52 | help='The perceptual loss scale for optimization. ' 53 | '(default: 5e-5)') 54 | parser.add_argument('--loss_weight_enc', type=float, default=2.0, 55 | help='The encoder loss scale for optimization.' 56 | '(default: 2.0)') 57 | parser.add_argument('--viz_size', type=int, default=256, 58 | help='Image size for visualization. (default: 256)') 59 | parser.add_argument('--gpu_id', type=str, default='0', 60 | help='Which GPU(s) to use. (default: `0`)') 61 | return parser.parse_args() 62 | 63 | 64 | def main(): 65 | """Main function.""" 66 | args = parse_args() 67 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 68 | assert os.path.exists(args.image_list) 69 | image_list_name = os.path.splitext(os.path.basename(args.image_list))[0] 70 | output_dir = args.output_dir or f'results/inversion/{image_list_name}' 71 | logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger') 72 | 73 | logger.info(f'Loading model.') 74 | tflib.init_tf({'rnd.np_random_seed': 1000}) 75 | with open(args.model_path, 'rb') as f: 76 | E, _, _, Gs = pickle.load(f) 77 | 78 | # Get input size. 79 | image_size = E.input_shape[2] 80 | assert image_size == E.input_shape[3] 81 | 82 | # Build graph. 83 | logger.info(f'Building graph.') 84 | sess = tf.get_default_session() 85 | input_shape = E.input_shape 86 | input_shape[0] = args.batch_size 87 | x = tf.placeholder(tf.float32, shape=input_shape, name='real_image') 88 | x_255 = (tf.transpose(x, [0, 2, 3, 1]) + 1) / 2 * 255 89 | latent_shape = Gs.components.synthesis.input_shape 90 | latent_shape[0] = args.batch_size 91 | wp = tf.get_variable(shape=latent_shape, name='latent_code') 92 | x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False) 93 | x_rec_255 = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) / 2 * 255 94 | if args.random_init: 95 | logger.info(f' Use random initialization for optimization.') 96 | wp_rnd = tf.random.normal(shape=latent_shape, name='latent_code_init') 97 | setter = tf.assign(wp, wp_rnd) 98 | else: 99 | logger.info(f' Use encoder output as the initialization for optimization.') 100 | w_enc = E.get_output_for(x, is_training=False) 101 | wp_enc = tf.reshape(w_enc, latent_shape) 102 | setter = tf.assign(wp, wp_enc) 103 | 104 | # Settings for optimization. 105 | logger.info(f'Setting configuration for optimization.') 106 | perceptual_model = PerceptualModel([image_size, image_size], False) 107 | x_feat = perceptual_model(x_255) 108 | x_rec_feat = perceptual_model(x_rec_255) 109 | loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1]) 110 | loss_pix = tf.reduce_mean(tf.square(x - x_rec), axis=[1, 2, 3]) 111 | if args.domain_regularizer: 112 | logger.info(f' Involve encoder for optimization.') 113 | w_enc_new = E.get_output_for(x_rec, is_training=False) 114 | wp_enc_new = tf.reshape(w_enc_new, latent_shape) 115 | loss_enc = tf.reduce_mean(tf.square(wp - wp_enc_new), axis=[1, 2]) 116 | else: 117 | logger.info(f' Do NOT involve encoder for optimization.') 118 | loss_enc = 0 119 | loss = (loss_pix + 120 | args.loss_weight_feat * loss_feat + 121 | args.loss_weight_enc * loss_enc) 122 | optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) 123 | train_op = optimizer.minimize(loss, var_list=[wp]) 124 | tflib.init_uninitialized_vars() 125 | 126 | # Load image list. 127 | logger.info(f'Loading image list.') 128 | image_list = [] 129 | with open(args.image_list, 'r') as f: 130 | for line in f: 131 | image_list.append(line.strip()) 132 | 133 | # Invert images. 134 | logger.info(f'Start inversion.') 135 | save_interval = args.num_iterations // args.num_results 136 | headers = ['Name', 'Original Image', 'Encoder Output'] 137 | for step in range(1, args.num_iterations + 1): 138 | if step == args.num_iterations or step % save_interval == 0: 139 | headers.append(f'Step {step:06d}') 140 | viz_size = None if args.viz_size == 0 else args.viz_size 141 | visualizer = HtmlPageVisualizer( 142 | num_rows=len(image_list), num_cols=len(headers), viz_size=viz_size) 143 | visualizer.set_headers(headers) 144 | 145 | images = np.zeros(input_shape, np.uint8) 146 | names = ['' for _ in range(args.batch_size)] 147 | latent_codes_enc = [] 148 | latent_codes = [] 149 | for img_idx in tqdm(range(0, len(image_list), args.batch_size), leave=False): 150 | # Load inputs. 151 | batch = image_list[img_idx:img_idx + args.batch_size] 152 | for i, image_path in enumerate(batch): 153 | image = resize_image(load_image(image_path), (image_size, image_size)) 154 | images[i] = np.transpose(image, [2, 0, 1]) 155 | names[i] = os.path.splitext(os.path.basename(image_path))[0] 156 | inputs = images.astype(np.float32) / 255 * 2.0 - 1.0 157 | # Run encoder. 158 | sess.run([setter], {x: inputs}) 159 | outputs = sess.run([wp, x_rec]) 160 | latent_codes_enc.append(outputs[0][0:len(batch)]) 161 | outputs[1] = adjust_pixel_range(outputs[1]) 162 | for i, _ in enumerate(batch): 163 | image = np.transpose(images[i], [1, 2, 0]) 164 | save_image(f'{output_dir}/{names[i]}_ori.png', image) 165 | save_image(f'{output_dir}/{names[i]}_enc.png', outputs[1][i]) 166 | visualizer.set_cell(i + img_idx, 0, text=names[i]) 167 | visualizer.set_cell(i + img_idx, 1, image=image) 168 | visualizer.set_cell(i + img_idx, 2, image=outputs[1][i]) 169 | # Optimize latent codes. 170 | col_idx = 3 171 | for step in tqdm(range(1, args.num_iterations + 1), leave=False): 172 | sess.run(train_op, {x: inputs}) 173 | if step == args.num_iterations or step % save_interval == 0: 174 | outputs = sess.run([wp, x_rec]) 175 | outputs[1] = adjust_pixel_range(outputs[1]) 176 | for i, _ in enumerate(batch): 177 | if step == args.num_iterations: 178 | save_image(f'{output_dir}/{names[i]}_inv.png', outputs[1][i]) 179 | visualizer.set_cell(i + img_idx, col_idx, image=outputs[1][i]) 180 | col_idx += 1 181 | latent_codes.append(outputs[0][0:len(batch)]) 182 | 183 | # Save results. 184 | os.system(f'cp {args.image_list} {output_dir}/image_list.txt') 185 | np.save(f'{output_dir}/encoded_codes.npy', 186 | np.concatenate(latent_codes_enc, axis=0)) 187 | np.save(f'{output_dir}/inverted_codes.npy', 188 | np.concatenate(latent_codes, axis=0)) 189 | visualizer.save(f'{output_dir}/inversion.html') 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | -------------------------------------------------------------------------------- /diffuse.py: -------------------------------------------------------------------------------- 1 | # python 3.6 2 | """diffuses target images to context images with In-domain GAN Inversion. 3 | 4 | Basically, this script first copies the central region from the target image to 5 | the context image, and then performs in-domain GAN inversion on the stitched 6 | image. Different from `intert.py`, masked reconstruction loss is used in the 7 | optimization stage. 8 | 9 | NOTE: This script will diffuse every image from `target_image_list` to every 10 | image from `context_image_list`. 11 | """ 12 | 13 | import os 14 | import argparse 15 | import pickle 16 | from tqdm import tqdm 17 | import numpy as np 18 | import tensorflow as tf 19 | from dnnlib import tflib 20 | 21 | from perceptual_model import PerceptualModel 22 | from utils.logger import setup_logger 23 | from utils.visualizer import adjust_pixel_range 24 | from utils.visualizer import HtmlPageVisualizer 25 | from utils.visualizer import load_image, resize_image 26 | 27 | 28 | def parse_args(): 29 | """Parses arguments.""" 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('model_path', type=str, 32 | help='Path to the pre-trained model.') 33 | parser.add_argument('target_list', type=str, 34 | help='List of target images to diffuse from.') 35 | parser.add_argument('context_list', type=str, 36 | help='List of context images to diffuse to.') 37 | parser.add_argument('-o', '--output_dir', type=str, default='', 38 | help='Directory to save the results. If not specified, ' 39 | '`./results/diffusion` will be used by default.') 40 | parser.add_argument('-s', '--crop_size', type=int, default=110, 41 | help='Crop size. (default: 110)') 42 | parser.add_argument('-x', '--center_x', type=int, default=125, 43 | help='X-coordinate (column) of the center of the cropped ' 44 | 'patch. This field should be adjusted according to ' 45 | 'dataset and image size. (default: 125)') 46 | parser.add_argument('-y', '--center_y', type=int, default=145, 47 | help='Y-coordinate (row) of the center of the cropped ' 48 | 'patch. This field should be adjusted according to ' 49 | 'dataset and image size. (default: 145)') 50 | parser.add_argument('--batch_size', type=int, default=4, 51 | help='Batch size. (default: 4)') 52 | parser.add_argument('--learning_rate', type=float, default=0.01, 53 | help='Learning rate for optimization. (default: 0.01)') 54 | parser.add_argument('--num_iterations', type=int, default=100, 55 | help='Number of optimization iterations. (default: 100)') 56 | parser.add_argument('--num_results', type=int, default=5, 57 | help='Number of intermediate optimization results to ' 58 | 'save for each sample. (default: 5)') 59 | parser.add_argument('--loss_weight_feat', type=float, default=5e-5, 60 | help='The perceptual loss scale for optimization. ' 61 | '(default: 5e-5)') 62 | parser.add_argument('--viz_size', type=int, default=256, 63 | help='Image size for visualization. (default: 256)') 64 | parser.add_argument('--gpu_id', type=str, default='0', 65 | help='Which GPU(s) to use. (default: `0`)') 66 | return parser.parse_args() 67 | 68 | 69 | def main(): 70 | """Main function.""" 71 | args = parse_args() 72 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 73 | assert os.path.exists(args.target_list) 74 | target_list_name = os.path.splitext(os.path.basename(args.target_list))[0] 75 | assert os.path.exists(args.context_list) 76 | context_list_name = os.path.splitext(os.path.basename(args.context_list))[0] 77 | output_dir = args.output_dir or f'results/diffusion' 78 | job_name = f'{target_list_name}_TO_{context_list_name}' 79 | logger = setup_logger(output_dir, f'{job_name}.log', f'{job_name}_logger') 80 | 81 | logger.info(f'Loading model.') 82 | tflib.init_tf({'rnd.np_random_seed': 1000}) 83 | with open(args.model_path, 'rb') as f: 84 | E, _, _, Gs = pickle.load(f) 85 | 86 | # Get input size. 87 | image_size = E.input_shape[2] 88 | assert image_size == E.input_shape[3] 89 | crop_size = args.crop_size 90 | crop_x = args.center_x - crop_size // 2 91 | crop_y = args.center_y - crop_size // 2 92 | mask = np.zeros((1, image_size, image_size, 3), dtype=np.float32) 93 | mask[:, crop_y:crop_y + crop_size, crop_x:crop_x + crop_size, :] = 1.0 94 | 95 | # Build graph. 96 | logger.info(f'Building graph.') 97 | sess = tf.get_default_session() 98 | input_shape = E.input_shape 99 | input_shape[0] = args.batch_size 100 | x = tf.placeholder(tf.float32, shape=input_shape, name='real_image') 101 | x_mask = (tf.transpose(x, [0, 2, 3, 1]) + 1) * mask - 1 102 | x_mask_255 = (x_mask + 1) / 2 * 255 103 | latent_shape = Gs.components.synthesis.input_shape 104 | latent_shape[0] = args.batch_size 105 | wp = tf.get_variable(shape=latent_shape, name='latent_code') 106 | x_rec = Gs.components.synthesis.get_output_for(wp, randomize_noise=False) 107 | x_rec_mask = (tf.transpose(x_rec, [0, 2, 3, 1]) + 1) * mask - 1 108 | x_rec_mask_255 = (x_rec_mask + 1) / 2 * 255 109 | 110 | w_enc = E.get_output_for(x, is_training=False) 111 | wp_enc = tf.reshape(w_enc, latent_shape) 112 | setter = tf.assign(wp, wp_enc) 113 | 114 | # Settings for optimization. 115 | logger.info(f'Setting configuration for optimization.') 116 | perceptual_model = PerceptualModel([image_size, image_size], False) 117 | x_feat = perceptual_model(x_mask_255) 118 | x_rec_feat = perceptual_model(x_rec_mask_255) 119 | loss_feat = tf.reduce_mean(tf.square(x_feat - x_rec_feat), axis=[1]) 120 | loss_pix = tf.reduce_mean(tf.square(x_mask - x_rec_mask), axis=[1, 2, 3]) 121 | 122 | loss = loss_pix + args.loss_weight_feat * loss_feat 123 | optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) 124 | train_op = optimizer.minimize(loss, var_list=[wp]) 125 | tflib.init_uninitialized_vars() 126 | 127 | # Load image list. 128 | logger.info(f'Loading target images and context images.') 129 | target_list = [] 130 | with open(args.target_list, 'r') as f: 131 | for line in f: 132 | target_list.append(line.strip()) 133 | num_targets = len(target_list) 134 | context_list = [] 135 | with open(args.context_list, 'r') as f: 136 | for line in f: 137 | context_list.append(line.strip()) 138 | num_contexts = len(context_list) 139 | num_pairs = num_targets * num_contexts 140 | 141 | # Invert images. 142 | logger.info(f'Start diffusion.') 143 | save_interval = args.num_iterations // args.num_results 144 | headers = ['Target Image', 'Context Image', 'Stitched Image', 145 | 'Encoder Output'] 146 | for step in range(1, args.num_iterations + 1): 147 | if step == args.num_iterations or step % save_interval == 0: 148 | headers.append(f'Step {step:06d}') 149 | viz_size = None if args.viz_size == 0 else args.viz_size 150 | visualizer = HtmlPageVisualizer( 151 | num_rows=num_pairs, num_cols=len(headers), viz_size=viz_size) 152 | visualizer.set_headers(headers) 153 | 154 | images = np.zeros(input_shape, np.uint8) 155 | latent_codes_enc = [] 156 | latent_codes = [] 157 | for target_idx in tqdm(range(num_targets), desc='Target ID', leave=False): 158 | # Load target. 159 | target_image = resize_image(load_image(target_list[target_idx]), 160 | (image_size, image_size)) 161 | visualizer.set_cell(target_idx * num_contexts, 0, image=target_image) 162 | for context_idx in tqdm(range(0, num_contexts, args.batch_size), 163 | desc='Context ID', leave=False): 164 | row_idx = target_idx * num_contexts + context_idx 165 | batch = context_list[context_idx:context_idx + args.batch_size] 166 | for i, context_image_path in enumerate(batch): 167 | context_image = resize_image(load_image(context_image_path), 168 | (image_size, image_size)) 169 | visualizer.set_cell(row_idx + i, 1, image=context_image) 170 | context_image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size] = ( 171 | target_image[crop_y:crop_y + crop_size, crop_x:crop_x + crop_size]) 172 | visualizer.set_cell(row_idx + i, 2, image=context_image) 173 | images[i] = np.transpose(context_image, [2, 0, 1]) 174 | inputs = images.astype(np.float32) / 255 * 2.0 - 1.0 175 | # Run encoder. 176 | sess.run([setter], {x: inputs}) 177 | outputs = sess.run([wp, x_rec]) 178 | latent_codes_enc.append(outputs[0][0:len(batch)]) 179 | outputs[1] = adjust_pixel_range(outputs[1]) 180 | for i, _ in enumerate(batch): 181 | visualizer.set_cell(row_idx + i, 3, image=outputs[1][i]) 182 | # Optimize latent codes. 183 | col_idx = 4 184 | for step in tqdm(range(1, args.num_iterations + 1), leave=False): 185 | sess.run(train_op, {x: inputs}) 186 | if step == args.num_iterations or step % save_interval == 0: 187 | outputs = sess.run([wp, x_rec]) 188 | outputs[1] = adjust_pixel_range(outputs[1]) 189 | for i, _ in enumerate(batch): 190 | visualizer.set_cell(row_idx + i, col_idx, image=outputs[1][i]) 191 | col_idx += 1 192 | latent_codes.append(outputs[0][0:len(batch)]) 193 | 194 | # Save results. 195 | code_shape = [num_targets, num_contexts] + list(latent_shape[1:]) 196 | np.save(f'{output_dir}/{job_name}_encoded_codes.npy', 197 | np.concatenate(latent_codes_enc, axis=0).reshape(code_shape)) 198 | np.save(f'{output_dir}/{job_name}_inverted_codes.npy', 199 | np.concatenate(latent_codes, axis=0).reshape(code_shape)) 200 | visualizer.save(f'{output_dir}/{job_name}.html') 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | -------------------------------------------------------------------------------- /generate_figures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators.""" 9 | 10 | import os 11 | import pickle 12 | import numpy as np 13 | import PIL.Image 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | import config 17 | 18 | #---------------------------------------------------------------------------- 19 | # Helpers for loading and using pre-trained generators. 20 | 21 | url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl 22 | url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl 23 | url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl 24 | url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl 25 | url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl 26 | 27 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 28 | 29 | _Gs_cache = dict() 30 | 31 | def load_Gs(url): 32 | if url not in _Gs_cache: 33 | with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 34 | _G, _D, Gs = pickle.load(f) 35 | _Gs_cache[url] = Gs 36 | return _Gs_cache[url] 37 | 38 | #---------------------------------------------------------------------------- 39 | # Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images. 40 | 41 | def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed): 42 | print(png) 43 | latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1]) 44 | images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb] 45 | 46 | canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white') 47 | image_iter = iter(list(images)) 48 | for col, lod in enumerate(lods): 49 | for row in range(rows * 2**lod): 50 | image = PIL.Image.fromarray(next(image_iter), 'RGB') 51 | image = image.crop((cx, cy, cx + cw, cy + ch)) 52 | image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS) 53 | canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod)) 54 | canvas.save(png) 55 | 56 | #---------------------------------------------------------------------------- 57 | # Figure 3: Style mixing. 58 | 59 | def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges): 60 | print(png) 61 | src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) 62 | dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds) 63 | src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] 64 | dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component] 65 | src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs) 66 | dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs) 67 | 68 | canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white') 69 | for col, src_image in enumerate(list(src_images)): 70 | canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0)) 71 | for row, dst_image in enumerate(list(dst_images)): 72 | canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h)) 73 | row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds)) 74 | row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]] 75 | row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) 76 | for col, image in enumerate(list(row_images)): 77 | canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h)) 78 | canvas.save(png) 79 | 80 | #---------------------------------------------------------------------------- 81 | # Figure 4: Noise detail. 82 | 83 | def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds): 84 | print(png) 85 | canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white') 86 | for row, seed in enumerate(seeds): 87 | latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples) 88 | images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs) 89 | canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h)) 90 | for i in range(4): 91 | crop = PIL.Image.fromarray(images[i + 1], 'RGB') 92 | crop = crop.crop((650, 180, 906, 436)) 93 | crop = crop.resize((w//2, h//2), PIL.Image.NEAREST) 94 | canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2)) 95 | diff = np.std(np.mean(images, axis=3), axis=0) * 4 96 | diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8) 97 | canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h)) 98 | canvas.save(png) 99 | 100 | #---------------------------------------------------------------------------- 101 | # Figure 5: Noise components. 102 | 103 | def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips): 104 | print(png) 105 | Gsc = Gs.clone() 106 | noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')] 107 | noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...] 108 | latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) 109 | all_images = [] 110 | for noise_range in noise_ranges: 111 | tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)}) 112 | range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs) 113 | range_images[flips, :, :] = range_images[flips, :, ::-1] 114 | all_images.append(list(range_images)) 115 | 116 | canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white') 117 | for col, col_images in enumerate(zip(*all_images)): 118 | canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0)) 119 | canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0)) 120 | canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h)) 121 | canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h)) 122 | canvas.save(png) 123 | 124 | #---------------------------------------------------------------------------- 125 | # Figure 8: Truncation trick. 126 | 127 | def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis): 128 | print(png) 129 | latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) 130 | dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component] 131 | dlatent_avg = Gs.get_var('dlatent_avg') # [component] 132 | 133 | canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white') 134 | for row, dlatent in enumerate(list(dlatents)): 135 | row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg 136 | row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) 137 | for col, image in enumerate(list(row_images)): 138 | canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h)) 139 | canvas.save(png) 140 | 141 | #---------------------------------------------------------------------------- 142 | # Main program. 143 | 144 | def main(): 145 | tflib.init_tf() 146 | os.makedirs(config.result_dir, exist_ok=True) 147 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5) 148 | draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)]) 149 | draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012]) 150 | draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1]) 151 | draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1]) 152 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0) 153 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2) 154 | draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1) 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | if __name__ == "__main__": 159 | main() 160 | 161 | #---------------------------------------------------------------------------- 162 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | return 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | return tf.saturate_cast(images, tf.uint8) 241 | -------------------------------------------------------------------------------- /dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /metrics/linear_separability.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Linear Separability (LS).""" 9 | 10 | from collections import defaultdict 11 | import numpy as np 12 | import sklearn.svm 13 | import tensorflow as tf 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | classifier_urls = [ 22 | 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl 23 | 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl 24 | 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl 25 | 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl 26 | 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl 27 | 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl 28 | 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl 29 | 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl 30 | 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl 31 | 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl 32 | 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl 33 | 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl 34 | 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl 35 | 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl 36 | 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl 37 | 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl 38 | 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl 39 | 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl 40 | 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl 41 | 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl 42 | 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl 43 | 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl 44 | 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl 45 | 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl 46 | 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl 47 | 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl 48 | 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl 49 | 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl 50 | 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl 51 | 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl 52 | 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl 53 | 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl 54 | 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl 55 | 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl 56 | 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl 57 | 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl 58 | 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl 59 | 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl 60 | 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl 61 | 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl 62 | ] 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | def prob_normalize(p): 67 | p = np.asarray(p).astype(np.float32) 68 | assert len(p.shape) == 2 69 | return p / np.sum(p) 70 | 71 | def mutual_information(p): 72 | p = prob_normalize(p) 73 | px = np.sum(p, axis=1) 74 | py = np.sum(p, axis=0) 75 | result = 0.0 76 | for x in range(p.shape[0]): 77 | p_x = px[x] 78 | for y in range(p.shape[1]): 79 | p_xy = p[x][y] 80 | p_y = py[y] 81 | if p_xy > 0.0: 82 | result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output 83 | return result 84 | 85 | def entropy(p): 86 | p = prob_normalize(p) 87 | result = 0.0 88 | for x in range(p.shape[0]): 89 | for y in range(p.shape[1]): 90 | p_xy = p[x][y] 91 | if p_xy > 0.0: 92 | result -= p_xy * np.log2(p_xy) 93 | return result 94 | 95 | def conditional_entropy(p): 96 | # H(Y|X) where X corresponds to axis 0, Y to axis 1 97 | # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? 98 | p = prob_normalize(p) 99 | y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) 100 | return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | class LS(metric_base.MetricBase): 105 | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): 106 | assert num_keep <= num_samples 107 | super().__init__(**kwargs) 108 | self.num_samples = num_samples 109 | self.num_keep = num_keep 110 | self.attrib_indices = attrib_indices 111 | self.minibatch_per_gpu = minibatch_per_gpu 112 | 113 | def _evaluate(self, Gs, num_gpus): 114 | minibatch_size = num_gpus * self.minibatch_per_gpu 115 | 116 | # Construct TensorFlow graph for each GPU. 117 | result_expr = [] 118 | for gpu_idx in range(num_gpus): 119 | with tf.device('/gpu:%d' % gpu_idx): 120 | Gs_clone = Gs.clone() 121 | 122 | # Generate images. 123 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 124 | dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True) 125 | images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True) 126 | 127 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 128 | if images.shape[2] > 256: 129 | factor = images.shape[2] // 256 130 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 131 | images = tf.reduce_mean(images, axis=[3, 5]) 132 | 133 | # Run classifier for each attribute. 134 | result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) 135 | for attrib_idx in self.attrib_indices: 136 | classifier = misc.load_pkl(classifier_urls[attrib_idx]) 137 | logits = classifier.get_output_for(images, None) 138 | predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) 139 | result_dict[attrib_idx] = predictions 140 | result_expr.append(result_dict) 141 | 142 | # Sampling loop. 143 | results = [] 144 | for _ in range(0, self.num_samples, minibatch_size): 145 | results += tflib.run(result_expr) 146 | results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} 147 | 148 | # Calculate conditional entropy for each attribute. 149 | conditional_entropies = defaultdict(list) 150 | for attrib_idx in self.attrib_indices: 151 | # Prune the least confident samples. 152 | pruned_indices = list(range(self.num_samples)) 153 | pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) 154 | pruned_indices = pruned_indices[:self.num_keep] 155 | 156 | # Fit SVM to the remaining samples. 157 | svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) 158 | for space in ['latents', 'dlatents']: 159 | svm_inputs = results[space][pruned_indices] 160 | try: 161 | svm = sklearn.svm.LinearSVC() 162 | svm.fit(svm_inputs, svm_targets) 163 | svm.score(svm_inputs, svm_targets) 164 | svm_outputs = svm.predict(svm_inputs) 165 | except: 166 | svm_outputs = svm_targets # assume perfect prediction 167 | 168 | # Calculate conditional entropy. 169 | p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] 170 | conditional_entropies[space].append(conditional_entropy(p)) 171 | 172 | # Calculate separability scores. 173 | scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} 174 | self._report_result(scores['latents'], suffix='_z') 175 | self._report_result(scores['dlatents'], suffix='_w') 176 | 177 | #---------------------------------------------------------------------------- 178 | -------------------------------------------------------------------------------- /training/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility functions.""" 9 | 10 | import os 11 | import glob 12 | import pickle 13 | import re 14 | import numpy as np 15 | from collections import defaultdict 16 | import PIL.Image 17 | import dnnlib 18 | 19 | import config 20 | from training import dataset 21 | 22 | #---------------------------------------------------------------------------- 23 | # Convenience wrappers for pickle that are able to load data produced by 24 | # older versions of the code, and from external URLs. 25 | 26 | def open_file_or_url(file_or_url): 27 | if dnnlib.util.is_url(file_or_url): 28 | return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir) 29 | return open(file_or_url, 'rb') 30 | 31 | def load_pkl(file_or_url): 32 | with open_file_or_url(file_or_url) as file: 33 | return pickle.load(file, encoding='latin1') 34 | 35 | def save_pkl(obj, filename): 36 | with open(filename, 'wb') as file: 37 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 38 | 39 | #---------------------------------------------------------------------------- 40 | # Image utils. 41 | 42 | def adjust_dynamic_range(data, drange_in, drange_out): 43 | if drange_in != drange_out: 44 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 45 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 46 | data = data * scale + bias 47 | return data 48 | 49 | def create_image_grid(images, grid_size=None): 50 | assert images.ndim == 3 or images.ndim == 4 51 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 52 | 53 | if grid_size is not None: 54 | grid_w, grid_h = tuple(grid_size) 55 | else: 56 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 57 | grid_h = max((num - 1) // grid_w + 1, 1) 58 | 59 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 60 | for idx in range(num): 61 | x = (idx % grid_w) * img_w 62 | y = (idx // grid_w) * img_h 63 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 64 | return grid 65 | 66 | def convert_to_pil_image(image, drange=[0,1]): 67 | assert image.ndim == 2 or image.ndim == 3 68 | if image.ndim == 3: 69 | if image.shape[0] == 1: 70 | image = image[0] # grayscale CHW => HW 71 | else: 72 | image = image.transpose(1, 2, 0) # CHW -> HWC 73 | 74 | image = adjust_dynamic_range(image, drange, [0,255]) 75 | image = np.rint(image).clip(0, 255).astype(np.uint8) 76 | fmt = 'RGB' if image.ndim == 3 else 'L' 77 | return PIL.Image.fromarray(image, fmt) 78 | 79 | def save_image(image, filename, drange=[0,1], quality=95): 80 | img = convert_to_pil_image(image, drange) 81 | if '.jpg' in filename: 82 | img.save(filename,"JPEG", quality=quality, optimize=True) 83 | else: 84 | img.save(filename) 85 | 86 | def save_image_grid(images, filename, drange=[0,1], grid_size=None): 87 | convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) 88 | 89 | #---------------------------------------------------------------------------- 90 | # Locating results. 91 | 92 | def locate_run_dir(run_id_or_run_dir): 93 | if isinstance(run_id_or_run_dir, str): 94 | if os.path.isdir(run_id_or_run_dir): 95 | return run_id_or_run_dir 96 | converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir) 97 | if os.path.isdir(converted): 98 | return converted 99 | 100 | run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir)) 101 | for search_dir in ['']: 102 | full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir)) 103 | run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir)) 104 | if os.path.isdir(run_dir): 105 | return run_dir 106 | run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*'))) 107 | run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))] 108 | run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)] 109 | if len(run_dirs) == 1: 110 | return run_dirs[0] 111 | raise IOError('Cannot locate result subdir for run', run_id_or_run_dir) 112 | 113 | def list_network_pkls(run_id_or_run_dir, include_final=True): 114 | run_dir = locate_run_dir(run_id_or_run_dir) 115 | pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl'))) 116 | if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': 117 | if include_final: 118 | pkls.append(pkls[0]) 119 | del pkls[0] 120 | return pkls 121 | 122 | def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): 123 | for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]: 124 | if isinstance(candidate, str): 125 | if os.path.isfile(candidate): 126 | return candidate 127 | converted = dnnlib.submission.submit.convert_path(candidate) 128 | if os.path.isfile(converted): 129 | return converted 130 | 131 | pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl) 132 | if len(pkls) >= 1 and snapshot_or_network_pkl is None: 133 | return pkls[-1] 134 | 135 | for pkl in pkls: 136 | try: 137 | name = os.path.splitext(os.path.basename(pkl))[0] 138 | number = int(name.split('-')[-1]) 139 | if number == snapshot_or_network_pkl: 140 | return pkl 141 | except ValueError: pass 142 | except IndexError: pass 143 | raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl) 144 | 145 | def get_id_string_for_network_pkl(network_pkl): 146 | p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') 147 | return '-'.join(p[max(len(p) - 2, 0):]) 148 | 149 | #---------------------------------------------------------------------------- 150 | # Loading data from previous training runs. 151 | 152 | def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): 153 | return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl)) 154 | 155 | def parse_config_for_previous_run(run_id): 156 | run_dir = locate_run_dir(run_id) 157 | 158 | # Parse config.txt. 159 | cfg = defaultdict(dict) 160 | with open(os.path.join(run_dir, 'config.txt'), 'rt') as f: 161 | for line in f: 162 | line = re.sub(r"^{?\s*'(\w+)':\s*{(.*)(},|}})$", r"\1 = {\2}", line.strip()) 163 | if line.startswith('dataset =') or line.startswith('train ='): 164 | exec(line, cfg, cfg) # pylint: disable=exec-used 165 | 166 | # Handle legacy options. 167 | if 'file_pattern' in cfg['dataset']: 168 | cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '') 169 | if 'mirror_augment' in cfg['dataset']: 170 | cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment') 171 | if 'max_labels' in cfg['dataset']: 172 | v = cfg['dataset'].pop('max_labels') 173 | if v is None: v = 0 174 | if v == 'all': v = 'full' 175 | cfg['dataset']['max_label_size'] = v 176 | if 'max_images' in cfg['dataset']: 177 | cfg['dataset'].pop('max_images') 178 | return cfg 179 | 180 | def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment 181 | cfg = parse_config_for_previous_run(run_id) 182 | cfg['dataset'].update(kwargs) 183 | dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset']) 184 | mirror_augment = cfg['train'].get('mirror_augment', False) 185 | return dataset_obj, mirror_augment 186 | 187 | def apply_mirror_augment(minibatch): 188 | mask = np.random.rand(minibatch.shape[0]) < 0.5 189 | minibatch = np.array(minibatch) 190 | minibatch[mask] = minibatch[mask, :, :, ::-1] 191 | return minibatch 192 | 193 | #---------------------------------------------------------------------------- 194 | # Size and contents of the image snapshot grids that are exported 195 | # periodically during training. 196 | 197 | def setup_snapshot_image_grid(G, training_set, 198 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 199 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 200 | 201 | # Select size. 202 | gw = 1; gh = 1 203 | if size == '1080p': 204 | gw = np.clip(1920 // G.output_shape[3], 3, 32) 205 | gh = np.clip(1080 // G.output_shape[2], 2, 32) 206 | if size == '4k': 207 | gw = np.clip(3840 // G.output_shape[3], 7, 32) 208 | gh = np.clip(2160 // G.output_shape[2], 4, 32) 209 | 210 | # Initialize data arrays. 211 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 212 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 213 | latents = np.random.randn(gw * gh, *G.input_shape[1:]) 214 | 215 | # Random layout. 216 | if layout == 'random': 217 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) 218 | 219 | # Class-conditional layouts. 220 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 221 | if layout in class_layouts: 222 | bw, bh = class_layouts[layout] 223 | nw = (gw - 1) // bw + 1 224 | nh = (gh - 1) // bh + 1 225 | blocks = [[] for _i in range(nw * nh)] 226 | for _iter in range(1000000): 227 | real, label = training_set.get_minibatch_np(1) 228 | idx = np.argmax(label[0]) 229 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 230 | idx += training_set.label_size 231 | if idx < len(blocks): 232 | blocks[idx].append((real, label)) 233 | if all(len(block) >= bw * bh for block in blocks): 234 | break 235 | for i, block in enumerate(blocks): 236 | for j, (real, label) in enumerate(block): 237 | x = (i % nw) * bw + j % bw 238 | y = (i // nw) * bh + j // bw 239 | if x < gw and y < gh: 240 | reals[x + y * gw] = real[0] 241 | labels[x + y * gw] = label[0] 242 | 243 | return (gw, gh), reals, labels, latents 244 | 245 | #---------------------------------------------------------------------------- 246 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Loss functions.""" 9 | 10 | import tensorflow as tf 11 | import dnnlib.tflib as tflib 12 | from dnnlib.tflib.autosummary import autosummary 13 | 14 | #---------------------------------------------------------------------------- 15 | # Convenience func that casts all of its arguments to tf.float32. 16 | 17 | def fp32(*values): 18 | if len(values) == 1 and isinstance(values[0], tuple): 19 | values = values[0] 20 | values = tuple(tf.cast(v, tf.float32) for v in values) 21 | return values if len(values) >= 2 else values[0] 22 | 23 | #---------------------------------------------------------------------------- 24 | # WGAN & WGAN-GP loss functions. 25 | 26 | def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 27 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 28 | labels = training_set.get_random_labels_tf(minibatch_size) 29 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 30 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 31 | loss = -fake_scores_out 32 | return loss 33 | 34 | def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 35 | wgan_epsilon = 0.001): # Weight for the epsilon term, \epsilon_{drift}. 36 | 37 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 38 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 39 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 40 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 41 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 42 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 43 | loss = fake_scores_out - real_scores_out 44 | 45 | with tf.name_scope('EpsilonPenalty'): 46 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) 47 | loss += epsilon_penalty * wgan_epsilon 48 | return loss 49 | 50 | def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 51 | wgan_lambda = 10.0, # Weight for the gradient penalty term. 52 | wgan_epsilon = 0.001, # Weight for the epsilon term, \epsilon_{drift}. 53 | wgan_target = 1.0): # Target value for gradient magnitudes. 54 | 55 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 56 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 57 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 58 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 59 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 60 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 61 | loss = fake_scores_out - real_scores_out 62 | 63 | with tf.name_scope('GradientPenalty'): 64 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) 65 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) 66 | mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) 67 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) 68 | mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) 69 | mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) 70 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) 71 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) 72 | gradient_penalty = tf.square(mixed_norms - wgan_target) 73 | loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) 74 | 75 | with tf.name_scope('EpsilonPenalty'): 76 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) 77 | loss += epsilon_penalty * wgan_epsilon 78 | return loss 79 | 80 | #---------------------------------------------------------------------------- 81 | # Hinge loss functions. (Use G_wgan with these) 82 | 83 | def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument 84 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 85 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 86 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 87 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 88 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 89 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 90 | loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) 91 | return loss 92 | 93 | def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument 94 | wgan_lambda = 10.0, # Weight for the gradient penalty term. 95 | wgan_target = 1.0): # Target value for gradient magnitudes. 96 | 97 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 98 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 99 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 100 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 101 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 102 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 103 | loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) 104 | 105 | with tf.name_scope('GradientPenalty'): 106 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) 107 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) 108 | mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) 109 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) 110 | mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) 111 | mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) 112 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) 113 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) 114 | gradient_penalty = tf.square(mixed_norms - wgan_target) 115 | loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) 116 | return loss 117 | 118 | 119 | #---------------------------------------------------------------------------- 120 | # Loss functions advocated by the paper 121 | # "Which Training Methods for GANs do actually Converge?" 122 | 123 | def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 124 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 125 | labels = training_set.get_random_labels_tf(minibatch_size) 126 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 127 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 128 | loss = -tf.nn.softplus(fake_scores_out) # log(1 - logistic(fake_scores_out)) 129 | return loss 130 | 131 | def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument 132 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 133 | labels = training_set.get_random_labels_tf(minibatch_size) 134 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 135 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 136 | loss = tf.nn.softplus(-fake_scores_out) # -log(logistic(fake_scores_out)) 137 | return loss 138 | 139 | def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument 140 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 141 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 142 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 143 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 144 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 145 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 146 | loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) 147 | loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type 148 | return loss 149 | 150 | def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument 151 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 152 | fake_images_out = G.get_output_for(latents, labels, is_training=True) 153 | real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) 154 | fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) 155 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 156 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 157 | loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) 158 | loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type 159 | 160 | if r1_gamma != 0.0: 161 | with tf.name_scope('R1Penalty'): 162 | real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out)) 163 | real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0])) 164 | r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3]) 165 | r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) 166 | loss += r1_penalty * (r1_gamma * 0.5) 167 | 168 | if r2_gamma != 0.0: 169 | with tf.name_scope('R2Penalty'): 170 | fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out)) 171 | fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0])) 172 | r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3]) 173 | r2_penalty = autosummary('Loss/r2_penalty', r2_penalty) 174 | loss += r2_penalty * (r2_gamma * 0.5) 175 | return loss 176 | 177 | #---------------------------------------------------------------------------- 178 | -------------------------------------------------------------------------------- /training/training_loop_encoder.py: -------------------------------------------------------------------------------- 1 | """Main script for training encoder. This script should be run 2 | only after the stylegan's generator is well-trained""" 3 | 4 | import os 5 | import time 6 | import sys 7 | import tensorflow as tf 8 | import numpy as np 9 | import dnnlib 10 | from dnnlib import EasyDict 11 | import dnnlib.tflib as tflib 12 | from training import misc 13 | from perceptual_model import PerceptualModel 14 | from utils.visualizer import fuse_images 15 | from utils.visualizer import save_image 16 | from utils.visualizer import adjust_pixel_range 17 | 18 | 19 | def process_reals(x, mirror_augment, drange_data, drange_net): 20 | with tf.name_scope('ProcessReals'): 21 | with tf.name_scope('DynamicRange'): 22 | x = tf.cast(x, tf.float32) 23 | x = misc.adjust_dynamic_range(x, drange_data, drange_net) 24 | if mirror_augment: 25 | with tf.name_scope('MirrorAugment'): 26 | s = tf.shape(x) 27 | mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) 28 | mask = tf.tile(mask, [1, s[1], s[2], s[3]]) 29 | x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) 30 | return x 31 | 32 | 33 | def parse_tfrecord_tf(record): 34 | features = tf.parse_single_example(record, features={ 35 | 'shape': tf.FixedLenFeature([3], tf.int64), 36 | 'data': tf.FixedLenFeature([], tf.string)}) 37 | data = tf.decode_raw(features['data'], tf.uint8) 38 | return tf.reshape(data, features['shape']) 39 | 40 | 41 | def get_train_data(sess, data_dir, submit_config, mode): 42 | if mode == 'train': 43 | shuffle = True; repeat = True; batch_size = submit_config.batch_size 44 | elif mode == 'test': 45 | shuffle = False; repeat = True; batch_size = submit_config.batch_size_test 46 | else: 47 | raise Exception("mode must in ['train', 'test'], but got {}" % mode) 48 | 49 | dset = tf.data.TFRecordDataset(data_dir) 50 | dset = dset.map(parse_tfrecord_tf, num_parallel_calls=16) 51 | if shuffle: 52 | bytes_per_item = np.prod([3, submit_config.image_size, submit_config.image_size]) * np.dtype('uint8').itemsize 53 | dset = dset.shuffle(((4096 << 20) - 1) // bytes_per_item + 1) 54 | if repeat: 55 | dset = dset.repeat() 56 | dset = dset.batch(batch_size) 57 | train_iterator = tf.data.Iterator.from_structure(dset.output_types, dset.output_shapes) 58 | training_init_op = train_iterator.make_initializer(dset) 59 | image_batch = train_iterator.get_next() 60 | sess.run(training_init_op) 61 | return image_batch 62 | 63 | 64 | def test(E, Gs, real_test, submit_config): 65 | with tf.name_scope("Run"), tf.control_dependencies(None): 66 | with tf.device("/cpu:0"): 67 | in_split = tf.split(real_test, submit_config.num_gpus) 68 | out_split = [] 69 | num_layers, latent_dim = Gs.components.synthesis.input_shape[1:3] 70 | for gpu in range(submit_config.num_gpus): 71 | with tf.device("/gpu:%d" % gpu): 72 | in_gpu = in_split[gpu] 73 | latent_w = E.get_output_for(in_gpu, is_training=False) 74 | latent_wp = tf.reshape(latent_w, [in_gpu.shape[0], num_layers, latent_dim]) 75 | fake_X_val = Gs.components.synthesis.get_output_for(latent_wp, randomize_noise=False) 76 | out_split.append(fake_X_val) 77 | 78 | with tf.device("/cpu:0"): 79 | out_expr = tf.concat(out_split, axis=0) 80 | 81 | return out_expr 82 | 83 | 84 | def training_loop( 85 | submit_config, 86 | Encoder_args = {}, 87 | E_opt_args = {}, 88 | D_opt_args = {}, 89 | E_loss_args = EasyDict(), 90 | D_loss_args = {}, 91 | lr_args = EasyDict(), 92 | tf_config = {}, 93 | dataset_args = EasyDict(), 94 | decoder_pkl = EasyDict(), 95 | drange_data = [0, 255], 96 | drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. 97 | mirror_augment = False, 98 | filter = 64, # Minimum number of feature maps in any layer. 99 | filter_max = 512, # Maximum number of feature maps in any layer. 100 | resume_run_id = None, # Run ID or network pkl to resume training from, None = start from scratch. 101 | resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. 102 | image_snapshot_ticks = 1, # How often to export image snapshots? 103 | network_snapshot_ticks = 10, # How often to export network snapshots? 104 | max_iters = 150000): 105 | 106 | tflib.init_tf(tf_config) 107 | 108 | with tf.name_scope('Input'): 109 | real_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='real_image_train') 110 | real_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='real_image_test') 111 | real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) 112 | 113 | with tf.device('/gpu:0'): 114 | if resume_run_id is not None: 115 | network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) 116 | print('Loading networks from "%s"...' % network_pkl) 117 | E, G, D, Gs = misc.load_pkl(network_pkl) 118 | start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size 119 | print('Start: ', start) 120 | else: 121 | print('Constructing networks...') 122 | G, D, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) 123 | num_layers = Gs.components.synthesis.input_shape[1] 124 | E = tflib.Network('E_gpu0', size=submit_config.image_size, filter=filter, filter_max=filter_max, 125 | num_layers=num_layers, is_training=True, num_gpus=submit_config.num_gpus, **Encoder_args) 126 | start = 0 127 | 128 | E.print_layers(); Gs.print_layers(); D.print_layers() 129 | 130 | global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') 131 | learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, 132 | lr_args.decay_rate, staircase=lr_args.stair) 133 | add_global0 = global_step0.assign_add(1) 134 | 135 | E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) 136 | D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) 137 | 138 | E_loss_rec = 0. 139 | E_loss_adv = 0. 140 | D_loss_real = 0. 141 | D_loss_fake = 0. 142 | D_loss_grad = 0. 143 | for gpu in range(submit_config.num_gpus): 144 | print('Building Graph on GPU %s' % str(gpu)) 145 | with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): 146 | E_gpu = E if gpu == 0 else E.clone(E.name[:-1] + str(gpu)) 147 | D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') 148 | G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') 149 | perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False) 150 | real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) 151 | with tf.name_scope('E_loss'), tf.control_dependencies(None): 152 | E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args) 153 | E_loss_rec += recon_loss 154 | E_loss_adv += adv_loss 155 | with tf.name_scope('D_loss'), tf.control_dependencies(None): 156 | D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) 157 | D_loss_real += loss_real 158 | D_loss_fake += loss_fake 159 | D_loss_grad += loss_gp 160 | with tf.control_dependencies([add_global0]): 161 | E_opt.register_gradients(E_loss, E_gpu.trainables) 162 | D_opt.register_gradients(D_loss, D_gpu.trainables) 163 | 164 | E_loss_rec /= submit_config.num_gpus 165 | E_loss_adv /= submit_config.num_gpus 166 | D_loss_real /= submit_config.num_gpus 167 | D_loss_fake /= submit_config.num_gpus 168 | D_loss_grad /= submit_config.num_gpus 169 | 170 | E_train_op = E_opt.apply_updates() 171 | D_train_op = D_opt.apply_updates() 172 | 173 | print('Building testing graph...') 174 | fake_X_val = test(E, Gs, real_test, submit_config) 175 | 176 | sess = tf.get_default_session() 177 | 178 | print('Getting training data...') 179 | image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') 180 | image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') 181 | 182 | summary_log = tf.summary.FileWriter(submit_config.run_dir) 183 | 184 | cur_nimg = start * submit_config.batch_size 185 | cur_tick = 0 186 | tick_start_nimg = cur_nimg 187 | start_time = time.time() 188 | 189 | print('Optimization starts!!!') 190 | for it in range(start, max_iters): 191 | 192 | batch_images = sess.run(image_batch_train) 193 | feed_dict = {real_train: batch_images} 194 | _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict) 195 | _, d_r_, d_f_, d_g_ = sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict) 196 | 197 | cur_nimg += submit_config.batch_size 198 | 199 | if it % 50 == 0: 200 | print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % ( 201 | it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time))) 202 | sys.stdout.flush() 203 | tflib.autosummary.save_summaries(summary_log, it) 204 | 205 | if cur_nimg >= tick_start_nimg + 65000: 206 | cur_tick += 1 207 | tick_start_nimg = cur_nimg 208 | 209 | if cur_tick % image_snapshot_ticks == 0: 210 | batch_images_test = sess.run(image_batch_test) 211 | batch_images_test = misc.adjust_dynamic_range(batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) 212 | recon = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) 213 | orin_recon = np.concatenate([batch_images_test, recon], axis=0) 214 | orin_recon = adjust_pixel_range(orin_recon) 215 | orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test) 216 | # save image results during training, first row is original images and the second row is reconstructed images 217 | save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), orin_recon) 218 | 219 | if cur_tick % network_snapshot_ticks == 0: 220 | pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) 221 | misc.save_pkl((E, G, D, Gs), pkl) 222 | 223 | misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) 224 | summary_log.close() 225 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Submit a function to be run either locally or in a computing cluster.""" 9 | 10 | import copy 11 | import io 12 | import os 13 | import pathlib 14 | import pickle 15 | import platform 16 | import pprint 17 | import re 18 | import shutil 19 | import time 20 | import traceback 21 | 22 | import zipfile 23 | 24 | from enum import Enum 25 | 26 | from .. import util 27 | from ..util import EasyDict 28 | 29 | 30 | class SubmitTarget(Enum): 31 | """The target where the function should be run. 32 | 33 | LOCAL: Run it locally. 34 | """ 35 | LOCAL = 1 36 | 37 | 38 | class PathType(Enum): 39 | """Determines in which format should a path be formatted. 40 | 41 | WINDOWS: Format with Windows style. 42 | LINUX: Format with Linux/Posix style. 43 | AUTO: Use current OS type to select either WINDOWS or LINUX. 44 | """ 45 | WINDOWS = 1 46 | LINUX = 2 47 | AUTO = 3 48 | 49 | 50 | _user_name_override = None 51 | 52 | 53 | class SubmitConfig(util.EasyDict): 54 | """Strongly typed config dict needed to submit runs. 55 | 56 | Attributes: 57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 58 | run_desc: Description of the run. Will be used in the run dir and task name. 59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 61 | submit_target: Submit target enum value. Used to select where the run is actually launched. 62 | num_gpus: Number of GPUs used/requested for the run. 63 | print_info: Whether to print debug information when submitting. 64 | ask_confirmation: Whether to ask a confirmation before submitting. 65 | run_id: Automatically populated value during submit. 66 | run_name: Automatically populated value during submit. 67 | run_dir: Automatically populated value during submit. 68 | run_func_name: Automatically populated value during submit. 69 | run_func_kwargs: Automatically populated value during submit. 70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 71 | task_name: Automatically populated value during submit. 72 | host_name: Automatically populated value during submit. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | 78 | # run (set these) 79 | self.run_dir_root = "" # should always be passed through get_path_from_template 80 | self.run_desc = "" 81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] 82 | self.run_dir_extra_files = None 83 | 84 | # submit (set these) 85 | self.submit_target = SubmitTarget.LOCAL 86 | self.num_gpus = 1 87 | self.print_info = False 88 | self.ask_confirmation = False 89 | 90 | # (automatically populated) 91 | self.run_id = None 92 | self.run_name = None 93 | self.run_dir = None 94 | self.run_func_name = None 95 | self.run_func_kwargs = None 96 | self.user_name = None 97 | self.task_name = None 98 | self.host_name = "localhost" 99 | 100 | 101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 102 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 103 | # automatically select path type depending on running OS 104 | if path_type == PathType.AUTO: 105 | if platform.system() == "Windows": 106 | path_type = PathType.WINDOWS 107 | elif platform.system() == "Linux": 108 | path_type = PathType.LINUX 109 | else: 110 | raise RuntimeError("Unknown platform") 111 | 112 | path_template = path_template.replace("", get_user_name()) 113 | 114 | # return correctly formatted path 115 | if path_type == PathType.WINDOWS: 116 | return str(pathlib.PureWindowsPath(path_template)) 117 | elif path_type == PathType.LINUX: 118 | return str(pathlib.PurePosixPath(path_template)) 119 | else: 120 | raise RuntimeError("Unknown platform") 121 | 122 | 123 | def get_template_from_path(path: str) -> str: 124 | """Convert a normal path back to its template representation.""" 125 | # replace all path parts with the template tags 126 | path = path.replace("\\", "/") 127 | return path 128 | 129 | 130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 131 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 132 | path_template = get_template_from_path(path) 133 | path = get_path_from_template(path_template, path_type) 134 | return path 135 | 136 | 137 | def set_user_name_override(name: str) -> None: 138 | """Set the global username override value.""" 139 | global _user_name_override 140 | _user_name_override = name 141 | 142 | 143 | def get_user_name(): 144 | """Get the current user name.""" 145 | if _user_name_override is not None: 146 | return _user_name_override 147 | elif platform.system() == "Windows": 148 | return os.getlogin() 149 | elif platform.system() == "Linux": 150 | try: 151 | import pwd # pylint: disable=import-error 152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member 153 | except: 154 | return "unknown" 155 | else: 156 | raise RuntimeError("Unknown platform") 157 | 158 | 159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 160 | """Create a new run dir with increasing ID number at the start.""" 161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 162 | 163 | if not os.path.exists(run_dir_root): 164 | print("Creating the run dir root: {}".format(run_dir_root)) 165 | os.makedirs(run_dir_root) 166 | 167 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 169 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 170 | 171 | if os.path.exists(run_dir): 172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 173 | 174 | print("Creating the run dir: {}".format(run_dir)) 175 | os.makedirs(run_dir) 176 | 177 | return run_dir 178 | 179 | 180 | def _get_next_run_id_local(run_dir_root: str) -> int: 181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 183 | r = re.compile("^\\d+") # match one or more digits at the start of the string 184 | run_id = 0 185 | 186 | for dir_name in dir_names: 187 | m = r.match(dir_name) 188 | 189 | if m is not None: 190 | i = int(m.group()) 191 | run_id = max(run_id, i + 1) 192 | 193 | return run_id 194 | 195 | 196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: 197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 198 | print("Copying files to the run dir") 199 | files = [] 200 | 201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 202 | assert '.' in submit_config.run_func_name 203 | for _idx in range(submit_config.run_func_name.count('.') - 1): 204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 206 | 207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 209 | 210 | if submit_config.run_dir_extra_files is not None: 211 | files += submit_config.run_dir_extra_files 212 | 213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] 215 | 216 | util.copy_files_and_create_dirs(files) 217 | 218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 219 | 220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 222 | 223 | 224 | def run_wrapper(submit_config: SubmitConfig) -> None: 225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 227 | 228 | checker = None 229 | 230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 231 | if is_local: 232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 234 | logger = util.Logger(file_name=None, should_flush=True) 235 | 236 | import dnnlib 237 | dnnlib.submit_config = submit_config 238 | 239 | try: 240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 241 | start_time = time.time() 242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) 243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 244 | except: 245 | if is_local: 246 | raise 247 | else: 248 | traceback.print_exc() 249 | 250 | log_src = os.path.join(submit_config.run_dir, "log.txt") 251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 252 | shutil.copyfile(log_src, log_dst) 253 | finally: 254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 255 | 256 | dnnlib.submit_config = None 257 | logger.close() 258 | 259 | if checker is not None: 260 | checker.stop() 261 | 262 | 263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 265 | submit_config = copy.copy(submit_config) 266 | 267 | if submit_config.user_name is None: 268 | submit_config.user_name = get_user_name() 269 | 270 | submit_config.run_func_name = run_func_name 271 | submit_config.run_func_kwargs = run_func_kwargs 272 | 273 | assert submit_config.submit_target == SubmitTarget.LOCAL 274 | if submit_config.submit_target in {SubmitTarget.LOCAL}: 275 | run_dir = _create_run_dir_local(submit_config) 276 | 277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 278 | submit_config.run_dir = run_dir 279 | _populate_run_dir(run_dir, submit_config) 280 | 281 | if submit_config.print_info: 282 | print("\nSubmit config:\n") 283 | pprint.pprint(submit_config, indent=4, width=200, compact=False) 284 | print() 285 | 286 | if submit_config.ask_confirmation: 287 | if not util.ask_yes_no("Continue submitting the job?"): 288 | return 289 | 290 | run_wrapper(submit_config) 291 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Multi-resolution input data pipeline.""" 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import tensorflow as tf 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | 17 | #---------------------------------------------------------------------------- 18 | # Parse individual image from a tfrecords file. 19 | 20 | def parse_tfrecord_tf(record): 21 | features = tf.parse_single_example(record, features={ 22 | 'shape': tf.FixedLenFeature([3], tf.int64), 23 | 'data': tf.FixedLenFeature([], tf.string)}) 24 | data = tf.decode_raw(features['data'], tf.uint8) 25 | return tf.reshape(data, features['shape']) 26 | 27 | def parse_tfrecord_np(record): 28 | ex = tf.train.Example() 29 | ex.ParseFromString(record) 30 | shape = ex.features.feature['shape'].int64_list.value # temporary pylint workaround # pylint: disable=no-member 31 | data = ex.features.feature['data'].bytes_list.value[0] # temporary pylint workaround # pylint: disable=no-member 32 | return np.fromstring(data, np.uint8).reshape(shape) 33 | 34 | #---------------------------------------------------------------------------- 35 | # Dataset class that loads data from tfrecords files. 36 | 37 | class TFRecordDataset: 38 | def __init__(self, 39 | tfrecord_dir, # Directory containing a collection of tfrecords files. 40 | resolution = None, # Dataset resolution, None = autodetect. 41 | label_file = None, # Relative path of the labels file, None = autodetect. 42 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. 43 | repeat = True, # Repeat dataset indefinitely. 44 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 45 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. 46 | buffer_mb = 256, # Read buffer size (megabytes). 47 | num_threads = 2): # Number of concurrent threads. 48 | 49 | self.tfrecord_dir = tfrecord_dir 50 | self.resolution = None 51 | self.resolution_log2 = None 52 | self.shape = [] # [channel, height, width] 53 | self.dtype = 'uint8' 54 | self.dynamic_range = [0, 255] 55 | self.label_file = label_file 56 | self.label_size = None # [component] 57 | self.label_dtype = None 58 | self._np_labels = None 59 | self._tf_minibatch_in = None 60 | self._tf_labels_var = None 61 | self._tf_labels_dataset = None 62 | self._tf_datasets = dict() 63 | self._tf_iterator = None 64 | self._tf_init_ops = dict() 65 | self._tf_minibatch_np = None 66 | self._cur_minibatch = -1 67 | self._cur_lod = -1 68 | 69 | # List tfrecords files and inspect their shapes. 70 | assert os.path.isdir(self.tfrecord_dir) 71 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) 72 | assert len(tfr_files) >= 1 73 | tfr_shapes = [] 74 | for tfr_file in tfr_files: 75 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 76 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): 77 | tfr_shapes.append(parse_tfrecord_np(record).shape) 78 | break 79 | 80 | # Autodetect label filename. 81 | if self.label_file is None: 82 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) 83 | if len(guess): 84 | self.label_file = guess[0] 85 | elif not os.path.isfile(self.label_file): 86 | guess = os.path.join(self.tfrecord_dir, self.label_file) 87 | if os.path.isfile(guess): 88 | self.label_file = guess 89 | 90 | # Determine shape and resolution. 91 | max_shape = max(tfr_shapes, key=np.prod) 92 | self.resolution = resolution if resolution is not None else max_shape[1] 93 | self.resolution_log2 = int(np.log2(self.resolution)) 94 | self.shape = [max_shape[0], self.resolution, self.resolution] 95 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] 96 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes) 97 | assert all(shape[1] == shape[2] for shape in tfr_shapes) 98 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) 99 | assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) 100 | 101 | # Load labels. 102 | assert max_label_size == 'full' or max_label_size >= 0 103 | self._np_labels = np.zeros([1<<20, 0], dtype=np.float32) 104 | if self.label_file is not None and max_label_size != 0: 105 | self._np_labels = np.load(self.label_file) 106 | assert self._np_labels.ndim == 2 107 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: 108 | self._np_labels = self._np_labels[:, :max_label_size] 109 | self.label_size = self._np_labels.shape[1] 110 | self.label_dtype = self._np_labels.dtype.name 111 | 112 | # Build TF expressions. 113 | with tf.name_scope('Dataset'), tf.device('/cpu:0'): 114 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) 115 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') 116 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) 117 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): 118 | if tfr_lod < 0: 119 | continue 120 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) 121 | dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads) 122 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) 123 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize 124 | if shuffle_mb > 0: 125 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) 126 | if repeat: 127 | dset = dset.repeat() 128 | if prefetch_mb > 0: 129 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) 130 | dset = dset.batch(self._tf_minibatch_in) 131 | self._tf_datasets[tfr_lod] = dset 132 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) 133 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} 134 | 135 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). 136 | def configure(self, minibatch_size, lod=0): 137 | lod = int(np.floor(lod)) 138 | assert minibatch_size >= 1 and lod in self._tf_datasets 139 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod: 140 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 141 | self._cur_minibatch = minibatch_size 142 | self._cur_lod = lod 143 | 144 | # Get next minibatch as TensorFlow expressions. 145 | def get_minibatch_tf(self): # => images, labels 146 | return self._tf_iterator.get_next() 147 | 148 | # Get next minibatch as NumPy arrays. 149 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 150 | self.configure(minibatch_size, lod) 151 | if self._tf_minibatch_np is None: 152 | self._tf_minibatch_np = self.get_minibatch_tf() 153 | return tflib.run(self._tf_minibatch_np) 154 | 155 | # Get random labels as TensorFlow expression. 156 | def get_random_labels_tf(self, minibatch_size): # => labels 157 | if self.label_size > 0: 158 | with tf.device('/cpu:0'): 159 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) 160 | return tf.zeros([minibatch_size, 0], self.label_dtype) 161 | 162 | # Get random labels as NumPy array. 163 | def get_random_labels_np(self, minibatch_size): # => labels 164 | if self.label_size > 0: 165 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] 166 | return np.zeros([minibatch_size, 0], self.label_dtype) 167 | 168 | #---------------------------------------------------------------------------- 169 | # Base class for datasets that are generated on the fly. 170 | 171 | class SyntheticDataset: 172 | def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'): 173 | self.resolution = resolution 174 | self.resolution_log2 = int(np.log2(resolution)) 175 | self.shape = [num_channels, resolution, resolution] 176 | self.dtype = dtype 177 | self.dynamic_range = dynamic_range 178 | self.label_size = label_size 179 | self.label_dtype = label_dtype 180 | self._tf_minibatch_var = None 181 | self._tf_lod_var = None 182 | self._tf_minibatch_np = None 183 | self._tf_labels_np = None 184 | 185 | assert self.resolution == 2 ** self.resolution_log2 186 | with tf.name_scope('Dataset'): 187 | self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var') 188 | self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var') 189 | 190 | def configure(self, minibatch_size, lod=0): 191 | lod = int(np.floor(lod)) 192 | assert minibatch_size >= 1 and 0 <= lod <= self.resolution_log2 193 | tflib.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod}) 194 | 195 | def get_minibatch_tf(self): # => images, labels 196 | with tf.name_scope('SyntheticDataset'): 197 | shrink = tf.cast(2.0 ** tf.cast(self._tf_lod_var, tf.float32), tf.int32) 198 | shape = [self.shape[0], self.shape[1] // shrink, self.shape[2] // shrink] 199 | images = self._generate_images(self._tf_minibatch_var, self._tf_lod_var, shape) 200 | labels = self._generate_labels(self._tf_minibatch_var) 201 | return images, labels 202 | 203 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 204 | self.configure(minibatch_size, lod) 205 | if self._tf_minibatch_np is None: 206 | self._tf_minibatch_np = self.get_minibatch_tf() 207 | return tflib.run(self._tf_minibatch_np) 208 | 209 | def get_random_labels_tf(self, minibatch_size): # => labels 210 | with tf.name_scope('SyntheticDataset'): 211 | return self._generate_labels(minibatch_size) 212 | 213 | def get_random_labels_np(self, minibatch_size): # => labels 214 | self.configure(minibatch_size) 215 | if self._tf_labels_np is None: 216 | self._tf_labels_np = self.get_random_labels_tf(minibatch_size) 217 | return tflib.run(self._tf_labels_np) 218 | 219 | def _generate_images(self, minibatch, lod, shape): # to be overridden by subclasses # pylint: disable=unused-argument 220 | return tf.zeros([minibatch] + shape, self.dtype) 221 | 222 | def _generate_labels(self, minibatch): # to be overridden by subclasses 223 | return tf.zeros([minibatch, self.label_size], self.label_dtype) 224 | 225 | #---------------------------------------------------------------------------- 226 | # Helper func for constructing a dataset object using the given options. 227 | 228 | def load_dataset(class_name='training.dataset.TFRecordDataset', data_dir=None, verbose=False, **kwargs): 229 | adjusted_kwargs = dict(kwargs) 230 | if 'tfrecord_dir' in adjusted_kwargs and data_dir is not None: 231 | adjusted_kwargs['tfrecord_dir'] = os.path.join(data_dir, adjusted_kwargs['tfrecord_dir']) 232 | if verbose: 233 | print('Streaming data using %s...' % class_name) 234 | dataset = dnnlib.util.get_obj_by_name(class_name)(**adjusted_kwargs) 235 | if verbose: 236 | print('Dataset shape =', np.int32(dataset.shape).tolist()) 237 | print('Dynamic range =', dataset.dynamic_range) 238 | print('Label size =', dataset.label_size) 239 | return dataset 240 | 241 | #---------------------------------------------------------------------------- 242 | --------------------------------------------------------------------------------