);
187 |
188 | //------------------------------------------------------------------------
189 |
--------------------------------------------------------------------------------
/dnnlib/tflib/ops/fused_bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Custom TensorFlow ops for efficient bias and activation."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 | from .. import custom_ops
13 | from ...util import EasyDict
14 |
15 | def _get_plugin():
16 | return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | activation_funcs = {
21 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
22 | 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
23 | 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
24 | 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
25 | 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
26 | 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
27 | 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
28 | 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
29 | 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
30 | }
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
35 | r"""Fused bias and activation function.
36 |
37 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
38 | and scales the result by `gain`. Each of the steps is optional. In most cases,
39 | the fused op is considerably more efficient than performing the same calculation
40 | using standard TensorFlow ops. It supports first and second order gradients,
41 | but not third order gradients.
42 |
43 | Args:
44 | x: Input activation tensor. Can have any shape, but if `b` is defined, the
45 | dimension corresponding to `axis`, as well as the rank, must be known.
46 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
47 | as `x`. The shape must be known, and it must match the dimension of `x`
48 | corresponding to `axis`.
49 | axis: The dimension in `x` corresponding to the elements of `b`.
50 | The value of `axis` is ignored if `b` is not specified.
51 | act: Name of the activation function to evaluate, or `"linear"` to disable.
52 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
53 | See `activation_funcs` for a full list. `None` is not allowed.
54 | alpha: Shape parameter for the activation function, or `None` to use the default.
55 | gain: Scaling factor for the output tensor, or `None` to use default.
56 | See `activation_funcs` for the default scaling of each activation function.
57 | If unsure, consider specifying `1.0`.
58 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
59 |
60 | Returns:
61 | Tensor of the same shape and datatype as `x`.
62 | """
63 |
64 | impl_dict = {
65 | 'ref': _fused_bias_act_ref,
66 | 'cuda': _fused_bias_act_cuda,
67 | }
68 | return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
73 | """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
74 |
75 | # Validate arguments.
76 | x = tf.convert_to_tensor(x)
77 | b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
78 | act_spec = activation_funcs[act]
79 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
80 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
81 | if alpha is None:
82 | alpha = act_spec.def_alpha
83 | if gain is None:
84 | gain = act_spec.def_gain
85 |
86 | # Add bias.
87 | if b.shape[0] != 0:
88 | x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
89 |
90 | # Evaluate activation function.
91 | x = act_spec.func(x, alpha=alpha)
92 |
93 | # Scale by gain.
94 | if gain != 1:
95 | x *= gain
96 | return x
97 |
98 | #----------------------------------------------------------------------------
99 |
100 | def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
101 | """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
102 |
103 | # Validate arguments.
104 | x = tf.convert_to_tensor(x)
105 | empty_tensor = tf.constant([], dtype=x.dtype)
106 | b = tf.convert_to_tensor(b) if b is not None else empty_tensor
107 | act_spec = activation_funcs[act]
108 | assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
109 | assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
110 | if alpha is None:
111 | alpha = act_spec.def_alpha
112 | if gain is None:
113 | gain = act_spec.def_gain
114 |
115 | # Special cases.
116 | if act == 'linear' and b is None and gain == 1.0:
117 | return x
118 | if act_spec.cuda_idx is None:
119 | return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
120 |
121 | # CUDA kernel.
122 | cuda_kernel = _get_plugin().fused_bias_act
123 | cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
124 |
125 | # Forward pass: y = func(x, b).
126 | def func_y(x, b):
127 | y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
128 | y.set_shape(x.shape)
129 | return y
130 |
131 | # Backward pass: dx, db = grad(dy, x, y)
132 | def grad_dx(dy, x, y):
133 | ref = {'x': x, 'y': y}[act_spec.ref]
134 | dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
135 | dx.set_shape(x.shape)
136 | return dx
137 | def grad_db(dx):
138 | if b.shape[0] == 0:
139 | return empty_tensor
140 | db = dx
141 | if axis < x.shape.rank - 1:
142 | db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
143 | if axis > 0:
144 | db = tf.reduce_sum(db, list(range(axis)))
145 | db.set_shape(b.shape)
146 | return db
147 |
148 | # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
149 | def grad2_d_dy(d_dx, d_db, x, y):
150 | ref = {'x': x, 'y': y}[act_spec.ref]
151 | d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
152 | d_dy.set_shape(x.shape)
153 | return d_dy
154 | def grad2_d_x(d_dx, d_db, x, y):
155 | ref = {'x': x, 'y': y}[act_spec.ref]
156 | d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
157 | d_x.set_shape(x.shape)
158 | return d_x
159 |
160 | # Fast version for piecewise-linear activation funcs.
161 | @tf.custom_gradient
162 | def func_zero_2nd_grad(x, b):
163 | y = func_y(x, b)
164 | @tf.custom_gradient
165 | def grad(dy):
166 | dx = grad_dx(dy, x, y)
167 | db = grad_db(dx)
168 | def grad2(d_dx, d_db):
169 | d_dy = grad2_d_dy(d_dx, d_db, x, y)
170 | return d_dy
171 | return (dx, db), grad2
172 | return y, grad
173 |
174 | # Slow version for general activation funcs.
175 | @tf.custom_gradient
176 | def func_nonzero_2nd_grad(x, b):
177 | y = func_y(x, b)
178 | def grad_wrap(dy):
179 | @tf.custom_gradient
180 | def grad_impl(dy, x):
181 | dx = grad_dx(dy, x, y)
182 | db = grad_db(dx)
183 | def grad2(d_dx, d_db):
184 | d_dy = grad2_d_dy(d_dx, d_db, x, y)
185 | d_x = grad2_d_x(d_dx, d_db, x, y)
186 | return d_dy, d_x
187 | return (dx, db), grad2
188 | return grad_impl(dy, x)
189 | return y, grad_wrap
190 |
191 | # Which version to use?
192 | if act_spec.zero_2nd_grad:
193 | return func_zero_2nd_grad(x, b)
194 | return func_nonzero_2nd_grad(x, b)
195 |
196 | #----------------------------------------------------------------------------
197 |
--------------------------------------------------------------------------------
/dnnlib/tflib/tfutil.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Miscellaneous helper utils for Tensorflow."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | # Silence deprecation warnings from TensorFlow 1.13 onwards
14 | import logging
15 | logging.getLogger('tensorflow').setLevel(logging.ERROR)
16 | import tensorflow.contrib # requires TensorFlow 1.x!
17 | tf.contrib = tensorflow.contrib
18 |
19 | from typing import Any, Iterable, List, Union
20 |
21 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
22 | """A type that represents a valid Tensorflow expression."""
23 |
24 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
25 | """A type that can be converted to a valid Tensorflow expression."""
26 |
27 |
28 | def run(*args, **kwargs) -> Any:
29 | """Run the specified ops in the default session."""
30 | assert_tf_initialized()
31 | return tf.get_default_session().run(*args, **kwargs)
32 |
33 |
34 | def is_tf_expression(x: Any) -> bool:
35 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
36 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
37 |
38 |
39 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
40 | """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
41 | return [dim.value for dim in shape]
42 |
43 |
44 | def flatten(x: TfExpressionEx) -> TfExpression:
45 | """Shortcut function for flattening a tensor."""
46 | with tf.name_scope("Flatten"):
47 | return tf.reshape(x, [-1])
48 |
49 |
50 | def log2(x: TfExpressionEx) -> TfExpression:
51 | """Logarithm in base 2."""
52 | with tf.name_scope("Log2"):
53 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
54 |
55 |
56 | def exp2(x: TfExpressionEx) -> TfExpression:
57 | """Exponent in base 2."""
58 | with tf.name_scope("Exp2"):
59 | return tf.exp(x * np.float32(np.log(2.0)))
60 |
61 |
62 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
63 | """Linear interpolation."""
64 | with tf.name_scope("Lerp"):
65 | return a + (b - a) * t
66 |
67 |
68 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
69 | """Linear interpolation with clip."""
70 | with tf.name_scope("LerpClip"):
71 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
72 |
73 |
74 | def absolute_name_scope(scope: str) -> tf.name_scope:
75 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
76 | return tf.name_scope(scope + "/")
77 |
78 |
79 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
80 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
81 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
82 |
83 |
84 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
85 | # Defaults.
86 | cfg = dict()
87 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
88 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
89 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
90 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
91 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
92 |
93 | # Remove defaults for environment variables that are already set.
94 | for key in list(cfg):
95 | fields = key.split(".")
96 | if fields[0] == "env":
97 | assert len(fields) == 2
98 | if fields[1] in os.environ:
99 | del cfg[key]
100 |
101 | # User overrides.
102 | if config_dict is not None:
103 | cfg.update(config_dict)
104 | return cfg
105 |
106 |
107 | def init_tf(config_dict: dict = None) -> None:
108 | """Initialize TensorFlow session using good default settings."""
109 | # Skip if already initialized.
110 | if tf.get_default_session() is not None:
111 | return
112 |
113 | # Setup config dict and random seeds.
114 | cfg = _sanitize_tf_config(config_dict)
115 | np_random_seed = cfg["rnd.np_random_seed"]
116 | if np_random_seed is not None:
117 | np.random.seed(np_random_seed)
118 | tf_random_seed = cfg["rnd.tf_random_seed"]
119 | if tf_random_seed == "auto":
120 | tf_random_seed = np.random.randint(1 << 31)
121 | if tf_random_seed is not None:
122 | tf.set_random_seed(tf_random_seed)
123 |
124 | # Setup environment variables.
125 | for key, value in cfg.items():
126 | fields = key.split(".")
127 | if fields[0] == "env":
128 | assert len(fields) == 2
129 | os.environ[fields[1]] = str(value)
130 |
131 | # Create default TensorFlow session.
132 | create_session(cfg, force_as_default=True)
133 |
134 |
135 | def assert_tf_initialized():
136 | """Check that TensorFlow session has been initialized."""
137 | if tf.get_default_session() is None:
138 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
139 |
140 |
141 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
142 | """Create tf.Session based on config dict."""
143 | # Setup TensorFlow config proto.
144 | cfg = _sanitize_tf_config(config_dict)
145 | config_proto = tf.ConfigProto()
146 | for key, value in cfg.items():
147 | fields = key.split(".")
148 | if fields[0] not in ["rnd", "env"]:
149 | obj = config_proto
150 | for field in fields[:-1]:
151 | obj = getattr(obj, field)
152 | setattr(obj, fields[-1], value)
153 |
154 | # Create session.
155 | session = tf.Session(config=config_proto)
156 | if force_as_default:
157 | # pylint: disable=protected-access
158 | session._default_session = session.as_default()
159 | session._default_session.enforce_nesting = False
160 | session._default_session.__enter__()
161 | return session
162 |
163 |
164 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
165 | """Initialize all tf.Variables that have not already been initialized.
166 |
167 | Equivalent to the following, but more efficient and does not bloat the tf graph:
168 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
169 | """
170 | assert_tf_initialized()
171 | if target_vars is None:
172 | target_vars = tf.global_variables()
173 |
174 | test_vars = []
175 | test_ops = []
176 |
177 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
178 | for var in target_vars:
179 | assert is_tf_expression(var)
180 |
181 | try:
182 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
183 | except KeyError:
184 | # Op does not exist => variable may be uninitialized.
185 | test_vars.append(var)
186 |
187 | with absolute_name_scope(var.name.split(":")[0]):
188 | test_ops.append(tf.is_variable_initialized(var))
189 |
190 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
191 | run([var.initializer for var in init_vars])
192 |
193 |
194 | def set_vars(var_to_value_dict: dict) -> None:
195 | """Set the values of given tf.Variables.
196 |
197 | Equivalent to the following, but more efficient and does not bloat the tf graph:
198 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
199 | """
200 | assert_tf_initialized()
201 | ops = []
202 | feed_dict = {}
203 |
204 | for var, value in var_to_value_dict.items():
205 | assert is_tf_expression(var)
206 |
207 | try:
208 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
209 | except KeyError:
210 | with absolute_name_scope(var.name.split(":")[0]):
211 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
212 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
213 |
214 | ops.append(setter)
215 | feed_dict[setter.op.inputs[1]] = value
216 |
217 | run(ops, feed_dict)
218 |
219 |
220 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
221 | """Create tf.Variable with large initial value without bloating the tf graph."""
222 | assert_tf_initialized()
223 | assert isinstance(initial_value, np.ndarray)
224 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
225 | var = tf.Variable(zeros, *args, **kwargs)
226 | set_vars({var: initial_value})
227 | return var
228 |
229 |
230 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
231 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
232 | Can be used as an input transformation for Network.run().
233 | """
234 | images = tf.cast(images, tf.float32)
235 | if nhwc_to_nchw:
236 | images = tf.transpose(images, [0, 3, 1, 2])
237 | return images * ((drange[1] - drange[0]) / 255) + drange[0]
238 |
239 |
240 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
241 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
242 | Can be used as an output transformation for Network.run().
243 | """
244 | images = tf.cast(images, tf.float32)
245 | if shrink > 1:
246 | ksize = [1, 1, shrink, shrink]
247 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
248 | if nchw_to_nhwc:
249 | images = tf.transpose(images, [0, 2, 3, 1])
250 | scale = 255 / (drange[1] - drange[0])
251 | images = images * scale + (0.5 - drange[0] * scale)
252 | return tf.saturate_cast(images, tf.uint8)
253 |
--------------------------------------------------------------------------------
/docs/license.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Nvidia Source Code License-NC
7 |
8 |
56 |
57 |
58 |
59 | Nvidia Source Code License-NC
60 |
61 |
62 |
63 | 1. Definitions
64 |
65 | “Licensor” means any person or entity that distributes its Work.
66 |
67 | “Software” means the original work of authorship made available under
68 | this License.
69 |
70 | “Work” means the Software and any additions to or derivative works of
71 | the Software that are made available under this License.
72 |
73 | “Nvidia Processors” means any central processing unit (CPU), graphics
74 | processing unit (GPU), field-programmable gate array (FPGA),
75 | application-specific integrated circuit (ASIC) or any combination
76 | thereof designed, made, sold, or provided by Nvidia or its affiliates.
77 |
78 | The terms “reproduce,” “reproduction,” “derivative works,” and
79 | “distribution” have the meaning as provided under U.S. copyright law;
80 | provided, however, that for the purposes of this License, derivative
81 | works shall not include works that remain separable from, or merely
82 | link (or bind by name) to the interfaces of, the Work.
83 |
84 | Works, including the Software, are “made available” under this License
85 | by including in or with the Work either (a) a copyright notice
86 | referencing the applicability of this License to the Work, or (b) a
87 | copy of this License.
88 |
89 |
2. License Grants
90 |
91 | 2.1 Copyright Grant. Subject to the terms and conditions of this
92 | License, each Licensor grants to you a perpetual, worldwide,
93 | non-exclusive, royalty-free, copyright license to reproduce,
94 | prepare derivative works of, publicly display, publicly perform,
95 | sublicense and distribute its Work and any resulting derivative
96 | works in any form.
97 |
98 | 3. Limitations
99 |
100 | 3.1 Redistribution. You may reproduce or distribute the Work only
101 | if (a) you do so under this License, (b) you include a complete
102 | copy of this License with your distribution, and (c) you retain
103 | without modification any copyright, patent, trademark, or
104 | attribution notices that are present in the Work.
105 |
106 | 3.2 Derivative Works. You may specify that additional or different
107 | terms apply to the use, reproduction, and distribution of your
108 | derivative works of the Work (“Your Terms”) only if (a) Your Terms
109 | provide that the use limitation in Section 3.3 applies to your
110 | derivative works, and (b) you identify the specific derivative
111 | works that are subject to Your Terms. Notwithstanding Your Terms,
112 | this License (including the redistribution requirements in Section
113 | 3.1) will continue to apply to the Work itself.
114 |
115 | 3.3 Use Limitation. The Work and any derivative works thereof only
116 | may be used or intended for use non-commercially. The Work or
117 | derivative works thereof may be used or intended for use by Nvidia
118 | or its affiliates commercially or non-commercially. As used herein,
119 | “non-commercially” means for research or evaluation purposes only.
120 |
121 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
122 | against any Licensor (including any claim, cross-claim or
123 | counterclaim in a lawsuit) to enforce any patents that you allege
124 | are infringed by any Work, then your rights under this License from
125 | such Licensor (including the grants in Sections 2.1 and 2.2) will
126 | terminate immediately.
127 |
128 | 3.5 Trademarks. This License does not grant any rights to use any
129 | Licensor’s or its affiliates’ names, logos, or trademarks, except
130 | as necessary to reproduce the notices described in this License.
131 |
132 | 3.6 Termination. If you violate any term of this License, then your
133 | rights under this License (including the grants in Sections 2.1 and
134 | 2.2) will terminate immediately.
135 |
136 | 4. Disclaimer of Warranty.
137 |
138 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY
139 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
140 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
141 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
142 | THIS LICENSE.
143 |
144 | 5. Limitation of Liability.
145 |
146 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
147 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
148 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
149 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
150 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
151 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
152 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
153 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
154 | THE POSSIBILITY OF SUCH DAMAGES.
155 |
156 |
157 |
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/docs/stylegan2-teaser-1024x256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/stylegan2/bf0fe0baba9fc7039eae0cac575c1778be1ce3e3/docs/stylegan2-teaser-1024x256.png
--------------------------------------------------------------------------------
/docs/stylegan2-training-curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/stylegan2/bf0fe0baba9fc7039eae0cac575c1778be1ce3e3/docs/stylegan2-training-curves.png
--------------------------------------------------------------------------------
/docs/versions.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | StyleGAN versions
7 |
8 |
45 |
46 |
47 |
48 | StyleGAN3 (2021)
49 |
54 |
55 | StyleGAN2-ADA (2020)
56 |
62 |
63 | StyleGAN2 (2019)
64 |
69 |
70 | StyleGAN (2018)
71 |
77 |
78 | Progressive GAN (2017)
79 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | # empty
8 |
--------------------------------------------------------------------------------
/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Frechet Inception Distance (FID)."""
8 |
9 | import os
10 | import numpy as np
11 | import scipy
12 | import tensorflow as tf
13 | import dnnlib.tflib as tflib
14 |
15 | from metrics import metric_base
16 | from training import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class FID(metric_base.MetricBase):
21 | def __init__(self, num_images, minibatch_per_gpu, **kwargs):
22 | super().__init__(**kwargs)
23 | self.num_images = num_images
24 | self.minibatch_per_gpu = minibatch_per_gpu
25 |
26 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
27 | minibatch_size = num_gpus * self.minibatch_per_gpu
28 | inception = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/inception_v3_features.pkl')
29 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
30 |
31 | # Calculate statistics for reals.
32 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
33 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
34 | if os.path.isfile(cache_file):
35 | mu_real, sigma_real = misc.load_pkl(cache_file)
36 | else:
37 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):
38 | begin = idx * minibatch_size
39 | end = min(begin + minibatch_size, self.num_images)
40 | activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)
41 | if end == self.num_images:
42 | break
43 | mu_real = np.mean(activations, axis=0)
44 | sigma_real = np.cov(activations, rowvar=False)
45 | misc.save_pkl((mu_real, sigma_real), cache_file)
46 |
47 | # Construct TensorFlow graph.
48 | result_expr = []
49 | for gpu_idx in range(num_gpus):
50 | with tf.device('/gpu:%d' % gpu_idx):
51 | Gs_clone = Gs.clone()
52 | inception_clone = inception.clone()
53 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
54 | labels = self._get_random_labels_tf(self.minibatch_per_gpu)
55 | images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
56 | images = tflib.convert_images_to_uint8(images)
57 | result_expr.append(inception_clone.get_output_for(images))
58 |
59 | # Calculate statistics for fakes.
60 | for begin in range(0, self.num_images, minibatch_size):
61 | self._report_progress(begin, self.num_images)
62 | end = min(begin + minibatch_size, self.num_images)
63 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
64 | mu_fake = np.mean(activations, axis=0)
65 | sigma_fake = np.cov(activations, rowvar=False)
66 |
67 | # Calculate FID.
68 | m = np.square(mu_fake - mu_real).sum()
69 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
70 | dist = m + np.trace(sigma_fake + sigma_real - 2*s)
71 | self._report_result(np.real(dist))
72 |
73 | #----------------------------------------------------------------------------
74 |
--------------------------------------------------------------------------------
/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Inception Score (IS)."""
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | import dnnlib.tflib as tflib
12 |
13 | from metrics import metric_base
14 | from training import misc
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | class IS(metric_base.MetricBase):
19 | def __init__(self, num_images, num_splits, minibatch_per_gpu, **kwargs):
20 | super().__init__(**kwargs)
21 | self.num_images = num_images
22 | self.num_splits = num_splits
23 | self.minibatch_per_gpu = minibatch_per_gpu
24 |
25 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
26 | minibatch_size = num_gpus * self.minibatch_per_gpu
27 | inception = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/inception_v3_softmax.pkl')
28 | activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
29 |
30 | # Construct TensorFlow graph.
31 | result_expr = []
32 | for gpu_idx in range(num_gpus):
33 | with tf.device('/gpu:%d' % gpu_idx):
34 | Gs_clone = Gs.clone()
35 | inception_clone = inception.clone()
36 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
37 | labels = self._get_random_labels_tf(self.minibatch_per_gpu)
38 | images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
39 | images = tflib.convert_images_to_uint8(images)
40 | result_expr.append(inception_clone.get_output_for(images))
41 |
42 | # Calculate activations for fakes.
43 | for begin in range(0, self.num_images, minibatch_size):
44 | self._report_progress(begin, self.num_images)
45 | end = min(begin + minibatch_size, self.num_images)
46 | activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
47 |
48 | # Calculate IS.
49 | scores = []
50 | for i in range(self.num_splits):
51 | part = activations[i * self.num_images // self.num_splits : (i + 1) * self.num_images // self.num_splits]
52 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
53 | kl = np.mean(np.sum(kl, 1))
54 | scores.append(np.exp(kl))
55 | self._report_result(np.mean(scores), suffix='_mean')
56 | self._report_result(np.std(scores), suffix='_std')
57 |
58 | #----------------------------------------------------------------------------
59 |
--------------------------------------------------------------------------------
/metrics/linear_separability.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Linear Separability (LS)."""
8 |
9 | from collections import defaultdict
10 | import numpy as np
11 | import sklearn.svm
12 | import tensorflow as tf
13 | import dnnlib.tflib as tflib
14 |
15 | from metrics import metric_base
16 | from training import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | classifier_urls = [
21 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-00-male.pkl',
22 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-01-smiling.pkl',
23 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-02-attractive.pkl',
24 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-03-wavy-hair.pkl',
25 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-04-young.pkl',
26 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-05-5-o-clock-shadow.pkl',
27 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-06-arched-eyebrows.pkl',
28 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-07-bags-under-eyes.pkl',
29 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-08-bald.pkl',
30 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-09-bangs.pkl',
31 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-10-big-lips.pkl',
32 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-11-big-nose.pkl',
33 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-12-black-hair.pkl',
34 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-13-blond-hair.pkl',
35 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-14-blurry.pkl',
36 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-15-brown-hair.pkl',
37 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-16-bushy-eyebrows.pkl',
38 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-17-chubby.pkl',
39 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-18-double-chin.pkl',
40 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-19-eyeglasses.pkl',
41 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-20-goatee.pkl',
42 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-21-gray-hair.pkl',
43 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-22-heavy-makeup.pkl',
44 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-23-high-cheekbones.pkl',
45 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-24-mouth-slightly-open.pkl',
46 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-25-mustache.pkl',
47 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-26-narrow-eyes.pkl',
48 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-27-no-beard.pkl',
49 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-28-oval-face.pkl',
50 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-29-pale-skin.pkl',
51 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-30-pointy-nose.pkl',
52 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-31-receding-hairline.pkl',
53 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-32-rosy-cheeks.pkl',
54 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-33-sideburns.pkl',
55 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-34-straight-hair.pkl',
56 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-35-wearing-earrings.pkl',
57 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-36-wearing-hat.pkl',
58 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-37-wearing-lipstick.pkl',
59 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-38-wearing-necklace.pkl',
60 | 'https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/celebahq-classifier-39-wearing-necktie.pkl',
61 | ]
62 |
63 | #----------------------------------------------------------------------------
64 |
65 | def prob_normalize(p):
66 | p = np.asarray(p).astype(np.float32)
67 | assert len(p.shape) == 2
68 | return p / np.sum(p)
69 |
70 | def mutual_information(p):
71 | p = prob_normalize(p)
72 | px = np.sum(p, axis=1)
73 | py = np.sum(p, axis=0)
74 | result = 0.0
75 | for x in range(p.shape[0]):
76 | p_x = px[x]
77 | for y in range(p.shape[1]):
78 | p_xy = p[x][y]
79 | p_y = py[y]
80 | if p_xy > 0.0:
81 | result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output
82 | return result
83 |
84 | def entropy(p):
85 | p = prob_normalize(p)
86 | result = 0.0
87 | for x in range(p.shape[0]):
88 | for y in range(p.shape[1]):
89 | p_xy = p[x][y]
90 | if p_xy > 0.0:
91 | result -= p_xy * np.log2(p_xy)
92 | return result
93 |
94 | def conditional_entropy(p):
95 | # H(Y|X) where X corresponds to axis 0, Y to axis 1
96 | # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0?
97 | p = prob_normalize(p)
98 | y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y)
99 | return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up.
100 |
101 | #----------------------------------------------------------------------------
102 |
103 | class LS(metric_base.MetricBase):
104 | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs):
105 | assert num_keep <= num_samples
106 | super().__init__(**kwargs)
107 | self.num_samples = num_samples
108 | self.num_keep = num_keep
109 | self.attrib_indices = attrib_indices
110 | self.minibatch_per_gpu = minibatch_per_gpu
111 |
112 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
113 | minibatch_size = num_gpus * self.minibatch_per_gpu
114 |
115 | # Construct TensorFlow graph for each GPU.
116 | result_expr = []
117 | for gpu_idx in range(num_gpus):
118 | with tf.device('/gpu:%d' % gpu_idx):
119 | Gs_clone = Gs.clone()
120 |
121 | # Generate images.
122 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
123 | labels = self._get_random_labels_tf(self.minibatch_per_gpu)
124 | dlatents = Gs_clone.components.mapping.get_output_for(latents, labels, **Gs_kwargs)
125 | images = Gs_clone.get_output_for(latents, None, **Gs_kwargs)
126 |
127 | # Downsample to 256x256. The attribute classifiers were built for 256x256.
128 | if images.shape[2] > 256:
129 | factor = images.shape[2] // 256
130 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
131 | images = tf.reduce_mean(images, axis=[3, 5])
132 |
133 | # Run classifier for each attribute.
134 | result_dict = dict(latents=latents, dlatents=dlatents[:,-1])
135 | for attrib_idx in self.attrib_indices:
136 | classifier = misc.load_pkl(classifier_urls[attrib_idx])
137 | logits = classifier.get_output_for(images, None)
138 | predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1))
139 | result_dict[attrib_idx] = predictions
140 | result_expr.append(result_dict)
141 |
142 | # Sampling loop.
143 | results = []
144 | for begin in range(0, self.num_samples, minibatch_size):
145 | self._report_progress(begin, self.num_samples)
146 | results += tflib.run(result_expr)
147 | results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()}
148 |
149 | # Calculate conditional entropy for each attribute.
150 | conditional_entropies = defaultdict(list)
151 | for attrib_idx in self.attrib_indices:
152 | # Prune the least confident samples.
153 | pruned_indices = list(range(self.num_samples))
154 | pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i]))
155 | pruned_indices = pruned_indices[:self.num_keep]
156 |
157 | # Fit SVM to the remaining samples.
158 | svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1)
159 | for space in ['latents', 'dlatents']:
160 | svm_inputs = results[space][pruned_indices]
161 | try:
162 | svm = sklearn.svm.LinearSVC()
163 | svm.fit(svm_inputs, svm_targets)
164 | svm.score(svm_inputs, svm_targets)
165 | svm_outputs = svm.predict(svm_inputs)
166 | except:
167 | svm_outputs = svm_targets # assume perfect prediction
168 |
169 | # Calculate conditional entropy.
170 | p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)]
171 | conditional_entropies[space].append(conditional_entropy(p))
172 |
173 | # Calculate separability scores.
174 | scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()}
175 | self._report_result(scores['latents'], suffix='_z')
176 | self._report_result(scores['dlatents'], suffix='_w')
177 |
178 | #----------------------------------------------------------------------------
179 |
--------------------------------------------------------------------------------
/metrics/metric_base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Common definitions for GAN metrics."""
8 |
9 | import os
10 | import time
11 | import hashlib
12 | import numpy as np
13 | import tensorflow as tf
14 | import dnnlib
15 | import dnnlib.tflib as tflib
16 |
17 | from training import misc
18 | from training import dataset
19 |
20 | #----------------------------------------------------------------------------
21 | # Base class for metrics.
22 |
23 | class MetricBase:
24 | def __init__(self, name):
25 | self.name = name
26 | self._dataset_obj = None
27 | self._progress_lo = None
28 | self._progress_hi = None
29 | self._progress_max = None
30 | self._progress_sec = None
31 | self._progress_time = None
32 | self._reset()
33 |
34 | def close(self):
35 | self._reset()
36 |
37 | def _reset(self, network_pkl=None, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None):
38 | if self._dataset_obj is not None:
39 | self._dataset_obj.close()
40 |
41 | self._network_pkl = network_pkl
42 | self._data_dir = data_dir
43 | self._dataset_args = dataset_args
44 | self._dataset_obj = None
45 | self._mirror_augment = mirror_augment
46 | self._eval_time = 0
47 | self._results = []
48 |
49 | if (dataset_args is None or mirror_augment is None) and run_dir is not None:
50 | run_config = misc.parse_config_for_previous_run(run_dir)
51 | self._dataset_args = dict(run_config['dataset'])
52 | self._dataset_args['shuffle_mb'] = 0
53 | self._mirror_augment = run_config['train'].get('mirror_augment', False)
54 |
55 | def configure_progress_reports(self, plo, phi, pmax, psec=15):
56 | self._progress_lo = plo
57 | self._progress_hi = phi
58 | self._progress_max = pmax
59 | self._progress_sec = psec
60 |
61 | def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, Gs_kwargs=dict(is_validation=True)):
62 | self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment)
63 | time_begin = time.time()
64 | with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager
65 | self._report_progress(0, 1)
66 | _G, _D, Gs = misc.load_pkl(self._network_pkl)
67 | self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus)
68 | self._report_progress(1, 1)
69 | self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init
70 |
71 | if log_results:
72 | if run_dir is not None:
73 | log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name)
74 | with dnnlib.util.Logger(log_file, 'a'):
75 | print(self.get_result_str().strip())
76 | else:
77 | print(self.get_result_str().strip())
78 |
79 | def get_result_str(self):
80 | network_name = os.path.splitext(os.path.basename(self._network_pkl))[0]
81 | if len(network_name) > 29:
82 | network_name = '...' + network_name[-26:]
83 | result_str = '%-30s' % network_name
84 | result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time)
85 | for res in self._results:
86 | result_str += ' ' + self.name + res.suffix + ' '
87 | result_str += res.fmt % res.value
88 | return result_str
89 |
90 | def update_autosummaries(self):
91 | for res in self._results:
92 | tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)
93 |
94 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
95 | raise NotImplementedError # to be overridden by subclasses
96 |
97 | def _report_result(self, value, suffix='', fmt='%-10.4f'):
98 | self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
99 |
100 | def _report_progress(self, pcur, pmax, status_str=''):
101 | if self._progress_lo is None or self._progress_hi is None or self._progress_max is None:
102 | return
103 | t = time.time()
104 | if self._progress_sec is not None and self._progress_time is not None and t < self._progress_time + self._progress_sec:
105 | return
106 | self._progress_time = t
107 | val = self._progress_lo + (pcur / pmax) * (self._progress_hi - self._progress_lo)
108 | dnnlib.RunContext.get().update(status_str, int(val), self._progress_max)
109 |
110 | def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
111 | all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)
112 | all_args.update(self._dataset_args)
113 | all_args.update(kwargs)
114 | md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
115 | dataset_name = self._dataset_args.get('tfrecord_dir', None) or self._dataset_args.get('h5_file', None)
116 | dataset_name = os.path.splitext(os.path.basename(dataset_name))[0]
117 | return os.path.join('.stylegan2-cache', '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))
118 |
119 | def _get_dataset_obj(self):
120 | if self._dataset_obj is None:
121 | self._dataset_obj = dataset.load_dataset(data_dir=self._data_dir, **self._dataset_args)
122 | return self._dataset_obj
123 |
124 | def _iterate_reals(self, minibatch_size):
125 | dataset_obj = self._get_dataset_obj()
126 | while True:
127 | images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
128 | if self._mirror_augment:
129 | images = misc.apply_mirror_augment(images)
130 | yield images
131 |
132 | def _iterate_fakes(self, Gs, minibatch_size, num_gpus):
133 | while True:
134 | latents = np.random.randn(minibatch_size, *Gs.input_shape[1:])
135 | fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
136 | images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True)
137 | yield images
138 |
139 | def _get_random_labels_tf(self, minibatch_size):
140 | return self._get_dataset_obj().get_random_labels_tf(minibatch_size)
141 |
142 | #----------------------------------------------------------------------------
143 | # Group of multiple metrics.
144 |
145 | class MetricGroup:
146 | def __init__(self, metric_kwarg_list):
147 | self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list]
148 |
149 | def run(self, *args, **kwargs):
150 | for metric in self.metrics:
151 | metric.run(*args, **kwargs)
152 |
153 | def get_result_str(self):
154 | return ' '.join(metric.get_result_str() for metric in self.metrics)
155 |
156 | def update_autosummaries(self):
157 | for metric in self.metrics:
158 | metric.update_autosummaries()
159 |
160 | #----------------------------------------------------------------------------
161 | # Dummy metric for debugging purposes.
162 |
163 | class DummyMetric(MetricBase):
164 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
165 | _ = Gs, Gs_kwargs, num_gpus
166 | self._report_result(0.0)
167 |
168 | #----------------------------------------------------------------------------
169 |
--------------------------------------------------------------------------------
/metrics/metric_defaults.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Default metric definitions."""
8 |
9 | from dnnlib import EasyDict
10 |
11 | #----------------------------------------------------------------------------
12 |
13 | metric_defaults = EasyDict([(args.name, args) for args in [
14 | EasyDict(name='fid50k', func_name='metrics.frechet_inception_distance.FID', num_images=50000, minibatch_per_gpu=8),
15 | EasyDict(name='is50k', func_name='metrics.inception_score.IS', num_images=50000, num_splits=10, minibatch_per_gpu=8),
16 | EasyDict(name='ppl_zfull', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')),
17 | EasyDict(name='ppl_wfull', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')),
18 | EasyDict(name='ppl_zend', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')),
19 | EasyDict(name='ppl_wend', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')),
20 | EasyDict(name='ppl2_wend', func_name='metrics.perceptual_path_length.PPL', num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, minibatch_per_gpu=4, Gs_overrides=dict(dtype='float32', mapping_dtype='float32')),
21 | EasyDict(name='ls', func_name='metrics.linear_separability.LS', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4),
22 | EasyDict(name='pr50k3', func_name='metrics.precision_recall.PR', num_images=50000, nhood_size=3, minibatch_per_gpu=8, row_batch_size=10000, col_batch_size=10000),
23 | ]])
24 |
25 | #----------------------------------------------------------------------------
26 |
--------------------------------------------------------------------------------
/metrics/perceptual_path_length.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Perceptual Path Length (PPL)."""
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | import dnnlib.tflib as tflib
12 |
13 | from metrics import metric_base
14 | from training import misc
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | # Normalize batch of vectors.
19 | def normalize(v):
20 | return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
21 |
22 | # Spherical interpolation of a batch of vectors.
23 | def slerp(a, b, t):
24 | a = normalize(a)
25 | b = normalize(b)
26 | d = tf.reduce_sum(a * b, axis=-1, keepdims=True)
27 | p = t * tf.math.acos(d)
28 | c = normalize(b - d * a)
29 | d = a * tf.math.cos(p) + c * tf.math.sin(p)
30 | return normalize(d)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | class PPL(metric_base.MetricBase):
35 | def __init__(self, num_samples, epsilon, space, sampling, crop, minibatch_per_gpu, Gs_overrides, **kwargs):
36 | assert space in ['z', 'w']
37 | assert sampling in ['full', 'end']
38 | super().__init__(**kwargs)
39 | self.num_samples = num_samples
40 | self.epsilon = epsilon
41 | self.space = space
42 | self.sampling = sampling
43 | self.crop = crop
44 | self.minibatch_per_gpu = minibatch_per_gpu
45 | self.Gs_overrides = Gs_overrides
46 |
47 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
48 | Gs_kwargs = dict(Gs_kwargs)
49 | Gs_kwargs.update(self.Gs_overrides)
50 | minibatch_size = num_gpus * self.minibatch_per_gpu
51 |
52 | # Construct TensorFlow graph.
53 | distance_expr = []
54 | for gpu_idx in range(num_gpus):
55 | with tf.device('/gpu:%d' % gpu_idx):
56 | Gs_clone = Gs.clone()
57 | noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')]
58 |
59 | # Generate random latents and interpolation t-values.
60 | lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:])
61 | lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0)
62 | labels = tf.reshape(tf.tile(self._get_random_labels_tf(self.minibatch_per_gpu), [1, 2]), [self.minibatch_per_gpu * 2, -1])
63 |
64 | # Interpolate in W or Z.
65 | if self.space == 'w':
66 | dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, labels, **Gs_kwargs)
67 | dlat_t01 = tf.cast(dlat_t01, tf.float32)
68 | dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]
69 | dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis])
70 | dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)
71 | dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape)
72 | else: # space == 'z'
73 | lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]
74 | lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])
75 | lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon)
76 | lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape)
77 | dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, labels, **Gs_kwargs)
78 |
79 | # Synthesize images.
80 | with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch
81 | images = Gs_clone.components.synthesis.get_output_for(dlat_e01, randomize_noise=False, **Gs_kwargs)
82 | images = tf.cast(images, tf.float32)
83 |
84 | # Crop only the face region.
85 | if self.crop:
86 | c = int(images.shape[2] // 8)
87 | images = images[:, :, c*3 : c*7, c*2 : c*6]
88 |
89 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
90 | factor = images.shape[2] // 256
91 | if factor > 1:
92 | images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
93 | images = tf.reduce_mean(images, axis=[3,5])
94 |
95 | # Scale dynamic range from [-1,1] to [0,255] for VGG.
96 | images = (images + 1) * (255 / 2)
97 |
98 | # Evaluate perceptual distance.
99 | img_e0, img_e1 = images[0::2], images[1::2]
100 | distance_measure = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/vgg16_zhang_perceptual.pkl')
101 | distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2))
102 |
103 | # Sampling loop.
104 | all_distances = []
105 | for begin in range(0, self.num_samples, minibatch_size):
106 | self._report_progress(begin, self.num_samples)
107 | all_distances += tflib.run(distance_expr)
108 | all_distances = np.concatenate(all_distances, axis=0)
109 |
110 | # Reject outliers.
111 | lo = np.percentile(all_distances, 1, interpolation='lower')
112 | hi = np.percentile(all_distances, 99, interpolation='higher')
113 | filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances)
114 | self._report_result(np.mean(filtered_distances))
115 |
116 | #----------------------------------------------------------------------------
117 |
--------------------------------------------------------------------------------
/metrics/precision_recall.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Precision/Recall (PR)."""
8 |
9 | import os
10 | import numpy as np
11 | import tensorflow as tf
12 | import dnnlib
13 | import dnnlib.tflib as tflib
14 |
15 | from metrics import metric_base
16 | from training import misc
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def batch_pairwise_distances(U, V):
21 | """ Compute pairwise distances between two batches of feature vectors."""
22 | with tf.variable_scope('pairwise_dist_block'):
23 | # Squared norms of each row in U and V.
24 | norm_u = tf.reduce_sum(tf.square(U), 1)
25 | norm_v = tf.reduce_sum(tf.square(V), 1)
26 |
27 | # norm_u as a row and norm_v as a column vectors.
28 | norm_u = tf.reshape(norm_u, [-1, 1])
29 | norm_v = tf.reshape(norm_v, [1, -1])
30 |
31 | # Pairwise squared Euclidean distances.
32 | D = tf.maximum(norm_u - 2*tf.matmul(U, V, False, True) + norm_v, 0.0)
33 |
34 | return D
35 |
36 | #----------------------------------------------------------------------------
37 |
38 | class DistanceBlock():
39 | """Distance block."""
40 | def __init__(self, num_features, num_gpus):
41 | self.num_features = num_features
42 | self.num_gpus = num_gpus
43 |
44 | # Initialize TF graph to calculate pairwise distances.
45 | with tf.device('/cpu:0'):
46 | self._features_batch1 = tf.placeholder(tf.float16, shape=[None, self.num_features])
47 | self._features_batch2 = tf.placeholder(tf.float16, shape=[None, self.num_features])
48 | features_split2 = tf.split(self._features_batch2, self.num_gpus, axis=0)
49 | distances_split = []
50 | for gpu_idx in range(self.num_gpus):
51 | with tf.device('/gpu:%d' % gpu_idx):
52 | distances_split.append(batch_pairwise_distances(self._features_batch1, features_split2[gpu_idx]))
53 | self._distance_block = tf.concat(distances_split, axis=1)
54 |
55 | def pairwise_distances(self, U, V):
56 | """Evaluate pairwise distances between two batches of feature vectors."""
57 | return self._distance_block.eval(feed_dict={self._features_batch1: U, self._features_batch2: V})
58 |
59 | #----------------------------------------------------------------------------
60 |
61 | class ManifoldEstimator():
62 | """Finds an estimate for the manifold of given feature vectors."""
63 | def __init__(self, distance_block, features, row_batch_size, col_batch_size, nhood_sizes, clamp_to_percentile=None):
64 | """Find an estimate of the manifold of given feature vectors."""
65 | num_images = features.shape[0]
66 | self.nhood_sizes = nhood_sizes
67 | self.num_nhoods = len(nhood_sizes)
68 | self.row_batch_size = row_batch_size
69 | self.col_batch_size = col_batch_size
70 | self._ref_features = features
71 | self._distance_block = distance_block
72 |
73 | # Estimate manifold of features by calculating distances to kth nearest neighbor of each sample.
74 | self.D = np.zeros([num_images, self.num_nhoods], dtype=np.float16)
75 | distance_batch = np.zeros([row_batch_size, num_images], dtype=np.float16)
76 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
77 |
78 | for begin1 in range(0, num_images, row_batch_size):
79 | end1 = min(begin1 + row_batch_size, num_images)
80 | row_batch = features[begin1:end1]
81 |
82 | for begin2 in range(0, num_images, col_batch_size):
83 | end2 = min(begin2 + col_batch_size, num_images)
84 | col_batch = features[begin2:end2]
85 |
86 | # Compute distances between batches.
87 | distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(row_batch, col_batch)
88 |
89 | # Find the kth nearest neighbor from the current batch.
90 | self.D[begin1:end1, :] = np.partition(distance_batch[0:end1-begin1, :], seq, axis=1)[:, self.nhood_sizes]
91 |
92 | if clamp_to_percentile is not None:
93 | max_distances = np.percentile(self.D, clamp_to_percentile, axis=0)
94 | self.D[self.D > max_distances] = 0 #max_distances # 0
95 |
96 | def evaluate(self, eval_features, return_realism=False, return_neighbors=False):
97 | """Evaluate if new feature vectors are in the estimated manifold."""
98 | num_eval_images = eval_features.shape[0]
99 | num_ref_images = self.D.shape[0]
100 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float16)
101 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
102 | #max_realism_score = np.zeros([num_eval_images,], dtype=np.float32)
103 | realism_score = np.zeros([num_eval_images,], dtype=np.float32)
104 | nearest_indices = np.zeros([num_eval_images,], dtype=np.int32)
105 |
106 | for begin1 in range(0, num_eval_images, self.row_batch_size):
107 | end1 = min(begin1 + self.row_batch_size, num_eval_images)
108 | feature_batch = eval_features[begin1:end1]
109 |
110 | for begin2 in range(0, num_ref_images, self.col_batch_size):
111 | end2 = min(begin2 + self.col_batch_size, num_ref_images)
112 | ref_batch = self._ref_features[begin2:end2]
113 |
114 | distance_batch[0:end1-begin1, begin2:end2] = self._distance_block.pairwise_distances(feature_batch, ref_batch)
115 |
116 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
117 | # If a feature vector is inside a hypersphere of some reference sample, then the new sample lies on the estimated manifold.
118 | # The radii of the hyperspheres are determined from distances of neighborhood size k.
119 | samples_in_manifold = distance_batch[0:end1-begin1, :, None] <= self.D
120 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
121 |
122 | #max_realism_score[begin1:end1] = np.max(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1)
123 | #nearest_indices[begin1:end1] = np.argmax(self.D[:, 0] / (distance_batch[0:end1-begin1, :] + 1e-18), axis=1)
124 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0:end1-begin1, :], axis=1)
125 | realism_score[begin1:end1] = self.D[nearest_indices[begin1:end1], 0] / np.min(distance_batch[0:end1-begin1, :], axis=1)
126 |
127 | if return_realism and return_neighbors:
128 | return batch_predictions, realism_score, nearest_indices
129 | elif return_realism:
130 | return batch_predictions, realism_score
131 | elif return_neighbors:
132 | return batch_predictions, nearest_indices
133 |
134 | return batch_predictions
135 |
136 | #----------------------------------------------------------------------------
137 |
138 | def knn_precision_recall_features(ref_features, eval_features, feature_net, nhood_sizes,
139 | row_batch_size, col_batch_size, num_gpus):
140 | """Calculates k-NN precision and recall for two sets of feature vectors."""
141 | state = dnnlib.EasyDict()
142 | #num_images = ref_features.shape[0]
143 | num_features = feature_net.output_shape[1]
144 | state.ref_features = ref_features
145 | state.eval_features = eval_features
146 |
147 | # Initialize DistanceBlock and ManifoldEstimators.
148 | distance_block = DistanceBlock(num_features, num_gpus)
149 | state.ref_manifold = ManifoldEstimator(distance_block, state.ref_features, row_batch_size, col_batch_size, nhood_sizes)
150 | state.eval_manifold = ManifoldEstimator(distance_block, state.eval_features, row_batch_size, col_batch_size, nhood_sizes)
151 |
152 | # Evaluate precision and recall using k-nearest neighbors.
153 | #print('Evaluating k-NN precision and recall with %i samples...' % num_images)
154 | #start = time.time()
155 |
156 | # Precision: How many points from eval_features are in ref_features manifold.
157 | state.precision, state.realism_scores, state.nearest_neighbors = state.ref_manifold.evaluate(state.eval_features, return_realism=True, return_neighbors=True)
158 | state.knn_precision = state.precision.mean(axis=0)
159 |
160 | # Recall: How many points from ref_features are in eval_features manifold.
161 | state.recall = state.eval_manifold.evaluate(state.ref_features)
162 | state.knn_recall = state.recall.mean(axis=0)
163 |
164 | #elapsed_time = time.time() - start
165 | #print('Done evaluation in: %gs' % elapsed_time)
166 |
167 | return state
168 |
169 | #----------------------------------------------------------------------------
170 |
171 | class PR(metric_base.MetricBase):
172 | def __init__(self, num_images, nhood_size, minibatch_per_gpu, row_batch_size, col_batch_size, **kwargs):
173 | super().__init__(**kwargs)
174 | self.num_images = num_images
175 | self.nhood_size = nhood_size
176 | self.minibatch_per_gpu = minibatch_per_gpu
177 | self.row_batch_size = row_batch_size
178 | self.col_batch_size = col_batch_size
179 |
180 | def _evaluate(self, Gs, Gs_kwargs, num_gpus):
181 | minibatch_size = num_gpus * self.minibatch_per_gpu
182 | feature_net = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/vgg16.pkl')
183 |
184 | # Calculate features for reals.
185 | cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
186 | os.makedirs(os.path.dirname(cache_file), exist_ok=True)
187 | if os.path.isfile(cache_file):
188 | ref_features = misc.load_pkl(cache_file)
189 | else:
190 | ref_features = np.empty([self.num_images, feature_net.output_shape[1]], dtype=np.float32)
191 | for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):
192 | begin = idx * minibatch_size
193 | end = min(begin + minibatch_size, self.num_images)
194 | ref_features[begin:end] = feature_net.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)
195 | if end == self.num_images:
196 | break
197 | misc.save_pkl(ref_features, cache_file)
198 |
199 | # Construct TensorFlow graph.
200 | result_expr = []
201 | for gpu_idx in range(num_gpus):
202 | with tf.device('/gpu:%d' % gpu_idx):
203 | Gs_clone = Gs.clone()
204 | feature_net_clone = feature_net.clone()
205 | latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
206 | labels = self._get_random_labels_tf(self.minibatch_per_gpu)
207 | images = Gs_clone.get_output_for(latents, labels, **Gs_kwargs)
208 | images = tflib.convert_images_to_uint8(images)
209 | result_expr.append(feature_net_clone.get_output_for(images))
210 |
211 | # Calculate features for fakes.
212 | eval_features = np.empty([self.num_images, feature_net.output_shape[1]], dtype=np.float32)
213 | for begin in range(0, self.num_images, minibatch_size):
214 | self._report_progress(begin, self.num_images)
215 | end = min(begin + minibatch_size, self.num_images)
216 | eval_features[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
217 |
218 | # Calculate precision and recall.
219 | state = knn_precision_recall_features(ref_features=ref_features, eval_features=eval_features, feature_net=feature_net,
220 | nhood_sizes=[self.nhood_size], row_batch_size=self.row_batch_size, col_batch_size=self.row_batch_size, num_gpus=num_gpus)
221 | self._report_result(state.knn_precision[0], suffix='_precision')
222 | self._report_result(state.knn_recall[0], suffix='_recall')
223 |
224 | #----------------------------------------------------------------------------
225 |
--------------------------------------------------------------------------------
/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """List of pre-trained StyleGAN2 networks located on Google Drive."""
8 |
9 | import pickle
10 | import dnnlib
11 | import dnnlib.tflib as tflib
12 |
13 | #----------------------------------------------------------------------------
14 | # StyleGAN2 Google Drive root: https://drive.google.com/open?id=1QHc-yF5C3DChRwSdZKcx1w6K8JvSxQi7
15 |
16 | gdrive_urls = {
17 | 'gdrive:networks/stylegan2-car-config-a.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-a.pkl',
18 | 'gdrive:networks/stylegan2-car-config-b.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-b.pkl',
19 | 'gdrive:networks/stylegan2-car-config-c.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-c.pkl',
20 | 'gdrive:networks/stylegan2-car-config-d.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-d.pkl',
21 | 'gdrive:networks/stylegan2-car-config-e.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-e.pkl',
22 | 'gdrive:networks/stylegan2-car-config-f.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-f.pkl',
23 | 'gdrive:networks/stylegan2-cat-config-a.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-a.pkl',
24 | 'gdrive:networks/stylegan2-cat-config-f.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl',
25 | 'gdrive:networks/stylegan2-church-config-a.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-church-config-a.pkl',
26 | 'gdrive:networks/stylegan2-church-config-f.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-church-config-f.pkl',
27 | 'gdrive:networks/stylegan2-ffhq-config-a.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-a.pkl',
28 | 'gdrive:networks/stylegan2-ffhq-config-b.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-b.pkl',
29 | 'gdrive:networks/stylegan2-ffhq-config-c.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-c.pkl',
30 | 'gdrive:networks/stylegan2-ffhq-config-d.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-d.pkl',
31 | 'gdrive:networks/stylegan2-ffhq-config-e.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-e.pkl',
32 | 'gdrive:networks/stylegan2-ffhq-config-f.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-f.pkl',
33 | 'gdrive:networks/stylegan2-horse-config-a.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-horse-config-a.pkl',
34 | 'gdrive:networks/stylegan2-horse-config-f.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-horse-config-f.pkl',
35 | 'gdrive:networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl',
36 | 'gdrive:networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl',
37 | 'gdrive:networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl',
38 | 'gdrive:networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl',
39 | 'gdrive:networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl',
40 | 'gdrive:networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl',
41 | 'gdrive:networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl',
42 | 'gdrive:networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl',
43 | 'gdrive:networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl',
44 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl',
45 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl',
46 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl',
47 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl',
48 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl',
49 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl',
50 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl',
51 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl',
52 | 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl',
53 | }
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | def get_path_or_url(path_or_gdrive_path):
58 | return gdrive_urls.get(path_or_gdrive_path, path_or_gdrive_path)
59 |
60 | #----------------------------------------------------------------------------
61 |
62 | _cached_networks = dict()
63 |
64 | def load_networks(path_or_gdrive_path):
65 | path_or_url = get_path_or_url(path_or_gdrive_path)
66 | if path_or_url in _cached_networks:
67 | return _cached_networks[path_or_url]
68 |
69 | if dnnlib.util.is_url(path_or_url):
70 | stream = dnnlib.util.open_url(path_or_url, cache_dir='.stylegan2-cache')
71 | else:
72 | stream = open(path_or_url, 'rb')
73 |
74 | tflib.init_tf()
75 | with stream:
76 | G, D, Gs = pickle.load(stream, encoding='latin1')
77 | _cached_networks[path_or_url] = G, D, Gs
78 | return G, D, Gs
79 |
80 | #----------------------------------------------------------------------------
81 |
--------------------------------------------------------------------------------
/projector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 | import dnnlib
10 | import dnnlib.tflib as tflib
11 |
12 | from training import misc
13 |
14 | #----------------------------------------------------------------------------
15 |
16 | class Projector:
17 | def __init__(self):
18 | self.num_steps = 1000
19 | self.dlatent_avg_samples = 10000
20 | self.initial_learning_rate = 0.1
21 | self.initial_noise_factor = 0.05
22 | self.lr_rampdown_length = 0.25
23 | self.lr_rampup_length = 0.05
24 | self.noise_ramp_length = 0.75
25 | self.regularize_noise_weight = 1e5
26 | self.verbose = False
27 | self.clone_net = True
28 |
29 | self._Gs = None
30 | self._minibatch_size = None
31 | self._dlatent_avg = None
32 | self._dlatent_std = None
33 | self._noise_vars = None
34 | self._noise_init_op = None
35 | self._noise_normalize_op = None
36 | self._dlatents_var = None
37 | self._noise_in = None
38 | self._dlatents_expr = None
39 | self._images_expr = None
40 | self._target_images_var = None
41 | self._lpips = None
42 | self._dist = None
43 | self._loss = None
44 | self._reg_sizes = None
45 | self._lrate_in = None
46 | self._opt = None
47 | self._opt_step = None
48 | self._cur_step = None
49 |
50 | def _info(self, *args):
51 | if self.verbose:
52 | print('Projector:', *args)
53 |
54 | def set_network(self, Gs, minibatch_size=1):
55 | assert minibatch_size == 1
56 | self._Gs = Gs
57 | self._minibatch_size = minibatch_size
58 | if self._Gs is None:
59 | return
60 | if self.clone_net:
61 | self._Gs = self._Gs.clone()
62 |
63 | # Find dlatent stats.
64 | self._info('Finding W midpoint and stddev using %d samples...' % self.dlatent_avg_samples)
65 | latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
66 | dlatent_samples = self._Gs.components.mapping.run(latent_samples, None)[:, :1, :] # [N, 1, 512]
67 | self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 1, 512]
68 | self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5
69 | self._info('std = %g' % self._dlatent_std)
70 |
71 | # Find noise inputs.
72 | self._info('Setting up noise inputs...')
73 | self._noise_vars = []
74 | noise_init_ops = []
75 | noise_normalize_ops = []
76 | while True:
77 | n = 'G_synthesis/noise%d' % len(self._noise_vars)
78 | if not n in self._Gs.vars:
79 | break
80 | v = self._Gs.vars[n]
81 | self._noise_vars.append(v)
82 | noise_init_ops.append(tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32)))
83 | noise_mean = tf.reduce_mean(v)
84 | noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5
85 | noise_normalize_ops.append(tf.assign(v, (v - noise_mean) / noise_std))
86 | self._info(n, v)
87 | self._noise_init_op = tf.group(*noise_init_ops)
88 | self._noise_normalize_op = tf.group(*noise_normalize_ops)
89 |
90 | # Image output graph.
91 | self._info('Building image output graph...')
92 | self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var')
93 | self._noise_in = tf.placeholder(tf.float32, [], name='noise_in')
94 | dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._noise_in
95 | self._dlatents_expr = tf.tile(self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1])
96 | self._images_expr = self._Gs.components.synthesis.get_output_for(self._dlatents_expr, randomize_noise=False)
97 |
98 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
99 | proc_images_expr = (self._images_expr + 1) * (255 / 2)
100 | sh = proc_images_expr.shape.as_list()
101 | if sh[2] > 256:
102 | factor = sh[2] // 256
103 | proc_images_expr = tf.reduce_mean(tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5])
104 |
105 | # Loss graph.
106 | self._info('Building loss graph...')
107 | self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var')
108 | if self._lpips is None:
109 | self._lpips = misc.load_pkl('https://nvlabs-fi-cdn.nvidia.com/stylegan/networks/metrics/vgg16_zhang_perceptual.pkl')
110 | self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var)
111 | self._loss = tf.reduce_sum(self._dist)
112 |
113 | # Noise regularization graph.
114 | self._info('Building noise regularization graph...')
115 | reg_loss = 0.0
116 | for v in self._noise_vars:
117 | sz = v.shape[2]
118 | while True:
119 | reg_loss += tf.reduce_mean(v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(v * tf.roll(v, shift=1, axis=2))**2
120 | if sz <= 8:
121 | break # Small enough already
122 | v = tf.reshape(v, [1, 1, sz//2, 2, sz//2, 2]) # Downscale
123 | v = tf.reduce_mean(v, axis=[3, 5])
124 | sz = sz // 2
125 | self._loss += reg_loss * self.regularize_noise_weight
126 |
127 | # Optimizer.
128 | self._info('Setting up optimizer...')
129 | self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in')
130 | self._opt = dnnlib.tflib.Optimizer(learning_rate=self._lrate_in)
131 | self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars)
132 | self._opt_step = self._opt.apply_updates()
133 |
134 | def run(self, target_images):
135 | # Run to completion.
136 | self.start(target_images)
137 | while self._cur_step < self.num_steps:
138 | self.step()
139 |
140 | # Collect results.
141 | pres = dnnlib.EasyDict()
142 | pres.dlatents = self.get_dlatents()
143 | pres.noises = self.get_noises()
144 | pres.images = self.get_images()
145 | return pres
146 |
147 | def start(self, target_images):
148 | assert self._Gs is not None
149 |
150 | # Prepare target images.
151 | self._info('Preparing target images...')
152 | target_images = np.asarray(target_images, dtype='float32')
153 | target_images = (target_images + 1) * (255 / 2)
154 | sh = target_images.shape
155 | assert sh[0] == self._minibatch_size
156 | if sh[2] > self._target_images_var.shape[2]:
157 | factor = sh[2] // self._target_images_var.shape[2]
158 | target_images = np.reshape(target_images, [-1, sh[1], sh[2] // factor, factor, sh[3] // factor, factor]).mean((3, 5))
159 |
160 | # Initialize optimization state.
161 | self._info('Initializing optimization state...')
162 | tflib.set_vars({self._target_images_var: target_images, self._dlatents_var: np.tile(self._dlatent_avg, [self._minibatch_size, 1, 1])})
163 | tflib.run(self._noise_init_op)
164 | self._opt.reset_optimizer_state()
165 | self._cur_step = 0
166 |
167 | def step(self):
168 | assert self._cur_step is not None
169 | if self._cur_step >= self.num_steps:
170 | return
171 | if self._cur_step == 0:
172 | self._info('Running...')
173 |
174 | # Hyperparameters.
175 | t = self._cur_step / self.num_steps
176 | noise_strength = self._dlatent_std * self.initial_noise_factor * max(0.0, 1.0 - t / self.noise_ramp_length) ** 2
177 | lr_ramp = min(1.0, (1.0 - t) / self.lr_rampdown_length)
178 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
179 | lr_ramp = lr_ramp * min(1.0, t / self.lr_rampup_length)
180 | learning_rate = self.initial_learning_rate * lr_ramp
181 |
182 | # Train.
183 | feed_dict = {self._noise_in: noise_strength, self._lrate_in: learning_rate}
184 | _, dist_value, loss_value = tflib.run([self._opt_step, self._dist, self._loss], feed_dict)
185 | tflib.run(self._noise_normalize_op)
186 |
187 | # Print status.
188 | self._cur_step += 1
189 | if self._cur_step == self.num_steps or self._cur_step % 10 == 0:
190 | self._info('%-8d%-12g%-12g' % (self._cur_step, dist_value, loss_value))
191 | if self._cur_step == self.num_steps:
192 | self._info('Done.')
193 |
194 | def get_cur_step(self):
195 | return self._cur_step
196 |
197 | def get_dlatents(self):
198 | return tflib.run(self._dlatents_expr, {self._noise_in: 0})
199 |
200 | def get_noises(self):
201 | return tflib.run(self._noise_vars)
202 |
203 | def get_images(self):
204 | return tflib.run(self._images_expr, {self._noise_in: 0})
205 |
206 | #----------------------------------------------------------------------------
207 |
--------------------------------------------------------------------------------
/run_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import argparse
8 | import numpy as np
9 | import PIL.Image
10 | import dnnlib
11 | import dnnlib.tflib as tflib
12 | import re
13 | import sys
14 |
15 | import pretrained_networks
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | def generate_images(network_pkl, seeds, truncation_psi):
20 | print('Loading networks from "%s"...' % network_pkl)
21 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
22 | noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
23 |
24 | Gs_kwargs = dnnlib.EasyDict()
25 | Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
26 | Gs_kwargs.randomize_noise = False
27 | if truncation_psi is not None:
28 | Gs_kwargs.truncation_psi = truncation_psi
29 |
30 | for seed_idx, seed in enumerate(seeds):
31 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
32 | rnd = np.random.RandomState(seed)
33 | z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
34 | tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
35 | images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
36 | PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed))
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, minibatch_size=4):
41 | print('Loading networks from "%s"...' % network_pkl)
42 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
43 | w_avg = Gs.get_var('dlatent_avg') # [component]
44 |
45 | Gs_syn_kwargs = dnnlib.EasyDict()
46 | Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
47 | Gs_syn_kwargs.randomize_noise = False
48 | Gs_syn_kwargs.minibatch_size = minibatch_size
49 |
50 | print('Generating W vectors...')
51 | all_seeds = list(set(row_seeds + col_seeds))
52 | all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds]) # [minibatch, component]
53 | all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component]
54 | all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component]
55 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component]
56 |
57 | print('Generating images...')
58 | all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel]
59 | image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
60 |
61 | print('Generating style-mixed images...')
62 | for row_seed in row_seeds:
63 | for col_seed in col_seeds:
64 | w = w_dict[row_seed].copy()
65 | w[col_styles] = w_dict[col_seed][col_styles]
66 | image = Gs.components.synthesis.run(w[np.newaxis], **Gs_syn_kwargs)[0]
67 | image_dict[(row_seed, col_seed)] = image
68 |
69 | print('Saving images...')
70 | for (row_seed, col_seed), image in image_dict.items():
71 | PIL.Image.fromarray(image, 'RGB').save(dnnlib.make_run_dir_path('%d-%d.png' % (row_seed, col_seed)))
72 |
73 | print('Saving image grid...')
74 | _N, _C, H, W = Gs.output_shape
75 | canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
76 | for row_idx, row_seed in enumerate([None] + row_seeds):
77 | for col_idx, col_seed in enumerate([None] + col_seeds):
78 | if row_seed is None and col_seed is None:
79 | continue
80 | key = (row_seed, col_seed)
81 | if row_seed is None:
82 | key = (col_seed, col_seed)
83 | if col_seed is None:
84 | key = (row_seed, row_seed)
85 | canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
86 | canvas.save(dnnlib.make_run_dir_path('grid.png'))
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | def _parse_num_range(s):
91 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
92 |
93 | range_re = re.compile(r'^(\d+)-(\d+)$')
94 | m = range_re.match(s)
95 | if m:
96 | return list(range(int(m.group(1)), int(m.group(2))+1))
97 | vals = s.split(',')
98 | return [int(x) for x in vals]
99 |
100 | #----------------------------------------------------------------------------
101 |
102 | _examples = '''examples:
103 |
104 | # Generate ffhq uncurated images (matches paper Figure 12)
105 | python %(prog)s generate-images --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --seeds=6600-6625 --truncation-psi=0.5
106 |
107 | # Generate ffhq curated images (matches paper Figure 11)
108 | python %(prog)s generate-images --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --seeds=66,230,389,1518 --truncation-psi=1.0
109 |
110 | # Generate uncurated car images (matches paper Figure 12)
111 | python %(prog)s generate-images --network=gdrive:networks/stylegan2-car-config-f.pkl --seeds=6000-6025 --truncation-psi=0.5
112 |
113 | # Generate style mixing example (matches style mixing video clip)
114 | python %(prog)s style-mixing-example --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --row-seeds=85,100,75,458,1500 --col-seeds=55,821,1789,293 --truncation-psi=1.0
115 | '''
116 |
117 | #----------------------------------------------------------------------------
118 |
119 | def main():
120 | parser = argparse.ArgumentParser(
121 | description='''StyleGAN2 generator.
122 |
123 | Run 'python %(prog)s --help' for subcommand help.''',
124 | epilog=_examples,
125 | formatter_class=argparse.RawDescriptionHelpFormatter
126 | )
127 |
128 | subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
129 |
130 | parser_generate_images = subparsers.add_parser('generate-images', help='Generate images')
131 | parser_generate_images.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
132 | parser_generate_images.add_argument('--seeds', type=_parse_num_range, help='List of random seeds', required=True)
133 | parser_generate_images.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5)
134 | parser_generate_images.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
135 |
136 | parser_style_mixing_example = subparsers.add_parser('style-mixing-example', help='Generate style mixing video')
137 | parser_style_mixing_example.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
138 | parser_style_mixing_example.add_argument('--row-seeds', type=_parse_num_range, help='Random seeds to use for image rows', required=True)
139 | parser_style_mixing_example.add_argument('--col-seeds', type=_parse_num_range, help='Random seeds to use for image columns', required=True)
140 | parser_style_mixing_example.add_argument('--col-styles', type=_parse_num_range, help='Style layer range (default: %(default)s)', default='0-6')
141 | parser_style_mixing_example.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=0.5)
142 | parser_style_mixing_example.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
143 |
144 | args = parser.parse_args()
145 | kwargs = vars(args)
146 | subcmd = kwargs.pop('command')
147 |
148 | if subcmd is None:
149 | print ('Error: missing subcommand. Re-run with --help for usage.')
150 | sys.exit(1)
151 |
152 | sc = dnnlib.SubmitConfig()
153 | sc.num_gpus = 1
154 | sc.submit_target = dnnlib.SubmitTarget.LOCAL
155 | sc.local.do_not_copy_source_files = True
156 | sc.run_dir_root = kwargs.pop('result_dir')
157 | sc.run_desc = subcmd
158 |
159 | func_name_map = {
160 | 'generate-images': 'run_generator.generate_images',
161 | 'style-mixing-example': 'run_generator.style_mixing_example'
162 | }
163 | dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs)
164 |
165 | #----------------------------------------------------------------------------
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/run_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import argparse
8 | import os
9 | import sys
10 |
11 | import dnnlib
12 | import dnnlib.tflib as tflib
13 |
14 | import pretrained_networks
15 | from metrics import metric_base
16 | from metrics.metric_defaults import metric_defaults
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def run(network_pkl, metrics, dataset, data_dir, mirror_augment):
21 | print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl))
22 | tflib.init_tf()
23 | network_pkl = pretrained_networks.get_path_or_url(network_pkl)
24 | dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0)
25 | num_gpus = dnnlib.submit_config.num_gpus
26 | metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics])
27 | metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus)
28 |
29 | #----------------------------------------------------------------------------
30 |
31 | def _str_to_bool(v):
32 | if isinstance(v, bool):
33 | return v
34 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
35 | return True
36 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
37 | return False
38 | else:
39 | raise argparse.ArgumentTypeError('Boolean value expected.')
40 |
41 | #----------------------------------------------------------------------------
42 |
43 | _examples = '''examples:
44 |
45 | python %(prog)s --data-dir=~/datasets --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --metrics=fid50k,ppl_wend --dataset=ffhq --mirror-augment=true
46 |
47 | valid metrics:
48 |
49 | ''' + ', '.join(sorted([x for x in metric_defaults.keys()])) + '''
50 | '''
51 |
52 | def main():
53 | parser = argparse.ArgumentParser(
54 | description='Run StyleGAN2 metrics.',
55 | epilog=_examples,
56 | formatter_class=argparse.RawDescriptionHelpFormatter
57 | )
58 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
59 | parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
60 | parser.add_argument('--metrics', help='Metrics to compute (default: %(default)s)', default='fid50k', type=lambda x: x.split(','))
61 | parser.add_argument('--dataset', help='Training dataset', required=True)
62 | parser.add_argument('--data-dir', help='Dataset root directory', required=True)
63 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, type=_str_to_bool, metavar='BOOL')
64 | parser.add_argument('--num-gpus', help='Number of GPUs to use', type=int, default=1, metavar='N')
65 |
66 | args = parser.parse_args()
67 |
68 | if not os.path.exists(args.data_dir):
69 | print ('Error: dataset root directory does not exist.')
70 | sys.exit(1)
71 |
72 | kwargs = vars(args)
73 | sc = dnnlib.SubmitConfig()
74 | sc.num_gpus = kwargs.pop('num_gpus')
75 | sc.submit_target = dnnlib.SubmitTarget.LOCAL
76 | sc.local.do_not_copy_source_files = True
77 | sc.run_dir_root = kwargs.pop('result_dir')
78 | sc.run_desc = 'run-metrics'
79 | dnnlib.submit_run(sc, 'run_metrics.run', **kwargs)
80 |
81 | #----------------------------------------------------------------------------
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
86 | #----------------------------------------------------------------------------
87 |
--------------------------------------------------------------------------------
/run_projector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import argparse
8 | import numpy as np
9 | import dnnlib
10 | import dnnlib.tflib as tflib
11 | import re
12 | import sys
13 |
14 | import projector
15 | import pretrained_networks
16 | from training import dataset
17 | from training import misc
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def project_image(proj, targets, png_prefix, num_snapshots):
22 | snapshot_steps = set(proj.num_steps - np.linspace(0, proj.num_steps, num_snapshots, endpoint=False, dtype=int))
23 | misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1,1])
24 | proj.start(targets)
25 | while proj.get_cur_step() < proj.num_steps:
26 | print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True)
27 | proj.step()
28 | if proj.get_cur_step() in snapshot_steps:
29 | misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1,1])
30 | print('\r%-30s\r' % '', end='', flush=True)
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def project_generated_images(network_pkl, seeds, num_snapshots, truncation_psi):
35 | print('Loading networks from "%s"...' % network_pkl)
36 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
37 | proj = projector.Projector()
38 | proj.set_network(Gs)
39 | noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
40 |
41 | Gs_kwargs = dnnlib.EasyDict()
42 | Gs_kwargs.randomize_noise = False
43 | Gs_kwargs.truncation_psi = truncation_psi
44 |
45 | for seed_idx, seed in enumerate(seeds):
46 | print('Projecting seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
47 | rnd = np.random.RandomState(seed)
48 | z = rnd.randn(1, *Gs.input_shape[1:])
49 | tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})
50 | images = Gs.run(z, None, **Gs_kwargs)
51 | project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('seed%04d-' % seed), num_snapshots=num_snapshots)
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def project_real_images(network_pkl, dataset_name, data_dir, num_images, num_snapshots):
56 | print('Loading networks from "%s"...' % network_pkl)
57 | _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
58 | proj = projector.Projector()
59 | proj.set_network(Gs)
60 |
61 | print('Loading images from "%s"...' % dataset_name)
62 | dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir=dataset_name, max_label_size=0, repeat=False, shuffle_mb=0)
63 | assert dataset_obj.shape == Gs.output_shape[1:]
64 |
65 | for image_idx in range(num_images):
66 | print('Projecting image %d/%d ...' % (image_idx, num_images))
67 | images, _labels = dataset_obj.get_minibatch_np(1)
68 | images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1])
69 | project_image(proj, targets=images, png_prefix=dnnlib.make_run_dir_path('image%04d-' % image_idx), num_snapshots=num_snapshots)
70 |
71 | #----------------------------------------------------------------------------
72 |
73 | def _parse_num_range(s):
74 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
75 |
76 | range_re = re.compile(r'^(\d+)-(\d+)$')
77 | m = range_re.match(s)
78 | if m:
79 | return list(range(int(m.group(1)), int(m.group(2))+1))
80 | vals = s.split(',')
81 | return [int(x) for x in vals]
82 |
83 | #----------------------------------------------------------------------------
84 |
85 | _examples = '''examples:
86 |
87 | # Project generated images
88 | python %(prog)s project-generated-images --network=gdrive:networks/stylegan2-car-config-f.pkl --seeds=0,1,5
89 |
90 | # Project real images
91 | python %(prog)s project-real-images --network=gdrive:networks/stylegan2-car-config-f.pkl --dataset=car --data-dir=~/datasets
92 |
93 | '''
94 |
95 | #----------------------------------------------------------------------------
96 |
97 | def main():
98 | parser = argparse.ArgumentParser(
99 | description='''StyleGAN2 projector.
100 |
101 | Run 'python %(prog)s --help' for subcommand help.''',
102 | epilog=_examples,
103 | formatter_class=argparse.RawDescriptionHelpFormatter
104 | )
105 |
106 | subparsers = parser.add_subparsers(help='Sub-commands', dest='command')
107 |
108 | project_generated_images_parser = subparsers.add_parser('project-generated-images', help='Project generated images')
109 | project_generated_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
110 | project_generated_images_parser.add_argument('--seeds', type=_parse_num_range, help='List of random seeds', default=range(3))
111 | project_generated_images_parser.add_argument('--num-snapshots', type=int, help='Number of snapshots (default: %(default)s)', default=5)
112 | project_generated_images_parser.add_argument('--truncation-psi', type=float, help='Truncation psi (default: %(default)s)', default=1.0)
113 | project_generated_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
114 |
115 | project_real_images_parser = subparsers.add_parser('project-real-images', help='Project real images')
116 | project_real_images_parser.add_argument('--network', help='Network pickle filename', dest='network_pkl', required=True)
117 | project_real_images_parser.add_argument('--data-dir', help='Dataset root directory', required=True)
118 | project_real_images_parser.add_argument('--dataset', help='Training dataset', dest='dataset_name', required=True)
119 | project_real_images_parser.add_argument('--num-snapshots', type=int, help='Number of snapshots (default: %(default)s)', default=5)
120 | project_real_images_parser.add_argument('--num-images', type=int, help='Number of images to project (default: %(default)s)', default=3)
121 | project_real_images_parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
122 |
123 | args = parser.parse_args()
124 | subcmd = args.command
125 | if subcmd is None:
126 | print ('Error: missing subcommand. Re-run with --help for usage.')
127 | sys.exit(1)
128 |
129 | kwargs = vars(args)
130 | sc = dnnlib.SubmitConfig()
131 | sc.num_gpus = 1
132 | sc.submit_target = dnnlib.SubmitTarget.LOCAL
133 | sc.local.do_not_copy_source_files = True
134 | sc.run_dir_root = kwargs.pop('result_dir')
135 | sc.run_desc = kwargs.pop('command')
136 |
137 | func_name_map = {
138 | 'project-generated-images': 'run_projector.project_generated_images',
139 | 'project-real-images': 'run_projector.project_real_images'
140 | }
141 | dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs)
142 |
143 | #----------------------------------------------------------------------------
144 |
145 | if __name__ == "__main__":
146 | main()
147 |
148 | #----------------------------------------------------------------------------
149 |
--------------------------------------------------------------------------------
/run_training.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | import argparse
8 | import copy
9 | import os
10 | import sys
11 |
12 | import dnnlib
13 | from dnnlib import EasyDict
14 |
15 | from metrics.metric_defaults import metric_defaults
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | _valid_configs = [
20 | # Table 1
21 | 'config-a', # Baseline StyleGAN
22 | 'config-b', # + Weight demodulation
23 | 'config-c', # + Lazy regularization
24 | 'config-d', # + Path length regularization
25 | 'config-e', # + No growing, new G & D arch.
26 | 'config-f', # + Large networks (default)
27 |
28 | # Table 2
29 | 'config-e-Gorig-Dorig', 'config-e-Gorig-Dresnet', 'config-e-Gorig-Dskip',
30 | 'config-e-Gresnet-Dorig', 'config-e-Gresnet-Dresnet', 'config-e-Gresnet-Dskip',
31 | 'config-e-Gskip-Dorig', 'config-e-Gskip-Dresnet', 'config-e-Gskip-Dskip',
32 | ]
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics):
37 | train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
38 | G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network.
39 | D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network.
40 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
41 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
42 | G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg') # Options for generator loss.
43 | D_loss = EasyDict(func_name='training.loss.D_logistic_r1') # Options for discriminator loss.
44 | sched = EasyDict() # Options for TrainingSchedule.
45 | grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid().
46 | sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run().
47 | tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf().
48 |
49 | train.data_dir = data_dir
50 | train.total_kimg = total_kimg
51 | train.mirror_augment = mirror_augment
52 | train.image_snapshot_ticks = train.network_snapshot_ticks = 10
53 | sched.G_lrate_base = sched.D_lrate_base = 0.002
54 | sched.minibatch_size_base = 32
55 | sched.minibatch_gpu_base = 4
56 | D_loss.gamma = 10
57 | metrics = [metric_defaults[x] for x in metrics]
58 | desc = 'stylegan2'
59 |
60 | desc += '-' + dataset
61 | dataset_args = EasyDict(tfrecord_dir=dataset)
62 |
63 | assert num_gpus in [1, 2, 4, 8]
64 | sc.num_gpus = num_gpus
65 | desc += '-%dgpu' % num_gpus
66 |
67 | assert config_id in _valid_configs
68 | desc += '-' + config_id
69 |
70 | # Configs A-E: Shrink networks to match original StyleGAN.
71 | if config_id != 'config-f':
72 | G.fmap_base = D.fmap_base = 8 << 10
73 |
74 | # Config E: Set gamma to 100 and override G & D architecture.
75 | if config_id.startswith('config-e'):
76 | D_loss.gamma = 100
77 | if 'Gorig' in config_id: G.architecture = 'orig'
78 | if 'Gskip' in config_id: G.architecture = 'skip' # (default)
79 | if 'Gresnet' in config_id: G.architecture = 'resnet'
80 | if 'Dorig' in config_id: D.architecture = 'orig'
81 | if 'Dskip' in config_id: D.architecture = 'skip'
82 | if 'Dresnet' in config_id: D.architecture = 'resnet' # (default)
83 |
84 | # Configs A-D: Enable progressive growing and switch to networks that support it.
85 | if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
86 | sched.lod_initial_resolution = 8
87 | sched.G_lrate_base = sched.D_lrate_base = 0.001
88 | sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
89 | sched.minibatch_size_base = 32 # (default)
90 | sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
91 | sched.minibatch_gpu_base = 4 # (default)
92 | sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
93 | G.synthesis_func = 'G_synthesis_stylegan_revised'
94 | D.func_name = 'training.networks_stylegan2.D_stylegan'
95 |
96 | # Configs A-C: Disable path length regularization.
97 | if config_id in ['config-a', 'config-b', 'config-c']:
98 | G_loss = EasyDict(func_name='training.loss.G_logistic_ns')
99 |
100 | # Configs A-B: Disable lazy regularization.
101 | if config_id in ['config-a', 'config-b']:
102 | train.lazy_regularization = False
103 |
104 | # Config A: Switch to original StyleGAN networks.
105 | if config_id == 'config-a':
106 | G = EasyDict(func_name='training.networks_stylegan.G_style')
107 | D = EasyDict(func_name='training.networks_stylegan.D_basic')
108 |
109 | if gamma is not None:
110 | D_loss.gamma = gamma
111 |
112 | sc.submit_target = dnnlib.SubmitTarget.LOCAL
113 | sc.local.do_not_copy_source_files = True
114 | kwargs = EasyDict(train)
115 | kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
116 | kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
117 | kwargs.submit_config = copy.deepcopy(sc)
118 | kwargs.submit_config.run_dir_root = result_dir
119 | kwargs.submit_config.run_desc = desc
120 | dnnlib.submit_run(**kwargs)
121 |
122 | #----------------------------------------------------------------------------
123 |
124 | def _str_to_bool(v):
125 | if isinstance(v, bool):
126 | return v
127 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
128 | return True
129 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
130 | return False
131 | else:
132 | raise argparse.ArgumentTypeError('Boolean value expected.')
133 |
134 | def _parse_comma_sep(s):
135 | if s is None or s.lower() == 'none' or s == '':
136 | return []
137 | return s.split(',')
138 |
139 | #----------------------------------------------------------------------------
140 |
141 | _examples = '''examples:
142 |
143 | # Train StyleGAN2 using the FFHQ dataset
144 | python %(prog)s --num-gpus=8 --data-dir=~/datasets --config=config-f --dataset=ffhq --mirror-augment=true
145 |
146 | valid configs:
147 |
148 | ''' + ', '.join(_valid_configs) + '''
149 |
150 | valid metrics:
151 |
152 | ''' + ', '.join(sorted([x for x in metric_defaults.keys()])) + '''
153 |
154 | '''
155 |
156 | def main():
157 | parser = argparse.ArgumentParser(
158 | description='Train StyleGAN2.',
159 | epilog=_examples,
160 | formatter_class=argparse.RawDescriptionHelpFormatter
161 | )
162 | parser.add_argument('--result-dir', help='Root directory for run results (default: %(default)s)', default='results', metavar='DIR')
163 | parser.add_argument('--data-dir', help='Dataset root directory', required=True)
164 | parser.add_argument('--dataset', help='Training dataset', required=True)
165 | parser.add_argument('--config', help='Training config (default: %(default)s)', default='config-f', required=True, dest='config_id', metavar='CONFIG')
166 | parser.add_argument('--num-gpus', help='Number of GPUs (default: %(default)s)', default=1, type=int, metavar='N')
167 | parser.add_argument('--total-kimg', help='Training length in thousands of images (default: %(default)s)', metavar='KIMG', default=25000, type=int)
168 | parser.add_argument('--gamma', help='R1 regularization weight (default is config dependent)', default=None, type=float)
169 | parser.add_argument('--mirror-augment', help='Mirror augment (default: %(default)s)', default=False, metavar='BOOL', type=_str_to_bool)
170 | parser.add_argument('--metrics', help='Comma-separated list of metrics or "none" (default: %(default)s)', default='fid50k', type=_parse_comma_sep)
171 |
172 | args = parser.parse_args()
173 |
174 | if not os.path.exists(args.data_dir):
175 | print ('Error: dataset root directory does not exist.')
176 | sys.exit(1)
177 |
178 | if args.config_id not in _valid_configs:
179 | print ('Error: --config value must be one of: ', ', '.join(_valid_configs))
180 | sys.exit(1)
181 |
182 | for metric in args.metrics:
183 | if metric not in metric_defaults:
184 | print ('Error: unknown metric \'%s\'' % metric)
185 | sys.exit(1)
186 |
187 | run(**vars(args))
188 |
189 | #----------------------------------------------------------------------------
190 |
191 | if __name__ == "__main__":
192 | main()
193 |
194 | #----------------------------------------------------------------------------
195 |
196 |
--------------------------------------------------------------------------------
/test_nvcc.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | void checkCudaError(cudaError_t err)
10 | {
11 | if (err != cudaSuccess)
12 | {
13 | printf("%s: %s\n", cudaGetErrorName(err), cudaGetErrorString(err));
14 | exit(1);
15 | }
16 | }
17 |
18 | __global__ void cudaKernel(void)
19 | {
20 | printf("GPU says hello.\n");
21 | }
22 |
23 | int main(void)
24 | {
25 | printf("CPU says hello.\n");
26 | checkCudaError(cudaLaunchKernel((void*)cudaKernel, 1, 1, NULL, 0, NULL));
27 | checkCudaError(cudaDeviceSynchronize());
28 | return 0;
29 | }
30 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | # empty
8 |
--------------------------------------------------------------------------------
/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Multi-resolution input data pipeline."""
8 |
9 | import os
10 | import glob
11 | import numpy as np
12 | import tensorflow as tf
13 | import dnnlib
14 | import dnnlib.tflib as tflib
15 |
16 | #----------------------------------------------------------------------------
17 | # Dataset class that loads data from tfrecords files.
18 |
19 | class TFRecordDataset:
20 | def __init__(self,
21 | tfrecord_dir, # Directory containing a collection of tfrecords files.
22 | resolution = None, # Dataset resolution, None = autodetect.
23 | label_file = None, # Relative path of the labels file, None = autodetect.
24 | max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components.
25 | max_images = None, # Maximum number of images to use, None = use all images.
26 | repeat = True, # Repeat dataset indefinitely?
27 | shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling.
28 | prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching.
29 | buffer_mb = 256, # Read buffer size (megabytes).
30 | num_threads = 2): # Number of concurrent threads.
31 |
32 | self.tfrecord_dir = tfrecord_dir
33 | self.resolution = None
34 | self.resolution_log2 = None
35 | self.shape = [] # [channels, height, width]
36 | self.dtype = 'uint8'
37 | self.dynamic_range = [0, 255]
38 | self.label_file = label_file
39 | self.label_size = None # components
40 | self.label_dtype = None
41 | self._np_labels = None
42 | self._tf_minibatch_in = None
43 | self._tf_labels_var = None
44 | self._tf_labels_dataset = None
45 | self._tf_datasets = dict()
46 | self._tf_iterator = None
47 | self._tf_init_ops = dict()
48 | self._tf_minibatch_np = None
49 | self._cur_minibatch = -1
50 | self._cur_lod = -1
51 |
52 | # List tfrecords files and inspect their shapes.
53 | assert os.path.isdir(self.tfrecord_dir)
54 | tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))
55 | assert len(tfr_files) >= 1
56 | tfr_shapes = []
57 | for tfr_file in tfr_files:
58 | tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
59 | for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):
60 | tfr_shapes.append(self.parse_tfrecord_np(record).shape)
61 | break
62 |
63 | # Autodetect label filename.
64 | if self.label_file is None:
65 | guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))
66 | if len(guess):
67 | self.label_file = guess[0]
68 | elif not os.path.isfile(self.label_file):
69 | guess = os.path.join(self.tfrecord_dir, self.label_file)
70 | if os.path.isfile(guess):
71 | self.label_file = guess
72 |
73 | # Determine shape and resolution.
74 | max_shape = max(tfr_shapes, key=np.prod)
75 | self.resolution = resolution if resolution is not None else max_shape[1]
76 | self.resolution_log2 = int(np.log2(self.resolution))
77 | self.shape = [max_shape[0], self.resolution, self.resolution]
78 | tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]
79 | assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
80 | assert all(shape[1] == shape[2] for shape in tfr_shapes)
81 | assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))
82 | assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))
83 |
84 | # Load labels.
85 | assert max_label_size == 'full' or max_label_size >= 0
86 | self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
87 | if self.label_file is not None and max_label_size != 0:
88 | self._np_labels = np.load(self.label_file)
89 | assert self._np_labels.ndim == 2
90 | if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:
91 | self._np_labels = self._np_labels[:, :max_label_size]
92 | if max_images is not None and self._np_labels.shape[0] > max_images:
93 | self._np_labels = self._np_labels[:max_images]
94 | self.label_size = self._np_labels.shape[1]
95 | self.label_dtype = self._np_labels.dtype.name
96 |
97 | # Build TF expressions.
98 | with tf.name_scope('Dataset'), tf.device('/cpu:0'):
99 | self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])
100 | self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')
101 | self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)
102 | for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):
103 | if tfr_lod < 0:
104 | continue
105 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)
106 | if max_images is not None:
107 | dset = dset.take(max_images)
108 | dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads)
109 | dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))
110 | bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize
111 | if shuffle_mb > 0:
112 | dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
113 | if repeat:
114 | dset = dset.repeat()
115 | if prefetch_mb > 0:
116 | dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
117 | dset = dset.batch(self._tf_minibatch_in)
118 | self._tf_datasets[tfr_lod] = dset
119 | self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)
120 | self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}
121 |
122 | def close(self):
123 | pass
124 |
125 | # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf().
126 | def configure(self, minibatch_size, lod=0):
127 | lod = int(np.floor(lod))
128 | assert minibatch_size >= 1 and lod in self._tf_datasets
129 | if self._cur_minibatch != minibatch_size or self._cur_lod != lod:
130 | self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})
131 | self._cur_minibatch = minibatch_size
132 | self._cur_lod = lod
133 |
134 | # Get next minibatch as TensorFlow expressions.
135 | def get_minibatch_tf(self): # => images, labels
136 | return self._tf_iterator.get_next()
137 |
138 | # Get next minibatch as NumPy arrays.
139 | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
140 | self.configure(minibatch_size, lod)
141 | with tf.name_scope('Dataset'):
142 | if self._tf_minibatch_np is None:
143 | self._tf_minibatch_np = self.get_minibatch_tf()
144 | return tflib.run(self._tf_minibatch_np)
145 |
146 | # Get random labels as TensorFlow expression.
147 | def get_random_labels_tf(self, minibatch_size): # => labels
148 | with tf.name_scope('Dataset'):
149 | if self.label_size > 0:
150 | with tf.device('/cpu:0'):
151 | return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
152 | return tf.zeros([minibatch_size, 0], self.label_dtype)
153 |
154 | # Get random labels as NumPy array.
155 | def get_random_labels_np(self, minibatch_size): # => labels
156 | if self.label_size > 0:
157 | return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])]
158 | return np.zeros([minibatch_size, 0], self.label_dtype)
159 |
160 | # Parse individual image from a tfrecords file into TensorFlow expression.
161 | @staticmethod
162 | def parse_tfrecord_tf(record):
163 | features = tf.parse_single_example(record, features={
164 | 'shape': tf.FixedLenFeature([3], tf.int64),
165 | 'data': tf.FixedLenFeature([], tf.string)})
166 | data = tf.decode_raw(features['data'], tf.uint8)
167 | return tf.reshape(data, features['shape'])
168 |
169 | # Parse individual image from a tfrecords file into NumPy array.
170 | @staticmethod
171 | def parse_tfrecord_np(record):
172 | ex = tf.train.Example()
173 | ex.ParseFromString(record)
174 | shape = ex.features.feature['shape'].int64_list.value # pylint: disable=no-member
175 | data = ex.features.feature['data'].bytes_list.value[0] # pylint: disable=no-member
176 | return np.fromstring(data, np.uint8).reshape(shape)
177 |
178 | #----------------------------------------------------------------------------
179 | # Helper func for constructing a dataset object using the given options.
180 |
181 | def load_dataset(class_name=None, data_dir=None, verbose=False, **kwargs):
182 | kwargs = dict(kwargs)
183 | if 'tfrecord_dir' in kwargs:
184 | if class_name is None:
185 | class_name = __name__ + '.TFRecordDataset'
186 | if data_dir is not None:
187 | kwargs['tfrecord_dir'] = os.path.join(data_dir, kwargs['tfrecord_dir'])
188 |
189 | assert class_name is not None
190 | if verbose:
191 | print('Streaming data using %s...' % class_name)
192 | dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs)
193 | if verbose:
194 | print('Dataset shape =', np.int32(dataset.shape).tolist())
195 | print('Dynamic range =', dataset.dynamic_range)
196 | print('Label size =', dataset.label_size)
197 | return dataset
198 |
199 | #----------------------------------------------------------------------------
200 |
--------------------------------------------------------------------------------
/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Loss functions."""
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | import dnnlib.tflib as tflib
12 | from dnnlib.tflib.autosummary import autosummary
13 |
14 | #----------------------------------------------------------------------------
15 | # Logistic loss from the paper
16 | # "Generative Adversarial Nets", Goodfellow et al. 2014
17 |
18 | def G_logistic(G, D, opt, training_set, minibatch_size):
19 | _ = opt
20 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
21 | labels = training_set.get_random_labels_tf(minibatch_size)
22 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
23 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
24 | loss = -tf.nn.softplus(fake_scores_out) # log(1-sigmoid(fake_scores_out)) # pylint: disable=invalid-unary-operand-type
25 | return loss, None
26 |
27 | def G_logistic_ns(G, D, opt, training_set, minibatch_size):
28 | _ = opt
29 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
30 | labels = training_set.get_random_labels_tf(minibatch_size)
31 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
32 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
33 | loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
34 | return loss, None
35 |
36 | def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels):
37 | _ = opt, training_set
38 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
39 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
40 | real_scores_out = D.get_output_for(reals, labels, is_training=True)
41 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
42 | real_scores_out = autosummary('Loss/scores/real', real_scores_out)
43 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
44 | loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
45 | loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
46 | return loss, None
47 |
48 | #----------------------------------------------------------------------------
49 | # R1 and R2 regularizers from the paper
50 | # "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018
51 |
52 | def D_logistic_r1(G, D, opt, training_set, minibatch_size, reals, labels, gamma=10.0):
53 | _ = opt, training_set
54 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
55 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
56 | real_scores_out = D.get_output_for(reals, labels, is_training=True)
57 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
58 | real_scores_out = autosummary('Loss/scores/real', real_scores_out)
59 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
60 | loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
61 | loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
62 |
63 | with tf.name_scope('GradientPenalty'):
64 | real_grads = tf.gradients(tf.reduce_sum(real_scores_out), [reals])[0]
65 | gradient_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])
66 | gradient_penalty = autosummary('Loss/gradient_penalty', gradient_penalty)
67 | reg = gradient_penalty * (gamma * 0.5)
68 | return loss, reg
69 |
70 | def D_logistic_r2(G, D, opt, training_set, minibatch_size, reals, labels, gamma=10.0):
71 | _ = opt, training_set
72 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
73 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
74 | real_scores_out = D.get_output_for(reals, labels, is_training=True)
75 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
76 | real_scores_out = autosummary('Loss/scores/real', real_scores_out)
77 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
78 | loss = tf.nn.softplus(fake_scores_out) # -log(1-sigmoid(fake_scores_out))
79 | loss += tf.nn.softplus(-real_scores_out) # -log(sigmoid(real_scores_out)) # pylint: disable=invalid-unary-operand-type
80 |
81 | with tf.name_scope('GradientPenalty'):
82 | fake_grads = tf.gradients(tf.reduce_sum(fake_scores_out), [fake_images_out])[0]
83 | gradient_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3])
84 | gradient_penalty = autosummary('Loss/gradient_penalty', gradient_penalty)
85 | reg = gradient_penalty * (gamma * 0.5)
86 | return loss, reg
87 |
88 | #----------------------------------------------------------------------------
89 | # WGAN loss from the paper
90 | # "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017
91 |
92 | def G_wgan(G, D, opt, training_set, minibatch_size):
93 | _ = opt
94 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
95 | labels = training_set.get_random_labels_tf(minibatch_size)
96 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
97 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
98 | loss = -fake_scores_out
99 | return loss, None
100 |
101 | def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, wgan_epsilon=0.001):
102 | _ = opt, training_set
103 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
104 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
105 | real_scores_out = D.get_output_for(reals, labels, is_training=True)
106 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
107 | real_scores_out = autosummary('Loss/scores/real', real_scores_out)
108 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
109 | loss = fake_scores_out - real_scores_out
110 | with tf.name_scope('EpsilonPenalty'):
111 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
112 | loss += epsilon_penalty * wgan_epsilon
113 | return loss, None
114 |
115 | #----------------------------------------------------------------------------
116 | # WGAN-GP loss from the paper
117 | # "Improved Training of Wasserstein GANs", Gulrajani et al. 2017
118 |
119 | def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, wgan_lambda=10.0, wgan_epsilon=0.001, wgan_target=1.0):
120 | _ = opt, training_set
121 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
122 | fake_images_out = G.get_output_for(latents, labels, is_training=True)
123 | real_scores_out = D.get_output_for(reals, labels, is_training=True)
124 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
125 | real_scores_out = autosummary('Loss/scores/real', real_scores_out)
126 | fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)
127 | loss = fake_scores_out - real_scores_out
128 | with tf.name_scope('EpsilonPenalty'):
129 | epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))
130 | loss += epsilon_penalty * wgan_epsilon
131 |
132 | with tf.name_scope('GradientPenalty'):
133 | mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)
134 | mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)
135 | mixed_scores_out = D.get_output_for(mixed_images_out, labels, is_training=True)
136 | mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)
137 | mixed_grads = tf.gradients(tf.reduce_sum(mixed_scores_out), [mixed_images_out])[0]
138 | mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))
139 | mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)
140 | gradient_penalty = tf.square(mixed_norms - wgan_target)
141 | reg = gradient_penalty * (wgan_lambda / (wgan_target**2))
142 | return loss, reg
143 |
144 | #----------------------------------------------------------------------------
145 | # Non-saturating logistic loss with path length regularizer from the paper
146 | # "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019
147 |
148 | def G_logistic_ns_pathreg(G, D, opt, training_set, minibatch_size, pl_minibatch_shrink=2, pl_decay=0.01, pl_weight=2.0):
149 | _ = opt
150 | latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
151 | labels = training_set.get_random_labels_tf(minibatch_size)
152 | fake_images_out, fake_dlatents_out = G.get_output_for(latents, labels, is_training=True, return_dlatents=True)
153 | fake_scores_out = D.get_output_for(fake_images_out, labels, is_training=True)
154 | loss = tf.nn.softplus(-fake_scores_out) # -log(sigmoid(fake_scores_out))
155 |
156 | # Path length regularization.
157 | with tf.name_scope('PathReg'):
158 |
159 | # Evaluate the regularization term using a smaller minibatch to conserve memory.
160 | if pl_minibatch_shrink > 1:
161 | pl_minibatch = minibatch_size // pl_minibatch_shrink
162 | pl_latents = tf.random_normal([pl_minibatch] + G.input_shapes[0][1:])
163 | pl_labels = training_set.get_random_labels_tf(pl_minibatch)
164 | fake_images_out, fake_dlatents_out = G.get_output_for(pl_latents, pl_labels, is_training=True, return_dlatents=True)
165 |
166 | # Compute |J*y|.
167 | pl_noise = tf.random_normal(tf.shape(fake_images_out)) / np.sqrt(np.prod(G.output_shape[2:]))
168 | pl_grads = tf.gradients(tf.reduce_sum(fake_images_out * pl_noise), [fake_dlatents_out])[0]
169 | pl_lengths = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(pl_grads), axis=2), axis=1))
170 | pl_lengths = autosummary('Loss/pl_lengths', pl_lengths)
171 |
172 | # Track exponential moving average of |J*y|.
173 | with tf.control_dependencies(None):
174 | pl_mean_var = tf.Variable(name='pl_mean', trainable=False, initial_value=0.0, dtype=tf.float32)
175 | pl_mean = pl_mean_var + pl_decay * (tf.reduce_mean(pl_lengths) - pl_mean_var)
176 | pl_update = tf.assign(pl_mean_var, pl_mean)
177 |
178 | # Calculate (|J*y|-a)^2.
179 | with tf.control_dependencies([pl_update]):
180 | pl_penalty = tf.square(pl_lengths - pl_mean)
181 | pl_penalty = autosummary('Loss/pl_penalty', pl_penalty)
182 |
183 | # Apply weight.
184 | #
185 | # Note: The division in pl_noise decreases the weight by num_pixels, and the reduce_mean
186 | # in pl_lengths decreases it by num_affine_layers. The effective weight then becomes:
187 | #
188 | # gamma_pl = pl_weight / num_pixels / num_affine_layers
189 | # = 2 / (r^2) / (log2(r) * 2 - 2)
190 | # = 1 / (r^2 * (log2(r) - 1))
191 | # = ln(2) / (r^2 * (ln(r) - ln(2))
192 | #
193 | reg = pl_penalty * pl_weight
194 |
195 | return loss, reg
196 |
197 | #----------------------------------------------------------------------------
198 |
--------------------------------------------------------------------------------
/training/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://nvlabs.github.io/stylegan2/license.html
6 |
7 | """Miscellaneous utility functions."""
8 |
9 | import os
10 | import pickle
11 | import numpy as np
12 | import PIL.Image
13 | import PIL.ImageFont
14 | import dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Convenience wrappers for pickle that are able to load data produced by
18 | # older versions of the code, and from external URLs.
19 |
20 | def open_file_or_url(file_or_url):
21 | if dnnlib.util.is_url(file_or_url):
22 | return dnnlib.util.open_url(file_or_url, cache_dir='.stylegan2-cache')
23 | return open(file_or_url, 'rb')
24 |
25 | def load_pkl(file_or_url):
26 | with open_file_or_url(file_or_url) as file:
27 | return pickle.load(file, encoding='latin1')
28 |
29 | def save_pkl(obj, filename):
30 | with open(filename, 'wb') as file:
31 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
32 |
33 | #----------------------------------------------------------------------------
34 | # Image utils.
35 |
36 | def adjust_dynamic_range(data, drange_in, drange_out):
37 | if drange_in != drange_out:
38 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))
39 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
40 | data = data * scale + bias
41 | return data
42 |
43 | def create_image_grid(images, grid_size=None):
44 | assert images.ndim == 3 or images.ndim == 4
45 | num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2]
46 |
47 | if grid_size is not None:
48 | grid_w, grid_h = tuple(grid_size)
49 | else:
50 | grid_w = max(int(np.ceil(np.sqrt(num))), 1)
51 | grid_h = max((num - 1) // grid_w + 1, 1)
52 |
53 | grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype)
54 | for idx in range(num):
55 | x = (idx % grid_w) * img_w
56 | y = (idx // grid_w) * img_h
57 | grid[..., y : y + img_h, x : x + img_w] = images[idx]
58 | return grid
59 |
60 | def convert_to_pil_image(image, drange=[0,1]):
61 | assert image.ndim == 2 or image.ndim == 3
62 | if image.ndim == 3:
63 | if image.shape[0] == 1:
64 | image = image[0] # grayscale CHW => HW
65 | else:
66 | image = image.transpose(1, 2, 0) # CHW -> HWC
67 |
68 | image = adjust_dynamic_range(image, drange, [0,255])
69 | image = np.rint(image).clip(0, 255).astype(np.uint8)
70 | fmt = 'RGB' if image.ndim == 3 else 'L'
71 | return PIL.Image.fromarray(image, fmt)
72 |
73 | def save_image_grid(images, filename, drange=[0,1], grid_size=None):
74 | convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename)
75 |
76 | def apply_mirror_augment(minibatch):
77 | mask = np.random.rand(minibatch.shape[0]) < 0.5
78 | minibatch = np.array(minibatch)
79 | minibatch[mask] = minibatch[mask, :, :, ::-1]
80 | return minibatch
81 |
82 | #----------------------------------------------------------------------------
83 | # Loading data from previous training runs.
84 |
85 | def parse_config_for_previous_run(run_dir):
86 | with open(os.path.join(run_dir, 'submit_config.pkl'), 'rb') as f:
87 | data = pickle.load(f)
88 | data = data.get('run_func_kwargs', {})
89 | return dict(train=data, dataset=data.get('dataset_args', {}))
90 |
91 | #----------------------------------------------------------------------------
92 | # Size and contents of the image snapshot grids that are exported
93 | # periodically during training.
94 |
95 | def setup_snapshot_image_grid(training_set,
96 | size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
97 | layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.
98 |
99 | # Select size.
100 | gw = 1; gh = 1
101 | if size == '1080p':
102 | gw = np.clip(1920 // training_set.shape[2], 3, 32)
103 | gh = np.clip(1080 // training_set.shape[1], 2, 32)
104 | if size == '4k':
105 | gw = np.clip(3840 // training_set.shape[2], 7, 32)
106 | gh = np.clip(2160 // training_set.shape[1], 4, 32)
107 | if size == '8k':
108 | gw = np.clip(7680 // training_set.shape[2], 7, 32)
109 | gh = np.clip(4320 // training_set.shape[1], 4, 32)
110 |
111 | # Initialize data arrays.
112 | reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
113 | labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
114 |
115 | # Random layout.
116 | if layout == 'random':
117 | reals[:], labels[:] = training_set.get_minibatch_np(gw * gh)
118 |
119 | # Class-conditional layouts.
120 | class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4])
121 | if layout in class_layouts:
122 | bw, bh = class_layouts[layout]
123 | nw = (gw - 1) // bw + 1
124 | nh = (gh - 1) // bh + 1
125 | blocks = [[] for _i in range(nw * nh)]
126 | for _iter in range(1000000):
127 | real, label = training_set.get_minibatch_np(1)
128 | idx = np.argmax(label[0])
129 | while idx < len(blocks) and len(blocks[idx]) >= bw * bh:
130 | idx += training_set.label_size
131 | if idx < len(blocks):
132 | blocks[idx].append((real, label))
133 | if all(len(block) >= bw * bh for block in blocks):
134 | break
135 | for i, block in enumerate(blocks):
136 | for j, (real, label) in enumerate(block):
137 | x = (i % nw) * bw + j % bw
138 | y = (i // nw) * bh + j // bw
139 | if x < gw and y < gh:
140 | reals[x + y * gw] = real[0]
141 | labels[x + y * gw] = label[0]
142 |
143 | return (gw, gh), reals, labels
144 |
145 | #----------------------------------------------------------------------------
146 |
--------------------------------------------------------------------------------