├── result └── 生成结果在这里 ├── examples ├── edit_age.jpg ├── example1.png ├── example2.png ├── example3.png ├── edit_angle.jpg ├── edit_smile.jpg ├── 64_examples.jpg ├── edit_exposure.jpg └── edit_gender.jpg ├── model └── 模型下载后放在这里.txt ├── dnnlib ├── __pycache__ │ ├── util.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── tflib │ ├── __pycache__ │ │ ├── tfutil.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── network.cpython-36.pyc │ │ ├── autosummary.cpython-36.pyc │ │ └── optimizer.cpython-36.pyc │ ├── __init__.py │ ├── autosummary.py │ ├── tfutil.py │ ├── optimizer.py │ └── network.py ├── submission │ ├── __pycache__ │ │ ├── submit.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── run_context.cpython-36.pyc │ ├── __init__.py │ ├── _internal │ │ └── run.py │ ├── run_context.py │ └── submit.py ├── __init__.py └── util.py ├── generate_model.py └── README.md /result/生成结果在这里: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/edit_age.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_age.jpg -------------------------------------------------------------------------------- /examples/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/example1.png -------------------------------------------------------------------------------- /examples/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/example2.png -------------------------------------------------------------------------------- /examples/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/example3.png -------------------------------------------------------------------------------- /examples/edit_angle.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_angle.jpg -------------------------------------------------------------------------------- /examples/edit_smile.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_smile.jpg -------------------------------------------------------------------------------- /model/模型下载后放在这里.txt: -------------------------------------------------------------------------------- 1 | 百度网盘: 2 | 3 | 链接:https://pan.baidu.com/s/1_qeDNrmSrvzR3j-xqTa7zQ 4 | 提取码:rybs 5 | 6 | 7 | -------------------------------------------------------------------------------- /examples/64_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/64_examples.jpg -------------------------------------------------------------------------------- /examples/edit_exposure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_exposure.jpg -------------------------------------------------------------------------------- /examples/edit_gender.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_gender.jpg -------------------------------------------------------------------------------- /dnnlib/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__pycache__/submit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/submission/__pycache__/submit.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/submission/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__pycache__/run_context.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/submission/__pycache__/run_context.cpython-36.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /generate_model.py: -------------------------------------------------------------------------------- 1 | # Thanks to StyleGAN provider —— Copyright (c) 2019, NVIDIA CORPORATION. 2 | # 3 | # This work is trained by Copyright(c) 2018, seeprettyface.com, BUPT_GWY. 4 | 5 | """Minimal script for generating an image using pre-trained StyleGAN generator.""" 6 | 7 | import os 8 | import pickle 9 | import numpy as np 10 | import PIL.Image 11 | import dnnlib.tflib as tflib 12 | 13 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) 14 | 15 | def text_save(file, data): # save generate code, which can be modified to generate customized style 16 | for i in range(len(data[0])): 17 | s = str(data[0][i])+'\n' 18 | file.write(s) 19 | 20 | def main(): 21 | # Initialize TensorFlow. 22 | tflib.init_tf() 23 | 24 | # Load pre-trained network. 25 | model_path = 'model/generator_model.pkl' 26 | 27 | # Prepare result folder 28 | result_dir = 'result' 29 | os.makedirs(result_dir, exist_ok=True) 30 | os.makedirs(result_dir + '/generate_code', exist_ok=True) 31 | 32 | with open(model_path, "rb") as f: 33 | _G, _D, Gs = pickle.load(f, encoding='latin1') 34 | 35 | # Print network details. 36 | Gs.print_layers() 37 | 38 | # Generate pictures 39 | generate_num = 20 40 | for i in range(generate_num): 41 | 42 | # Generate latent. 43 | latents = np.random.randn(1, Gs.input_shape[1]) 44 | 45 | # Save latent. 46 | txt_filename = os.path.join(result_dir, 'generate_code/' + str(i).zfill(4) + '.txt') 47 | file = open(txt_filename, 'w') 48 | text_save(file, latents) 49 | 50 | # Generate image. 51 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 52 | images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) 53 | 54 | # Save image. 55 | png_filename = os.path.join(result_dir, str(i).zfill(4)+'.png') 56 | PIL.Image.fromarray(images[0], 'RGB').save(png_filename) 57 | 58 | # Close file. 59 | file.close() 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 超模脸生成器 2 |
3 |   更新:基于StyleGAN2制作的新版生成器消除了生成图片中水滴斑点和扭曲/损坏现象的出现,质量大幅提升。点此查看新版
4 | --------------------------------------------------------------------------------------------------------------------


5 |   注明:之前做的一些有意思的人脸生成器,现在全部开源分享出来。它的主要作用是可生成制作各类型的人脸素材,供我们任意使用且无须担心人脸版权的问题。在定制人脸上,开源的全系列生成器包括:黄种人脸生成器网红脸生成器明星脸生成器超模脸生成器萌娃脸生成器,同时人脸属性编辑器能够对所有这些生成器生成的人物进行调整和改变。
6 |   此项目已免费开源使用,模型版权拥有者为:www.seeprettyface.com 。
7 |


8 |   这是一个用StyleGAN训练出的超模人脸生成器,生成效果如下所示。


9 | 10 | # 生成示例 11 | 12 | ##   单张样本 13 |      ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/example1.png)

14 |      ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/example2.png)

15 |      ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/example3.png)

16 | 17 | ## 概览(有筛选) 18 | ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/64_examples.jpg) 19 |


