├── .gitignore
├── Dockerfile
├── LICENSE.txt
├── README.md
├── dataset_tools
├── create_from_images.py
└── tfrecord_utils.py
├── dnnlib
├── __init__.py
├── submission
│ ├── __init__.py
│ ├── internal
│ │ ├── __init__.py
│ │ └── local.py
│ ├── run_context.py
│ └── submit.py
├── tflib
│ ├── __init__.py
│ ├── autosummary.py
│ ├── custom_ops.py
│ ├── network.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── fused_bias_act.cu
│ │ ├── fused_bias_act.py
│ │ ├── upfirdn_2d.cu
│ │ └── upfirdn_2d.py
│ ├── optimizer.py
│ └── tfutil.py
└── util.py
├── imgs
├── demo.gif
├── example_image.jpg
├── example_mask.jpg
└── grid-main.jpg
├── metrics
├── __init__.py
├── frechet_inception_distance.py
├── inception_discriminative_score.py
├── learned_perceptual_image_patch_similarity.py
├── metric_base.py
└── metric_defaults.py
├── run_demo.py
├── run_generator.py
├── run_metrics.py
├── run_training.py
└── training
├── __init__.py
├── co_mod_gan.py
├── dataset.py
├── loss.py
├── mask_generator.py
├── misc.py
└── training_loop.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /results/
2 | /images/
3 | /models/
4 | /datasets/
5 | /test/
6 | /training/_my_pm.pyd
7 | /.stylegan2-cache/
8 | *.db
9 | *.pyc
10 | *.dll
11 | *.pkl
12 | *.ipynb
13 | *.zip
14 | *.tfrecords
15 | .vscode
16 | *.csv
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | FROM tensorflow/tensorflow:1.15.0-gpu-py3
8 |
9 | RUN pip install scipy==1.3.3
10 | RUN pip install requests==2.22.0
11 | RUN pip install Pillow==6.2.1
12 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020, Shengyu Zhao, Jonathan Cui, Yilun Sheng, Yue Dong,
2 | Xiao Liang, Eric I Chang, Yan Xu
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
26 |
27 | ----------------------- LICENSE FOR stylegan2 -----------------------
28 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
29 |
30 |
31 | Nvidia Source Code License-NC
32 |
33 | =======================================================================
34 |
35 | 1. Definitions
36 |
37 | "Licensor" means any person or entity that distributes its Work.
38 |
39 | "Software" means the original work of authorship made available under
40 | this License.
41 |
42 | "Work" means the Software and any additions to or derivative works of
43 | the Software that are made available under this License.
44 |
45 | "Nvidia Processors" means any central processing unit (CPU), graphics
46 | processing unit (GPU), field-programmable gate array (FPGA),
47 | application-specific integrated circuit (ASIC) or any combination
48 | thereof designed, made, sold, or provided by Nvidia or its affiliates.
49 |
50 | The terms "reproduce," "reproduction," "derivative works," and
51 | "distribution" have the meaning as provided under U.S. copyright law;
52 | provided, however, that for the purposes of this License, derivative
53 | works shall not include works that remain separable from, or merely
54 | link (or bind by name) to the interfaces of, the Work.
55 |
56 | Works, including the Software, are "made available" under this License
57 | by including in or with the Work either (a) a copyright notice
58 | referencing the applicability of this License to the Work, or (b) a
59 | copy of this License.
60 |
61 | 2. License Grants
62 |
63 | 2.1 Copyright Grant. Subject to the terms and conditions of this
64 | License, each Licensor grants to you a perpetual, worldwide,
65 | non-exclusive, royalty-free, copyright license to reproduce,
66 | prepare derivative works of, publicly display, publicly perform,
67 | sublicense and distribute its Work and any resulting derivative
68 | works in any form.
69 |
70 | 3. Limitations
71 |
72 | 3.1 Redistribution. You may reproduce or distribute the Work only
73 | if (a) you do so under this License, (b) you include a complete
74 | copy of this License with your distribution, and (c) you retain
75 | without modification any copyright, patent, trademark, or
76 | attribution notices that are present in the Work.
77 |
78 | 3.2 Derivative Works. You may specify that additional or different
79 | terms apply to the use, reproduction, and distribution of your
80 | derivative works of the Work ("Your Terms") only if (a) Your Terms
81 | provide that the use limitation in Section 3.3 applies to your
82 | derivative works, and (b) you identify the specific derivative
83 | works that are subject to Your Terms. Notwithstanding Your Terms,
84 | this License (including the redistribution requirements in Section
85 | 3.1) will continue to apply to the Work itself.
86 |
87 | 3.3 Use Limitation. The Work and any derivative works thereof only
88 | may be used or intended for use non-commercially. The Work or
89 | derivative works thereof may be used or intended for use by Nvidia
90 | or its affiliates commercially or non-commercially. As used herein,
91 | "non-commercially" means for research or evaluation purposes only.
92 |
93 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
94 | against any Licensor (including any claim, cross-claim or
95 | counterclaim in a lawsuit) to enforce any patents that you allege
96 | are infringed by any Work, then your rights under this License from
97 | such Licensor (including the grants in Sections 2.1 and 2.2) will
98 | terminate immediately.
99 |
100 | 3.5 Trademarks. This License does not grant any rights to use any
101 | Licensor's or its affiliates' names, logos, or trademarks, except
102 | as necessary to reproduce the notices described in this License.
103 |
104 | 3.6 Termination. If you violate any term of this License, then your
105 | rights under this License (including the grants in Sections 2.1 and
106 | 2.2) will terminate immediately.
107 |
108 | 4. Disclaimer of Warranty.
109 |
110 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
111 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
112 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
113 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
114 | THIS LICENSE.
115 |
116 | 5. Limitation of Liability.
117 |
118 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
119 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
120 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
121 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
122 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
123 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
124 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
125 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
126 | THE POSSIBILITY OF SUCH DAMAGES.
127 |
128 | =======================================================================
129 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight)
2 |
3 | ### [Demo (Unofficial)](https://www.microsoft.com/en-us/ai/ai-lab-CoModGAN) | [Paper](https://openreview.net/pdf?id=sSjqmfsk95O)
4 |
5 |
6 |
7 | **[NEW!]** Another [unofficial demo](https://www.microsoft.com/en-us/ai/ai-lab-CoModGAN) is available!
8 |
9 | **[NOTICE]** Our web demo will be closed recently. Enjoy the last days!
10 |
11 | **[NEW!]** Time to play with our [interactive web demo](http://comodgan.ml)!
12 |
13 | *Numerous task-specific variants of conditional generative adversarial networks have been developed for image completion. Yet, a serious limitation remains that all existing algorithms tend to fail when handling **large-scale missing regions**. To overcome this challenge, we propose a generic new approach that bridges the gap between image-conditional and recent modulated unconditional generative architectures via **co-modulation** of both conditional and stochastic style representations. Also, due to the lack of good quantitative metrics for image completion, we propose the new **Paired/Unpaired Inception Discriminative Score (P-IDS/U-IDS)**, which robustly measures the perceptual fidelity of inpainted images compared to real images via linear separability in a feature space. Experiments demonstrate superior performance in terms of both quality and diversity over state-of-the-art methods in free-form image completion and easy generalization to image-to-image translation.*
14 |
15 |
16 |
17 | Large Scale Image Completion via Co-Modulated Generative Adversarial Networks
18 | [Shengyu Zhao](https://scholar.google.com/citations?user=gLCdw70AAAAJ), [Jonathan Cui](https://www.linkedin.com/in/jonathan-cui-110b211a6/), Yilun Sheng, Yue Dong, Xiao Liang, Eric I Chang, Yan Xu
19 | Tsinghua University and Microsoft Research
20 | [arXiv](http://arxiv.org/abs/2103.10428) | [OpenReview](https://openreview.net/pdf?id=sSjqmfsk95O)
21 |
22 | ## Overview
23 |
24 | This repo is implemented upon and has the same dependencies as the official [StyleGAN2 repo](https://github.com/NVlabs/stylegan2). We also provide a [Dockerfile](https://github.com/zsyzzsoft/co-mod-gan/blob/master/Dockerfile) for Docker users. This repo currently supports:
25 | - [x] Large scale image completion experiments on FFHQ and Places2
26 | - [x] Image-to-image translation experiments on Edges2Shoes and Edges2Handbags
27 | - [ ] Image-to-image translation experiments on COCO-Stuff
28 | - [x] Evaluation code of *Paired/Unpaired Inception Discriminative Score (P-IDS/U-IDS)*
29 |
30 | ## Datasets
31 |
32 | - FFHQ dataset (in TFRecords format) can be downloaded following the [StyleGAN2 repo](https://github.com/NVlabs/stylegan2).
33 | - Places2 dataset can be downloaded in [this website](http://places2.csail.mit.edu/download.html) (Places365-Challenge 2016 high-resolution images, [training set](http://data.csail.mit.edu/places/places365/train_large_places365challenge.tar) and [validation set](http://data.csail.mit.edu/places/places365/val_large.tar)). The raw images should be converted into TFRecords using `dataset_tools/create_from_images.py` with `--shuffle --compressed`.
34 | - Edges2Shoes and Edges2Handbags datasets can be downloaded following the [pix2pix repo](https://github.com/phillipi/pix2pix). The raw images should be converted into TFRecords using `dataset_tools/create_from_images.py` with `--shuffle --pix2pix`.
35 | - To prepare a custom dataset, please use `dataset_tools/create_from_images.py`, which will automatically center crop and resize your images to the specified resolution. You only need to specify `--val-image-dir` for testing purpose.
36 |
37 | ## Training
38 |
39 | The following script is for training on FFHQ. It will split 10k images for validation. We recommend using 8 NVIDIA Tesla V100 GPUs for training. Training at 512x512 resolution takes about 1 week.
40 |
41 | ```bash
42 | python run_training.py --data-dir=DATA_DIR --dataset=DATASET --metrics=ids10k --mirror-augment --num-gpus=8
43 | ```
44 |
45 | The following script is for training on Places2 at resolution 512x512 (resolution must be specified when training on compressed dataset), which has a validation set of 36500 images:
46 |
47 | ```bash
48 | python run_training.py --data-dir=DATA_DIR --dataset=DATASET --resolution=512 --metrics=ids36k5 --total-kimg 50000 --num-gpus=8
49 | ```
50 |
51 | The following script is for training on Edges2Handbags (and similarly for Edges2Shoes):
52 |
53 | ```bash
54 | python run_training.py --data-dir=DATA_DIR --dataset=DATASET --metrics=fid200-rt-handbags --mirror-augment --num-gpus=8
55 | ```
56 |
57 | ## Pre-Trained Models
58 |
59 | Our pre-trained models are available on [Google Drive](https://drive.google.com/drive/folders/1zSJj1ichgSA-4sECGm-fQ0Ww8aiwpkoO):
60 |
61 | | Model name & URL | Description |
62 | | ------------------------------------------------------------ | ------------------------------------------------------------ |
63 | | [co-mod-gan-ffhq-9-025000.pkl](https://drive.google.com/file/d/1b3XxfAmJ9k2vd73j-3nPMr_lvNMQOFGE/view?usp=sharing) | Large scale image completion on FFHQ (512x512) |
64 | | [co-mod-gan-ffhq-10-025000.pkl](https://drive.google.com/file/d/1M2dSxlJnCFNM6LblpB2nQCnaimgwaaKu/view?usp=sharing) | Large scale image completion on FFHQ (1024x1024) |
65 | | [co-mod-gan-places2-050000.pkl](https://drive.google.com/file/d/1dJa3DRWIkx6Ebr8Sc0v1FdvWf6wkd010/view?usp=sharing) | Large scale image completion on Places2 (512x512) |
66 | | [co-mod-gan-coco-stuff-025000.pkl](https://drive.google.com/file/d/1Ol9_pKMpfIHHwbdE7RFmJcCAzfj8hqxQ/view?usp=sharing) | Image-to-image translation on COCO-Stuff (labels to photos) (512x512) |
67 | | [co-mod-gan-edges2shoes-025000.pkl](https://drive.google.com/file/d/155p-_zAtL8RJSsKHAWrRaGxJVzT4NZKg/view?usp=sharing) | Image-to-image translation on edges2shoes (256x256) |
68 | | [co-mod-gan-edges2handbags-025000.pkl](https://drive.google.com/file/d/1nBIQaUs6fXRpEt1cweqQKtWVw5UZAqLi/view?usp=sharing) | Image-to-image translation on edges2handbags (256x256) |
69 |
70 | Use the following script to run the interactive demo locally:
71 |
72 | ```bash
73 | python run_demo.py -d DATA_DIR/DATASET -c CHECKPOINT_FILE(S)
74 | ```
75 |
76 | or the following command as a minimal example of usage:
77 |
78 | ```bash
79 | python run_generator.py -c CHECKPOINT_FILE -i imgs/example_image.jpg -m imgs/example_mask.jpg -o imgs/example_output.jpg
80 | ```
81 |
82 | ## Evaluation
83 |
84 | The following script is for evaluation:
85 |
86 | ```bash
87 | python run_metrics.py --data-dir=DATA_DIR --dataset=DATASET --network=CHECKPOINT_FILE(S) --metrics=METRIC(S) --num-gpus=1
88 | ```
89 |
90 | Commonly used metrics are `ids10k` and `ids36k5` (for FFHQ and Places2 respectively), which will compute P-IDS and U-IDS together with FID. By default, masks are generated randomly for evaluation, or you may append the metric name with `-h0` ([0.0, 0.2]) to `-h4` ([0.8, 1.0]) to specify the range of masked ratio.
91 |
92 | ## Citation
93 |
94 | If you find this code helpful, please cite our paper:
95 | ```
96 | @inproceedings{zhao2021comodgan,
97 | title={Large Scale Image Completion via Co-Modulated Generative Adversarial Networks},
98 | author={Zhao, Shengyu and Cui, Jonathan and Sheng, Yilun and Dong, Yue and Liang, Xiao and Chang, Eric I and Xu, Yan},
99 | booktitle={International Conference on Learning Representations (ICLR)},
100 | year={2021}
101 | }
102 | ```
--------------------------------------------------------------------------------
/dataset_tools/create_from_images.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 | import numpy as np
3 | import argparse
4 | from tqdm import tqdm
5 | import random
6 | import os
7 | import PIL.Image
8 |
9 | from tfrecord_utils import TFRecordExporter
10 |
11 | def worker(in_queue, out_queue, resolution, compressed, pix2pix):
12 | while True:
13 | fpath = in_queue.get()
14 | if compressed:
15 | assert not pix2pix
16 | if fpath.endswith('.jpg') or fpath.endswith('.JPG'):
17 | img = np.fromfile(fpath, dtype='uint8')
18 | else:
19 | img = None
20 | else:
21 | try:
22 | img = PIL.Image.open(fpath)
23 | except IOError:
24 | img = None
25 | else:
26 | img_size = min(img.size[0] // 2 if pix2pix else img.size[1], img.size[1])
27 | left = (img.size[0] - (img_size * 2 if pix2pix else img_size)) // 2
28 | top = (img.size[1] - img_size) // 2
29 | img = img.crop((left, top, left + (img_size * 2 if pix2pix else img_size), top + img_size))
30 | img = img.resize((resolution * 2 if pix2pix else resolution, resolution), PIL.Image.BILINEAR)
31 | img = np.asarray(img.convert('RGB')).transpose([2, 0, 1])
32 | if pix2pix:
33 | img = np.concatenate(np.split(img, 2, axis=-1), axis=0)
34 | out_queue.put(img)
35 |
36 | def create_from_images(tfrecord_dir, val_image_dir, train_image_dir, resolution, num_channels, num_processes, shuffle, compressed, pix2pix):
37 | in_queue = mp.Queue()
38 | out_queue = mp.Queue(num_processes * 8)
39 |
40 | worker_procs = [mp.Process(target=worker, args=(in_queue, out_queue, resolution, compressed, pix2pix)) for _ in range(num_processes)]
41 | for worker_proc in worker_procs:
42 | worker_proc.daemon = True
43 | worker_proc.start()
44 |
45 | print('Processes created.')
46 |
47 | with TFRecordExporter(tfrecord_dir, compressed=compressed) as tfr:
48 | tfr.set_shape([num_channels * 2 if pix2pix else num_channels, resolution, resolution])
49 |
50 | if val_image_dir:
51 | print('Processing validation images...')
52 | flist = []
53 | for root, _, files in os.walk(val_image_dir):
54 | print(root)
55 | flist.extend([os.path.join(root, fname) for fname in files])
56 | tfr.set_num_val_images(len(flist))
57 | if shuffle:
58 | random.shuffle(flist)
59 | for fpath in tqdm(flist):
60 | in_queue.put(fpath, block=False)
61 | for _ in tqdm(range(len(flist))):
62 | img = out_queue.get()
63 | if img is not None:
64 | tfr.add_image(img)
65 |
66 | if train_image_dir:
67 | print('Processing training images...')
68 | flist = []
69 | for root, _, files in os.walk(train_image_dir):
70 | print(root)
71 | flist.extend([os.path.join(root, fname) for fname in files])
72 | if shuffle:
73 | random.shuffle(flist)
74 | for fpath in tqdm(flist):
75 | in_queue.put(fpath, block=False)
76 | for _ in tqdm(range(len(flist))):
77 | img = out_queue.get()
78 | if img is not None:
79 | tfr.add_image(img)
80 |
81 | def main():
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument('--tfrecord-dir', help='Output directory of generated TFRecord', required=True)
84 | parser.add_argument('--val-image-dir', help='Root directory of validation images', default=None)
85 | parser.add_argument('--train-image-dir', help='Root directory of training images', default=None)
86 | parser.add_argument('--resolution', help='Target resolution', type=int, default=512)
87 | parser.add_argument('--num-channels', help='Number of channels of images', type=int, default=3)
88 | parser.add_argument('--num-processes', help='Number of parallel processes', type=int, default=8)
89 | parser.add_argument('--shuffle', default=False, action='store_true')
90 | parser.add_argument('--compressed', default=False, action='store_true')
91 | parser.add_argument('--pix2pix', default=False, action='store_true')
92 |
93 | args = parser.parse_args()
94 | create_from_images(**vars(args))
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/dataset_tools/tfrecord_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Tool for creating TFRecords datasets."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | class TFRecordExporter:
16 | def __init__(self, tfrecord_dir, compressed=False):
17 | self.tfrecord_dir = tfrecord_dir
18 | self.num_val_images = 0
19 | self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
20 | self.shape = None
21 | self.resolution_log2 = None
22 | self.tfr_writer = None
23 | self.compressed = compressed
24 |
25 | if not os.path.isdir(self.tfrecord_dir):
26 | os.makedirs(self.tfrecord_dir)
27 | assert os.path.isdir(self.tfrecord_dir)
28 |
29 | def close(self):
30 | self.tfr_writer.close()
31 | self.tfr_writer = None
32 |
33 | def set_shape(self, shape):
34 | self.shape = shape
35 | self.resolution_log2 = int(np.log2(self.shape[1]))
36 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
37 | tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % self.resolution_log2
38 | self.tfr_writer = tf.python_io.TFRecordWriter(tfr_file, tfr_opt)
39 |
40 | def set_num_val_images(self, num_val_images):
41 | self.num_val_images = num_val_images
42 |
43 | def add_image(self, img):
44 | if self.shape is None:
45 | self.set_shape(img.shape)
46 | if not self.compressed:
47 | assert list(self.shape) == list(img.shape)
48 | quant = np.rint(img).clip(0, 255).astype(np.uint8) if not self.compressed else img
49 | ex = tf.train.Example(features=tf.train.Features(feature={
50 | 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=self.shape)),
51 | 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()])),
52 | 'compressed': tf.train.Feature(int64_list=tf.train.Int64List(value=[self.compressed])),
53 | 'num_val_images': tf.train.Feature(int64_list=tf.train.Int64List(value=[self.num_val_images])),
54 | }))
55 | self.tfr_writer.write(ex.SerializeToString())
56 |
57 | def __enter__(self):
58 | return self
59 |
60 | def __exit__(self, *args):
61 | self.close()
62 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | from . import submission
8 |
9 | from .submission.run_context import RunContext
10 |
11 | from .submission.submit import SubmitTarget
12 | from .submission.submit import PathType
13 | from .submission.submit import SubmitConfig
14 | from .submission.submit import submit_run
15 | from .submission.submit import get_path_from_template
16 | from .submission.submit import convert_path
17 | from .submission.submit import make_run_dir_path
18 |
19 | from .util import EasyDict
20 |
21 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
22 |
--------------------------------------------------------------------------------
/dnnlib/submission/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | from . import run_context
8 | from . import submit
9 |
--------------------------------------------------------------------------------
/dnnlib/submission/internal/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | from . import local
8 |
--------------------------------------------------------------------------------
/dnnlib/submission/internal/local.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | class TargetOptions():
8 | def __init__(self):
9 | self.do_not_copy_source_files = False
10 |
11 | class Target():
12 | def __init__(self):
13 | pass
14 |
15 | def finalize_submit_config(self, submit_config, host_run_dir):
16 | print ('Local submit ', end='', flush=True)
17 | submit_config.run_dir = host_run_dir
18 |
19 | def submit(self, submit_config, host_run_dir):
20 | from ..submit import run_wrapper, convert_path
21 | print('- run_dir: %s' % convert_path(submit_config.run_dir), flush=True)
22 | return run_wrapper(submit_config)
23 |
--------------------------------------------------------------------------------
/dnnlib/submission/run_context.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Helpers for managing the run/training loop."""
8 |
9 | import datetime
10 | import json
11 | import os
12 | import pprint
13 | import time
14 | import types
15 |
16 | from typing import Any
17 |
18 | from . import submit
19 |
20 | # Singleton RunContext
21 | _run_context = None
22 |
23 | class RunContext(object):
24 | """Helper class for managing the run/training loop.
25 |
26 | The context will hide the implementation details of a basic run/training loop.
27 | It will set things up properly, tell if run should be stopped, and then cleans up.
28 | User should call update periodically and use should_stop to determine if run should be stopped.
29 |
30 | Args:
31 | submit_config: The SubmitConfig that is used for the current run.
32 | config_module: (deprecated) The whole config module that is used for the current run.
33 | """
34 |
35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None):
36 | global _run_context
37 | # Only a single RunContext can be alive
38 | assert _run_context is None
39 | _run_context = self
40 | self.submit_config = submit_config
41 | self.should_stop_flag = False
42 | self.has_closed = False
43 | self.start_time = time.time()
44 | self.last_update_time = time.time()
45 | self.last_update_interval = 0.0
46 | self.progress_monitor_file_path = None
47 |
48 | # vestigial config_module support just prints a warning
49 | if config_module is not None:
50 | print("RunContext.config_module parameter support has been removed.")
51 |
52 | # write out details about the run to a text file
53 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
54 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
55 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
56 |
57 | def __enter__(self) -> "RunContext":
58 | return self
59 |
60 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
61 | self.close()
62 |
63 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
64 | """Do general housekeeping and keep the state of the context up-to-date.
65 | Should be called often enough but not in a tight loop."""
66 | assert not self.has_closed
67 |
68 | self.last_update_interval = time.time() - self.last_update_time
69 | self.last_update_time = time.time()
70 |
71 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
72 | self.should_stop_flag = True
73 |
74 | def should_stop(self) -> bool:
75 | """Tell whether a stopping condition has been triggered one way or another."""
76 | return self.should_stop_flag
77 |
78 | def get_time_since_start(self) -> float:
79 | """How much time has passed since the creation of the context."""
80 | return time.time() - self.start_time
81 |
82 | def get_time_since_last_update(self) -> float:
83 | """How much time has passed since the last call to update."""
84 | return time.time() - self.last_update_time
85 |
86 | def get_last_update_interval(self) -> float:
87 | """How much time passed between the previous two calls to update."""
88 | return self.last_update_interval
89 |
90 | def close(self) -> None:
91 | """Close the context and clean up.
92 | Should only be called once."""
93 | if not self.has_closed:
94 | # update the run.txt with stopping time
95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
98 | self.has_closed = True
99 |
100 | # detach the global singleton
101 | global _run_context
102 | if _run_context is self:
103 | _run_context = None
104 |
105 | @staticmethod
106 | def get():
107 | import dnnlib
108 | if _run_context is not None:
109 | return _run_context
110 | return RunContext(dnnlib.submit_config)
111 |
--------------------------------------------------------------------------------
/dnnlib/submission/submit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Submit a function to be run either locally or in a computing cluster."""
8 |
9 | import copy
10 | import inspect
11 | import os
12 | import pathlib
13 | import pickle
14 | import platform
15 | import pprint
16 | import re
17 | import shutil
18 | import sys
19 | import time
20 | import traceback
21 |
22 | from enum import Enum
23 |
24 | from .. import util
25 | from ..util import EasyDict
26 |
27 | from . import internal
28 |
29 | class SubmitTarget(Enum):
30 | """The target where the function should be run.
31 |
32 | LOCAL: Run it locally.
33 | """
34 | LOCAL = 1
35 |
36 |
37 | class PathType(Enum):
38 | """Determines in which format should a path be formatted.
39 |
40 | WINDOWS: Format with Windows style.
41 | LINUX: Format with Linux/Posix style.
42 | AUTO: Use current OS type to select either WINDOWS or LINUX.
43 | """
44 | WINDOWS = 1
45 | LINUX = 2
46 | AUTO = 3
47 |
48 |
49 | class PlatformExtras:
50 | """A mixed bag of values used by dnnlib heuristics.
51 |
52 | Attributes:
53 |
54 | data_reader_buffer_size: Used by DataReader to size internal shared memory buffers.
55 | data_reader_process_count: Number of worker processes to spawn (zero for single thread operation)
56 | """
57 | def __init__(self):
58 | self.data_reader_buffer_size = 1<<30 # 1 GB
59 | self.data_reader_process_count = 0 # single threaded default
60 |
61 |
62 | _user_name_override = None
63 |
64 | class SubmitConfig(util.EasyDict):
65 | """Strongly typed config dict needed to submit runs.
66 |
67 | Attributes:
68 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
69 | run_desc: Description of the run. Will be used in the run dir and task name.
70 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
71 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
72 | submit_target: Submit target enum value. Used to select where the run is actually launched.
73 | num_gpus: Number of GPUs used/requested for the run.
74 | print_info: Whether to print debug information when submitting.
75 | local.do_not_copy_source_files: Do not copy source files from the working directory to the run dir.
76 | run_id: Automatically populated value during submit.
77 | run_name: Automatically populated value during submit.
78 | run_dir: Automatically populated value during submit.
79 | run_func_name: Automatically populated value during submit.
80 | run_func_kwargs: Automatically populated value during submit.
81 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
82 | task_name: Automatically populated value during submit.
83 | host_name: Automatically populated value during submit.
84 | platform_extras: Automatically populated values during submit. Used by various dnnlib libraries such as the DataReader class.
85 | """
86 |
87 | def __init__(self):
88 | super().__init__()
89 |
90 | # run (set these)
91 | self.run_dir_root = "" # should always be passed through get_path_from_template
92 | self.run_desc = ""
93 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode", "_cudacache"]
94 | self.run_dir_extra_files = []
95 |
96 | # submit (set these)
97 | self.submit_target = SubmitTarget.LOCAL
98 | self.num_gpus = 1
99 | self.print_info = False
100 | self.nvprof = False
101 | self.local = internal.local.TargetOptions()
102 | self.datasets = []
103 |
104 | # (automatically populated)
105 | self.run_id = None
106 | self.run_name = None
107 | self.run_dir = None
108 | self.run_func_name = None
109 | self.run_func_kwargs = None
110 | self.user_name = None
111 | self.task_name = None
112 | self.host_name = "localhost"
113 | self.platform_extras = PlatformExtras()
114 |
115 |
116 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
117 | """Replace tags in the given path template and return either Windows or Linux formatted path."""
118 | # automatically select path type depending on running OS
119 | if path_type == PathType.AUTO:
120 | if platform.system() == "Windows":
121 | path_type = PathType.WINDOWS
122 | elif platform.system() == "Linux":
123 | path_type = PathType.LINUX
124 | else:
125 | raise RuntimeError("Unknown platform")
126 |
127 | path_template = path_template.replace("", get_user_name())
128 |
129 | # return correctly formatted path
130 | if path_type == PathType.WINDOWS:
131 | return str(pathlib.PureWindowsPath(path_template))
132 | elif path_type == PathType.LINUX:
133 | return str(pathlib.PurePosixPath(path_template))
134 | else:
135 | raise RuntimeError("Unknown platform")
136 |
137 |
138 | def get_template_from_path(path: str) -> str:
139 | """Convert a normal path back to its template representation."""
140 | path = path.replace("\\", "/")
141 | return path
142 |
143 |
144 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
145 | """Convert a normal path to template and the convert it back to a normal path with given path type."""
146 | path_template = get_template_from_path(path)
147 | path = get_path_from_template(path_template, path_type)
148 | return path
149 |
150 |
151 | def set_user_name_override(name: str) -> None:
152 | """Set the global username override value."""
153 | global _user_name_override
154 | _user_name_override = name
155 |
156 |
157 | def get_user_name():
158 | """Get the current user name."""
159 | if _user_name_override is not None:
160 | return _user_name_override
161 | elif platform.system() == "Windows":
162 | return os.getlogin()
163 | elif platform.system() == "Linux":
164 | try:
165 | import pwd
166 | return pwd.getpwuid(os.geteuid()).pw_name
167 | except:
168 | return "unknown"
169 | else:
170 | raise RuntimeError("Unknown platform")
171 |
172 |
173 | def make_run_dir_path(*paths):
174 | """Make a path/filename that resides under the current submit run_dir.
175 |
176 | Args:
177 | *paths: Path components to be passed to os.path.join
178 |
179 | Returns:
180 | A file/dirname rooted at submit_config.run_dir. If there's no
181 | submit_config or run_dir, the base directory is the current
182 | working directory.
183 |
184 | E.g., `os.path.join(dnnlib.submit_config.run_dir, "output.txt"))`
185 | """
186 | import dnnlib
187 | if (dnnlib.submit_config is None) or (dnnlib.submit_config.run_dir is None):
188 | return os.path.join(os.getcwd(), *paths)
189 | return os.path.join(dnnlib.submit_config.run_dir, *paths)
190 |
191 |
192 | def _create_run_dir_local(submit_config: SubmitConfig) -> str:
193 | """Create a new run dir with increasing ID number at the start."""
194 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
195 |
196 | if not os.path.exists(run_dir_root):
197 | os.makedirs(run_dir_root)
198 |
199 | submit_config.run_id = _get_next_run_id_local(run_dir_root)
200 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
201 | run_dir = os.path.join(run_dir_root, submit_config.run_name)
202 |
203 | if os.path.exists(run_dir):
204 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
205 |
206 | os.makedirs(run_dir)
207 |
208 | return run_dir
209 |
210 |
211 | def _get_next_run_id_local(run_dir_root: str) -> int:
212 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
213 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
214 | r = re.compile("^\\d+") # match one or more digits at the start of the string
215 | run_id = 0
216 |
217 | for dir_name in dir_names:
218 | m = r.match(dir_name)
219 |
220 | if m is not None:
221 | i = int(m.group())
222 | run_id = max(run_id, i + 1)
223 |
224 | return run_id
225 |
226 |
227 | def _populate_run_dir(submit_config: SubmitConfig, run_dir: str) -> None:
228 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
229 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
230 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
231 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
232 |
233 | if (submit_config.submit_target == SubmitTarget.LOCAL) and submit_config.local.do_not_copy_source_files:
234 | return
235 |
236 | files = []
237 |
238 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
239 | assert '.' in submit_config.run_func_name
240 | for _idx in range(submit_config.run_func_name.count('.') - 1):
241 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
242 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
243 |
244 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
245 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
246 |
247 | files += submit_config.run_dir_extra_files
248 |
249 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
250 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "internal", "run.py"), os.path.join(run_dir, "run.py"))]
251 |
252 | util.copy_files_and_create_dirs(files)
253 |
254 |
255 |
256 | def run_wrapper(submit_config: SubmitConfig) -> None:
257 | """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
258 | is_local = submit_config.submit_target == SubmitTarget.LOCAL
259 |
260 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
261 | if is_local:
262 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
263 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
264 | logger = util.Logger(file_name=None, should_flush=True)
265 |
266 | import dnnlib
267 | dnnlib.submit_config = submit_config
268 |
269 | exit_with_errcode = False
270 | try:
271 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
272 | start_time = time.time()
273 |
274 | run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
275 | assert callable(run_func_obj)
276 | sig = inspect.signature(run_func_obj)
277 | if 'submit_config' in sig.parameters:
278 | run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279 | else:
280 | run_func_obj(**submit_config.run_func_kwargs)
281 |
282 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
283 | except:
284 | if is_local:
285 | raise
286 | else:
287 | traceback.print_exc()
288 |
289 | log_src = os.path.join(submit_config.run_dir, "log.txt")
290 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
291 | shutil.copyfile(log_src, log_dst)
292 |
293 | # Defer sys.exit(1) to happen after we close the logs and create a _finished.txt
294 | exit_with_errcode = True
295 | finally:
296 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
297 |
298 | dnnlib.RunContext.get().close()
299 | dnnlib.submit_config = None
300 | logger.close()
301 |
302 | # If we hit an error, get out of the script now and signal the error
303 | # to whatever process that started this script.
304 | if exit_with_errcode:
305 | sys.exit(1)
306 |
307 | return submit_config
308 |
309 |
310 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
311 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
312 | submit_config = copy.deepcopy(submit_config)
313 |
314 | submit_target = submit_config.submit_target
315 | farm = None
316 | if submit_target == SubmitTarget.LOCAL:
317 | farm = internal.local.Target()
318 | assert farm is not None # unknown target
319 |
320 | # Disallow submitting jobs with zero num_gpus.
321 | if (submit_config.num_gpus is None) or (submit_config.num_gpus == 0):
322 | raise RuntimeError("submit_config.num_gpus must be set to a non-zero value")
323 |
324 | if submit_config.user_name is None:
325 | submit_config.user_name = get_user_name()
326 |
327 | submit_config.run_func_name = run_func_name
328 | submit_config.run_func_kwargs = run_func_kwargs
329 |
330 | #--------------------------------------------------------------------
331 | # Prepare submission by populating the run dir
332 | #--------------------------------------------------------------------
333 | host_run_dir = _create_run_dir_local(submit_config)
334 |
335 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
336 | docker_valid_name_regex = "^[a-zA-Z0-9][a-zA-Z0-9_.-]+$"
337 | if not re.match(docker_valid_name_regex, submit_config.task_name):
338 | raise RuntimeError("Invalid task name. Probable reason: unacceptable characters in your submit_config.run_desc. Task name must be accepted by the following regex: " + docker_valid_name_regex + ", got " + submit_config.task_name)
339 |
340 | # Farm specific preparations for a submit
341 | farm.finalize_submit_config(submit_config, host_run_dir)
342 | _populate_run_dir(submit_config, host_run_dir)
343 | return farm.submit(submit_config, host_run_dir)
344 |
--------------------------------------------------------------------------------
/dnnlib/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | from . import autosummary
8 | from . import network
9 | from . import optimizer
10 | from . import tfutil
11 | from . import custom_ops
12 |
13 | from .tfutil import *
14 | from .network import Network
15 |
16 | from .optimizer import Optimizer
17 |
18 | from .custom_ops import get_plugin
19 |
--------------------------------------------------------------------------------
/dnnlib/tflib/autosummary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Helper for adding automatically tracked values to Tensorboard.
8 |
9 | Autosummary creates an identity op that internally keeps track of the input
10 | values and automatically shows up in TensorBoard. The reported value
11 | represents an average over input components. The average is accumulated
12 | constantly over time and flushed when save_summaries() is called.
13 |
14 | Notes:
15 | - The output tensor must be used as an input for something else in the
16 | graph. Otherwise, the autosummary op will not get executed, and the average
17 | value will not get accumulated.
18 | - It is perfectly fine to include autosummaries with the same name in
19 | several places throughout the graph, even if they are executed concurrently.
20 | - It is ok to also pass in a python scalar or numpy array. In this case, it
21 | is added to the average immediately.
22 | """
23 |
24 | from collections import OrderedDict
25 | import numpy as np
26 | import tensorflow as tf
27 | from tensorboard import summary as summary_lib
28 | from tensorboard.plugins.custom_scalar import layout_pb2
29 |
30 | from . import tfutil
31 | from .tfutil import TfExpression
32 | from .tfutil import TfExpressionEx
33 |
34 | # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
35 | # Disabled by default to reduce tfevents file size.
36 | enable_custom_scalars = False
37 |
38 | _dtype = tf.float64
39 | _vars = OrderedDict() # name => [var, ...]
40 | _immediate = OrderedDict() # name => update_op, update_value
41 | _finalized = False
42 | _merge_op = None
43 |
44 |
45 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
46 | """Internal helper for creating autosummary accumulators."""
47 | assert not _finalized
48 | name_id = name.replace("/", "_")
49 | v = tf.cast(value_expr, _dtype)
50 |
51 | if v.shape.is_fully_defined():
52 | size = np.prod(v.shape.as_list())
53 | size_expr = tf.constant(size, dtype=_dtype)
54 | else:
55 | size = None
56 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
57 |
58 | if size == 1:
59 | if v.shape.ndims != 0:
60 | v = tf.reshape(v, [])
61 | v = [size_expr, v, tf.square(v)]
62 | else:
63 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
64 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
65 |
66 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
67 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
68 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
69 |
70 | if name in _vars:
71 | _vars[name].append(var)
72 | else:
73 | _vars[name] = [var]
74 | return update_op
75 |
76 |
77 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
78 | """Create a new autosummary.
79 |
80 | Args:
81 | name: Name to use in TensorBoard
82 | value: TensorFlow expression or python value to track
83 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
84 |
85 | Example use of the passthru mechanism:
86 |
87 | n = autosummary('l2loss', loss, passthru=n)
88 |
89 | This is a shorthand for the following code:
90 |
91 | with tf.control_dependencies([autosummary('l2loss', loss)]):
92 | n = tf.identity(n)
93 | """
94 | tfutil.assert_tf_initialized()
95 | name_id = name.replace("/", "_")
96 |
97 | if tfutil.is_tf_expression(value):
98 | with tf.name_scope("summary_" + name_id), tf.device(value.device):
99 | condition = tf.convert_to_tensor(condition, name='condition')
100 | update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
101 | with tf.control_dependencies([update_op]):
102 | return tf.identity(value if passthru is None else passthru)
103 |
104 | else: # python scalar or numpy array
105 | assert not tfutil.is_tf_expression(passthru)
106 | assert not tfutil.is_tf_expression(condition)
107 | if condition:
108 | if name not in _immediate:
109 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
110 | update_value = tf.placeholder(_dtype)
111 | update_op = _create_var(name, update_value)
112 | _immediate[name] = update_op, update_value
113 | update_op, update_value = _immediate[name]
114 | tfutil.run(update_op, {update_value: value})
115 | return value if passthru is None else passthru
116 |
117 |
118 | def finalize_autosummaries() -> None:
119 | """Create the necessary ops to include autosummaries in TensorBoard report.
120 | Note: This should be done only once per graph.
121 | """
122 | global _finalized
123 | tfutil.assert_tf_initialized()
124 |
125 | if _finalized:
126 | return None
127 |
128 | _finalized = True
129 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
130 |
131 | # Create summary ops.
132 | with tf.device(None), tf.control_dependencies(None):
133 | for name, vars_list in _vars.items():
134 | name_id = name.replace("/", "_")
135 | with tfutil.absolute_name_scope("Autosummary/" + name_id):
136 | moments = tf.add_n(vars_list)
137 | moments /= moments[0]
138 | with tf.control_dependencies([moments]): # read before resetting
139 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
140 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
141 | mean = moments[1]
142 | std = tf.sqrt(moments[2] - tf.square(moments[1]))
143 | tf.summary.scalar(name, mean)
144 | if enable_custom_scalars:
145 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
146 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
147 |
148 | # Setup layout for custom scalars.
149 | layout = None
150 | if enable_custom_scalars:
151 | cat_dict = OrderedDict()
152 | for series_name in sorted(_vars.keys()):
153 | p = series_name.split("/")
154 | cat = p[0] if len(p) >= 2 else ""
155 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
156 | if cat not in cat_dict:
157 | cat_dict[cat] = OrderedDict()
158 | if chart not in cat_dict[cat]:
159 | cat_dict[cat][chart] = []
160 | cat_dict[cat][chart].append(series_name)
161 | categories = []
162 | for cat_name, chart_dict in cat_dict.items():
163 | charts = []
164 | for chart_name, series_names in chart_dict.items():
165 | series = []
166 | for series_name in series_names:
167 | series.append(layout_pb2.MarginChartContent.Series(
168 | value=series_name,
169 | lower="xCustomScalars/" + series_name + "/margin_lo",
170 | upper="xCustomScalars/" + series_name + "/margin_hi"))
171 | margin = layout_pb2.MarginChartContent(series=series)
172 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
173 | categories.append(layout_pb2.Category(title=cat_name, chart=charts))
174 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
175 | return layout
176 |
177 | def save_summaries(file_writer, global_step=None):
178 | """Call FileWriter.add_summary() with all summaries in the default graph,
179 | automatically finalizing and merging them on the first call.
180 | """
181 | global _merge_op
182 | tfutil.assert_tf_initialized()
183 |
184 | if _merge_op is None:
185 | layout = finalize_autosummaries()
186 | if layout is not None:
187 | file_writer.add_summary(layout)
188 | with tf.device(None), tf.control_dependencies(None):
189 | _merge_op = tf.summary.merge_all()
190 |
191 | file_writer.add_summary(_merge_op.eval(), global_step)
192 |
--------------------------------------------------------------------------------
/dnnlib/tflib/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """TensorFlow custom ops builder.
8 | """
9 |
10 | import os
11 | import re
12 | import uuid
13 | import hashlib
14 | import tempfile
15 | import shutil
16 | import tensorflow as tf
17 | from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
18 |
19 | #----------------------------------------------------------------------------
20 | # Global options.
21 |
22 | cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
23 | cuda_cache_version_tag = 'v1'
24 | do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
25 | verbose = True # Print status messages to stdout.
26 |
27 | compiler_bindir_search_path = [
28 | 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
29 | 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
30 | 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
31 | ]
32 |
33 | #----------------------------------------------------------------------------
34 | # Internal helper funcs.
35 |
36 | def _find_compiler_bindir():
37 | for compiler_path in compiler_bindir_search_path:
38 | if os.path.isdir(compiler_path):
39 | return compiler_path
40 | return None
41 |
42 | def _get_compute_cap(device):
43 | caps_str = device.physical_device_desc
44 | m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
45 | major = m.group(1)
46 | minor = m.group(2)
47 | return (major, minor)
48 |
49 | def _get_cuda_gpu_arch_string():
50 | gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
51 | if len(gpus) == 0:
52 | raise RuntimeError('No GPU devices found')
53 | (major, minor) = _get_compute_cap(gpus[0])
54 | return 'sm_%s%s' % (major, minor)
55 |
56 | def _run_cmd(cmd):
57 | with os.popen(cmd) as pipe:
58 | output = pipe.read()
59 | status = pipe.close()
60 | if status is not None:
61 | raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
62 |
63 | def _prepare_nvcc_cli(opts):
64 | cmd = 'nvcc ' + opts.strip()
65 | cmd += ' --disable-warnings'
66 | include_path = r'C:\Program Files\tensorflow\include' if os.name == 'nt' else tf.sysconfig.get_include()
67 | cmd += ' --include-path "%s"' % include_path
68 | cmd += ' --include-path "%s"' % os.path.join(include_path, 'external', 'protobuf_archive', 'src')
69 | cmd += ' --include-path "%s"' % os.path.join(include_path, 'external', 'com_google_absl')
70 | cmd += ' --include-path "%s"' % os.path.join(include_path, 'external', 'eigen_archive')
71 |
72 | compiler_bindir = _find_compiler_bindir()
73 | if compiler_bindir is None:
74 | # Require that _find_compiler_bindir succeeds on Windows. Allow
75 | # nvcc to use whatever is the default on Linux.
76 | if os.name == 'nt':
77 | raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
78 | else:
79 | cmd += ' --compiler-bindir "%s"' % compiler_bindir
80 | cmd += ' 2>&1'
81 | return cmd
82 |
83 | #----------------------------------------------------------------------------
84 | # Main entry point.
85 |
86 | _plugin_cache = dict()
87 |
88 | def get_plugin(cuda_file):
89 | cuda_file_base = os.path.basename(cuda_file)
90 | cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
91 |
92 | # Already in cache?
93 | if cuda_file in _plugin_cache:
94 | return _plugin_cache[cuda_file]
95 |
96 | # Setup plugin.
97 | if verbose:
98 | print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
99 | try:
100 | # Hash CUDA source.
101 | md5 = hashlib.md5()
102 | with open(cuda_file, 'rb') as f:
103 | md5.update(f.read())
104 | md5.update(b'\n')
105 |
106 | # Hash headers included by the CUDA code by running it through the preprocessor.
107 | if not do_not_hash_included_headers:
108 | if verbose:
109 | print('Preprocessing... ', end='', flush=True)
110 | with tempfile.TemporaryDirectory() as tmp_dir:
111 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
112 | _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
113 | with open(tmp_file, 'rb') as f:
114 | bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
115 | good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
116 | for ln in f:
117 | if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
118 | ln = ln.replace(bad_file_str, good_file_str)
119 | md5.update(ln)
120 | md5.update(b'\n')
121 |
122 | # Select compiler options.
123 | compile_opts = ''
124 | if os.name == 'nt':
125 | lib_path = r'C:\Program Files\tensorflow'
126 | compile_opts += '"%s"' % os.path.join(lib_path, 'python', '_pywrap_tensorflow_internal.lib')
127 | elif os.name == 'posix':
128 | lib_path = tf.sysconfig.get_lib()
129 | compile_opts += '"%s"' % os.path.join(lib_path, 'python', '_pywrap_tensorflow_internal.so')
130 | compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
131 | else:
132 | assert False # not Windows or Linux, w00t?
133 | compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
134 | compile_opts += ' --use_fast_math'
135 | nvcc_cmd = _prepare_nvcc_cli(compile_opts)
136 |
137 | # Hash build configuration.
138 | md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
139 | md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
140 | md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
141 |
142 | # Compile if not already compiled.
143 | bin_file_ext = '.dll' if os.name == 'nt' else '.so'
144 | bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
145 | if not os.path.isfile(bin_file):
146 | if verbose:
147 | print('Compiling... ', end='', flush=True)
148 | with tempfile.TemporaryDirectory() as tmp_dir:
149 | tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
150 | _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
151 | os.makedirs(cuda_cache_path, exist_ok=True)
152 | intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
153 | shutil.copyfile(tmp_file, intermediate_file)
154 | os.rename(intermediate_file, bin_file) # atomic
155 |
156 | # Load.
157 | if verbose:
158 | print('Loading... ', end='', flush=True)
159 | plugin = tf.load_op_library(bin_file)
160 |
161 | # Add to cache.
162 | _plugin_cache[cuda_file] = plugin
163 | if verbose:
164 | print('Done.', flush=True)
165 | return plugin
166 |
167 | except:
168 | if verbose:
169 | print('Failed!', flush=True)
170 | raise
171 |
172 | #----------------------------------------------------------------------------
173 |
--------------------------------------------------------------------------------
/dnnlib/tflib/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | # empty
8 |
--------------------------------------------------------------------------------
/dnnlib/tflib/ops/fused_bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #define EIGEN_USE_GPU
8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
9 | #include "tensorflow/core/framework/op.h"
10 | #include "tensorflow/core/framework/op_kernel.h"
11 | #include "tensorflow/core/framework/shape_inference.h"
12 | #include
13 |
14 | using namespace tensorflow;
15 | using namespace tensorflow::shape_inference;
16 |
17 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
18 |
19 | //------------------------------------------------------------------------
20 | // CUDA kernel.
21 |
22 | template
23 | struct FusedBiasActKernelParams
24 | {
25 | const T* x; // [sizeX]
26 | const T* b; // [sizeB] or NULL
27 | const T* ref; // [sizeX] or NULL
28 | T* y; // [sizeX]
29 |
30 | int grad;
31 | int axis;
32 | int act;
33 | float alpha;
34 | float gain;
35 |
36 | int sizeX;
37 | int sizeB;
38 | int stepB;
39 | int loopX;
40 | };
41 |
42 | template
43 | static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p)
44 | {
45 | const float expRange = 80.0f;
46 | const float halfExpRange = 40.0f;
47 | const float seluScale = 1.0507009873554804934193349852946f;
48 | const float seluAlpha = 1.6732632423543772848170429916717f;
49 |
50 | // Loop over elements.
51 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
52 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
53 | {
54 | // Load and apply bias.
55 | float x = (float)p.x[xi];
56 | if (p.b)
57 | x += (float)p.b[(xi / p.stepB) % p.sizeB];
58 | float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
59 | if (p.gain != 0.0f & p.act != 9)
60 | ref /= p.gain;
61 |
62 | // Evaluate activation func.
63 | float y;
64 | switch (p.act * 10 + p.grad)
65 | {
66 | // linear
67 | default:
68 | case 10: y = x; break;
69 | case 11: y = x; break;
70 | case 12: y = 0.0f; break;
71 |
72 | // relu
73 | case 20: y = (x > 0.0f) ? x : 0.0f; break;
74 | case 21: y = (ref > 0.0f) ? x : 0.0f; break;
75 | case 22: y = 0.0f; break;
76 |
77 | // lrelu
78 | case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
79 | case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
80 | case 32: y = 0.0f; break;
81 |
82 | // tanh
83 | case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
84 | case 41: y = x * (1.0f - ref * ref); break;
85 | case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
86 |
87 | // sigmoid
88 | case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
89 | case 51: y = x * ref * (1.0f - ref); break;
90 | case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
91 |
92 | // elu
93 | case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
94 | case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
95 | case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
96 |
97 | // selu
98 | case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
99 | case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
100 | case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
101 |
102 | // softplus
103 | case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
104 | case 81: y = x * (1.0f - expf(-ref)); break;
105 | case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
106 |
107 | // swish
108 | case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
109 | case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
110 | case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
111 | }
112 |
113 | // Apply gain and store.
114 | p.y[xi] = (T)(y * p.gain);
115 | }
116 | }
117 |
118 | //------------------------------------------------------------------------
119 | // TensorFlow op.
120 |
121 | template
122 | struct FusedBiasActOp : public OpKernel
123 | {
124 | FusedBiasActKernelParams m_attribs;
125 |
126 | FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
127 | {
128 | memset(&m_attribs, 0, sizeof(m_attribs));
129 | OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
130 | OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
131 | OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
132 | OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
133 | OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
134 | OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
135 | OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
136 | OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
137 | }
138 |
139 | void Compute(OpKernelContext* ctx)
140 | {
141 | FusedBiasActKernelParams p = m_attribs;
142 | cudaStream_t stream = ctx->eigen_device().stream();
143 |
144 | const Tensor& x = ctx->input(0); // [...]
145 | const Tensor& b = ctx->input(1); // [sizeB] or [0]
146 | const Tensor& ref = ctx->input(2); // x.shape or [0]
147 | p.x = x.flat().data();
148 | p.b = (b.NumElements()) ? b.flat().data() : NULL;
149 | p.ref = (ref.NumElements()) ? ref.flat().data() : NULL;
150 | OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
151 | OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
152 | OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
153 | OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
154 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
155 |
156 | p.sizeX = (int)x.NumElements();
157 | p.sizeB = (int)b.NumElements();
158 | p.stepB = 1;
159 | for (int i = m_attribs.axis + 1; i < x.dims(); i++)
160 | p.stepB *= (int)x.dim_size(i);
161 |
162 | Tensor* y = NULL; // x.shape
163 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
164 | p.y = y->flat().data();
165 |
166 | p.loopX = 4;
167 | int blockSize = 4 * 32;
168 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
169 | void* args[] = {&p};
170 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream));
171 | }
172 | };
173 |
174 | REGISTER_OP("FusedBiasAct")
175 | .Input ("x: T")
176 | .Input ("b: T")
177 | .Input ("ref: T")
178 | .Output ("y: T")
179 | .Attr ("T: {float, half}")
180 | .Attr ("grad: int = 0")
181 | .Attr ("axis: int = 1")
182 | .Attr ("act: int = 0")
183 | .Attr ("alpha: float = 0.0")
184 | .Attr ("gain: float = 1.0");
185 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
186 | REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp);
187 |
188 | //------------------------------------------------------------------------
189 |
--------------------------------------------------------------------------------
/dnnlib/tflib/ops/fused_bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Custom TensorFlow ops for efficient bias and activation."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 | from .. import custom_ops
13 | from ...util import EasyDict
14 |
15 | def _get_plugin():
16 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | activation_funcs = {
21 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
22 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
23 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
24 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
25 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
26 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
27 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
28 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
29 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
30 | }
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
35 | r"""Fused bias and activation function.
36 |
37 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
38 | and scales the result by `gain`. Each of the steps is optional. In most cases,
39 | the fused op is considerably more efficient than performing the same calculation
40 | using standard TensorFlow ops. It supports first and second order gradients,
41 | but not third order gradients.
42 |
43 | Args:
44 | x: Input activation tensor. Can have any shape, but if `b` is defined, the
45 | dimension corresponding to `axis`, as well as the rank, must be known.
46 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
47 | as `x`. The shape must be known, and it must match the dimension of `x`
48 | corresponding to `axis`.
49 | axis: The dimension in `x` corresponding to the elements of `b`.
50 | The value of `axis` is ignored if `b` is not specified.
51 | act: Name of the activation function to evaluate, or `"linear"` to disable.
52 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
53 | See `activation_funcs` for a full list. `None` is not allowed.
54 | alpha: Shape parameter for the activation function, or `None` to use the default.
55 | gain: Scaling factor for the output tensor, or `None` to use default.
56 | See `activation_funcs` for the default scaling of each activation function.
57 | If unsure, consider specifying `1.0`.
58 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
59 |
60 | Returns:
61 | Tensor of the same shape and datatype as `x`.
62 | """
63 |
64 | impl_dict = {
65 | 'ref': _fused_bias_act_ref,
66 | 'cuda': _fused_bias_act_cuda,
67 | }
68 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
73 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
74 |
75 | # Validate arguments.
76 | x = tf.convert_to_tensor(x)
77 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
78 | act_spec = activation_funcs[act]
79 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
80 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
81 | if alpha is None:
82 | alpha = act_spec.def_alpha
83 | if gain is None:
84 | gain = act_spec.def_gain
85 |
86 | # Add bias.
87 | if b.shape[0] != 0:
88 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
89 |
90 | # Evaluate activation function.
91 | x = act_spec.func(x, alpha=alpha)
92 |
93 | # Scale by gain.
94 | if gain != 1:
95 | x *= gain
96 | return x
97 |
98 | #----------------------------------------------------------------------------
99 |
100 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
101 | """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
102 |
103 | # Validate arguments.
104 | x = tf.convert_to_tensor(x)
105 | empty_tensor = tf.constant([], dtype=x.dtype)
106 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor
107 | act_spec = activation_funcs[act]
108 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
109 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
110 | if alpha is None:
111 | alpha = act_spec.def_alpha
112 | if gain is None:
113 | gain = act_spec.def_gain
114 |
115 | # Special cases.
116 | if act == 'linear' and b is None and gain == 1.0:
117 | return x
118 | if act_spec.cuda_idx is None:
119 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
120 |
121 | # CUDA kernel.
122 | cuda_kernel = _get_plugin().fused_bias_act
123 | cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
124 |
125 | # Forward pass: y = func(x, b).
126 | def func_y(x, b):
127 | y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
128 | y.set_shape(x.shape)
129 | return y
130 |
131 | # Backward pass: dx, db = grad(dy, x, y)
132 | def grad_dx(dy, x, y):
133 | ref = {'x': x, 'y': y}[act_spec.ref]
134 | dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
135 | dx.set_shape(x.shape)
136 | return dx
137 | def grad_db(dx):
138 | if b.shape[0] == 0:
139 | return empty_tensor
140 | db = dx
141 | if axis < x.shape.rank - 1:
142 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
143 | if axis > 0:
144 | db = tf.reduce_sum(db, list(range(axis)))
145 | db.set_shape(b.shape)
146 | return db
147 |
148 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
149 | def grad2_d_dy(d_dx, d_db, x, y):
150 | ref = {'x': x, 'y': y}[act_spec.ref]
151 | d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
152 | d_dy.set_shape(x.shape)
153 | return d_dy
154 | def grad2_d_x(d_dx, d_db, x, y):
155 | ref = {'x': x, 'y': y}[act_spec.ref]
156 | d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
157 | d_x.set_shape(x.shape)
158 | return d_x
159 |
160 | # Fast version for piecewise-linear activation funcs.
161 | @tf.custom_gradient
162 | def func_zero_2nd_grad(x, b):
163 | y = func_y(x, b)
164 | @tf.custom_gradient
165 | def grad(dy):
166 | dx = grad_dx(dy, x, y)
167 | db = grad_db(dx)
168 | def grad2(d_dx, d_db):
169 | d_dy = grad2_d_dy(d_dx, d_db, x, y)
170 | return d_dy
171 | return (dx, db), grad2
172 | return y, grad
173 |
174 | # Slow version for general activation funcs.
175 | @tf.custom_gradient
176 | def func_nonzero_2nd_grad(x, b):
177 | y = func_y(x, b)
178 | def grad_wrap(dy):
179 | @tf.custom_gradient
180 | def grad_impl(dy, x):
181 | dx = grad_dx(dy, x, y)
182 | db = grad_db(dx)
183 | def grad2(d_dx, d_db):
184 | d_dy = grad2_d_dy(d_dx, d_db, x, y)
185 | d_x = grad2_d_x(d_dx, d_db, x, y)
186 | return d_dy, d_x
187 | return (dx, db), grad2
188 | return grad_impl(dy, x)
189 | return y, grad_wrap
190 |
191 | # Which version to use?
192 | if act_spec.zero_2nd_grad:
193 | return func_zero_2nd_grad(x, b)
194 | return func_nonzero_2nd_grad(x, b)
195 |
196 | #----------------------------------------------------------------------------
197 |
--------------------------------------------------------------------------------
/dnnlib/tflib/ops/upfirdn_2d.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #define EIGEN_USE_GPU
8 | #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
9 | #include "tensorflow/core/framework/op.h"
10 | #include "tensorflow/core/framework/op_kernel.h"
11 | #include "tensorflow/core/framework/shape_inference.h"
12 | #include
13 |
14 | using namespace tensorflow;
15 | using namespace tensorflow::shape_inference;
16 |
17 | //------------------------------------------------------------------------
18 | // Helpers.
19 |
20 | #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
21 |
22 | static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
23 | {
24 | int c = a / b;
25 | if (c * b > a)
26 | c--;
27 | return c;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 | // CUDA kernel params.
32 |
33 | template
34 | struct UpFirDn2DKernelParams
35 | {
36 | const T* x; // [majorDim, inH, inW, minorDim]
37 | const T* k; // [kernelH, kernelW]
38 | T* y; // [majorDim, outH, outW, minorDim]
39 |
40 | int upx;
41 | int upy;
42 | int downx;
43 | int downy;
44 | int padx0;
45 | int padx1;
46 | int pady0;
47 | int pady1;
48 |
49 | int majorDim;
50 | int inH;
51 | int inW;
52 | int minorDim;
53 | int kernelH;
54 | int kernelW;
55 | int outH;
56 | int outW;
57 | int loopMajor;
58 | int loopX;
59 | };
60 |
61 | //------------------------------------------------------------------------
62 | // General CUDA implementation for large filter kernels.
63 |
64 | template
65 | static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p)
66 | {
67 | // Calculate thread index.
68 | int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
69 | int outY = minorIdx / p.minorDim;
70 | minorIdx -= outY * p.minorDim;
71 | int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
72 | int majorIdxBase = blockIdx.z * p.loopMajor;
73 | if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
74 | return;
75 |
76 | // Setup Y receptive field.
77 | int midY = outY * p.downy + p.upy - 1 - p.pady0;
78 | int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
79 | int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
80 | int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
81 |
82 | // Loop over majorDim and outX.
83 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
84 | for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
85 | {
86 | // Setup X receptive field.
87 | int midX = outX * p.downx + p.upx - 1 - p.padx0;
88 | int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
89 | int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
90 | int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
91 |
92 | // Initialize pointers.
93 | const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
94 | const T* kp = &p.k[kernelY * p.kernelW + kernelX];
95 | int xpx = p.minorDim;
96 | int kpx = -p.upx;
97 | int xpy = p.inW * p.minorDim;
98 | int kpy = -p.upy * p.kernelW;
99 |
100 | // Inner loop.
101 | float v = 0.0f;
102 | for (int y = 0; y < h; y++)
103 | {
104 | for (int x = 0; x < w; x++)
105 | {
106 | v += (float)(*xp) * (float)(*kp);
107 | xp += xpx;
108 | kp += kpx;
109 | }
110 | xp += xpy - w * xpx;
111 | kp += kpy - w * kpx;
112 | }
113 |
114 | // Store result.
115 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
116 | }
117 | }
118 |
119 | //------------------------------------------------------------------------
120 | // Specialized CUDA implementation for small filter kernels.
121 |
122 | template
123 | static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p)
124 | {
125 | //assert(kernelW % upx == 0);
126 | //assert(kernelH % upy == 0);
127 | const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
128 | const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
129 | __shared__ volatile float sk[kernelH][kernelW];
130 | __shared__ volatile float sx[tileInH][tileInW];
131 |
132 | // Calculate tile index.
133 | int minorIdx = blockIdx.x;
134 | int tileOutY = minorIdx / p.minorDim;
135 | minorIdx -= tileOutY * p.minorDim;
136 | tileOutY *= tileOutH;
137 | int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
138 | int majorIdxBase = blockIdx.z * p.loopMajor;
139 | if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
140 | return;
141 |
142 | // Load filter kernel (flipped).
143 | for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
144 | {
145 | int ky = tapIdx / kernelW;
146 | int kx = tapIdx - ky * kernelW;
147 | float v = 0.0f;
148 | if (kx < p.kernelW & ky < p.kernelH)
149 | v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
150 | sk[ky][kx] = v;
151 | }
152 |
153 | // Loop over majorDim and outX.
154 | for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
155 | for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
156 | {
157 | // Load input pixels.
158 | int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
159 | int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
160 | int tileInX = floorDiv(tileMidX, upx);
161 | int tileInY = floorDiv(tileMidY, upy);
162 | __syncthreads();
163 | for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
164 | {
165 | int relInY = inIdx / tileInW;
166 | int relInX = inIdx - relInY * tileInW;
167 | int inX = relInX + tileInX;
168 | int inY = relInY + tileInY;
169 | float v = 0.0f;
170 | if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
171 | v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
172 | sx[relInY][relInX] = v;
173 | }
174 |
175 | // Loop over output pixels.
176 | __syncthreads();
177 | for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
178 | {
179 | int relOutY = outIdx / tileOutW;
180 | int relOutX = outIdx - relOutY * tileOutW;
181 | int outX = relOutX + tileOutX;
182 | int outY = relOutY + tileOutY;
183 |
184 | // Setup receptive field.
185 | int midX = tileMidX + relOutX * downx;
186 | int midY = tileMidY + relOutY * downy;
187 | int inX = floorDiv(midX, upx);
188 | int inY = floorDiv(midY, upy);
189 | int relInX = inX - tileInX;
190 | int relInY = inY - tileInY;
191 | int kernelX = (inX + 1) * upx - midX - 1; // flipped
192 | int kernelY = (inY + 1) * upy - midY - 1; // flipped
193 |
194 | // Inner loop.
195 | float v = 0.0f;
196 | #pragma unroll
197 | for (int y = 0; y < kernelH / upy; y++)
198 | #pragma unroll
199 | for (int x = 0; x < kernelW / upx; x++)
200 | v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
201 |
202 | // Store result.
203 | if (outX < p.outW & outY < p.outH)
204 | p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
205 | }
206 | }
207 | }
208 |
209 | //------------------------------------------------------------------------
210 | // TensorFlow op.
211 |
212 | template
213 | struct UpFirDn2DOp : public OpKernel
214 | {
215 | UpFirDn2DKernelParams m_attribs;
216 |
217 | UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
218 | {
219 | memset(&m_attribs, 0, sizeof(m_attribs));
220 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
221 | OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
222 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
223 | OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
224 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
225 | OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
226 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
227 | OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
228 | OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
229 | OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
230 | }
231 |
232 | void Compute(OpKernelContext* ctx)
233 | {
234 | UpFirDn2DKernelParams p = m_attribs;
235 | cudaStream_t stream = ctx->eigen_device().stream();
236 |
237 | const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
238 | const Tensor& k = ctx->input(1); // [kernelH, kernelW]
239 | p.x = x.flat().data();
240 | p.k = k.flat().data();
241 | OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
242 | OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
243 | OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
244 | OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
245 |
246 | p.majorDim = (int)x.dim_size(0);
247 | p.inH = (int)x.dim_size(1);
248 | p.inW = (int)x.dim_size(2);
249 | p.minorDim = (int)x.dim_size(3);
250 | p.kernelH = (int)k.dim_size(0);
251 | p.kernelW = (int)k.dim_size(1);
252 | OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
253 |
254 | p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
255 | p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
256 | OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
257 |
258 | Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
259 | TensorShape ys;
260 | ys.AddDim(p.majorDim);
261 | ys.AddDim(p.outH);
262 | ys.AddDim(p.outW);
263 | ys.AddDim(p.minorDim);
264 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
265 | p.y = y->flat().data();
266 | OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
267 |
268 | // Choose CUDA kernel to use.
269 | void* cudaKernel = (void*)UpFirDn2DKernel_large;
270 | int tileOutW = -1;
271 | int tileOutH = -1;
272 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
273 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
274 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
275 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
276 | if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
277 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
278 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
279 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
280 | if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; }
281 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
282 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
283 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
284 | if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; }
285 |
286 | // Choose launch params.
287 | dim3 blockSize;
288 | dim3 gridSize;
289 | if (tileOutW > 0 && tileOutH > 0) // small
290 | {
291 | p.loopMajor = (p.majorDim - 1) / 16384 + 1;
292 | p.loopX = 1;
293 | blockSize = dim3(32 * 8, 1, 1);
294 | gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
295 | }
296 | else // large
297 | {
298 | p.loopMajor = (p.majorDim - 1) / 16384 + 1;
299 | p.loopX = 4;
300 | blockSize = dim3(4, 32, 1);
301 | gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
302 | }
303 |
304 | // Launch CUDA kernel.
305 | void* args[] = {&p};
306 | OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
307 | }
308 | };
309 |
310 | REGISTER_OP("UpFirDn2D")
311 | .Input ("x: T")
312 | .Input ("k: T")
313 | .Output ("y: T")
314 | .Attr ("T: {float, half}")
315 | .Attr ("upx: int = 1")
316 | .Attr ("upy: int = 1")
317 | .Attr ("downx: int = 1")
318 | .Attr ("downy: int = 1")
319 | .Attr ("padx0: int = 0")
320 | .Attr ("padx1: int = 0")
321 | .Attr ("pady0: int = 0")
322 | .Attr ("pady1: int = 0");
323 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
324 | REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp);
325 |
326 | //------------------------------------------------------------------------
327 |
--------------------------------------------------------------------------------
/dnnlib/tflib/ops/upfirdn_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Custom TensorFlow ops for efficient resampling of 2D images."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 | from .. import custom_ops
13 |
14 | def _get_plugin():
15 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
20 | r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
21 |
22 | Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
23 | and performs the following operations for each image, batched across
24 | `majorDim` and `minorDim`:
25 |
26 | 1. Pad the image with zeros by the specified number of pixels on each side
27 | (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
28 | corresponds to cropping the image.
29 |
30 | 2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
31 |
32 | 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
33 | image so that the footprint of all output pixels lies within the input image.
34 |
35 | 4. Downsample the image by throwing away pixels (`downx`, `downy`).
36 |
37 | This sequence of operations bears close resemblance to scipy.signal.upfirdn().
38 | The fused op is considerably more efficient than performing the same calculation
39 | using standard TensorFlow ops. It supports gradients of arbitrary order.
40 |
41 | Args:
42 | x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
43 | k: 2D FIR filter of the shape `[firH, firW]`.
44 | upx: Integer upsampling factor along the X-axis (default: 1).
45 | upy: Integer upsampling factor along the Y-axis (default: 1).
46 | downx: Integer downsampling factor along the X-axis (default: 1).
47 | downy: Integer downsampling factor along the Y-axis (default: 1).
48 | padx0: Number of pixels to pad on the left side (default: 0).
49 | padx1: Number of pixels to pad on the right side (default: 0).
50 | pady0: Number of pixels to pad on the top side (default: 0).
51 | pady1: Number of pixels to pad on the bottom side (default: 0).
52 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
53 |
54 | Returns:
55 | Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
56 | """
57 |
58 | impl_dict = {
59 | 'ref': _upfirdn_2d_ref,
60 | 'cuda': _upfirdn_2d_cuda,
61 | }
62 | return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
67 | """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
68 |
69 | x = tf.convert_to_tensor(x)
70 | k = np.asarray(k, dtype=np.float32)
71 | assert x.shape.rank == 4
72 | inH = x.shape[1].value
73 | inW = x.shape[2].value
74 | minorDim = _shape(x, 3)
75 | kernelH, kernelW = k.shape
76 | assert inW >= 1 and inH >= 1
77 | assert kernelW >= 1 and kernelH >= 1
78 | assert isinstance(upx, int) and isinstance(upy, int)
79 | assert isinstance(downx, int) and isinstance(downy, int)
80 | assert isinstance(padx0, int) and isinstance(padx1, int)
81 | assert isinstance(pady0, int) and isinstance(pady1, int)
82 |
83 | # Upsample (insert zeros).
84 | x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
85 | x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
86 | x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
87 |
88 | # Pad (crop if negative).
89 | x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
90 | x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
91 |
92 | # Convolve with filter.
93 | x = tf.transpose(x, [0, 3, 1, 2])
94 | x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
95 | w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
96 | x = tf.transpose(x, [0, 2, 3, 1])
97 | x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NHWC')
98 | x = tf.transpose(x, [0, 3, 1, 2])
99 | x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
100 | x = tf.transpose(x, [0, 2, 3, 1])
101 |
102 | # Downsample (throw away pixels).
103 | return x[:, ::downy, ::downx, :]
104 |
105 | #----------------------------------------------------------------------------
106 |
107 | def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
108 | """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
109 |
110 | x = tf.convert_to_tensor(x)
111 | k = np.asarray(k, dtype=np.float32)
112 | majorDim, inH, inW, minorDim = x.shape.as_list()
113 | kernelH, kernelW = k.shape
114 | assert inW >= 1 and inH >= 1
115 | assert kernelW >= 1 and kernelH >= 1
116 | assert isinstance(upx, int) and isinstance(upy, int)
117 | assert isinstance(downx, int) and isinstance(downy, int)
118 | assert isinstance(padx0, int) and isinstance(padx1, int)
119 | assert isinstance(pady0, int) and isinstance(pady1, int)
120 |
121 | outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
122 | outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
123 | assert outW >= 1 and outH >= 1
124 |
125 | kc = tf.constant(k, dtype=x.dtype)
126 | gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
127 | gpadx0 = kernelW - padx0 - 1
128 | gpady0 = kernelH - pady0 - 1
129 | gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
130 | gpady1 = inH * upy - outH * downy + pady0 - upy + 1
131 |
132 | @tf.custom_gradient
133 | def func(x):
134 | y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
135 | y.set_shape([majorDim, outH, outW, minorDim])
136 | @tf.custom_gradient
137 | def grad(dy):
138 | dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
139 | dx.set_shape([majorDim, inH, inW, minorDim])
140 | return dx, func
141 | return y, grad
142 | return func(x)
143 |
144 | #----------------------------------------------------------------------------
145 |
146 | def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
147 | r"""Filter a batch of 2D images with the given FIR filter.
148 |
149 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
150 | and filters each image with the given filter. The filter is normalized so that
151 | if the input pixels are constant, they will be scaled by the specified `gain`.
152 | Pixels outside the image are assumed to be zero.
153 |
154 | Args:
155 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
156 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
157 | gain: Scaling factor for signal magnitude (default: 1.0).
158 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
159 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
160 |
161 | Returns:
162 | Tensor of the same shape and datatype as `x`.
163 | """
164 |
165 | k = _setup_kernel(k) * gain
166 | p = k.shape[0] - 1
167 | return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
168 |
169 | #----------------------------------------------------------------------------
170 |
171 | def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
172 | r"""Upsample a batch of 2D images with the given filter.
173 |
174 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
175 | and upsamples each image with the given filter. The filter is normalized so that
176 | if the input pixels are constant, they will be scaled by the specified `gain`.
177 | Pixels outside the image are assumed to be zero, and the filter is padded with
178 | zeros so that its shape is a multiple of the upsampling factor.
179 |
180 | Args:
181 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
182 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
183 | The default is `[1] * factor`, which corresponds to nearest-neighbor
184 | upsampling.
185 | factor: Integer upsampling factor (default: 2).
186 | gain: Scaling factor for signal magnitude (default: 1.0).
187 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
188 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
189 |
190 | Returns:
191 | Tensor of the shape `[N, C, H * factor, W * factor]` or
192 | `[N, H * factor, W * factor, C]`, and same datatype as `x`.
193 | """
194 |
195 | assert isinstance(factor, int) and factor >= 1
196 | if k is None:
197 | k = [1] * factor
198 | k = _setup_kernel(k) * (gain * (factor ** 2))
199 | p = k.shape[0] - factor
200 | return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
201 |
202 | #----------------------------------------------------------------------------
203 |
204 | def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
205 | r"""Downsample a batch of 2D images with the given filter.
206 |
207 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
208 | and downsamples each image with the given filter. The filter is normalized so that
209 | if the input pixels are constant, they will be scaled by the specified `gain`.
210 | Pixels outside the image are assumed to be zero, and the filter is padded with
211 | zeros so that its shape is a multiple of the downsampling factor.
212 |
213 | Args:
214 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
215 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
216 | The default is `[1] * factor`, which corresponds to average pooling.
217 | factor: Integer downsampling factor (default: 2).
218 | gain: Scaling factor for signal magnitude (default: 1.0).
219 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
220 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
221 |
222 | Returns:
223 | Tensor of the shape `[N, C, H // factor, W // factor]` or
224 | `[N, H // factor, W // factor, C]`, and same datatype as `x`.
225 | """
226 |
227 | assert isinstance(factor, int) and factor >= 1
228 | if k is None:
229 | k = [1] * factor
230 | k = _setup_kernel(k) * gain
231 | p = k.shape[0] - factor
232 | return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
233 |
234 | #----------------------------------------------------------------------------
235 |
236 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
237 | r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
238 |
239 | Padding is performed only once at the beginning, not between the operations.
240 | The fused op is considerably more efficient than performing the same calculation
241 | using standard TensorFlow ops. It supports gradients of arbitrary order.
242 |
243 | Args:
244 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
245 | w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
246 | Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
247 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
248 | The default is `[1] * factor`, which corresponds to nearest-neighbor
249 | upsampling.
250 | factor: Integer upsampling factor (default: 2).
251 | gain: Scaling factor for signal magnitude (default: 1.0).
252 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
253 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
254 |
255 | Returns:
256 | Tensor of the shape `[N, C, H * factor, W * factor]` or
257 | `[N, H * factor, W * factor, C]`, and same datatype as `x`.
258 | """
259 |
260 | assert isinstance(factor, int) and factor >= 1
261 |
262 | # Check weight shape.
263 | w = tf.convert_to_tensor(w)
264 | assert w.shape.rank == 4
265 | convH = w.shape[0].value
266 | convW = w.shape[1].value
267 | inC = _shape(w, 2)
268 | outC = _shape(w, 3)
269 | assert convW == convH
270 |
271 | # Setup filter kernel.
272 | if k is None:
273 | k = [1] * factor
274 | k = _setup_kernel(k) * (gain * (factor ** 2))
275 | p = (k.shape[0] - factor) - (convW - 1)
276 |
277 | # Determine data dimensions.
278 | if data_format == 'NCHW':
279 | stride = [1, 1, factor, factor]
280 | output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
281 | num_groups = _shape(x, 1) // inC
282 | else:
283 | stride = [1, factor, factor, 1]
284 | output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
285 | num_groups = _shape(x, 3) // inC
286 |
287 | # Transpose weights.
288 | w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
289 | w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
290 | w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
291 |
292 | # Execute.
293 | x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
294 | return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
295 |
296 | #----------------------------------------------------------------------------
297 |
298 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
299 | r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
300 |
301 | Padding is performed only once at the beginning, not between the operations.
302 | The fused op is considerably more efficient than performing the same calculation
303 | using standard TensorFlow ops. It supports gradients of arbitrary order.
304 |
305 | Args:
306 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
307 | w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
308 | Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
309 | k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
310 | The default is `[1] * factor`, which corresponds to average pooling.
311 | factor: Integer downsampling factor (default: 2).
312 | gain: Scaling factor for signal magnitude (default: 1.0).
313 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
314 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
315 |
316 | Returns:
317 | Tensor of the shape `[N, C, H // factor, W // factor]` or
318 | `[N, H // factor, W // factor, C]`, and same datatype as `x`.
319 | """
320 |
321 | assert isinstance(factor, int) and factor >= 1
322 | w = tf.convert_to_tensor(w)
323 | convH, convW, _inC, _outC = w.shape.as_list()
324 | assert convW == convH
325 | if k is None:
326 | k = [1] * factor
327 | k = _setup_kernel(k) * gain
328 | p = (k.shape[0] - factor) + (convW - 1)
329 | if data_format == 'NCHW':
330 | s = [1, 1, factor, factor]
331 | else:
332 | s = [1, factor, factor, 1]
333 | x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
334 | return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
335 |
336 | #----------------------------------------------------------------------------
337 | # Internal helper funcs.
338 |
339 | def _shape(tf_expr, dim_idx):
340 | if tf_expr.shape.rank is not None:
341 | dim = tf_expr.shape[dim_idx].value
342 | if dim is not None:
343 | return dim
344 | return tf.shape(tf_expr)[dim_idx]
345 |
346 | def _setup_kernel(k):
347 | k = np.asarray(k, dtype=np.float32)
348 | if k.ndim == 1:
349 | k = np.outer(k, k)
350 | k /= np.sum(k)
351 | assert k.ndim == 2
352 | assert k.shape[0] == k.shape[1]
353 | return k
354 |
355 | def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
356 | assert data_format in ['NCHW', 'NHWC']
357 | assert x.shape.rank == 4
358 | y = x
359 | if data_format == 'NCHW':
360 | y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
361 | y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
362 | if data_format == 'NCHW':
363 | y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
364 | return y
365 |
366 | #----------------------------------------------------------------------------
367 |
--------------------------------------------------------------------------------
/dnnlib/tflib/tfutil.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Miscellaneous helper utils for Tensorflow."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | # Silence deprecation warnings from TensorFlow 1.13 onwards
14 | import logging
15 | logging.getLogger('tensorflow').setLevel(logging.ERROR)
16 | import tensorflow.contrib # requires TensorFlow 1.x!
17 | tf.contrib = tensorflow.contrib
18 |
19 | from typing import Any, Iterable, List, Union
20 |
21 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
22 | """A type that represents a valid Tensorflow expression."""
23 |
24 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
25 | """A type that can be converted to a valid Tensorflow expression."""
26 |
27 |
28 | def run(*args, **kwargs) -> Any:
29 | """Run the specified ops in the default session."""
30 | assert_tf_initialized()
31 | return tf.get_default_session().run(*args, **kwargs)
32 |
33 |
34 | def is_tf_expression(x: Any) -> bool:
35 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
36 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
37 |
38 |
39 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
40 | """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
41 | return [dim.value for dim in shape]
42 |
43 |
44 | def flatten(x: TfExpressionEx) -> TfExpression:
45 | """Shortcut function for flattening a tensor."""
46 | with tf.name_scope("Flatten"):
47 | return tf.reshape(x, [-1])
48 |
49 |
50 | def log2(x: TfExpressionEx) -> TfExpression:
51 | """Logarithm in base 2."""
52 | with tf.name_scope("Log2"):
53 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
54 |
55 |
56 | def exp2(x: TfExpressionEx) -> TfExpression:
57 | """Exponent in base 2."""
58 | with tf.name_scope("Exp2"):
59 | return tf.exp(x * np.float32(np.log(2.0)))
60 |
61 |
62 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
63 | """Linear interpolation."""
64 | with tf.name_scope("Lerp"):
65 | return a + (b - a) * t
66 |
67 |
68 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
69 | """Linear interpolation with clip."""
70 | with tf.name_scope("LerpClip"):
71 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
72 |
73 |
74 | def absolute_name_scope(scope: str) -> tf.name_scope:
75 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
76 | return tf.name_scope(scope + "/")
77 |
78 |
79 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
80 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
81 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
82 |
83 |
84 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
85 | # Defaults.
86 | cfg = dict()
87 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
88 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
89 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
90 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
91 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
92 |
93 | # Remove defaults for environment variables that are already set.
94 | for key in list(cfg):
95 | fields = key.split(".")
96 | if fields[0] == "env":
97 | assert len(fields) == 2
98 | if fields[1] in os.environ:
99 | del cfg[key]
100 |
101 | # User overrides.
102 | if config_dict is not None:
103 | cfg.update(config_dict)
104 | return cfg
105 |
106 |
107 | def init_tf(config_dict: dict = None) -> None:
108 | """Initialize TensorFlow session using good default settings."""
109 | # Skip if already initialized.
110 | if tf.get_default_session() is not None:
111 | return
112 |
113 | # Setup config dict and random seeds.
114 | cfg = _sanitize_tf_config(config_dict)
115 | np_random_seed = cfg["rnd.np_random_seed"]
116 | if np_random_seed is not None:
117 | np.random.seed(np_random_seed)
118 | tf_random_seed = cfg["rnd.tf_random_seed"]
119 | if tf_random_seed == "auto":
120 | tf_random_seed = np.random.randint(1 << 31)
121 | if tf_random_seed is not None:
122 | tf.set_random_seed(tf_random_seed)
123 |
124 | # Setup environment variables.
125 | for key, value in cfg.items():
126 | fields = key.split(".")
127 | if fields[0] == "env":
128 | assert len(fields) == 2
129 | os.environ[fields[1]] = str(value)
130 |
131 | # Create default TensorFlow session.
132 | create_session(cfg, force_as_default=True)
133 |
134 |
135 | def assert_tf_initialized():
136 | """Check that TensorFlow session has been initialized."""
137 | if tf.get_default_session() is None:
138 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
139 |
140 |
141 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
142 | """Create tf.Session based on config dict."""
143 | # Setup TensorFlow config proto.
144 | cfg = _sanitize_tf_config(config_dict)
145 | config_proto = tf.ConfigProto()
146 | for key, value in cfg.items():
147 | fields = key.split(".")
148 | if fields[0] not in ["rnd", "env"]:
149 | obj = config_proto
150 | for field in fields[:-1]:
151 | obj = getattr(obj, field)
152 | setattr(obj, fields[-1], value)
153 |
154 | # Create session.
155 | session = tf.Session(config=config_proto)
156 | if force_as_default:
157 | # pylint: disable=protected-access
158 | session._default_session = session.as_default()
159 | session._default_session.enforce_nesting = False
160 | session._default_session.__enter__()
161 | return session
162 |
163 |
164 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
165 | """Initialize all tf.Variables that have not already been initialized.
166 |
167 | Equivalent to the following, but more efficient and does not bloat the tf graph:
168 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
169 | """
170 | assert_tf_initialized()
171 | if target_vars is None:
172 | target_vars = tf.global_variables()
173 |
174 | test_vars = []
175 | test_ops = []
176 |
177 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
178 | for var in target_vars:
179 | assert is_tf_expression(var)
180 |
181 | try:
182 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
183 | except KeyError:
184 | # Op does not exist => variable may be uninitialized.
185 | test_vars.append(var)
186 |
187 | with absolute_name_scope(var.name.split(":")[0]):
188 | test_ops.append(tf.is_variable_initialized(var))
189 |
190 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
191 | run([var.initializer for var in init_vars])
192 |
193 |
194 | def set_vars(var_to_value_dict: dict) -> None:
195 | """Set the values of given tf.Variables.
196 |
197 | Equivalent to the following, but more efficient and does not bloat the tf graph:
198 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
199 | """
200 | assert_tf_initialized()
201 | ops = []
202 | feed_dict = {}
203 |
204 | for var, value in var_to_value_dict.items():
205 | assert is_tf_expression(var)
206 |
207 | try:
208 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
209 | except KeyError:
210 | with absolute_name_scope(var.name.split(":")[0]):
211 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
212 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
213 |
214 | ops.append(setter)
215 | feed_dict[setter.op.inputs[1]] = value
216 |
217 | run(ops, feed_dict)
218 |
219 |
220 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
221 | """Create tf.Variable with large initial value without bloating the tf graph."""
222 | assert_tf_initialized()
223 | assert isinstance(initial_value, np.ndarray)
224 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
225 | var = tf.Variable(zeros, *args, **kwargs)
226 | set_vars({var: initial_value})
227 | return var
228 |
229 |
230 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
231 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
232 | Can be used as an input transformation for Network.run().
233 | """
234 | images = tf.cast(images, tf.float32)
235 | if nhwc_to_nchw:
236 | images = tf.transpose(images, [0, 3, 1, 2])
237 | return images * ((drange[1] - drange[0]) / 255) + drange[0]
238 |
239 |
240 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
241 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
242 | Can be used as an output transformation for Network.run().
243 | """
244 | images = tf.cast(images, tf.float32)
245 | if shrink > 1:
246 | ksize = [1, 1, shrink, shrink]
247 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
248 | if nchw_to_nhwc:
249 | images = tf.transpose(images, [0, 2, 3, 1])
250 | scale = 255 / (drange[1] - drange[0])
251 | images = images * scale + (0.5 - drange[0] * scale)
252 | return tf.saturate_cast(images, tf.uint8)
253 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Miscellaneous utility classes and functions."""
8 |
9 | import ctypes
10 | import fnmatch
11 | import importlib
12 | import inspect
13 | import numpy as np
14 | import os
15 | import shutil
16 | import sys
17 | import types
18 | import io
19 | import pickle
20 | import re
21 | import requests
22 | import html
23 | import hashlib
24 | import glob
25 | import uuid
26 |
27 | from distutils.util import strtobool
28 | from typing import Any, List, Tuple, Union
29 |
30 |
31 | # Util classes
32 | # ------------------------------------------------------------------------------------------
33 |
34 |
35 | class EasyDict(dict):
36 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
37 |
38 | def __getattr__(self, name: str) -> Any:
39 | try:
40 | return self[name]
41 | except KeyError:
42 | raise AttributeError(name)
43 |
44 | def __setattr__(self, name: str, value: Any) -> None:
45 | self[name] = value
46 |
47 | def __delattr__(self, name: str) -> None:
48 | del self[name]
49 |
50 |
51 | class Logger(object):
52 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
53 |
54 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
55 | self.file = None
56 |
57 | if file_name is not None:
58 | self.file = open(file_name, file_mode)
59 |
60 | self.should_flush = should_flush
61 | self.stdout = sys.stdout
62 | self.stderr = sys.stderr
63 |
64 | sys.stdout = self
65 | sys.stderr = self
66 |
67 | def __enter__(self) -> "Logger":
68 | return self
69 |
70 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
71 | self.close()
72 |
73 | def write(self, text: str) -> None:
74 | """Write text to stdout (and a file) and optionally flush."""
75 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
76 | return
77 |
78 | if self.file is not None:
79 | self.file.write(text)
80 |
81 | self.stdout.write(text)
82 |
83 | if self.should_flush:
84 | self.flush()
85 |
86 | def flush(self) -> None:
87 | """Flush written text to both stdout and a file, if open."""
88 | if self.file is not None:
89 | self.file.flush()
90 |
91 | self.stdout.flush()
92 |
93 | def close(self) -> None:
94 | """Flush, close possible files, and remove stdout/stderr mirroring."""
95 | self.flush()
96 |
97 | # if using multiple loggers, prevent closing in wrong order
98 | if sys.stdout is self:
99 | sys.stdout = self.stdout
100 | if sys.stderr is self:
101 | sys.stderr = self.stderr
102 |
103 | if self.file is not None:
104 | self.file.close()
105 |
106 |
107 | # Small util functions
108 | # ------------------------------------------------------------------------------------------
109 |
110 |
111 | def format_time(seconds: Union[int, float]) -> str:
112 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
113 | s = int(np.rint(seconds))
114 |
115 | if s < 60:
116 | return "{0}s".format(s)
117 | elif s < 60 * 60:
118 | return "{0}m {1:02}s".format(s // 60, s % 60)
119 | elif s < 24 * 60 * 60:
120 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
121 | else:
122 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
123 |
124 |
125 | def ask_yes_no(question: str) -> bool:
126 | """Ask the user the question until the user inputs a valid answer."""
127 | while True:
128 | try:
129 | print("{0} [y/n]".format(question))
130 | return strtobool(input().lower())
131 | except ValueError:
132 | pass
133 |
134 |
135 | def tuple_product(t: Tuple) -> Any:
136 | """Calculate the product of the tuple elements."""
137 | result = 1
138 |
139 | for v in t:
140 | result *= v
141 |
142 | return result
143 |
144 |
145 | _str_to_ctype = {
146 | "uint8": ctypes.c_ubyte,
147 | "uint16": ctypes.c_uint16,
148 | "uint32": ctypes.c_uint32,
149 | "uint64": ctypes.c_uint64,
150 | "int8": ctypes.c_byte,
151 | "int16": ctypes.c_int16,
152 | "int32": ctypes.c_int32,
153 | "int64": ctypes.c_int64,
154 | "float32": ctypes.c_float,
155 | "float64": ctypes.c_double
156 | }
157 |
158 |
159 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
160 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
161 | type_str = None
162 |
163 | if isinstance(type_obj, str):
164 | type_str = type_obj
165 | elif hasattr(type_obj, "__name__"):
166 | type_str = type_obj.__name__
167 | elif hasattr(type_obj, "name"):
168 | type_str = type_obj.name
169 | else:
170 | raise RuntimeError("Cannot infer type name from input")
171 |
172 | assert type_str in _str_to_ctype.keys()
173 |
174 | my_dtype = np.dtype(type_str)
175 | my_ctype = _str_to_ctype[type_str]
176 |
177 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
178 |
179 | return my_dtype, my_ctype
180 |
181 |
182 | def is_pickleable(obj: Any) -> bool:
183 | try:
184 | with io.BytesIO() as stream:
185 | pickle.dump(obj, stream)
186 | return True
187 | except:
188 | return False
189 |
190 |
191 | # Functionality to import modules/objects by name, and call functions by name
192 | # ------------------------------------------------------------------------------------------
193 |
194 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
195 | """Searches for the underlying module behind the name to some python object.
196 | Returns the module and the object name (original name with module part removed)."""
197 |
198 | # allow convenience shorthands, substitute them by full names
199 | obj_name = re.sub("^np.", "numpy.", obj_name)
200 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
201 |
202 | # list alternatives for (module_name, local_obj_name)
203 | parts = obj_name.split(".")
204 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
205 |
206 | # try each alternative in turn
207 | for module_name, local_obj_name in name_pairs:
208 | try:
209 | module = importlib.import_module(module_name) # may raise ImportError
210 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
211 | return module, local_obj_name
212 | except:
213 | pass
214 |
215 | # maybe some of the modules themselves contain errors?
216 | for module_name, _local_obj_name in name_pairs:
217 | try:
218 | importlib.import_module(module_name) # may raise ImportError
219 | except ImportError:
220 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
221 | raise
222 |
223 | # maybe the requested attribute is missing?
224 | for module_name, local_obj_name in name_pairs:
225 | try:
226 | module = importlib.import_module(module_name) # may raise ImportError
227 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
228 | except ImportError:
229 | pass
230 |
231 | # we are out of luck, but we have no idea why
232 | raise ImportError(obj_name)
233 |
234 |
235 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
236 | """Traverses the object name and returns the last (rightmost) python object."""
237 | if obj_name == '':
238 | return module
239 | obj = module
240 | for part in obj_name.split("."):
241 | obj = getattr(obj, part)
242 | return obj
243 |
244 |
245 | def get_obj_by_name(name: str) -> Any:
246 | """Finds the python object with the given name."""
247 | module, obj_name = get_module_from_obj_name(name)
248 | return get_obj_from_module(module, obj_name)
249 |
250 |
251 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
252 | """Finds the python object with the given name and calls it as a function."""
253 | assert func_name is not None
254 | func_obj = get_obj_by_name(func_name)
255 | assert callable(func_obj)
256 | return func_obj(*args, **kwargs)
257 |
258 |
259 | def get_module_dir_by_obj_name(obj_name: str) -> str:
260 | """Get the directory path of the module containing the given object name."""
261 | module, _ = get_module_from_obj_name(obj_name)
262 | return os.path.dirname(inspect.getfile(module))
263 |
264 |
265 | def is_top_level_function(obj: Any) -> bool:
266 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
267 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
268 |
269 |
270 | def get_top_level_function_name(obj: Any) -> str:
271 | """Return the fully-qualified name of a top-level function."""
272 | assert is_top_level_function(obj)
273 | return obj.__module__ + "." + obj.__name__
274 |
275 |
276 | # File system helpers
277 | # ------------------------------------------------------------------------------------------
278 |
279 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
280 | """List all files recursively in a given directory while ignoring given file and directory names.
281 | Returns list of tuples containing both absolute and relative paths."""
282 | assert os.path.isdir(dir_path)
283 | base_name = os.path.basename(os.path.normpath(dir_path))
284 |
285 | if ignores is None:
286 | ignores = []
287 |
288 | result = []
289 |
290 | for root, dirs, files in os.walk(dir_path, topdown=True):
291 | for ignore_ in ignores:
292 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
293 |
294 | # dirs need to be edited in-place
295 | for d in dirs_to_remove:
296 | dirs.remove(d)
297 |
298 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
299 |
300 | absolute_paths = [os.path.join(root, f) for f in files]
301 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
302 |
303 | if add_base_to_relative:
304 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
305 |
306 | assert len(absolute_paths) == len(relative_paths)
307 | result += zip(absolute_paths, relative_paths)
308 |
309 | return result
310 |
311 |
312 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
313 | """Takes in a list of tuples of (src, dst) paths and copies files.
314 | Will create all necessary directories."""
315 | for file in files:
316 | target_dir_name = os.path.dirname(file[1])
317 |
318 | # will create all intermediate-level directories
319 | if not os.path.exists(target_dir_name):
320 | os.makedirs(target_dir_name)
321 |
322 | shutil.copyfile(file[0], file[1])
323 |
324 |
325 | # URL helpers
326 | # ------------------------------------------------------------------------------------------
327 |
328 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
329 | """Determine whether the given object is a valid URL string."""
330 | if not isinstance(obj, str) or not "://" in obj:
331 | return False
332 | if allow_file_urls and obj.startswith('file:///'):
333 | return True
334 | try:
335 | res = requests.compat.urlparse(obj)
336 | if not res.scheme or not res.netloc or not "." in res.netloc:
337 | return False
338 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
339 | if not res.scheme or not res.netloc or not "." in res.netloc:
340 | return False
341 | except:
342 | return False
343 | return True
344 |
345 |
346 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
347 | """Download the given URL and return a binary-mode file object to access the data."""
348 | assert is_url(url, allow_file_urls=True)
349 | assert num_attempts >= 1
350 |
351 | # Handle file URLs.
352 | if url.startswith('file:///'):
353 | return open(url[len('file:///'):], "rb")
354 |
355 | # Lookup from cache.
356 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
357 | if cache_dir is not None:
358 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
359 | if len(cache_files) == 1:
360 | return open(cache_files[0], "rb")
361 |
362 | # Download.
363 | url_name = None
364 | url_data = None
365 | with requests.Session() as session:
366 | if verbose:
367 | print("Downloading %s ..." % url, end="", flush=True)
368 | for attempts_left in reversed(range(num_attempts)):
369 | try:
370 | with session.get(url) as res:
371 | res.raise_for_status()
372 | if len(res.content) == 0:
373 | raise IOError("No data received")
374 |
375 | if len(res.content) < 8192:
376 | content_str = res.content.decode("utf-8")
377 | if "download_warning" in res.headers.get("Set-Cookie", ""):
378 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
379 | if len(links) == 1:
380 | url = requests.compat.urljoin(url, links[0])
381 | raise IOError("Google Drive virus checker nag")
382 | if "Google Drive - Quota exceeded" in content_str:
383 | raise IOError("Google Drive download quota exceeded -- please try again later")
384 |
385 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
386 | url_name = match[1] if match else url
387 | url_data = res.content
388 | if verbose:
389 | print(" done")
390 | break
391 | except:
392 | if not attempts_left:
393 | if verbose:
394 | print(" failed")
395 | raise
396 | if verbose:
397 | print(".", end="", flush=True)
398 |
399 | # Save to cache.
400 | if cache_dir is not None:
401 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
402 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
403 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
404 | os.makedirs(cache_dir, exist_ok=True)
405 | with open(temp_file, "wb") as f:
406 | f.write(url_data)
407 | os.replace(temp_file, cache_file) # atomic
408 |
409 | # Return data as file object.
410 | return io.BytesIO(url_data)
411 |
--------------------------------------------------------------------------------
/imgs/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/demo.gif
--------------------------------------------------------------------------------
/imgs/example_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/example_image.jpg
--------------------------------------------------------------------------------
/imgs/example_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/example_mask.jpg
--------------------------------------------------------------------------------
/imgs/grid-main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zsyzzsoft/co-mod-gan/c8b9ffe30c950dfdfb9e86652a27b6e72304b3d5/imgs/grid-main.jpg
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | # empty
8 |
--------------------------------------------------------------------------------
/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Frechet Inception Distance (FID)."""
8 |
9 | import os
10 | import numpy as np
11 | import scipy
12 | import tensorflow as tf
13 | import dnnlib.tflib as tflib
14 |
15 | from metrics import metric_base
16 | from training import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class FID(metric_base.MetricBase):
21 | def __init__(self, num_images, minibatch_per_gpu, ref_train=False, ref_samples=None, **kwargs):
22 | super().__init__(**kwargs)
23 | self.num_images = num_images
24 | self.minibatch_per_gpu = minibatch_per_gpu
25 | self.ref_train = ref_train
26 | self.ref_samples = num_images if ref_samples is None else ref_samples
27 |
28 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
29 | minibatch_size = num_gpus * self.minibatch_per_gpu
30 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
31 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
32 |
33 | # Calculate statistics for reals.
34 | cache_file = self._get_cache_file_for_reals(num_images=self.ref_samples)
35 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
36 | if os.path.isfile(cache_file):
37 | mu_real, sigma_real = misc.load_pkl(cache_file)
38 | else:
39 | real_activations = np.empty([self.ref_samples, inception.output_shape[1]], dtype=np.float32)
40 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size, is_training=self.ref_train)):
41 | begin = idx * minibatch_size
42 | end = min(begin + minibatch_size, self.ref_samples)
43 | real_activations[begin:end] = inception.run(images[:end-begin, :3], num_gpus=num_gpus, assume_frozen=True)
44 | if end == self.ref_samples:
45 | break
46 | mu_real = np.mean(real_activations, axis=0)
47 | sigma_real = np.cov(real_activations, rowvar=False)
48 | misc.save_pkl((mu_real, sigma_real), cache_file)
49 |
50 | # Construct TensorFlow graph.
51 | self._configure(self.minibatch_per_gpu)
52 | result_expr = []
53 | for gpu_idx in range(num_gpus):
54 | with tf.device('/gpu:%d' % gpu_idx):
55 | Gs_clone = Gs.clone()
56 | inception_clone = inception.clone()
57 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
58 | reals, labels = self._get_minibatch_tf()
59 | reals = tflib.convert_images_from_uint8(reals)
60 | masks = self._get_random_masks_tf()
61 | images = Gs_clone.get_output_for(latents, labels, reals, masks, **Gs_kwargs)
62 | images = images[:, :3]
63 | images = tflib.convert_images_to_uint8(images)
64 | result_expr.append(inception_clone.get_output_for(images))
65 |
66 | # Calculate statistics for fakes.
67 | for begin in range(0, self.num_images, minibatch_size):
68 | self._report_progress(begin, self.num_images)
69 | end = min(begin + minibatch_size, self.num_images)
70 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
71 | mu_fake = np.mean(activations, axis=0)
72 | sigma_fake = np.cov(activations, rowvar=False)
73 |
74 | # Calculate FID.
75 | m = np.square(mu_fake - mu_real).sum()
76 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
77 | dist = m + np.trace(sigma_fake + sigma_real - 2*s)
78 | self._report_result(np.real(dist))
79 |
80 | #----------------------------------------------------------------------------
81 |
--------------------------------------------------------------------------------
/metrics/inception_discriminative_score.py:
--------------------------------------------------------------------------------
1 | # Large Scale Image Completion via Co-Modulated Generative Adversarial Networks
2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3 | # https://openreview.net/pdf?id=sSjqmfsk95O
4 |
5 | """Paired/Unpaired Inception Discriminative Score (P-IDS/U-IDS)."""
6 |
7 | import os
8 | from tqdm import tqdm
9 | import numpy as np
10 | import scipy
11 | import sklearn.svm
12 | import tensorflow as tf
13 | import dnnlib
14 | import dnnlib.tflib as tflib
15 |
16 | from metrics import metric_base
17 | from training import misc
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | class IDS(metric_base.MetricBase):
22 | def __init__(self, num_images, minibatch_per_gpu, hole_range=[0,1], **kwargs):
23 | super().__init__(**kwargs)
24 | self.num_images = num_images
25 | self.minibatch_per_gpu = minibatch_per_gpu
26 | self.hole_range = hole_range
27 |
28 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
29 | minibatch_size = num_gpus * self.minibatch_per_gpu
30 | inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
31 | real_activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
32 | fake_activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
33 |
34 | # Construct TensorFlow graph.
35 | self._configure(self.minibatch_per_gpu, hole_range=self.hole_range)
36 | real_img_expr = []
37 | fake_img_expr = []
38 | real_result_expr = []
39 | fake_result_expr = []
40 | for gpu_idx in range(num_gpus):
41 | with tf.device('/gpu:%d' % gpu_idx):
42 | Gs_clone = Gs.clone()
43 | inception_clone = inception.clone()
44 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
45 | reals, labels = self._get_minibatch_tf()
46 | reals_tf = tflib.convert_images_from_uint8(reals)
47 | masks = self._get_random_masks_tf()
48 | fakes = Gs_clone.get_output_for(latents, labels, reals_tf, masks, **Gs_kwargs)
49 | fakes = tflib.convert_images_to_uint8(fakes[:, :3])
50 | reals = tflib.convert_images_to_uint8(reals_tf[:, :3])
51 | real_img_expr.append(reals)
52 | fake_img_expr.append(fakes)
53 | real_result_expr.append(inception_clone.get_output_for(reals))
54 | fake_result_expr.append(inception_clone.get_output_for(fakes))
55 |
56 | for begin in tqdm(range(0, self.num_images, minibatch_size)):
57 | self._report_progress(begin, self.num_images)
58 | end = min(begin + minibatch_size, self.num_images)
59 | real_results, fake_results = tflib.run([real_result_expr, fake_result_expr])
60 | real_activations[begin:end] = np.concatenate(real_results, axis=0)[:end-begin]
61 | fake_activations[begin:end] = np.concatenate(fake_results, axis=0)[:end-begin]
62 |
63 | # Calculate FID conviniently.
64 | mu_real = np.mean(real_activations, axis=0)
65 | sigma_real = np.cov(real_activations, rowvar=False)
66 | mu_fake = np.mean(fake_activations, axis=0)
67 | sigma_fake = np.cov(fake_activations, rowvar=False)
68 | m = np.square(mu_fake - mu_real).sum()
69 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False)
70 | dist = m + np.trace(sigma_fake + sigma_real - 2*s)
71 | self._report_result(np.real(dist), suffix='-FID')
72 |
73 | svm = sklearn.svm.LinearSVC(dual=False)
74 | svm_inputs = np.concatenate([real_activations, fake_activations])
75 | svm_targets = np.array([1] * real_activations.shape[0] + [0] * fake_activations.shape[0])
76 | svm.fit(svm_inputs, svm_targets)
77 | self._report_result(1 - svm.score(svm_inputs, svm_targets), suffix='-U')
78 | real_outputs = svm.decision_function(real_activations)
79 | fake_outputs = svm.decision_function(fake_activations)
80 | self._report_result(np.mean(fake_outputs > real_outputs), suffix='-P')
81 |
82 | #----------------------------------------------------------------------------
83 |
--------------------------------------------------------------------------------
/metrics/learned_perceptual_image_patch_similarity.py:
--------------------------------------------------------------------------------
1 | """Learned Perceptual Image Patch Similarity (LPIPS)."""
2 |
3 | import os
4 | import numpy as np
5 | import scipy
6 | import tensorflow as tf
7 | import dnnlib.tflib as tflib
8 |
9 | from metrics import metric_base
10 | from training import misc
11 |
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | class LPIPS(metric_base.MetricBase):
16 | def __init__(self, num_pairs=2000, minibatch_per_gpu=8, **kwargs):
17 | super().__init__(**kwargs)
18 | self.num_pairs = num_pairs
19 | self.minibatch_per_gpu = minibatch_per_gpu
20 |
21 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
22 | minibatch_size = num_gpus * self.minibatch_per_gpu
23 |
24 | graph_def = tf.GraphDef()
25 | with misc.open_file_or_url('http://rail.eecs.berkeley.edu/models/lpips/net-lin_alex_v0.1.pb') as f:
26 | graph_def.ParseFromString(f.read())
27 |
28 | # Construct TensorFlow graph.
29 | self._configure(self.minibatch_per_gpu)
30 | result_expr = []
31 | for gpu_idx in range(num_gpus):
32 | def auto_gpu(opr):
33 | if opr.type in ['SparseToDense', 'Tile', 'GatherV2', 'Pack']:
34 | return '/cpu:0'
35 | else:
36 | return '/gpu:%d' % gpu_idx
37 | with tf.device(auto_gpu):
38 | Gs_clone = Gs.clone()
39 | reals, labels = self._get_minibatch_tf()
40 | reals = tflib.convert_images_from_uint8(reals)
41 | masks = self._get_random_masks_tf()
42 |
43 | latents0 = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
44 | fakes0 = Gs_clone.get_output_for(latents0, labels, reals, masks, **Gs_kwargs)[:, :3, :, :]
45 | fakes0 = tf.clip_by_value(fakes0, -1.0, 1.0)
46 |
47 | latents1 = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
48 | fakes1 = Gs_clone.get_output_for(latents1, labels, reals, masks, **Gs_kwargs)[:, :3, :, :]
49 | fakes1 = tf.clip_by_value(fakes1, -1.0, 1.0)
50 |
51 | distance, = tf.import_graph_def(
52 | graph_def,
53 | input_map={'0:0': fakes0, '1:0': fakes1},
54 | return_elements = ['Reshape_10']
55 | )
56 | result_expr.append(distance.outputs)
57 |
58 | # Run metric
59 | results = []
60 | for begin in range(0, self.num_pairs, minibatch_size):
61 | self._report_progress(begin, self.num_pairs)
62 | res = tflib.run(result_expr)
63 | results.append(np.reshape(res, [-1]))
64 | results = np.concatenate(results)
65 | self._report_result(np.mean(results))
66 | self._report_result(np.std(results), suffix='-var')
67 |
68 | #----------------------------------------------------------------------------
69 |
--------------------------------------------------------------------------------
/metrics/metric_base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Common definitions for GAN metrics."""
8 |
9 | import os
10 | import time
11 | import hashlib
12 | import numpy as np
13 | import tensorflow as tf
14 | import dnnlib
15 | import dnnlib.tflib as tflib
16 |
17 | from training import misc
18 | from training import dataset
19 |
20 | #----------------------------------------------------------------------------
21 | # Base class for metrics.
22 |
23 | class MetricBase:
24 | def __init__(self, name):
25 | self.name = name
26 | self._dataset_obj = None
27 | self._progress_lo = None
28 | self._progress_hi = None
29 | self._progress_max = None
30 | self._progress_sec = None
31 | self._progress_time = None
32 | self._reset()
33 |
34 | def close(self):
35 | self._reset()
36 |
37 | def _reset(self, network_pkl=None, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None):
38 | if self._dataset_obj is not None:
39 | self._dataset_obj.close()
40 |
41 | self._network_pkl = network_pkl
42 | self._data_dir = data_dir
43 | self._dataset_args = dataset_args
44 | self._dataset_obj = None
45 | self._mirror_augment = mirror_augment
46 | self._eval_time = 0
47 | self._results = []
48 |
49 | if (dataset_args is None or mirror_augment is None) and run_dir is not None:
50 | run_config = misc.parse_config_for_previous_run(run_dir)
51 | self._dataset_args = dict(run_config['dataset'])
52 | self._dataset_args['shuffle_mb'] = 0
53 | self._mirror_augment = run_config['train'].get('mirror_augment', False)
54 |
55 | def configure_progress_reports(self, plo, phi, pmax, psec=15):
56 | self._progress_lo = plo
57 | self._progress_hi = phi
58 | self._progress_max = pmax
59 | self._progress_sec = psec
60 |
61 | def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True,
62 | num_repeats=1, Gs_kwargs=dict(is_validation=True), resume_with_new_nets=False, truncations=[None]):
63 | self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment)
64 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager
65 | self._report_progress(0, 1)
66 | _G, _D, Gs = misc.load_pkl(self._network_pkl)
67 |
68 | if resume_with_new_nets:
69 | dataset = self._get_dataset_obj()
70 | G = dnnlib.tflib.Network('G', num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size,
71 | func_name='training.co_mod_gan.G_main', pix2pix=dataset.pix2pix)
72 | Gs_new = G.clone('Gs')
73 | Gs_new.copy_vars_from(Gs)
74 | Gs = Gs_new
75 |
76 | for t in truncations:
77 | print('truncation={}'.format(t))
78 | self._results = []
79 | time_begin = time.time()
80 |
81 | Gs_kwargs.update(truncation_psi_val=t)
82 | self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus)
83 | self._report_progress(1, 1)
84 |
85 | if num_repeats > 1:
86 | records = [dnnlib.EasyDict(value=[res.value], suffix=res.suffix, fmt=res.fmt) for res in self._results]
87 | for i in range(1, num_repeats):
88 | print(self.get_result_str().strip())
89 | self._results = []
90 | self._report_progress(0, 1)
91 | self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus)
92 | self._report_progress(1, 1)
93 | for rec, res in zip(records, self._results):
94 | rec.value.append(res.value)
95 |
96 | self._results = []
97 | for rec in records:
98 | self._report_result(np.mean(rec.value), rec.suffix, rec.fmt)
99 | self._report_result(np.std(rec.value), rec.suffix + '-std', rec.fmt)
100 |
101 | self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init
102 |
103 | if log_results:
104 | if run_dir is not None:
105 | log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name)
106 | with dnnlib.util.Logger(log_file, 'a'):
107 | print(self.get_result_str().strip())
108 | else:
109 | print(self.get_result_str().strip())
110 |
111 | def get_result_str(self):
112 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0]
113 | if len(network_name) > 29:
114 | network_name = '...' + network_name[-26:]
115 | result_str = '%-30s' % network_name
116 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time)
117 | for res in self._results:
118 | result_str += ' ' + self.name + res.suffix + ' '
119 | result_str += res.fmt % res.value
120 | return result_str
121 |
122 | def update_autosummaries(self):
123 | for res in self._results:
124 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)
125 |
126 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
127 | raise NotImplementedError # to be overridden by subclasses
128 |
129 | def _report_result(self, value, suffix='', fmt='%-10.4f'):
130 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
131 |
132 | def _report_progress(self, pcur, pmax, status_str=''):
133 | if self._progress_lo is None or self._progress_hi is None or self._progress_max is None:
134 | return
135 | t = time.time()
136 | if self._progress_sec is not None and self._progress_time is not None and t < self._progress_time + self._progress_sec:
137 | return
138 | self._progress_time = t
139 | val = self._progress_lo + (pcur / pmax) * (self._progress_hi - self._progress_lo)
140 | dnnlib.RunContext.get().update(status_str, int(val), self._progress_max)
141 |
142 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
143 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)
144 | all_args.update(self._dataset_args)
145 | all_args.update(kwargs)
146 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
147 | dataset_name = self._dataset_args.get('tfrecord_dir', None) or self._dataset_args.get('h5_file', None)
148 | dataset_name = os.path.splitext(os.path.basename(dataset_name))[0]
149 | return os.path.join('.stylegan2-cache', '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))
150 |
151 | def _get_dataset_obj(self):
152 | if self._dataset_obj is None:
153 | self._dataset_obj = dataset.load_dataset(data_dir=self._data_dir, **self._dataset_args)
154 | return self._dataset_obj
155 |
156 | def _iterate_reals(self, minibatch_size, is_training=True):
157 | dataset_obj = self._get_dataset_obj()
158 | while True:
159 | if is_training:
160 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
161 | else:
162 | images, _labels = dataset_obj.get_minibatch_val_np(minibatch_size)
163 | if self._mirror_augment:
164 | images = misc.apply_mirror_augment(images)
165 | yield images
166 |
167 | def _configure(self, minibatch_size, hole_range=[0, 1]):
168 | return self._get_dataset_obj().configure(minibatch_size, hole_range=hole_range)
169 |
170 | def _get_minibatch_tf(self):
171 | return self._get_dataset_obj().get_minibatch_val_tf()
172 |
173 | def _get_random_masks_tf(self):
174 | return self._get_dataset_obj().get_random_masks_tf()
175 |
176 | def _get_random_labels_tf(self, minibatch_size):
177 | return self._get_dataset_obj().get_random_labels_tf(minibatch_size)
178 |
179 | #----------------------------------------------------------------------------
180 | # Group of multiple metrics.
181 |
182 | class MetricGroup:
183 | def __init__(self, metric_kwarg_list):
184 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list]
185 |
186 | def run(self, *args, **kwargs):
187 | for metric in self.metrics:
188 | metric.run(*args, **kwargs)
189 |
190 | def get_result_str(self):
191 | return ' '.join(metric.get_result_str() for metric in self.metrics)
192 |
193 | def update_autosummaries(self):
194 | for metric in self.metrics:
195 | metric.update_autosummaries()
196 |
197 | #----------------------------------------------------------------------------
198 | # Dummy metric for debugging purposes.
199 |
200 | class DummyMetric(MetricBase):
201 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
202 | _ = Gs, Gs_kwargs, num_gpus
203 | self._report_result(0.0)
204 |
205 | #----------------------------------------------------------------------------
206 |
--------------------------------------------------------------------------------
/metrics/metric_defaults.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Default metric definitions."""
8 |
9 | from dnnlib import EasyDict
10 |
11 | #----------------------------------------------------------------------------
12 |
13 | metric_defaults = EasyDict([(args.name, args) for args in [
14 | EasyDict(name='fid200-rt-shoes', func_name='metrics.frechet_inception_distance.FID', num_images=200, minibatch_per_gpu=1, ref_train=True, ref_samples=49825),
15 | EasyDict(name='fid200-rt-handbags', func_name='metrics.frechet_inception_distance.FID', num_images=200, minibatch_per_gpu=1, ref_train=True, ref_samples=138567),
16 | EasyDict(name='fid5k', func_name='metrics.frechet_inception_distance.FID', num_images=5000, minibatch_per_gpu=8),
17 | EasyDict(name='fid10k', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8),
18 | EasyDict(name='fid10k-b1', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=1),
19 | EasyDict(name='fid10k-h0', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.0, 0.2]),
20 | EasyDict(name='fid10k-h1', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.2, 0.4]),
21 | EasyDict(name='fid10k-h2', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.4, 0.6]),
22 | EasyDict(name='fid10k-h3', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.6, 0.8]),
23 | EasyDict(name='fid10k-h4', func_name='metrics.frechet_inception_distance.FID', num_images=10000, minibatch_per_gpu=8, hole_range=[0.8, 1.0]),
24 | EasyDict(name='fid36k5', func_name='metrics.frechet_inception_distance.FID',num_images=36500, minibatch_per_gpu=8),
25 | EasyDict(name='fid36k5-h0', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.0, 0.2]),
26 | EasyDict(name='fid36k5-h1', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.2, 0.4]),
27 | EasyDict(name='fid36k5-h2', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.4, 0.6]),
28 | EasyDict(name='fid36k5-h3', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.6, 0.8]),
29 | EasyDict(name='fid36k5-h4', func_name='metrics.frechet_inception_distance.FID', num_images=36500, minibatch_per_gpu=8, hole_range=[0.8, 1.0]),
30 | EasyDict(name='fid50k', func_name='metrics.frechet_inception_distance.FID', num_images=50000, minibatch_per_gpu=8),
31 | EasyDict(name='ids5k', func_name='metrics.inception_discriminator_score.IDS', num_images=5000, minibatch_per_gpu=8),
32 | EasyDict(name='ids10k', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8),
33 | EasyDict(name='ids10k-b1', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=1),
34 | EasyDict(name='ids10k-h0', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.0, 0.2]),
35 | EasyDict(name='ids10k-h1', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.2, 0.4]),
36 | EasyDict(name='ids10k-h2', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.4, 0.6]),
37 | EasyDict(name='ids10k-h3', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.6, 0.8]),
38 | EasyDict(name='ids10k-h4', func_name='metrics.inception_discriminative_score.IDS', num_images=10000, minibatch_per_gpu=8, hole_range=[0.8, 1.0]),
39 | EasyDict(name='ids36k5', func_name='metrics.inception_discriminative_score.IDS',num_images=36500, minibatch_per_gpu=8),
40 | EasyDict(name='ids36k5-h0', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.0, 0.2]),
41 | EasyDict(name='ids36k5-h1', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.2, 0.4]),
42 | EasyDict(name='ids36k5-h2', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.4, 0.6]),
43 | EasyDict(name='ids36k5-h3', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.6, 0.8]),
44 | EasyDict(name='ids36k5-h4', func_name='metrics.inception_discriminative_score.IDS', num_images=36500, minibatch_per_gpu=8, hole_range=[0.8, 1.0]),
45 | EasyDict(name='ids50k', func_name='metrics.inception_discriminative_score.IDS', num_images=50000, minibatch_per_gpu=8),
46 | EasyDict(name='lpips2k', func_name='metrics.learned_perceptual_image_patch_similarity.LPIPS', num_pairs=2000, minibatch_per_gpu=8),
47 | ]])
48 |
49 | #----------------------------------------------------------------------------
50 |
--------------------------------------------------------------------------------
/run_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('-c', '--checkpoint', required=True)
6 | parser.add_argument('-d', '--data-dir', required=True)
7 | parser.add_argument('-s', '--save-dir', default='images')
8 | parser.add_argument('-w', '--window-size', type=int, default=512)
9 | args = parser.parse_args()
10 |
11 | import tkinter as tk
12 | from PIL import Image, ImageTk, ImageDraw
13 | import numpy as np
14 | import cv2
15 | import hashlib
16 |
17 | import dnnlib
18 |
19 | from training import misc, dataset
20 |
21 |
22 | class App(tk.Tk):
23 | def __init__(self):
24 | super().__init__()
25 | self.state = -1
26 | self.canvas = tk.Canvas(self, bg='gray', height=args.window_size, width=args.window_size)
27 | self.canvas.bind("", self.L_press)
28 | self.canvas.bind("", self.L_release)
29 | self.canvas.bind("", self.L_move)
30 | self.canvas.bind("", self.R_press)
31 | self.canvas.bind("", self.R_release)
32 | self.canvas.bind("", self.R_move)
33 | self.canvas.bind("", self.key_down)
34 | self.canvas.bind("", self.key_up)
35 | self.canvas.pack()
36 |
37 | self.canvas.focus_set()
38 | self.canvas_image = self.canvas.create_image(0, 0, anchor='nw')
39 |
40 | dnnlib.tflib.init_tf()
41 | self.dataset = dataset.load_dataset(tfrecord_dir=args.data_dir, verbose=True, shuffle_mb=0)
42 |
43 | self.networks = []
44 | self.truncations = []
45 | self.model_names = []
46 | for ckpt in args.checkpoint.split(','):
47 | if ':' in ckpt:
48 | ckpt, truncation = ckpt.split(':')
49 | truncation = float(truncation)
50 | else:
51 | truncation = None
52 |
53 | _, _, Gs = misc.load_pkl(ckpt)
54 |
55 | self.networks.append(Gs)
56 | self.truncations.append(truncation)
57 | self.model_names.append(os.path.basename(os.path.splitext(ckpt)[0]))
58 |
59 | self.key_list = ['q', 'w', 'e', 'r', 't', 'y', 'u', 'i', 'o', 'p'][:len(self.networks)]
60 | self.image_id = -1
61 |
62 | self.new_image()
63 | self.display()
64 |
65 | def generate(self, idx=0):
66 | self.cur_idx = idx
67 | latent = np.random.randn(1, *self.networks[idx].input_shape[1:])
68 | real = misc.adjust_dynamic_range(self.real_image, [0, 255], [-1, 1])
69 | fake = self.networks[idx].run(latent, self.label, real, self.mask, truncation_psi=self.truncations[idx])
70 | self.fake_image = misc.adjust_dynamic_range(fake, [-1, 1], [0, 255]).clip(0, 255).astype(np.uint8)
71 |
72 | def new_image(self):
73 | self.image_id += 1
74 | self.save_count = 0
75 | self.real_image, self.label = self.dataset.get_minibatch_val_np(1)
76 | self.resolution = self.real_image.shape[-1]
77 | self.mask = np.ones((1, 1, self.resolution, self.resolution), np.uint8)
78 | self.mask_history = [self.mask]
79 |
80 | def display(self, state=0):
81 | if state != self.state:
82 | self.last_state = self.state
83 | self.state = state
84 | self.image = self.real_image if self.state == 1 else self.fake_image if self.state == 2 else self.real_image * self.mask
85 | self.image_for_display = np.transpose(self.image[0, :3], (1, 2, 0))
86 | self.image_for_display_resized = cv2.resize(self.image_for_display, (args.window_size, args.window_size))
87 | self.tkimage = ImageTk.PhotoImage(image=Image.fromarray(self.image_for_display_resized))
88 | self.canvas.itemconfig(self.canvas_image, image=self.tkimage)
89 |
90 | def save_image(self):
91 | folder_name = os.path.join(args.save_dir, '-'.join([os.path.basename(args.data_dir), str(self.image_id), hashlib.sha1(self.mask.tostring()).hexdigest()[:6]]))
92 | if not os.path.exists(folder_name):
93 | os.makedirs(folder_name)
94 | self.save_count = 0
95 | for img, name in [[self.real_image, 'real'], [self.real_image * self.mask, 'masked']]:
96 | cv2.imwrite(os.path.join(folder_name, name + '.jpg'), np.transpose(img[0, :3], (1, 2, 0))[..., ::-1])
97 | if self.state == 2:
98 | cv2.imwrite(os.path.join(folder_name, '-'.join([self.model_names[self.cur_idx], str(self.save_count)]) + '.jpg'), self.image_for_display[..., ::-1])
99 | self.save_count += 1
100 |
101 | def get_pos(self, event):
102 | return (int(event.x * self.resolution / args.window_size), int(event.y * self.resolution / args.window_size))
103 |
104 | def L_press(self, event):
105 | self.last_pos = self.get_pos(event)
106 |
107 | def L_move(self, event):
108 | a = self.last_pos
109 | b = self.get_pos(event)
110 | width = 30
111 | img = Image.fromarray(self.mask[0, 0])
112 | draw = ImageDraw.Draw(img)
113 | draw.line([a, b], fill=0, width=width)
114 | draw.ellipse((b[0] - width // 2, b[1] - width // 2, b[0] + width // 2, b[1] + width // 2), fill=0)
115 | self.mask = np.array(img)[np.newaxis, np.newaxis, ...]
116 | self.display()
117 | self.last_pos = b
118 |
119 | def L_release(self, event):
120 | self.L_move(event)
121 | self.mask_history.append(self.mask)
122 |
123 | def R_press(self, event):
124 | self.last_pos = self.get_pos(event)
125 |
126 | def R_move(self, event):
127 | a = self.last_pos
128 | b = self.get_pos(event)
129 | self.mask = self.mask_history[-1].copy()
130 | self.mask[0, 0, max(min(a[1], b[1]), 0): max(a[1], b[1]), max(min(a[0], b[0]), 0): max(a[0], b[0])] = 0
131 | self.display()
132 |
133 | def R_release(self, event):
134 | self.R_move(event)
135 | self.mask_history.append(self.mask)
136 |
137 | def key_down(self, event):
138 | if event.keysym == 'z':
139 | if len(self.mask_history) > 1:
140 | self.mask_history.pop()
141 | self.mask = self.mask_history[-1]
142 | self.display()
143 | elif event.keysym == 'space':
144 | self.generate()
145 | self.display(2)
146 | elif event.keysym in self.key_list:
147 | self.generate(self.key_list.index(event.keysym))
148 | self.display(2)
149 | elif event.keysym == 's':
150 | self.save_image()
151 | elif event.keysym == 'Return':
152 | self.new_image()
153 | self.display()
154 | elif event.keysym == '1':
155 | self.display(1)
156 | elif event.keysym == '2':
157 | self.display(0)
158 |
159 | def key_up(self, event):
160 | if event.keysym in ['1', '2']:
161 | self.display(self.last_state)
162 |
163 | def main():
164 | app = App()
165 | app.mainloop()
166 |
167 | if __name__ == "__main__":
168 | main()
--------------------------------------------------------------------------------
/run_generator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import PIL.Image
4 |
5 | from dnnlib import tflib
6 | from training import misc
7 |
8 | def generate(checkpoint, image, mask, output, truncation):
9 | real = np.asarray(PIL.Image.open(image)).transpose([2, 0, 1])
10 | real = misc.adjust_dynamic_range(real, [0, 255], [-1, 1])
11 | mask = np.asarray(PIL.Image.open(mask).convert('1'), dtype=np.float32)[np.newaxis]
12 |
13 | tflib.init_tf()
14 | _, _, Gs = misc.load_pkl(checkpoint)
15 | latent = np.random.randn(1, *Gs.input_shape[1:])
16 | fake = Gs.run(latent, None, real[np.newaxis], mask[np.newaxis], truncation_psi=truncation)[0]
17 | fake = misc.adjust_dynamic_range(fake, [-1, 1], [0, 255])
18 | fake = fake.clip(0, 255).astype(np.uint8).transpose([1, 2, 0])
19 | fake = PIL.Image.fromarray(fake)
20 | fake.save(output)
21 |
22 | def main():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('-c', '--checkpoint', help='Network checkpoint path', required=True)
25 | parser.add_argument('-i', '--image', help='Original image path', required=True)
26 | parser.add_argument('-m', '--mask', help='Mask path', required=True)
27 | parser.add_argument('-o', '--output', help='Output (inpainted) image path', required=True)
28 | parser.add_argument('-t', '--truncation', help='Truncation psi for the trade-off between quality and diversity. Defaults to 1.', default=None)
29 |
30 | args = parser.parse_args()
31 | generate(**vars(args))
32 |
33 | if __name__ == "__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/run_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import argparse
8 | import os
9 | import sys
10 |
11 | import dnnlib
12 | import dnnlib.tflib as tflib
13 |
14 | from metrics import metric_base
15 | from metrics.metric_defaults import metric_defaults
16 | from training import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def run(network_pkls, metrics, dataset, data_dir, mirror_augment, num_repeats, truncation, resume_with_new_nets):
21 | tflib.init_tf()
22 | dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0)
23 | num_gpus = dnnlib.submit_config.num_gpus
24 | truncations = [float(t) for t in truncation.split(',')] if truncation is not None else [None]
25 |
26 | for network_pkl in network_pkls.split(','):
27 | print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl))
28 | metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics])
29 | metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment,
30 | num_gpus=num_gpus, num_repeats=num_repeats, resume_with_new_nets=resume_with_new_nets, truncations=truncations)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def _str_to_bool(v):
35 | if isinstance(v, bool):
36 | return v
37 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
38 | return True
39 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
40 | return False
41 | else:
42 | raise argparse.ArgumentTypeError('Boolean value expected.')
43 |
44 | #----------------------------------------------------------------------------
45 |
46 | def main():
47 | parser = argparse.ArgumentParser(
48 | description='Run CoModGAN metrics.',
49 | formatter_class=argparse.RawDescriptionHelpFormatter
50 | )
51 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
52 | parser.add_argument('--network', help='Network pickle filename', dest='network_pkls', required=True)
53 | parser.add_argument('--metrics', help='Metrics to compute (default: %(default)s)', default='ids10k', type=lambda x: x.split(','))
54 | parser.add_argument('--dataset', help='Training dataset', required=True)
55 | parser.add_argument('--data-dir', help='Dataset root directory', required=True)
56 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, type=_str_to_bool, metavar='BOOL')
57 | parser.add_argument('--num-gpus', help='Number of GPUs to use', type=int, default=1, metavar='N')
58 | parser.add_argument('--num-repeats', type=int, default=1)
59 | parser.add_argument('--truncation', type=str, default=None)
60 | parser.add_argument('--resume-with-new-nets', default=False, action='store_true')
61 |
62 | args = parser.parse_args()
63 |
64 | if not os.path.exists(args.data_dir):
65 | print ('Error: dataset root directory does not exist.')
66 | sys.exit(1)
67 |
68 | kwargs = vars(args)
69 | sc = dnnlib.SubmitConfig()
70 | sc.num_gpus = kwargs.pop('num_gpus')
71 | sc.submit_target = dnnlib.SubmitTarget.LOCAL
72 | sc.local.do_not_copy_source_files = True
73 | sc.run_dir_root = kwargs.pop('result_dir')
74 | sc.run_desc = 'run-metrics'
75 | dnnlib.submit_run(sc, 'run_metrics.run', **kwargs)
76 |
77 | #----------------------------------------------------------------------------
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
82 | #----------------------------------------------------------------------------
83 |
--------------------------------------------------------------------------------
/run_training.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import argparse
8 | import copy
9 | import os
10 | import sys
11 |
12 | import dnnlib
13 | from dnnlib import EasyDict
14 |
15 | from metrics.metric_defaults import metric_defaults
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | def run(dataset, data_dir, result_dir, num_gpus, total_kimg, mirror_augment, metrics, resume, resume_with_new_nets, disable_style_mod, disable_cond_mod):
20 |
21 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
22 | G = EasyDict(func_name='training.co_mod_gan.G_main') # Options for generator network.
23 | D = EasyDict(func_name='training.co_mod_gan.D_co_mod_gan') # Options for discriminator network.
24 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
25 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
26 | G_loss = EasyDict(func_name='training.loss.G_masked_logistic_ns_l1') # Options for generator loss.
27 | D_loss = EasyDict(func_name='training.loss.D_masked_logistic_r1') # Options for discriminator loss.
28 | sched = EasyDict() # Options for TrainingSchedule.
29 | grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid().
30 | sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run().
31 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf().
32 |
33 | train.data_dir = data_dir
34 | train.total_kimg = total_kimg
35 | train.mirror_augment = mirror_augment
36 | train.image_snapshot_ticks = train.network_snapshot_ticks = 10
37 | sched.G_lrate_base = sched.D_lrate_base = 0.002
38 | sched.minibatch_size_base = 32
39 | sched.minibatch_gpu_base = 4
40 | D_loss.gamma = 10
41 | metrics = [metric_defaults[x] for x in metrics]
42 | desc = 'co-mod-gan'
43 |
44 | desc += '-' + os.path.basename(dataset)
45 | dataset_args = EasyDict(tfrecord_dir=dataset)
46 |
47 | assert num_gpus in [1, 2, 4, 8]
48 | sc.num_gpus = num_gpus
49 | desc += '-%dgpu' % num_gpus
50 |
51 | if resume is not None:
52 | resume_kimg = int(os.path.basename(resume).replace('.pkl', '').split('-')[-1])
53 | else:
54 | resume_kimg = 0
55 |
56 | if disable_style_mod:
57 | G.style_mod = False
58 |
59 | if disable_cond_mod:
60 | G.cond_mod = False
61 |
62 | sc.submit_target = dnnlib.SubmitTarget.LOCAL
63 | sc.local.do_not_copy_source_files = True
64 | kwargs = EasyDict(train)
65 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
66 | kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
67 | kwargs.update(resume_pkl=resume, resume_kimg=resume_kimg, resume_with_new_nets=resume_with_new_nets)
68 | kwargs.submit_config = copy.deepcopy(sc)
69 | kwargs.submit_config.run_dir_root = result_dir
70 | kwargs.submit_config.run_desc = desc
71 | dnnlib.submit_run(**kwargs)
72 |
73 | #----------------------------------------------------------------------------
74 |
75 | def _str_to_bool(v):
76 | if isinstance(v, bool):
77 | return v
78 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
79 | return True
80 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
81 | return False
82 | else:
83 | raise argparse.ArgumentTypeError('Boolean value expected.')
84 |
85 | def _parse_comma_sep(s):
86 | if s is None or s.lower() == 'none' or s == '':
87 | return []
88 | return s.split(',')
89 |
90 | #----------------------------------------------------------------------------
91 |
92 | _examples = '''examples:
93 |
94 | # Train CoModGAN using the FFHQ dataset
95 | python %(prog)s --data-dir=~/datasets --dataset=ffhq --metrics=ids10k --num-gpus=8
96 |
97 | '''
98 |
99 | def main():
100 | parser = argparse.ArgumentParser(
101 | description='Train CoModGAN.',
102 | epilog=_examples,
103 | formatter_class=argparse.RawDescriptionHelpFormatter
104 | )
105 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
106 | parser.add_argument('--data-dir', help='Dataset root directory', required=True)
107 | parser.add_argument('--dataset', help='Training dataset', required=True)
108 | parser.add_argument('--num-gpus', help='Number of GPUs (default: %(default)s)', default=1, type=int, metavar='N')
109 | parser.add_argument('--total-kimg', help='Training length in thousands of images (default: %(default)s)', metavar='KIMG', default=25000, type=int)
110 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
111 | parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='ids10k', type=_parse_comma_sep)
112 | parser.add_argument('--resume', default=None)
113 | parser.add_argument('--resume-with-new-nets', default=False, action='store_true')
114 | parser.add_argument('--disable-style-mod', default=False, action='store_true')
115 | parser.add_argument('--disable-cond-mod', default=False, action='store_true')
116 |
117 | args = parser.parse_args()
118 |
119 | if not os.path.exists(args.data_dir):
120 | print ('Error: dataset root directory does not exist.')
121 | sys.exit(1)
122 |
123 | for metric in args.metrics:
124 | if metric not in metric_defaults:
125 | print ('Error: unknown metric \'%s\'' % metric)
126 | sys.exit(1)
127 |
128 | run(**vars(args))
129 |
130 | #----------------------------------------------------------------------------
131 |
132 | if __name__ == "__main__":
133 | main()
134 |
135 | #----------------------------------------------------------------------------
136 |
137 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | # empty
8 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Multi-resolution input data pipeline."""
8 |
9 | import os
10 | import glob
11 | import numpy as np
12 | import tensorflow as tf
13 | import dnnlib
14 | import dnnlib.tflib as tflib
15 |
16 | from .mask_generator import tf_mask_generator
17 |
18 | #----------------------------------------------------------------------------
19 | # Dataset class that loads data from tfrecords files.
20 |
21 | class TFRecordDataset:
22 | def __init__(self,
23 | tfrecord_dir, # Directory containing a collection of tfrecords files.
24 | resolution = None, # Dataset resolution, None = autodetect.
25 | label_file = None, # Relative path of the labels file, None = autodetect.
26 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components.
27 | max_images = None, # Maximum number of images to use, None = use all images.
28 | repeat = True, # Repeat dataset indefinitely?
29 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling.
30 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching.
31 | buffer_mb = 256, # Read buffer size (megabytes).
32 | num_threads = 2, # Number of concurrent threads.
33 | num_val_images = 10000,
34 | compressed = False):
35 |
36 | self.tfrecord_dir = tfrecord_dir
37 | self.resolution = None
38 | self.resolution_log2 = None
39 | self.shape = [] # [channels, height, width]
40 | self.dtype = 'uint8'
41 | self.dynamic_range = [0, 255]
42 | self.label_file = label_file
43 | self.label_size = None # components
44 | self.label_dtype = None
45 | self.pix2pix = False
46 | self._np_labels = None
47 | self._tf_minibatch_in = None
48 | self._tf_labels_var = None
49 | self._tf_labels_dataset = None
50 | self._tf_datasets = dict()
51 | self._tf_val_datasets = dict()
52 | self._tf_iterator = None
53 | self._tf_val_iterator = None
54 | self._tf_init_ops = dict()
55 | self._tf_val_init_ops = dict()
56 | self._tf_minibatch_np = None
57 | self._tf_minibatch_val_np = None
58 | self._tf_masks_iterator_np = None
59 | self._cur_minibatch = -1
60 | self._cur_lod = -1
61 | self._hole_range = -1
62 |
63 | # List tfrecords files and inspect their shapes.
64 | assert os.path.isdir(self.tfrecord_dir)
65 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
66 | assert len(tfr_files) >= 1
67 | tfr_shapes = []
68 | for tfr_file in tfr_files:
69 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
70 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
71 | ex = tf.train.Example()
72 | ex.ParseFromString(record)
73 | features = ex.features.feature
74 | if 'compressed' in features and features['compressed'].int64_list.value[0]:
75 | compressed = True
76 | if 'num_val_images' in features:
77 | num_val_images = features['num_val_images'].int64_list.value[0]
78 | tfr_shapes.append(features['shape'].int64_list.value)
79 | break
80 |
81 | # Determine shape and resolution.
82 | max_shape = max(tfr_shapes, key=np.prod)
83 |
84 | if max_shape[0] > 3:
85 | self.pix2pix = True
86 |
87 | self.resolution = resolution if resolution is not None else max_shape[1]
88 | self.resolution_log2 = int(np.log2(self.resolution))
89 | self.shape = [max_shape[0], self.resolution, self.resolution]
90 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
91 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
92 | assert all(shape[1] == shape[2] for shape in tfr_shapes)
93 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
94 |
95 | # Autodetect label filename.
96 | if self.label_file is None:
97 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
98 | if len(guess):
99 | self.label_file = guess[0]
100 | elif not os.path.isfile(self.label_file):
101 | guess = os.path.join(self.tfrecord_dir, self.label_file)
102 | if os.path.isfile(guess):
103 | self.label_file = guess
104 |
105 | # Load labels.
106 | assert max_label_size == 'full' or max_label_size >= 0
107 | self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
108 | if self.label_file is not None and max_label_size != 0:
109 | self._np_labels = np.load(self.label_file)
110 | assert self._np_labels.ndim == 2
111 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
112 | self._np_labels = self._np_labels[:, :max_label_size]
113 | if max_images is not None and self._np_labels.shape[0] > max_images:
114 | self._np_labels = self._np_labels[:max_images]
115 | self.label_size = self._np_labels.shape[1]
116 | self.label_dtype = self._np_labels.dtype.name
117 |
118 | # Build TF expressions.
119 | with tf.name_scope('Dataset'), tf.device('/cpu:0'):
120 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
121 | self._tf_hole_range = tf.placeholder(tf.float32, name='hole_range', shape=[2])
122 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')
123 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
124 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
125 | if tfr_lod < 0:
126 | continue
127 | dset_raw = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
128 | if max_images is not None:
129 | dset_raw = dset_raw.take(max_images)
130 | for tf_datasets, dset in [(self._tf_val_datasets, dset_raw.take(num_val_images)), (self._tf_datasets, dset_raw.skip(num_val_images))]:
131 | if compressed:
132 | dset = dset.map(self.parse_and_decode_tfrecord_tf, num_parallel_calls=num_threads)
133 | else:
134 | dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads)
135 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
136 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
137 | if shuffle_mb > 0:
138 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
139 | if repeat:
140 | dset = dset.repeat()
141 | if prefetch_mb > 0:
142 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
143 | dset = dset.batch(self._tf_minibatch_in)
144 | tf_datasets[tfr_lod] = dset
145 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
146 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
147 | self._tf_val_iterator = tf.data.Iterator.from_structure(self._tf_val_datasets[0].output_types, self._tf_val_datasets[0].output_shapes)
148 | self._tf_val_init_ops = {lod: self._tf_val_iterator.make_initializer(dset) for lod, dset in self._tf_val_datasets.items()}
149 |
150 | self._tf_masks_dataset = tf_mask_generator(self.resolution, self._tf_hole_range).batch(self._tf_minibatch_in).prefetch(64)
151 | self._tf_masks_iterator = self._tf_masks_dataset.make_initializable_iterator()
152 |
153 | def close(self):
154 | pass
155 |
156 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf().
157 | def configure(self, minibatch_size, lod=0, hole_range=[0,1]):
158 | lod = int(np.floor(lod))
159 | assert minibatch_size >= 1 and lod in self._tf_datasets
160 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod or (hole_range is not None and self._hole_range != hole_range):
161 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})
162 | self._tf_val_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})
163 | self._cur_minibatch = minibatch_size
164 | self._cur_lod = lod
165 | if hole_range is not None:
166 | self._tf_masks_iterator.initializer.run({self._tf_minibatch_in: minibatch_size, self._tf_hole_range: hole_range})
167 | self._hole_range = hole_range
168 |
169 | # Get next minibatch as TensorFlow expressions.
170 | def get_minibatch_tf(self): # => images, labels
171 | return self._tf_iterator.get_next()
172 |
173 | def get_minibatch_val_tf(self): # => images, labels
174 | return self._tf_val_iterator.get_next()
175 |
176 | # Get next minibatch as NumPy arrays.
177 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
178 | self.configure(minibatch_size, lod)
179 | with tf.name_scope('Dataset'):
180 | if self._tf_minibatch_np is None:
181 | self._tf_minibatch_np = self.get_minibatch_tf()
182 | return tflib.run(self._tf_minibatch_np)
183 |
184 | def get_minibatch_val_np(self, minibatch_size, lod=0): # => images, labels
185 | self.configure(minibatch_size, lod)
186 | with tf.name_scope('Dataset'):
187 | if self._tf_minibatch_val_np is None:
188 | self._tf_minibatch_val_np = self.get_minibatch_val_tf()
189 | return tflib.run(self._tf_minibatch_val_np)
190 |
191 | # Get next minibatch as TensorFlow expressions.
192 | def get_random_masks_tf(self): # => images, labels
193 | return self._tf_masks_iterator.get_next()
194 |
195 | # Get next minibatch as NumPy arrays.
196 | def get_random_masks_np(self, minibatch_size, hole_range=[0,1]):
197 | self.configure(minibatch_size, hole_range=hole_range)
198 | with tf.name_scope('Dataset'):
199 | if self._tf_masks_iterator_np is None:
200 | self._tf_masks_iterator_np = self.get_random_masks_tf()
201 | return tflib.run(self._tf_masks_iterator_np)
202 |
203 | # Get random labels as TensorFlow expression.
204 | def get_random_labels_tf(self, minibatch_size): # => labels
205 | with tf.name_scope('Dataset'):
206 | if self.label_size > 0:
207 | with tf.device('/cpu:0'):
208 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
209 | return tf.zeros([minibatch_size, 0], self.label_dtype)
210 |
211 | # Get random labels as NumPy array.
212 | def get_random_labels_np(self, minibatch_size): # => labels
213 | if self.label_size > 0:
214 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])]
215 | return np.zeros([minibatch_size, 0], self.label_dtype)
216 |
217 | # Parse individual image from a tfrecords file into TensorFlow expression.
218 | @staticmethod
219 | def parse_tfrecord_tf(record):
220 | features = tf.parse_single_example(record, features={
221 | 'shape': tf.FixedLenFeature([3], tf.int64),
222 | 'data': tf.FixedLenFeature([], tf.string)})
223 | data = tf.decode_raw(features['data'], tf.uint8)
224 | return tf.reshape(data, features['shape'])
225 |
226 | @staticmethod
227 | def parse_and_decode_tfrecord_tf(record):
228 | features = tf.parse_single_example(record, features={
229 | 'shape': tf.FixedLenFeature([3], tf.int64),
230 | 'data': tf.FixedLenFeature([], tf.string)})
231 | shape = tf.cast(features['shape'], 'int32')
232 | data = tf.image.decode_image(features['data'])
233 | data = tf.image.resize_with_crop_or_pad(data, shape[1], shape[2])
234 | return tf.broadcast_to(tf.transpose(data, [2, 0, 1]), shape)
235 |
236 | #----------------------------------------------------------------------------
237 | # Helper func for constructing a dataset object using the given options.
238 |
239 | def load_dataset(class_name=None, data_dir=None, verbose=False, **kwargs):
240 | kwargs = dict(kwargs)
241 | if 'tfrecord_dir' in kwargs:
242 | if class_name is None:
243 | class_name = __name__ + '.TFRecordDataset'
244 | if data_dir is not None:
245 | kwargs['tfrecord_dir'] = os.path.join(data_dir, kwargs['tfrecord_dir'])
246 |
247 | assert class_name is not None
248 | if verbose:
249 | print('Streaming data using %s...' % class_name)
250 | dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs)
251 | if verbose:
252 | print('Dataset shape =', np.int32(dataset.shape).tolist())
253 | print('Dynamic range =', dataset.dynamic_range)
254 | print('Label size =', dataset.label_size)
255 | return dataset
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Loss functions."""
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | import dnnlib.tflib as tflib
12 | from dnnlib.tflib.autosummary import autosummary
13 |
14 | def G_masked_logistic_ns_l1(G, D, opt, training_set, minibatch_size, reals, masks, l1_weight=0):
15 | _ = opt
16 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
17 | labels = training_set.get_random_labels_tf(minibatch_size)
18 | fake_images_out = G.get_output_for(latents, labels, reals, masks, is_training=True)
19 | fake_scores_out = D.get_output_for(fake_images_out, labels, masks, is_training=True)
20 | logistic_loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
21 | logistic_loss = autosummary('Loss/logistic_loss', logistic_loss)
22 | l1_loss = tf.reduce_mean(tf.abs(fake_images_out - reals), axis=[1,2,3])
23 | l1_loss = autosummary('Loss/l1_loss', l1_loss)
24 | loss = logistic_loss + l1_loss * l1_weight
25 | return loss, None
26 |
27 | def D_masked_logistic_r1(G, D, opt, training_set, minibatch_size, reals, labels, masks, gamma=10.0):
28 | _ = opt, training_set
29 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
30 | fake_images_out = G.get_output_for(latents, labels, reals, masks, is_training=True)
31 | real_scores_out = D.get_output_for(reals, labels, masks, is_training=True)
32 | fake_scores_out = D.get_output_for(fake_images_out, labels, masks, is_training=True)
33 | real_scores_out = autosummary('Loss/scores/real', real_scores_out)
34 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
35 | loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
36 | loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out))
37 |
38 | with tf.name_scope('GradientPenalty'):
39 | real_grads = tf.gradients(tf.reduce_sum(real_scores_out), [reals])[0]
40 | gradient_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])
41 | gradient_penalty = autosummary('Loss/gradient_penalty', gradient_penalty)
42 | reg = gradient_penalty * (gamma * 0.5)
43 | return loss, reg
44 |
45 | #----------------------------------------------------------------------------
46 |
--------------------------------------------------------------------------------
/training/mask_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image, ImageDraw
3 | import math
4 | import random
5 |
6 | import tensorflow as tf
7 |
8 | def RandomBrush(
9 | max_tries,
10 | s,
11 | min_num_vertex = 4,
12 | max_num_vertex = 18,
13 | mean_angle = 2*math.pi / 5,
14 | angle_range = 2*math.pi / 15,
15 | min_width = 12,
16 | max_width = 48):
17 | H, W = s, s
18 | average_radius = math.sqrt(H*H+W*W) / 8
19 | mask = Image.new('L', (W, H), 0)
20 | for _ in range(np.random.randint(max_tries)):
21 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
22 | angle_min = mean_angle - np.random.uniform(0, angle_range)
23 | angle_max = mean_angle + np.random.uniform(0, angle_range)
24 | angles = []
25 | vertex = []
26 | for i in range(num_vertex):
27 | if i % 2 == 0:
28 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
29 | else:
30 | angles.append(np.random.uniform(angle_min, angle_max))
31 |
32 | h, w = mask.size
33 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
34 | for i in range(num_vertex):
35 | r = np.clip(
36 | np.random.normal(loc=average_radius, scale=average_radius//2),
37 | 0, 2*average_radius)
38 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
39 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
40 | vertex.append((int(new_x), int(new_y)))
41 |
42 | draw = ImageDraw.Draw(mask)
43 | width = int(np.random.uniform(min_width, max_width))
44 | draw.line(vertex, fill=1, width=width)
45 | for v in vertex:
46 | draw.ellipse((v[0] - width//2,
47 | v[1] - width//2,
48 | v[0] + width//2,
49 | v[1] + width//2),
50 | fill=1)
51 | if np.random.random() > 0.5:
52 | mask.transpose(Image.FLIP_LEFT_RIGHT)
53 | if np.random.random() > 0.5:
54 | mask.transpose(Image.FLIP_TOP_BOTTOM)
55 | mask = np.asarray(mask, np.uint8)
56 | if np.random.random() > 0.5:
57 | mask = np.flip(mask, 0)
58 | if np.random.random() > 0.5:
59 | mask = np.flip(mask, 1)
60 | return mask
61 |
62 | def RandomMask(s, hole_range=[0,1]):
63 | coef = min(hole_range[0] + hole_range[1], 1.0)
64 | while True:
65 | mask = np.ones((s, s), np.uint8)
66 | def Fill(max_size):
67 | w, h = np.random.randint(max_size), np.random.randint(max_size)
68 | ww, hh = w // 2, h // 2
69 | x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
70 | mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
71 | def MultiFill(max_tries, max_size):
72 | for _ in range(np.random.randint(max_tries)):
73 | Fill(max_size)
74 | MultiFill(int(10 * coef), s // 2)
75 | MultiFill(int(5 * coef), s)
76 | mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
77 | hole_ratio = 1 - np.mean(mask)
78 | if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
79 | continue
80 | return mask[np.newaxis, ...].astype(np.float32)
81 |
82 | def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
83 | return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis = 0)
84 |
85 | def tf_mask_generator(s, tf_hole_range):
86 | def random_mask_generator(hole_range):
87 | while True:
88 | yield RandomMask(s, hole_range=hole_range)
89 | return tf.data.Dataset.from_generator(random_mask_generator, tf.float32, tf.TensorShape([1, s, s]), (tf_hole_range,))
--------------------------------------------------------------------------------
/training/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Miscellaneous utility functions."""
8 |
9 | import os
10 | import pickle
11 | import numpy as np
12 | import PIL.Image
13 | import PIL.ImageFont
14 | import dnnlib
15 | from dnnlib import tflib
16 |
17 | import tensorflow as tf
18 |
19 | #----------------------------------------------------------------------------
20 | # Convenience wrappers for pickle that are able to load data produced by
21 | # older versions of the code, and from external URLs.
22 |
23 | def open_file_or_url(file_or_url):
24 | if dnnlib.util.is_url(file_or_url):
25 | return dnnlib.util.open_url(file_or_url, cache_dir='.stylegan2-cache')
26 | return open(file_or_url, 'rb')
27 |
28 | def load_pkl(file_or_url):
29 | with open_file_or_url(file_or_url) as file:
30 | return pickle.load(file, encoding='latin1')
31 |
32 | def save_pkl(obj, filename):
33 | with open(filename, 'wb') as file:
34 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
35 |
36 | #----------------------------------------------------------------------------
37 | # Image utils.
38 |
39 | def adjust_dynamic_range(data, drange_in, drange_out):
40 | if drange_in != drange_out:
41 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))
42 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
43 | data = data * scale + bias
44 | return data
45 |
46 | def create_image_grid(images, grid_size=None, pix2pix=False):
47 | if pix2pix:
48 | images, _ = np.split(images, 2, axis=1)
49 | assert images.ndim == 3 or images.ndim == 4
50 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2]
51 |
52 | if grid_size is not None:
53 | grid_w, grid_h = tuple(grid_size)
54 | else:
55 | grid_w = max(int(np.ceil(np.sqrt(num))), 1)
56 | grid_h = max((num - 1) // grid_w + 1, 1)
57 |
58 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype)
59 | for idx in range(num):
60 | x = (idx % grid_w) * img_w
61 | y = (idx // grid_w) * img_h
62 | grid[..., y : y + img_h, x : x + img_w] = images[idx]
63 | return grid
64 |
65 | def convert_to_pil_image(image, drange=[0,1]):
66 | assert image.ndim == 2 or image.ndim == 3
67 | if image.ndim == 3:
68 | if image.shape[0] == 1:
69 | image = image[0] # grayscale CHW => HW
70 | else:
71 | image = image.transpose(1, 2, 0) # CHW -> HWC
72 |
73 | if drange is not None:
74 | image = adjust_dynamic_range(image, drange, [0,255])
75 | image = np.rint(image).clip(0, 255).astype(np.uint8)
76 | fmt = 'RGB' if image.ndim == 3 else 'L'
77 | return PIL.Image.fromarray(image, fmt)
78 |
79 | def save_image_grid(images, filename, drange=[0,1], grid_size=None, pix2pix=False):
80 | convert_to_pil_image(create_image_grid(images, grid_size, pix2pix=pix2pix), drange).save(filename)
81 |
82 | def apply_mirror_augment(minibatch):
83 | mask = np.random.rand(minibatch.shape[0]) < 0.5
84 | minibatch = np.array(minibatch)
85 | minibatch[mask] = minibatch[mask, :, :, ::-1]
86 | return minibatch
87 |
88 | #----------------------------------------------------------------------------
89 | # Loading data from previous training runs.
90 |
91 | def parse_config_for_previous_run(run_dir):
92 | with open(os.path.join(run_dir, 'submit_config.pkl'), 'rb') as f:
93 | data = pickle.load(f)
94 | data = data.get('run_func_kwargs', {})
95 | return dict(train=data, dataset=data.get('dataset_args', {}))
96 |
97 | #----------------------------------------------------------------------------
98 | # Size and contents of the image snapshot grids that are exported
99 | # periodically during training.
100 |
101 | def setup_snapshot_image_grid(training_set,
102 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
103 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.
104 |
105 | # Select size.
106 | gw = 1; gh = 1
107 | if size == '1080p':
108 | gw = np.clip(1920 // training_set.shape[2], 3, 32)
109 | gh = np.clip(1080 // training_set.shape[1], 2, 32)
110 | if size == '4k':
111 | gw = np.clip(3840 // training_set.shape[2], 7, 32)
112 | gh = np.clip(2160 // training_set.shape[1], 4, 32)
113 | if size == '8k':
114 | gw = np.clip(7680 // training_set.shape[2], 7, 32)
115 | gh = np.clip(4320 // training_set.shape[1], 4, 32)
116 |
117 | # Initialize data arrays.
118 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
119 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
120 |
121 | # Random layout.
122 | if layout == 'random':
123 | reals[:], labels[:] = training_set.get_minibatch_val_np(gw * gh)
124 |
125 | # Class-conditional layouts.
126 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4])
127 | if layout in class_layouts:
128 | bw, bh = class_layouts[layout]
129 | nw = (gw - 1) // bw + 1
130 | nh = (gh - 1) // bh + 1
131 | blocks = [[] for _i in range(nw * nh)]
132 | for _iter in range(1000000):
133 | real, label = training_set.get_minibatch_val_np(1)
134 | idx = np.argmax(label[0])
135 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh:
136 | idx += training_set.label_size
137 | if idx < len(blocks):
138 | blocks[idx].append((real, label))
139 | if all(len(block) >= bw * bh for block in blocks):
140 | break
141 | for i, block in enumerate(blocks):
142 | for j, (real, label) in enumerate(block):
143 | x = (i % nw) * bw + j % bw
144 | y = (i // nw) * bh + j // bw
145 | if x < gw and y < gh:
146 | reals[x + y * gw] = real[0]
147 | labels[x + y * gw] = label[0]
148 |
149 | masks = training_set.get_random_masks_np(gw * gh)
150 |
151 | return (gw, gh), reals, labels, masks
152 |
153 | #----------------------------------------------------------------------------
154 |
--------------------------------------------------------------------------------