├── .gitignore ├── Dockerfile ├── LICENSE.txt ├── README.md ├── dataset_tools ├── create_from_images.py └── tfrecord_utils.py ├── dnnlib ├── __init__.py ├── submission │ ├── __init__.py │ ├── internal │ │ ├── __init__.py │ │ └── local.py │ ├── run_context.py │ └── submit.py ├── tflib │ ├── __init__.py │ ├── autosummary.py │ ├── custom_ops.py │ ├── network.py │ ├── ops │ │ ├── __init__.py │ │ ├── fused_bias_act.cu │ │ ├── fused_bias_act.py │ │ ├── upfirdn_2d.cu │ │ └── upfirdn_2d.py │ ├── optimizer.py │ └── tfutil.py └── util.py ├── imgs ├── demo.gif ├── example_image.jpg ├── example_mask.jpg └── grid-main.jpg ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── inception_discriminative_score.py ├── learned_perceptual_image_patch_similarity.py ├── metric_base.py └── metric_defaults.py ├── run_demo.py ├── run_generator.py ├── run_metrics.py ├── run_training.py └── training ├── __init__.py ├── co_mod_gan.py ├── dataset.py ├── loss.py ├── mask_generator.py ├── misc.py └── training_loop.py /.gitignore: -------------------------------------------------------------------------------- 1 | /results/ 2 | /images/ 3 | /models/ 4 | /datasets/ 5 | /test/ 6 | /training/_my_pm.pyd 7 | /.stylegan2-cache/ 8 | *.db 9 | *.pyc 10 | *.dll 11 | *.pkl 12 | *.ipynb 13 | *.zip 14 | *.tfrecords 15 | .vscode 16 | *.csv -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | FROM tensorflow/tensorflow:1.15.0-gpu-py3 8 | 9 | RUN pip install scipy==1.3.3 10 | RUN pip install requests==2.22.0 11 | RUN pip install Pillow==6.2.1 12 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Shengyu Zhao, Jonathan Cui, Yilun Sheng, Yue Dong, 2 | Xiao Liang, Eric I Chang, Yan Xu 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | 27 | ----------------------- LICENSE FOR stylegan2 ----------------------- 28 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 29 | 30 | 31 | Nvidia Source Code License-NC 32 | 33 | ======================================================================= 34 | 35 | 1. Definitions 36 | 37 | "Licensor" means any person or entity that distributes its Work. 38 | 39 | "Software" means the original work of authorship made available under 40 | this License. 41 | 42 | "Work" means the Software and any additions to or derivative works of 43 | the Software that are made available under this License. 44 | 45 | "Nvidia Processors" means any central processing unit (CPU), graphics 46 | processing unit (GPU), field-programmable gate array (FPGA), 47 | application-specific integrated circuit (ASIC) or any combination 48 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 49 | 50 | The terms "reproduce," "reproduction," "derivative works," and 51 | "distribution" have the meaning as provided under U.S. copyright law; 52 | provided, however, that for the purposes of this License, derivative 53 | works shall not include works that remain separable from, or merely 54 | link (or bind by name) to the interfaces of, the Work. 55 | 56 | Works, including the Software, are "made available" under this License 57 | by including in or with the Work either (a) a copyright notice 58 | referencing the applicability of this License to the Work, or (b) a 59 | copy of this License. 60 | 61 | 2. License Grants 62 | 63 | 2.1 Copyright Grant. Subject to the terms and conditions of this 64 | License, each Licensor grants to you a perpetual, worldwide, 65 | non-exclusive, royalty-free, copyright license to reproduce, 66 | prepare derivative works of, publicly display, publicly perform, 67 | sublicense and distribute its Work and any resulting derivative 68 | works in any form. 69 | 70 | 3. Limitations 71 | 72 | 3.1 Redistribution. You may reproduce or distribute the Work only 73 | if (a) you do so under this License, (b) you include a complete 74 | copy of this License with your distribution, and (c) you retain 75 | without modification any copyright, patent, trademark, or 76 | attribution notices that are present in the Work. 77 | 78 | 3.2 Derivative Works. You may specify that additional or different 79 | terms apply to the use, reproduction, and distribution of your 80 | derivative works of the Work ("Your Terms") only if (a) Your Terms 81 | provide that the use limitation in Section 3.3 applies to your 82 | derivative works, and (b) you identify the specific derivative 83 | works that are subject to Your Terms. Notwithstanding Your Terms, 84 | this License (including the redistribution requirements in Section 85 | 3.1) will continue to apply to the Work itself. 86 | 87 | 3.3 Use Limitation. The Work and any derivative works thereof only 88 | may be used or intended for use non-commercially. The Work or 89 | derivative works thereof may be used or intended for use by Nvidia 90 | or its affiliates commercially or non-commercially. As used herein, 91 | "non-commercially" means for research or evaluation purposes only. 92 | 93 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 94 | against any Licensor (including any claim, cross-claim or 95 | counterclaim in a lawsuit) to enforce any patents that you allege 96 | are infringed by any Work, then your rights under this License from 97 | such Licensor (including the grants in Sections 2.1 and 2.2) will 98 | terminate immediately. 99 | 100 | 3.5 Trademarks. This License does not grant any rights to use any 101 | Licensor's or its affiliates' names, logos, or trademarks, except 102 | as necessary to reproduce the notices described in this License. 103 | 104 | 3.6 Termination. If you violate any term of this License, then your 105 | rights under this License (including the grants in Sections 2.1 and 106 | 2.2) will terminate immediately. 107 | 108 | 4. Disclaimer of Warranty. 109 | 110 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 111 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 112 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 113 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 114 | THIS LICENSE. 115 | 116 | 5. Limitation of Liability. 117 | 118 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 119 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 120 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 121 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 122 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 123 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 124 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 125 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 126 | THE POSSIBILITY OF SUCH DAMAGES. 127 | 128 | ======================================================================= 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) 2 | 3 | ### [Demo (Unofficial)](https://www.microsoft.com/en-us/ai/ai-lab-CoModGAN) | [Paper](https://openreview.net/pdf?id=sSjqmfsk95O) 4 | 5 | 6 | 7 | **[NEW!]** Another [unofficial demo](https://www.microsoft.com/en-us/ai/ai-lab-CoModGAN) is available! 8 | 9 | **[NOTICE]** Our web demo will be closed recently. Enjoy the last days! 10 | 11 | **[NEW!]** Time to play with our [interactive web demo](http://comodgan.ml)! 12 | 13 | *Numerous task-specific variants of conditional generative adversarial networks have been developed for image completion. Yet, a serious limitation remains that all existing algorithms tend to fail when handling **large-scale missing regions**. To overcome this challenge, we propose a generic new approach that bridges the gap between image-conditional and recent modulated unconditional generative architectures via **co-modulation** of both conditional and stochastic style representations. Also, due to the lack of good quantitative metrics for image completion, we propose the new **Paired/Unpaired Inception Discriminative Score (P-IDS/U-IDS)**, which robustly measures the perceptual fidelity of inpainted images compared to real images via linear separability in a feature space. Experiments demonstrate superior performance in terms of both quality and diversity over state-of-the-art methods in free-form image completion and easy generalization to image-to-image translation.* 14 | 15 | 16 | 17 | Large Scale Image Completion via Co-Modulated Generative Adversarial Networks
18 | [Shengyu Zhao](https://scholar.google.com/citations?user=gLCdw70AAAAJ), [Jonathan Cui](https://www.linkedin.com/in/jonathan-cui-110b211a6/), Yilun Sheng, Yue Dong, Xiao Liang, Eric I Chang, Yan Xu
19 | Tsinghua University and Microsoft Research
20 | [arXiv](http://arxiv.org/abs/2103.10428) | [OpenReview](https://openreview.net/pdf?id=sSjqmfsk95O) 21 | 22 | ## Overview 23 | 24 | This repo is implemented upon and has the same dependencies as the official [StyleGAN2 repo](https://github.com/NVlabs/stylegan2). We also provide a [Dockerfile](https://github.com/zsyzzsoft/co-mod-gan/blob/master/Dockerfile) for Docker users. This repo currently supports: 25 | - [x] Large scale image completion experiments on FFHQ and Places2 26 | - [x] Image-to-image translation experiments on Edges2Shoes and Edges2Handbags 27 | - [ ] Image-to-image translation experiments on COCO-Stuff 28 | - [x] Evaluation code of *Paired/Unpaired Inception Discriminative Score (P-IDS/U-IDS)* 29 | 30 | ## Datasets 31 | 32 | - FFHQ dataset (in TFRecords format) can be downloaded following the [StyleGAN2 repo](https://github.com/NVlabs/stylegan2). 33 | - Places2 dataset can be downloaded in [this website](http://places2.csail.mit.edu/download.html) (Places365-Challenge 2016 high-resolution images, [training set](http://data.csail.mit.edu/places/places365/train_large_places365challenge.tar) and [validation set](http://data.csail.mit.edu/places/places365/val_large.tar)). The raw images should be converted into TFRecords using `dataset_tools/create_from_images.py` with `--shuffle --compressed`. 34 | - Edges2Shoes and Edges2Handbags datasets can be downloaded following the [pix2pix repo](https://github.com/phillipi/pix2pix). The raw images should be converted into TFRecords using `dataset_tools/create_from_images.py` with `--shuffle --pix2pix`. 35 | - To prepare a custom dataset, please use `dataset_tools/create_from_images.py`, which will automatically center crop and resize your images to the specified resolution. You only need to specify `--val-image-dir` for testing purpose. 36 | 37 | ## Training 38 | 39 | The following script is for training on FFHQ. It will split 10k images for validation. We recommend using 8 NVIDIA Tesla V100 GPUs for training. Training at 512x512 resolution takes about 1 week. 40 | 41 | ```bash 42 | python run_training.py --data-dir=DATA_DIR --dataset=DATASET --metrics=ids10k --mirror-augment --num-gpus=8 43 | ``` 44 | 45 | The following script is for training on Places2 at resolution 512x512 (resolution must be specified when training on compressed dataset), which has a validation set of 36500 images: 46 | 47 | ```bash 48 | python run_training.py --data-dir=DATA_DIR --dataset=DATASET --resolution=512 --metrics=ids36k5 --total-kimg 50000 --num-gpus=8 49 | ``` 50 | 51 | The following script is for training on Edges2Handbags (and similarly for Edges2Shoes): 52 | 53 | ```bash 54 | python run_training.py --data-dir=DATA_DIR --dataset=DATASET --metrics=fid200-rt-handbags --mirror-augment --num-gpus=8 55 | ``` 56 | 57 | ## Pre-Trained Models 58 | 59 | Our pre-trained models are available on [Google Drive](https://drive.google.com/drive/folders/1zSJj1ichgSA-4sECGm-fQ0Ww8aiwpkoO): 60 | 61 | | Model name & URL | Description | 62 | | ------------------------------------------------------------ | ------------------------------------------------------------ | 63 | | [co-mod-gan-ffhq-9-025000.pkl](https://drive.google.com/file/d/1b3XxfAmJ9k2vd73j-3nPMr_lvNMQOFGE/view?usp=sharing) | Large scale image completion on FFHQ (512x512) | 64 | | [co-mod-gan-ffhq-10-025000.pkl](https://drive.google.com/file/d/1M2dSxlJnCFNM6LblpB2nQCnaimgwaaKu/view?usp=sharing) | Large scale image completion on FFHQ (1024x1024) | 65 | | [co-mod-gan-places2-050000.pkl](https://drive.google.com/file/d/1dJa3DRWIkx6Ebr8Sc0v1FdvWf6wkd010/view?usp=sharing) | Large scale image completion on Places2 (512x512) | 66 | | [co-mod-gan-coco-stuff-025000.pkl](https://drive.google.com/file/d/1Ol9_pKMpfIHHwbdE7RFmJcCAzfj8hqxQ/view?usp=sharing) | Image-to-image translation on COCO-Stuff (labels to photos) (512x512) | 67 | | [co-mod-gan-edges2shoes-025000.pkl](https://drive.google.com/file/d/155p-_zAtL8RJSsKHAWrRaGxJVzT4NZKg/view?usp=sharing) | Image-to-image translation on edges2shoes (256x256) | 68 | | [co-mod-gan-edges2handbags-025000.pkl](https://drive.google.com/file/d/1nBIQaUs6fXRpEt1cweqQKtWVw5UZAqLi/view?usp=sharing) | Image-to-image translation on edges2handbags (256x256) | 69 | 70 | Use the following script to run the interactive demo locally: 71 | 72 | ```bash 73 | python run_demo.py -d DATA_DIR/DATASET -c CHECKPOINT_FILE(S) 74 | ``` 75 | 76 | or the following command as a minimal example of usage: 77 | 78 | ```bash 79 | python run_generator.py -c CHECKPOINT_FILE -i imgs/example_image.jpg -m imgs/example_mask.jpg -o imgs/example_output.jpg 80 | ``` 81 | 82 | ## Evaluation 83 | 84 | The following script is for evaluation: 85 | 86 | ```bash 87 | python run_metrics.py --data-dir=DATA_DIR --dataset=DATASET --network=CHECKPOINT_FILE(S) --metrics=METRIC(S) --num-gpus=1 88 | ``` 89 | 90 | Commonly used metrics are `ids10k` and `ids36k5` (for FFHQ and Places2 respectively), which will compute P-IDS and U-IDS together with FID. By default, masks are generated randomly for evaluation, or you may append the metric name with `-h0` ([0.0, 0.2]) to `-h4` ([0.8, 1.0]) to specify the range of masked ratio. 91 | 92 | ## Citation 93 | 94 | If you find this code helpful, please cite our paper: 95 | ``` 96 | @inproceedings{zhao2021comodgan, 97 | title={Large Scale Image Completion via Co-Modulated Generative Adversarial Networks}, 98 | author={Zhao, Shengyu and Cui, Jonathan and Sheng, Yilun and Dong, Yue and Liang, Xiao and Chang, Eric I and Xu, Yan}, 99 | booktitle={International Conference on Learning Representations (ICLR)}, 100 | year={2021} 101 | } 102 | ``` -------------------------------------------------------------------------------- /dataset_tools/create_from_images.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import numpy as np 3 | import argparse 4 | from tqdm import tqdm 5 | import random 6 | import os 7 | import PIL.Image 8 | 9 | from tfrecord_utils import TFRecordExporter 10 | 11 | def worker(in_queue, out_queue, resolution, compressed, pix2pix): 12 | while True: 13 | fpath = in_queue.get() 14 | if compressed: 15 | assert not pix2pix 16 | if fpath.endswith('.jpg') or fpath.endswith('.JPG'): 17 | img = np.fromfile(fpath, dtype='uint8') 18 | else: 19 | img = None 20 | else: 21 | try: 22 | img = PIL.Image.open(fpath) 23 | except IOError: 24 | img = None 25 | else: 26 | img_size = min(img.size[0] // 2 if pix2pix else img.size[1], img.size[1]) 27 | left = (img.size[0] - (img_size * 2 if pix2pix else img_size)) // 2 28 | top = (img.size[1] - img_size) // 2 29 | img = img.crop((left, top, left + (img_size * 2 if pix2pix else img_size), top + img_size)) 30 | img = img.resize((resolution * 2 if pix2pix else resolution, resolution), PIL.Image.BILINEAR) 31 | img = np.asarray(img.convert('RGB')).transpose([2, 0, 1]) 32 | if pix2pix: 33 | img = np.concatenate(np.split(img, 2, axis=-1), axis=0) 34 | out_queue.put(img) 35 | 36 | def create_from_images(tfrecord_dir, val_image_dir, train_image_dir, resolution, num_channels, num_processes, shuffle, compressed, pix2pix): 37 | in_queue = mp.Queue() 38 | out_queue = mp.Queue(num_processes * 8) 39 | 40 | worker_procs = [mp.Process(target=worker, args=(in_queue, out_queue, resolution, compressed, pix2pix)) for _ in range(num_processes)] 41 | for worker_proc in worker_procs: 42 | worker_proc.daemon = True 43 | worker_proc.start() 44 | 45 | print('Processes created.') 46 | 47 | with TFRecordExporter(tfrecord_dir, compressed=compressed) as tfr: 48 | tfr.set_shape([num_channels * 2 if pix2pix else num_channels, resolution, resolution]) 49 | 50 | if val_image_dir: 51 | print('Processing validation images...') 52 | flist = [] 53 | for root, _, files in os.walk(val_image_dir): 54 | print(root) 55 | flist.extend([os.path.join(root, fname) for fname in files]) 56 | tfr.set_num_val_images(len(flist)) 57 | if shuffle: 58 | random.shuffle(flist) 59 | for fpath in tqdm(flist): 60 | in_queue.put(fpath, block=False) 61 | for _ in tqdm(range(len(flist))): 62 | img = out_queue.get() 63 | if img is not None: 64 | tfr.add_image(img) 65 | 66 | if train_image_dir: 67 | print('Processing training images...') 68 | flist = [] 69 | for root, _, files in os.walk(train_image_dir): 70 | print(root) 71 | flist.extend([os.path.join(root, fname) for fname in files]) 72 | if shuffle: 73 | random.shuffle(flist) 74 | for fpath in tqdm(flist): 75 | in_queue.put(fpath, block=False) 76 | for _ in tqdm(range(len(flist))): 77 | img = out_queue.get() 78 | if img is not None: 79 | tfr.add_image(img) 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--tfrecord-dir', help='Output directory of generated TFRecord', required=True) 84 | parser.add_argument('--val-image-dir', help='Root directory of validation images', default=None) 85 | parser.add_argument('--train-image-dir', help='Root directory of training images', default=None) 86 | parser.add_argument('--resolution', help='Target resolution', type=int, default=512) 87 | parser.add_argument('--num-channels', help='Number of channels of images', type=int, default=3) 88 | parser.add_argument('--num-processes', help='Number of parallel processes', type=int, default=8) 89 | parser.add_argument('--shuffle', default=False, action='store_true') 90 | parser.add_argument('--compressed', default=False, action='store_true') 91 | parser.add_argument('--pix2pix', default=False, action='store_true') 92 | 93 | args = parser.parse_args() 94 | create_from_images(**vars(args)) 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /dataset_tools/tfrecord_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Tool for creating TFRecords datasets.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | class TFRecordExporter: 16 | def __init__(self, tfrecord_dir, compressed=False): 17 | self.tfrecord_dir = tfrecord_dir 18 | self.num_val_images = 0 19 | self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) 20 | self.shape = None 21 | self.resolution_log2 = None 22 | self.tfr_writer = None 23 | self.compressed = compressed 24 | 25 | if not os.path.isdir(self.tfrecord_dir): 26 | os.makedirs(self.tfrecord_dir) 27 | assert os.path.isdir(self.tfrecord_dir) 28 | 29 | def close(self): 30 | self.tfr_writer.close() 31 | self.tfr_writer = None 32 | 33 | def set_shape(self, shape): 34 | self.shape = shape 35 | self.resolution_log2 = int(np.log2(self.shape[1])) 36 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 37 | tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % self.resolution_log2 38 | self.tfr_writer = tf.python_io.TFRecordWriter(tfr_file, tfr_opt) 39 | 40 | def set_num_val_images(self, num_val_images): 41 | self.num_val_images = num_val_images 42 | 43 | def add_image(self, img): 44 | if self.shape is None: 45 | self.set_shape(img.shape) 46 | if not self.compressed: 47 | assert list(self.shape) == list(img.shape) 48 | quant = np.rint(img).clip(0, 255).astype(np.uint8) if not self.compressed else img 49 | ex = tf.train.Example(features=tf.train.Features(feature={ 50 | 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=self.shape)), 51 | 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()])), 52 | 'compressed': tf.train.Feature(int64_list=tf.train.Int64List(value=[self.compressed])), 53 | 'num_val_images': tf.train.Feature(int64_list=tf.train.Int64List(value=[self.num_val_images])), 54 | })) 55 | self.tfr_writer.write(ex.SerializeToString()) 56 | 57 | def __enter__(self): 58 | return self 59 | 60 | def __exit__(self, *args): 61 | self.close() 62 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import submission 8 | 9 | from .submission.run_context import RunContext 10 | 11 | from .submission.submit import SubmitTarget 12 | from .submission.submit import PathType 13 | from .submission.submit import SubmitConfig 14 | from .submission.submit import submit_run 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import convert_path 17 | from .submission.submit import make_run_dir_path 18 | 19 | from .util import EasyDict 20 | 21 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 22 | -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import run_context 8 | from . import submit 9 | -------------------------------------------------------------------------------- /dnnlib/submission/internal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import local 8 | -------------------------------------------------------------------------------- /dnnlib/submission/internal/local.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | class TargetOptions(): 8 | def __init__(self): 9 | self.do_not_copy_source_files = False 10 | 11 | class Target(): 12 | def __init__(self): 13 | pass 14 | 15 | def finalize_submit_config(self, submit_config, host_run_dir): 16 | print ('Local submit ', end='', flush=True) 17 | submit_config.run_dir = host_run_dir 18 | 19 | def submit(self, submit_config, host_run_dir): 20 | from ..submit import run_wrapper, convert_path 21 | print('- run_dir: %s' % convert_path(submit_config.run_dir), flush=True) 22 | return run_wrapper(submit_config) 23 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Helpers for managing the run/training loop.""" 8 | 9 | import datetime 10 | import json 11 | import os 12 | import pprint 13 | import time 14 | import types 15 | 16 | from typing import Any 17 | 18 | from . import submit 19 | 20 | # Singleton RunContext 21 | _run_context = None 22 | 23 | class RunContext(object): 24 | """Helper class for managing the run/training loop. 25 | 26 | The context will hide the implementation details of a basic run/training loop. 27 | It will set things up properly, tell if run should be stopped, and then cleans up. 28 | User should call update periodically and use should_stop to determine if run should be stopped. 29 | 30 | Args: 31 | submit_config: The SubmitConfig that is used for the current run. 32 | config_module: (deprecated) The whole config module that is used for the current run. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None): 36 | global _run_context 37 | # Only a single RunContext can be alive 38 | assert _run_context is None 39 | _run_context = self 40 | self.submit_config = submit_config 41 | self.should_stop_flag = False 42 | self.has_closed = False 43 | self.start_time = time.time() 44 | self.last_update_time = time.time() 45 | self.last_update_interval = 0.0 46 | self.progress_monitor_file_path = None 47 | 48 | # vestigial config_module support just prints a warning 49 | if config_module is not None: 50 | print("RunContext.config_module parameter support has been removed.") 51 | 52 | # write out details about the run to a text file 53 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 54 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 55 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 56 | 57 | def __enter__(self) -> "RunContext": 58 | return self 59 | 60 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 61 | self.close() 62 | 63 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 64 | """Do general housekeeping and keep the state of the context up-to-date. 65 | Should be called often enough but not in a tight loop.""" 66 | assert not self.has_closed 67 | 68 | self.last_update_interval = time.time() - self.last_update_time 69 | self.last_update_time = time.time() 70 | 71 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 72 | self.should_stop_flag = True 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | self.has_closed = True 99 | 100 | # detach the global singleton 101 | global _run_context 102 | if _run_context is self: 103 | _run_context = None 104 | 105 | @staticmethod 106 | def get(): 107 | import dnnlib 108 | if _run_context is not None: 109 | return _run_context 110 | return RunContext(dnnlib.submit_config) 111 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Submit a function to be run either locally or in a computing cluster.""" 8 | 9 | import copy 10 | import inspect 11 | import os 12 | import pathlib 13 | import pickle 14 | import platform 15 | import pprint 16 | import re 17 | import shutil 18 | import sys 19 | import time 20 | import traceback 21 | 22 | from enum import Enum 23 | 24 | from .. import util 25 | from ..util import EasyDict 26 | 27 | from . import internal 28 | 29 | class SubmitTarget(Enum): 30 | """The target where the function should be run. 31 | 32 | LOCAL: Run it locally. 33 | """ 34 | LOCAL = 1 35 | 36 | 37 | class PathType(Enum): 38 | """Determines in which format should a path be formatted. 39 | 40 | WINDOWS: Format with Windows style. 41 | LINUX: Format with Linux/Posix style. 42 | AUTO: Use current OS type to select either WINDOWS or LINUX. 43 | """ 44 | WINDOWS = 1 45 | LINUX = 2 46 | AUTO = 3 47 | 48 | 49 | class PlatformExtras: 50 | """A mixed bag of values used by dnnlib heuristics. 51 | 52 | Attributes: 53 | 54 | data_reader_buffer_size: Used by DataReader to size internal shared memory buffers. 55 | data_reader_process_count: Number of worker processes to spawn (zero for single thread operation) 56 | """ 57 | def __init__(self): 58 | self.data_reader_buffer_size = 1<<30 # 1 GB 59 | self.data_reader_process_count = 0 # single threaded default 60 | 61 | 62 | _user_name_override = None 63 | 64 | class SubmitConfig(util.EasyDict): 65 | """Strongly typed config dict needed to submit runs. 66 | 67 | Attributes: 68 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 69 | run_desc: Description of the run. Will be used in the run dir and task name. 70 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 71 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 72 | submit_target: Submit target enum value. Used to select where the run is actually launched. 73 | num_gpus: Number of GPUs used/requested for the run. 74 | print_info: Whether to print debug information when submitting. 75 | local.do_not_copy_source_files: Do not copy source files from the working directory to the run dir. 76 | run_id: Automatically populated value during submit. 77 | run_name: Automatically populated value during submit. 78 | run_dir: Automatically populated value during submit. 79 | run_func_name: Automatically populated value during submit. 80 | run_func_kwargs: Automatically populated value during submit. 81 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 82 | task_name: Automatically populated value during submit. 83 | host_name: Automatically populated value during submit. 84 | platform_extras: Automatically populated values during submit. Used by various dnnlib libraries such as the DataReader class. 85 | """ 86 | 87 | def __init__(self): 88 | super().__init__() 89 | 90 | # run (set these) 91 | self.run_dir_root = "" # should always be passed through get_path_from_template 92 | self.run_desc = "" 93 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode", "_cudacache"] 94 | self.run_dir_extra_files = [] 95 | 96 | # submit (set these) 97 | self.submit_target = SubmitTarget.LOCAL 98 | self.num_gpus = 1 99 | self.print_info = False 100 | self.nvprof = False 101 | self.local = internal.local.TargetOptions() 102 | self.datasets = [] 103 | 104 | # (automatically populated) 105 | self.run_id = None 106 | self.run_name = None 107 | self.run_dir = None 108 | self.run_func_name = None 109 | self.run_func_kwargs = None 110 | self.user_name = None 111 | self.task_name = None 112 | self.host_name = "localhost" 113 | self.platform_extras = PlatformExtras() 114 | 115 | 116 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 117 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 118 | # automatically select path type depending on running OS 119 | if path_type == PathType.AUTO: 120 | if platform.system() == "Windows": 121 | path_type = PathType.WINDOWS 122 | elif platform.system() == "Linux": 123 | path_type = PathType.LINUX 124 | else: 125 | raise RuntimeError("Unknown platform") 126 | 127 | path_template = path_template.replace("", get_user_name()) 128 | 129 | # return correctly formatted path 130 | if path_type == PathType.WINDOWS: 131 | return str(pathlib.PureWindowsPath(path_template)) 132 | elif path_type == PathType.LINUX: 133 | return str(pathlib.PurePosixPath(path_template)) 134 | else: 135 | raise RuntimeError("Unknown platform") 136 | 137 | 138 | def get_template_from_path(path: str) -> str: 139 | """Convert a normal path back to its template representation.""" 140 | path = path.replace("\\", "/") 141 | return path 142 | 143 | 144 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 145 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 146 | path_template = get_template_from_path(path) 147 | path = get_path_from_template(path_template, path_type) 148 | return path 149 | 150 | 151 | def set_user_name_override(name: str) -> None: 152 | """Set the global username override value.""" 153 | global _user_name_override 154 | _user_name_override = name 155 | 156 | 157 | def get_user_name(): 158 | """Get the current user name.""" 159 | if _user_name_override is not None: 160 | return _user_name_override 161 | elif platform.system() == "Windows": 162 | return os.getlogin() 163 | elif platform.system() == "Linux": 164 | try: 165 | import pwd 166 | return pwd.getpwuid(os.geteuid()).pw_name 167 | except: 168 | return "unknown" 169 | else: 170 | raise RuntimeError("Unknown platform") 171 | 172 | 173 | def make_run_dir_path(*paths): 174 | """Make a path/filename that resides under the current submit run_dir. 175 | 176 | Args: 177 | *paths: Path components to be passed to os.path.join 178 | 179 | Returns: 180 | A file/dirname rooted at submit_config.run_dir. If there's no 181 | submit_config or run_dir, the base directory is the current 182 | working directory. 183 | 184 | E.g., `os.path.join(dnnlib.submit_config.run_dir, "output.txt"))` 185 | """ 186 | import dnnlib 187 | if (dnnlib.submit_config is None) or (dnnlib.submit_config.run_dir is None): 188 | return os.path.join(os.getcwd(), *paths) 189 | return os.path.join(dnnlib.submit_config.run_dir, *paths) 190 | 191 | 192 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 193 | """Create a new run dir with increasing ID number at the start.""" 194 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 195 | 196 | if not os.path.exists(run_dir_root): 197 | os.makedirs(run_dir_root) 198 | 199 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 200 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 201 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 202 | 203 | if os.path.exists(run_dir): 204 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 205 | 206 | os.makedirs(run_dir) 207 | 208 | return run_dir 209 | 210 | 211 | def _get_next_run_id_local(run_dir_root: str) -> int: 212 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 213 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 214 | r = re.compile("^\\d+") # match one or more digits at the start of the string 215 | run_id = 0 216 | 217 | for dir_name in dir_names: 218 | m = r.match(dir_name) 219 | 220 | if m is not None: 221 | i = int(m.group()) 222 | run_id = max(run_id, i + 1) 223 | 224 | return run_id 225 | 226 | 227 | def _populate_run_dir(submit_config: SubmitConfig, run_dir: str) -> None: 228 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 229 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 230 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 231 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 232 | 233 | if (submit_config.submit_target == SubmitTarget.LOCAL) and submit_config.local.do_not_copy_source_files: 234 | return 235 | 236 | files = [] 237 | 238 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 239 | assert '.' in submit_config.run_func_name 240 | for _idx in range(submit_config.run_func_name.count('.') - 1): 241 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 242 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 243 | 244 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 245 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 246 | 247 | files += submit_config.run_dir_extra_files 248 | 249 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 250 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "internal", "run.py"), os.path.join(run_dir, "run.py"))] 251 | 252 | util.copy_files_and_create_dirs(files) 253 | 254 | 255 | 256 | def run_wrapper(submit_config: SubmitConfig) -> None: 257 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 258 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 259 | 260 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 261 | if is_local: 262 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 263 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 264 | logger = util.Logger(file_name=None, should_flush=True) 265 | 266 | import dnnlib 267 | dnnlib.submit_config = submit_config 268 | 269 | exit_with_errcode = False 270 | try: 271 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 272 | start_time = time.time() 273 | 274 | run_func_obj = util.get_obj_by_name(submit_config.run_func_name) 275 | assert callable(run_func_obj) 276 | sig = inspect.signature(run_func_obj) 277 | if 'submit_config' in sig.parameters: 278 | run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs) 279 | else: 280 | run_func_obj(**submit_config.run_func_kwargs) 281 | 282 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 283 | except: 284 | if is_local: 285 | raise 286 | else: 287 | traceback.print_exc() 288 | 289 | log_src = os.path.join(submit_config.run_dir, "log.txt") 290 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 291 | shutil.copyfile(log_src, log_dst) 292 | 293 | # Defer sys.exit(1) to happen after we close the logs and create a _finished.txt 294 | exit_with_errcode = True 295 | finally: 296 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 297 | 298 | dnnlib.RunContext.get().close() 299 | dnnlib.submit_config = None 300 | logger.close() 301 | 302 | # If we hit an error, get out of the script now and signal the error 303 | # to whatever process that started this script. 304 | if exit_with_errcode: 305 | sys.exit(1) 306 | 307 | return submit_config 308 | 309 | 310 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 311 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 312 | submit_config = copy.deepcopy(submit_config) 313 | 314 | submit_target = submit_config.submit_target 315 | farm = None 316 | if submit_target == SubmitTarget.LOCAL: 317 | farm = internal.local.Target() 318 | assert farm is not None # unknown target 319 | 320 | # Disallow submitting jobs with zero num_gpus. 321 | if (submit_config.num_gpus is None) or (submit_config.num_gpus == 0): 322 | raise RuntimeError("submit_config.num_gpus must be set to a non-zero value") 323 | 324 | if submit_config.user_name is None: 325 | submit_config.user_name = get_user_name() 326 | 327 | submit_config.run_func_name = run_func_name 328 | submit_config.run_func_kwargs = run_func_kwargs 329 | 330 | #-------------------------------------------------------------------- 331 | # Prepare submission by populating the run dir 332 | #-------------------------------------------------------------------- 333 | host_run_dir = _create_run_dir_local(submit_config) 334 | 335 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 336 | docker_valid_name_regex = "^[a-zA-Z0-9][a-zA-Z0-9_.-]+$" 337 | if not re.match(docker_valid_name_regex, submit_config.task_name): 338 | raise RuntimeError("Invalid task name. Probable reason: unacceptable characters in your submit_config.run_desc. Task name must be accepted by the following regex: " + docker_valid_name_regex + ", got " + submit_config.task_name) 339 | 340 | # Farm specific preparations for a submit 341 | farm.finalize_submit_config(submit_config, host_run_dir) 342 | _populate_run_dir(submit_config, host_run_dir) 343 | return farm.submit(submit_config, host_run_dir) 344 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from . import autosummary 8 | from . import network 9 | from . import optimizer 10 | from . import tfutil 11 | from . import custom_ops 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | 18 | from .custom_ops import get_plugin 19 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Helper for adding automatically tracked values to Tensorboard. 8 | 9 | Autosummary creates an identity op that internally keeps track of the input 10 | values and automatically shows up in TensorBoard. The reported value 11 | represents an average over input components. The average is accumulated 12 | constantly over time and flushed when save_summaries() is called. 13 | 14 | Notes: 15 | - The output tensor must be used as an input for something else in the 16 | graph. Otherwise, the autosummary op will not get executed, and the average 17 | value will not get accumulated. 18 | - It is perfectly fine to include autosummaries with the same name in 19 | several places throughout the graph, even if they are executed concurrently. 20 | - It is ok to also pass in a python scalar or numpy array. In this case, it 21 | is added to the average immediately. 22 | """ 23 | 24 | from collections import OrderedDict 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorboard import summary as summary_lib 28 | from tensorboard.plugins.custom_scalar import layout_pb2 29 | 30 | from . import tfutil 31 | from .tfutil import TfExpression 32 | from .tfutil import TfExpressionEx 33 | 34 | # Enable "Custom scalars" tab in TensorBoard for advanced formatting. 35 | # Disabled by default to reduce tfevents file size. 36 | enable_custom_scalars = False 37 | 38 | _dtype = tf.float64 39 | _vars = OrderedDict() # name => [var, ...] 40 | _immediate = OrderedDict() # name => update_op, update_value 41 | _finalized = False 42 | _merge_op = None 43 | 44 | 45 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 46 | """Internal helper for creating autosummary accumulators.""" 47 | assert not _finalized 48 | name_id = name.replace("/", "_") 49 | v = tf.cast(value_expr, _dtype) 50 | 51 | if v.shape.is_fully_defined(): 52 | size = np.prod(v.shape.as_list()) 53 | size_expr = tf.constant(size, dtype=_dtype) 54 | else: 55 | size = None 56 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 57 | 58 | if size == 1: 59 | if v.shape.ndims != 0: 60 | v = tf.reshape(v, []) 61 | v = [size_expr, v, tf.square(v)] 62 | else: 63 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 64 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 65 | 66 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 67 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 68 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 69 | 70 | if name in _vars: 71 | _vars[name].append(var) 72 | else: 73 | _vars[name] = [var] 74 | return update_op 75 | 76 | 77 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: 78 | """Create a new autosummary. 79 | 80 | Args: 81 | name: Name to use in TensorBoard 82 | value: TensorFlow expression or python value to track 83 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 84 | 85 | Example use of the passthru mechanism: 86 | 87 | n = autosummary('l2loss', loss, passthru=n) 88 | 89 | This is a shorthand for the following code: 90 | 91 | with tf.control_dependencies([autosummary('l2loss', loss)]): 92 | n = tf.identity(n) 93 | """ 94 | tfutil.assert_tf_initialized() 95 | name_id = name.replace("/", "_") 96 | 97 | if tfutil.is_tf_expression(value): 98 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 99 | condition = tf.convert_to_tensor(condition, name='condition') 100 | update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) 101 | with tf.control_dependencies([update_op]): 102 | return tf.identity(value if passthru is None else passthru) 103 | 104 | else: # python scalar or numpy array 105 | assert not tfutil.is_tf_expression(passthru) 106 | assert not tfutil.is_tf_expression(condition) 107 | if condition: 108 | if name not in _immediate: 109 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 110 | update_value = tf.placeholder(_dtype) 111 | update_op = _create_var(name, update_value) 112 | _immediate[name] = update_op, update_value 113 | update_op, update_value = _immediate[name] 114 | tfutil.run(update_op, {update_value: value}) 115 | return value if passthru is None else passthru 116 | 117 | 118 | def finalize_autosummaries() -> None: 119 | """Create the necessary ops to include autosummaries in TensorBoard report. 120 | Note: This should be done only once per graph. 121 | """ 122 | global _finalized 123 | tfutil.assert_tf_initialized() 124 | 125 | if _finalized: 126 | return None 127 | 128 | _finalized = True 129 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 130 | 131 | # Create summary ops. 132 | with tf.device(None), tf.control_dependencies(None): 133 | for name, vars_list in _vars.items(): 134 | name_id = name.replace("/", "_") 135 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 136 | moments = tf.add_n(vars_list) 137 | moments /= moments[0] 138 | with tf.control_dependencies([moments]): # read before resetting 139 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 140 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 141 | mean = moments[1] 142 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 143 | tf.summary.scalar(name, mean) 144 | if enable_custom_scalars: 145 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 146 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 147 | 148 | # Setup layout for custom scalars. 149 | layout = None 150 | if enable_custom_scalars: 151 | cat_dict = OrderedDict() 152 | for series_name in sorted(_vars.keys()): 153 | p = series_name.split("/") 154 | cat = p[0] if len(p) >= 2 else "" 155 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 156 | if cat not in cat_dict: 157 | cat_dict[cat] = OrderedDict() 158 | if chart not in cat_dict[cat]: 159 | cat_dict[cat][chart] = [] 160 | cat_dict[cat][chart].append(series_name) 161 | categories = [] 162 | for cat_name, chart_dict in cat_dict.items(): 163 | charts = [] 164 | for chart_name, series_names in chart_dict.items(): 165 | series = [] 166 | for series_name in series_names: 167 | series.append(layout_pb2.MarginChartContent.Series( 168 | value=series_name, 169 | lower="xCustomScalars/" + series_name + "/margin_lo", 170 | upper="xCustomScalars/" + series_name + "/margin_hi")) 171 | margin = layout_pb2.MarginChartContent(series=series) 172 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 173 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 174 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 175 | return layout 176 | 177 | def save_summaries(file_writer, global_step=None): 178 | """Call FileWriter.add_summary() with all summaries in the default graph, 179 | automatically finalizing and merging them on the first call. 180 | """ 181 | global _merge_op 182 | tfutil.assert_tf_initialized() 183 | 184 | if _merge_op is None: 185 | layout = finalize_autosummaries() 186 | if layout is not None: 187 | file_writer.add_summary(layout) 188 | with tf.device(None), tf.control_dependencies(None): 189 | _merge_op = tf.summary.merge_all() 190 | 191 | file_writer.add_summary(_merge_op.eval(), global_step) 192 | -------------------------------------------------------------------------------- /dnnlib/tflib/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """TensorFlow custom ops builder. 8 | """ 9 | 10 | import os 11 | import re 12 | import uuid 13 | import hashlib 14 | import tempfile 15 | import shutil 16 | import tensorflow as tf 17 | from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module 18 | 19 | #---------------------------------------------------------------------------- 20 | # Global options. 21 | 22 | cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache') 23 | cuda_cache_version_tag = 'v1' 24 | do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe! 25 | verbose = True # Print status messages to stdout. 26 | 27 | compiler_bindir_search_path = [ 28 | 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64', 29 | 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64', 30 | 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin', 31 | ] 32 | 33 | #---------------------------------------------------------------------------- 34 | # Internal helper funcs. 35 | 36 | def _find_compiler_bindir(): 37 | for compiler_path in compiler_bindir_search_path: 38 | if os.path.isdir(compiler_path): 39 | return compiler_path 40 | return None 41 | 42 | def _get_compute_cap(device): 43 | caps_str = device.physical_device_desc 44 | m = re.search('compute capability: (\\d+).(\\d+)', caps_str) 45 | major = m.group(1) 46 | minor = m.group(2) 47 | return (major, minor) 48 | 49 | def _get_cuda_gpu_arch_string(): 50 | gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] 51 | if len(gpus) == 0: 52 | raise RuntimeError('No GPU devices found') 53 | (major, minor) = _get_compute_cap(gpus[0]) 54 | return 'sm_%s%s' % (major, minor) 55 | 56 | def _run_cmd(cmd): 57 | with os.popen(cmd) as pipe: 58 | output = pipe.read() 59 | status = pipe.close() 60 | if status is not None: 61 | raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) 62 | 63 | def _prepare_nvcc_cli(opts): 64 | cmd = 'nvcc ' + opts.strip() 65 | cmd += ' --disable-warnings' 66 | include_path = r'C:\Program Files\tensorflow\include' if os.name == 'nt' else tf.sysconfig.get_include() 67 | cmd += ' --include-path "%s"' % include_path 68 | cmd += ' --include-path "%s"' % os.path.join(include_path, 'external', 'protobuf_archive', 'src') 69 | cmd += ' --include-path "%s"' % os.path.join(include_path, 'external', 'com_google_absl') 70 | cmd += ' --include-path "%s"' % os.path.join(include_path, 'external', 'eigen_archive') 71 | 72 | compiler_bindir = _find_compiler_bindir() 73 | if compiler_bindir is None: 74 | # Require that _find_compiler_bindir succeeds on Windows. Allow 75 | # nvcc to use whatever is the default on Linux. 76 | if os.name == 'nt': 77 | raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) 78 | else: 79 | cmd += ' --compiler-bindir "%s"' % compiler_bindir 80 | cmd += ' 2>&1' 81 | return cmd 82 | 83 | #---------------------------------------------------------------------------- 84 | # Main entry point. 85 | 86 | _plugin_cache = dict() 87 | 88 | def get_plugin(cuda_file): 89 | cuda_file_base = os.path.basename(cuda_file) 90 | cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) 91 | 92 | # Already in cache? 93 | if cuda_file in _plugin_cache: 94 | return _plugin_cache[cuda_file] 95 | 96 | # Setup plugin. 97 | if verbose: 98 | print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) 99 | try: 100 | # Hash CUDA source. 101 | md5 = hashlib.md5() 102 | with open(cuda_file, 'rb') as f: 103 | md5.update(f.read()) 104 | md5.update(b'\n') 105 | 106 | # Hash headers included by the CUDA code by running it through the preprocessor. 107 | if not do_not_hash_included_headers: 108 | if verbose: 109 | print('Preprocessing... ', end='', flush=True) 110 | with tempfile.TemporaryDirectory() as tmp_dir: 111 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) 112 | _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) 113 | with open(tmp_file, 'rb') as f: 114 | bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros 115 | good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') 116 | for ln in f: 117 | if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas 118 | ln = ln.replace(bad_file_str, good_file_str) 119 | md5.update(ln) 120 | md5.update(b'\n') 121 | 122 | # Select compiler options. 123 | compile_opts = '' 124 | if os.name == 'nt': 125 | lib_path = r'C:\Program Files\tensorflow' 126 | compile_opts += '"%s"' % os.path.join(lib_path, 'python', '_pywrap_tensorflow_internal.lib') 127 | elif os.name == 'posix': 128 | lib_path = tf.sysconfig.get_lib() 129 | compile_opts += '"%s"' % os.path.join(lib_path, 'python', '_pywrap_tensorflow_internal.so') 130 | compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\'' 131 | else: 132 | assert False # not Windows or Linux, w00t? 133 | compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string() 134 | compile_opts += ' --use_fast_math' 135 | nvcc_cmd = _prepare_nvcc_cli(compile_opts) 136 | 137 | # Hash build configuration. 138 | md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') 139 | md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') 140 | md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') 141 | 142 | # Compile if not already compiled. 143 | bin_file_ext = '.dll' if os.name == 'nt' else '.so' 144 | bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) 145 | if not os.path.isfile(bin_file): 146 | if verbose: 147 | print('Compiling... ', end='', flush=True) 148 | with tempfile.TemporaryDirectory() as tmp_dir: 149 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) 150 | _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) 151 | os.makedirs(cuda_cache_path, exist_ok=True) 152 | intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) 153 | shutil.copyfile(tmp_file, intermediate_file) 154 | os.rename(intermediate_file, bin_file) # atomic 155 | 156 | # Load. 157 | if verbose: 158 | print('Loading... ', end='', flush=True) 159 | plugin = tf.load_op_library(bin_file) 160 | 161 | # Add to cache. 162 | _plugin_cache[cuda_file] = plugin 163 | if verbose: 164 | print('Done.', flush=True) 165 | return plugin 166 | 167 | except: 168 | if verbose: 169 | print('Failed!', flush=True) 170 | raise 171 | 172 | #---------------------------------------------------------------------------- 173 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/fused_bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #define EIGEN_USE_GPU 8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 9 | #include "tensorflow/core/framework/op.h" 10 | #include "tensorflow/core/framework/op_kernel.h" 11 | #include "tensorflow/core/framework/shape_inference.h" 12 | #include 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 18 | 19 | //------------------------------------------------------------------------ 20 | // CUDA kernel. 21 | 22 | template 23 | struct FusedBiasActKernelParams 24 | { 25 | const T* x; // [sizeX] 26 | const T* b; // [sizeB] or NULL 27 | const T* ref; // [sizeX] or NULL 28 | T* y; // [sizeX] 29 | 30 | int grad; 31 | int axis; 32 | int act; 33 | float alpha; 34 | float gain; 35 | 36 | int sizeX; 37 | int sizeB; 38 | int stepB; 39 | int loopX; 40 | }; 41 | 42 | template 43 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) 44 | { 45 | const float expRange = 80.0f; 46 | const float halfExpRange = 40.0f; 47 | const float seluScale = 1.0507009873554804934193349852946f; 48 | const float seluAlpha = 1.6732632423543772848170429916717f; 49 | 50 | // Loop over elements. 51 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 52 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 53 | { 54 | // Load and apply bias. 55 | float x = (float)p.x[xi]; 56 | if (p.b) 57 | x += (float)p.b[(xi / p.stepB) % p.sizeB]; 58 | float ref = (p.ref) ? (float)p.ref[xi] : 0.0f; 59 | if (p.gain != 0.0f & p.act != 9) 60 | ref /= p.gain; 61 | 62 | // Evaluate activation func. 63 | float y; 64 | switch (p.act * 10 + p.grad) 65 | { 66 | // linear 67 | default: 68 | case 10: y = x; break; 69 | case 11: y = x; break; 70 | case 12: y = 0.0f; break; 71 | 72 | // relu 73 | case 20: y = (x > 0.0f) ? x : 0.0f; break; 74 | case 21: y = (ref > 0.0f) ? x : 0.0f; break; 75 | case 22: y = 0.0f; break; 76 | 77 | // lrelu 78 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break; 79 | case 31: y = (ref > 0.0f) ? x : x * p.alpha; break; 80 | case 32: y = 0.0f; break; 81 | 82 | // tanh 83 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; 84 | case 41: y = x * (1.0f - ref * ref); break; 85 | case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break; 86 | 87 | // sigmoid 88 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; 89 | case 51: y = x * ref * (1.0f - ref); break; 90 | case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break; 91 | 92 | // elu 93 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; 94 | case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break; 95 | case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break; 96 | 97 | // selu 98 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; 99 | case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break; 100 | case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break; 101 | 102 | // softplus 103 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; 104 | case 81: y = x * (1.0f - expf(-ref)); break; 105 | case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break; 106 | 107 | // swish 108 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; 109 | case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break; 110 | case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break; 111 | } 112 | 113 | // Apply gain and store. 114 | p.y[xi] = (T)(y * p.gain); 115 | } 116 | } 117 | 118 | //------------------------------------------------------------------------ 119 | // TensorFlow op. 120 | 121 | template 122 | struct FusedBiasActOp : public OpKernel 123 | { 124 | FusedBiasActKernelParams m_attribs; 125 | 126 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) 127 | { 128 | memset(&m_attribs, 0, sizeof(m_attribs)); 129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); 130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); 131 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); 132 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); 133 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); 134 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); 135 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); 136 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); 137 | } 138 | 139 | void Compute(OpKernelContext* ctx) 140 | { 141 | FusedBiasActKernelParams p = m_attribs; 142 | cudaStream_t stream = ctx->eigen_device().stream(); 143 | 144 | const Tensor& x = ctx->input(0); // [...] 145 | const Tensor& b = ctx->input(1); // [sizeB] or [0] 146 | const Tensor& ref = ctx->input(2); // x.shape or [0] 147 | p.x = x.flat().data(); 148 | p.b = (b.NumElements()) ? b.flat().data() : NULL; 149 | p.ref = (ref.NumElements()) ? ref.flat().data() : NULL; 150 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); 151 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); 152 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); 153 | OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements")); 154 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); 155 | 156 | p.sizeX = (int)x.NumElements(); 157 | p.sizeB = (int)b.NumElements(); 158 | p.stepB = 1; 159 | for (int i = m_attribs.axis + 1; i < x.dims(); i++) 160 | p.stepB *= (int)x.dim_size(i); 161 | 162 | Tensor* y = NULL; // x.shape 163 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); 164 | p.y = y->flat().data(); 165 | 166 | p.loopX = 4; 167 | int blockSize = 4 * 32; 168 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 169 | void* args[] = {&p}; 170 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); 171 | } 172 | }; 173 | 174 | REGISTER_OP("FusedBiasAct") 175 | .Input ("x: T") 176 | .Input ("b: T") 177 | .Input ("ref: T") 178 | .Output ("y: T") 179 | .Attr ("T: {float, half}") 180 | .Attr ("grad: int = 0") 181 | .Attr ("axis: int = 1") 182 | .Attr ("act: int = 0") 183 | .Attr ("alpha: float = 0.0") 184 | .Attr ("gain: float = 1.0"); 185 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 186 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); 187 | 188 | //------------------------------------------------------------------------ 189 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/fused_bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Custom TensorFlow ops for efficient bias and activation.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | from .. import custom_ops 13 | from ...util import EasyDict 14 | 15 | def _get_plugin(): 16 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | activation_funcs = { 21 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), 22 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), 23 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), 24 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), 25 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), 26 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), 27 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), 28 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), 29 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), 30 | } 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'): 35 | r"""Fused bias and activation function. 36 | 37 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 38 | and scales the result by `gain`. Each of the steps is optional. In most cases, 39 | the fused op is considerably more efficient than performing the same calculation 40 | using standard TensorFlow ops. It supports first and second order gradients, 41 | but not third order gradients. 42 | 43 | Args: 44 | x: Input activation tensor. Can have any shape, but if `b` is defined, the 45 | dimension corresponding to `axis`, as well as the rank, must be known. 46 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 47 | as `x`. The shape must be known, and it must match the dimension of `x` 48 | corresponding to `axis`. 49 | axis: The dimension in `x` corresponding to the elements of `b`. 50 | The value of `axis` is ignored if `b` is not specified. 51 | act: Name of the activation function to evaluate, or `"linear"` to disable. 52 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 53 | See `activation_funcs` for a full list. `None` is not allowed. 54 | alpha: Shape parameter for the activation function, or `None` to use the default. 55 | gain: Scaling factor for the output tensor, or `None` to use default. 56 | See `activation_funcs` for the default scaling of each activation function. 57 | If unsure, consider specifying `1.0`. 58 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 59 | 60 | Returns: 61 | Tensor of the same shape and datatype as `x`. 62 | """ 63 | 64 | impl_dict = { 65 | 'ref': _fused_bias_act_ref, 66 | 'cuda': _fused_bias_act_cuda, 67 | } 68 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain): 73 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" 74 | 75 | # Validate arguments. 76 | x = tf.convert_to_tensor(x) 77 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) 78 | act_spec = activation_funcs[act] 79 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 80 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 81 | if alpha is None: 82 | alpha = act_spec.def_alpha 83 | if gain is None: 84 | gain = act_spec.def_gain 85 | 86 | # Add bias. 87 | if b.shape[0] != 0: 88 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) 89 | 90 | # Evaluate activation function. 91 | x = act_spec.func(x, alpha=alpha) 92 | 93 | # Scale by gain. 94 | if gain != 1: 95 | x *= gain 96 | return x 97 | 98 | #---------------------------------------------------------------------------- 99 | 100 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain): 101 | """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" 102 | 103 | # Validate arguments. 104 | x = tf.convert_to_tensor(x) 105 | empty_tensor = tf.constant([], dtype=x.dtype) 106 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor 107 | act_spec = activation_funcs[act] 108 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) 109 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank 110 | if alpha is None: 111 | alpha = act_spec.def_alpha 112 | if gain is None: 113 | gain = act_spec.def_gain 114 | 115 | # Special cases. 116 | if act == 'linear' and b is None and gain == 1.0: 117 | return x 118 | if act_spec.cuda_idx is None: 119 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain) 120 | 121 | # CUDA kernel. 122 | cuda_kernel = _get_plugin().fused_bias_act 123 | cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain) 124 | 125 | # Forward pass: y = func(x, b). 126 | def func_y(x, b): 127 | y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs) 128 | y.set_shape(x.shape) 129 | return y 130 | 131 | # Backward pass: dx, db = grad(dy, x, y) 132 | def grad_dx(dy, x, y): 133 | ref = {'x': x, 'y': y}[act_spec.ref] 134 | dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs) 135 | dx.set_shape(x.shape) 136 | return dx 137 | def grad_db(dx): 138 | if b.shape[0] == 0: 139 | return empty_tensor 140 | db = dx 141 | if axis < x.shape.rank - 1: 142 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) 143 | if axis > 0: 144 | db = tf.reduce_sum(db, list(range(axis))) 145 | db.set_shape(b.shape) 146 | return db 147 | 148 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) 149 | def grad2_d_dy(d_dx, d_db, x, y): 150 | ref = {'x': x, 'y': y}[act_spec.ref] 151 | d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs) 152 | d_dy.set_shape(x.shape) 153 | return d_dy 154 | def grad2_d_x(d_dx, d_db, x, y): 155 | ref = {'x': x, 'y': y}[act_spec.ref] 156 | d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs) 157 | d_x.set_shape(x.shape) 158 | return d_x 159 | 160 | # Fast version for piecewise-linear activation funcs. 161 | @tf.custom_gradient 162 | def func_zero_2nd_grad(x, b): 163 | y = func_y(x, b) 164 | @tf.custom_gradient 165 | def grad(dy): 166 | dx = grad_dx(dy, x, y) 167 | db = grad_db(dx) 168 | def grad2(d_dx, d_db): 169 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 170 | return d_dy 171 | return (dx, db), grad2 172 | return y, grad 173 | 174 | # Slow version for general activation funcs. 175 | @tf.custom_gradient 176 | def func_nonzero_2nd_grad(x, b): 177 | y = func_y(x, b) 178 | def grad_wrap(dy): 179 | @tf.custom_gradient 180 | def grad_impl(dy, x): 181 | dx = grad_dx(dy, x, y) 182 | db = grad_db(dx) 183 | def grad2(d_dx, d_db): 184 | d_dy = grad2_d_dy(d_dx, d_db, x, y) 185 | d_x = grad2_d_x(d_dx, d_db, x, y) 186 | return d_dy, d_x 187 | return (dx, db), grad2 188 | return grad_impl(dy, x) 189 | return y, grad_wrap 190 | 191 | # Which version to use? 192 | if act_spec.zero_2nd_grad: 193 | return func_zero_2nd_grad(x, b) 194 | return func_nonzero_2nd_grad(x, b) 195 | 196 | #---------------------------------------------------------------------------- 197 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/upfirdn_2d.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #define EIGEN_USE_GPU 8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ 9 | #include "tensorflow/core/framework/op.h" 10 | #include "tensorflow/core/framework/op_kernel.h" 11 | #include "tensorflow/core/framework/shape_inference.h" 12 | #include 13 | 14 | using namespace tensorflow; 15 | using namespace tensorflow::shape_inference; 16 | 17 | //------------------------------------------------------------------------ 18 | // Helpers. 19 | 20 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) 21 | 22 | static __host__ __device__ __forceinline__ int floorDiv(int a, int b) 23 | { 24 | int c = a / b; 25 | if (c * b > a) 26 | c--; 27 | return c; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | // CUDA kernel params. 32 | 33 | template 34 | struct UpFirDn2DKernelParams 35 | { 36 | const T* x; // [majorDim, inH, inW, minorDim] 37 | const T* k; // [kernelH, kernelW] 38 | T* y; // [majorDim, outH, outW, minorDim] 39 | 40 | int upx; 41 | int upy; 42 | int downx; 43 | int downy; 44 | int padx0; 45 | int padx1; 46 | int pady0; 47 | int pady1; 48 | 49 | int majorDim; 50 | int inH; 51 | int inW; 52 | int minorDim; 53 | int kernelH; 54 | int kernelW; 55 | int outH; 56 | int outW; 57 | int loopMajor; 58 | int loopX; 59 | }; 60 | 61 | //------------------------------------------------------------------------ 62 | // General CUDA implementation for large filter kernels. 63 | 64 | template 65 | static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) 66 | { 67 | // Calculate thread index. 68 | int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; 69 | int outY = minorIdx / p.minorDim; 70 | minorIdx -= outY * p.minorDim; 71 | int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; 72 | int majorIdxBase = blockIdx.z * p.loopMajor; 73 | if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) 74 | return; 75 | 76 | // Setup Y receptive field. 77 | int midY = outY * p.downy + p.upy - 1 - p.pady0; 78 | int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); 79 | int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; 80 | int kernelY = midY + p.kernelH - (inY + 1) * p.upy; 81 | 82 | // Loop over majorDim and outX. 83 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) 84 | for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) 85 | { 86 | // Setup X receptive field. 87 | int midX = outX * p.downx + p.upx - 1 - p.padx0; 88 | int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); 89 | int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; 90 | int kernelX = midX + p.kernelW - (inX + 1) * p.upx; 91 | 92 | // Initialize pointers. 93 | const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; 94 | const T* kp = &p.k[kernelY * p.kernelW + kernelX]; 95 | int xpx = p.minorDim; 96 | int kpx = -p.upx; 97 | int xpy = p.inW * p.minorDim; 98 | int kpy = -p.upy * p.kernelW; 99 | 100 | // Inner loop. 101 | float v = 0.0f; 102 | for (int y = 0; y < h; y++) 103 | { 104 | for (int x = 0; x < w; x++) 105 | { 106 | v += (float)(*xp) * (float)(*kp); 107 | xp += xpx; 108 | kp += kpx; 109 | } 110 | xp += xpy - w * xpx; 111 | kp += kpy - w * kpx; 112 | } 113 | 114 | // Store result. 115 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; 116 | } 117 | } 118 | 119 | //------------------------------------------------------------------------ 120 | // Specialized CUDA implementation for small filter kernels. 121 | 122 | template 123 | static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) 124 | { 125 | //assert(kernelW % upx == 0); 126 | //assert(kernelH % upy == 0); 127 | const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; 128 | const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; 129 | __shared__ volatile float sk[kernelH][kernelW]; 130 | __shared__ volatile float sx[tileInH][tileInW]; 131 | 132 | // Calculate tile index. 133 | int minorIdx = blockIdx.x; 134 | int tileOutY = minorIdx / p.minorDim; 135 | minorIdx -= tileOutY * p.minorDim; 136 | tileOutY *= tileOutH; 137 | int tileOutXBase = blockIdx.y * p.loopX * tileOutW; 138 | int majorIdxBase = blockIdx.z * p.loopMajor; 139 | if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) 140 | return; 141 | 142 | // Load filter kernel (flipped). 143 | for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) 144 | { 145 | int ky = tapIdx / kernelW; 146 | int kx = tapIdx - ky * kernelW; 147 | float v = 0.0f; 148 | if (kx < p.kernelW & ky < p.kernelH) 149 | v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; 150 | sk[ky][kx] = v; 151 | } 152 | 153 | // Loop over majorDim and outX. 154 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) 155 | for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) 156 | { 157 | // Load input pixels. 158 | int tileMidX = tileOutX * downx + upx - 1 - p.padx0; 159 | int tileMidY = tileOutY * downy + upy - 1 - p.pady0; 160 | int tileInX = floorDiv(tileMidX, upx); 161 | int tileInY = floorDiv(tileMidY, upy); 162 | __syncthreads(); 163 | for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) 164 | { 165 | int relInY = inIdx / tileInW; 166 | int relInX = inIdx - relInY * tileInW; 167 | int inX = relInX + tileInX; 168 | int inY = relInY + tileInY; 169 | float v = 0.0f; 170 | if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) 171 | v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; 172 | sx[relInY][relInX] = v; 173 | } 174 | 175 | // Loop over output pixels. 176 | __syncthreads(); 177 | for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) 178 | { 179 | int relOutY = outIdx / tileOutW; 180 | int relOutX = outIdx - relOutY * tileOutW; 181 | int outX = relOutX + tileOutX; 182 | int outY = relOutY + tileOutY; 183 | 184 | // Setup receptive field. 185 | int midX = tileMidX + relOutX * downx; 186 | int midY = tileMidY + relOutY * downy; 187 | int inX = floorDiv(midX, upx); 188 | int inY = floorDiv(midY, upy); 189 | int relInX = inX - tileInX; 190 | int relInY = inY - tileInY; 191 | int kernelX = (inX + 1) * upx - midX - 1; // flipped 192 | int kernelY = (inY + 1) * upy - midY - 1; // flipped 193 | 194 | // Inner loop. 195 | float v = 0.0f; 196 | #pragma unroll 197 | for (int y = 0; y < kernelH / upy; y++) 198 | #pragma unroll 199 | for (int x = 0; x < kernelW / upx; x++) 200 | v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; 201 | 202 | // Store result. 203 | if (outX < p.outW & outY < p.outH) 204 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; 205 | } 206 | } 207 | } 208 | 209 | //------------------------------------------------------------------------ 210 | // TensorFlow op. 211 | 212 | template 213 | struct UpFirDn2DOp : public OpKernel 214 | { 215 | UpFirDn2DKernelParams m_attribs; 216 | 217 | UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) 218 | { 219 | memset(&m_attribs, 0, sizeof(m_attribs)); 220 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); 221 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); 222 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); 223 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); 224 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); 225 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); 226 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); 227 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); 228 | OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); 229 | OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); 230 | } 231 | 232 | void Compute(OpKernelContext* ctx) 233 | { 234 | UpFirDn2DKernelParams p = m_attribs; 235 | cudaStream_t stream = ctx->eigen_device().stream(); 236 | 237 | const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] 238 | const Tensor& k = ctx->input(1); // [kernelH, kernelW] 239 | p.x = x.flat().data(); 240 | p.k = k.flat().data(); 241 | OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); 242 | OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); 243 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); 244 | OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); 245 | 246 | p.majorDim = (int)x.dim_size(0); 247 | p.inH = (int)x.dim_size(1); 248 | p.inW = (int)x.dim_size(2); 249 | p.minorDim = (int)x.dim_size(3); 250 | p.kernelH = (int)k.dim_size(0); 251 | p.kernelW = (int)k.dim_size(1); 252 | OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); 253 | 254 | p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; 255 | p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; 256 | OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); 257 | 258 | Tensor* y = NULL; // [majorDim, outH, outW, minorDim] 259 | TensorShape ys; 260 | ys.AddDim(p.majorDim); 261 | ys.AddDim(p.outH); 262 | ys.AddDim(p.outW); 263 | ys.AddDim(p.minorDim); 264 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); 265 | p.y = y->flat().data(); 266 | OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); 267 | 268 | // Choose CUDA kernel to use. 269 | void* cudaKernel = (void*)UpFirDn2DKernel_large; 270 | int tileOutW = -1; 271 | int tileOutH = -1; 272 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 273 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 274 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 275 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 276 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 277 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 278 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 279 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 280 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } 281 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 282 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 283 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 284 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } 285 | 286 | // Choose launch params. 287 | dim3 blockSize; 288 | dim3 gridSize; 289 | if (tileOutW > 0 && tileOutH > 0) // small 290 | { 291 | p.loopMajor = (p.majorDim - 1) / 16384 + 1; 292 | p.loopX = 1; 293 | blockSize = dim3(32 * 8, 1, 1); 294 | gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); 295 | } 296 | else // large 297 | { 298 | p.loopMajor = (p.majorDim - 1) / 16384 + 1; 299 | p.loopX = 4; 300 | blockSize = dim3(4, 32, 1); 301 | gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); 302 | } 303 | 304 | // Launch CUDA kernel. 305 | void* args[] = {&p}; 306 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); 307 | } 308 | }; 309 | 310 | REGISTER_OP("UpFirDn2D") 311 | .Input ("x: T") 312 | .Input ("k: T") 313 | .Output ("y: T") 314 | .Attr ("T: {float, half}") 315 | .Attr ("upx: int = 1") 316 | .Attr ("upy: int = 1") 317 | .Attr ("downx: int = 1") 318 | .Attr ("downy: int = 1") 319 | .Attr ("padx0: int = 0") 320 | .Attr ("padx1: int = 0") 321 | .Attr ("pady0: int = 0") 322 | .Attr ("pady1: int = 0"); 323 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); 324 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); 325 | 326 | //------------------------------------------------------------------------ 327 | -------------------------------------------------------------------------------- /dnnlib/tflib/ops/upfirdn_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Custom TensorFlow ops for efficient resampling of 2D images.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | from .. import custom_ops 13 | 14 | def _get_plugin(): 15 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'): 20 | r"""Pad, upsample, FIR filter, and downsample a batch of 2D images. 21 | 22 | Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]` 23 | and performs the following operations for each image, batched across 24 | `majorDim` and `minorDim`: 25 | 26 | 1. Pad the image with zeros by the specified number of pixels on each side 27 | (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value 28 | corresponds to cropping the image. 29 | 30 | 2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`). 31 | 32 | 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the 33 | image so that the footprint of all output pixels lies within the input image. 34 | 35 | 4. Downsample the image by throwing away pixels (`downx`, `downy`). 36 | 37 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 38 | The fused op is considerably more efficient than performing the same calculation 39 | using standard TensorFlow ops. It supports gradients of arbitrary order. 40 | 41 | Args: 42 | x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`. 43 | k: 2D FIR filter of the shape `[firH, firW]`. 44 | upx: Integer upsampling factor along the X-axis (default: 1). 45 | upy: Integer upsampling factor along the Y-axis (default: 1). 46 | downx: Integer downsampling factor along the X-axis (default: 1). 47 | downy: Integer downsampling factor along the Y-axis (default: 1). 48 | padx0: Number of pixels to pad on the left side (default: 0). 49 | padx1: Number of pixels to pad on the right side (default: 0). 50 | pady0: Number of pixels to pad on the top side (default: 0). 51 | pady1: Number of pixels to pad on the bottom side (default: 0). 52 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 53 | 54 | Returns: 55 | Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`. 56 | """ 57 | 58 | impl_dict = { 59 | 'ref': _upfirdn_2d_ref, 60 | 'cuda': _upfirdn_2d_cuda, 61 | } 62 | return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): 67 | """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops.""" 68 | 69 | x = tf.convert_to_tensor(x) 70 | k = np.asarray(k, dtype=np.float32) 71 | assert x.shape.rank == 4 72 | inH = x.shape[1].value 73 | inW = x.shape[2].value 74 | minorDim = _shape(x, 3) 75 | kernelH, kernelW = k.shape 76 | assert inW >= 1 and inH >= 1 77 | assert kernelW >= 1 and kernelH >= 1 78 | assert isinstance(upx, int) and isinstance(upy, int) 79 | assert isinstance(downx, int) and isinstance(downy, int) 80 | assert isinstance(padx0, int) and isinstance(padx1, int) 81 | assert isinstance(pady0, int) and isinstance(pady1, int) 82 | 83 | # Upsample (insert zeros). 84 | x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim]) 85 | x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]]) 86 | x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim]) 87 | 88 | # Pad (crop if negative). 89 | x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]]) 90 | x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :] 91 | 92 | # Convolve with filter. 93 | x = tf.transpose(x, [0, 3, 1, 2]) 94 | x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1]) 95 | w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype) 96 | x = tf.transpose(x, [0, 2, 3, 1]) 97 | x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NHWC') 98 | x = tf.transpose(x, [0, 3, 1, 2]) 99 | x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1]) 100 | x = tf.transpose(x, [0, 2, 3, 1]) 101 | 102 | # Downsample (throw away pixels). 103 | return x[:, ::downy, ::downx, :] 104 | 105 | #---------------------------------------------------------------------------- 106 | 107 | def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): 108 | """Fast CUDA implementation of `upfirdn_2d()` using custom ops.""" 109 | 110 | x = tf.convert_to_tensor(x) 111 | k = np.asarray(k, dtype=np.float32) 112 | majorDim, inH, inW, minorDim = x.shape.as_list() 113 | kernelH, kernelW = k.shape 114 | assert inW >= 1 and inH >= 1 115 | assert kernelW >= 1 and kernelH >= 1 116 | assert isinstance(upx, int) and isinstance(upy, int) 117 | assert isinstance(downx, int) and isinstance(downy, int) 118 | assert isinstance(padx0, int) and isinstance(padx1, int) 119 | assert isinstance(pady0, int) and isinstance(pady1, int) 120 | 121 | outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1 122 | outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1 123 | assert outW >= 1 and outH >= 1 124 | 125 | kc = tf.constant(k, dtype=x.dtype) 126 | gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype) 127 | gpadx0 = kernelW - padx0 - 1 128 | gpady0 = kernelH - pady0 - 1 129 | gpadx1 = inW * upx - outW * downx + padx0 - upx + 1 130 | gpady1 = inH * upy - outH * downy + pady0 - upy + 1 131 | 132 | @tf.custom_gradient 133 | def func(x): 134 | y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) 135 | y.set_shape([majorDim, outH, outW, minorDim]) 136 | @tf.custom_gradient 137 | def grad(dy): 138 | dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1) 139 | dx.set_shape([majorDim, inH, inW, minorDim]) 140 | return dx, func 141 | return y, grad 142 | return func(x) 143 | 144 | #---------------------------------------------------------------------------- 145 | 146 | def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'): 147 | r"""Filter a batch of 2D images with the given FIR filter. 148 | 149 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 150 | and filters each image with the given filter. The filter is normalized so that 151 | if the input pixels are constant, they will be scaled by the specified `gain`. 152 | Pixels outside the image are assumed to be zero. 153 | 154 | Args: 155 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 156 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 157 | gain: Scaling factor for signal magnitude (default: 1.0). 158 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 159 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 160 | 161 | Returns: 162 | Tensor of the same shape and datatype as `x`. 163 | """ 164 | 165 | k = _setup_kernel(k) * gain 166 | p = k.shape[0] - 1 167 | return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl) 168 | 169 | #---------------------------------------------------------------------------- 170 | 171 | def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): 172 | r"""Upsample a batch of 2D images with the given filter. 173 | 174 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 175 | and upsamples each image with the given filter. The filter is normalized so that 176 | if the input pixels are constant, they will be scaled by the specified `gain`. 177 | Pixels outside the image are assumed to be zero, and the filter is padded with 178 | zeros so that its shape is a multiple of the upsampling factor. 179 | 180 | Args: 181 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 182 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 183 | The default is `[1] * factor`, which corresponds to nearest-neighbor 184 | upsampling. 185 | factor: Integer upsampling factor (default: 2). 186 | gain: Scaling factor for signal magnitude (default: 1.0). 187 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 188 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 189 | 190 | Returns: 191 | Tensor of the shape `[N, C, H * factor, W * factor]` or 192 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 193 | """ 194 | 195 | assert isinstance(factor, int) and factor >= 1 196 | if k is None: 197 | k = [1] * factor 198 | k = _setup_kernel(k) * (gain * (factor ** 2)) 199 | p = k.shape[0] - factor 200 | return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl) 201 | 202 | #---------------------------------------------------------------------------- 203 | 204 | def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): 205 | r"""Downsample a batch of 2D images with the given filter. 206 | 207 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 208 | and downsamples each image with the given filter. The filter is normalized so that 209 | if the input pixels are constant, they will be scaled by the specified `gain`. 210 | Pixels outside the image are assumed to be zero, and the filter is padded with 211 | zeros so that its shape is a multiple of the downsampling factor. 212 | 213 | Args: 214 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 215 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 216 | The default is `[1] * factor`, which corresponds to average pooling. 217 | factor: Integer downsampling factor (default: 2). 218 | gain: Scaling factor for signal magnitude (default: 1.0). 219 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 220 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 221 | 222 | Returns: 223 | Tensor of the shape `[N, C, H // factor, W // factor]` or 224 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 225 | """ 226 | 227 | assert isinstance(factor, int) and factor >= 1 228 | if k is None: 229 | k = [1] * factor 230 | k = _setup_kernel(k) * gain 231 | p = k.shape[0] - factor 232 | return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl) 233 | 234 | #---------------------------------------------------------------------------- 235 | 236 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): 237 | r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 238 | 239 | Padding is performed only once at the beginning, not between the operations. 240 | The fused op is considerably more efficient than performing the same calculation 241 | using standard TensorFlow ops. It supports gradients of arbitrary order. 242 | 243 | Args: 244 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 245 | w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. 246 | Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. 247 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 248 | The default is `[1] * factor`, which corresponds to nearest-neighbor 249 | upsampling. 250 | factor: Integer upsampling factor (default: 2). 251 | gain: Scaling factor for signal magnitude (default: 1.0). 252 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 253 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 254 | 255 | Returns: 256 | Tensor of the shape `[N, C, H * factor, W * factor]` or 257 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 258 | """ 259 | 260 | assert isinstance(factor, int) and factor >= 1 261 | 262 | # Check weight shape. 263 | w = tf.convert_to_tensor(w) 264 | assert w.shape.rank == 4 265 | convH = w.shape[0].value 266 | convW = w.shape[1].value 267 | inC = _shape(w, 2) 268 | outC = _shape(w, 3) 269 | assert convW == convH 270 | 271 | # Setup filter kernel. 272 | if k is None: 273 | k = [1] * factor 274 | k = _setup_kernel(k) * (gain * (factor ** 2)) 275 | p = (k.shape[0] - factor) - (convW - 1) 276 | 277 | # Determine data dimensions. 278 | if data_format == 'NCHW': 279 | stride = [1, 1, factor, factor] 280 | output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW] 281 | num_groups = _shape(x, 1) // inC 282 | else: 283 | stride = [1, factor, factor, 1] 284 | output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC] 285 | num_groups = _shape(x, 3) // inC 286 | 287 | # Transpose weights. 288 | w = tf.reshape(w, [convH, convW, inC, num_groups, -1]) 289 | w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2]) 290 | w = tf.reshape(w, [convH, convW, -1, num_groups * inC]) 291 | 292 | # Execute. 293 | x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format) 294 | return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl) 295 | 296 | #---------------------------------------------------------------------------- 297 | 298 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'): 299 | r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 300 | 301 | Padding is performed only once at the beginning, not between the operations. 302 | The fused op is considerably more efficient than performing the same calculation 303 | using standard TensorFlow ops. It supports gradients of arbitrary order. 304 | 305 | Args: 306 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 307 | w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. 308 | Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. 309 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). 310 | The default is `[1] * factor`, which corresponds to average pooling. 311 | factor: Integer downsampling factor (default: 2). 312 | gain: Scaling factor for signal magnitude (default: 1.0). 313 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 314 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 315 | 316 | Returns: 317 | Tensor of the shape `[N, C, H // factor, W // factor]` or 318 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 319 | """ 320 | 321 | assert isinstance(factor, int) and factor >= 1 322 | w = tf.convert_to_tensor(w) 323 | convH, convW, _inC, _outC = w.shape.as_list() 324 | assert convW == convH 325 | if k is None: 326 | k = [1] * factor 327 | k = _setup_kernel(k) * gain 328 | p = (k.shape[0] - factor) + (convW - 1) 329 | if data_format == 'NCHW': 330 | s = [1, 1, factor, factor] 331 | else: 332 | s = [1, factor, factor, 1] 333 | x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl) 334 | return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format) 335 | 336 | #---------------------------------------------------------------------------- 337 | # Internal helper funcs. 338 | 339 | def _shape(tf_expr, dim_idx): 340 | if tf_expr.shape.rank is not None: 341 | dim = tf_expr.shape[dim_idx].value 342 | if dim is not None: 343 | return dim 344 | return tf.shape(tf_expr)[dim_idx] 345 | 346 | def _setup_kernel(k): 347 | k = np.asarray(k, dtype=np.float32) 348 | if k.ndim == 1: 349 | k = np.outer(k, k) 350 | k /= np.sum(k) 351 | assert k.ndim == 2 352 | assert k.shape[0] == k.shape[1] 353 | return k 354 | 355 | def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'): 356 | assert data_format in ['NCHW', 'NHWC'] 357 | assert x.shape.rank == 4 358 | y = x 359 | if data_format == 'NCHW': 360 | y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1]) 361 | y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl) 362 | if data_format == 'NCHW': 363 | y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)]) 364 | return y 365 | 366 | #---------------------------------------------------------------------------- 367 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Miscellaneous helper utils for Tensorflow.""" 8 | 9 | import os 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | # Silence deprecation warnings from TensorFlow 1.13 onwards 14 | import logging 15 | logging.getLogger('tensorflow').setLevel(logging.ERROR) 16 | import tensorflow.contrib # requires TensorFlow 1.x! 17 | tf.contrib = tensorflow.contrib 18 | 19 | from typing import Any, Iterable, List, Union 20 | 21 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 22 | """A type that represents a valid Tensorflow expression.""" 23 | 24 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 25 | """A type that can be converted to a valid Tensorflow expression.""" 26 | 27 | 28 | def run(*args, **kwargs) -> Any: 29 | """Run the specified ops in the default session.""" 30 | assert_tf_initialized() 31 | return tf.get_default_session().run(*args, **kwargs) 32 | 33 | 34 | def is_tf_expression(x: Any) -> bool: 35 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 36 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 37 | 38 | 39 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 40 | """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" 41 | return [dim.value for dim in shape] 42 | 43 | 44 | def flatten(x: TfExpressionEx) -> TfExpression: 45 | """Shortcut function for flattening a tensor.""" 46 | with tf.name_scope("Flatten"): 47 | return tf.reshape(x, [-1]) 48 | 49 | 50 | def log2(x: TfExpressionEx) -> TfExpression: 51 | """Logarithm in base 2.""" 52 | with tf.name_scope("Log2"): 53 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 54 | 55 | 56 | def exp2(x: TfExpressionEx) -> TfExpression: 57 | """Exponent in base 2.""" 58 | with tf.name_scope("Exp2"): 59 | return tf.exp(x * np.float32(np.log(2.0))) 60 | 61 | 62 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 63 | """Linear interpolation.""" 64 | with tf.name_scope("Lerp"): 65 | return a + (b - a) * t 66 | 67 | 68 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 69 | """Linear interpolation with clip.""" 70 | with tf.name_scope("LerpClip"): 71 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 72 | 73 | 74 | def absolute_name_scope(scope: str) -> tf.name_scope: 75 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 76 | return tf.name_scope(scope + "/") 77 | 78 | 79 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 80 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 81 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 82 | 83 | 84 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 85 | # Defaults. 86 | cfg = dict() 87 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 88 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 89 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 90 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 91 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 92 | 93 | # Remove defaults for environment variables that are already set. 94 | for key in list(cfg): 95 | fields = key.split(".") 96 | if fields[0] == "env": 97 | assert len(fields) == 2 98 | if fields[1] in os.environ: 99 | del cfg[key] 100 | 101 | # User overrides. 102 | if config_dict is not None: 103 | cfg.update(config_dict) 104 | return cfg 105 | 106 | 107 | def init_tf(config_dict: dict = None) -> None: 108 | """Initialize TensorFlow session using good default settings.""" 109 | # Skip if already initialized. 110 | if tf.get_default_session() is not None: 111 | return 112 | 113 | # Setup config dict and random seeds. 114 | cfg = _sanitize_tf_config(config_dict) 115 | np_random_seed = cfg["rnd.np_random_seed"] 116 | if np_random_seed is not None: 117 | np.random.seed(np_random_seed) 118 | tf_random_seed = cfg["rnd.tf_random_seed"] 119 | if tf_random_seed == "auto": 120 | tf_random_seed = np.random.randint(1 << 31) 121 | if tf_random_seed is not None: 122 | tf.set_random_seed(tf_random_seed) 123 | 124 | # Setup environment variables. 125 | for key, value in cfg.items(): 126 | fields = key.split(".") 127 | if fields[0] == "env": 128 | assert len(fields) == 2 129 | os.environ[fields[1]] = str(value) 130 | 131 | # Create default TensorFlow session. 132 | create_session(cfg, force_as_default=True) 133 | 134 | 135 | def assert_tf_initialized(): 136 | """Check that TensorFlow session has been initialized.""" 137 | if tf.get_default_session() is None: 138 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 139 | 140 | 141 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 142 | """Create tf.Session based on config dict.""" 143 | # Setup TensorFlow config proto. 144 | cfg = _sanitize_tf_config(config_dict) 145 | config_proto = tf.ConfigProto() 146 | for key, value in cfg.items(): 147 | fields = key.split(".") 148 | if fields[0] not in ["rnd", "env"]: 149 | obj = config_proto 150 | for field in fields[:-1]: 151 | obj = getattr(obj, field) 152 | setattr(obj, fields[-1], value) 153 | 154 | # Create session. 155 | session = tf.Session(config=config_proto) 156 | if force_as_default: 157 | # pylint: disable=protected-access 158 | session._default_session = session.as_default() 159 | session._default_session.enforce_nesting = False 160 | session._default_session.__enter__() 161 | return session 162 | 163 | 164 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 165 | """Initialize all tf.Variables that have not already been initialized. 166 | 167 | Equivalent to the following, but more efficient and does not bloat the tf graph: 168 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 169 | """ 170 | assert_tf_initialized() 171 | if target_vars is None: 172 | target_vars = tf.global_variables() 173 | 174 | test_vars = [] 175 | test_ops = [] 176 | 177 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 178 | for var in target_vars: 179 | assert is_tf_expression(var) 180 | 181 | try: 182 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 183 | except KeyError: 184 | # Op does not exist => variable may be uninitialized. 185 | test_vars.append(var) 186 | 187 | with absolute_name_scope(var.name.split(":")[0]): 188 | test_ops.append(tf.is_variable_initialized(var)) 189 | 190 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 191 | run([var.initializer for var in init_vars]) 192 | 193 | 194 | def set_vars(var_to_value_dict: dict) -> None: 195 | """Set the values of given tf.Variables. 196 | 197 | Equivalent to the following, but more efficient and does not bloat the tf graph: 198 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 199 | """ 200 | assert_tf_initialized() 201 | ops = [] 202 | feed_dict = {} 203 | 204 | for var, value in var_to_value_dict.items(): 205 | assert is_tf_expression(var) 206 | 207 | try: 208 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 209 | except KeyError: 210 | with absolute_name_scope(var.name.split(":")[0]): 211 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 212 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 213 | 214 | ops.append(setter) 215 | feed_dict[setter.op.inputs[1]] = value 216 | 217 | run(ops, feed_dict) 218 | 219 | 220 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 221 | """Create tf.Variable with large initial value without bloating the tf graph.""" 222 | assert_tf_initialized() 223 | assert isinstance(initial_value, np.ndarray) 224 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 225 | var = tf.Variable(zeros, *args, **kwargs) 226 | set_vars({var: initial_value}) 227 | return var 228 | 229 | 230 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 231 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 232 | Can be used as an input transformation for Network.run(). 233 | """ 234 | images = tf.cast(images, tf.float32) 235 | if nhwc_to_nchw: 236 | images = tf.transpose(images, [0, 3, 1, 2]) 237 | return images * ((drange[1] - drange[0]) / 255) + drange[0] 238 | 239 | 240 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 241 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 242 | Can be used as an output transformation for Network.run(). 243 | """ 244 | images = tf.cast(images, tf.float32) 245 | if shrink > 1: 246 | ksize = [1, 1, shrink, shrink] 247 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 248 | if nchw_to_nhwc: 249 | images = tf.transpose(images, [0, 2, 3, 1]) 250 | scale = 255 / (drange[1] - drange[0]) 251 | images = images * scale + (0.5 - drange[0] * scale) 252 | return tf.saturate_cast(images, tf.uint8) 253 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Miscellaneous utility classes and functions.""" 8 | 9 | import ctypes 10 | import fnmatch 11 | import importlib 12 | import inspect 13 | import numpy as np 14 | import os 15 | import shutil 16 | import sys 17 | import types 18 | import io 19 | import pickle 20 | import re 21 | import requests 22 | import html 23 | import hashlib 24 | import glob 25 | import uuid 26 | 27 | from distutils.util import strtobool 28 | from typing import Any, List, Tuple, Union 29 | 30 | 31 | # Util classes 32 | # ------------------------------------------------------------------------------------------ 33 | 34 | 35 | class EasyDict(dict): 36 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 37 | 38 | def __getattr__(self, name: str) -> Any: 39 | try: 40 | return self[name] 41 | except KeyError: 42 | raise AttributeError(name) 43 | 44 | def __setattr__(self, name: str, value: Any) -> None: 45 | self[name] = value 46 | 47 | def __delattr__(self, name: str) -> None: 48 | del self[name] 49 | 50 | 51 | class Logger(object): 52 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 53 | 54 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 55 | self.file = None 56 | 57 | if file_name is not None: 58 | self.file = open(file_name, file_mode) 59 | 60 | self.should_flush = should_flush 61 | self.stdout = sys.stdout 62 | self.stderr = sys.stderr 63 | 64 | sys.stdout = self 65 | sys.stderr = self 66 | 67 | def __enter__(self) -> "Logger": 68 | return self 69 | 70 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 71 | self.close() 72 | 73 | def write(self, text: str) -> None: 74 | """Write text to stdout (and a file) and optionally flush.""" 75 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 76 | return 77 | 78 | if self.file is not None: 79 | self.file.write(text) 80 | 81 | self.stdout.write(text) 82 | 83 | if self.should_flush: 84 | self.flush() 85 | 86 | def flush(self) -> None: 87 | """Flush written text to both stdout and a file, if open.""" 88 | if self.file is not None: 89 | self.file.flush() 90 | 91 | self.stdout.flush() 92 | 93 | def close(self) -> None: 94 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 95 | self.flush() 96 | 97 | # if using multiple loggers, prevent closing in wrong order 98 | if sys.stdout is self: 99 | sys.stdout = self.stdout 100 | if sys.stderr is self: 101 | sys.stderr = self.stderr 102 | 103 | if self.file is not None: 104 | self.file.close() 105 | 106 | 107 | # Small util functions 108 | # ------------------------------------------------------------------------------------------ 109 | 110 | 111 | def format_time(seconds: Union[int, float]) -> str: 112 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 113 | s = int(np.rint(seconds)) 114 | 115 | if s < 60: 116 | return "{0}s".format(s) 117 | elif s < 60 * 60: 118 | return "{0}m {1:02}s".format(s // 60, s % 60) 119 | elif s < 24 * 60 * 60: 120 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 121 | else: 122 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 123 | 124 | 125 | def ask_yes_no(question: str) -> bool: 126 | """Ask the user the question until the user inputs a valid answer.""" 127 | while True: 128 | try: 129 | print("{0} [y/n]".format(question)) 130 | return strtobool(input().lower()) 131 | except ValueError: 132 | pass 133 | 134 | 135 | def tuple_product(t: Tuple) -> Any: 136 | """Calculate the product of the tuple elements.""" 137 | result = 1 138 | 139 | for v in t: 140 | result *= v 141 | 142 | return result 143 | 144 | 145 | _str_to_ctype = { 146 | "uint8": ctypes.c_ubyte, 147 | "uint16": ctypes.c_uint16, 148 | "uint32": ctypes.c_uint32, 149 | "uint64": ctypes.c_uint64, 150 | "int8": ctypes.c_byte, 151 | "int16": ctypes.c_int16, 152 | "int32": ctypes.c_int32, 153 | "int64": ctypes.c_int64, 154 | "float32": ctypes.c_float, 155 | "float64": ctypes.c_double 156 | } 157 | 158 | 159 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 160 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 161 | type_str = None 162 | 163 | if isinstance(type_obj, str): 164 | type_str = type_obj 165 | elif hasattr(type_obj, "__name__"): 166 | type_str = type_obj.__name__ 167 | elif hasattr(type_obj, "name"): 168 | type_str = type_obj.name 169 | else: 170 | raise RuntimeError("Cannot infer type name from input") 171 | 172 | assert type_str in _str_to_ctype.keys() 173 | 174 | my_dtype = np.dtype(type_str) 175 | my_ctype = _str_to_ctype[type_str] 176 | 177 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 178 | 179 | return my_dtype, my_ctype 180 | 181 | 182 | def is_pickleable(obj: Any) -> bool: 183 | try: 184 | with io.BytesIO() as stream: 185 | pickle.dump(obj, stream) 186 | return True 187 | except: 188 | return False 189 | 190 | 191 | # Functionality to import modules/objects by name, and call functions by name 192 | # ------------------------------------------------------------------------------------------ 193 | 194 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 195 | """Searches for the underlying module behind the name to some python object. 196 | Returns the module and the object name (original name with module part removed).""" 197 | 198 | # allow convenience shorthands, substitute them by full names 199 | obj_name = re.sub("^np.", "numpy.", obj_name) 200 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 201 | 202 | # list alternatives for (module_name, local_obj_name) 203 | parts = obj_name.split(".") 204 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 205 | 206 | # try each alternative in turn 207 | for module_name, local_obj_name in name_pairs: 208 | try: 209 | module = importlib.import_module(module_name) # may raise ImportError 210 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 211 | return module, local_obj_name 212 | except: 213 | pass 214 | 215 | # maybe some of the modules themselves contain errors? 216 | for module_name, _local_obj_name in name_pairs: 217 | try: 218 | importlib.import_module(module_name) # may raise ImportError 219 | except ImportError: 220 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 221 | raise 222 | 223 | # maybe the requested attribute is missing? 224 | for module_name, local_obj_name in name_pairs: 225 | try: 226 | module = importlib.import_module(module_name) # may raise ImportError 227 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 228 | except ImportError: 229 | pass 230 | 231 | # we are out of luck, but we have no idea why 232 | raise ImportError(obj_name) 233 | 234 | 235 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 236 | """Traverses the object name and returns the last (rightmost) python object.""" 237 | if obj_name == '': 238 | return module 239 | obj = module 240 | for part in obj_name.split("."): 241 | obj = getattr(obj, part) 242 | return obj 243 | 244 | 245 | def get_obj_by_name(name: str) -> Any: 246 | """Finds the python object with the given name.""" 247 | module, obj_name = get_module_from_obj_name(name) 248 | return get_obj_from_module(module, obj_name) 249 | 250 | 251 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 252 | """Finds the python object with the given name and calls it as a function.""" 253 | assert func_name is not None 254 | func_obj = get_obj_by_name(func_name) 255 | assert callable(func_obj) 256 | return func_obj(*args, **kwargs) 257 | 258 | 259 | def get_module_dir_by_obj_name(obj_name: str) -> str: 260 | """Get the directory path of the module containing the given object name.""" 261 | module, _ = get_module_from_obj_name(obj_name) 262 | return os.path.dirname(inspect.getfile(module)) 263 | 264 | 265 | def is_top_level_function(obj: Any) -> bool: 266 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 267 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 268 | 269 | 270 | def get_top_level_function_name(obj: Any) -> str: 271 | """Return the fully-qualified name of a top-level function.""" 272 | assert is_top_level_function(obj) 273 | return obj.__module__ + "." + obj.__name__ 274 | 275 | 276 | # File system helpers 277 | # ------------------------------------------------------------------------------------------ 278 | 279 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 280 | """List all files recursively in a given directory while ignoring given file and directory names. 281 | Returns list of tuples containing both absolute and relative paths.""" 282 | assert os.path.isdir(dir_path) 283 | base_name = os.path.basename(os.path.normpath(dir_path)) 284 | 285 | if ignores is None: 286 | ignores = [] 287 | 288 | result = [] 289 | 290 | for root, dirs, files in os.walk(dir_path, topdown=True): 291 | for ignore_ in ignores: 292 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 293 | 294 | # dirs need to be edited in-place 295 | for d in dirs_to_remove: 296 | dirs.remove(d) 297 | 298 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 299 | 300 | absolute_paths = [os.path.join(root, f) for f in files] 301 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 302 | 303 | if add_base_to_relative: 304 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 305 | 306 | assert len(absolute_paths) == len(relative_paths) 307 | result += zip(absolute_paths, relative_paths) 308 | 309 | return result 310 | 311 | 312 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 313 | """Takes in a list of tuples of (src, dst) paths and copies files. 314 | Will create all necessary directories.""" 315 | for file in files: 316 | target_dir_name = os.path.dirname(file[1]) 317 | 318 | # will create all intermediate-level directories 319 | if not os.path.exists(target_dir_name): 320 | os.makedirs(target_dir_name) 321 | 322 | shutil.copyfile(file[0], file[1]) 323 | 324 | 325 | # URL helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 329 | """Determine whether the given object is a valid URL string.""" 330 | if not isinstance(obj, str) or not "://" in obj: 331 | return False 332 | if allow_file_urls and obj.startswith('file:///'): 333 | return True 334 | try: 335 | res = requests.compat.urlparse(obj) 336 | if not res.scheme or not res.netloc or not "." in res.netloc: 337 | return False 338 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 339 | if not res.scheme or not res.netloc or not "." in res.netloc: 340 | return False 341 | except: 342 | return False 343 | return True 344 | 345 | 346 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: 347 | """Download the given URL and return a binary-mode file object to access the data.""" 348 | assert is_url(url, allow_file_urls=True) 349 | assert num_attempts >= 1 350 | 351 | # Handle file URLs. 352 | if url.startswith('file:///'): 353 | return open(url[len('file:///'):], "rb") 354 | 355 | # Lookup from cache. 356 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 357 | if cache_dir is not None: 358 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 359 | if len(cache_files) == 1: 360 | return open(cache_files[0], "rb") 361 | 362 | # Download. 363 | url_name = None 364 | url_data = None 365 | with requests.Session() as session: 366 | if verbose: 367 | print("Downloading %s ..." % url, end="", flush=True) 368 | for attempts_left in reversed(range(num_attempts)): 369 | try: 370 | with session.get(url) as res: 371 | res.raise_for_status() 372 | if len(res.content) == 0: 373 | raise IOError("No data received") 374 | 375 | if len(res.content) < 8192: 376 | content_str = res.content.decode("utf-8") 377 | if "download_warning" in res.headers.get("Set-Cookie", ""): 378 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 379 | if len(links) == 1: 380 | url = requests.compat.urljoin(url, links[0]) 381 | raise IOError("Google Drive virus checker nag") 382 | if "Google Drive - Quota exceeded" in content_str: 383 | raise IOError("Google Drive download quota exceeded -- please try again later") 384 | 385 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 386 | url_name = match[1] if match else url 387 | url_data = res.content 388 | if verbose: 389 | print(" done") 390 | break 391 | except: 392 | if not attempts_left: 393 | if verbose: 394 | print(" failed") 395 | raise 396 | if verbose: 397 | print(".", end="", flush=True) 398 | 399 | # Save to cache. 400 | if cache_dir is not None: 401 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 402 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 403 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 404 | os.makedirs(cache_dir, exist_ok=True) 405 | with open(temp_file, "wb") as f: 406 | f.write(url_data) 407 | os.replace(temp_file, cache_file) # atomic 408 | 409 | # Return data as file object. 410 | return io.BytesIO(url_data) 411 | -------------------------------------------------------------------------------- /imgs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/demo.gif -------------------------------------------------------------------------------- /imgs/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/example_image.jpg -------------------------------------------------------------------------------- /imgs/example_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/example_mask.jpg -------------------------------------------------------------------------------- /imgs/grid-main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/grid-main.jpg -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Frechet Inception Distance (FID).""" 8 | 9 | import os 10 | import numpy as np 11 | import scipy 12 | import tensorflow as tf 13 | import dnnlib.tflib as tflib 14 | 15 | from metrics import metric_base 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class FID(metric_base.MetricBase): 21 | def __init__(self, num_images, minibatch_per_gpu, ref_train=False, ref_samples=None, **kwargs): 22 | super().__init__(**kwargs) 23 | self.num_images = num_images 24 | self.minibatch_per_gpu = minibatch_per_gpu 25 | self.ref_train = ref_train 26 | self.ref_samples = num_images if ref_samples is None else ref_samples 27 | 28 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 29 | minibatch_size = num_gpus * self.minibatch_per_gpu 30 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl 31 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 32 | 33 | # Calculate statistics for reals. 34 | cache_file = self._get_cache_file_for_reals(num_images=self.ref_samples) 35 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 36 | if os.path.isfile(cache_file): 37 | mu_real, sigma_real = misc.load_pkl(cache_file) 38 | else: 39 | real_activations = np.empty([self.ref_samples, inception.output_shape[1]], dtype=np.float32) 40 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size, is_training=self.ref_train)): 41 | begin = idx * minibatch_size 42 | end = min(begin + minibatch_size, self.ref_samples) 43 | real_activations[begin:end] = inception.run(images[:end-begin, :3], num_gpus=num_gpus, assume_frozen=True) 44 | if end == self.ref_samples: 45 | break 46 | mu_real = np.mean(real_activations, axis=0) 47 | sigma_real = np.cov(real_activations, rowvar=False) 48 | misc.save_pkl((mu_real, sigma_real), cache_file) 49 | 50 | # Construct TensorFlow graph. 51 | self._configure(self.minibatch_per_gpu) 52 | result_expr = [] 53 | for gpu_idx in range(num_gpus): 54 | with tf.device('/gpu:%d' % gpu_idx): 55 | Gs_clone = Gs.clone() 56 | inception_clone = inception.clone() 57 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 58 | reals, labels = self._get_minibatch_tf() 59 | reals = tflib.convert_images_from_uint8(reals) 60 | masks = self._get_random_masks_tf() 61 | images = Gs_clone.get_output_for(latents, labels, reals, masks, **Gs_kwargs) 62 | images = images[:, :3] 63 | images = tflib.convert_images_to_uint8(images) 64 | result_expr.append(inception_clone.get_output_for(images)) 65 | 66 | # Calculate statistics for fakes. 67 | for begin in range(0, self.num_images, minibatch_size): 68 | self._report_progress(begin, self.num_images) 69 | end = min(begin + minibatch_size, self.num_images) 70 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] 71 | mu_fake = np.mean(activations, axis=0) 72 | sigma_fake = np.cov(activations, rowvar=False) 73 | 74 | # Calculate FID. 75 | m = np.square(mu_fake - mu_real).sum() 76 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member 77 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 78 | self._report_result(np.real(dist)) 79 | 80 | #---------------------------------------------------------------------------- 81 | -------------------------------------------------------------------------------- /metrics/inception_discriminative_score.py: -------------------------------------------------------------------------------- 1 | # Large Scale Image Completion via Co-Modulated Generative Adversarial Networks 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://openreview.net/pdf?id=sSjqmfsk95O 4 | 5 | """Paired/Unpaired Inception Discriminative Score (P-IDS/U-IDS).""" 6 | 7 | import os 8 | from tqdm import tqdm 9 | import numpy as np 10 | import scipy 11 | import sklearn.svm 12 | import tensorflow as tf 13 | import dnnlib 14 | import dnnlib.tflib as tflib 15 | 16 | from metrics import metric_base 17 | from training import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | class IDS(metric_base.MetricBase): 22 | def __init__(self, num_images, minibatch_per_gpu, hole_range=[0,1], **kwargs): 23 | super().__init__(**kwargs) 24 | self.num_images = num_images 25 | self.minibatch_per_gpu = minibatch_per_gpu 26 | self.hole_range = hole_range 27 | 28 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 29 | minibatch_size = num_gpus * self.minibatch_per_gpu 30 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl 31 | real_activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 32 | fake_activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) 33 | 34 | # Construct TensorFlow graph. 35 | self._configure(self.minibatch_per_gpu, hole_range=self.hole_range) 36 | real_img_expr = [] 37 | fake_img_expr = [] 38 | real_result_expr = [] 39 | fake_result_expr = [] 40 | for gpu_idx in range(num_gpus): 41 | with tf.device('/gpu:%d' % gpu_idx): 42 | Gs_clone = Gs.clone() 43 | inception_clone = inception.clone() 44 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 45 | reals, labels = self._get_minibatch_tf() 46 | reals_tf = tflib.convert_images_from_uint8(reals) 47 | masks = self._get_random_masks_tf() 48 | fakes = Gs_clone.get_output_for(latents, labels, reals_tf, masks, **Gs_kwargs) 49 | fakes = tflib.convert_images_to_uint8(fakes[:, :3]) 50 | reals = tflib.convert_images_to_uint8(reals_tf[:, :3]) 51 | real_img_expr.append(reals) 52 | fake_img_expr.append(fakes) 53 | real_result_expr.append(inception_clone.get_output_for(reals)) 54 | fake_result_expr.append(inception_clone.get_output_for(fakes)) 55 | 56 | for begin in tqdm(range(0, self.num_images, minibatch_size)): 57 | self._report_progress(begin, self.num_images) 58 | end = min(begin + minibatch_size, self.num_images) 59 | real_results, fake_results = tflib.run([real_result_expr, fake_result_expr]) 60 | real_activations[begin:end] = np.concatenate(real_results, axis=0)[:end-begin] 61 | fake_activations[begin:end] = np.concatenate(fake_results, axis=0)[:end-begin] 62 | 63 | # Calculate FID conviniently. 64 | mu_real = np.mean(real_activations, axis=0) 65 | sigma_real = np.cov(real_activations, rowvar=False) 66 | mu_fake = np.mean(fake_activations, axis=0) 67 | sigma_fake = np.cov(fake_activations, rowvar=False) 68 | m = np.square(mu_fake - mu_real).sum() 69 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) 70 | dist = m + np.trace(sigma_fake + sigma_real - 2*s) 71 | self._report_result(np.real(dist), suffix='-FID') 72 | 73 | svm = sklearn.svm.LinearSVC(dual=False) 74 | svm_inputs = np.concatenate([real_activations, fake_activations]) 75 | svm_targets = np.array([1] * real_activations.shape[0] + [0] * fake_activations.shape[0]) 76 | svm.fit(svm_inputs, svm_targets) 77 | self._report_result(1 - svm.score(svm_inputs, svm_targets), suffix='-U') 78 | real_outputs = svm.decision_function(real_activations) 79 | fake_outputs = svm.decision_function(fake_activations) 80 | self._report_result(np.mean(fake_outputs > real_outputs), suffix='-P') 81 | 82 | #---------------------------------------------------------------------------- 83 | -------------------------------------------------------------------------------- /metrics/learned_perceptual_image_patch_similarity.py: -------------------------------------------------------------------------------- 1 | """Learned Perceptual Image Patch Similarity (LPIPS).""" 2 | 3 | import os 4 | import numpy as np 5 | import scipy 6 | import tensorflow as tf 7 | import dnnlib.tflib as tflib 8 | 9 | from metrics import metric_base 10 | from training import misc 11 | 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | class LPIPS(metric_base.MetricBase): 16 | def __init__(self, num_pairs=2000, minibatch_per_gpu=8, **kwargs): 17 | super().__init__(**kwargs) 18 | self.num_pairs = num_pairs 19 | self.minibatch_per_gpu = minibatch_per_gpu 20 | 21 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 22 | minibatch_size = num_gpus * self.minibatch_per_gpu 23 | 24 | graph_def = tf.GraphDef() 25 | with misc.open_file_or_url('http://rail.eecs.berkeley.edu/models/lpips/net-lin_alex_v0.1.pb') as f: 26 | graph_def.ParseFromString(f.read()) 27 | 28 | # Construct TensorFlow graph. 29 | self._configure(self.minibatch_per_gpu) 30 | result_expr = [] 31 | for gpu_idx in range(num_gpus): 32 | def auto_gpu(opr): 33 | if opr.type in ['SparseToDense', 'Tile', 'GatherV2', 'Pack']: 34 | return '/cpu:0' 35 | else: 36 | return '/gpu:%d' % gpu_idx 37 | with tf.device(auto_gpu): 38 | Gs_clone = Gs.clone() 39 | reals, labels = self._get_minibatch_tf() 40 | reals = tflib.convert_images_from_uint8(reals) 41 | masks = self._get_random_masks_tf() 42 | 43 | latents0 = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 44 | fakes0 = Gs_clone.get_output_for(latents0, labels, reals, masks, **Gs_kwargs)[:, :3, :, :] 45 | fakes0 = tf.clip_by_value(fakes0, -1.0, 1.0) 46 | 47 | latents1 = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) 48 | fakes1 = Gs_clone.get_output_for(latents1, labels, reals, masks, **Gs_kwargs)[:, :3, :, :] 49 | fakes1 = tf.clip_by_value(fakes1, -1.0, 1.0) 50 | 51 | distance, = tf.import_graph_def( 52 | graph_def, 53 | input_map={'0:0': fakes0, '1:0': fakes1}, 54 | return_elements = ['Reshape_10'] 55 | ) 56 | result_expr.append(distance.outputs) 57 | 58 | # Run metric 59 | results = [] 60 | for begin in range(0, self.num_pairs, minibatch_size): 61 | self._report_progress(begin, self.num_pairs) 62 | res = tflib.run(result_expr) 63 | results.append(np.reshape(res, [-1])) 64 | results = np.concatenate(results) 65 | self._report_result(np.mean(results)) 66 | self._report_result(np.std(results), suffix='-var') 67 | 68 | #---------------------------------------------------------------------------- 69 | -------------------------------------------------------------------------------- /metrics/metric_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Common definitions for GAN metrics.""" 8 | 9 | import os 10 | import time 11 | import hashlib 12 | import numpy as np 13 | import tensorflow as tf 14 | import dnnlib 15 | import dnnlib.tflib as tflib 16 | 17 | from training import misc 18 | from training import dataset 19 | 20 | #---------------------------------------------------------------------------- 21 | # Base class for metrics. 22 | 23 | class MetricBase: 24 | def __init__(self, name): 25 | self.name = name 26 | self._dataset_obj = None 27 | self._progress_lo = None 28 | self._progress_hi = None 29 | self._progress_max = None 30 | self._progress_sec = None 31 | self._progress_time = None 32 | self._reset() 33 | 34 | def close(self): 35 | self._reset() 36 | 37 | def _reset(self, network_pkl=None, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None): 38 | if self._dataset_obj is not None: 39 | self._dataset_obj.close() 40 | 41 | self._network_pkl = network_pkl 42 | self._data_dir = data_dir 43 | self._dataset_args = dataset_args 44 | self._dataset_obj = None 45 | self._mirror_augment = mirror_augment 46 | self._eval_time = 0 47 | self._results = [] 48 | 49 | if (dataset_args is None or mirror_augment is None) and run_dir is not None: 50 | run_config = misc.parse_config_for_previous_run(run_dir) 51 | self._dataset_args = dict(run_config['dataset']) 52 | self._dataset_args['shuffle_mb'] = 0 53 | self._mirror_augment = run_config['train'].get('mirror_augment', False) 54 | 55 | def configure_progress_reports(self, plo, phi, pmax, psec=15): 56 | self._progress_lo = plo 57 | self._progress_hi = phi 58 | self._progress_max = pmax 59 | self._progress_sec = psec 60 | 61 | def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, 62 | num_repeats=1, Gs_kwargs=dict(is_validation=True), resume_with_new_nets=False, truncations=[None]): 63 | self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) 64 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager 65 | self._report_progress(0, 1) 66 | _G, _D, Gs = misc.load_pkl(self._network_pkl) 67 | 68 | if resume_with_new_nets: 69 | dataset = self._get_dataset_obj() 70 | G = dnnlib.tflib.Network('G', num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, 71 | func_name='training.co_mod_gan.G_main', pix2pix=dataset.pix2pix) 72 | Gs_new = G.clone('Gs') 73 | Gs_new.copy_vars_from(Gs) 74 | Gs = Gs_new 75 | 76 | for t in truncations: 77 | print('truncation={}'.format(t)) 78 | self._results = [] 79 | time_begin = time.time() 80 | 81 | Gs_kwargs.update(truncation_psi_val=t) 82 | self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) 83 | self._report_progress(1, 1) 84 | 85 | if num_repeats > 1: 86 | records = [dnnlib.EasyDict(value=[res.value], suffix=res.suffix, fmt=res.fmt) for res in self._results] 87 | for i in range(1, num_repeats): 88 | print(self.get_result_str().strip()) 89 | self._results = [] 90 | self._report_progress(0, 1) 91 | self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) 92 | self._report_progress(1, 1) 93 | for rec, res in zip(records, self._results): 94 | rec.value.append(res.value) 95 | 96 | self._results = [] 97 | for rec in records: 98 | self._report_result(np.mean(rec.value), rec.suffix, rec.fmt) 99 | self._report_result(np.std(rec.value), rec.suffix + '-std', rec.fmt) 100 | 101 | self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init 102 | 103 | if log_results: 104 | if run_dir is not None: 105 | log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name) 106 | with dnnlib.util.Logger(log_file, 'a'): 107 | print(self.get_result_str().strip()) 108 | else: 109 | print(self.get_result_str().strip()) 110 | 111 | def get_result_str(self): 112 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] 113 | if len(network_name) > 29: 114 | network_name = '...' + network_name[-26:] 115 | result_str = '%-30s' % network_name 116 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) 117 | for res in self._results: 118 | result_str += ' ' + self.name + res.suffix + ' ' 119 | result_str += res.fmt % res.value 120 | return result_str 121 | 122 | def update_autosummaries(self): 123 | for res in self._results: 124 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) 125 | 126 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 127 | raise NotImplementedError # to be overridden by subclasses 128 | 129 | def _report_result(self, value, suffix='', fmt='%-10.4f'): 130 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] 131 | 132 | def _report_progress(self, pcur, pmax, status_str=''): 133 | if self._progress_lo is None or self._progress_hi is None or self._progress_max is None: 134 | return 135 | t = time.time() 136 | if self._progress_sec is not None and self._progress_time is not None and t < self._progress_time + self._progress_sec: 137 | return 138 | self._progress_time = t 139 | val = self._progress_lo + (pcur / pmax) * (self._progress_hi - self._progress_lo) 140 | dnnlib.RunContext.get().update(status_str, int(val), self._progress_max) 141 | 142 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs): 143 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) 144 | all_args.update(self._dataset_args) 145 | all_args.update(kwargs) 146 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) 147 | dataset_name = self._dataset_args.get('tfrecord_dir', None) or self._dataset_args.get('h5_file', None) 148 | dataset_name = os.path.splitext(os.path.basename(dataset_name))[0] 149 | return os.path.join('.stylegan2-cache', '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) 150 | 151 | def _get_dataset_obj(self): 152 | if self._dataset_obj is None: 153 | self._dataset_obj = dataset.load_dataset(data_dir=self._data_dir, **self._dataset_args) 154 | return self._dataset_obj 155 | 156 | def _iterate_reals(self, minibatch_size, is_training=True): 157 | dataset_obj = self._get_dataset_obj() 158 | while True: 159 | if is_training: 160 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size) 161 | else: 162 | images, _labels = dataset_obj.get_minibatch_val_np(minibatch_size) 163 | if self._mirror_augment: 164 | images = misc.apply_mirror_augment(images) 165 | yield images 166 | 167 | def _configure(self, minibatch_size, hole_range=[0, 1]): 168 | return self._get_dataset_obj().configure(minibatch_size, hole_range=hole_range) 169 | 170 | def _get_minibatch_tf(self): 171 | return self._get_dataset_obj().get_minibatch_val_tf() 172 | 173 | def _get_random_masks_tf(self): 174 | return self._get_dataset_obj().get_random_masks_tf() 175 | 176 | def _get_random_labels_tf(self, minibatch_size): 177 | return self._get_dataset_obj().get_random_labels_tf(minibatch_size) 178 | 179 | #---------------------------------------------------------------------------- 180 | # Group of multiple metrics. 181 | 182 | class MetricGroup: 183 | def __init__(self, metric_kwarg_list): 184 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] 185 | 186 | def run(self, *args, **kwargs): 187 | for metric in self.metrics: 188 | metric.run(*args, **kwargs) 189 | 190 | def get_result_str(self): 191 | return ' '.join(metric.get_result_str() for metric in self.metrics) 192 | 193 | def update_autosummaries(self): 194 | for metric in self.metrics: 195 | metric.update_autosummaries() 196 | 197 | #---------------------------------------------------------------------------- 198 | # Dummy metric for debugging purposes. 199 | 200 | class DummyMetric(MetricBase): 201 | def _evaluate(self, Gs, Gs_kwargs, num_gpus): 202 | _ = Gs, Gs_kwargs, num_gpus 203 | self._report_result(0.0) 204 | 205 | #---------------------------------------------------------------------------- 206 | -------------------------------------------------------------------------------- /metrics/metric_defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Default metric definitions.""" 8 | 9 | from dnnlib import EasyDict 10 | 11 | #---------------------------------------------------------------------------- 12 | 13 | metric_defaults = EasyDict([(args.name, args) for args in [ 14 | EasyDict(name='fid200-rt-shoes', func_name='metrics.frechet_inception_distance.FID', num_images=200, minibatch_per_gpu=1, ref_train=True, ref_samples=49825), 15 | EasyDict(name='fid200-rt-handbags', func_name='metrics.frechet_inception_distance.FID', num_images=200, minibatch_per_gpu=1, ref_train=True, ref_samples=138567), 16 | EasyDict(name='fid5k', func_name='metrics.frechet_inception_distance.FID', num_images=5000, minibatch_per_gpu=8), 17 | EasyDict(name='fid10k', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8), 18 | EasyDict(name='fid10k-b1', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=1), 19 | EasyDict(name='fid10k-h0', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.0, 0.2]), 20 | EasyDict(name='fid10k-h1', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.2, 0.4]), 21 | EasyDict(name='fid10k-h2', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.4, 0.6]), 22 | EasyDict(name='fid10k-h3', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.6, 0.8]), 23 | EasyDict(name='fid10k-h4', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.8, 1.0]), 24 | EasyDict(name='fid36k5', func_name='metrics.frechet_inception_distance.FID',num_images=36500, minibatch_per_gpu=8), 25 | EasyDict(name='fid36k5-h0', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.0, 0.2]), 26 | EasyDict(name='fid36k5-h1', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.2, 0.4]), 27 | EasyDict(name='fid36k5-h2', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.4, 0.6]), 28 | EasyDict(name='fid36k5-h3', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.6, 0.8]), 29 | EasyDict(name='fid36k5-h4', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.8, 1.0]), 30 | EasyDict(name='fid50k', func_name='metrics.frechet_inception_distance.FID', num_images=50000, minibatch_per_gpu=8), 31 | EasyDict(name='ids5k', func_name='metrics.inception_discriminator_score.IDS', num_images=5000, minibatch_per_gpu=8), 32 | EasyDict(name='ids10k', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8), 33 | EasyDict(name='ids10k-b1', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=1), 34 | EasyDict(name='ids10k-h0', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.0, 0.2]), 35 | EasyDict(name='ids10k-h1', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.2, 0.4]), 36 | EasyDict(name='ids10k-h2', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.4, 0.6]), 37 | EasyDict(name='ids10k-h3', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.6, 0.8]), 38 | EasyDict(name='ids10k-h4', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.8, 1.0]), 39 | EasyDict(name='ids36k5', func_name='metrics.inception_discriminative_score.IDS',num_images=36500, minibatch_per_gpu=8), 40 | EasyDict(name='ids36k5-h0', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.0, 0.2]), 41 | EasyDict(name='ids36k5-h1', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.2, 0.4]), 42 | EasyDict(name='ids36k5-h2', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.4, 0.6]), 43 | EasyDict(name='ids36k5-h3', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.6, 0.8]), 44 | EasyDict(name='ids36k5-h4', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.8, 1.0]), 45 | EasyDict(name='ids50k', func_name='metrics.inception_discriminative_score.IDS', num_images=50000, minibatch_per_gpu=8), 46 | EasyDict(name='lpips2k', func_name='metrics.learned_perceptual_image_patch_similarity.LPIPS', num_pairs=2000, minibatch_per_gpu=8), 47 | ]]) 48 | 49 | #---------------------------------------------------------------------------- 50 | -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('-c', '--checkpoint', required=True) 6 | parser.add_argument('-d', '--data-dir', required=True) 7 | parser.add_argument('-s', '--save-dir', default='images') 8 | parser.add_argument('-w', '--window-size', type=int, default=512) 9 | args = parser.parse_args() 10 | 11 | import tkinter as tk 12 | from PIL import Image, ImageTk, ImageDraw 13 | import numpy as np 14 | import cv2 15 | import hashlib 16 | 17 | import dnnlib 18 | 19 | from training import misc, dataset 20 | 21 | 22 | class App(tk.Tk): 23 | def __init__(self): 24 | super().__init__() 25 | self.state = -1 26 | self.canvas = tk.Canvas(self, bg='gray', height=args.window_size, width=args.window_size) 27 | self.canvas.bind("", self.L_press) 28 | self.canvas.bind("", self.L_release) 29 | self.canvas.bind("", self.L_move) 30 | self.canvas.bind("", self.R_press) 31 | self.canvas.bind("", self.R_release) 32 | self.canvas.bind("", self.R_move) 33 | self.canvas.bind("", self.key_down) 34 | self.canvas.bind("", self.key_up) 35 | self.canvas.pack() 36 | 37 | self.canvas.focus_set() 38 | self.canvas_image = self.canvas.create_image(0, 0, anchor='nw') 39 | 40 | dnnlib.tflib.init_tf() 41 | self.dataset = dataset.load_dataset(tfrecord_dir=args.data_dir, verbose=True, shuffle_mb=0) 42 | 43 | self.networks = [] 44 | self.truncations = [] 45 | self.model_names = [] 46 | for ckpt in args.checkpoint.split(','): 47 | if ':' in ckpt: 48 | ckpt, truncation = ckpt.split(':') 49 | truncation = float(truncation) 50 | else: 51 | truncation = None 52 | 53 | _, _, Gs = misc.load_pkl(ckpt) 54 | 55 | self.networks.append(Gs) 56 | self.truncations.append(truncation) 57 | self.model_names.append(os.path.basename(os.path.splitext(ckpt)[0])) 58 | 59 | self.key_list = ['q', 'w', 'e', 'r', 't', 'y', 'u', 'i', 'o', 'p'][:len(self.networks)] 60 | self.image_id = -1 61 | 62 | self.new_image() 63 | self.display() 64 | 65 | def generate(self, idx=0): 66 | self.cur_idx = idx 67 | latent = np.random.randn(1, *self.networks[idx].input_shape[1:]) 68 | real = misc.adjust_dynamic_range(self.real_image, [0, 255], [-1, 1]) 69 | fake = self.networks[idx].run(latent, self.label, real, self.mask, truncation_psi=self.truncations[idx]) 70 | self.fake_image = misc.adjust_dynamic_range(fake, [-1, 1], [0, 255]).clip(0, 255).astype(np.uint8) 71 | 72 | def new_image(self): 73 | self.image_id += 1 74 | self.save_count = 0 75 | self.real_image, self.label = self.dataset.get_minibatch_val_np(1) 76 | self.resolution = self.real_image.shape[-1] 77 | self.mask = np.ones((1, 1, self.resolution, self.resolution), np.uint8) 78 | self.mask_history = [self.mask] 79 | 80 | def display(self, state=0): 81 | if state != self.state: 82 | self.last_state = self.state 83 | self.state = state 84 | self.image = self.real_image if self.state == 1 else self.fake_image if self.state == 2 else self.real_image * self.mask 85 | self.image_for_display = np.transpose(self.image[0, :3], (1, 2, 0)) 86 | self.image_for_display_resized = cv2.resize(self.image_for_display, (args.window_size, args.window_size)) 87 | self.tkimage = ImageTk.PhotoImage(image=Image.fromarray(self.image_for_display_resized)) 88 | self.canvas.itemconfig(self.canvas_image, image=self.tkimage) 89 | 90 | def save_image(self): 91 | folder_name = os.path.join(args.save_dir, '-'.join([os.path.basename(args.data_dir), str(self.image_id), hashlib.sha1(self.mask.tostring()).hexdigest()[:6]])) 92 | if not os.path.exists(folder_name): 93 | os.makedirs(folder_name) 94 | self.save_count = 0 95 | for img, name in [[self.real_image, 'real'], [self.real_image * self.mask, 'masked']]: 96 | cv2.imwrite(os.path.join(folder_name, name + '.jpg'), np.transpose(img[0, :3], (1, 2, 0))[..., ::-1]) 97 | if self.state == 2: 98 | cv2.imwrite(os.path.join(folder_name, '-'.join([self.model_names[self.cur_idx], str(self.save_count)]) + '.jpg'), self.image_for_display[..., ::-1]) 99 | self.save_count += 1 100 | 101 | def get_pos(self, event): 102 | return (int(event.x * self.resolution / args.window_size), int(event.y * self.resolution / args.window_size)) 103 | 104 | def L_press(self, event): 105 | self.last_pos = self.get_pos(event) 106 | 107 | def L_move(self, event): 108 | a = self.last_pos 109 | b = self.get_pos(event) 110 | width = 30 111 | img = Image.fromarray(self.mask[0, 0]) 112 | draw = ImageDraw.Draw(img) 113 | draw.line([a, b], fill=0, width=width) 114 | draw.ellipse((b[0] - width // 2, b[1] - width // 2, b[0] + width // 2, b[1] + width // 2), fill=0) 115 | self.mask = np.array(img)[np.newaxis, np.newaxis, ...] 116 | self.display() 117 | self.last_pos = b 118 | 119 | def L_release(self, event): 120 | self.L_move(event) 121 | self.mask_history.append(self.mask) 122 | 123 | def R_press(self, event): 124 | self.last_pos = self.get_pos(event) 125 | 126 | def R_move(self, event): 127 | a = self.last_pos 128 | b = self.get_pos(event) 129 | self.mask = self.mask_history[-1].copy() 130 | self.mask[0, 0, max(min(a[1], b[1]), 0): max(a[1], b[1]), max(min(a[0], b[0]), 0): max(a[0], b[0])] = 0 131 | self.display() 132 | 133 | def R_release(self, event): 134 | self.R_move(event) 135 | self.mask_history.append(self.mask) 136 | 137 | def key_down(self, event): 138 | if event.keysym == 'z': 139 | if len(self.mask_history) > 1: 140 | self.mask_history.pop() 141 | self.mask = self.mask_history[-1] 142 | self.display() 143 | elif event.keysym == 'space': 144 | self.generate() 145 | self.display(2) 146 | elif event.keysym in self.key_list: 147 | self.generate(self.key_list.index(event.keysym)) 148 | self.display(2) 149 | elif event.keysym == 's': 150 | self.save_image() 151 | elif event.keysym == 'Return': 152 | self.new_image() 153 | self.display() 154 | elif event.keysym == '1': 155 | self.display(1) 156 | elif event.keysym == '2': 157 | self.display(0) 158 | 159 | def key_up(self, event): 160 | if event.keysym in ['1', '2']: 161 | self.display(self.last_state) 162 | 163 | def main(): 164 | app = App() 165 | app.mainloop() 166 | 167 | if __name__ == "__main__": 168 | main() -------------------------------------------------------------------------------- /run_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import PIL.Image 4 | 5 | from dnnlib import tflib 6 | from training import misc 7 | 8 | def generate(checkpoint, image, mask, output, truncation): 9 | real = np.asarray(PIL.Image.open(image)).transpose([2, 0, 1]) 10 | real = misc.adjust_dynamic_range(real, [0, 255], [-1, 1]) 11 | mask = np.asarray(PIL.Image.open(mask).convert('1'), dtype=np.float32)[np.newaxis] 12 | 13 | tflib.init_tf() 14 | _, _, Gs = misc.load_pkl(checkpoint) 15 | latent = np.random.randn(1, *Gs.input_shape[1:]) 16 | fake = Gs.run(latent, None, real[np.newaxis], mask[np.newaxis], truncation_psi=truncation)[0] 17 | fake = misc.adjust_dynamic_range(fake, [-1, 1], [0, 255]) 18 | fake = fake.clip(0, 255).astype(np.uint8).transpose([1, 2, 0]) 19 | fake = PIL.Image.fromarray(fake) 20 | fake.save(output) 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('-c', '--checkpoint', help='Network checkpoint path', required=True) 25 | parser.add_argument('-i', '--image', help='Original image path', required=True) 26 | parser.add_argument('-m', '--mask', help='Mask path', required=True) 27 | parser.add_argument('-o', '--output', help='Output (inpainted) image path', required=True) 28 | parser.add_argument('-t', '--truncation', help='Truncation psi for the trade-off between quality and diversity. Defaults to 1.', default=None) 29 | 30 | args = parser.parse_args() 31 | generate(**vars(args)) 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /run_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import dnnlib 12 | import dnnlib.tflib as tflib 13 | 14 | from metrics import metric_base 15 | from metrics.metric_defaults import metric_defaults 16 | from training import misc 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def run(network_pkls, metrics, dataset, data_dir, mirror_augment, num_repeats, truncation, resume_with_new_nets): 21 | tflib.init_tf() 22 | dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0) 23 | num_gpus = dnnlib.submit_config.num_gpus 24 | truncations = [float(t) for t in truncation.split(',')] if truncation is not None else [None] 25 | 26 | for network_pkl in network_pkls.split(','): 27 | print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) 28 | metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics]) 29 | metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, 30 | num_gpus=num_gpus, num_repeats=num_repeats, resume_with_new_nets=resume_with_new_nets, truncations=truncations) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _str_to_bool(v): 35 | if isinstance(v, bool): 36 | return v 37 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 38 | return True 39 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 40 | return False 41 | else: 42 | raise argparse.ArgumentTypeError('Boolean value expected.') 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser( 48 | description='Run CoModGAN metrics.', 49 | formatter_class=argparse.RawDescriptionHelpFormatter 50 | ) 51 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 52 | parser.add_argument('--network', help='Network pickle filename', dest='network_pkls', required=True) 53 | parser.add_argument('--metrics', help='Metrics to compute (default: %(default)s)', default='ids10k', type=lambda x: x.split(',')) 54 | parser.add_argument('--dataset', help='Training dataset', required=True) 55 | parser.add_argument('--data-dir', help='Dataset root directory', required=True) 56 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, type=_str_to_bool, metavar='BOOL') 57 | parser.add_argument('--num-gpus', help='Number of GPUs to use', type=int, default=1, metavar='N') 58 | parser.add_argument('--num-repeats', type=int, default=1) 59 | parser.add_argument('--truncation', type=str, default=None) 60 | parser.add_argument('--resume-with-new-nets', default=False, action='store_true') 61 | 62 | args = parser.parse_args() 63 | 64 | if not os.path.exists(args.data_dir): 65 | print ('Error: dataset root directory does not exist.') 66 | sys.exit(1) 67 | 68 | kwargs = vars(args) 69 | sc = dnnlib.SubmitConfig() 70 | sc.num_gpus = kwargs.pop('num_gpus') 71 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 72 | sc.local.do_not_copy_source_files = True 73 | sc.run_dir_root = kwargs.pop('result_dir') 74 | sc.run_desc = 'run-metrics' 75 | dnnlib.submit_run(sc, 'run_metrics.run', **kwargs) 76 | 77 | #---------------------------------------------------------------------------- 78 | 79 | if __name__ == "__main__": 80 | main() 81 | 82 | #---------------------------------------------------------------------------- 83 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import argparse 8 | import copy 9 | import os 10 | import sys 11 | 12 | import dnnlib 13 | from dnnlib import EasyDict 14 | 15 | from metrics.metric_defaults import metric_defaults 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def run(dataset, data_dir, result_dir, num_gpus, total_kimg, mirror_augment, metrics, resume, resume_with_new_nets, disable_style_mod, disable_cond_mod): 20 | 21 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. 22 | G = EasyDict(func_name='training.co_mod_gan.G_main') # Options for generator network. 23 | D = EasyDict(func_name='training.co_mod_gan.D_co_mod_gan') # Options for discriminator network. 24 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. 25 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. 26 | G_loss = EasyDict(func_name='training.loss.G_masked_logistic_ns_l1') # Options for generator loss. 27 | D_loss = EasyDict(func_name='training.loss.D_masked_logistic_r1') # Options for discriminator loss. 28 | sched = EasyDict() # Options for TrainingSchedule. 29 | grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid(). 30 | sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). 31 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). 32 | 33 | train.data_dir = data_dir 34 | train.total_kimg = total_kimg 35 | train.mirror_augment = mirror_augment 36 | train.image_snapshot_ticks = train.network_snapshot_ticks = 10 37 | sched.G_lrate_base = sched.D_lrate_base = 0.002 38 | sched.minibatch_size_base = 32 39 | sched.minibatch_gpu_base = 4 40 | D_loss.gamma = 10 41 | metrics = [metric_defaults[x] for x in metrics] 42 | desc = 'co-mod-gan' 43 | 44 | desc += '-' + os.path.basename(dataset) 45 | dataset_args = EasyDict(tfrecord_dir=dataset) 46 | 47 | assert num_gpus in [1, 2, 4, 8] 48 | sc.num_gpus = num_gpus 49 | desc += '-%dgpu' % num_gpus 50 | 51 | if resume is not None: 52 | resume_kimg = int(os.path.basename(resume).replace('.pkl', '').split('-')[-1]) 53 | else: 54 | resume_kimg = 0 55 | 56 | if disable_style_mod: 57 | G.style_mod = False 58 | 59 | if disable_cond_mod: 60 | G.cond_mod = False 61 | 62 | sc.submit_target = dnnlib.SubmitTarget.LOCAL 63 | sc.local.do_not_copy_source_files = True 64 | kwargs = EasyDict(train) 65 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) 66 | kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) 67 | kwargs.update(resume_pkl=resume, resume_kimg=resume_kimg, resume_with_new_nets=resume_with_new_nets) 68 | kwargs.submit_config = copy.deepcopy(sc) 69 | kwargs.submit_config.run_dir_root = result_dir 70 | kwargs.submit_config.run_desc = desc 71 | dnnlib.submit_run(**kwargs) 72 | 73 | #---------------------------------------------------------------------------- 74 | 75 | def _str_to_bool(v): 76 | if isinstance(v, bool): 77 | return v 78 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 79 | return True 80 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 81 | return False 82 | else: 83 | raise argparse.ArgumentTypeError('Boolean value expected.') 84 | 85 | def _parse_comma_sep(s): 86 | if s is None or s.lower() == 'none' or s == '': 87 | return [] 88 | return s.split(',') 89 | 90 | #---------------------------------------------------------------------------- 91 | 92 | _examples = '''examples: 93 | 94 | # Train CoModGAN using the FFHQ dataset 95 | python %(prog)s --data-dir=~/datasets --dataset=ffhq --metrics=ids10k --num-gpus=8 96 | 97 | ''' 98 | 99 | def main(): 100 | parser = argparse.ArgumentParser( 101 | description='Train CoModGAN.', 102 | epilog=_examples, 103 | formatter_class=argparse.RawDescriptionHelpFormatter 104 | ) 105 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR') 106 | parser.add_argument('--data-dir', help='Dataset root directory', required=True) 107 | parser.add_argument('--dataset', help='Training dataset', required=True) 108 | parser.add_argument('--num-gpus', help='Number of GPUs (default: %(default)s)', default=1, type=int, metavar='N') 109 | parser.add_argument('--total-kimg', help='Training length in thousands of images (default: %(default)s)', metavar='KIMG', default=25000, type=int) 110 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool) 111 | parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='ids10k', type=_parse_comma_sep) 112 | parser.add_argument('--resume', default=None) 113 | parser.add_argument('--resume-with-new-nets', default=False, action='store_true') 114 | parser.add_argument('--disable-style-mod', default=False, action='store_true') 115 | parser.add_argument('--disable-cond-mod', default=False, action='store_true') 116 | 117 | args = parser.parse_args() 118 | 119 | if not os.path.exists(args.data_dir): 120 | print ('Error: dataset root directory does not exist.') 121 | sys.exit(1) 122 | 123 | for metric in args.metrics: 124 | if metric not in metric_defaults: 125 | print ('Error: unknown metric \'%s\'' % metric) 126 | sys.exit(1) 127 | 128 | run(**vars(args)) 129 | 130 | #---------------------------------------------------------------------------- 131 | 132 | if __name__ == "__main__": 133 | main() 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | # empty 8 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Multi-resolution input data pipeline.""" 8 | 9 | import os 10 | import glob 11 | import numpy as np 12 | import tensorflow as tf 13 | import dnnlib 14 | import dnnlib.tflib as tflib 15 | 16 | from .mask_generator import tf_mask_generator 17 | 18 | #---------------------------------------------------------------------------- 19 | # Dataset class that loads data from tfrecords files. 20 | 21 | class TFRecordDataset: 22 | def __init__(self, 23 | tfrecord_dir, # Directory containing a collection of tfrecords files. 24 | resolution = None, # Dataset resolution, None = autodetect. 25 | label_file = None, # Relative path of the labels file, None = autodetect. 26 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. 27 | max_images = None, # Maximum number of images to use, None = use all images. 28 | repeat = True, # Repeat dataset indefinitely? 29 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. 30 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. 31 | buffer_mb = 256, # Read buffer size (megabytes). 32 | num_threads = 2, # Number of concurrent threads. 33 | num_val_images = 10000, 34 | compressed = False): 35 | 36 | self.tfrecord_dir = tfrecord_dir 37 | self.resolution = None 38 | self.resolution_log2 = None 39 | self.shape = [] # [channels, height, width] 40 | self.dtype = 'uint8' 41 | self.dynamic_range = [0, 255] 42 | self.label_file = label_file 43 | self.label_size = None # components 44 | self.label_dtype = None 45 | self.pix2pix = False 46 | self._np_labels = None 47 | self._tf_minibatch_in = None 48 | self._tf_labels_var = None 49 | self._tf_labels_dataset = None 50 | self._tf_datasets = dict() 51 | self._tf_val_datasets = dict() 52 | self._tf_iterator = None 53 | self._tf_val_iterator = None 54 | self._tf_init_ops = dict() 55 | self._tf_val_init_ops = dict() 56 | self._tf_minibatch_np = None 57 | self._tf_minibatch_val_np = None 58 | self._tf_masks_iterator_np = None 59 | self._cur_minibatch = -1 60 | self._cur_lod = -1 61 | self._hole_range = -1 62 | 63 | # List tfrecords files and inspect their shapes. 64 | assert os.path.isdir(self.tfrecord_dir) 65 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) 66 | assert len(tfr_files) >= 1 67 | tfr_shapes = [] 68 | for tfr_file in tfr_files: 69 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) 70 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): 71 | ex = tf.train.Example() 72 | ex.ParseFromString(record) 73 | features = ex.features.feature 74 | if 'compressed' in features and features['compressed'].int64_list.value[0]: 75 | compressed = True 76 | if 'num_val_images' in features: 77 | num_val_images = features['num_val_images'].int64_list.value[0] 78 | tfr_shapes.append(features['shape'].int64_list.value) 79 | break 80 | 81 | # Determine shape and resolution. 82 | max_shape = max(tfr_shapes, key=np.prod) 83 | 84 | if max_shape[0] > 3: 85 | self.pix2pix = True 86 | 87 | self.resolution = resolution if resolution is not None else max_shape[1] 88 | self.resolution_log2 = int(np.log2(self.resolution)) 89 | self.shape = [max_shape[0], self.resolution, self.resolution] 90 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] 91 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes) 92 | assert all(shape[1] == shape[2] for shape in tfr_shapes) 93 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) 94 | 95 | # Autodetect label filename. 96 | if self.label_file is None: 97 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) 98 | if len(guess): 99 | self.label_file = guess[0] 100 | elif not os.path.isfile(self.label_file): 101 | guess = os.path.join(self.tfrecord_dir, self.label_file) 102 | if os.path.isfile(guess): 103 | self.label_file = guess 104 | 105 | # Load labels. 106 | assert max_label_size == 'full' or max_label_size >= 0 107 | self._np_labels = np.zeros([1<<30, 0], dtype=np.float32) 108 | if self.label_file is not None and max_label_size != 0: 109 | self._np_labels = np.load(self.label_file) 110 | assert self._np_labels.ndim == 2 111 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: 112 | self._np_labels = self._np_labels[:, :max_label_size] 113 | if max_images is not None and self._np_labels.shape[0] > max_images: 114 | self._np_labels = self._np_labels[:max_images] 115 | self.label_size = self._np_labels.shape[1] 116 | self.label_dtype = self._np_labels.dtype.name 117 | 118 | # Build TF expressions. 119 | with tf.name_scope('Dataset'), tf.device('/cpu:0'): 120 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) 121 | self._tf_hole_range = tf.placeholder(tf.float32, name='hole_range', shape=[2]) 122 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') 123 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) 124 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): 125 | if tfr_lod < 0: 126 | continue 127 | dset_raw = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) 128 | if max_images is not None: 129 | dset_raw = dset_raw.take(max_images) 130 | for tf_datasets, dset in [(self._tf_val_datasets, dset_raw.take(num_val_images)), (self._tf_datasets, dset_raw.skip(num_val_images))]: 131 | if compressed: 132 | dset = dset.map(self.parse_and_decode_tfrecord_tf, num_parallel_calls=num_threads) 133 | else: 134 | dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads) 135 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) 136 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize 137 | if shuffle_mb > 0: 138 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) 139 | if repeat: 140 | dset = dset.repeat() 141 | if prefetch_mb > 0: 142 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) 143 | dset = dset.batch(self._tf_minibatch_in) 144 | tf_datasets[tfr_lod] = dset 145 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) 146 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} 147 | self._tf_val_iterator = tf.data.Iterator.from_structure(self._tf_val_datasets[0].output_types, self._tf_val_datasets[0].output_shapes) 148 | self._tf_val_init_ops = {lod: self._tf_val_iterator.make_initializer(dset) for lod, dset in self._tf_val_datasets.items()} 149 | 150 | self._tf_masks_dataset = tf_mask_generator(self.resolution, self._tf_hole_range).batch(self._tf_minibatch_in).prefetch(64) 151 | self._tf_masks_iterator = self._tf_masks_dataset.make_initializable_iterator() 152 | 153 | def close(self): 154 | pass 155 | 156 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). 157 | def configure(self, minibatch_size, lod=0, hole_range=[0,1]): 158 | lod = int(np.floor(lod)) 159 | assert minibatch_size >= 1 and lod in self._tf_datasets 160 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod or (hole_range is not None and self._hole_range != hole_range): 161 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 162 | self._tf_val_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) 163 | self._cur_minibatch = minibatch_size 164 | self._cur_lod = lod 165 | if hole_range is not None: 166 | self._tf_masks_iterator.initializer.run({self._tf_minibatch_in: minibatch_size, self._tf_hole_range: hole_range}) 167 | self._hole_range = hole_range 168 | 169 | # Get next minibatch as TensorFlow expressions. 170 | def get_minibatch_tf(self): # => images, labels 171 | return self._tf_iterator.get_next() 172 | 173 | def get_minibatch_val_tf(self): # => images, labels 174 | return self._tf_val_iterator.get_next() 175 | 176 | # Get next minibatch as NumPy arrays. 177 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels 178 | self.configure(minibatch_size, lod) 179 | with tf.name_scope('Dataset'): 180 | if self._tf_minibatch_np is None: 181 | self._tf_minibatch_np = self.get_minibatch_tf() 182 | return tflib.run(self._tf_minibatch_np) 183 | 184 | def get_minibatch_val_np(self, minibatch_size, lod=0): # => images, labels 185 | self.configure(minibatch_size, lod) 186 | with tf.name_scope('Dataset'): 187 | if self._tf_minibatch_val_np is None: 188 | self._tf_minibatch_val_np = self.get_minibatch_val_tf() 189 | return tflib.run(self._tf_minibatch_val_np) 190 | 191 | # Get next minibatch as TensorFlow expressions. 192 | def get_random_masks_tf(self): # => images, labels 193 | return self._tf_masks_iterator.get_next() 194 | 195 | # Get next minibatch as NumPy arrays. 196 | def get_random_masks_np(self, minibatch_size, hole_range=[0,1]): 197 | self.configure(minibatch_size, hole_range=hole_range) 198 | with tf.name_scope('Dataset'): 199 | if self._tf_masks_iterator_np is None: 200 | self._tf_masks_iterator_np = self.get_random_masks_tf() 201 | return tflib.run(self._tf_masks_iterator_np) 202 | 203 | # Get random labels as TensorFlow expression. 204 | def get_random_labels_tf(self, minibatch_size): # => labels 205 | with tf.name_scope('Dataset'): 206 | if self.label_size > 0: 207 | with tf.device('/cpu:0'): 208 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) 209 | return tf.zeros([minibatch_size, 0], self.label_dtype) 210 | 211 | # Get random labels as NumPy array. 212 | def get_random_labels_np(self, minibatch_size): # => labels 213 | if self.label_size > 0: 214 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] 215 | return np.zeros([minibatch_size, 0], self.label_dtype) 216 | 217 | # Parse individual image from a tfrecords file into TensorFlow expression. 218 | @staticmethod 219 | def parse_tfrecord_tf(record): 220 | features = tf.parse_single_example(record, features={ 221 | 'shape': tf.FixedLenFeature([3], tf.int64), 222 | 'data': tf.FixedLenFeature([], tf.string)}) 223 | data = tf.decode_raw(features['data'], tf.uint8) 224 | return tf.reshape(data, features['shape']) 225 | 226 | @staticmethod 227 | def parse_and_decode_tfrecord_tf(record): 228 | features = tf.parse_single_example(record, features={ 229 | 'shape': tf.FixedLenFeature([3], tf.int64), 230 | 'data': tf.FixedLenFeature([], tf.string)}) 231 | shape = tf.cast(features['shape'], 'int32') 232 | data = tf.image.decode_image(features['data']) 233 | data = tf.image.resize_with_crop_or_pad(data, shape[1], shape[2]) 234 | return tf.broadcast_to(tf.transpose(data, [2, 0, 1]), shape) 235 | 236 | #---------------------------------------------------------------------------- 237 | # Helper func for constructing a dataset object using the given options. 238 | 239 | def load_dataset(class_name=None, data_dir=None, verbose=False, **kwargs): 240 | kwargs = dict(kwargs) 241 | if 'tfrecord_dir' in kwargs: 242 | if class_name is None: 243 | class_name = __name__ + '.TFRecordDataset' 244 | if data_dir is not None: 245 | kwargs['tfrecord_dir'] = os.path.join(data_dir, kwargs['tfrecord_dir']) 246 | 247 | assert class_name is not None 248 | if verbose: 249 | print('Streaming data using %s...' % class_name) 250 | dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs) 251 | if verbose: 252 | print('Dataset shape =', np.int32(dataset.shape).tolist()) 253 | print('Dynamic range =', dataset.dynamic_range) 254 | print('Label size =', dataset.label_size) 255 | return dataset 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Loss functions.""" 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import dnnlib.tflib as tflib 12 | from dnnlib.tflib.autosummary import autosummary 13 | 14 | def G_masked_logistic_ns_l1(G, D, opt, training_set, minibatch_size, reals, masks, l1_weight=0): 15 | _ = opt 16 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 17 | labels = training_set.get_random_labels_tf(minibatch_size) 18 | fake_images_out = G.get_output_for(latents, labels, reals, masks, is_training=True) 19 | fake_scores_out = D.get_output_for(fake_images_out, labels, masks, is_training=True) 20 | logistic_loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out)) 21 | logistic_loss = autosummary('Loss/logistic_loss', logistic_loss) 22 | l1_loss = tf.reduce_mean(tf.abs(fake_images_out - reals), axis=[1,2,3]) 23 | l1_loss = autosummary('Loss/l1_loss', l1_loss) 24 | loss = logistic_loss + l1_loss * l1_weight 25 | return loss, None 26 | 27 | def D_masked_logistic_r1(G, D, opt, training_set, minibatch_size, reals, labels, masks, gamma=10.0): 28 | _ = opt, training_set 29 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) 30 | fake_images_out = G.get_output_for(latents, labels, reals, masks, is_training=True) 31 | real_scores_out = D.get_output_for(reals, labels, masks, is_training=True) 32 | fake_scores_out = D.get_output_for(fake_images_out, labels, masks, is_training=True) 33 | real_scores_out = autosummary('Loss/scores/real', real_scores_out) 34 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) 35 | loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out)) 36 | loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) 37 | 38 | with tf.name_scope('GradientPenalty'): 39 | real_grads = tf.gradients(tf.reduce_sum(real_scores_out), [reals])[0] 40 | gradient_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3]) 41 | gradient_penalty = autosummary('Loss/gradient_penalty', gradient_penalty) 42 | reg = gradient_penalty * (gamma * 0.5) 43 | return loss, reg 44 | 45 | #---------------------------------------------------------------------------- 46 | -------------------------------------------------------------------------------- /training/mask_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | import math 4 | import random 5 | 6 | import tensorflow as tf 7 | 8 | def RandomBrush( 9 | max_tries, 10 | s, 11 | min_num_vertex = 4, 12 | max_num_vertex = 18, 13 | mean_angle = 2*math.pi / 5, 14 | angle_range = 2*math.pi / 15, 15 | min_width = 12, 16 | max_width = 48): 17 | H, W = s, s 18 | average_radius = math.sqrt(H*H+W*W) / 8 19 | mask = Image.new('L', (W, H), 0) 20 | for _ in range(np.random.randint(max_tries)): 21 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex) 22 | angle_min = mean_angle - np.random.uniform(0, angle_range) 23 | angle_max = mean_angle + np.random.uniform(0, angle_range) 24 | angles = [] 25 | vertex = [] 26 | for i in range(num_vertex): 27 | if i % 2 == 0: 28 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) 29 | else: 30 | angles.append(np.random.uniform(angle_min, angle_max)) 31 | 32 | h, w = mask.size 33 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) 34 | for i in range(num_vertex): 35 | r = np.clip( 36 | np.random.normal(loc=average_radius, scale=average_radius//2), 37 | 0, 2*average_radius) 38 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) 39 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) 40 | vertex.append((int(new_x), int(new_y))) 41 | 42 | draw = ImageDraw.Draw(mask) 43 | width = int(np.random.uniform(min_width, max_width)) 44 | draw.line(vertex, fill=1, width=width) 45 | for v in vertex: 46 | draw.ellipse((v[0] - width//2, 47 | v[1] - width//2, 48 | v[0] + width//2, 49 | v[1] + width//2), 50 | fill=1) 51 | if np.random.random() > 0.5: 52 | mask.transpose(Image.FLIP_LEFT_RIGHT) 53 | if np.random.random() > 0.5: 54 | mask.transpose(Image.FLIP_TOP_BOTTOM) 55 | mask = np.asarray(mask, np.uint8) 56 | if np.random.random() > 0.5: 57 | mask = np.flip(mask, 0) 58 | if np.random.random() > 0.5: 59 | mask = np.flip(mask, 1) 60 | return mask 61 | 62 | def RandomMask(s, hole_range=[0,1]): 63 | coef = min(hole_range[0] + hole_range[1], 1.0) 64 | while True: 65 | mask = np.ones((s, s), np.uint8) 66 | def Fill(max_size): 67 | w, h = np.random.randint(max_size), np.random.randint(max_size) 68 | ww, hh = w // 2, h // 2 69 | x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh) 70 | mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0 71 | def MultiFill(max_tries, max_size): 72 | for _ in range(np.random.randint(max_tries)): 73 | Fill(max_size) 74 | MultiFill(int(10 * coef), s // 2) 75 | MultiFill(int(5 * coef), s) 76 | mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s)) 77 | hole_ratio = 1 - np.mean(mask) 78 | if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]): 79 | continue 80 | return mask[np.newaxis, ...].astype(np.float32) 81 | 82 | def BatchRandomMask(batch_size, s, hole_range=[0, 1]): 83 | return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis = 0) 84 | 85 | def tf_mask_generator(s, tf_hole_range): 86 | def random_mask_generator(hole_range): 87 | while True: 88 | yield RandomMask(s, hole_range=hole_range) 89 | return tf.data.Dataset.from_generator(random_mask_generator, tf.float32, tf.TensorShape([1, s, s]), (tf_hole_range,)) -------------------------------------------------------------------------------- /training/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | """Miscellaneous utility functions.""" 8 | 9 | import os 10 | import pickle 11 | import numpy as np 12 | import PIL.Image 13 | import PIL.ImageFont 14 | import dnnlib 15 | from dnnlib import tflib 16 | 17 | import tensorflow as tf 18 | 19 | #---------------------------------------------------------------------------- 20 | # Convenience wrappers for pickle that are able to load data produced by 21 | # older versions of the code, and from external URLs. 22 | 23 | def open_file_or_url(file_or_url): 24 | if dnnlib.util.is_url(file_or_url): 25 | return dnnlib.util.open_url(file_or_url, cache_dir='.stylegan2-cache') 26 | return open(file_or_url, 'rb') 27 | 28 | def load_pkl(file_or_url): 29 | with open_file_or_url(file_or_url) as file: 30 | return pickle.load(file, encoding='latin1') 31 | 32 | def save_pkl(obj, filename): 33 | with open(filename, 'wb') as file: 34 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) 35 | 36 | #---------------------------------------------------------------------------- 37 | # Image utils. 38 | 39 | def adjust_dynamic_range(data, drange_in, drange_out): 40 | if drange_in != drange_out: 41 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) 42 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) 43 | data = data * scale + bias 44 | return data 45 | 46 | def create_image_grid(images, grid_size=None, pix2pix=False): 47 | if pix2pix: 48 | images, _ = np.split(images, 2, axis=1) 49 | assert images.ndim == 3 or images.ndim == 4 50 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] 51 | 52 | if grid_size is not None: 53 | grid_w, grid_h = tuple(grid_size) 54 | else: 55 | grid_w = max(int(np.ceil(np.sqrt(num))), 1) 56 | grid_h = max((num - 1) // grid_w + 1, 1) 57 | 58 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) 59 | for idx in range(num): 60 | x = (idx % grid_w) * img_w 61 | y = (idx // grid_w) * img_h 62 | grid[..., y : y + img_h, x : x + img_w] = images[idx] 63 | return grid 64 | 65 | def convert_to_pil_image(image, drange=[0,1]): 66 | assert image.ndim == 2 or image.ndim == 3 67 | if image.ndim == 3: 68 | if image.shape[0] == 1: 69 | image = image[0] # grayscale CHW => HW 70 | else: 71 | image = image.transpose(1, 2, 0) # CHW -> HWC 72 | 73 | if drange is not None: 74 | image = adjust_dynamic_range(image, drange, [0,255]) 75 | image = np.rint(image).clip(0, 255).astype(np.uint8) 76 | fmt = 'RGB' if image.ndim == 3 else 'L' 77 | return PIL.Image.fromarray(image, fmt) 78 | 79 | def save_image_grid(images, filename, drange=[0,1], grid_size=None, pix2pix=False): 80 | convert_to_pil_image(create_image_grid(images, grid_size, pix2pix=pix2pix), drange).save(filename) 81 | 82 | def apply_mirror_augment(minibatch): 83 | mask = np.random.rand(minibatch.shape[0]) < 0.5 84 | minibatch = np.array(minibatch) 85 | minibatch[mask] = minibatch[mask, :, :, ::-1] 86 | return minibatch 87 | 88 | #---------------------------------------------------------------------------- 89 | # Loading data from previous training runs. 90 | 91 | def parse_config_for_previous_run(run_dir): 92 | with open(os.path.join(run_dir, 'submit_config.pkl'), 'rb') as f: 93 | data = pickle.load(f) 94 | data = data.get('run_func_kwargs', {}) 95 | return dict(train=data, dataset=data.get('dataset_args', {})) 96 | 97 | #---------------------------------------------------------------------------- 98 | # Size and contents of the image snapshot grids that are exported 99 | # periodically during training. 100 | 101 | def setup_snapshot_image_grid(training_set, 102 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. 103 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. 104 | 105 | # Select size. 106 | gw = 1; gh = 1 107 | if size == '1080p': 108 | gw = np.clip(1920 // training_set.shape[2], 3, 32) 109 | gh = np.clip(1080 // training_set.shape[1], 2, 32) 110 | if size == '4k': 111 | gw = np.clip(3840 // training_set.shape[2], 7, 32) 112 | gh = np.clip(2160 // training_set.shape[1], 4, 32) 113 | if size == '8k': 114 | gw = np.clip(7680 // training_set.shape[2], 7, 32) 115 | gh = np.clip(4320 // training_set.shape[1], 4, 32) 116 | 117 | # Initialize data arrays. 118 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) 119 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) 120 | 121 | # Random layout. 122 | if layout == 'random': 123 | reals[:], labels[:] = training_set.get_minibatch_val_np(gw * gh) 124 | 125 | # Class-conditional layouts. 126 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) 127 | if layout in class_layouts: 128 | bw, bh = class_layouts[layout] 129 | nw = (gw - 1) // bw + 1 130 | nh = (gh - 1) // bh + 1 131 | blocks = [[] for _i in range(nw * nh)] 132 | for _iter in range(1000000): 133 | real, label = training_set.get_minibatch_val_np(1) 134 | idx = np.argmax(label[0]) 135 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh: 136 | idx += training_set.label_size 137 | if idx < len(blocks): 138 | blocks[idx].append((real, label)) 139 | if all(len(block) >= bw * bh for block in blocks): 140 | break 141 | for i, block in enumerate(blocks): 142 | for j, (real, label) in enumerate(block): 143 | x = (i % nw) * bw + j % bw 144 | y = (i // nw) * bh + j // bw 145 | if x < gw and y < gh: 146 | reals[x + y * gw] = real[0] 147 | labels[x + y * gw] = label[0] 148 | 149 | masks = training_set.get_random_masks_np(gw * gh) 150 | 151 | return (gw, gh), reals, labels, masks 152 | 153 | #---------------------------------------------------------------------------- 154 | --------------------------------------------------------------------------------