20 | 查看更多的生成样本可以前往[这里](https://pan.baidu.com/s/1G5lTsk1TJPZMCHqudQqqYg)(提取码:2A5W),是一个含有1万张生成样本的超模脸数据集。


21 | 22 | # 超模脸属性编辑 23 |   人脸属性编辑支持在年龄、笑容、角度、性别和光照等23个维度上对生成人物作出调整(详细了解请前往[人脸属性编辑器](https://github.com/a312863063/seeprettyface-face_editor)处)。这儿只展示5种基本调整示例。 24 | ## 笑容调整 25 | ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/edit_smile.jpg) 26 |

27 | ## 年龄调整 28 | ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/edit_age.jpg) 29 |

30 | ## 角度调整 31 | ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/edit_angle.jpg) 32 |

33 | ## 性别调整 34 | ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/edit_gender.jpg) 35 |

36 | ## 光照调整 37 | ![Image text](https://github.com/a312863063/seeprettyface-generator-model/blob/master/examples/edit_exposure.jpg) 38 |


39 | 40 | # 运行代码 41 | ## 环境配置 42 |   Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
43 |   64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.
44 |   TensorFlow 1.10.0 or newer with GPU support.
45 |   NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.
46 | 47 | ## 运行步骤 48 |   1.在model文件夹中按照txt地址下载模型,放在该位置
49 |   2.运行generate_model.py
50 |


51 | ## 了解技术原理 & 获取训练集:[点此进入](http://www.seeprettyface.com/) 52 | ![Image text](https://github.com/a312863063/seeprettyface/blob/master/EP001-01.png)


53 | 54 | ## 小小的赞助~ 55 |

56 | Sample 57 |

58 | 若对您有帮助可给予小小的赞助~ 59 |

60 |

61 |


62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | return 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | return tf.saturate_cast(images, tf.uint8) 241 | -------------------------------------------------------------------------------- /dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Submit a function to be run either locally or in a computing cluster.""" 9 | 10 | import copy 11 | import io 12 | import os 13 | import pathlib 14 | import pickle 15 | import platform 16 | import pprint 17 | import re 18 | import shutil 19 | import time 20 | import traceback 21 | 22 | import zipfile 23 | 24 | from enum import Enum 25 | 26 | from .. import util 27 | from ..util import EasyDict 28 | 29 | 30 | class SubmitTarget(Enum): 31 | """The target where the function should be run. 32 | 33 | LOCAL: Run it locally. 34 | """ 35 | LOCAL = 1 36 | 37 | 38 | class PathType(Enum): 39 | """Determines in which format should a path be formatted. 40 | 41 | WINDOWS: Format with Windows style. 42 | LINUX: Format with Linux/Posix style. 43 | AUTO: Use current OS type to select either WINDOWS or LINUX. 44 | """ 45 | WINDOWS = 1 46 | LINUX = 2 47 | AUTO = 3 48 | 49 | 50 | _user_name_override = None 51 | 52 | 53 | class SubmitConfig(util.EasyDict): 54 | """Strongly typed config dict needed to submit runs. 55 | 56 | Attributes: 57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 58 | run_desc: Description of the run. Will be used in the run dir and task name. 59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 61 | submit_target: Submit target enum value. Used to select where the run is actually launched. 62 | num_gpus: Number of GPUs used/requested for the run. 63 | print_info: Whether to print debug information when submitting. 64 | ask_confirmation: Whether to ask a confirmation before submitting. 65 | run_id: Automatically populated value during submit. 66 | run_name: Automatically populated value during submit. 67 | run_dir: Automatically populated value during submit. 68 | run_func_name: Automatically populated value during submit. 69 | run_func_kwargs: Automatically populated value during submit. 70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 71 | task_name: Automatically populated value during submit. 72 | host_name: Automatically populated value during submit. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | 78 | # run (set these) 79 | self.run_dir_root = "" # should always be passed through get_path_from_template 80 | self.run_desc = "" 81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] 82 | self.run_dir_extra_files = None 83 | 84 | # submit (set these) 85 | self.submit_target = SubmitTarget.LOCAL 86 | self.num_gpus = 1 87 | self.print_info = False 88 | self.ask_confirmation = False 89 | 90 | # (automatically populated) 91 | self.run_id = None 92 | self.run_name = None 93 | self.run_dir = None 94 | self.run_func_name = None 95 | self.run_func_kwargs = None 96 | self.user_name = None 97 | self.task_name = None 98 | self.host_name = "localhost" 99 | 100 | 101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 102 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 103 | # automatically select path type depending on running OS 104 | if path_type == PathType.AUTO: 105 | if platform.system() == "Windows": 106 | path_type = PathType.WINDOWS 107 | elif platform.system() == "Linux": 108 | path_type = PathType.LINUX 109 | else: 110 | raise RuntimeError("Unknown platform") 111 | 112 | path_template = path_template.replace("", get_user_name()) 113 | 114 | # return correctly formatted path 115 | if path_type == PathType.WINDOWS: 116 | return str(pathlib.PureWindowsPath(path_template)) 117 | elif path_type == PathType.LINUX: 118 | return str(pathlib.PurePosixPath(path_template)) 119 | else: 120 | raise RuntimeError("Unknown platform") 121 | 122 | 123 | def get_template_from_path(path: str) -> str: 124 | """Convert a normal path back to its template representation.""" 125 | # replace all path parts with the template tags 126 | path = path.replace("\\", "/") 127 | return path 128 | 129 | 130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 131 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 132 | path_template = get_template_from_path(path) 133 | path = get_path_from_template(path_template, path_type) 134 | return path 135 | 136 | 137 | def set_user_name_override(name: str) -> None: 138 | """Set the global username override value.""" 139 | global _user_name_override 140 | _user_name_override = name 141 | 142 | 143 | def get_user_name(): 144 | """Get the current user name.""" 145 | if _user_name_override is not None: 146 | return _user_name_override 147 | elif platform.system() == "Windows": 148 | return os.getlogin() 149 | elif platform.system() == "Linux": 150 | try: 151 | import pwd # pylint: disable=import-error 152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member 153 | except: 154 | return "unknown" 155 | else: 156 | raise RuntimeError("Unknown platform") 157 | 158 | 159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 160 | """Create a new run dir with increasing ID number at the start.""" 161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 162 | 163 | if not os.path.exists(run_dir_root): 164 | print("Creating the run dir root: {}".format(run_dir_root)) 165 | os.makedirs(run_dir_root) 166 | 167 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 169 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 170 | 171 | if os.path.exists(run_dir): 172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 173 | 174 | print("Creating the run dir: {}".format(run_dir)) 175 | os.makedirs(run_dir) 176 | 177 | return run_dir 178 | 179 | 180 | def _get_next_run_id_local(run_dir_root: str) -> int: 181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 183 | r = re.compile("^\\d+") # match one or more digits at the start of the string 184 | run_id = 0 185 | 186 | for dir_name in dir_names: 187 | m = r.match(dir_name) 188 | 189 | if m is not None: 190 | i = int(m.group()) 191 | run_id = max(run_id, i + 1) 192 | 193 | return run_id 194 | 195 | 196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: 197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 198 | print("Copying files to the run dir") 199 | files = [] 200 | 201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 202 | assert '.' in submit_config.run_func_name 203 | for _idx in range(submit_config.run_func_name.count('.') - 1): 204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 206 | 207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 209 | 210 | if submit_config.run_dir_extra_files is not None: 211 | files += submit_config.run_dir_extra_files 212 | 213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] 215 | 216 | util.copy_files_and_create_dirs(files) 217 | 218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 219 | 220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 222 | 223 | 224 | def run_wrapper(submit_config: SubmitConfig) -> None: 225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 227 | 228 | checker = None 229 | 230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 231 | if is_local: 232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 234 | logger = util.Logger(file_name=None, should_flush=True) 235 | 236 | import dnnlib 237 | dnnlib.submit_config = submit_config 238 | 239 | try: 240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 241 | start_time = time.time() 242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) 243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 244 | except: 245 | if is_local: 246 | raise 247 | else: 248 | traceback.print_exc() 249 | 250 | log_src = os.path.join(submit_config.run_dir, "log.txt") 251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 252 | shutil.copyfile(log_src, log_dst) 253 | finally: 254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 255 | 256 | dnnlib.submit_config = None 257 | logger.close() 258 | 259 | if checker is not None: 260 | checker.stop() 261 | 262 | 263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 265 | submit_config = copy.copy(submit_config) 266 | 267 | if submit_config.user_name is None: 268 | submit_config.user_name = get_user_name() 269 | 270 | submit_config.run_func_name = run_func_name 271 | submit_config.run_func_kwargs = run_func_kwargs 272 | 273 | assert submit_config.submit_target == SubmitTarget.LOCAL 274 | if submit_config.submit_target in {SubmitTarget.LOCAL}: 275 | run_dir = _create_run_dir_local(submit_config) 276 | 277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 278 | submit_config.run_dir = run_dir 279 | _populate_run_dir(run_dir, submit_config) 280 | 281 | if submit_config.print_info: 282 | print("\nSubmit config:\n") 283 | pprint.pprint(submit_config, indent=4, width=200, compact=False) 284 | print() 285 | 286 | if submit_config.ask_confirmation: 287 | if not util.ask_yes_no("Continue submitting the job?"): 288 | return 289 | 290 | run_wrapper(submit_config) 291 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import uuid 27 | 28 | from distutils.util import strtobool 29 | from typing import Any, List, Tuple, Union 30 | 31 | 32 | # Util classes 33 | # ------------------------------------------------------------------------------------------ 34 | 35 | 36 | class EasyDict(dict): 37 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 38 | 39 | def __getattr__(self, name: str) -> Any: 40 | try: 41 | return self[name] 42 | except KeyError: 43 | raise AttributeError(name) 44 | 45 | def __setattr__(self, name: str, value: Any) -> None: 46 | self[name] = value 47 | 48 | def __delattr__(self, name: str) -> None: 49 | del self[name] 50 | 51 | 52 | class Logger(object): 53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 54 | 55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 56 | self.file = None 57 | 58 | if file_name is not None: 59 | self.file = open(file_name, file_mode) 60 | 61 | self.should_flush = should_flush 62 | self.stdout = sys.stdout 63 | self.stderr = sys.stderr 64 | 65 | sys.stdout = self 66 | sys.stderr = self 67 | 68 | def __enter__(self) -> "Logger": 69 | return self 70 | 71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 72 | self.close() 73 | 74 | def write(self, text: str) -> None: 75 | """Write text to stdout (and a file) and optionally flush.""" 76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 77 | return 78 | 79 | if self.file is not None: 80 | self.file.write(text) 81 | 82 | self.stdout.write(text) 83 | 84 | if self.should_flush: 85 | self.flush() 86 | 87 | def flush(self) -> None: 88 | """Flush written text to both stdout and a file, if open.""" 89 | if self.file is not None: 90 | self.file.flush() 91 | 92 | self.stdout.flush() 93 | 94 | def close(self) -> None: 95 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 96 | self.flush() 97 | 98 | # if using multiple loggers, prevent closing in wrong order 99 | if sys.stdout is self: 100 | sys.stdout = self.stdout 101 | if sys.stderr is self: 102 | sys.stderr = self.stderr 103 | 104 | if self.file is not None: 105 | self.file.close() 106 | 107 | 108 | # Small util functions 109 | # ------------------------------------------------------------------------------------------ 110 | 111 | 112 | def format_time(seconds: Union[int, float]) -> str: 113 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 114 | s = int(np.rint(seconds)) 115 | 116 | if s < 60: 117 | return "{0}s".format(s) 118 | elif s < 60 * 60: 119 | return "{0}m {1:02}s".format(s // 60, s % 60) 120 | elif s < 24 * 60 * 60: 121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 122 | else: 123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 124 | 125 | 126 | def ask_yes_no(question: str) -> bool: 127 | """Ask the user the question until the user inputs a valid answer.""" 128 | while True: 129 | try: 130 | print("{0} [y/n]".format(question)) 131 | return strtobool(input().lower()) 132 | except ValueError: 133 | pass 134 | 135 | 136 | def tuple_product(t: Tuple) -> Any: 137 | """Calculate the product of the tuple elements.""" 138 | result = 1 139 | 140 | for v in t: 141 | result *= v 142 | 143 | return result 144 | 145 | 146 | _str_to_ctype = { 147 | "uint8": ctypes.c_ubyte, 148 | "uint16": ctypes.c_uint16, 149 | "uint32": ctypes.c_uint32, 150 | "uint64": ctypes.c_uint64, 151 | "int8": ctypes.c_byte, 152 | "int16": ctypes.c_int16, 153 | "int32": ctypes.c_int32, 154 | "int64": ctypes.c_int64, 155 | "float32": ctypes.c_float, 156 | "float64": ctypes.c_double 157 | } 158 | 159 | 160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 162 | type_str = None 163 | 164 | if isinstance(type_obj, str): 165 | type_str = type_obj 166 | elif hasattr(type_obj, "__name__"): 167 | type_str = type_obj.__name__ 168 | elif hasattr(type_obj, "name"): 169 | type_str = type_obj.name 170 | else: 171 | raise RuntimeError("Cannot infer type name from input") 172 | 173 | assert type_str in _str_to_ctype.keys() 174 | 175 | my_dtype = np.dtype(type_str) 176 | my_ctype = _str_to_ctype[type_str] 177 | 178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 179 | 180 | return my_dtype, my_ctype 181 | 182 | 183 | def is_pickleable(obj: Any) -> bool: 184 | try: 185 | with io.BytesIO() as stream: 186 | pickle.dump(obj, stream) 187 | return True 188 | except: 189 | return False 190 | 191 | 192 | # Functionality to import modules/objects by name, and call functions by name 193 | # ------------------------------------------------------------------------------------------ 194 | 195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 196 | """Searches for the underlying module behind the name to some python object. 197 | Returns the module and the object name (original name with module part removed).""" 198 | 199 | # allow convenience shorthands, substitute them by full names 200 | obj_name = re.sub("^np.", "numpy.", obj_name) 201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 202 | 203 | # list alternatives for (module_name, local_obj_name) 204 | parts = obj_name.split(".") 205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 206 | 207 | # try each alternative in turn 208 | for module_name, local_obj_name in name_pairs: 209 | try: 210 | module = importlib.import_module(module_name) # may raise ImportError 211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 212 | return module, local_obj_name 213 | except: 214 | pass 215 | 216 | # maybe some of the modules themselves contain errors? 217 | for module_name, _local_obj_name in name_pairs: 218 | try: 219 | importlib.import_module(module_name) # may raise ImportError 220 | except ImportError: 221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 222 | raise 223 | 224 | # maybe the requested attribute is missing? 225 | for module_name, local_obj_name in name_pairs: 226 | try: 227 | module = importlib.import_module(module_name) # may raise ImportError 228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 229 | except ImportError: 230 | pass 231 | 232 | # we are out of luck, but we have no idea why 233 | raise ImportError(obj_name) 234 | 235 | 236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 237 | """Traverses the object name and returns the last (rightmost) python object.""" 238 | if obj_name == '': 239 | return module 240 | obj = module 241 | for part in obj_name.split("."): 242 | obj = getattr(obj, part) 243 | return obj 244 | 245 | 246 | def get_obj_by_name(name: str) -> Any: 247 | """Finds the python object with the given name.""" 248 | module, obj_name = get_module_from_obj_name(name) 249 | return get_obj_from_module(module, obj_name) 250 | 251 | 252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 253 | """Finds the python object with the given name and calls it as a function.""" 254 | assert func_name is not None 255 | func_obj = get_obj_by_name(func_name) 256 | assert callable(func_obj) 257 | return func_obj(*args, **kwargs) 258 | 259 | 260 | def get_module_dir_by_obj_name(obj_name: str) -> str: 261 | """Get the directory path of the module containing the given object name.""" 262 | module, _ = get_module_from_obj_name(obj_name) 263 | return os.path.dirname(inspect.getfile(module)) 264 | 265 | 266 | def is_top_level_function(obj: Any) -> bool: 267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 269 | 270 | 271 | def get_top_level_function_name(obj: Any) -> str: 272 | """Return the fully-qualified name of a top-level function.""" 273 | assert is_top_level_function(obj) 274 | return obj.__module__ + "." + obj.__name__ 275 | 276 | 277 | # File system helpers 278 | # ------------------------------------------------------------------------------------------ 279 | 280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 281 | """List all files recursively in a given directory while ignoring given file and directory names. 282 | Returns list of tuples containing both absolute and relative paths.""" 283 | assert os.path.isdir(dir_path) 284 | base_name = os.path.basename(os.path.normpath(dir_path)) 285 | 286 | if ignores is None: 287 | ignores = [] 288 | 289 | result = [] 290 | 291 | for root, dirs, files in os.walk(dir_path, topdown=True): 292 | for ignore_ in ignores: 293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 294 | 295 | # dirs need to be edited in-place 296 | for d in dirs_to_remove: 297 | dirs.remove(d) 298 | 299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 300 | 301 | absolute_paths = [os.path.join(root, f) for f in files] 302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 303 | 304 | if add_base_to_relative: 305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 306 | 307 | assert len(absolute_paths) == len(relative_paths) 308 | result += zip(absolute_paths, relative_paths) 309 | 310 | return result 311 | 312 | 313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 314 | """Takes in a list of tuples of (src, dst) paths and copies files. 315 | Will create all necessary directories.""" 316 | for file in files: 317 | target_dir_name = os.path.dirname(file[1]) 318 | 319 | # will create all intermediate-level directories 320 | if not os.path.exists(target_dir_name): 321 | os.makedirs(target_dir_name) 322 | 323 | shutil.copyfile(file[0], file[1]) 324 | 325 | 326 | # URL helpers 327 | # ------------------------------------------------------------------------------------------ 328 | 329 | def is_url(obj: Any) -> bool: 330 | """Determine whether the given object is a valid URL string.""" 331 | if not isinstance(obj, str) or not "://" in obj: 332 | return False 333 | try: 334 | res = requests.compat.urlparse(obj) 335 | if not res.scheme or not res.netloc or not "." in res.netloc: 336 | return False 337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 338 | if not res.scheme or not res.netloc or not "." in res.netloc: 339 | return False 340 | except: 341 | return False 342 | return True 343 | 344 | 345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: 346 | """Download the given URL and return a binary-mode file object to access the data.""" 347 | assert is_url(url) 348 | assert num_attempts >= 1 349 | 350 | # Lookup from cache. 351 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 352 | if cache_dir is not None: 353 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 354 | if len(cache_files) == 1: 355 | return open(cache_files[0], "rb") 356 | 357 | # Download. 358 | url_name = None 359 | url_data = None 360 | with requests.Session() as session: 361 | if verbose: 362 | print("Downloading %s ..." % url, end="", flush=True) 363 | for attempts_left in reversed(range(num_attempts)): 364 | try: 365 | with session.get(url) as res: 366 | res.raise_for_status() 367 | if len(res.content) == 0: 368 | raise IOError("No data received") 369 | 370 | if len(res.content) < 8192: 371 | content_str = res.content.decode("utf-8") 372 | if "download_warning" in res.headers.get("Set-Cookie", ""): 373 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 374 | if len(links) == 1: 375 | url = requests.compat.urljoin(url, links[0]) 376 | raise IOError("Google Drive virus checker nag") 377 | if "Google Drive - Quota exceeded" in content_str: 378 | raise IOError("Google Drive quota exceeded") 379 | 380 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 381 | url_name = match[1] if match else url 382 | url_data = res.content 383 | if verbose: 384 | print(" done") 385 | break 386 | except: 387 | if not attempts_left: 388 | if verbose: 389 | print(" failed") 390 | raise 391 | if verbose: 392 | print(".", end="", flush=True) 393 | 394 | # Save to cache. 395 | if cache_dir is not None: 396 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 397 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 398 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 399 | os.makedirs(cache_dir, exist_ok=True) 400 | with open(temp_file, "wb") as f: 401 | f.write(url_data) 402 | os.replace(temp_file, cache_file) # atomic 403 | 404 | # Return data as file object. 405 | return io.BytesIO(url_data) 406 | -------------------------------------------------------------------------------- /dnnlib/tflib/network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for managing networks.""" 9 | 10 | import types 11 | import inspect 12 | import re 13 | import uuid 14 | import sys 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from collections import OrderedDict 19 | from typing import Any, List, Tuple, Union 20 | 21 | from . import tfutil 22 | from .. import util 23 | 24 | from .tfutil import TfExpression, TfExpressionEx 25 | 26 | _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. 27 | _import_module_src = dict() # Source code for temporary modules created during pickle import. 28 | 29 | 30 | def import_handler(handler_func): 31 | """Function decorator for declaring custom import handlers.""" 32 | _import_handlers.append(handler_func) 33 | return handler_func 34 | 35 | 36 | class Network: 37 | """Generic network abstraction. 38 | 39 | Acts as a convenience wrapper for a parameterized network construction 40 | function, providing several utility methods and convenient access to 41 | the inputs/outputs/weights. 42 | 43 | Network objects can be safely pickled and unpickled for long-term 44 | archival purposes. The pickling works reliably as long as the underlying 45 | network construction function is defined in a standalone Python module 46 | that has no side effects or application-specific imports. 47 | 48 | Args: 49 | name: Network name. Used to select TensorFlow name and variable scopes. 50 | func_name: Fully qualified name of the underlying network construction function, or a top-level function object. 51 | static_kwargs: Keyword arguments to be passed in to the network construction function. 52 | 53 | Attributes: 54 | name: User-specified name, defaults to build func name if None. 55 | scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name. 56 | static_kwargs: Arguments passed to the user-supplied build func. 57 | components: Container for sub-networks. Passed to the build func, and retained between calls. 58 | num_inputs: Number of input tensors. 59 | num_outputs: Number of output tensors. 60 | input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension. 61 | output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension. 62 | input_shape: Short-hand for input_shapes[0]. 63 | output_shape: Short-hand for output_shapes[0]. 64 | input_templates: Input placeholders in the template graph. 65 | output_templates: Output tensors in the template graph. 66 | input_names: Name string for each input. 67 | output_names: Name string for each output. 68 | own_vars: Variables defined by this network (local_name => var), excluding sub-networks. 69 | vars: All variables (local_name => var). 70 | trainables: All trainable variables (local_name => var). 71 | var_global_to_local: Mapping from variable global names to local names. 72 | """ 73 | 74 | def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): 75 | tfutil.assert_tf_initialized() 76 | assert isinstance(name, str) or name is None 77 | assert func_name is not None 78 | assert isinstance(func_name, str) or util.is_top_level_function(func_name) 79 | assert util.is_pickleable(static_kwargs) 80 | 81 | self._init_fields() 82 | self.name = name 83 | self.static_kwargs = util.EasyDict(static_kwargs) 84 | 85 | # Locate the user-specified network build function. 86 | if util.is_top_level_function(func_name): 87 | func_name = util.get_top_level_function_name(func_name) 88 | module, self._build_func_name = util.get_module_from_obj_name(func_name) 89 | self._build_func = util.get_obj_from_module(module, self._build_func_name) 90 | assert callable(self._build_func) 91 | 92 | # Dig up source code for the module containing the build function. 93 | self._build_module_src = _import_module_src.get(module, None) 94 | if self._build_module_src is None: 95 | self._build_module_src = inspect.getsource(module) 96 | 97 | # Init TensorFlow graph. 98 | self._init_graph() 99 | self.reset_own_vars() 100 | 101 | def _init_fields(self) -> None: 102 | self.name = None 103 | self.scope = None 104 | self.static_kwargs = util.EasyDict() 105 | self.components = util.EasyDict() 106 | self.num_inputs = 0 107 | self.num_outputs = 0 108 | self.input_shapes = [[]] 109 | self.output_shapes = [[]] 110 | self.input_shape = [] 111 | self.output_shape = [] 112 | self.input_templates = [] 113 | self.output_templates = [] 114 | self.input_names = [] 115 | self.output_names = [] 116 | self.own_vars = OrderedDict() 117 | self.vars = OrderedDict() 118 | self.trainables = OrderedDict() 119 | self.var_global_to_local = OrderedDict() 120 | 121 | self._build_func = None # User-supplied build function that constructs the network. 122 | self._build_func_name = None # Name of the build function. 123 | self._build_module_src = None # Full source code of the module containing the build function. 124 | self._run_cache = dict() # Cached graph data for Network.run(). 125 | 126 | def _init_graph(self) -> None: 127 | # Collect inputs. 128 | self.input_names = [] 129 | 130 | for param in inspect.signature(self._build_func).parameters.values(): 131 | if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: 132 | self.input_names.append(param.name) 133 | 134 | self.num_inputs = len(self.input_names) 135 | assert self.num_inputs >= 1 136 | 137 | # Choose name and scope. 138 | if self.name is None: 139 | self.name = self._build_func_name 140 | assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) 141 | with tf.name_scope(None): 142 | self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) 143 | 144 | # Finalize build func kwargs. 145 | build_kwargs = dict(self.static_kwargs) 146 | build_kwargs["is_template_graph"] = True 147 | build_kwargs["components"] = self.components 148 | 149 | # Build template graph. 150 | with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes 151 | assert tf.get_variable_scope().name == self.scope 152 | assert tf.get_default_graph().get_name_scope() == self.scope 153 | with tf.control_dependencies(None): # ignore surrounding control dependencies 154 | self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] 155 | out_expr = self._build_func(*self.input_templates, **build_kwargs) 156 | 157 | # Collect outputs. 158 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) 159 | self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) 160 | self.num_outputs = len(self.output_templates) 161 | assert self.num_outputs >= 1 162 | assert all(tfutil.is_tf_expression(t) for t in self.output_templates) 163 | 164 | # Perform sanity checks. 165 | if any(t.shape.ndims is None for t in self.input_templates): 166 | raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") 167 | if any(t.shape.ndims is None for t in self.output_templates): 168 | raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") 169 | if any(not isinstance(comp, Network) for comp in self.components.values()): 170 | raise ValueError("Components of a Network must be Networks themselves.") 171 | if len(self.components) != len(set(comp.name for comp in self.components.values())): 172 | raise ValueError("Components of a Network must have unique names.") 173 | 174 | # List inputs and outputs. 175 | self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates] 176 | self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates] 177 | self.input_shape = self.input_shapes[0] 178 | self.output_shape = self.output_shapes[0] 179 | self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] 180 | 181 | # List variables. 182 | self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) 183 | self.vars = OrderedDict(self.own_vars) 184 | self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) 185 | self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) 186 | self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) 187 | 188 | def reset_own_vars(self) -> None: 189 | """Re-initialize all variables of this network, excluding sub-networks.""" 190 | tfutil.run([var.initializer for var in self.own_vars.values()]) 191 | 192 | def reset_vars(self) -> None: 193 | """Re-initialize all variables of this network, including sub-networks.""" 194 | tfutil.run([var.initializer for var in self.vars.values()]) 195 | 196 | def reset_trainables(self) -> None: 197 | """Re-initialize all trainable variables of this network, including sub-networks.""" 198 | tfutil.run([var.initializer for var in self.trainables.values()]) 199 | 200 | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: 201 | """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" 202 | assert len(in_expr) == self.num_inputs 203 | assert not all(expr is None for expr in in_expr) 204 | 205 | # Finalize build func kwargs. 206 | build_kwargs = dict(self.static_kwargs) 207 | build_kwargs.update(dynamic_kwargs) 208 | build_kwargs["is_template_graph"] = False 209 | build_kwargs["components"] = self.components 210 | 211 | # Build TensorFlow graph to evaluate the network. 212 | with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): 213 | assert tf.get_variable_scope().name == self.scope 214 | valid_inputs = [expr for expr in in_expr if expr is not None] 215 | final_inputs = [] 216 | for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): 217 | if expr is not None: 218 | expr = tf.identity(expr, name=name) 219 | else: 220 | expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) 221 | final_inputs.append(expr) 222 | out_expr = self._build_func(*final_inputs, **build_kwargs) 223 | 224 | # Propagate input shapes back to the user-specified expressions. 225 | for expr, final in zip(in_expr, final_inputs): 226 | if isinstance(expr, tf.Tensor): 227 | expr.set_shape(final.shape) 228 | 229 | # Express outputs in the desired format. 230 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) 231 | if return_as_list: 232 | out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) 233 | return out_expr 234 | 235 | def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: 236 | """Get the local name of a given variable, without any surrounding name scopes.""" 237 | assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) 238 | global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name 239 | return self.var_global_to_local[global_name] 240 | 241 | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: 242 | """Find variable by local or global name.""" 243 | assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) 244 | return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name 245 | 246 | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: 247 | """Get the value of a given variable as NumPy array. 248 | Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" 249 | return self.find_var(var_or_local_name).eval() 250 | 251 | def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: 252 | """Set the value of a given variable based on the given NumPy array. 253 | Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" 254 | tfutil.set_vars({self.find_var(var_or_local_name): new_value}) 255 | 256 | def __getstate__(self) -> dict: 257 | """Pickle export.""" 258 | state = dict() 259 | state["version"] = 3 260 | state["name"] = self.name 261 | state["static_kwargs"] = dict(self.static_kwargs) 262 | state["components"] = dict(self.components) 263 | state["build_module_src"] = self._build_module_src 264 | state["build_func_name"] = self._build_func_name 265 | state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) 266 | return state 267 | 268 | def __setstate__(self, state: dict) -> None: 269 | """Pickle import.""" 270 | # pylint: disable=attribute-defined-outside-init 271 | tfutil.assert_tf_initialized() 272 | self._init_fields() 273 | 274 | # Execute custom import handlers. 275 | for handler in _import_handlers: 276 | state = handler(state) 277 | 278 | # Set basic fields. 279 | assert state["version"] in [2, 3] 280 | self.name = state["name"] 281 | self.static_kwargs = util.EasyDict(state["static_kwargs"]) 282 | self.components = util.EasyDict(state.get("components", {})) 283 | self._build_module_src = state["build_module_src"] 284 | self._build_func_name = state["build_func_name"] 285 | 286 | # Create temporary module from the imported source code. 287 | module_name = "_tflib_network_import_" + uuid.uuid4().hex 288 | module = types.ModuleType(module_name) 289 | sys.modules[module_name] = module 290 | _import_module_src[module] = self._build_module_src 291 | exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used 292 | 293 | # Locate network build function in the temporary module. 294 | self._build_func = util.get_obj_from_module(module, self._build_func_name) 295 | assert callable(self._build_func) 296 | 297 | # Init TensorFlow graph. 298 | self._init_graph() 299 | self.reset_own_vars() 300 | tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]}) 301 | 302 | def clone(self, name: str = None, **new_static_kwargs) -> "Network": 303 | """Create a clone of this network with its own copy of the variables.""" 304 | # pylint: disable=protected-access 305 | net = object.__new__(Network) 306 | net._init_fields() 307 | net.name = name if name is not None else self.name 308 | net.static_kwargs = util.EasyDict(self.static_kwargs) 309 | net.static_kwargs.update(new_static_kwargs) 310 | net._build_module_src = self._build_module_src 311 | net._build_func_name = self._build_func_name 312 | net._build_func = self._build_func 313 | net._init_graph() 314 | net.copy_vars_from(self) 315 | return net 316 | 317 | def copy_own_vars_from(self, src_net: "Network") -> None: 318 | """Copy the values of all variables from the given network, excluding sub-networks.""" 319 | names = [name for name in self.own_vars.keys() if name in src_net.own_vars] 320 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 321 | 322 | def copy_vars_from(self, src_net: "Network") -> None: 323 | """Copy the values of all variables from the given network, including sub-networks.""" 324 | names = [name for name in self.vars.keys() if name in src_net.vars] 325 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 326 | 327 | def copy_trainables_from(self, src_net: "Network") -> None: 328 | """Copy the values of all trainable variables from the given network, including sub-networks.""" 329 | names = [name for name in self.trainables.keys() if name in src_net.trainables] 330 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 331 | 332 | def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": 333 | """Create new network with the given parameters, and copy all variables from this network.""" 334 | if new_name is None: 335 | new_name = self.name 336 | static_kwargs = dict(self.static_kwargs) 337 | static_kwargs.update(new_static_kwargs) 338 | net = Network(name=new_name, func_name=new_func_name, **static_kwargs) 339 | net.copy_vars_from(self) 340 | return net 341 | 342 | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: 343 | """Construct a TensorFlow op that updates the variables of this network 344 | to be slightly closer to those of the given network.""" 345 | with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): 346 | ops = [] 347 | for name, var in self.vars.items(): 348 | if name in src_net.vars: 349 | cur_beta = beta if name in self.trainables else beta_nontrainable 350 | new_value = tfutil.lerp(src_net.vars[name], var, cur_beta) 351 | ops.append(var.assign(new_value)) 352 | return tf.group(*ops) 353 | 354 | def run(self, 355 | *in_arrays: Tuple[Union[np.ndarray, None], ...], 356 | input_transform: dict = None, 357 | output_transform: dict = None, 358 | return_as_list: bool = False, 359 | print_progress: bool = False, 360 | minibatch_size: int = None, 361 | num_gpus: int = 1, 362 | assume_frozen: bool = False, 363 | **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: 364 | """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). 365 | 366 | Args: 367 | input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. 368 | The dict must contain a 'func' field that points to a top-level function. The function is called with the input 369 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. 370 | output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. 371 | The dict must contain a 'func' field that points to a top-level function. The function is called with the output 372 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. 373 | return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. 374 | print_progress: Print progress to the console? Useful for very large input arrays. 375 | minibatch_size: Maximum minibatch size to use, None = disable batching. 376 | num_gpus: Number of GPUs to use. 377 | assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. 378 | dynamic_kwargs: Additional keyword arguments to be passed into the network build function. 379 | """ 380 | assert len(in_arrays) == self.num_inputs 381 | assert not all(arr is None for arr in in_arrays) 382 | assert input_transform is None or util.is_top_level_function(input_transform["func"]) 383 | assert output_transform is None or util.is_top_level_function(output_transform["func"]) 384 | output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) 385 | num_items = in_arrays[0].shape[0] 386 | if minibatch_size is None: 387 | minibatch_size = num_items 388 | 389 | # Construct unique hash key from all arguments that affect the TensorFlow graph. 390 | key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) 391 | def unwind_key(obj): 392 | if isinstance(obj, dict): 393 | return [(key, unwind_key(value)) for key, value in sorted(obj.items())] 394 | if callable(obj): 395 | return util.get_top_level_function_name(obj) 396 | return obj 397 | key = repr(unwind_key(key)) 398 | 399 | # Build graph. 400 | if key not in self._run_cache: 401 | with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): 402 | with tf.device("/cpu:0"): 403 | in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] 404 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) 405 | 406 | out_split = [] 407 | for gpu in range(num_gpus): 408 | with tf.device("/gpu:%d" % gpu): 409 | net_gpu = self.clone() if assume_frozen else self 410 | in_gpu = in_split[gpu] 411 | 412 | if input_transform is not None: 413 | in_kwargs = dict(input_transform) 414 | in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) 415 | in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) 416 | 417 | assert len(in_gpu) == self.num_inputs 418 | out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) 419 | 420 | if output_transform is not None: 421 | out_kwargs = dict(output_transform) 422 | out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) 423 | out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) 424 | 425 | assert len(out_gpu) == self.num_outputs 426 | out_split.append(out_gpu) 427 | 428 | with tf.device("/cpu:0"): 429 | out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] 430 | self._run_cache[key] = in_expr, out_expr 431 | 432 | # Run minibatches. 433 | in_expr, out_expr = self._run_cache[key] 434 | out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr] 435 | 436 | for mb_begin in range(0, num_items, minibatch_size): 437 | if print_progress: 438 | print("\r%d / %d" % (mb_begin, num_items), end="") 439 | 440 | mb_end = min(mb_begin + minibatch_size, num_items) 441 | mb_num = mb_end - mb_begin 442 | mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] 443 | mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) 444 | 445 | for dst, src in zip(out_arrays, mb_out): 446 | dst[mb_begin: mb_end] = src 447 | 448 | # Done. 449 | if print_progress: 450 | print("\r%d / %d" % (num_items, num_items)) 451 | 452 | if not return_as_list: 453 | out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) 454 | 455 | return out_arrays 456 | 457 | def list_ops(self) -> List[TfExpression]: 458 | include_prefix = self.scope + "/" 459 | exclude_prefix = include_prefix + "_" 460 | ops = tf.get_default_graph().get_operations() 461 | ops = [op for op in ops if op.name.startswith(include_prefix)] 462 | ops = [op for op in ops if not op.name.startswith(exclude_prefix)] 463 | return ops 464 | 465 | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: 466 | """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to 467 | individual layers of the network. Mainly intended to be used for reporting.""" 468 | layers = [] 469 | 470 | def recurse(scope, parent_ops, parent_vars, level): 471 | # Ignore specific patterns. 472 | if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): 473 | return 474 | 475 | # Filter ops and vars by scope. 476 | global_prefix = scope + "/" 477 | local_prefix = global_prefix[len(self.scope) + 1:] 478 | cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] 479 | cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] 480 | if not cur_ops and not cur_vars: 481 | return 482 | 483 | # Filter out all ops related to variables. 484 | for var in [op for op in cur_ops if op.type.startswith("Variable")]: 485 | var_prefix = var.name + "/" 486 | cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] 487 | 488 | # Scope does not contain ops as immediate children => recurse deeper. 489 | contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops) 490 | if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1: 491 | visited = set() 492 | for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: 493 | token = rel_name.split("/")[0] 494 | if token not in visited: 495 | recurse(global_prefix + token, cur_ops, cur_vars, level + 1) 496 | visited.add(token) 497 | return 498 | 499 | # Report layer. 500 | layer_name = scope[len(self.scope) + 1:] 501 | layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] 502 | layer_trainables = [var for _name, var in cur_vars if var.trainable] 503 | layers.append((layer_name, layer_output, layer_trainables)) 504 | 505 | recurse(self.scope, self.list_ops(), list(self.vars.items()), 0) 506 | return layers 507 | 508 | def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: 509 | """Print a summary table of the network structure.""" 510 | rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] 511 | rows += [["---"] * 4] 512 | total_params = 0 513 | 514 | for layer_name, layer_output, layer_trainables in self.list_layers(): 515 | num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables) 516 | weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] 517 | weights.sort(key=lambda x: len(x.name)) 518 | if len(weights) == 0 and len(layer_trainables) == 1: 519 | weights = layer_trainables 520 | total_params += num_params 521 | 522 | if not hide_layers_with_no_params or num_params != 0: 523 | num_params_str = str(num_params) if num_params > 0 else "-" 524 | output_shape_str = str(layer_output.shape) 525 | weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" 526 | rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] 527 | 528 | rows += [["---"] * 4] 529 | rows += [["Total", str(total_params), "", ""]] 530 | 531 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 532 | print() 533 | for row in rows: 534 | print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) 535 | print() 536 | 537 | def setup_weight_histograms(self, title: str = None) -> None: 538 | """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" 539 | if title is None: 540 | title = self.name 541 | 542 | with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): 543 | for local_name, var in self.trainables.items(): 544 | if "/" in local_name: 545 | p = local_name.split("/") 546 | name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) 547 | else: 548 | name = title + "_toplevel/" + local_name 549 | 550 | tf.summary.histogram(name, var) 551 | 552 | def save_my_model(self) -> None: 553 | # 这儿是新添加的保存模型的代码,为方便转化为tfjs 554 | # tf.train.Saver().save(tf.get_default_session(), 'networks/ckpt/my_net.ckpt') 555 | 556 | graph = tf.get_default_graph() # 获得默认的图 557 | input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 558 | for node in input_graph_def.node: # 把batchnorm给处理一下,不然不好转成tensorflowjs 559 | if node.op == 'RefSwitch': 560 | node.op = 'Switch' 561 | for index in range(len(node.input)): 562 | if 'moving_' in node.input[index]: 563 | node.input[index] = node.input[index] + '/read' 564 | elif node.op == 'AssignSub': 565 | node.op = 'Sub' 566 | if 'use_locking' in node.attr: del node.attr['use_locking'] 567 | with tf.get_default_session() as sess: 568 | tf.saved_model.simple_save(sess, "./模型转换/saved_model", 569 | inputs={self.input_names[0]: self.input_templates[0]}, 570 | outputs={self.output_names[0]: self.output_templates[0]}) 571 | 572 | #---------------------------------------------------------------------------- 573 | # Backwards-compatible emulation of legacy output transformation in Network.run(). 574 | 575 | _print_legacy_warning = True 576 | 577 | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): 578 | global _print_legacy_warning 579 | legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] 580 | if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): 581 | return output_transform, dynamic_kwargs 582 | 583 | if _print_legacy_warning: 584 | _print_legacy_warning = False 585 | print() 586 | print("WARNING: Old-style output transformations in Network.run() are deprecated.") 587 | print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") 588 | print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") 589 | print() 590 | assert output_transform is None 591 | 592 | new_kwargs = dict(dynamic_kwargs) 593 | new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} 594 | new_transform["func"] = _legacy_output_transform_func 595 | return new_transform, new_kwargs 596 | 597 | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): 598 | if out_mul != 1.0: 599 | expr = [x * out_mul for x in expr] 600 | 601 | if out_add != 0.0: 602 | expr = [x + out_add for x in expr] 603 | 604 | if out_shrink > 1: 605 | ksize = [1, 1, out_shrink, out_shrink] 606 | expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] 607 | 608 | if out_dtype is not None: 609 | if tf.as_dtype(out_dtype).is_integer: 610 | expr = [tf.round(x) for x in expr] 611 | expr = [tf.saturate_cast(x, out_dtype) for x in expr] 612 | return expr 613 | --------------------------------------------------------------------------------