├── result
└── 生成结果在这里
├── examples
├── edit_age.jpg
├── example1.png
├── example2.png
├── example3.png
├── edit_angle.jpg
├── edit_smile.jpg
├── 64_examples.jpg
├── edit_exposure.jpg
└── edit_gender.jpg
├── model
└── 模型下载后放在这里.txt
├── dnnlib
├── __pycache__
│ ├── util.cpython-36.pyc
│ └── __init__.cpython-36.pyc
├── tflib
│ ├── __pycache__
│ │ ├── tfutil.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── network.cpython-36.pyc
│ │ ├── autosummary.cpython-36.pyc
│ │ └── optimizer.cpython-36.pyc
│ ├── __init__.py
│ ├── autosummary.py
│ ├── tfutil.py
│ ├── optimizer.py
│ └── network.py
├── submission
│ ├── __pycache__
│ │ ├── submit.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── run_context.cpython-36.pyc
│ ├── __init__.py
│ ├── _internal
│ │ └── run.py
│ ├── run_context.py
│ └── submit.py
├── __init__.py
└── util.py
├── generate_model.py
└── README.md
/result/生成结果在这里:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/examples/edit_age.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_age.jpg
--------------------------------------------------------------------------------
/examples/example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/example1.png
--------------------------------------------------------------------------------
/examples/example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/example2.png
--------------------------------------------------------------------------------
/examples/example3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/example3.png
--------------------------------------------------------------------------------
/examples/edit_angle.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_angle.jpg
--------------------------------------------------------------------------------
/examples/edit_smile.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_smile.jpg
--------------------------------------------------------------------------------
/model/模型下载后放在这里.txt:
--------------------------------------------------------------------------------
1 | 百度网盘:
2 |
3 | 链接:https://pan.baidu.com/s/1_qeDNrmSrvzR3j-xqTa7zQ
4 | 提取码:rybs
5 |
6 |
7 |
--------------------------------------------------------------------------------
/examples/64_examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/64_examples.jpg
--------------------------------------------------------------------------------
/examples/edit_exposure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_exposure.jpg
--------------------------------------------------------------------------------
/examples/edit_gender.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/examples/edit_gender.jpg
--------------------------------------------------------------------------------
/dnnlib/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/network.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__pycache__/submit.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/submission/__pycache__/submit.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/submission/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__pycache__/run_context.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/seeprettyface-generator-model/HEAD/dnnlib/submission/__pycache__/run_context.cpython-36.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import run_context
9 | from . import submit
10 |
--------------------------------------------------------------------------------
/dnnlib/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import autosummary
9 | from . import network
10 | from . import optimizer
11 | from . import tfutil
12 |
13 | from .tfutil import *
14 | from .network import Network
15 |
16 | from .optimizer import Optimizer
17 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import submission
9 |
10 | from .submission.run_context import RunContext
11 |
12 | from .submission.submit import SubmitTarget
13 | from .submission.submit import PathType
14 | from .submission.submit import SubmitConfig
15 | from .submission.submit import get_path_from_template
16 | from .submission.submit import submit_run
17 |
18 | from .util import EasyDict
19 |
20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
21 |
--------------------------------------------------------------------------------
/dnnlib/submission/_internal/run.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for launching run functions in computing clusters.
9 |
10 | During the submit process, this file is copied to the appropriate run dir.
11 | When the job is launched in the cluster, this module is the first thing that
12 | is run inside the docker container.
13 | """
14 |
15 | import os
16 | import pickle
17 | import sys
18 |
19 | # PYTHONPATH should have been set so that the run_dir/src is in it
20 | import dnnlib
21 |
22 | def main():
23 | if not len(sys.argv) >= 4:
24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
25 |
26 | run_dir = str(sys.argv[1])
27 | task_name = str(sys.argv[2])
28 | host_name = str(sys.argv[3])
29 |
30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl")
31 |
32 | # SubmitConfig should have been pickled to the run dir
33 | if not os.path.exists(submit_config_path):
34 | raise RuntimeError("SubmitConfig pickle file does not exist!")
35 |
36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
38 |
39 | submit_config.task_name = task_name
40 | submit_config.host_name = host_name
41 |
42 | dnnlib.submission.submit.run_wrapper(submit_config)
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/generate_model.py:
--------------------------------------------------------------------------------
1 | # Thanks to StyleGAN provider —— Copyright (c) 2019, NVIDIA CORPORATION.
2 | #
3 | # This work is trained by Copyright(c) 2018, seeprettyface.com, BUPT_GWY.
4 |
5 | """Minimal script for generating an image using pre-trained StyleGAN generator."""
6 |
7 | import os
8 | import pickle
9 | import numpy as np
10 | import PIL.Image
11 | import dnnlib.tflib as tflib
12 |
13 | synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
14 |
15 | def text_save(file, data): # save generate code, which can be modified to generate customized style
16 | for i in range(len(data[0])):
17 | s = str(data[0][i])+'\n'
18 | file.write(s)
19 |
20 | def main():
21 | # Initialize TensorFlow.
22 | tflib.init_tf()
23 |
24 | # Load pre-trained network.
25 | model_path = 'model/generator_model.pkl'
26 |
27 | # Prepare result folder
28 | result_dir = 'result'
29 | os.makedirs(result_dir, exist_ok=True)
30 | os.makedirs(result_dir + '/generate_code', exist_ok=True)
31 |
32 | with open(model_path, "rb") as f:
33 | _G, _D, Gs = pickle.load(f, encoding='latin1')
34 |
35 | # Print network details.
36 | Gs.print_layers()
37 |
38 | # Generate pictures
39 | generate_num = 20
40 | for i in range(generate_num):
41 |
42 | # Generate latent.
43 | latents = np.random.randn(1, Gs.input_shape[1])
44 |
45 | # Save latent.
46 | txt_filename = os.path.join(result_dir, 'generate_code/' + str(i).zfill(4) + '.txt')
47 | file = open(txt_filename, 'w')
48 | text_save(file, latents)
49 |
50 | # Generate image.
51 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
52 | images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
53 |
54 | # Save image.
55 | png_filename = os.path.join(result_dir, str(i).zfill(4)+'.png')
56 | PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
57 |
58 | # Close file.
59 | file.close()
60 |
61 | if __name__ == "__main__":
62 | main()
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 超模脸生成器
2 |
3 | 更新:基于StyleGAN2制作的新版生成器消除了生成图片中水滴斑点和扭曲/损坏现象的出现,质量大幅提升。点此查看新版。
4 | --------------------------------------------------------------------------------------------------------------------
5 | 注明:之前做的一些有意思的人脸生成器,现在全部开源分享出来。它的主要作用是可生成制作各类型的人脸素材,供我们任意使用且无须担心人脸版权的问题。在定制人脸上,开源的全系列生成器包括:黄种人脸生成器,网红脸生成器,明星脸生成器,超模脸生成器和萌娃脸生成器,同时人脸属性编辑器能够对所有这些生成器生成的人物进行调整和改变。
6 | 此项目已免费开源使用,模型版权拥有者为:www.seeprettyface.com 。
7 |
8 | 这是一个用StyleGAN训练出的超模人脸生成器,生成效果如下所示。
9 |
10 | # 生成示例
11 |
12 | ## 单张样本
13 | 
14 | 
15 | 
16 |
17 | ## 概览(有筛选)
18 | 
19 |
20 | 查看更多的生成样本可以前往[这里](https://pan.baidu.com/s/1G5lTsk1TJPZMCHqudQqqYg)(提取码:2A5W),是一个含有1万张生成样本的超模脸数据集。
21 |
22 | # 超模脸属性编辑
23 | 人脸属性编辑支持在年龄、笑容、角度、性别和光照等23个维度上对生成人物作出调整(详细了解请前往[人脸属性编辑器](https://github.com/a312863063/seeprettyface-face_editor)处)。这儿只展示5种基本调整示例。
24 | ## 笑容调整
25 | 
26 |
27 | ## 年龄调整
28 | 
29 |
30 | ## 角度调整
31 | 
32 |
33 | ## 性别调整
34 | 
35 |
36 | ## 光照调整
37 | 
38 |
39 |
40 | # 运行代码
41 | ## 环境配置
42 | Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
43 | 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.
44 | TensorFlow 1.10.0 or newer with GPU support.
45 | NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.
46 |
47 | ## 运行步骤
48 | 1.在model文件夹中按照txt地址下载模型,放在该位置
49 | 2.运行generate_model.py
50 |
51 | ## 了解技术原理 & 获取训练集:[点此进入](http://www.seeprettyface.com/)
52 | 
53 |
54 | ## 小小的赞助~
55 |
56 |
57 |
58 | 若对您有帮助可给予小小的赞助~
59 |
60 |
61 |
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/dnnlib/submission/run_context.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helpers for managing the run/training loop."""
9 |
10 | import datetime
11 | import json
12 | import os
13 | import pprint
14 | import time
15 | import types
16 |
17 | from typing import Any
18 |
19 | from . import submit
20 |
21 |
22 | class RunContext(object):
23 | """Helper class for managing the run/training loop.
24 |
25 | The context will hide the implementation details of a basic run/training loop.
26 | It will set things up properly, tell if run should be stopped, and then cleans up.
27 | User should call update periodically and use should_stop to determine if run should be stopped.
28 |
29 | Args:
30 | submit_config: The SubmitConfig that is used for the current run.
31 | config_module: The whole config module that is used for the current run.
32 | max_epoch: Optional cached value for the max_epoch variable used in update.
33 | """
34 |
35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
36 | self.submit_config = submit_config
37 | self.should_stop_flag = False
38 | self.has_closed = False
39 | self.start_time = time.time()
40 | self.last_update_time = time.time()
41 | self.last_update_interval = 0.0
42 | self.max_epoch = max_epoch
43 |
44 | # pretty print the all the relevant content of the config module to a text file
45 | if config_module is not None:
46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
49 |
50 | # write out details about the run to a text file
51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
54 |
55 | def __enter__(self) -> "RunContext":
56 | return self
57 |
58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
59 | self.close()
60 |
61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
62 | """Do general housekeeping and keep the state of the context up-to-date.
63 | Should be called often enough but not in a tight loop."""
64 | assert not self.has_closed
65 |
66 | self.last_update_interval = time.time() - self.last_update_time
67 | self.last_update_time = time.time()
68 |
69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
70 | self.should_stop_flag = True
71 |
72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
73 |
74 | def should_stop(self) -> bool:
75 | """Tell whether a stopping condition has been triggered one way or another."""
76 | return self.should_stop_flag
77 |
78 | def get_time_since_start(self) -> float:
79 | """How much time has passed since the creation of the context."""
80 | return time.time() - self.start_time
81 |
82 | def get_time_since_last_update(self) -> float:
83 | """How much time has passed since the last call to update."""
84 | return time.time() - self.last_update_time
85 |
86 | def get_last_update_interval(self) -> float:
87 | """How much time passed between the previous two calls to update."""
88 | return self.last_update_interval
89 |
90 | def close(self) -> None:
91 | """Close the context and clean up.
92 | Should only be called once."""
93 | if not self.has_closed:
94 | # update the run.txt with stopping time
95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
98 |
99 | self.has_closed = True
100 |
--------------------------------------------------------------------------------
/dnnlib/tflib/autosummary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for adding automatically tracked values to Tensorboard.
9 |
10 | Autosummary creates an identity op that internally keeps track of the input
11 | values and automatically shows up in TensorBoard. The reported value
12 | represents an average over input components. The average is accumulated
13 | constantly over time and flushed when save_summaries() is called.
14 |
15 | Notes:
16 | - The output tensor must be used as an input for something else in the
17 | graph. Otherwise, the autosummary op will not get executed, and the average
18 | value will not get accumulated.
19 | - It is perfectly fine to include autosummaries with the same name in
20 | several places throughout the graph, even if they are executed concurrently.
21 | - It is ok to also pass in a python scalar or numpy array. In this case, it
22 | is added to the average immediately.
23 | """
24 |
25 | from collections import OrderedDict
26 | import numpy as np
27 | import tensorflow as tf
28 | from tensorboard import summary as summary_lib
29 | from tensorboard.plugins.custom_scalar import layout_pb2
30 |
31 | from . import tfutil
32 | from .tfutil import TfExpression
33 | from .tfutil import TfExpressionEx
34 |
35 | _dtype = tf.float64
36 | _vars = OrderedDict() # name => [var, ...]
37 | _immediate = OrderedDict() # name => update_op, update_value
38 | _finalized = False
39 | _merge_op = None
40 |
41 |
42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
43 | """Internal helper for creating autosummary accumulators."""
44 | assert not _finalized
45 | name_id = name.replace("/", "_")
46 | v = tf.cast(value_expr, _dtype)
47 |
48 | if v.shape.is_fully_defined():
49 | size = np.prod(tfutil.shape_to_list(v.shape))
50 | size_expr = tf.constant(size, dtype=_dtype)
51 | else:
52 | size = None
53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
54 |
55 | if size == 1:
56 | if v.shape.ndims != 0:
57 | v = tf.reshape(v, [])
58 | v = [size_expr, v, tf.square(v)]
59 | else:
60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
62 |
63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
66 |
67 | if name in _vars:
68 | _vars[name].append(var)
69 | else:
70 | _vars[name] = [var]
71 | return update_op
72 |
73 |
74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
75 | """Create a new autosummary.
76 |
77 | Args:
78 | name: Name to use in TensorBoard
79 | value: TensorFlow expression or python value to track
80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
81 |
82 | Example use of the passthru mechanism:
83 |
84 | n = autosummary('l2loss', loss, passthru=n)
85 |
86 | This is a shorthand for the following code:
87 |
88 | with tf.control_dependencies([autosummary('l2loss', loss)]):
89 | n = tf.identity(n)
90 | """
91 | tfutil.assert_tf_initialized()
92 | name_id = name.replace("/", "_")
93 |
94 | if tfutil.is_tf_expression(value):
95 | with tf.name_scope("summary_" + name_id), tf.device(value.device):
96 | update_op = _create_var(name, value)
97 | with tf.control_dependencies([update_op]):
98 | return tf.identity(value if passthru is None else passthru)
99 |
100 | else: # python scalar or numpy array
101 | if name not in _immediate:
102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
103 | update_value = tf.placeholder(_dtype)
104 | update_op = _create_var(name, update_value)
105 | _immediate[name] = update_op, update_value
106 |
107 | update_op, update_value = _immediate[name]
108 | tfutil.run(update_op, {update_value: value})
109 | return value if passthru is None else passthru
110 |
111 |
112 | def finalize_autosummaries() -> None:
113 | """Create the necessary ops to include autosummaries in TensorBoard report.
114 | Note: This should be done only once per graph.
115 | """
116 | global _finalized
117 | tfutil.assert_tf_initialized()
118 |
119 | if _finalized:
120 | return None
121 |
122 | _finalized = True
123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
124 |
125 | # Create summary ops.
126 | with tf.device(None), tf.control_dependencies(None):
127 | for name, vars_list in _vars.items():
128 | name_id = name.replace("/", "_")
129 | with tfutil.absolute_name_scope("Autosummary/" + name_id):
130 | moments = tf.add_n(vars_list)
131 | moments /= moments[0]
132 | with tf.control_dependencies([moments]): # read before resetting
133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
135 | mean = moments[1]
136 | std = tf.sqrt(moments[2] - tf.square(moments[1]))
137 | tf.summary.scalar(name, mean)
138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
140 |
141 | # Group by category and chart name.
142 | cat_dict = OrderedDict()
143 | for series_name in sorted(_vars.keys()):
144 | p = series_name.split("/")
145 | cat = p[0] if len(p) >= 2 else ""
146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
147 | if cat not in cat_dict:
148 | cat_dict[cat] = OrderedDict()
149 | if chart not in cat_dict[cat]:
150 | cat_dict[cat][chart] = []
151 | cat_dict[cat][chart].append(series_name)
152 |
153 | # Setup custom_scalar layout.
154 | categories = []
155 | for cat_name, chart_dict in cat_dict.items():
156 | charts = []
157 | for chart_name, series_names in chart_dict.items():
158 | series = []
159 | for series_name in series_names:
160 | series.append(layout_pb2.MarginChartContent.Series(
161 | value=series_name,
162 | lower="xCustomScalars/" + series_name + "/margin_lo",
163 | upper="xCustomScalars/" + series_name + "/margin_hi"))
164 | margin = layout_pb2.MarginChartContent(series=series)
165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts))
167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
168 | return layout
169 |
170 | def save_summaries(file_writer, global_step=None):
171 | """Call FileWriter.add_summary() with all summaries in the default graph,
172 | automatically finalizing and merging them on the first call.
173 | """
174 | global _merge_op
175 | tfutil.assert_tf_initialized()
176 |
177 | if _merge_op is None:
178 | layout = finalize_autosummaries()
179 | if layout is not None:
180 | file_writer.add_summary(layout)
181 | with tf.device(None), tf.control_dependencies(None):
182 | _merge_op = tf.summary.merge_all()
183 |
184 | file_writer.add_summary(_merge_op.eval(), global_step)
185 |
--------------------------------------------------------------------------------
/dnnlib/tflib/tfutil.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous helper utils for Tensorflow."""
9 |
10 | import os
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from typing import Any, Iterable, List, Union
15 |
16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17 | """A type that represents a valid Tensorflow expression."""
18 |
19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20 | """A type that can be converted to a valid Tensorflow expression."""
21 |
22 |
23 | def run(*args, **kwargs) -> Any:
24 | """Run the specified ops in the default session."""
25 | assert_tf_initialized()
26 | return tf.get_default_session().run(*args, **kwargs)
27 |
28 |
29 | def is_tf_expression(x: Any) -> bool:
30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32 |
33 |
34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
35 | """Convert a Tensorflow shape to a list of ints."""
36 | return [dim.value for dim in shape]
37 |
38 |
39 | def flatten(x: TfExpressionEx) -> TfExpression:
40 | """Shortcut function for flattening a tensor."""
41 | with tf.name_scope("Flatten"):
42 | return tf.reshape(x, [-1])
43 |
44 |
45 | def log2(x: TfExpressionEx) -> TfExpression:
46 | """Logarithm in base 2."""
47 | with tf.name_scope("Log2"):
48 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
49 |
50 |
51 | def exp2(x: TfExpressionEx) -> TfExpression:
52 | """Exponent in base 2."""
53 | with tf.name_scope("Exp2"):
54 | return tf.exp(x * np.float32(np.log(2.0)))
55 |
56 |
57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58 | """Linear interpolation."""
59 | with tf.name_scope("Lerp"):
60 | return a + (b - a) * t
61 |
62 |
63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64 | """Linear interpolation with clip."""
65 | with tf.name_scope("LerpClip"):
66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67 |
68 |
69 | def absolute_name_scope(scope: str) -> tf.name_scope:
70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71 | return tf.name_scope(scope + "/")
72 |
73 |
74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77 |
78 |
79 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
80 | # Defaults.
81 | cfg = dict()
82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87 |
88 | # User overrides.
89 | if config_dict is not None:
90 | cfg.update(config_dict)
91 | return cfg
92 |
93 |
94 | def init_tf(config_dict: dict = None) -> None:
95 | """Initialize TensorFlow session using good default settings."""
96 | # Skip if already initialized.
97 | if tf.get_default_session() is not None:
98 | return
99 |
100 | # Setup config dict and random seeds.
101 | cfg = _sanitize_tf_config(config_dict)
102 | np_random_seed = cfg["rnd.np_random_seed"]
103 | if np_random_seed is not None:
104 | np.random.seed(np_random_seed)
105 | tf_random_seed = cfg["rnd.tf_random_seed"]
106 | if tf_random_seed == "auto":
107 | tf_random_seed = np.random.randint(1 << 31)
108 | if tf_random_seed is not None:
109 | tf.set_random_seed(tf_random_seed)
110 |
111 | # Setup environment variables.
112 | for key, value in list(cfg.items()):
113 | fields = key.split(".")
114 | if fields[0] == "env":
115 | assert len(fields) == 2
116 | os.environ[fields[1]] = str(value)
117 |
118 | # Create default TensorFlow session.
119 | create_session(cfg, force_as_default=True)
120 |
121 |
122 | def assert_tf_initialized():
123 | """Check that TensorFlow session has been initialized."""
124 | if tf.get_default_session() is None:
125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126 |
127 |
128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
129 | """Create tf.Session based on config dict."""
130 | # Setup TensorFlow config proto.
131 | cfg = _sanitize_tf_config(config_dict)
132 | config_proto = tf.ConfigProto()
133 | for key, value in cfg.items():
134 | fields = key.split(".")
135 | if fields[0] not in ["rnd", "env"]:
136 | obj = config_proto
137 | for field in fields[:-1]:
138 | obj = getattr(obj, field)
139 | setattr(obj, fields[-1], value)
140 |
141 | # Create session.
142 | session = tf.Session(config=config_proto)
143 | if force_as_default:
144 | # pylint: disable=protected-access
145 | session._default_session = session.as_default()
146 | session._default_session.enforce_nesting = False
147 | session._default_session.__enter__() # pylint: disable=no-member
148 |
149 | return session
150 |
151 |
152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153 | """Initialize all tf.Variables that have not already been initialized.
154 |
155 | Equivalent to the following, but more efficient and does not bloat the tf graph:
156 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
157 | """
158 | assert_tf_initialized()
159 | if target_vars is None:
160 | target_vars = tf.global_variables()
161 |
162 | test_vars = []
163 | test_ops = []
164 |
165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
166 | for var in target_vars:
167 | assert is_tf_expression(var)
168 |
169 | try:
170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171 | except KeyError:
172 | # Op does not exist => variable may be uninitialized.
173 | test_vars.append(var)
174 |
175 | with absolute_name_scope(var.name.split(":")[0]):
176 | test_ops.append(tf.is_variable_initialized(var))
177 |
178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179 | run([var.initializer for var in init_vars])
180 |
181 |
182 | def set_vars(var_to_value_dict: dict) -> None:
183 | """Set the values of given tf.Variables.
184 |
185 | Equivalent to the following, but more efficient and does not bloat the tf graph:
186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187 | """
188 | assert_tf_initialized()
189 | ops = []
190 | feed_dict = {}
191 |
192 | for var, value in var_to_value_dict.items():
193 | assert is_tf_expression(var)
194 |
195 | try:
196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197 | except KeyError:
198 | with absolute_name_scope(var.name.split(":")[0]):
199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201 |
202 | ops.append(setter)
203 | feed_dict[setter.op.inputs[1]] = value
204 |
205 | run(ops, feed_dict)
206 |
207 |
208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209 | """Create tf.Variable with large initial value without bloating the tf graph."""
210 | assert_tf_initialized()
211 | assert isinstance(initial_value, np.ndarray)
212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213 | var = tf.Variable(zeros, *args, **kwargs)
214 | set_vars({var: initial_value})
215 | return var
216 |
217 |
218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220 | Can be used as an input transformation for Network.run().
221 | """
222 | images = tf.cast(images, tf.float32)
223 | if nhwc_to_nchw:
224 | images = tf.transpose(images, [0, 3, 1, 2])
225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226 |
227 |
228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230 | Can be used as an output transformation for Network.run().
231 | """
232 | images = tf.cast(images, tf.float32)
233 | if shrink > 1:
234 | ksize = [1, 1, shrink, shrink]
235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236 | if nchw_to_nhwc:
237 | images = tf.transpose(images, [0, 2, 3, 1])
238 | scale = 255 / (drange[1] - drange[0])
239 | images = images * scale + (0.5 - drange[0] * scale)
240 | return tf.saturate_cast(images, tf.uint8)
241 |
--------------------------------------------------------------------------------
/dnnlib/tflib/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper wrapper for a Tensorflow optimizer."""
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from collections import OrderedDict
14 | from typing import List, Union
15 |
16 | from . import autosummary
17 | from . import tfutil
18 | from .. import util
19 |
20 | from .tfutil import TfExpression, TfExpressionEx
21 |
22 | try:
23 | # TensorFlow 1.13
24 | from tensorflow.python.ops import nccl_ops
25 | except:
26 | # Older TensorFlow versions
27 | import tensorflow.contrib.nccl as nccl_ops
28 |
29 | class Optimizer:
30 | """A Wrapper for tf.train.Optimizer.
31 |
32 | Automatically takes care of:
33 | - Gradient averaging for multi-GPU training.
34 | - Dynamic loss scaling and typecasts for FP16 training.
35 | - Ignoring corrupted gradients that contain NaNs/Infs.
36 | - Reporting statistics.
37 | - Well-chosen default settings.
38 | """
39 |
40 | def __init__(self,
41 | name: str = "Train",
42 | tf_optimizer: str = "tf.train.AdamOptimizer",
43 | learning_rate: TfExpressionEx = 0.001,
44 | use_loss_scaling: bool = False,
45 | loss_scaling_init: float = 64.0,
46 | loss_scaling_inc: float = 0.0005,
47 | loss_scaling_dec: float = 1.0,
48 | **kwargs):
49 |
50 | # Init fields.
51 | self.name = name
52 | self.learning_rate = tf.convert_to_tensor(learning_rate)
53 | self.id = self.name.replace("/", ".")
54 | self.scope = tf.get_default_graph().unique_name(self.id)
55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer)
56 | self.optimizer_kwargs = dict(kwargs)
57 | self.use_loss_scaling = use_loss_scaling
58 | self.loss_scaling_init = loss_scaling_init
59 | self.loss_scaling_inc = loss_scaling_inc
60 | self.loss_scaling_dec = loss_scaling_dec
61 | self._grad_shapes = None # [shape, ...]
62 | self._dev_opt = OrderedDict() # device => optimizer
63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
65 | self._updates_applied = False
66 |
67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
68 | """Register the gradients of the given loss function with respect to the given variables.
69 | Intended to be called once per GPU."""
70 | assert not self._updates_applied
71 |
72 | # Validate arguments.
73 | if isinstance(trainable_vars, dict):
74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
75 |
76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
78 |
79 | if self._grad_shapes is None:
80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
81 |
82 | assert len(trainable_vars) == len(self._grad_shapes)
83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
84 |
85 | dev = loss.device
86 |
87 | assert all(var.device == dev for var in trainable_vars)
88 |
89 | # Register device and compute gradients.
90 | with tf.name_scope(self.id + "_grad"), tf.device(dev):
91 | if dev not in self._dev_opt:
92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
93 | assert callable(self.optimizer_class)
94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
95 | self._dev_grads[dev] = []
96 |
97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
100 | self._dev_grads[dev].append(grads)
101 |
102 | def apply_updates(self) -> tf.Operation:
103 | """Construct training op to update the registered variables based on their gradients."""
104 | tfutil.assert_tf_initialized()
105 | assert not self._updates_applied
106 | self._updates_applied = True
107 | devices = list(self._dev_grads.keys())
108 | total_grads = sum(len(grads) for grads in self._dev_grads.values())
109 | assert len(devices) >= 1 and total_grads >= 1
110 | ops = []
111 |
112 | with tfutil.absolute_name_scope(self.scope):
113 | # Cast gradients to FP32 and calculate partial sum within each device.
114 | dev_grads = OrderedDict() # device => [(grad, var), ...]
115 |
116 | for dev_idx, dev in enumerate(devices):
117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
118 | sums = []
119 |
120 | for gv in zip(*self._dev_grads[dev]):
121 | assert all(v is gv[0][1] for g, v in gv)
122 | g = [tf.cast(g, tf.float32) for g, v in gv]
123 | g = g[0] if len(g) == 1 else tf.add_n(g)
124 | sums.append((g, gv[0][1]))
125 |
126 | dev_grads[dev] = sums
127 |
128 | # Sum gradients across devices.
129 | if len(devices) > 1:
130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None):
131 | for var_idx, grad_shape in enumerate(self._grad_shapes):
132 | g = [dev_grads[dev][var_idx][0] for dev in devices]
133 |
134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors
135 | g = nccl_ops.all_sum(g)
136 |
137 | for dev, gg in zip(devices, g):
138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
139 |
140 | # Apply updates separately on each device.
141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
143 | # Scale gradients as needed.
144 | if self.use_loss_scaling or total_grads > 1:
145 | with tf.name_scope("Scale"):
146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
147 | coef = self.undo_loss_scaling(coef)
148 | grads = [(g * coef, v) for g, v in grads]
149 |
150 | # Check for overflows.
151 | with tf.name_scope("CheckOverflow"):
152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
153 |
154 | # Update weights and adjust loss scaling.
155 | with tf.name_scope("UpdateWeights"):
156 | # pylint: disable=cell-var-from-loop
157 | opt = self._dev_opt[dev]
158 | ls_var = self.get_loss_scaling_var(dev)
159 |
160 | if not self.use_loss_scaling:
161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
162 | else:
163 | ops.append(tf.cond(grad_ok,
164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
166 |
167 | # Report statistics on the last device.
168 | if dev == devices[-1]:
169 | with tf.name_scope("Statistics"):
170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
172 |
173 | if self.use_loss_scaling:
174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
175 |
176 | # Initialize variables and group everything into a single op.
177 | self.reset_optimizer_state()
178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
179 |
180 | return tf.group(*ops, name="TrainingOp")
181 |
182 | def reset_optimizer_state(self) -> None:
183 | """Reset internal state of the underlying optimizer."""
184 | tfutil.assert_tf_initialized()
185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
186 |
187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
188 | """Get or create variable representing log2 of the current dynamic loss scaling factor."""
189 | if not self.use_loss_scaling:
190 | return None
191 |
192 | if device not in self._dev_ls_var:
193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
195 |
196 | return self._dev_ls_var[device]
197 |
198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
199 | """Apply dynamic loss scaling for the given expression."""
200 | assert tfutil.is_tf_expression(value)
201 |
202 | if not self.use_loss_scaling:
203 | return value
204 |
205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
206 |
207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
208 | """Undo the effect of dynamic loss scaling for the given expression."""
209 | assert tfutil.is_tf_expression(value)
210 |
211 | if not self.use_loss_scaling:
212 | return value
213 |
214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
215 |
--------------------------------------------------------------------------------
/dnnlib/submission/submit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Submit a function to be run either locally or in a computing cluster."""
9 |
10 | import copy
11 | import io
12 | import os
13 | import pathlib
14 | import pickle
15 | import platform
16 | import pprint
17 | import re
18 | import shutil
19 | import time
20 | import traceback
21 |
22 | import zipfile
23 |
24 | from enum import Enum
25 |
26 | from .. import util
27 | from ..util import EasyDict
28 |
29 |
30 | class SubmitTarget(Enum):
31 | """The target where the function should be run.
32 |
33 | LOCAL: Run it locally.
34 | """
35 | LOCAL = 1
36 |
37 |
38 | class PathType(Enum):
39 | """Determines in which format should a path be formatted.
40 |
41 | WINDOWS: Format with Windows style.
42 | LINUX: Format with Linux/Posix style.
43 | AUTO: Use current OS type to select either WINDOWS or LINUX.
44 | """
45 | WINDOWS = 1
46 | LINUX = 2
47 | AUTO = 3
48 |
49 |
50 | _user_name_override = None
51 |
52 |
53 | class SubmitConfig(util.EasyDict):
54 | """Strongly typed config dict needed to submit runs.
55 |
56 | Attributes:
57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
58 | run_desc: Description of the run. Will be used in the run dir and task name.
59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
61 | submit_target: Submit target enum value. Used to select where the run is actually launched.
62 | num_gpus: Number of GPUs used/requested for the run.
63 | print_info: Whether to print debug information when submitting.
64 | ask_confirmation: Whether to ask a confirmation before submitting.
65 | run_id: Automatically populated value during submit.
66 | run_name: Automatically populated value during submit.
67 | run_dir: Automatically populated value during submit.
68 | run_func_name: Automatically populated value during submit.
69 | run_func_kwargs: Automatically populated value during submit.
70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
71 | task_name: Automatically populated value during submit.
72 | host_name: Automatically populated value during submit.
73 | """
74 |
75 | def __init__(self):
76 | super().__init__()
77 |
78 | # run (set these)
79 | self.run_dir_root = "" # should always be passed through get_path_from_template
80 | self.run_desc = ""
81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
82 | self.run_dir_extra_files = None
83 |
84 | # submit (set these)
85 | self.submit_target = SubmitTarget.LOCAL
86 | self.num_gpus = 1
87 | self.print_info = False
88 | self.ask_confirmation = False
89 |
90 | # (automatically populated)
91 | self.run_id = None
92 | self.run_name = None
93 | self.run_dir = None
94 | self.run_func_name = None
95 | self.run_func_kwargs = None
96 | self.user_name = None
97 | self.task_name = None
98 | self.host_name = "localhost"
99 |
100 |
101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
102 | """Replace tags in the given path template and return either Windows or Linux formatted path."""
103 | # automatically select path type depending on running OS
104 | if path_type == PathType.AUTO:
105 | if platform.system() == "Windows":
106 | path_type = PathType.WINDOWS
107 | elif platform.system() == "Linux":
108 | path_type = PathType.LINUX
109 | else:
110 | raise RuntimeError("Unknown platform")
111 |
112 | path_template = path_template.replace("", get_user_name())
113 |
114 | # return correctly formatted path
115 | if path_type == PathType.WINDOWS:
116 | return str(pathlib.PureWindowsPath(path_template))
117 | elif path_type == PathType.LINUX:
118 | return str(pathlib.PurePosixPath(path_template))
119 | else:
120 | raise RuntimeError("Unknown platform")
121 |
122 |
123 | def get_template_from_path(path: str) -> str:
124 | """Convert a normal path back to its template representation."""
125 | # replace all path parts with the template tags
126 | path = path.replace("\\", "/")
127 | return path
128 |
129 |
130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
131 | """Convert a normal path to template and the convert it back to a normal path with given path type."""
132 | path_template = get_template_from_path(path)
133 | path = get_path_from_template(path_template, path_type)
134 | return path
135 |
136 |
137 | def set_user_name_override(name: str) -> None:
138 | """Set the global username override value."""
139 | global _user_name_override
140 | _user_name_override = name
141 |
142 |
143 | def get_user_name():
144 | """Get the current user name."""
145 | if _user_name_override is not None:
146 | return _user_name_override
147 | elif platform.system() == "Windows":
148 | return os.getlogin()
149 | elif platform.system() == "Linux":
150 | try:
151 | import pwd # pylint: disable=import-error
152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
153 | except:
154 | return "unknown"
155 | else:
156 | raise RuntimeError("Unknown platform")
157 |
158 |
159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str:
160 | """Create a new run dir with increasing ID number at the start."""
161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
162 |
163 | if not os.path.exists(run_dir_root):
164 | print("Creating the run dir root: {}".format(run_dir_root))
165 | os.makedirs(run_dir_root)
166 |
167 | submit_config.run_id = _get_next_run_id_local(run_dir_root)
168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
169 | run_dir = os.path.join(run_dir_root, submit_config.run_name)
170 |
171 | if os.path.exists(run_dir):
172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
173 |
174 | print("Creating the run dir: {}".format(run_dir))
175 | os.makedirs(run_dir)
176 |
177 | return run_dir
178 |
179 |
180 | def _get_next_run_id_local(run_dir_root: str) -> int:
181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
183 | r = re.compile("^\\d+") # match one or more digits at the start of the string
184 | run_id = 0
185 |
186 | for dir_name in dir_names:
187 | m = r.match(dir_name)
188 |
189 | if m is not None:
190 | i = int(m.group())
191 | run_id = max(run_id, i + 1)
192 |
193 | return run_id
194 |
195 |
196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
198 | print("Copying files to the run dir")
199 | files = []
200 |
201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
202 | assert '.' in submit_config.run_func_name
203 | for _idx in range(submit_config.run_func_name.count('.') - 1):
204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
206 |
207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
209 |
210 | if submit_config.run_dir_extra_files is not None:
211 | files += submit_config.run_dir_extra_files
212 |
213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
215 |
216 | util.copy_files_and_create_dirs(files)
217 |
218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
219 |
220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
222 |
223 |
224 | def run_wrapper(submit_config: SubmitConfig) -> None:
225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL
227 |
228 | checker = None
229 |
230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
231 | if is_local:
232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
234 | logger = util.Logger(file_name=None, should_flush=True)
235 |
236 | import dnnlib
237 | dnnlib.submit_config = submit_config
238 |
239 | try:
240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
241 | start_time = time.time()
242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
244 | except:
245 | if is_local:
246 | raise
247 | else:
248 | traceback.print_exc()
249 |
250 | log_src = os.path.join(submit_config.run_dir, "log.txt")
251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
252 | shutil.copyfile(log_src, log_dst)
253 | finally:
254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
255 |
256 | dnnlib.submit_config = None
257 | logger.close()
258 |
259 | if checker is not None:
260 | checker.stop()
261 |
262 |
263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
265 | submit_config = copy.copy(submit_config)
266 |
267 | if submit_config.user_name is None:
268 | submit_config.user_name = get_user_name()
269 |
270 | submit_config.run_func_name = run_func_name
271 | submit_config.run_func_kwargs = run_func_kwargs
272 |
273 | assert submit_config.submit_target == SubmitTarget.LOCAL
274 | if submit_config.submit_target in {SubmitTarget.LOCAL}:
275 | run_dir = _create_run_dir_local(submit_config)
276 |
277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
278 | submit_config.run_dir = run_dir
279 | _populate_run_dir(run_dir, submit_config)
280 |
281 | if submit_config.print_info:
282 | print("\nSubmit config:\n")
283 | pprint.pprint(submit_config, indent=4, width=200, compact=False)
284 | print()
285 |
286 | if submit_config.ask_confirmation:
287 | if not util.ask_yes_no("Continue submitting the job?"):
288 | return
289 |
290 | run_wrapper(submit_config)
291 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import uuid
27 |
28 | from distutils.util import strtobool
29 | from typing import Any, List, Tuple, Union
30 |
31 |
32 | # Util classes
33 | # ------------------------------------------------------------------------------------------
34 |
35 |
36 | class EasyDict(dict):
37 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
38 |
39 | def __getattr__(self, name: str) -> Any:
40 | try:
41 | return self[name]
42 | except KeyError:
43 | raise AttributeError(name)
44 |
45 | def __setattr__(self, name: str, value: Any) -> None:
46 | self[name] = value
47 |
48 | def __delattr__(self, name: str) -> None:
49 | del self[name]
50 |
51 |
52 | class Logger(object):
53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
54 |
55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
56 | self.file = None
57 |
58 | if file_name is not None:
59 | self.file = open(file_name, file_mode)
60 |
61 | self.should_flush = should_flush
62 | self.stdout = sys.stdout
63 | self.stderr = sys.stderr
64 |
65 | sys.stdout = self
66 | sys.stderr = self
67 |
68 | def __enter__(self) -> "Logger":
69 | return self
70 |
71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
72 | self.close()
73 |
74 | def write(self, text: str) -> None:
75 | """Write text to stdout (and a file) and optionally flush."""
76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
77 | return
78 |
79 | if self.file is not None:
80 | self.file.write(text)
81 |
82 | self.stdout.write(text)
83 |
84 | if self.should_flush:
85 | self.flush()
86 |
87 | def flush(self) -> None:
88 | """Flush written text to both stdout and a file, if open."""
89 | if self.file is not None:
90 | self.file.flush()
91 |
92 | self.stdout.flush()
93 |
94 | def close(self) -> None:
95 | """Flush, close possible files, and remove stdout/stderr mirroring."""
96 | self.flush()
97 |
98 | # if using multiple loggers, prevent closing in wrong order
99 | if sys.stdout is self:
100 | sys.stdout = self.stdout
101 | if sys.stderr is self:
102 | sys.stderr = self.stderr
103 |
104 | if self.file is not None:
105 | self.file.close()
106 |
107 |
108 | # Small util functions
109 | # ------------------------------------------------------------------------------------------
110 |
111 |
112 | def format_time(seconds: Union[int, float]) -> str:
113 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
114 | s = int(np.rint(seconds))
115 |
116 | if s < 60:
117 | return "{0}s".format(s)
118 | elif s < 60 * 60:
119 | return "{0}m {1:02}s".format(s // 60, s % 60)
120 | elif s < 24 * 60 * 60:
121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
122 | else:
123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
124 |
125 |
126 | def ask_yes_no(question: str) -> bool:
127 | """Ask the user the question until the user inputs a valid answer."""
128 | while True:
129 | try:
130 | print("{0} [y/n]".format(question))
131 | return strtobool(input().lower())
132 | except ValueError:
133 | pass
134 |
135 |
136 | def tuple_product(t: Tuple) -> Any:
137 | """Calculate the product of the tuple elements."""
138 | result = 1
139 |
140 | for v in t:
141 | result *= v
142 |
143 | return result
144 |
145 |
146 | _str_to_ctype = {
147 | "uint8": ctypes.c_ubyte,
148 | "uint16": ctypes.c_uint16,
149 | "uint32": ctypes.c_uint32,
150 | "uint64": ctypes.c_uint64,
151 | "int8": ctypes.c_byte,
152 | "int16": ctypes.c_int16,
153 | "int32": ctypes.c_int32,
154 | "int64": ctypes.c_int64,
155 | "float32": ctypes.c_float,
156 | "float64": ctypes.c_double
157 | }
158 |
159 |
160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
162 | type_str = None
163 |
164 | if isinstance(type_obj, str):
165 | type_str = type_obj
166 | elif hasattr(type_obj, "__name__"):
167 | type_str = type_obj.__name__
168 | elif hasattr(type_obj, "name"):
169 | type_str = type_obj.name
170 | else:
171 | raise RuntimeError("Cannot infer type name from input")
172 |
173 | assert type_str in _str_to_ctype.keys()
174 |
175 | my_dtype = np.dtype(type_str)
176 | my_ctype = _str_to_ctype[type_str]
177 |
178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
179 |
180 | return my_dtype, my_ctype
181 |
182 |
183 | def is_pickleable(obj: Any) -> bool:
184 | try:
185 | with io.BytesIO() as stream:
186 | pickle.dump(obj, stream)
187 | return True
188 | except:
189 | return False
190 |
191 |
192 | # Functionality to import modules/objects by name, and call functions by name
193 | # ------------------------------------------------------------------------------------------
194 |
195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
196 | """Searches for the underlying module behind the name to some python object.
197 | Returns the module and the object name (original name with module part removed)."""
198 |
199 | # allow convenience shorthands, substitute them by full names
200 | obj_name = re.sub("^np.", "numpy.", obj_name)
201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
202 |
203 | # list alternatives for (module_name, local_obj_name)
204 | parts = obj_name.split(".")
205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
206 |
207 | # try each alternative in turn
208 | for module_name, local_obj_name in name_pairs:
209 | try:
210 | module = importlib.import_module(module_name) # may raise ImportError
211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
212 | return module, local_obj_name
213 | except:
214 | pass
215 |
216 | # maybe some of the modules themselves contain errors?
217 | for module_name, _local_obj_name in name_pairs:
218 | try:
219 | importlib.import_module(module_name) # may raise ImportError
220 | except ImportError:
221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
222 | raise
223 |
224 | # maybe the requested attribute is missing?
225 | for module_name, local_obj_name in name_pairs:
226 | try:
227 | module = importlib.import_module(module_name) # may raise ImportError
228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
229 | except ImportError:
230 | pass
231 |
232 | # we are out of luck, but we have no idea why
233 | raise ImportError(obj_name)
234 |
235 |
236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
237 | """Traverses the object name and returns the last (rightmost) python object."""
238 | if obj_name == '':
239 | return module
240 | obj = module
241 | for part in obj_name.split("."):
242 | obj = getattr(obj, part)
243 | return obj
244 |
245 |
246 | def get_obj_by_name(name: str) -> Any:
247 | """Finds the python object with the given name."""
248 | module, obj_name = get_module_from_obj_name(name)
249 | return get_obj_from_module(module, obj_name)
250 |
251 |
252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
253 | """Finds the python object with the given name and calls it as a function."""
254 | assert func_name is not None
255 | func_obj = get_obj_by_name(func_name)
256 | assert callable(func_obj)
257 | return func_obj(*args, **kwargs)
258 |
259 |
260 | def get_module_dir_by_obj_name(obj_name: str) -> str:
261 | """Get the directory path of the module containing the given object name."""
262 | module, _ = get_module_from_obj_name(obj_name)
263 | return os.path.dirname(inspect.getfile(module))
264 |
265 |
266 | def is_top_level_function(obj: Any) -> bool:
267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
269 |
270 |
271 | def get_top_level_function_name(obj: Any) -> str:
272 | """Return the fully-qualified name of a top-level function."""
273 | assert is_top_level_function(obj)
274 | return obj.__module__ + "." + obj.__name__
275 |
276 |
277 | # File system helpers
278 | # ------------------------------------------------------------------------------------------
279 |
280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
281 | """List all files recursively in a given directory while ignoring given file and directory names.
282 | Returns list of tuples containing both absolute and relative paths."""
283 | assert os.path.isdir(dir_path)
284 | base_name = os.path.basename(os.path.normpath(dir_path))
285 |
286 | if ignores is None:
287 | ignores = []
288 |
289 | result = []
290 |
291 | for root, dirs, files in os.walk(dir_path, topdown=True):
292 | for ignore_ in ignores:
293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
294 |
295 | # dirs need to be edited in-place
296 | for d in dirs_to_remove:
297 | dirs.remove(d)
298 |
299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
300 |
301 | absolute_paths = [os.path.join(root, f) for f in files]
302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
303 |
304 | if add_base_to_relative:
305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
306 |
307 | assert len(absolute_paths) == len(relative_paths)
308 | result += zip(absolute_paths, relative_paths)
309 |
310 | return result
311 |
312 |
313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
314 | """Takes in a list of tuples of (src, dst) paths and copies files.
315 | Will create all necessary directories."""
316 | for file in files:
317 | target_dir_name = os.path.dirname(file[1])
318 |
319 | # will create all intermediate-level directories
320 | if not os.path.exists(target_dir_name):
321 | os.makedirs(target_dir_name)
322 |
323 | shutil.copyfile(file[0], file[1])
324 |
325 |
326 | # URL helpers
327 | # ------------------------------------------------------------------------------------------
328 |
329 | def is_url(obj: Any) -> bool:
330 | """Determine whether the given object is a valid URL string."""
331 | if not isinstance(obj, str) or not "://" in obj:
332 | return False
333 | try:
334 | res = requests.compat.urlparse(obj)
335 | if not res.scheme or not res.netloc or not "." in res.netloc:
336 | return False
337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
338 | if not res.scheme or not res.netloc or not "." in res.netloc:
339 | return False
340 | except:
341 | return False
342 | return True
343 |
344 |
345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
346 | """Download the given URL and return a binary-mode file object to access the data."""
347 | assert is_url(url)
348 | assert num_attempts >= 1
349 |
350 | # Lookup from cache.
351 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
352 | if cache_dir is not None:
353 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
354 | if len(cache_files) == 1:
355 | return open(cache_files[0], "rb")
356 |
357 | # Download.
358 | url_name = None
359 | url_data = None
360 | with requests.Session() as session:
361 | if verbose:
362 | print("Downloading %s ..." % url, end="", flush=True)
363 | for attempts_left in reversed(range(num_attempts)):
364 | try:
365 | with session.get(url) as res:
366 | res.raise_for_status()
367 | if len(res.content) == 0:
368 | raise IOError("No data received")
369 |
370 | if len(res.content) < 8192:
371 | content_str = res.content.decode("utf-8")
372 | if "download_warning" in res.headers.get("Set-Cookie", ""):
373 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
374 | if len(links) == 1:
375 | url = requests.compat.urljoin(url, links[0])
376 | raise IOError("Google Drive virus checker nag")
377 | if "Google Drive - Quota exceeded" in content_str:
378 | raise IOError("Google Drive quota exceeded")
379 |
380 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
381 | url_name = match[1] if match else url
382 | url_data = res.content
383 | if verbose:
384 | print(" done")
385 | break
386 | except:
387 | if not attempts_left:
388 | if verbose:
389 | print(" failed")
390 | raise
391 | if verbose:
392 | print(".", end="", flush=True)
393 |
394 | # Save to cache.
395 | if cache_dir is not None:
396 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
397 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
398 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
399 | os.makedirs(cache_dir, exist_ok=True)
400 | with open(temp_file, "wb") as f:
401 | f.write(url_data)
402 | os.replace(temp_file, cache_file) # atomic
403 |
404 | # Return data as file object.
405 | return io.BytesIO(url_data)
406 |
--------------------------------------------------------------------------------
/dnnlib/tflib/network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for managing networks."""
9 |
10 | import types
11 | import inspect
12 | import re
13 | import uuid
14 | import sys
15 | import numpy as np
16 | import tensorflow as tf
17 |
18 | from collections import OrderedDict
19 | from typing import Any, List, Tuple, Union
20 |
21 | from . import tfutil
22 | from .. import util
23 |
24 | from .tfutil import TfExpression, TfExpressionEx
25 |
26 | _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
27 | _import_module_src = dict() # Source code for temporary modules created during pickle import.
28 |
29 |
30 | def import_handler(handler_func):
31 | """Function decorator for declaring custom import handlers."""
32 | _import_handlers.append(handler_func)
33 | return handler_func
34 |
35 |
36 | class Network:
37 | """Generic network abstraction.
38 |
39 | Acts as a convenience wrapper for a parameterized network construction
40 | function, providing several utility methods and convenient access to
41 | the inputs/outputs/weights.
42 |
43 | Network objects can be safely pickled and unpickled for long-term
44 | archival purposes. The pickling works reliably as long as the underlying
45 | network construction function is defined in a standalone Python module
46 | that has no side effects or application-specific imports.
47 |
48 | Args:
49 | name: Network name. Used to select TensorFlow name and variable scopes.
50 | func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
51 | static_kwargs: Keyword arguments to be passed in to the network construction function.
52 |
53 | Attributes:
54 | name: User-specified name, defaults to build func name if None.
55 | scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
56 | static_kwargs: Arguments passed to the user-supplied build func.
57 | components: Container for sub-networks. Passed to the build func, and retained between calls.
58 | num_inputs: Number of input tensors.
59 | num_outputs: Number of output tensors.
60 | input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
61 | output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
62 | input_shape: Short-hand for input_shapes[0].
63 | output_shape: Short-hand for output_shapes[0].
64 | input_templates: Input placeholders in the template graph.
65 | output_templates: Output tensors in the template graph.
66 | input_names: Name string for each input.
67 | output_names: Name string for each output.
68 | own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
69 | vars: All variables (local_name => var).
70 | trainables: All trainable variables (local_name => var).
71 | var_global_to_local: Mapping from variable global names to local names.
72 | """
73 |
74 | def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
75 | tfutil.assert_tf_initialized()
76 | assert isinstance(name, str) or name is None
77 | assert func_name is not None
78 | assert isinstance(func_name, str) or util.is_top_level_function(func_name)
79 | assert util.is_pickleable(static_kwargs)
80 |
81 | self._init_fields()
82 | self.name = name
83 | self.static_kwargs = util.EasyDict(static_kwargs)
84 |
85 | # Locate the user-specified network build function.
86 | if util.is_top_level_function(func_name):
87 | func_name = util.get_top_level_function_name(func_name)
88 | module, self._build_func_name = util.get_module_from_obj_name(func_name)
89 | self._build_func = util.get_obj_from_module(module, self._build_func_name)
90 | assert callable(self._build_func)
91 |
92 | # Dig up source code for the module containing the build function.
93 | self._build_module_src = _import_module_src.get(module, None)
94 | if self._build_module_src is None:
95 | self._build_module_src = inspect.getsource(module)
96 |
97 | # Init TensorFlow graph.
98 | self._init_graph()
99 | self.reset_own_vars()
100 |
101 | def _init_fields(self) -> None:
102 | self.name = None
103 | self.scope = None
104 | self.static_kwargs = util.EasyDict()
105 | self.components = util.EasyDict()
106 | self.num_inputs = 0
107 | self.num_outputs = 0
108 | self.input_shapes = [[]]
109 | self.output_shapes = [[]]
110 | self.input_shape = []
111 | self.output_shape = []
112 | self.input_templates = []
113 | self.output_templates = []
114 | self.input_names = []
115 | self.output_names = []
116 | self.own_vars = OrderedDict()
117 | self.vars = OrderedDict()
118 | self.trainables = OrderedDict()
119 | self.var_global_to_local = OrderedDict()
120 |
121 | self._build_func = None # User-supplied build function that constructs the network.
122 | self._build_func_name = None # Name of the build function.
123 | self._build_module_src = None # Full source code of the module containing the build function.
124 | self._run_cache = dict() # Cached graph data for Network.run().
125 |
126 | def _init_graph(self) -> None:
127 | # Collect inputs.
128 | self.input_names = []
129 |
130 | for param in inspect.signature(self._build_func).parameters.values():
131 | if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
132 | self.input_names.append(param.name)
133 |
134 | self.num_inputs = len(self.input_names)
135 | assert self.num_inputs >= 1
136 |
137 | # Choose name and scope.
138 | if self.name is None:
139 | self.name = self._build_func_name
140 | assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
141 | with tf.name_scope(None):
142 | self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
143 |
144 | # Finalize build func kwargs.
145 | build_kwargs = dict(self.static_kwargs)
146 | build_kwargs["is_template_graph"] = True
147 | build_kwargs["components"] = self.components
148 |
149 | # Build template graph.
150 | with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
151 | assert tf.get_variable_scope().name == self.scope
152 | assert tf.get_default_graph().get_name_scope() == self.scope
153 | with tf.control_dependencies(None): # ignore surrounding control dependencies
154 | self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
155 | out_expr = self._build_func(*self.input_templates, **build_kwargs)
156 |
157 | # Collect outputs.
158 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
159 | self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
160 | self.num_outputs = len(self.output_templates)
161 | assert self.num_outputs >= 1
162 | assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
163 |
164 | # Perform sanity checks.
165 | if any(t.shape.ndims is None for t in self.input_templates):
166 | raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167 | if any(t.shape.ndims is None for t in self.output_templates):
168 | raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169 | if any(not isinstance(comp, Network) for comp in self.components.values()):
170 | raise ValueError("Components of a Network must be Networks themselves.")
171 | if len(self.components) != len(set(comp.name for comp in self.components.values())):
172 | raise ValueError("Components of a Network must have unique names.")
173 |
174 | # List inputs and outputs.
175 | self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
176 | self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
177 | self.input_shape = self.input_shapes[0]
178 | self.output_shape = self.output_shapes[0]
179 | self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
180 |
181 | # List variables.
182 | self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
183 | self.vars = OrderedDict(self.own_vars)
184 | self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
185 | self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
186 | self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
187 |
188 | def reset_own_vars(self) -> None:
189 | """Re-initialize all variables of this network, excluding sub-networks."""
190 | tfutil.run([var.initializer for var in self.own_vars.values()])
191 |
192 | def reset_vars(self) -> None:
193 | """Re-initialize all variables of this network, including sub-networks."""
194 | tfutil.run([var.initializer for var in self.vars.values()])
195 |
196 | def reset_trainables(self) -> None:
197 | """Re-initialize all trainable variables of this network, including sub-networks."""
198 | tfutil.run([var.initializer for var in self.trainables.values()])
199 |
200 | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
201 | """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
202 | assert len(in_expr) == self.num_inputs
203 | assert not all(expr is None for expr in in_expr)
204 |
205 | # Finalize build func kwargs.
206 | build_kwargs = dict(self.static_kwargs)
207 | build_kwargs.update(dynamic_kwargs)
208 | build_kwargs["is_template_graph"] = False
209 | build_kwargs["components"] = self.components
210 |
211 | # Build TensorFlow graph to evaluate the network.
212 | with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
213 | assert tf.get_variable_scope().name == self.scope
214 | valid_inputs = [expr for expr in in_expr if expr is not None]
215 | final_inputs = []
216 | for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
217 | if expr is not None:
218 | expr = tf.identity(expr, name=name)
219 | else:
220 | expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
221 | final_inputs.append(expr)
222 | out_expr = self._build_func(*final_inputs, **build_kwargs)
223 |
224 | # Propagate input shapes back to the user-specified expressions.
225 | for expr, final in zip(in_expr, final_inputs):
226 | if isinstance(expr, tf.Tensor):
227 | expr.set_shape(final.shape)
228 |
229 | # Express outputs in the desired format.
230 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
231 | if return_as_list:
232 | out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
233 | return out_expr
234 |
235 | def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
236 | """Get the local name of a given variable, without any surrounding name scopes."""
237 | assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
238 | global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
239 | return self.var_global_to_local[global_name]
240 |
241 | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
242 | """Find variable by local or global name."""
243 | assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
244 | return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
245 |
246 | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
247 | """Get the value of a given variable as NumPy array.
248 | Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
249 | return self.find_var(var_or_local_name).eval()
250 |
251 | def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
252 | """Set the value of a given variable based on the given NumPy array.
253 | Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
254 | tfutil.set_vars({self.find_var(var_or_local_name): new_value})
255 |
256 | def __getstate__(self) -> dict:
257 | """Pickle export."""
258 | state = dict()
259 | state["version"] = 3
260 | state["name"] = self.name
261 | state["static_kwargs"] = dict(self.static_kwargs)
262 | state["components"] = dict(self.components)
263 | state["build_module_src"] = self._build_module_src
264 | state["build_func_name"] = self._build_func_name
265 | state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
266 | return state
267 |
268 | def __setstate__(self, state: dict) -> None:
269 | """Pickle import."""
270 | # pylint: disable=attribute-defined-outside-init
271 | tfutil.assert_tf_initialized()
272 | self._init_fields()
273 |
274 | # Execute custom import handlers.
275 | for handler in _import_handlers:
276 | state = handler(state)
277 |
278 | # Set basic fields.
279 | assert state["version"] in [2, 3]
280 | self.name = state["name"]
281 | self.static_kwargs = util.EasyDict(state["static_kwargs"])
282 | self.components = util.EasyDict(state.get("components", {}))
283 | self._build_module_src = state["build_module_src"]
284 | self._build_func_name = state["build_func_name"]
285 |
286 | # Create temporary module from the imported source code.
287 | module_name = "_tflib_network_import_" + uuid.uuid4().hex
288 | module = types.ModuleType(module_name)
289 | sys.modules[module_name] = module
290 | _import_module_src[module] = self._build_module_src
291 | exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
292 |
293 | # Locate network build function in the temporary module.
294 | self._build_func = util.get_obj_from_module(module, self._build_func_name)
295 | assert callable(self._build_func)
296 |
297 | # Init TensorFlow graph.
298 | self._init_graph()
299 | self.reset_own_vars()
300 | tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
301 |
302 | def clone(self, name: str = None, **new_static_kwargs) -> "Network":
303 | """Create a clone of this network with its own copy of the variables."""
304 | # pylint: disable=protected-access
305 | net = object.__new__(Network)
306 | net._init_fields()
307 | net.name = name if name is not None else self.name
308 | net.static_kwargs = util.EasyDict(self.static_kwargs)
309 | net.static_kwargs.update(new_static_kwargs)
310 | net._build_module_src = self._build_module_src
311 | net._build_func_name = self._build_func_name
312 | net._build_func = self._build_func
313 | net._init_graph()
314 | net.copy_vars_from(self)
315 | return net
316 |
317 | def copy_own_vars_from(self, src_net: "Network") -> None:
318 | """Copy the values of all variables from the given network, excluding sub-networks."""
319 | names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
320 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
321 |
322 | def copy_vars_from(self, src_net: "Network") -> None:
323 | """Copy the values of all variables from the given network, including sub-networks."""
324 | names = [name for name in self.vars.keys() if name in src_net.vars]
325 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
326 |
327 | def copy_trainables_from(self, src_net: "Network") -> None:
328 | """Copy the values of all trainable variables from the given network, including sub-networks."""
329 | names = [name for name in self.trainables.keys() if name in src_net.trainables]
330 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
331 |
332 | def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
333 | """Create new network with the given parameters, and copy all variables from this network."""
334 | if new_name is None:
335 | new_name = self.name
336 | static_kwargs = dict(self.static_kwargs)
337 | static_kwargs.update(new_static_kwargs)
338 | net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
339 | net.copy_vars_from(self)
340 | return net
341 |
342 | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
343 | """Construct a TensorFlow op that updates the variables of this network
344 | to be slightly closer to those of the given network."""
345 | with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
346 | ops = []
347 | for name, var in self.vars.items():
348 | if name in src_net.vars:
349 | cur_beta = beta if name in self.trainables else beta_nontrainable
350 | new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
351 | ops.append(var.assign(new_value))
352 | return tf.group(*ops)
353 |
354 | def run(self,
355 | *in_arrays: Tuple[Union[np.ndarray, None], ...],
356 | input_transform: dict = None,
357 | output_transform: dict = None,
358 | return_as_list: bool = False,
359 | print_progress: bool = False,
360 | minibatch_size: int = None,
361 | num_gpus: int = 1,
362 | assume_frozen: bool = False,
363 | **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
364 | """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
365 |
366 | Args:
367 | input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
368 | The dict must contain a 'func' field that points to a top-level function. The function is called with the input
369 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
370 | output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
371 | The dict must contain a 'func' field that points to a top-level function. The function is called with the output
372 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
373 | return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
374 | print_progress: Print progress to the console? Useful for very large input arrays.
375 | minibatch_size: Maximum minibatch size to use, None = disable batching.
376 | num_gpus: Number of GPUs to use.
377 | assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
378 | dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
379 | """
380 | assert len(in_arrays) == self.num_inputs
381 | assert not all(arr is None for arr in in_arrays)
382 | assert input_transform is None or util.is_top_level_function(input_transform["func"])
383 | assert output_transform is None or util.is_top_level_function(output_transform["func"])
384 | output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
385 | num_items = in_arrays[0].shape[0]
386 | if minibatch_size is None:
387 | minibatch_size = num_items
388 |
389 | # Construct unique hash key from all arguments that affect the TensorFlow graph.
390 | key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
391 | def unwind_key(obj):
392 | if isinstance(obj, dict):
393 | return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
394 | if callable(obj):
395 | return util.get_top_level_function_name(obj)
396 | return obj
397 | key = repr(unwind_key(key))
398 |
399 | # Build graph.
400 | if key not in self._run_cache:
401 | with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
402 | with tf.device("/cpu:0"):
403 | in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
404 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
405 |
406 | out_split = []
407 | for gpu in range(num_gpus):
408 | with tf.device("/gpu:%d" % gpu):
409 | net_gpu = self.clone() if assume_frozen else self
410 | in_gpu = in_split[gpu]
411 |
412 | if input_transform is not None:
413 | in_kwargs = dict(input_transform)
414 | in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
415 | in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
416 |
417 | assert len(in_gpu) == self.num_inputs
418 | out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
419 |
420 | if output_transform is not None:
421 | out_kwargs = dict(output_transform)
422 | out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
423 | out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
424 |
425 | assert len(out_gpu) == self.num_outputs
426 | out_split.append(out_gpu)
427 |
428 | with tf.device("/cpu:0"):
429 | out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
430 | self._run_cache[key] = in_expr, out_expr
431 |
432 | # Run minibatches.
433 | in_expr, out_expr = self._run_cache[key]
434 | out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
435 |
436 | for mb_begin in range(0, num_items, minibatch_size):
437 | if print_progress:
438 | print("\r%d / %d" % (mb_begin, num_items), end="")
439 |
440 | mb_end = min(mb_begin + minibatch_size, num_items)
441 | mb_num = mb_end - mb_begin
442 | mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
443 | mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
444 |
445 | for dst, src in zip(out_arrays, mb_out):
446 | dst[mb_begin: mb_end] = src
447 |
448 | # Done.
449 | if print_progress:
450 | print("\r%d / %d" % (num_items, num_items))
451 |
452 | if not return_as_list:
453 | out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
454 |
455 | return out_arrays
456 |
457 | def list_ops(self) -> List[TfExpression]:
458 | include_prefix = self.scope + "/"
459 | exclude_prefix = include_prefix + "_"
460 | ops = tf.get_default_graph().get_operations()
461 | ops = [op for op in ops if op.name.startswith(include_prefix)]
462 | ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
463 | return ops
464 |
465 | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
466 | """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
467 | individual layers of the network. Mainly intended to be used for reporting."""
468 | layers = []
469 |
470 | def recurse(scope, parent_ops, parent_vars, level):
471 | # Ignore specific patterns.
472 | if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
473 | return
474 |
475 | # Filter ops and vars by scope.
476 | global_prefix = scope + "/"
477 | local_prefix = global_prefix[len(self.scope) + 1:]
478 | cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
479 | cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
480 | if not cur_ops and not cur_vars:
481 | return
482 |
483 | # Filter out all ops related to variables.
484 | for var in [op for op in cur_ops if op.type.startswith("Variable")]:
485 | var_prefix = var.name + "/"
486 | cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
487 |
488 | # Scope does not contain ops as immediate children => recurse deeper.
489 | contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
490 | if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
491 | visited = set()
492 | for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
493 | token = rel_name.split("/")[0]
494 | if token not in visited:
495 | recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
496 | visited.add(token)
497 | return
498 |
499 | # Report layer.
500 | layer_name = scope[len(self.scope) + 1:]
501 | layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
502 | layer_trainables = [var for _name, var in cur_vars if var.trainable]
503 | layers.append((layer_name, layer_output, layer_trainables))
504 |
505 | recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
506 | return layers
507 |
508 | def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
509 | """Print a summary table of the network structure."""
510 | rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
511 | rows += [["---"] * 4]
512 | total_params = 0
513 |
514 | for layer_name, layer_output, layer_trainables in self.list_layers():
515 | num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
516 | weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
517 | weights.sort(key=lambda x: len(x.name))
518 | if len(weights) == 0 and len(layer_trainables) == 1:
519 | weights = layer_trainables
520 | total_params += num_params
521 |
522 | if not hide_layers_with_no_params or num_params != 0:
523 | num_params_str = str(num_params) if num_params > 0 else "-"
524 | output_shape_str = str(layer_output.shape)
525 | weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
526 | rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
527 |
528 | rows += [["---"] * 4]
529 | rows += [["Total", str(total_params), "", ""]]
530 |
531 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
532 | print()
533 | for row in rows:
534 | print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
535 | print()
536 |
537 | def setup_weight_histograms(self, title: str = None) -> None:
538 | """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
539 | if title is None:
540 | title = self.name
541 |
542 | with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
543 | for local_name, var in self.trainables.items():
544 | if "/" in local_name:
545 | p = local_name.split("/")
546 | name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
547 | else:
548 | name = title + "_toplevel/" + local_name
549 |
550 | tf.summary.histogram(name, var)
551 |
552 | def save_my_model(self) -> None:
553 | # 这儿是新添加的保存模型的代码,为方便转化为tfjs
554 | # tf.train.Saver().save(tf.get_default_session(), 'networks/ckpt/my_net.ckpt')
555 |
556 | graph = tf.get_default_graph() # 获得默认的图
557 | input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
558 | for node in input_graph_def.node: # 把batchnorm给处理一下,不然不好转成tensorflowjs
559 | if node.op == 'RefSwitch':
560 | node.op = 'Switch'
561 | for index in range(len(node.input)):
562 | if 'moving_' in node.input[index]:
563 | node.input[index] = node.input[index] + '/read'
564 | elif node.op == 'AssignSub':
565 | node.op = 'Sub'
566 | if 'use_locking' in node.attr: del node.attr['use_locking']
567 | with tf.get_default_session() as sess:
568 | tf.saved_model.simple_save(sess, "./模型转换/saved_model",
569 | inputs={self.input_names[0]: self.input_templates[0]},
570 | outputs={self.output_names[0]: self.output_templates[0]})
571 |
572 | #----------------------------------------------------------------------------
573 | # Backwards-compatible emulation of legacy output transformation in Network.run().
574 |
575 | _print_legacy_warning = True
576 |
577 | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
578 | global _print_legacy_warning
579 | legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
580 | if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
581 | return output_transform, dynamic_kwargs
582 |
583 | if _print_legacy_warning:
584 | _print_legacy_warning = False
585 | print()
586 | print("WARNING: Old-style output transformations in Network.run() are deprecated.")
587 | print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
588 | print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
589 | print()
590 | assert output_transform is None
591 |
592 | new_kwargs = dict(dynamic_kwargs)
593 | new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
594 | new_transform["func"] = _legacy_output_transform_func
595 | return new_transform, new_kwargs
596 |
597 | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
598 | if out_mul != 1.0:
599 | expr = [x * out_mul for x in expr]
600 |
601 | if out_add != 0.0:
602 | expr = [x + out_add for x in expr]
603 |
604 | if out_shrink > 1:
605 | ksize = [1, 1, out_shrink, out_shrink]
606 | expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
607 |
608 | if out_dtype is not None:
609 | if tf.as_dtype(out_dtype).is_integer:
610 | expr = [tf.round(x) for x in expr]
611 | expr = [tf.saturate_cast(x, out_dtype) for x in expr]
612 | return expr
613 |
--------------------------------------------------------------------------------