├── LISENCE.txt
├── README.md
├── dnnlib
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── util.cpython-37.pyc
├── submission
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── run_context.cpython-37.pyc
│ │ └── submit.cpython-37.pyc
│ ├── _internal
│ │ └── run.py
│ ├── run_context.py
│ └── submit.py
├── tflib
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── autosummary.cpython-37.pyc
│ │ ├── network.cpython-37.pyc
│ │ ├── optimizer.cpython-37.pyc
│ │ └── tfutil.cpython-37.pyc
│ ├── autosummary.py
│ ├── network.py
│ ├── optimizer.py
│ └── tfutil.py
└── util.py
├── encoder
├── generator_model.py
├── model.py
├── perceptual_model.py
└── resnet.py
├── input
├── test1.jpg
├── test2.jpeg
├── test3.jpg
└── test4.jpeg
├── main.py
├── networks
└── download_weights.txt
├── pics
├── architecture.png
├── example_2kids.jpg
├── example_2wanghong.png
├── examples_mix.jpg
├── multi-model-solution.png
├── preview.jpg
├── single_input.png
└── single_output.png
├── project_image.py
├── project_image_without_optimizer.py
└── tools
├── face_alignment.py
├── functions.py
└── landmarks_detector.py
/LISENCE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, www.seeprettyface.com. All rights reserved.
2 |
3 |
4 | Attribution-NonCommercial 4.0 International
5 |
6 | =======================================================================
7 |
8 | Creative Commons Corporation ("Creative Commons") is not a law firm and
9 | does not provide legal services or legal advice. Distribution of
10 | Creative Commons public licenses does not create a lawyer-client or
11 | other relationship. Creative Commons makes its licenses and related
12 | information available on an "as-is" basis. Creative Commons gives no
13 | warranties regarding its licenses, any material licensed under their
14 | terms and conditions, or any related information. Creative Commons
15 | disclaims all liability for damages resulting from their use to the
16 | fullest extent possible.
17 |
18 | Using Creative Commons Public Licenses
19 |
20 | Creative Commons public licenses provide a standard set of terms and
21 | conditions that creators and other rights holders may use to share
22 | original works of authorship and other material subject to copyright
23 | and certain other rights specified in the public license below. The
24 | following considerations are for informational purposes only, are not
25 | exhaustive, and do not form part of our licenses.
26 |
27 | Considerations for licensors: Our public licenses are
28 | intended for use by those authorized to give the public
29 | permission to use material in ways otherwise restricted by
30 | copyright and certain other rights. Our licenses are
31 | irrevocable. Licensors should read and understand the terms
32 | and conditions of the license they choose before applying it.
33 | Licensors should also secure all rights necessary before
34 | applying our licenses so that the public can reuse the
35 | material as expected. Licensors should clearly mark any
36 | material not subject to the license. This includes other CC-
37 | licensed material, or material used under an exception or
38 | limitation to copyright. More considerations for licensors:
39 | wiki.creativecommons.org/Considerations_for_licensors
40 |
41 | Considerations for the public: By using one of our public
42 | licenses, a licensor grants the public permission to use the
43 | licensed material under specified terms and conditions. If
44 | the licensor's permission is not necessary for any reason--for
45 | example, because of any applicable exception or limitation to
46 | copyright--then that use is not regulated by the license. Our
47 | licenses grant only permissions under copyright and certain
48 | other rights that a licensor has authority to grant. Use of
49 | the licensed material may still be restricted for other
50 | reasons, including because others have copyright or other
51 | rights in the material. A licensor may make special requests,
52 | such as asking that all changes be marked or described.
53 | Although not required by our licenses, you are encouraged to
54 | respect those requests where reasonable. More_considerations
55 | for the public:
56 | wiki.creativecommons.org/Considerations_for_licensees
57 |
58 | =======================================================================
59 |
60 | Creative Commons Attribution-NonCommercial 4.0 International Public
61 | License
62 |
63 | By exercising the Licensed Rights (defined below), You accept and agree
64 | to be bound by the terms and conditions of this Creative Commons
65 | Attribution-NonCommercial 4.0 International Public License ("Public
66 | License"). To the extent this Public License may be interpreted as a
67 | contract, You are granted the Licensed Rights in consideration of Your
68 | acceptance of these terms and conditions, and the Licensor grants You
69 | such rights in consideration of benefits the Licensor receives from
70 | making the Licensed Material available under these terms and
71 | conditions.
72 |
73 |
74 | Section 1 -- Definitions.
75 |
76 | a. Adapted Material means material subject to Copyright and Similar
77 | Rights that is derived from or based upon the Licensed Material
78 | and in which the Licensed Material is translated, altered,
79 | arranged, transformed, or otherwise modified in a manner requiring
80 | permission under the Copyright and Similar Rights held by the
81 | Licensor. For purposes of this Public License, where the Licensed
82 | Material is a musical work, performance, or sound recording,
83 | Adapted Material is always produced where the Licensed Material is
84 | synched in timed relation with a moving image.
85 |
86 | b. Adapter's License means the license You apply to Your Copyright
87 | and Similar Rights in Your contributions to Adapted Material in
88 | accordance with the terms and conditions of this Public License.
89 |
90 | c. Copyright and Similar Rights means copyright and/or similar rights
91 | closely related to copyright including, without limitation,
92 | performance, broadcast, sound recording, and Sui Generis Database
93 | Rights, without regard to how the rights are labeled or
94 | categorized. For purposes of this Public License, the rights
95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
96 | Rights.
97 | d. Effective Technological Measures means those measures that, in the
98 | absence of proper authority, may not be circumvented under laws
99 | fulfilling obligations under Article 11 of the WIPO Copyright
100 | Treaty adopted on December 20, 1996, and/or similar international
101 | agreements.
102 |
103 | e. Exceptions and Limitations means fair use, fair dealing, and/or
104 | any other exception or limitation to Copyright and Similar Rights
105 | that applies to Your use of the Licensed Material.
106 |
107 | f. Licensed Material means the artistic or literary work, database,
108 | or other material to which the Licensor applied this Public
109 | License.
110 |
111 | g. Licensed Rights means the rights granted to You subject to the
112 | terms and conditions of this Public License, which are limited to
113 | all Copyright and Similar Rights that apply to Your use of the
114 | Licensed Material and that the Licensor has authority to license.
115 |
116 | h. Licensor means the individual(s) or entity(ies) granting rights
117 | under this Public License.
118 |
119 | i. NonCommercial means not primarily intended for or directed towards
120 | commercial advantage or monetary compensation. For purposes of
121 | this Public License, the exchange of the Licensed Material for
122 | other material subject to Copyright and Similar Rights by digital
123 | file-sharing or similar means is NonCommercial provided there is
124 | no payment of monetary compensation in connection with the
125 | exchange.
126 |
127 | j. Share means to provide material to the public by any means or
128 | process that requires permission under the Licensed Rights, such
129 | as reproduction, public display, public performance, distribution,
130 | dissemination, communication, or importation, and to make material
131 | available to the public including in ways that members of the
132 | public may access the material from a place and at a time
133 | individually chosen by them.
134 |
135 | k. Sui Generis Database Rights means rights other than copyright
136 | resulting from Directive 96/9/EC of the European Parliament and of
137 | the Council of 11 March 1996 on the legal protection of databases,
138 | as amended and/or succeeded, as well as other essentially
139 | equivalent rights anywhere in the world.
140 |
141 | l. You means the individual or entity exercising the Licensed Rights
142 | under this Public License. Your has a corresponding meaning.
143 |
144 |
145 | Section 2 -- Scope.
146 |
147 | a. License grant.
148 |
149 | 1. Subject to the terms and conditions of this Public License,
150 | the Licensor hereby grants You a worldwide, royalty-free,
151 | non-sublicensable, non-exclusive, irrevocable license to
152 | exercise the Licensed Rights in the Licensed Material to:
153 |
154 | a. reproduce and Share the Licensed Material, in whole or
155 | in part, for NonCommercial purposes only; and
156 |
157 | b. produce, reproduce, and Share Adapted Material for
158 | NonCommercial purposes only.
159 |
160 | 2. Exceptions and Limitations. For the avoidance of doubt, where
161 | Exceptions and Limitations apply to Your use, this Public
162 | License does not apply, and You do not need to comply with
163 | its terms and conditions.
164 |
165 | 3. Term. The term of this Public License is specified in Section
166 | 6(a).
167 |
168 | 4. Media and formats; technical modifications allowed. The
169 | Licensor authorizes You to exercise the Licensed Rights in
170 | all media and formats whether now known or hereafter created,
171 | and to make technical modifications necessary to do so. The
172 | Licensor waives and/or agrees not to assert any right or
173 | authority to forbid You from making technical modifications
174 | necessary to exercise the Licensed Rights, including
175 | technical modifications necessary to circumvent Effective
176 | Technological Measures. For purposes of this Public License,
177 | simply making modifications authorized by this Section 2(a)
178 | (4) never produces Adapted Material.
179 |
180 | 5. Downstream recipients.
181 |
182 | a. Offer from the Licensor -- Licensed Material. Every
183 | recipient of the Licensed Material automatically
184 | receives an offer from the Licensor to exercise the
185 | Licensed Rights under the terms and conditions of this
186 | Public License.
187 |
188 | b. No downstream restrictions. You may not offer or impose
189 | any additional or different terms or conditions on, or
190 | apply any Effective Technological Measures to, the
191 | Licensed Material if doing so restricts exercise of the
192 | Licensed Rights by any recipient of the Licensed
193 | Material.
194 |
195 | 6. No endorsement. Nothing in this Public License constitutes or
196 | may be construed as permission to assert or imply that You
197 | are, or that Your use of the Licensed Material is, connected
198 | with, or sponsored, endorsed, or granted official status by,
199 | the Licensor or others designated to receive attribution as
200 | provided in Section 3(a)(1)(A)(i).
201 |
202 | b. Other rights.
203 |
204 | 1. Moral rights, such as the right of integrity, are not
205 | licensed under this Public License, nor are publicity,
206 | privacy, and/or other similar personality rights; however, to
207 | the extent possible, the Licensor waives and/or agrees not to
208 | assert any such rights held by the Licensor to the limited
209 | extent necessary to allow You to exercise the Licensed
210 | Rights, but not otherwise.
211 |
212 | 2. Patent and trademark rights are not licensed under this
213 | Public License.
214 |
215 | 3. To the extent possible, the Licensor waives any right to
216 | collect royalties from You for the exercise of the Licensed
217 | Rights, whether directly or through a collecting society
218 | under any voluntary or waivable statutory or compulsory
219 | licensing scheme. In all other cases the Licensor expressly
220 | reserves any right to collect such royalties, including when
221 | the Licensed Material is used other than for NonCommercial
222 | purposes.
223 |
224 |
225 | Section 3 -- License Conditions.
226 |
227 | Your exercise of the Licensed Rights is expressly made subject to the
228 | following conditions.
229 |
230 | a. Attribution.
231 |
232 | 1. If You Share the Licensed Material (including in modified
233 | form), You must:
234 |
235 | a. retain the following if it is supplied by the Licensor
236 | with the Licensed Material:
237 |
238 | i. identification of the creator(s) of the Licensed
239 | Material and any others designated to receive
240 | attribution, in any reasonable manner requested by
241 | the Licensor (including by pseudonym if
242 | designated);
243 |
244 | ii. a copyright notice;
245 |
246 | iii. a notice that refers to this Public License;
247 |
248 | iv. a notice that refers to the disclaimer of
249 | warranties;
250 |
251 | v. a URI or hyperlink to the Licensed Material to the
252 | extent reasonably practicable;
253 |
254 | b. indicate if You modified the Licensed Material and
255 | retain an indication of any previous modifications; and
256 |
257 | c. indicate the Licensed Material is licensed under this
258 | Public License, and include the text of, or the URI or
259 | hyperlink to, this Public License.
260 |
261 | 2. You may satisfy the conditions in Section 3(a)(1) in any
262 | reasonable manner based on the medium, means, and context in
263 | which You Share the Licensed Material. For example, it may be
264 | reasonable to satisfy the conditions by providing a URI or
265 | hyperlink to a resource that includes the required
266 | information.
267 |
268 | 3. If requested by the Licensor, You must remove any of the
269 | information required by Section 3(a)(1)(A) to the extent
270 | reasonably practicable.
271 |
272 | 4. If You Share Adapted Material You produce, the Adapter's
273 | License You apply must not prevent recipients of the Adapted
274 | Material from complying with this Public License.
275 |
276 |
277 | Section 4 -- Sui Generis Database Rights.
278 |
279 | Where the Licensed Rights include Sui Generis Database Rights that
280 | apply to Your use of the Licensed Material:
281 |
282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
283 | to extract, reuse, reproduce, and Share all or a substantial
284 | portion of the contents of the database for NonCommercial purposes
285 | only;
286 |
287 | b. if You include all or a substantial portion of the database
288 | contents in a database in which You have Sui Generis Database
289 | Rights, then the database in which You have Sui Generis Database
290 | Rights (but not its individual contents) is Adapted Material; and
291 |
292 | c. You must comply with the conditions in Section 3(a) if You Share
293 | all or a substantial portion of the contents of the database.
294 |
295 | For the avoidance of doubt, this Section 4 supplements and does not
296 | replace Your obligations under this Public License where the Licensed
297 | Rights include other Copyright and Similar Rights.
298 |
299 |
300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
301 |
302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
312 |
313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
322 |
323 | c. The disclaimer of warranties and limitation of liability provided
324 | above shall be interpreted in a manner that, to the extent
325 | possible, most closely approximates an absolute disclaimer and
326 | waiver of all liability.
327 |
328 |
329 | Section 6 -- Term and Termination.
330 |
331 | a. This Public License applies for the term of the Copyright and
332 | Similar Rights licensed here. However, if You fail to comply with
333 | this Public License, then Your rights under this Public License
334 | terminate automatically.
335 |
336 | b. Where Your right to use the Licensed Material has terminated under
337 | Section 6(a), it reinstates:
338 |
339 | 1. automatically as of the date the violation is cured, provided
340 | it is cured within 30 days of Your discovery of the
341 | violation; or
342 |
343 | 2. upon express reinstatement by the Licensor.
344 |
345 | For the avoidance of doubt, this Section 6(b) does not affect any
346 | right the Licensor may have to seek remedies for Your violations
347 | of this Public License.
348 |
349 | c. For the avoidance of doubt, the Licensor may also offer the
350 | Licensed Material under separate terms or conditions or stop
351 | distributing the Licensed Material at any time; however, doing so
352 | will not terminate this Public License.
353 |
354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
355 | License.
356 |
357 |
358 | Section 7 -- Other Terms and Conditions.
359 |
360 | a. The Licensor shall not be bound by any additional or different
361 | terms or conditions communicated by You unless expressly agreed.
362 |
363 | b. Any arrangements, understandings, or agreements regarding the
364 | Licensed Material not stated herein are separate from and
365 | independent of the terms and conditions of this Public License.
366 |
367 |
368 | Section 8 -- Interpretation.
369 |
370 | a. For the avoidance of doubt, this Public License does not, and
371 | shall not be interpreted to, reduce, limit, restrict, or impose
372 | conditions on any use of the Licensed Material that could lawfully
373 | be made without permission under this Public License.
374 |
375 | b. To the extent possible, if any provision of this Public License is
376 | deemed unenforceable, it shall be automatically reformed to the
377 | minimum extent necessary to make it enforceable. If the provision
378 | cannot be reformed, it shall be severed from this Public License
379 | without affecting the enforceability of the remaining terms and
380 | conditions.
381 |
382 | c. No term or condition of this Public License will be waived and no
383 | failure to comply consented to unless expressly agreed to by the
384 | Licensor.
385 |
386 | d. Nothing in this Public License constitutes or may be interpreted
387 | as a limitation upon, or waiver of, any privileges and immunities
388 | that apply to the Licensor or You, including from the legal
389 | processes of any jurisdiction or authority.
390 |
391 | =======================================================================
392 |
393 | Creative Commons is not a party to its public
394 | licenses. Notwithstanding, Creative Commons may elect to apply one of
395 | its public licenses to material it publishes and in those instances
396 | will be considered the "Licensor." The text of the Creative Commons
397 | public licenses is dedicated to the public domain under the CC0 Public
398 | Domain Dedication. Except for the limited purpose of indicating that
399 | material is shared under a Creative Commons public license or as
400 | otherwise permitted by the Creative Commons policies published at
401 | creativecommons.org/policies, Creative Commons does not authorize the
402 | use of the trademark "Creative Commons" or any other trademark or logo
403 | of Creative Commons without its prior written consent including,
404 | without limitation, in connection with any unauthorized modifications
405 | to any of its public licenses or any other arrangements,
406 | understandings, or agreements concerning use of licensed material. For
407 | the avoidance of doubt, this paragraph does not form part of the
408 | public licenses.
409 |
410 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Better model,Better performance
2 | Model-Swap-Face_v2挂出来了,可以参考使用。
3 |
4 | Note:这个项目展示的是我在数字模特方面的一些探索,希望通过降本增效的方式挖掘生成技术的实际商用价值。此项目展示的是仅支持端到端的单模特头像合成方案,即在保留输入模特表情信息的情况下生成一张更富样式吸引力的新模特。如果想了解支持多模特形象选择的方案可以参阅我的研究笔记。
5 |
6 | # 效果预览
7 | ## 单图输入-输出展示
8 |
9 |
10 |
11 | 输入
12 |
13 |
14 |
15 | 模特风格输出
16 |
17 | ## 多图对比展示
18 |
19 |
20 |
21 | 多效果转换图预览
22 |
23 | ## 替换效果展示
24 | 此处是展示生成图像替换回原图的效果,引入了额外的后处理。
25 |
26 |
27 |
28 | 转小孩子风格图片——左:输入-右:输出
29 |
30 |
31 |
32 | 转网红风格图片——左:输入-右:输出
33 |
34 |
35 |
36 | 转多种风格图片——1排:输入-2-5排:输出
37 |
38 |
39 | # Inference框架
40 |
41 |
42 |
43 |
44 |
45 | # 使用方法
46 |
47 | ## 环境配置
48 | * Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
49 | * 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.
50 | * TensorFlow 1.10.0 or newer with GPU support.
51 | * One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs.
52 | * NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.
53 | *
54 | ## 运行方法
55 | 1.按照```netwotk/download_weights.txt```所示将模型文件下载至networks文件夹下。
56 | 2.配置好main.py并运行```python main.py```。
57 |
58 |
59 | # 多模特选择方案
60 |
61 |
62 |
63 | 多模特选择方案支持更多样的模特选择,实现方法可以参阅我的研究笔记。
64 |
65 |
66 |
67 | # 致谢
68 | 代码部分借用了Puzer和Pbaylies的代码,感谢分享。
69 |
70 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import submission
9 |
10 | from .submission.run_context import RunContext
11 |
12 | from .submission.submit import SubmitTarget
13 | from .submission.submit import PathType
14 | from .submission.submit import SubmitConfig
15 | from .submission.submit import get_path_from_template
16 | from .submission.submit import submit_run
17 |
18 | from .util import EasyDict
19 |
20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
21 |
--------------------------------------------------------------------------------
/dnnlib/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import run_context
9 | from . import submit
10 |
--------------------------------------------------------------------------------
/dnnlib/submission/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/submission/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__pycache__/run_context.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/submission/__pycache__/run_context.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/__pycache__/submit.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/submission/__pycache__/submit.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/submission/_internal/run.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for launching run functions in computing clusters.
9 |
10 | During the submit process, this file is copied to the appropriate run dir.
11 | When the job is launched in the cluster, this module is the first thing that
12 | is run inside the docker container.
13 | """
14 |
15 | import os
16 | import pickle
17 | import sys
18 |
19 | # PYTHONPATH should have been set so that the run_dir/src is in it
20 | import dnnlib
21 |
22 | def main():
23 | if not len(sys.argv) >= 4:
24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
25 |
26 | run_dir = str(sys.argv[1])
27 | task_name = str(sys.argv[2])
28 | host_name = str(sys.argv[3])
29 |
30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl")
31 |
32 | # SubmitConfig should have been pickled to the run dir
33 | if not os.path.exists(submit_config_path):
34 | raise RuntimeError("SubmitConfig pickle file does not exist!")
35 |
36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
38 |
39 | submit_config.task_name = task_name
40 | submit_config.host_name = host_name
41 |
42 | dnnlib.submission.submit.run_wrapper(submit_config)
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/dnnlib/submission/run_context.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helpers for managing the run/training loop."""
9 |
10 | import datetime
11 | import json
12 | import os
13 | import pprint
14 | import time
15 | import types
16 |
17 | from typing import Any
18 |
19 | from . import submit
20 |
21 |
22 | class RunContext(object):
23 | """Helper class for managing the run/training loop.
24 |
25 | The context will hide the implementation details of a basic run/training loop.
26 | It will set things up properly, tell if run should be stopped, and then cleans up.
27 | User should call update periodically and use should_stop to determine if run should be stopped.
28 |
29 | Args:
30 | submit_config: The SubmitConfig that is used for the current run.
31 | config_module: The whole config module that is used for the current run.
32 | max_epoch: Optional cached value for the max_epoch variable used in update.
33 | """
34 |
35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
36 | self.submit_config = submit_config
37 | self.should_stop_flag = False
38 | self.has_closed = False
39 | self.start_time = time.time()
40 | self.last_update_time = time.time()
41 | self.last_update_interval = 0.0
42 | self.max_epoch = max_epoch
43 |
44 | # pretty print the all the relevant content of the config module to a text file
45 | if config_module is not None:
46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
49 |
50 | # write out details about the run to a text file
51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
54 |
55 | def __enter__(self) -> "RunContext":
56 | return self
57 |
58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
59 | self.close()
60 |
61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
62 | """Do general housekeeping and keep the state of the context up-to-date.
63 | Should be called often enough but not in a tight loop."""
64 | assert not self.has_closed
65 |
66 | self.last_update_interval = time.time() - self.last_update_time
67 | self.last_update_time = time.time()
68 |
69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
70 | self.should_stop_flag = True
71 |
72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
73 |
74 | def should_stop(self) -> bool:
75 | """Tell whether a stopping condition has been triggered one way or another."""
76 | return self.should_stop_flag
77 |
78 | def get_time_since_start(self) -> float:
79 | """How much time has passed since the creation of the context."""
80 | return time.time() - self.start_time
81 |
82 | def get_time_since_last_update(self) -> float:
83 | """How much time has passed since the last call to update."""
84 | return time.time() - self.last_update_time
85 |
86 | def get_last_update_interval(self) -> float:
87 | """How much time passed between the previous two calls to update."""
88 | return self.last_update_interval
89 |
90 | def close(self) -> None:
91 | """Close the context and clean up.
92 | Should only be called once."""
93 | if not self.has_closed:
94 | # update the run.txt with stopping time
95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
98 |
99 | self.has_closed = True
100 |
--------------------------------------------------------------------------------
/dnnlib/submission/submit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Submit a function to be run either locally or in a computing cluster."""
9 |
10 | import copy
11 | import io
12 | import os
13 | import pathlib
14 | import pickle
15 | import platform
16 | import pprint
17 | import re
18 | import shutil
19 | import time
20 | import traceback
21 |
22 | import zipfile
23 |
24 | from enum import Enum
25 |
26 | from .. import util
27 | from ..util import EasyDict
28 |
29 |
30 | class SubmitTarget(Enum):
31 | """The target where the function should be run.
32 |
33 | LOCAL: Run it locally.
34 | """
35 | LOCAL = 1
36 |
37 |
38 | class PathType(Enum):
39 | """Determines in which format should a path be formatted.
40 |
41 | WINDOWS: Format with Windows style.
42 | LINUX: Format with Linux/Posix style.
43 | AUTO: Use current OS type to select either WINDOWS or LINUX.
44 | """
45 | WINDOWS = 1
46 | LINUX = 2
47 | AUTO = 3
48 |
49 |
50 | _user_name_override = None
51 |
52 |
53 | class SubmitConfig(util.EasyDict):
54 | """Strongly typed config dict needed to submit runs.
55 |
56 | Attributes:
57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
58 | run_desc: Description of the run. Will be used in the run dir and task name.
59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
61 | submit_target: Submit target enum value. Used to select where the run is actually launched.
62 | num_gpus: Number of GPUs used/requested for the run.
63 | print_info: Whether to print debug information when submitting.
64 | ask_confirmation: Whether to ask a confirmation before submitting.
65 | run_id: Automatically populated value during submit.
66 | run_name: Automatically populated value during submit.
67 | run_dir: Automatically populated value during submit.
68 | run_func_name: Automatically populated value during submit.
69 | run_func_kwargs: Automatically populated value during submit.
70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
71 | task_name: Automatically populated value during submit.
72 | host_name: Automatically populated value during submit.
73 | """
74 |
75 | def __init__(self):
76 | super().__init__()
77 |
78 | # run (set these)
79 | self.run_dir_root = "" # should always be passed through get_path_from_template
80 | self.run_desc = ""
81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
82 | self.run_dir_extra_files = None
83 |
84 | # submit (set these)
85 | self.submit_target = SubmitTarget.LOCAL
86 | self.num_gpus = 1
87 | self.print_info = False
88 | self.ask_confirmation = False
89 |
90 | # (automatically populated)
91 | self.run_id = None
92 | self.run_name = None
93 | self.run_dir = None
94 | self.run_func_name = None
95 | self.run_func_kwargs = None
96 | self.user_name = None
97 | self.task_name = None
98 | self.host_name = "localhost"
99 |
100 |
101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
102 | """Replace tags in the given path template and return either Windows or Linux formatted path."""
103 | # automatically select path type depending on running OS
104 | if path_type == PathType.AUTO:
105 | if platform.system() == "Windows":
106 | path_type = PathType.WINDOWS
107 | elif platform.system() == "Linux":
108 | path_type = PathType.LINUX
109 | else:
110 | raise RuntimeError("Unknown platform")
111 |
112 | path_template = path_template.replace("", get_user_name())
113 |
114 | # return correctly formatted path
115 | if path_type == PathType.WINDOWS:
116 | return str(pathlib.PureWindowsPath(path_template))
117 | elif path_type == PathType.LINUX:
118 | return str(pathlib.PurePosixPath(path_template))
119 | else:
120 | raise RuntimeError("Unknown platform")
121 |
122 |
123 | def get_template_from_path(path: str) -> str:
124 | """Convert a normal path back to its template representation."""
125 | # replace all path parts with the template tags
126 | path = path.replace("\\", "/")
127 | return path
128 |
129 |
130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
131 | """Convert a normal path to template and the convert it back to a normal path with given path type."""
132 | path_template = get_template_from_path(path)
133 | path = get_path_from_template(path_template, path_type)
134 | return path
135 |
136 |
137 | def set_user_name_override(name: str) -> None:
138 | """Set the global username override value."""
139 | global _user_name_override
140 | _user_name_override = name
141 |
142 |
143 | def get_user_name():
144 | """Get the current user name."""
145 | if _user_name_override is not None:
146 | return _user_name_override
147 | elif platform.system() == "Windows":
148 | return os.getlogin()
149 | elif platform.system() == "Linux":
150 | try:
151 | import pwd # pylint: disable=import-error
152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
153 | except:
154 | return "unknown"
155 | else:
156 | raise RuntimeError("Unknown platform")
157 |
158 |
159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str:
160 | """Create a new run dir with increasing ID number at the start."""
161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
162 |
163 | if not os.path.exists(run_dir_root):
164 | print("Creating the run dir root: {}".format(run_dir_root))
165 | os.makedirs(run_dir_root)
166 |
167 | submit_config.run_id = _get_next_run_id_local(run_dir_root)
168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
169 | run_dir = os.path.join(run_dir_root, submit_config.run_name)
170 |
171 | if os.path.exists(run_dir):
172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
173 |
174 | print("Creating the run dir: {}".format(run_dir))
175 | os.makedirs(run_dir)
176 |
177 | return run_dir
178 |
179 |
180 | def _get_next_run_id_local(run_dir_root: str) -> int:
181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
183 | r = re.compile("^\\d+") # match one or more digits at the start of the string
184 | run_id = 0
185 |
186 | for dir_name in dir_names:
187 | m = r.match(dir_name)
188 |
189 | if m is not None:
190 | i = int(m.group())
191 | run_id = max(run_id, i + 1)
192 |
193 | return run_id
194 |
195 |
196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
198 | print("Copying files to the run dir")
199 | files = []
200 |
201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
202 | assert '.' in submit_config.run_func_name
203 | for _idx in range(submit_config.run_func_name.count('.') - 1):
204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
206 |
207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
209 |
210 | if submit_config.run_dir_extra_files is not None:
211 | files += submit_config.run_dir_extra_files
212 |
213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
215 |
216 | util.copy_files_and_create_dirs(files)
217 |
218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
219 |
220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
222 |
223 |
224 | def run_wrapper(submit_config: SubmitConfig) -> None:
225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL
227 |
228 | checker = None
229 |
230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
231 | if is_local:
232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
234 | logger = util.Logger(file_name=None, should_flush=True)
235 |
236 | import dnnlib
237 | dnnlib.submit_config = submit_config
238 |
239 | try:
240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
241 | start_time = time.time()
242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
244 | except:
245 | if is_local:
246 | raise
247 | else:
248 | traceback.print_exc()
249 |
250 | log_src = os.path.join(submit_config.run_dir, "log.txt")
251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
252 | shutil.copyfile(log_src, log_dst)
253 | finally:
254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
255 |
256 | dnnlib.submit_config = None
257 | logger.close()
258 |
259 | if checker is not None:
260 | checker.stop()
261 |
262 |
263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
265 | submit_config = copy.copy(submit_config)
266 |
267 | if submit_config.user_name is None:
268 | submit_config.user_name = get_user_name()
269 |
270 | submit_config.run_func_name = run_func_name
271 | submit_config.run_func_kwargs = run_func_kwargs
272 |
273 | assert submit_config.submit_target == SubmitTarget.LOCAL
274 | if submit_config.submit_target in {SubmitTarget.LOCAL}:
275 | run_dir = _create_run_dir_local(submit_config)
276 |
277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
278 | submit_config.run_dir = run_dir
279 | _populate_run_dir(run_dir, submit_config)
280 |
281 | if submit_config.print_info:
282 | print("\nSubmit config:\n")
283 | pprint.pprint(submit_config, indent=4, width=200, compact=False)
284 | print()
285 |
286 | if submit_config.ask_confirmation:
287 | if not util.ask_yes_no("Continue submitting the job?"):
288 | return
289 |
290 | run_wrapper(submit_config)
291 |
--------------------------------------------------------------------------------
/dnnlib/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import autosummary
9 | from . import network
10 | from . import optimizer
11 | from . import tfutil
12 |
13 | from .tfutil import *
14 | from .network import Network
15 |
16 | from .optimizer import Optimizer
17 |
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/network.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc
--------------------------------------------------------------------------------
/dnnlib/tflib/autosummary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for adding automatically tracked values to Tensorboard.
9 |
10 | Autosummary creates an identity op that internally keeps track of the input
11 | values and automatically shows up in TensorBoard. The reported value
12 | represents an average over input components. The average is accumulated
13 | constantly over time and flushed when save_summaries() is called.
14 |
15 | Notes:
16 | - The output tensor must be used as an input for something else in the
17 | graph. Otherwise, the autosummary op will not get executed, and the average
18 | value will not get accumulated.
19 | - It is perfectly fine to include autosummaries with the same name in
20 | several places throughout the graph, even if they are executed concurrently.
21 | - It is ok to also pass in a python scalar or numpy array. In this case, it
22 | is added to the average immediately.
23 | """
24 |
25 | from collections import OrderedDict
26 | import numpy as np
27 | import tensorflow as tf
28 | from tensorboard import summary as summary_lib
29 | from tensorboard.plugins.custom_scalar import layout_pb2
30 |
31 | from . import tfutil
32 | from .tfutil import TfExpression
33 | from .tfutil import TfExpressionEx
34 |
35 | _dtype = tf.float64
36 | _vars = OrderedDict() # name => [var, ...]
37 | _immediate = OrderedDict() # name => update_op, update_value
38 | _finalized = False
39 | _merge_op = None
40 |
41 |
42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
43 | """Internal helper for creating autosummary accumulators."""
44 | assert not _finalized
45 | name_id = name.replace("/", "_")
46 | v = tf.cast(value_expr, _dtype)
47 |
48 | if v.shape.is_fully_defined():
49 | size = np.prod(tfutil.shape_to_list(v.shape))
50 | size_expr = tf.constant(size, dtype=_dtype)
51 | else:
52 | size = None
53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
54 |
55 | if size == 1:
56 | if v.shape.ndims != 0:
57 | v = tf.reshape(v, [])
58 | v = [size_expr, v, tf.square(v)]
59 | else:
60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
62 |
63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
66 |
67 | if name in _vars:
68 | _vars[name].append(var)
69 | else:
70 | _vars[name] = [var]
71 | return update_op
72 |
73 |
74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
75 | """Create a new autosummary.
76 |
77 | Args:
78 | name: Name to use in TensorBoard
79 | value: TensorFlow expression or python value to track
80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
81 |
82 | Example use of the passthru mechanism:
83 |
84 | n = autosummary('l2loss', loss, passthru=n)
85 |
86 | This is a shorthand for the following code:
87 |
88 | with tf.control_dependencies([autosummary('l2loss', loss)]):
89 | n = tf.identity(n)
90 | """
91 | tfutil.assert_tf_initialized()
92 | name_id = name.replace("/", "_")
93 |
94 | if tfutil.is_tf_expression(value):
95 | with tf.name_scope("summary_" + name_id), tf.device(value.device):
96 | update_op = _create_var(name, value)
97 | with tf.control_dependencies([update_op]):
98 | return tf.identity(value if passthru is None else passthru)
99 |
100 | else: # python scalar or numpy array
101 | if name not in _immediate:
102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
103 | update_value = tf.placeholder(_dtype)
104 | update_op = _create_var(name, update_value)
105 | _immediate[name] = update_op, update_value
106 |
107 | update_op, update_value = _immediate[name]
108 | tfutil.run(update_op, {update_value: value})
109 | return value if passthru is None else passthru
110 |
111 |
112 | def finalize_autosummaries() -> None:
113 | """Create the necessary ops to include autosummaries in TensorBoard report.
114 | Note: This should be done only once per graph.
115 | """
116 | global _finalized
117 | tfutil.assert_tf_initialized()
118 |
119 | if _finalized:
120 | return None
121 |
122 | _finalized = True
123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
124 |
125 | # Create summary ops.
126 | with tf.device(None), tf.control_dependencies(None):
127 | for name, vars_list in _vars.items():
128 | name_id = name.replace("/", "_")
129 | with tfutil.absolute_name_scope("Autosummary/" + name_id):
130 | moments = tf.add_n(vars_list)
131 | moments /= moments[0]
132 | with tf.control_dependencies([moments]): # read before resetting
133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
135 | mean = moments[1]
136 | std = tf.sqrt(moments[2] - tf.square(moments[1]))
137 | tf.summary.scalar(name, mean)
138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
140 |
141 | # Group by category and chart name.
142 | cat_dict = OrderedDict()
143 | for series_name in sorted(_vars.keys()):
144 | p = series_name.split("/")
145 | cat = p[0] if len(p) >= 2 else ""
146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
147 | if cat not in cat_dict:
148 | cat_dict[cat] = OrderedDict()
149 | if chart not in cat_dict[cat]:
150 | cat_dict[cat][chart] = []
151 | cat_dict[cat][chart].append(series_name)
152 |
153 | # Setup custom_scalar layout.
154 | categories = []
155 | for cat_name, chart_dict in cat_dict.items():
156 | charts = []
157 | for chart_name, series_names in chart_dict.items():
158 | series = []
159 | for series_name in series_names:
160 | series.append(layout_pb2.MarginChartContent.Series(
161 | value=series_name,
162 | lower="xCustomScalars/" + series_name + "/margin_lo",
163 | upper="xCustomScalars/" + series_name + "/margin_hi"))
164 | margin = layout_pb2.MarginChartContent(series=series)
165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts))
167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
168 | return layout
169 |
170 | def save_summaries(file_writer, global_step=None):
171 | """Call FileWriter.add_summary() with all summaries in the default graph,
172 | automatically finalizing and merging them on the first call.
173 | """
174 | global _merge_op
175 | tfutil.assert_tf_initialized()
176 |
177 | if _merge_op is None:
178 | layout = finalize_autosummaries()
179 | if layout is not None:
180 | file_writer.add_summary(layout)
181 | with tf.device(None), tf.control_dependencies(None):
182 | _merge_op = tf.summary.merge_all()
183 |
184 | file_writer.add_summary(_merge_op.eval(), global_step)
185 |
--------------------------------------------------------------------------------
/dnnlib/tflib/network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for managing networks."""
9 |
10 | import types
11 | import inspect
12 | import re
13 | import uuid
14 | import sys
15 | import numpy as np
16 | import tensorflow as tf
17 |
18 | from collections import OrderedDict
19 | from typing import Any, List, Tuple, Union
20 |
21 | from . import tfutil
22 | from .. import util
23 |
24 | from .tfutil import TfExpression, TfExpressionEx
25 |
26 | _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
27 | _import_module_src = dict() # Source code for temporary modules created during pickle import.
28 |
29 |
30 | def import_handler(handler_func):
31 | """Function decorator for declaring custom import handlers."""
32 | _import_handlers.append(handler_func)
33 | return handler_func
34 |
35 |
36 | class Network:
37 | """Generic network abstraction.
38 |
39 | Acts as a convenience wrapper for a parameterized network construction
40 | function, providing several utility methods and convenient access to
41 | the inputs/outputs/weights.
42 |
43 | Network objects can be safely pickled and unpickled for long-term
44 | archival purposes. The pickling works reliably as long as the underlying
45 | network construction function is defined in a standalone Python module
46 | that has no side effects or application-specific imports.
47 |
48 | Args:
49 | name: Network name. Used to select TensorFlow name and variable scopes.
50 | func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
51 | static_kwargs: Keyword arguments to be passed in to the network construction function.
52 |
53 | Attributes:
54 | name: User-specified name, defaults to build func name if None.
55 | scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
56 | static_kwargs: Arguments passed to the user-supplied build func.
57 | components: Container for sub-networks. Passed to the build func, and retained between calls.
58 | num_inputs: Number of input tensors.
59 | num_outputs: Number of output tensors.
60 | input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
61 | output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
62 | input_shape: Short-hand for input_shapes[0].
63 | output_shape: Short-hand for output_shapes[0].
64 | input_templates: Input placeholders in the template graph.
65 | output_templates: Output tensors in the template graph.
66 | input_names: Name string for each input.
67 | output_names: Name string for each output.
68 | own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
69 | vars: All variables (local_name => var).
70 | trainables: All trainable variables (local_name => var).
71 | var_global_to_local: Mapping from variable global names to local names.
72 | """
73 |
74 | def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
75 | tfutil.assert_tf_initialized()
76 | assert isinstance(name, str) or name is None
77 | assert func_name is not None
78 | assert isinstance(func_name, str) or util.is_top_level_function(func_name)
79 | assert util.is_pickleable(static_kwargs)
80 |
81 | self._init_fields()
82 | self.name = name
83 | self.static_kwargs = util.EasyDict(static_kwargs)
84 |
85 | # Locate the user-specified network build function.
86 | if util.is_top_level_function(func_name):
87 | func_name = util.get_top_level_function_name(func_name)
88 | module, self._build_func_name = util.get_module_from_obj_name(func_name)
89 | self._build_func = util.get_obj_from_module(module, self._build_func_name)
90 | assert callable(self._build_func)
91 |
92 | # Dig up source code for the module containing the build function.
93 | self._build_module_src = _import_module_src.get(module, None)
94 | if self._build_module_src is None:
95 | self._build_module_src = inspect.getsource(module)
96 |
97 | # Init TensorFlow graph.
98 | self._init_graph()
99 | self.reset_own_vars()
100 |
101 | def _init_fields(self) -> None:
102 | self.name = None
103 | self.scope = None
104 | self.static_kwargs = util.EasyDict()
105 | self.components = util.EasyDict()
106 | self.num_inputs = 0
107 | self.num_outputs = 0
108 | self.input_shapes = [[]]
109 | self.output_shapes = [[]]
110 | self.input_shape = []
111 | self.output_shape = []
112 | self.input_templates = []
113 | self.output_templates = []
114 | self.input_names = []
115 | self.output_names = []
116 | self.own_vars = OrderedDict()
117 | self.vars = OrderedDict()
118 | self.trainables = OrderedDict()
119 | self.var_global_to_local = OrderedDict()
120 |
121 | self._build_func = None # User-supplied build function that constructs the network.
122 | self._build_func_name = None # Name of the build function.
123 | self._build_module_src = None # Full source code of the module containing the build function.
124 | self._run_cache = dict() # Cached graph data for Network.run().
125 |
126 | def _init_graph(self) -> None:
127 | # Collect inputs.
128 | self.input_names = []
129 |
130 | for param in inspect.signature(self._build_func).parameters.values():
131 | if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
132 | self.input_names.append(param.name)
133 |
134 | self.num_inputs = len(self.input_names)
135 | assert self.num_inputs >= 1
136 |
137 | # Choose name and scope.
138 | if self.name is None:
139 | self.name = self._build_func_name
140 | assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
141 | with tf.name_scope(None):
142 | self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
143 |
144 | # Finalize build func kwargs.
145 | build_kwargs = dict(self.static_kwargs)
146 | build_kwargs["is_template_graph"] = True
147 | build_kwargs["components"] = self.components
148 |
149 | # Build template graph.
150 | with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
151 | assert tf.get_variable_scope().name == self.scope
152 | assert tf.get_default_graph().get_name_scope() == self.scope
153 | with tf.control_dependencies(None): # ignore surrounding control dependencies
154 | self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
155 | out_expr = self._build_func(*self.input_templates, **build_kwargs)
156 |
157 | # Collect outputs.
158 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
159 | self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
160 | self.num_outputs = len(self.output_templates)
161 | assert self.num_outputs >= 1
162 | assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
163 |
164 | # Perform sanity checks.
165 | if any(t.shape.ndims is None for t in self.input_templates):
166 | raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167 | if any(t.shape.ndims is None for t in self.output_templates):
168 | raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169 | if any(not isinstance(comp, Network) for comp in self.components.values()):
170 | raise ValueError("Components of a Network must be Networks themselves.")
171 | if len(self.components) != len(set(comp.name for comp in self.components.values())):
172 | raise ValueError("Components of a Network must have unique names.")
173 |
174 | # List inputs and outputs.
175 | self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
176 | self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
177 | self.input_shape = self.input_shapes[0]
178 | self.output_shape = self.output_shapes[0]
179 | self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
180 |
181 | # List variables.
182 | self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
183 | self.vars = OrderedDict(self.own_vars)
184 | self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
185 | self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
186 | self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
187 |
188 | def reset_own_vars(self) -> None:
189 | """Re-initialize all variables of this network, excluding sub-networks."""
190 | tfutil.run([var.initializer for var in self.own_vars.values()])
191 |
192 | def reset_vars(self) -> None:
193 | """Re-initialize all variables of this network, including sub-networks."""
194 | tfutil.run([var.initializer for var in self.vars.values()])
195 |
196 | def reset_trainables(self) -> None:
197 | """Re-initialize all trainable variables of this network, including sub-networks."""
198 | tfutil.run([var.initializer for var in self.trainables.values()])
199 |
200 | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
201 | """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
202 | assert len(in_expr) == self.num_inputs
203 | assert not all(expr is None for expr in in_expr)
204 |
205 | # Finalize build func kwargs.
206 | build_kwargs = dict(self.static_kwargs)
207 | build_kwargs.update(dynamic_kwargs)
208 | build_kwargs["is_template_graph"] = False
209 | build_kwargs["components"] = self.components
210 |
211 | # Build TensorFlow graph to evaluate the network.
212 | with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
213 | assert tf.get_variable_scope().name == self.scope
214 | valid_inputs = [expr for expr in in_expr if expr is not None]
215 | final_inputs = []
216 | for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
217 | if expr is not None:
218 | expr = tf.identity(expr, name=name)
219 | else:
220 | expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
221 | final_inputs.append(expr)
222 | out_expr = self._build_func(*final_inputs, **build_kwargs)
223 |
224 | # Propagate input shapes back to the user-specified expressions.
225 | for expr, final in zip(in_expr, final_inputs):
226 | if isinstance(expr, tf.Tensor):
227 | expr.set_shape(final.shape)
228 |
229 | # Express outputs in the desired format.
230 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
231 | if return_as_list:
232 | out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
233 | return out_expr
234 |
235 | def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
236 | """Get the local name of a given variable, without any surrounding name scopes."""
237 | assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
238 | global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
239 | return self.var_global_to_local[global_name]
240 |
241 | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
242 | """Find variable by local or global name."""
243 | assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
244 | return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
245 |
246 | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
247 | """Get the value of a given variable as NumPy array.
248 | Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
249 | return self.find_var(var_or_local_name).eval()
250 |
251 | def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
252 | """Set the value of a given variable based on the given NumPy array.
253 | Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
254 | tfutil.set_vars({self.find_var(var_or_local_name): new_value})
255 |
256 | def __getstate__(self) -> dict:
257 | """Pickle export."""
258 | state = dict()
259 | state["version"] = 3
260 | state["name"] = self.name
261 | state["static_kwargs"] = dict(self.static_kwargs)
262 | state["components"] = dict(self.components)
263 | state["build_module_src"] = self._build_module_src
264 | state["build_func_name"] = self._build_func_name
265 | state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
266 | return state
267 |
268 | def __setstate__(self, state: dict) -> None:
269 | """Pickle import."""
270 | # pylint: disable=attribute-defined-outside-init
271 | tfutil.assert_tf_initialized()
272 | self._init_fields()
273 |
274 | # Execute custom import handlers.
275 | for handler in _import_handlers:
276 | state = handler(state)
277 |
278 | # Set basic fields.
279 | assert state["version"] in [2, 3]
280 | self.name = state["name"]
281 | self.static_kwargs = util.EasyDict(state["static_kwargs"])
282 | self.components = util.EasyDict(state.get("components", {}))
283 | self._build_module_src = state["build_module_src"]
284 | self._build_func_name = state["build_func_name"]
285 |
286 | # Create temporary module from the imported source code.
287 | module_name = "_tflib_network_import_" + uuid.uuid4().hex
288 | module = types.ModuleType(module_name)
289 | sys.modules[module_name] = module
290 | _import_module_src[module] = self._build_module_src
291 | exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
292 |
293 | # Locate network build function in the temporary module.
294 | self._build_func = util.get_obj_from_module(module, self._build_func_name)
295 | assert callable(self._build_func)
296 |
297 | # Init TensorFlow graph.
298 | self._init_graph()
299 | self.reset_own_vars()
300 | tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
301 |
302 | def clone(self, name: str = None, **new_static_kwargs) -> "Network":
303 | """Create a clone of this network with its own copy of the variables."""
304 | # pylint: disable=protected-access
305 | net = object.__new__(Network)
306 | net._init_fields()
307 | net.name = name if name is not None else self.name
308 | net.static_kwargs = util.EasyDict(self.static_kwargs)
309 | net.static_kwargs.update(new_static_kwargs)
310 | net._build_module_src = self._build_module_src
311 | net._build_func_name = self._build_func_name
312 | net._build_func = self._build_func
313 | net._init_graph()
314 | net.copy_vars_from(self)
315 | return net
316 |
317 | def copy_own_vars_from(self, src_net: "Network") -> None:
318 | """Copy the values of all variables from the given network, excluding sub-networks."""
319 | names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
320 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
321 |
322 | def copy_vars_from(self, src_net: "Network") -> None:
323 | """Copy the values of all variables from the given network, including sub-networks."""
324 | names = [name for name in self.vars.keys() if name in src_net.vars]
325 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
326 |
327 | def copy_trainables_from(self, src_net: "Network") -> None:
328 | """Copy the values of all trainable variables from the given network, including sub-networks."""
329 | names = [name for name in self.trainables.keys() if name in src_net.trainables]
330 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
331 |
332 | def copy_compatible_trainables_from(self, src_net: "Network") -> None:
333 | """Copy the compatible values of all trainable variables from the given network, including sub-networks"""
334 | names = []
335 | for name in self.trainables.keys():
336 | if name not in src_net.trainables:
337 | print("Not restoring (not present): {}".format(name))
338 | elif self.trainables[name].shape != src_net.trainables[name].shape:
339 | print("Not restoring (different shape): {}".format(name))
340 |
341 | if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
342 | names.append(name)
343 |
344 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
345 |
346 | def apply_swa(self, src_net, epoch):
347 | """Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks"""
348 | names = []
349 | for name in self.trainables.keys():
350 | if name not in src_net.trainables:
351 | print("Not restoring (not present): {}".format(name))
352 | elif self.trainables[name].shape != src_net.trainables[name].shape:
353 | print("Not restoring (different shape): {}".format(name))
354 |
355 | if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
356 | names.append(name)
357 |
358 | scale_new_data = 1.0 / (epoch + 1)
359 | scale_moving_average = (1.0 - scale_new_data)
360 | tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))
361 |
362 | def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
363 | """Create new network with the given parameters, and copy all variables from this network."""
364 | if new_name is None:
365 | new_name = self.name
366 | static_kwargs = dict(self.static_kwargs)
367 | static_kwargs.update(new_static_kwargs)
368 | net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
369 | net.copy_vars_from(self)
370 | return net
371 |
372 | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
373 | """Construct a TensorFlow op that updates the variables of this network
374 | to be slightly closer to those of the given network."""
375 | with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
376 | ops = []
377 | for name, var in self.vars.items():
378 | if name in src_net.vars:
379 | cur_beta = beta if name in self.trainables else beta_nontrainable
380 | new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
381 | ops.append(var.assign(new_value))
382 | return tf.group(*ops)
383 |
384 | def run(self,
385 | *in_arrays: Tuple[Union[np.ndarray, None], ...],
386 | input_transform: dict = None,
387 | output_transform: dict = None,
388 | return_as_list: bool = False,
389 | print_progress: bool = False,
390 | minibatch_size: int = None,
391 | num_gpus: int = 1,
392 | assume_frozen: bool = False,
393 | custom_inputs=None,
394 | **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
395 | """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
396 |
397 | Args:
398 | input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
399 | The dict must contain a 'func' field that points to a top-level function. The function is called with the input
400 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
401 | output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
402 | The dict must contain a 'func' field that points to a top-level function. The function is called with the output
403 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
404 | return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
405 | print_progress: Print progress to the console? Useful for very large input arrays.
406 | minibatch_size: Maximum minibatch size to use, None = disable batching.
407 | num_gpus: Number of GPUs to use.
408 | assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
409 | dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
410 | custom_inputs: Allow to use another Tensor as input instead of default Placeholders
411 | """
412 | assert len(in_arrays) == self.num_inputs
413 | assert not all(arr is None for arr in in_arrays)
414 | assert input_transform is None or util.is_top_level_function(input_transform["func"])
415 | assert output_transform is None or util.is_top_level_function(output_transform["func"])
416 | output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
417 | num_items = in_arrays[0].shape[0]
418 | if minibatch_size is None:
419 | minibatch_size = num_items
420 |
421 | # Construct unique hash key from all arguments that affect the TensorFlow graph.
422 | key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
423 | def unwind_key(obj):
424 | if isinstance(obj, dict):
425 | return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
426 | if callable(obj):
427 | return util.get_top_level_function_name(obj)
428 | return obj
429 | key = repr(unwind_key(key))
430 |
431 | # Build graph.
432 | if key not in self._run_cache:
433 | with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
434 | if custom_inputs is not None:
435 | with tf.device("/gpu:0"):
436 | in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
437 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
438 | else:
439 | with tf.device("/cpu:0"):
440 | in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
441 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
442 |
443 | out_split = []
444 | for gpu in range(num_gpus):
445 | with tf.device("/gpu:%d" % gpu):
446 | net_gpu = self.clone() if assume_frozen else self
447 | in_gpu = in_split[gpu]
448 |
449 | if input_transform is not None:
450 | in_kwargs = dict(input_transform)
451 | in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
452 | in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
453 |
454 | assert len(in_gpu) == self.num_inputs
455 | out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
456 |
457 | if output_transform is not None:
458 | out_kwargs = dict(output_transform)
459 | out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
460 | out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
461 |
462 | assert len(out_gpu) == self.num_outputs
463 | out_split.append(out_gpu)
464 |
465 | with tf.device("/cpu:0"):
466 | out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
467 | self._run_cache[key] = in_expr, out_expr
468 |
469 | # Run minibatches.
470 | in_expr, out_expr = self._run_cache[key]
471 | out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
472 |
473 | for mb_begin in range(0, num_items, minibatch_size):
474 | if print_progress:
475 | print("\r%d / %d" % (mb_begin, num_items), end="")
476 |
477 | mb_end = min(mb_begin + minibatch_size, num_items)
478 | mb_num = mb_end - mb_begin
479 | mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
480 | mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
481 |
482 | for dst, src in zip(out_arrays, mb_out):
483 | dst[mb_begin: mb_end] = src
484 |
485 | # Done.
486 | if print_progress:
487 | print("\r%d / %d" % (num_items, num_items))
488 |
489 | if not return_as_list:
490 | out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
491 | return out_arrays
492 |
493 | def list_ops(self) -> List[TfExpression]:
494 | include_prefix = self.scope + "/"
495 | exclude_prefix = include_prefix + "_"
496 | ops = tf.get_default_graph().get_operations()
497 | ops = [op for op in ops if op.name.startswith(include_prefix)]
498 | ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
499 | return ops
500 |
501 | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
502 | """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
503 | individual layers of the network. Mainly intended to be used for reporting."""
504 | layers = []
505 |
506 | def recurse(scope, parent_ops, parent_vars, level):
507 | # Ignore specific patterns.
508 | if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
509 | return
510 |
511 | # Filter ops and vars by scope.
512 | global_prefix = scope + "/"
513 | local_prefix = global_prefix[len(self.scope) + 1:]
514 | cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
515 | cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
516 | if not cur_ops and not cur_vars:
517 | return
518 |
519 | # Filter out all ops related to variables.
520 | for var in [op for op in cur_ops if op.type.startswith("Variable")]:
521 | var_prefix = var.name + "/"
522 | cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
523 |
524 | # Scope does not contain ops as immediate children => recurse deeper.
525 | contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
526 | if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
527 | visited = set()
528 | for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
529 | token = rel_name.split("/")[0]
530 | if token not in visited:
531 | recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
532 | visited.add(token)
533 | return
534 |
535 | # Report layer.
536 | layer_name = scope[len(self.scope) + 1:]
537 | layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
538 | layer_trainables = [var for _name, var in cur_vars if var.trainable]
539 | layers.append((layer_name, layer_output, layer_trainables))
540 |
541 | recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
542 | return layers
543 |
544 | def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
545 | """Print a summary table of the network structure."""
546 | rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
547 | rows += [["---"] * 4]
548 | total_params = 0
549 |
550 | for layer_name, layer_output, layer_trainables in self.list_layers():
551 | num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
552 | weights = [var for var in layer_trainables if var.name.endswith("/weight:0") or var.name.endswith("/weight_1:0")]
553 | weights.sort(key=lambda x: len(x.name))
554 | if len(weights) == 0 and len(layer_trainables) == 1:
555 | weights = layer_trainables
556 | total_params += num_params
557 |
558 | if not hide_layers_with_no_params or num_params != 0:
559 | num_params_str = str(num_params) if num_params > 0 else "-"
560 | output_shape_str = str(layer_output.shape)
561 | weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
562 | rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
563 |
564 | rows += [["---"] * 4]
565 | rows += [["Total", str(total_params), "", ""]]
566 |
567 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
568 | print()
569 | for row in rows:
570 | print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
571 | print()
572 |
573 | def setup_weight_histograms(self, title: str = None) -> None:
574 | """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
575 | if title is None:
576 | title = self.name
577 |
578 | with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
579 | for local_name, var in self.trainables.items():
580 | if "/" in local_name:
581 | p = local_name.split("/")
582 | name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
583 | else:
584 | name = title + "_toplevel/" + local_name
585 |
586 | tf.summary.histogram(name, var)
587 |
588 | #----------------------------------------------------------------------------
589 | # Backwards-compatible emulation of legacy output transformation in Network.run().
590 |
591 | _print_legacy_warning = True
592 |
593 | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
594 | global _print_legacy_warning
595 | legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
596 | if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
597 | return output_transform, dynamic_kwargs
598 |
599 | if _print_legacy_warning:
600 | _print_legacy_warning = False
601 | print()
602 | print("WARNING: Old-style output transformations in Network.run() are deprecated.")
603 | print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
604 | print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
605 | print()
606 | assert output_transform is None
607 |
608 | new_kwargs = dict(dynamic_kwargs)
609 | new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
610 | new_transform["func"] = _legacy_output_transform_func
611 | return new_transform, new_kwargs
612 |
613 | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
614 | if out_mul != 1.0:
615 | expr = [x * out_mul for x in expr]
616 |
617 | if out_add != 0.0:
618 | expr = [x + out_add for x in expr]
619 |
620 | if out_shrink > 1:
621 | ksize = [1, 1, out_shrink, out_shrink]
622 | expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
623 |
624 | if out_dtype is not None:
625 | if tf.as_dtype(out_dtype).is_integer:
626 | expr = [tf.round(x) for x in expr]
627 | expr = [tf.saturate_cast(x, out_dtype) for x in expr]
628 | return expr
629 |
--------------------------------------------------------------------------------
/dnnlib/tflib/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper wrapper for a Tensorflow optimizer."""
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from collections import OrderedDict
14 | from typing import List, Union
15 |
16 | from . import autosummary
17 | from . import tfutil
18 | from .. import util
19 |
20 | from .tfutil import TfExpression, TfExpressionEx
21 |
22 | try:
23 | # TensorFlow 1.13
24 | from tensorflow.python.ops import nccl_ops
25 | except:
26 | # Older TensorFlow versions
27 | import tensorflow.contrib.nccl as nccl_ops
28 |
29 | class Optimizer:
30 | """A Wrapper for tf.train.Optimizer.
31 |
32 | Automatically takes care of:
33 | - Gradient averaging for multi-GPU training.
34 | - Dynamic loss scaling and typecasts for FP16 training.
35 | - Ignoring corrupted gradients that contain NaNs/Infs.
36 | - Reporting statistics.
37 | - Well-chosen default settings.
38 | """
39 |
40 | def __init__(self,
41 | name: str = "Train",
42 | tf_optimizer: str = "tf.train.AdamOptimizer",
43 | learning_rate: TfExpressionEx = 0.001,
44 | use_loss_scaling: bool = False,
45 | loss_scaling_init: float = 64.0,
46 | loss_scaling_inc: float = 0.0005,
47 | loss_scaling_dec: float = 1.0,
48 | **kwargs):
49 |
50 | # Init fields.
51 | self.name = name
52 | self.learning_rate = tf.convert_to_tensor(learning_rate)
53 | self.id = self.name.replace("/", ".")
54 | self.scope = tf.get_default_graph().unique_name(self.id)
55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer)
56 | self.optimizer_kwargs = dict(kwargs)
57 | self.use_loss_scaling = use_loss_scaling
58 | self.loss_scaling_init = loss_scaling_init
59 | self.loss_scaling_inc = loss_scaling_inc
60 | self.loss_scaling_dec = loss_scaling_dec
61 | self._grad_shapes = None # [shape, ...]
62 | self._dev_opt = OrderedDict() # device => optimizer
63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
65 | self._updates_applied = False
66 |
67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
68 | """Register the gradients of the given loss function with respect to the given variables.
69 | Intended to be called once per GPU."""
70 | assert not self._updates_applied
71 |
72 | # Validate arguments.
73 | if isinstance(trainable_vars, dict):
74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
75 |
76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
78 |
79 | if self._grad_shapes is None:
80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
81 |
82 | assert len(trainable_vars) == len(self._grad_shapes)
83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
84 |
85 | dev = loss.device
86 |
87 | assert all(var.device == dev for var in trainable_vars)
88 |
89 | # Register device and compute gradients.
90 | with tf.name_scope(self.id + "_grad"), tf.device(dev):
91 | if dev not in self._dev_opt:
92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
93 | assert callable(self.optimizer_class)
94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
95 | self._dev_grads[dev] = []
96 |
97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
100 | self._dev_grads[dev].append(grads)
101 |
102 | def apply_updates(self) -> tf.Operation:
103 | """Construct training op to update the registered variables based on their gradients."""
104 | tfutil.assert_tf_initialized()
105 | assert not self._updates_applied
106 | self._updates_applied = True
107 | devices = list(self._dev_grads.keys())
108 | total_grads = sum(len(grads) for grads in self._dev_grads.values())
109 | assert len(devices) >= 1 and total_grads >= 1
110 | ops = []
111 |
112 | with tfutil.absolute_name_scope(self.scope):
113 | # Cast gradients to FP32 and calculate partial sum within each device.
114 | dev_grads = OrderedDict() # device => [(grad, var), ...]
115 |
116 | for dev_idx, dev in enumerate(devices):
117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
118 | sums = []
119 |
120 | for gv in zip(*self._dev_grads[dev]):
121 | assert all(v is gv[0][1] for g, v in gv)
122 | g = [tf.cast(g, tf.float32) for g, v in gv]
123 | g = g[0] if len(g) == 1 else tf.add_n(g)
124 | sums.append((g, gv[0][1]))
125 |
126 | dev_grads[dev] = sums
127 |
128 | # Sum gradients across devices.
129 | if len(devices) > 1:
130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None):
131 | for var_idx, grad_shape in enumerate(self._grad_shapes):
132 | g = [dev_grads[dev][var_idx][0] for dev in devices]
133 |
134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors
135 | g = nccl_ops.all_sum(g)
136 |
137 | for dev, gg in zip(devices, g):
138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
139 |
140 | # Apply updates separately on each device.
141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
143 | # Scale gradients as needed.
144 | if self.use_loss_scaling or total_grads > 1:
145 | with tf.name_scope("Scale"):
146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
147 | coef = self.undo_loss_scaling(coef)
148 | grads = [(g * coef, v) for g, v in grads]
149 |
150 | # Check for overflows.
151 | with tf.name_scope("CheckOverflow"):
152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
153 |
154 | # Update weights and adjust loss scaling.
155 | with tf.name_scope("UpdateWeights"):
156 | # pylint: disable=cell-var-from-loop
157 | opt = self._dev_opt[dev]
158 | ls_var = self.get_loss_scaling_var(dev)
159 |
160 | if not self.use_loss_scaling:
161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
162 | else:
163 | ops.append(tf.cond(grad_ok,
164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
166 |
167 | # Report statistics on the last device.
168 | if dev == devices[-1]:
169 | with tf.name_scope("Statistics"):
170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
172 |
173 | if self.use_loss_scaling:
174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
175 |
176 | # Initialize variables and group everything into a single op.
177 | self.reset_optimizer_state()
178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
179 |
180 | return tf.group(*ops, name="TrainingOp")
181 |
182 | def reset_optimizer_state(self) -> None:
183 | """Reset internal state of the underlying optimizer."""
184 | tfutil.assert_tf_initialized()
185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
186 |
187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
188 | """Get or create variable representing log2 of the current dynamic loss scaling factor."""
189 | if not self.use_loss_scaling:
190 | return None
191 |
192 | if device not in self._dev_ls_var:
193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
195 |
196 | return self._dev_ls_var[device]
197 |
198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
199 | """Apply dynamic loss scaling for the given expression."""
200 | assert tfutil.is_tf_expression(value)
201 |
202 | if not self.use_loss_scaling:
203 | return value
204 |
205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
206 |
207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
208 | """Undo the effect of dynamic loss scaling for the given expression."""
209 | assert tfutil.is_tf_expression(value)
210 |
211 | if not self.use_loss_scaling:
212 | return value
213 |
214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
215 |
--------------------------------------------------------------------------------
/dnnlib/tflib/tfutil.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous helper utils for Tensorflow."""
9 |
10 | import os
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from typing import Any, Iterable, List, Union
15 |
16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17 | """A type that represents a valid Tensorflow expression."""
18 |
19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20 | """A type that can be converted to a valid Tensorflow expression."""
21 |
22 |
23 | def run(*args, **kwargs) -> Any:
24 | """Run the specified ops in the default session."""
25 | assert_tf_initialized()
26 | return tf.get_default_session().run(*args, **kwargs)
27 |
28 |
29 | def is_tf_expression(x: Any) -> bool:
30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32 |
33 |
34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
35 | """Convert a Tensorflow shape to a list of ints."""
36 | return [dim.value for dim in shape]
37 |
38 |
39 | def flatten(x: TfExpressionEx) -> TfExpression:
40 | """Shortcut function for flattening a tensor."""
41 | with tf.name_scope("Flatten"):
42 | return tf.reshape(x, [-1])
43 |
44 |
45 | def log2(x: TfExpressionEx) -> TfExpression:
46 | """Logarithm in base 2."""
47 | with tf.name_scope("Log2"):
48 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
49 |
50 |
51 | def exp2(x: TfExpressionEx) -> TfExpression:
52 | """Exponent in base 2."""
53 | with tf.name_scope("Exp2"):
54 | return tf.exp(x * np.float32(np.log(2.0)))
55 |
56 |
57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58 | """Linear interpolation."""
59 | with tf.name_scope("Lerp"):
60 | return a + (b - a) * t
61 |
62 |
63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64 | """Linear interpolation with clip."""
65 | with tf.name_scope("LerpClip"):
66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67 |
68 |
69 | def absolute_name_scope(scope: str) -> tf.name_scope:
70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71 | return tf.name_scope(scope + "/")
72 |
73 |
74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77 |
78 |
79 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
80 | # Defaults.
81 | cfg = dict()
82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87 |
88 | # User overrides.
89 | if config_dict is not None:
90 | cfg.update(config_dict)
91 | return cfg
92 |
93 |
94 | def init_tf(config_dict: dict = None) -> None:
95 | """Initialize TensorFlow session using good default settings."""
96 | # Skip if already initialized.
97 | if tf.get_default_session() is not None:
98 | tf.reset_default_graph()
99 |
100 | # Setup config dict and random seeds.
101 | cfg = _sanitize_tf_config(config_dict)
102 | np_random_seed = cfg["rnd.np_random_seed"]
103 | if np_random_seed is not None:
104 | np.random.seed(np_random_seed)
105 | tf_random_seed = cfg["rnd.tf_random_seed"]
106 | if tf_random_seed == "auto":
107 | tf_random_seed = np.random.randint(1 << 31)
108 | if tf_random_seed is not None:
109 | tf.set_random_seed(tf_random_seed)
110 |
111 | # Setup environment variables.
112 | for key, value in list(cfg.items()):
113 | fields = key.split(".")
114 | if fields[0] == "env":
115 | assert len(fields) == 2
116 | os.environ[fields[1]] = str(value)
117 |
118 | # Create default TensorFlow session.
119 | create_session(cfg, force_as_default=True)
120 |
121 |
122 | def assert_tf_initialized():
123 | """Check that TensorFlow session has been initialized."""
124 | if tf.get_default_session() is None:
125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126 |
127 |
128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
129 | """Create tf.Session based on config dict."""
130 | # Setup TensorFlow config proto.
131 | cfg = _sanitize_tf_config(config_dict)
132 | config_proto = tf.ConfigProto()
133 | for key, value in cfg.items():
134 | fields = key.split(".")
135 | if fields[0] not in ["rnd", "env"]:
136 | obj = config_proto
137 | for field in fields[:-1]:
138 | obj = getattr(obj, field)
139 | setattr(obj, fields[-1], value)
140 |
141 | # Create session.
142 | session = tf.Session(config=config_proto)
143 | if force_as_default:
144 | # pylint: disable=protected-access
145 | session._default_session = session.as_default()
146 | session._default_session.enforce_nesting = False
147 | session._default_session.__enter__() # pylint: disable=no-member
148 |
149 | return session
150 |
151 |
152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153 | """Initialize all tf.Variables that have not already been initialized.
154 |
155 | Equivalent to the following, but more efficient and does not bloat the tf graph:
156 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
157 | """
158 | assert_tf_initialized()
159 | if target_vars is None:
160 | target_vars = tf.global_variables()
161 |
162 | test_vars = []
163 | test_ops = []
164 |
165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
166 | for var in target_vars:
167 | assert is_tf_expression(var)
168 |
169 | try:
170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171 | except KeyError:
172 | # Op does not exist => variable may be uninitialized.
173 | test_vars.append(var)
174 |
175 | with absolute_name_scope(var.name.split(":")[0]):
176 | test_ops.append(tf.is_variable_initialized(var))
177 |
178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179 | run([var.initializer for var in init_vars])
180 |
181 |
182 | def set_vars(var_to_value_dict: dict) -> None:
183 | """Set the values of given tf.Variables.
184 |
185 | Equivalent to the following, but more efficient and does not bloat the tf graph:
186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187 | """
188 | assert_tf_initialized()
189 | ops = []
190 | feed_dict = {}
191 |
192 | for var, value in var_to_value_dict.items():
193 | assert is_tf_expression(var)
194 |
195 | try:
196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197 | except KeyError:
198 | with absolute_name_scope(var.name.split(":")[0]):
199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201 |
202 | ops.append(setter)
203 | feed_dict[setter.op.inputs[1]] = value
204 |
205 | run(ops, feed_dict)
206 |
207 |
208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209 | """Create tf.Variable with large initial value without bloating the tf graph."""
210 | assert_tf_initialized()
211 | assert isinstance(initial_value, np.ndarray)
212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213 | var = tf.Variable(zeros, *args, **kwargs)
214 | set_vars({var: initial_value})
215 | return var
216 |
217 |
218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220 | Can be used as an input transformation for Network.run().
221 | """
222 | images = tf.cast(images, tf.float32)
223 | if nhwc_to_nchw:
224 | images = tf.transpose(images, [0, 3, 1, 2])
225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226 |
227 |
228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True):
229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230 | Can be used as an output transformation for Network.run().
231 | """
232 | images = tf.cast(images, tf.float32)
233 | if shrink > 1:
234 | ksize = [1, 1, shrink, shrink]
235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236 | if nchw_to_nhwc:
237 | images = tf.transpose(images, [0, 2, 3, 1])
238 | scale = 255 / (drange[1] - drange[0])
239 | images = images * scale + (0.5 - drange[0] * scale)
240 | if uint8_cast:
241 | images = tf.saturate_cast(images, tf.uint8)
242 | return images
243 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import uuid
27 |
28 | from distutils.util import strtobool
29 | from typing import Any, List, Tuple, Union
30 |
31 |
32 | # Util classes
33 | # ------------------------------------------------------------------------------------------
34 |
35 |
36 | class EasyDict(dict):
37 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
38 |
39 | def __getattr__(self, name: str) -> Any:
40 | try:
41 | return self[name]
42 | except KeyError:
43 | raise AttributeError(name)
44 |
45 | def __setattr__(self, name: str, value: Any) -> None:
46 | self[name] = value
47 |
48 | def __delattr__(self, name: str) -> None:
49 | del self[name]
50 |
51 |
52 | class Logger(object):
53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
54 |
55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
56 | self.file = None
57 |
58 | if file_name is not None:
59 | self.file = open(file_name, file_mode)
60 |
61 | self.should_flush = should_flush
62 | self.stdout = sys.stdout
63 | self.stderr = sys.stderr
64 |
65 | sys.stdout = self
66 | sys.stderr = self
67 |
68 | def __enter__(self) -> "Logger":
69 | return self
70 |
71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
72 | self.close()
73 |
74 | def write(self, text: str) -> None:
75 | """Write text to stdout (and a file) and optionally flush."""
76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
77 | return
78 |
79 | if self.file is not None:
80 | self.file.write(text)
81 |
82 | self.stdout.write(text)
83 |
84 | if self.should_flush:
85 | self.flush()
86 |
87 | def flush(self) -> None:
88 | """Flush written text to both stdout and a file, if open."""
89 | if self.file is not None:
90 | self.file.flush()
91 |
92 | self.stdout.flush()
93 |
94 | def close(self) -> None:
95 | """Flush, close possible files, and remove stdout/stderr mirroring."""
96 | self.flush()
97 |
98 | # if using multiple loggers, prevent closing in wrong order
99 | if sys.stdout is self:
100 | sys.stdout = self.stdout
101 | if sys.stderr is self:
102 | sys.stderr = self.stderr
103 |
104 | if self.file is not None:
105 | self.file.close()
106 |
107 |
108 | # Small util functions
109 | # ------------------------------------------------------------------------------------------
110 |
111 |
112 | def format_time(seconds: Union[int, float]) -> str:
113 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
114 | s = int(np.rint(seconds))
115 |
116 | if s < 60:
117 | return "{0}s".format(s)
118 | elif s < 60 * 60:
119 | return "{0}m {1:02}s".format(s // 60, s % 60)
120 | elif s < 24 * 60 * 60:
121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
122 | else:
123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
124 |
125 |
126 | def ask_yes_no(question: str) -> bool:
127 | """Ask the user the question until the user inputs a valid answer."""
128 | while True:
129 | try:
130 | print("{0} [y/n]".format(question))
131 | return strtobool(input().lower())
132 | except ValueError:
133 | pass
134 |
135 |
136 | def tuple_product(t: Tuple) -> Any:
137 | """Calculate the product of the tuple elements."""
138 | result = 1
139 |
140 | for v in t:
141 | result *= v
142 |
143 | return result
144 |
145 |
146 | _str_to_ctype = {
147 | "uint8": ctypes.c_ubyte,
148 | "uint16": ctypes.c_uint16,
149 | "uint32": ctypes.c_uint32,
150 | "uint64": ctypes.c_uint64,
151 | "int8": ctypes.c_byte,
152 | "int16": ctypes.c_int16,
153 | "int32": ctypes.c_int32,
154 | "int64": ctypes.c_int64,
155 | "float32": ctypes.c_float,
156 | "float64": ctypes.c_double
157 | }
158 |
159 |
160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
162 | type_str = None
163 |
164 | if isinstance(type_obj, str):
165 | type_str = type_obj
166 | elif hasattr(type_obj, "__name__"):
167 | type_str = type_obj.__name__
168 | elif hasattr(type_obj, "name"):
169 | type_str = type_obj.name
170 | else:
171 | raise RuntimeError("Cannot infer type name from input")
172 |
173 | assert type_str in _str_to_ctype.keys()
174 |
175 | my_dtype = np.dtype(type_str)
176 | my_ctype = _str_to_ctype[type_str]
177 |
178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
179 |
180 | return my_dtype, my_ctype
181 |
182 |
183 | def is_pickleable(obj: Any) -> bool:
184 | try:
185 | with io.BytesIO() as stream:
186 | pickle.dump(obj, stream)
187 | return True
188 | except:
189 | return False
190 |
191 |
192 | # Functionality to import modules/objects by name, and call functions by name
193 | # ------------------------------------------------------------------------------------------
194 |
195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
196 | """Searches for the underlying module behind the name to some python object.
197 | Returns the module and the object name (original name with module part removed)."""
198 |
199 | # allow convenience shorthands, substitute them by full names
200 | obj_name = re.sub("^np.", "numpy.", obj_name)
201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
202 |
203 | # list alternatives for (module_name, local_obj_name)
204 | parts = obj_name.split(".")
205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
206 |
207 | # try each alternative in turn
208 | for module_name, local_obj_name in name_pairs:
209 | try:
210 | module = importlib.import_module(module_name) # may raise ImportError
211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
212 | return module, local_obj_name
213 | except:
214 | pass
215 |
216 | # maybe some of the modules themselves contain errors?
217 | for module_name, _local_obj_name in name_pairs:
218 | try:
219 | importlib.import_module(module_name) # may raise ImportError
220 | except ImportError:
221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
222 | raise
223 |
224 | # maybe the requested attribute is missing?
225 | for module_name, local_obj_name in name_pairs:
226 | try:
227 | module = importlib.import_module(module_name) # may raise ImportError
228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
229 | except ImportError:
230 | pass
231 |
232 | # we are out of luck, but we have no idea why
233 | raise ImportError(obj_name)
234 |
235 |
236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
237 | """Traverses the object name and returns the last (rightmost) python object."""
238 | if obj_name == '':
239 | return module
240 | obj = module
241 | for part in obj_name.split("."):
242 | obj = getattr(obj, part)
243 | return obj
244 |
245 |
246 | def get_obj_by_name(name: str) -> Any:
247 | """Finds the python object with the given name."""
248 | module, obj_name = get_module_from_obj_name(name)
249 | return get_obj_from_module(module, obj_name)
250 |
251 |
252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
253 | """Finds the python object with the given name and calls it as a function."""
254 | assert func_name is not None
255 | func_obj = get_obj_by_name(func_name)
256 | assert callable(func_obj)
257 | return func_obj(*args, **kwargs)
258 |
259 |
260 | def get_module_dir_by_obj_name(obj_name: str) -> str:
261 | """Get the directory path of the module containing the given object name."""
262 | module, _ = get_module_from_obj_name(obj_name)
263 | return os.path.dirname(inspect.getfile(module))
264 |
265 |
266 | def is_top_level_function(obj: Any) -> bool:
267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
269 |
270 |
271 | def get_top_level_function_name(obj: Any) -> str:
272 | """Return the fully-qualified name of a top-level function."""
273 | assert is_top_level_function(obj)
274 | return obj.__module__ + "." + obj.__name__
275 |
276 |
277 | # File system helpers
278 | # ------------------------------------------------------------------------------------------
279 |
280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
281 | """List all files recursively in a given directory while ignoring given file and directory names.
282 | Returns list of tuples containing both absolute and relative paths."""
283 | assert os.path.isdir(dir_path)
284 | base_name = os.path.basename(os.path.normpath(dir_path))
285 |
286 | if ignores is None:
287 | ignores = []
288 |
289 | result = []
290 |
291 | for root, dirs, files in os.walk(dir_path, topdown=True):
292 | for ignore_ in ignores:
293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
294 |
295 | # dirs need to be edited in-place
296 | for d in dirs_to_remove:
297 | dirs.remove(d)
298 |
299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
300 |
301 | absolute_paths = [os.path.join(root, f) for f in files]
302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
303 |
304 | if add_base_to_relative:
305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
306 |
307 | assert len(absolute_paths) == len(relative_paths)
308 | result += zip(absolute_paths, relative_paths)
309 |
310 | return result
311 |
312 |
313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
314 | """Takes in a list of tuples of (src, dst) paths and copies files.
315 | Will create all necessary directories."""
316 | for file in files:
317 | target_dir_name = os.path.dirname(file[1])
318 |
319 | # will create all intermediate-level directories
320 | if not os.path.exists(target_dir_name):
321 | os.makedirs(target_dir_name)
322 |
323 | shutil.copyfile(file[0], file[1])
324 |
325 |
326 | # URL helpers
327 | # ------------------------------------------------------------------------------------------
328 |
329 | def is_url(obj: Any) -> bool:
330 | """Determine whether the given object is a valid URL string."""
331 | if not isinstance(obj, str) or not "://" in obj:
332 | return False
333 | try:
334 | res = requests.compat.urlparse(obj)
335 | if not res.scheme or not res.netloc or not "." in res.netloc:
336 | return False
337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
338 | if not res.scheme or not res.netloc or not "." in res.netloc:
339 | return False
340 | except:
341 | return False
342 | return True
343 |
344 |
345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
346 | """Download the given URL and return a binary-mode file object to access the data."""
347 | if not is_url(url) and os.path.isfile(url):
348 | return open(url, 'rb')
349 |
350 | assert is_url(url)
351 | assert num_attempts >= 1
352 |
353 | # Lookup from cache.
354 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
355 | if cache_dir is not None:
356 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
357 | if len(cache_files) == 1:
358 | return open(cache_files[0], "rb")
359 |
360 | # Download.
361 | url_name = None
362 | url_data = None
363 | with requests.Session() as session:
364 | if verbose:
365 | print("Downloading %s ..." % url, end="", flush=True)
366 | for attempts_left in reversed(range(num_attempts)):
367 | try:
368 | with session.get(url) as res:
369 | res.raise_for_status()
370 | if len(res.content) == 0:
371 | raise IOError("No data received")
372 |
373 | if len(res.content) < 8192:
374 | content_str = res.content.decode("utf-8")
375 | if "download_warning" in res.headers.get("Set-Cookie", ""):
376 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
377 | if len(links) == 1:
378 | url = requests.compat.urljoin(url, links[0])
379 | raise IOError("Google Drive virus checker nag")
380 | if "Google Drive - Quota exceeded" in content_str:
381 | raise IOError("Google Drive quota exceeded")
382 |
383 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
384 | url_name = match[1] if match else url
385 | url_data = res.content
386 | if verbose:
387 | print(" done")
388 | break
389 | except:
390 | if not attempts_left:
391 | if verbose:
392 | print(" failed")
393 | raise
394 | if verbose:
395 | print(".", end="", flush=True)
396 |
397 | # Save to cache.
398 | if cache_dir is not None:
399 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
400 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
401 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
402 | os.makedirs(cache_dir, exist_ok=True)
403 | with open(temp_file, "wb") as f:
404 | f.write(url_data)
405 | os.replace(temp_file, cache_file) # atomic
406 |
407 | # Return data as file object.
408 | return io.BytesIO(url_data)
409 |
--------------------------------------------------------------------------------
/encoder/generator_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import tensorflow as tf
3 | import numpy as np
4 | import dnnlib.tflib as tflib
5 | from functools import partial
6 |
7 |
8 | def create_stub(name, batch_size):
9 | return tf.constant(0, dtype='float32', shape=(batch_size, 0))
10 |
11 |
12 | def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1):
13 | if tiled_dlatent:
14 | low_dim_dlatent = tf.get_variable('learnable_dlatents',
15 | shape=(batch_size, tile_size, 512),
16 | dtype='float32',
17 | initializer=tf.initializers.random_normal())
18 | return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1])
19 | else:
20 | return tf.get_variable('learnable_dlatents',
21 | shape=(batch_size, model_scale, 512),
22 | dtype='float32',
23 | initializer=tf.initializers.random_normal())
24 |
25 |
26 | class Generator:
27 | def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False, initial=True):
28 | self.batch_size = batch_size
29 | self.tiled_dlatent=tiled_dlatent
30 | self.model_scale = int(2*(math.log(model_res, 2)-1)) # For example, 1024 -> 18
31 | if tiled_dlatent:
32 | self.initial_dlatents = np.zeros((self.batch_size, 1, 512))
33 | if initial:
34 | model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
35 | randomize_noise=randomize_noise, minibatch_size=self.batch_size,
36 | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True, model_scale=self.model_scale),
37 | partial(create_stub, batch_size=batch_size)],
38 | structure='fixed')
39 | else:
40 | self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
41 | if initial:
42 | if custom_input is not None:
43 | model.components.synthesis.run(self.initial_dlatents,
44 | randomize_noise=randomize_noise, minibatch_size=self.batch_size,
45 | custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)],
46 | structure='fixed')
47 | else:
48 | model.components.synthesis.run(self.initial_dlatents,
49 | randomize_noise=randomize_noise, minibatch_size=self.batch_size,
50 | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
51 | partial(create_stub, batch_size=batch_size)],
52 | structure='fixed')
53 | self.dlatent_avg_def = model.get_var('dlatent_avg')
54 | self.reset_dlatent_avg()
55 | self.sess = tf.get_default_session()
56 | self.graph = tf.get_default_graph()
57 |
58 | self.dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name)
59 | self._assign_dlatent_ph = tf.placeholder(tf.float32, name="assign_dlatent_ph")
60 | self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
61 | self.set_dlatents(self.initial_dlatents)
62 |
63 | def get_tensor(name):
64 | try:
65 | return self.graph.get_tensor_by_name(name)
66 | except KeyError:
67 | return None
68 |
69 | self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
70 | if self.generator_output is None:
71 | self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
72 | if self.generator_output is None:
73 | self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
74 | # If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
75 | if self.generator_output is None:
76 | self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
77 | if self.generator_output is None:
78 | self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
79 | if self.generator_output is None:
80 | self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
81 | if self.generator_output is None:
82 | for op in self.graph.get_operations():
83 | print(op)
84 | raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
85 | self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
86 | self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)
87 |
88 | # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
89 | # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
90 | # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
91 | clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
92 | clipped_values = tf.where(clipping_mask, tf.random_normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
93 | self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
94 |
95 | def reset_dlatents(self):
96 | self.set_dlatents(self.initial_dlatents)
97 |
98 | def set_dlatents(self, dlatents):
99 | if self.tiled_dlatent:
100 | if (dlatents.shape != (self.batch_size, 1, 512)) and (dlatents.shape[1] != 512):
101 | dlatents = np.mean(dlatents, axis=1, keepdims=True)
102 | if (dlatents.shape != (self.batch_size, 1, 512)):
103 | dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 1, 512))])
104 | assert (dlatents.shape == (self.batch_size, 1, 512))
105 | else:
106 | if (dlatents.shape[1] > self.model_scale):
107 | dlatents = dlatents[:,:self.model_scale,:]
108 | if (isinstance(dlatents.shape[0], int)):
109 | if (dlatents.shape != (self.batch_size, self.model_scale, 512)):
110 | dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))])
111 | assert (dlatents.shape == (self.batch_size, self.model_scale, 512))
112 | self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
113 | return
114 | else:
115 | self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents)
116 | return
117 | self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
118 |
119 | def stochastic_clip_dlatents(self):
120 | self.sess.run(self.stochastic_clip_op)
121 |
122 | def get_dlatents(self):
123 | return self.sess.run(self.dlatent_variable)
124 |
125 | def get_dlatent_avg(self):
126 | return self.dlatent_avg
127 |
128 | def set_dlatent_avg(self, dlatent_avg):
129 | self.dlatent_avg = dlatent_avg
130 |
131 | def reset_dlatent_avg(self):
132 | self.dlatent_avg = self.dlatent_avg_def
133 |
134 | def generate_images(self, dlatents=None):
135 | if dlatents is not None:
136 | self.set_dlatents(dlatents)
137 | return self.sess.run(self.generated_image_uint8)
138 |
--------------------------------------------------------------------------------
/encoder/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 |
10 | from encoder.resnet import Resnet18
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 |
14 | class ConvBNReLU(nn.Module):
15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16 | super(ConvBNReLU, self).__init__()
17 | self.conv = nn.Conv2d(in_chan,
18 | out_chan,
19 | kernel_size=ks,
20 | stride=stride,
21 | padding=padding,
22 | bias=False)
23 | self.bn = nn.BatchNorm2d(out_chan)
24 | self.init_weight()
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = F.relu(self.bn(x))
29 | return x
30 |
31 | def init_weight(self):
32 | for ly in self.children():
33 | if isinstance(ly, nn.Conv2d):
34 | nn.init.kaiming_normal_(ly.weight, a=1)
35 | if not ly.bias is None:
36 | nn.init.constant_(ly.bias, 0)
37 |
38 |
39 | class BiSeNetOutput(nn.Module):
40 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
41 | super(BiSeNetOutput, self).__init__()
42 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
43 | self.conv_out = nn.Conv2d(
44 | mid_chan, n_classes, kernel_size=1, bias=False)
45 | self.init_weight()
46 |
47 | def forward(self, x):
48 | x = self.conv(x)
49 | x = self.conv_out(x)
50 | return x
51 |
52 | def init_weight(self):
53 | for ly in self.children():
54 | if isinstance(ly, nn.Conv2d):
55 | nn.init.kaiming_normal_(ly.weight, a=1)
56 | if not ly.bias is None:
57 | nn.init.constant_(ly.bias, 0)
58 |
59 | def get_params(self):
60 | wd_params, nowd_params = [], []
61 | for name, module in self.named_modules():
62 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
63 | wd_params.append(module.weight)
64 | if not module.bias is None:
65 | nowd_params.append(module.bias)
66 | elif isinstance(module, nn.BatchNorm2d):
67 | nowd_params += list(module.parameters())
68 | return wd_params, nowd_params
69 |
70 |
71 | class AttentionRefinementModule(nn.Module):
72 | def __init__(self, in_chan, out_chan, *args, **kwargs):
73 | super(AttentionRefinementModule, self).__init__()
74 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
75 | self.conv_atten = nn.Conv2d(
76 | out_chan, out_chan, kernel_size=1, bias=False)
77 | self.bn_atten = nn.BatchNorm2d(out_chan)
78 | self.sigmoid_atten = nn.Sigmoid()
79 | self.init_weight()
80 |
81 | def forward(self, x):
82 | feat = self.conv(x)
83 | atten = F.avg_pool2d(feat, feat.size()[2:])
84 | atten = self.conv_atten(atten)
85 | atten = self.bn_atten(atten)
86 | atten = self.sigmoid_atten(atten)
87 | out = torch.mul(feat, atten)
88 | return out
89 |
90 | def init_weight(self):
91 | for ly in self.children():
92 | if isinstance(ly, nn.Conv2d):
93 | nn.init.kaiming_normal_(ly.weight, a=1)
94 | if not ly.bias is None:
95 | nn.init.constant_(ly.bias, 0)
96 |
97 |
98 | class ContextPath(nn.Module):
99 | def __init__(self, *args, **kwargs):
100 | super(ContextPath, self).__init__()
101 | self.resnet = Resnet18()
102 | self.arm16 = AttentionRefinementModule(256, 128)
103 | self.arm32 = AttentionRefinementModule(512, 128)
104 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
105 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
106 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
107 |
108 | self.init_weight()
109 |
110 | def forward(self, x):
111 | H0, W0 = x.size()[2:]
112 | feat8, feat16, feat32 = self.resnet(x)
113 | H8, W8 = feat8.size()[2:]
114 | H16, W16 = feat16.size()[2:]
115 | H32, W32 = feat32.size()[2:]
116 |
117 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
118 | avg = self.conv_avg(avg)
119 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
120 |
121 | feat32_arm = self.arm32(feat32)
122 | feat32_sum = feat32_arm + avg_up
123 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
124 | feat32_up = self.conv_head32(feat32_up)
125 |
126 | feat16_arm = self.arm16(feat16)
127 | feat16_sum = feat16_arm + feat32_up
128 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
129 | feat16_up = self.conv_head16(feat16_up)
130 |
131 | return feat8, feat16_up, feat32_up # x8, x8, x16
132 |
133 | def init_weight(self):
134 | for ly in self.children():
135 | if isinstance(ly, nn.Conv2d):
136 | nn.init.kaiming_normal_(ly.weight, a=1)
137 | if not ly.bias is None:
138 | nn.init.constant_(ly.bias, 0)
139 |
140 | def get_params(self):
141 | wd_params, nowd_params = [], []
142 | for name, module in self.named_modules():
143 | if isinstance(module, (nn.Linear, nn.Conv2d)):
144 | wd_params.append(module.weight)
145 | if not module.bias is None:
146 | nowd_params.append(module.bias)
147 | elif isinstance(module, nn.BatchNorm2d):
148 | nowd_params += list(module.parameters())
149 | return wd_params, nowd_params
150 |
151 |
152 | # This is not used, since I replace this with the resnet feature with the same size
153 | class SpatialPath(nn.Module):
154 | def __init__(self, *args, **kwargs):
155 | super(SpatialPath, self).__init__()
156 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
157 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
158 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
159 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
160 | self.init_weight()
161 |
162 | def forward(self, x):
163 | feat = self.conv1(x)
164 | feat = self.conv2(feat)
165 | feat = self.conv3(feat)
166 | feat = self.conv_out(feat)
167 | return feat
168 |
169 | def init_weight(self):
170 | for ly in self.children():
171 | if isinstance(ly, nn.Conv2d):
172 | nn.init.kaiming_normal_(ly.weight, a=1)
173 | if not ly.bias is None:
174 | nn.init.constant_(ly.bias, 0)
175 |
176 | def get_params(self):
177 | wd_params, nowd_params = [], []
178 | for name, module in self.named_modules():
179 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
180 | wd_params.append(module.weight)
181 | if not module.bias is None:
182 | nowd_params.append(module.bias)
183 | elif isinstance(module, nn.BatchNorm2d):
184 | nowd_params += list(module.parameters())
185 | return wd_params, nowd_params
186 |
187 |
188 | class FeatureFusionModule(nn.Module):
189 | def __init__(self, in_chan, out_chan, *args, **kwargs):
190 | super(FeatureFusionModule, self).__init__()
191 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
192 | self.conv1 = nn.Conv2d(out_chan,
193 | out_chan//4,
194 | kernel_size=1,
195 | stride=1,
196 | padding=0,
197 | bias=False)
198 | self.conv2 = nn.Conv2d(out_chan//4,
199 | out_chan,
200 | kernel_size=1,
201 | stride=1,
202 | padding=0,
203 | bias=False)
204 | self.relu = nn.ReLU(inplace=True)
205 | self.sigmoid = nn.Sigmoid()
206 | self.init_weight()
207 |
208 | def forward(self, fsp, fcp):
209 | fcat = torch.cat([fsp, fcp], dim=1)
210 | feat = self.convblk(fcat)
211 | atten = F.avg_pool2d(feat, feat.size()[2:])
212 | atten = self.conv1(atten)
213 | atten = self.relu(atten)
214 | atten = self.conv2(atten)
215 | atten = self.sigmoid(atten)
216 | feat_atten = torch.mul(feat, atten)
217 | feat_out = feat_atten + feat
218 | return feat_out
219 |
220 | def init_weight(self):
221 | for ly in self.children():
222 | if isinstance(ly, nn.Conv2d):
223 | nn.init.kaiming_normal_(ly.weight, a=1)
224 | if not ly.bias is None:
225 | nn.init.constant_(ly.bias, 0)
226 |
227 | def get_params(self):
228 | wd_params, nowd_params = [], []
229 | for name, module in self.named_modules():
230 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
231 | wd_params.append(module.weight)
232 | if not module.bias is None:
233 | nowd_params.append(module.bias)
234 | elif isinstance(module, nn.BatchNorm2d):
235 | nowd_params += list(module.parameters())
236 | return wd_params, nowd_params
237 |
238 |
239 | class BiSeNet(nn.Module):
240 | def __init__(self, n_classes, *args, **kwargs):
241 | super(BiSeNet, self).__init__()
242 | self.cp = ContextPath()
243 | # here self.sp is deleted
244 | self.ffm = FeatureFusionModule(256, 256)
245 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
246 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
247 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
248 | self.init_weight()
249 |
250 | def forward(self, x):
251 | H, W = x.size()[2:]
252 | feat_res8, feat_cp8, feat_cp16 = self.cp(
253 | x) # here return res3b1 feature
254 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
255 | feat_fuse = self.ffm(feat_sp, feat_cp8)
256 |
257 | feat_out = self.conv_out(feat_fuse)
258 | feat_out16 = self.conv_out16(feat_cp8)
259 | feat_out32 = self.conv_out32(feat_cp16)
260 |
261 | feat_out = F.interpolate(
262 | feat_out, (H, W), mode='bilinear', align_corners=True)
263 | feat_out16 = F.interpolate(
264 | feat_out16, (H, W), mode='bilinear', align_corners=True)
265 | feat_out32 = F.interpolate(
266 | feat_out32, (H, W), mode='bilinear', align_corners=True)
267 | return feat_out, feat_out16, feat_out32
268 |
269 | def init_weight(self):
270 | for ly in self.children():
271 | if isinstance(ly, nn.Conv2d):
272 | nn.init.kaiming_normal_(ly.weight, a=1)
273 | if not ly.bias is None:
274 | nn.init.constant_(ly.bias, 0)
275 |
276 | def get_params(self):
277 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
278 | for name, child in self.named_children():
279 | child_wd_params, child_nowd_params = child.get_params()
280 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
281 | lr_mul_wd_params += child_wd_params
282 | lr_mul_nowd_params += child_nowd_params
283 | else:
284 | wd_params += child_wd_params
285 | nowd_params += child_nowd_params
286 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
287 |
288 |
289 | if __name__ == "__main__":
290 | net = BiSeNet(19)
291 | net.cuda()
292 | net.eval()
293 | in_ten = torch.randn(16, 3, 640, 480).cuda()
294 | out, out16, out32 = net(in_ten)
295 | print(out.shape)
296 |
297 | net.get_params()
298 |
--------------------------------------------------------------------------------
/encoder/perceptual_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 | import tensorflow as tf
3 | #import tensorflow_probability as tfp
4 | #tf.enable_eager_execution()
5 |
6 | import os
7 | import bz2
8 | import PIL.Image
9 | from PIL import ImageFilter
10 | import numpy as np
11 | from keras.models import Model
12 | from keras.utils import get_file
13 | from keras.applications.vgg16 import VGG16, preprocess_input
14 | import keras.backend as K
15 | import traceback
16 | import dnnlib.tflib as tflib
17 |
18 | def load_image(image, image_size=256, sharpen=False):
19 | loaded_images = list()
20 | img = image.convert('RGB')
21 | if image_size is not None:
22 | img = img.resize((image_size, image_size), PIL.Image.LANCZOS)
23 | if (sharpen):
24 | img = img.filter(ImageFilter.DETAIL)
25 | img = np.array(img)
26 | img = np.expand_dims(img, 0)
27 | loaded_images.append(img)
28 | loaded_images = np.vstack(loaded_images)
29 | return loaded_images
30 |
31 | def tf_custom_adaptive_loss(a,b):
32 | from adaptive import lossfun
33 | shape = a.get_shape().as_list()
34 | dim = np.prod(shape[1:])
35 | a = tf.reshape(a, [-1, dim])
36 | b = tf.reshape(b, [-1, dim])
37 | loss, _, _ = lossfun(b-a, var_suffix='1')
38 | return tf.math.reduce_mean(loss)
39 |
40 | def tf_custom_adaptive_rgb_loss(a,b):
41 | from adaptive import image_lossfun
42 | loss, _, _ = image_lossfun(b-a, color_space='RGB', representation='PIXEL')
43 | return tf.math.reduce_mean(loss)
44 |
45 | def tf_custom_l1_loss(img1,img2):
46 | return tf.math.reduce_mean(tf.math.abs(img2-img1), axis=None)
47 |
48 | def tf_custom_logcosh_loss(img1,img2):
49 | return tf.math.reduce_mean(tf.keras.losses.logcosh(img1,img2))
50 |
51 | def create_stub(batch_size):
52 | return tf.constant(0, dtype='float32', shape=(batch_size, 0))
53 |
54 | def unpack_bz2(src_path):
55 | data = bz2.BZ2File(src_path).read()
56 | dst_path = src_path[:-4]
57 | with open(dst_path, 'wb') as fp:
58 | fp.write(data)
59 | return dst_path
60 |
61 | class PerceptualModel:
62 | def __init__(self, args, batch_size=1, perc_model=None, sess=None):
63 | self.sess = tf.get_default_session() if sess is None else sess
64 | K.set_session(self.sess)
65 | self.epsilon = 0.00000001
66 | self.lr = args.lr
67 | self.decay_rate = args.decay_rate
68 | self.decay_steps = args.decay_steps
69 | self.img_size = args.image_size
70 | self.layer = args.use_vgg_layer
71 | self.vgg_loss = args.use_vgg_loss
72 | if (self.layer <= 0 or self.vgg_loss <= self.epsilon):
73 | self.vgg_loss = None
74 | self.pixel_loss = args.use_pixel_loss
75 | if (self.pixel_loss <= self.epsilon):
76 | self.pixel_loss = None
77 | self.mssim_loss = args.use_mssim_loss
78 | if (self.mssim_loss <= self.epsilon):
79 | self.mssim_loss = None
80 | self.lpips_loss = args.use_lpips_loss
81 | if (self.lpips_loss <= self.epsilon):
82 | self.lpips_loss = None
83 | self.l1_penalty = args.use_l1_penalty
84 | if (self.l1_penalty <= self.epsilon):
85 | self.l1_penalty = None
86 | self.adaptive_loss = args.use_adaptive_loss
87 | self.sharpen_input = args.sharpen_input
88 | self.batch_size = batch_size
89 | if perc_model is not None and self.lpips_loss is not None:
90 | self.perc_model = perc_model
91 | else:
92 | self.perc_model = None
93 | self.ref_img = None
94 | self.ref_weight = None
95 | self.perceptual_model = None
96 | self.ref_img_features = None
97 | self.features_weight = None
98 | self.loss = None
99 | self.discriminator_loss = args.use_discriminator_loss
100 | if (self.discriminator_loss <= self.epsilon):
101 | self.discriminator_loss = None
102 | if self.discriminator_loss is not None:
103 | self.discriminator = None
104 | self.stub = create_stub(batch_size)
105 |
106 | def add_placeholder(self, var_name):
107 | var_val = getattr(self, var_name)
108 | setattr(self, var_name + "_placeholder", tf.placeholder(var_val.dtype, shape=var_val.get_shape()))
109 | setattr(self, var_name + "_op", var_val.assign(getattr(self, var_name + "_placeholder")))
110 |
111 | def assign_placeholder(self, var_name, var_val):
112 | self.sess.run(getattr(self, var_name + "_op"), {getattr(self, var_name + "_placeholder"): var_val})
113 |
114 | def build_perceptual_model(self, generator, discriminator=None):
115 | # Learning rate
116 | global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
117 | incremented_global_step = tf.assign_add(global_step, 1)
118 | self._reset_global_step = tf.assign(global_step, 0)
119 | self.learning_rate = tf.train.exponential_decay(self.lr, incremented_global_step,
120 | self.decay_steps, self.decay_rate, staircase=True)
121 | self.sess.run([self._reset_global_step])
122 |
123 | if self.discriminator_loss is not None:
124 | self.discriminator = discriminator
125 |
126 | generated_image_tensor = generator.generated_image
127 | generated_image = tf.image.resize_nearest_neighbor(generated_image_tensor,
128 | (self.img_size, self.img_size), align_corners=True)
129 |
130 | self.ref_img = tf.get_variable('ref_img', shape=generated_image.shape,
131 | dtype='float32', initializer=tf.initializers.zeros())
132 | self.ref_weight = tf.get_variable('ref_weight', shape=generated_image.shape,
133 | dtype='float32', initializer=tf.initializers.zeros())
134 | self.add_placeholder("ref_img")
135 | self.add_placeholder("ref_weight")
136 |
137 | if (self.vgg_loss is not None):
138 | vgg16 = VGG16(include_top=False, input_shape=(self.img_size, self.img_size, 3))
139 | self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output)
140 | generated_img_features = self.perceptual_model(preprocess_input(self.ref_weight * generated_image))
141 | self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape,
142 | dtype='float32', initializer=tf.initializers.zeros())
143 | self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape,
144 | dtype='float32', initializer=tf.initializers.zeros())
145 | self.sess.run([self.features_weight.initializer, self.features_weight.initializer])
146 | self.add_placeholder("ref_img_features")
147 | self.add_placeholder("features_weight")
148 |
149 | if self.perc_model is not None and self.lpips_loss is not None:
150 | img1 = tflib.convert_images_from_uint8(self.ref_weight * self.ref_img, nhwc_to_nchw=True)
151 | img2 = tflib.convert_images_from_uint8(self.ref_weight * generated_image, nhwc_to_nchw=True)
152 |
153 | self.loss = 0
154 | # L1 loss on VGG16 features
155 | if (self.vgg_loss is not None):
156 | if self.adaptive_loss:
157 | self.loss += self.vgg_loss * tf_custom_adaptive_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
158 | else:
159 | self.loss += self.vgg_loss * tf_custom_logcosh_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
160 | # + logcosh loss on image pixels
161 | if (self.pixel_loss is not None):
162 | if self.adaptive_loss:
163 | self.loss += self.pixel_loss * tf_custom_adaptive_rgb_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
164 | else:
165 | self.loss += self.pixel_loss * tf_custom_logcosh_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
166 | # + MS-SIM loss on image pixels
167 | if (self.mssim_loss is not None):
168 | self.loss += self.mssim_loss * tf.math.reduce_mean(1-tf.image.ssim_multiscale(self.ref_weight * self.ref_img, self.ref_weight * generated_image, 1))
169 | # + extra perceptual loss on image pixels
170 | if self.perc_model is not None and self.lpips_loss is not None:
171 | self.loss += self.lpips_loss * tf.math.reduce_mean(self.perc_model.get_output_for(img1, img2))
172 | # + L1 penalty on dlatent weights
173 | if self.l1_penalty is not None:
174 | self.loss += self.l1_penalty * 512 * tf.math.reduce_mean(tf.math.abs(generator.dlatent_variable-generator.get_dlatent_avg()))
175 | # discriminator loss (realism)
176 | if self.discriminator_loss is not None:
177 | self.loss += self.discriminator_loss * tf.math.reduce_mean(self.discriminator.get_output_for(tflib.convert_images_from_uint8(generated_image_tensor, nhwc_to_nchw=True), self.stub))
178 | # - discriminator_network.get_output_for(tflib.convert_images_from_uint8(ref_img, nhwc_to_nchw=True), stub)
179 |
180 |
181 | def generate_face_mask(self, im):
182 | from imutils import face_utils
183 | import cv2
184 | rects = self.detector(im, 1)
185 | # loop over the face detections
186 | for (j, rect) in enumerate(rects):
187 | """
188 | Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array
189 | """
190 | shape = self.predictor(im, rect)
191 | shape = face_utils.shape_to_np(shape)
192 |
193 | # we extract the face
194 | vertices = cv2.convexHull(shape)
195 | mask = np.zeros(im.shape[:2],np.uint8)
196 | cv2.fillConvexPoly(mask, vertices, 1)
197 | if self.use_grabcut:
198 | bgdModel = np.zeros((1,65),np.float64)
199 | fgdModel = np.zeros((1,65),np.float64)
200 | rect = (0,0,im.shape[1],im.shape[2])
201 | (x,y),radius = cv2.minEnclosingCircle(vertices)
202 | center = (int(x),int(y))
203 | radius = int(radius*self.scale_mask)
204 | mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1)
205 | cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD)
206 | cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK)
207 | mask = np.where((mask==2)|(mask==0),0,1)
208 | return mask
209 |
210 | def set_reference_image(self, image):
211 | loaded_image = load_image(image, self.img_size, sharpen=self.sharpen_input)
212 | image_features = None
213 | if self.perceptual_model is not None:
214 | image_features = self.perceptual_model.predict_on_batch(preprocess_input(np.array(loaded_image)))
215 | weight_mask = np.ones(self.features_weight.shape)
216 | if image_features is not None:
217 | self.assign_placeholder("features_weight", weight_mask)
218 | self.assign_placeholder("ref_img_features", image_features)
219 | image_mask = np.ones(self.ref_weight.shape)
220 | self.assign_placeholder("ref_weight", image_mask)
221 | self.assign_placeholder("ref_img", loaded_image)
222 |
223 | def optimize(self, vars_to_optimize, iterations=200, use_optimizer='adam'):
224 | vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]
225 | if use_optimizer == 'lbfgs':
226 | optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.loss, var_list=vars_to_optimize, method='L-BFGS-B', options={'maxiter': iterations})
227 | else:
228 | if use_optimizer == 'ggt':
229 | optimizer = tf.contrib.opt.GGTOptimizer(learning_rate=self.learning_rate)
230 | else:
231 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
232 | min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize])
233 | self.sess.run(tf.variables_initializer(optimizer.variables()))
234 | fetch_ops = [min_op, self.loss, self.learning_rate]
235 | #min_op = optimizer.minimize(self.sess)
236 | #optim_results = tfp.optimizer.lbfgs_minimize(make_val_and_grad_fn(get_loss), initial_position=vars_to_optimize, num_correction_pairs=10, tolerance=1e-8)
237 | self.sess.run(self._reset_global_step)
238 | #self.sess.graph.finalize() # Graph is read-only after this statement.
239 | for _ in range(iterations):
240 | if use_optimizer == 'lbfgs':
241 | optimizer.minimize(self.sess, fetches=[vars_to_optimize, self.loss])
242 | yield {"loss":self.loss.eval()}
243 | else:
244 | _, loss, lr = self.sess.run(fetch_ops)
245 | yield {"loss":loss,"lr":lr}
246 |
--------------------------------------------------------------------------------
/encoder/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/input/test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test1.jpg
--------------------------------------------------------------------------------
/input/test2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test2.jpeg
--------------------------------------------------------------------------------
/input/test3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test3.jpg
--------------------------------------------------------------------------------
/input/test4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test4.jpeg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | #import project_image_without_optimizer as projector # much faster but effect worse
4 | import project_image as projector
5 | from encoder.model import BiSeNet
6 | import torch
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 | import PIL.Image
10 | from PIL import ImageFilter
11 | import cv2
12 | from tools.face_alignment import image_align
13 | from tools.landmarks_detector import LandmarksDetector
14 | from tools import functions
15 | from skimage.measure import label
16 |
17 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
18 | # Colors for all 20 parts
19 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
20 | [255, 0, 85], [255, 0, 170],
21 | [0, 255, 0], [85, 255, 0], [170, 255, 0],
22 | [0, 255, 85], [0, 255, 170],
23 | [0, 0, 255], [85, 0, 255], [170, 0, 255],
24 | [0, 85, 255], [0, 170, 255],
25 | [255, 255, 0], [255, 255, 85], [255, 255, 170],
26 | [255, 0, 255], [255, 85, 255], [255, 170, 255],
27 | [0, 255, 255], [85, 255, 255], [170, 255, 255]]
28 |
29 | im = np.array(im)
30 | vis_im = im.copy().astype(np.uint8)
31 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
32 | vis_parsing_anno = cv2.resize(
33 | vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
34 |
35 | vis_parsing_anno_color = np.zeros(
36 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
37 | mask = np.zeros(
38 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1]), dtype=np.uint8)
39 | num_of_class = np.max(vis_parsing_anno)
40 |
41 | idx = 11
42 | for pi in range(1, num_of_class + 1):
43 | index = np.where((vis_parsing_anno <= 5) & (
44 | vis_parsing_anno >= 1) | ((vis_parsing_anno >= 10) & (vis_parsing_anno <= 13)))
45 | mask[index[0], index[1]] = 1
46 | return mask
47 |
48 | def find_max_region(bw_img): # find Maximum Connected Domain of parsing mask
49 | labeled_img, num = label(bw_img, background=0, return_num=True)
50 | max_label = 0
51 | max_num = 0
52 | for i in range(1, num + 1):
53 | if np.sum(labeled_img == i) > max_num:
54 | max_num = np.sum(labeled_img == i)
55 | max_label = i
56 | lcc = (labeled_img == max_label)
57 | return lcc
58 |
59 |
60 | def main():
61 | """
62 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
63 | """
64 | parser = argparse.ArgumentParser(description='Model Face Swap', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
65 | parser.add_argument('--input_img', type=str, default='input/test1.jpg', help='Directory with raw images for face swap')
66 | parser.add_argument('--output_dir', type=str, default='output/', help='Directory for storing changed images')
67 | parser.add_argument('--project_style', type=str, default='model', help='model/pop-star/kids/wanghong...')
68 | parser.add_argument('--record', type=bool, default=True, help='Recording process')
69 | parser.add_argument('--landmark_path', type=str, default='networks/shape_predictor_68_face_landmarks.dat', help='face landmark file path')
70 | parser.add_argument('--parsing_path', type=str, default='networks/79999_iter.pth', help='parsing model path')
71 | args, _ = parser.parse_known_args()
72 |
73 | landmarks_detector = LandmarksDetector(args.landmark_path)
74 | parse_net = BiSeNet(n_classes=19)
75 | parse_net.cuda()
76 | parse_net.load_state_dict(torch.load(args.parsing_path))
77 | parse_net.eval()
78 | to_tensor = transforms.Compose([
79 | transforms.ToTensor(),
80 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
81 | ])
82 |
83 | os.makedirs(args.output_dir, exist_ok=True)
84 | dst_path = os.path.join(args.output_dir, args.input_img.rsplit('/', 1)[1].split('.')[0])+'_to-'+args.project_style+'/'
85 | os.makedirs(dst_path, exist_ok=True)
86 | ori_img = cv2.imread(args.input_img)
87 | face_data = {'aligned_images': [], 'masks': [], 'crops': [], 'pads': [], 'quads': [], 'record_paths': []}
88 | print('Step1 - Face alignment and mask extraction...')
89 | for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(args.input_img), start=1):
90 | if i == 1:
91 | cv2.imwrite(dst_path + 'input.png', ori_img)
92 | if args.record:
93 | record_path = dst_path + 'face'+str(i) + '/'
94 | face_data['record_paths'].append(record_path)
95 | os.makedirs(record_path, exist_ok=True)
96 |
97 | # face aligned
98 | aligned_image, crop, pad, quad = image_align(args.input_img, face_landmarks, output_size=1024, x_scale=1, y_scale=1, em_scale=0.1)
99 | face_data['aligned_images'].append(aligned_image)
100 | face_data['crops'].append(crop)
101 | face_data['pads'].append(pad)
102 | face_data['quads'].append(quad)
103 | if args.record:
104 | aligned_image.save(record_path+'face_input.png', 'PNG')
105 |
106 | # mask extraction
107 | image_sharp = aligned_image.filter(ImageFilter.DETAIL)
108 | alinged_image_np = np.array(image_sharp)
109 | img = to_tensor(alinged_image_np)
110 | img = torch.unsqueeze(img, 0)
111 | img = img.cuda()
112 | out = parse_net(img)[0]
113 | parsing = out.detach().squeeze(0).cpu().numpy().argmax(0)
114 | mask = vis_parsing_maps(alinged_image_np, parsing, stride=1)
115 | mask = find_max_region(mask)
116 | mask = (255 * mask).astype('uint8')
117 | mask = PIL.Image.fromarray(mask, 'L')
118 | face_data['masks'].append(mask)
119 | if args.record:
120 | mask.save(record_path+'face_mask.png', 'PNG')
121 |
122 | print('Step2 - Face projection and mixing back...')
123 | projected_images, dlatents = projector.project(face_data['aligned_images'], face_data['masks'], args.project_style)
124 | merged_image = ori_img
125 | for projected_image, dlatent, crop, quad, pad, record_path, mask in zip(projected_images, dlatents,
126 | face_data['crops'], face_data['quads'], face_data['pads'], face_data['record_paths'], face_data['masks']):
127 | if args.record:
128 | projected_image.save(record_path+'face_output.png', 'PNG')
129 | np.save(record_path+'dlatent.npy', dlatent)
130 | merged_image = functions.merge_image(merged_image, projected_image, mask, crop, quad, pad)
131 | cv2.imwrite(dst_path+'output.png', merged_image)
132 |
133 | if __name__ == "__main__":
134 | main()
135 |
--------------------------------------------------------------------------------
/networks/download_weights.txt:
--------------------------------------------------------------------------------
1 | 模型打包下载地址:
2 | 链接:https://pan.baidu.com/s/1yr_QNpHrXvq4PegMZBzDGA
3 | 提取码:v4xt
4 |
5 |
--------------------------------------------------------------------------------
/pics/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/architecture.png
--------------------------------------------------------------------------------
/pics/example_2kids.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/example_2kids.jpg
--------------------------------------------------------------------------------
/pics/example_2wanghong.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/example_2wanghong.png
--------------------------------------------------------------------------------
/pics/examples_mix.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/examples_mix.jpg
--------------------------------------------------------------------------------
/pics/multi-model-solution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/multi-model-solution.png
--------------------------------------------------------------------------------
/pics/preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/preview.jpg
--------------------------------------------------------------------------------
/pics/single_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/single_input.png
--------------------------------------------------------------------------------
/pics/single_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/single_output.png
--------------------------------------------------------------------------------
/project_image.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import argparse
4 | import pickle
5 | from tqdm import tqdm
6 | import PIL.Image
7 | from PIL import ImageFilter
8 | import numpy as np
9 | import dnnlib
10 | import dnnlib.tflib as tflib
11 | import tensorflow as tf
12 | from encoder.generator_model import Generator
13 | from encoder.perceptual_model import PerceptualModel, load_image
14 | #from tensorflow.keras.models import load_model
15 | from keras.models import load_model
16 | from keras.applications.resnet50 import preprocess_input
17 |
18 | def split_to_batches(l, n):
19 | for i in range(0, len(l), n):
20 | yield l[i:i + n]
21 |
22 | def str2bool(v):
23 | if isinstance(v, bool):
24 | return v
25 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
26 | return True
27 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
28 | return False
29 | else:
30 | raise argparse.ArgumentTypeError('Boolean value expected.')
31 |
32 | def project(images, masks, projector_name):
33 | parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
34 | parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
35 | parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
36 | parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
37 | parser.add_argument('--optimizer', default='ggt', help='Optimization algorithm used for optimizing dlatents')
38 |
39 | # Perceptual model params
40 | parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
41 | parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int)
42 | parser.add_argument('--lr', default=0.25, help='Learning rate for perceptual model', type=float)
43 | parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
44 | parser.add_argument('--iterations', default=1000, help='Number of optimization steps for each batch', type=int)
45 | parser.add_argument('--decay_steps', default=4, help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
46 | parser.add_argument('--early_stopping', default=True, help='Stop early once training stabilizes', type=str2bool, nargs='?', const=True)
47 | parser.add_argument('--early_stopping_threshold', default=0.5, help='Stop after this threshold has been reached', type=float)
48 | parser.add_argument('--early_stopping_patience', default=10, help='Number of iterations to wait below threshold', type=int)
49 | parser.add_argument('--load_resnet', default='networks/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents')
50 | parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True)
51 | parser.add_argument('--use_best_loss', default=True, help='Output the lowest loss value found as the solution', type=str2bool, nargs='?', const=True)
52 | parser.add_argument('--average_best_loss', default=0.25, help='Do a running weighted average with the previous best dlatents found', type=float)
53 | parser.add_argument('--sharpen_input', default=True, help='Sharpen the input images', type=str2bool, nargs='?', const=True)
54 |
55 | # Loss function options
56 | parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float)
57 | parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int)
58 | parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
59 | parser.add_argument('--use_mssim_loss', default=200, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.', type=float)
60 | parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.', type=float)
61 | parser.add_argument('--use_l1_penalty', default=0.5, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float)
62 | parser.add_argument('--use_discriminator_loss', default=0.5, help='Use trained discriminator to evaluate realism.', type=float)
63 | parser.add_argument('--use_adaptive_loss', default=False, help='Use the adaptive robust loss function from Google Research for pixel and VGG feature loss.', type=str2bool, nargs='?', const=True)
64 |
65 | # Generator params
66 | parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True)
67 | parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True)
68 | parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float)
69 |
70 | # Masking params
71 | parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int)
72 |
73 | args, other_args = parser.parse_known_args()
74 | args.decay_steps *= 0.01 * args.iterations # Calculate steps as a percent of total iterations
75 |
76 | # Initialize generator and perceptual model
77 | tflib.init_tf()
78 | with open('networks/karras2019stylegan-ffhq-1024x1024.pkl','rb') as f:
79 | generator_network, discriminator_network, Gs_network = pickle.load(f)
80 |
81 | generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
82 | if (args.dlatent_avg != ''):
83 | generator.set_dlatent_avg(np.load(args.dlatent_avg))
84 |
85 | perc_model = None
86 | if (args.use_lpips_loss > 0.00000001):
87 | with open('networks/vgg16_zhang_perceptual.pkl', 'rb') as f:
88 | perc_model = pickle.load(f)
89 | perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
90 | perceptual_model.build_perceptual_model(generator, discriminator_network)
91 |
92 | ff_model = None
93 |
94 | # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
95 | best_dlatents = []
96 | for image, mask in zip(images, masks):
97 | perceptual_model.set_reference_image(image)
98 | dlatents = None
99 | if (ff_model is None):
100 | if os.path.exists(args.load_resnet):
101 | print("Loading ResNet Model:")
102 | ff_model = load_model(args.load_resnet)
103 | if (ff_model is not None): # predict initial dlatents with ResNet model
104 | if (args.use_preprocess_input):
105 | dlatents = ff_model.predict(preprocess_input(load_image(image, image_size=args.resnet_image_size)))
106 | else:
107 | dlatents = ff_model.predict(load_image(image,image_size=args.resnet_image_size))
108 | if dlatents is not None:
109 | generator.set_dlatents(dlatents)
110 | op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, use_optimizer=args.optimizer)
111 | pbar = tqdm(op, leave=False, total=args.iterations)
112 | best_loss = None
113 | best_dlatent = None
114 | avg_loss_count = 0
115 | if args.early_stopping:
116 | avg_loss = prev_loss = None
117 | for loss_dict in pbar:
118 | if args.early_stopping: # early stopping feature
119 | if prev_loss is not None:
120 | if avg_loss is not None:
121 | avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"])
122 | if avg_loss < args.early_stopping_threshold: # count while under threshold; else reset
123 | avg_loss_count += 1
124 | else:
125 | avg_loss_count = 0
126 | if avg_loss_count > args.early_stopping_patience: # stop once threshold is reached
127 | break
128 | else:
129 | avg_loss = prev_loss - loss_dict["loss"]
130 | pbar.set_description(" Oprimizing dlatent: " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
131 | if best_loss is None or loss_dict["loss"] < best_loss:
132 | if best_dlatent is None or args.average_best_loss <= 0.00000001:
133 | best_dlatent = generator.get_dlatents()
134 | else:
135 | best_dlatent = 0.25 * best_dlatent + 0.75 * generator.get_dlatents()
136 | if args.use_best_loss:
137 | generator.set_dlatents(best_dlatent)
138 | best_loss = loss_dict["loss"]
139 | generator.stochastic_clip_dlatents()
140 | prev_loss = loss_dict["loss"]
141 | if not args.use_best_loss:
142 | best_loss = prev_loss
143 | best_dlatents.append(best_dlatent)
144 | print("\n Optimizing dlatent Best Loss {:.4f}".format(best_loss))
145 |
146 | # Using Projector to generate images
147 | tflib.init_tf()
148 | with open('networks/projector_'+projector_name+'.pkl', 'rb') as f:
149 | Gs_network = pickle.load(f)
150 | generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold,
151 | tiled_dlatent=args.tile_dlatents, model_res=args.model_res,
152 | randomize_noise=args.randomize_noise)
153 | imgs = []
154 | for best_dlatent, image, mask in zip(best_dlatents, images, masks):
155 | generator.set_dlatents(best_dlatent)
156 | img_array = generator.generate_images()[0]
157 | generator.reset_dlatents()
158 |
159 | # Merge images with new face
160 | width, height = image.size
161 | mask = mask.resize((width, height))
162 | mask = mask.filter(ImageFilter.GaussianBlur(args.composite_blur))
163 | mask = np.array(mask) / 255
164 | mask = np.expand_dims(mask, axis=-1)
165 | img_array = mask * np.array(img_array) + (1.0 - mask) * np.array(image)
166 | img_array = img_array.astype(np.uint8)
167 | img = PIL.Image.fromarray(img_array, 'RGB')
168 | imgs.append(img)
169 |
170 | return imgs, best_dlatents
--------------------------------------------------------------------------------
/project_image_without_optimizer.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import argparse
4 | import pickle
5 | from tqdm import tqdm
6 | import PIL.Image
7 | from PIL import ImageFilter
8 | import numpy as np
9 | import dnnlib.tflib as tflib
10 | from encoder.generator_model import Generator
11 | from encoder.perceptual_model import load_image
12 | from keras.models import load_model
13 | from keras.applications.resnet50 import preprocess_input
14 |
15 | def split_to_batches(l, n):
16 | for i in range(0, len(l), n):
17 | yield l[i:i + n]
18 |
19 | def str2bool(v):
20 | if isinstance(v, bool):
21 | return v
22 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
23 | return True
24 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
25 | return False
26 | else:
27 | raise argparse.ArgumentTypeError('Boolean value expected.')
28 |
29 | def project(images, masks, projector_name):
30 | parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31 | parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
32 | parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
33 | parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
34 |
35 | # Perceptual model params
36 | parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
37 | parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int)
38 | parser.add_argument('--load_resnet', default='networks/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents')
39 | parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True)
40 |
41 | # Generator params
42 | parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True)
43 | parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True)
44 | parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float)
45 |
46 | # Masking params
47 | parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int)
48 |
49 | args, other_args = parser.parse_known_args()
50 |
51 | # Initialize generator and encoder model
52 | tflib.init_tf()
53 | with open('networks/projector_'+projector_name+'.pkl', 'rb') as f:
54 | projector = pickle.load(f)
55 | generator = Generator(projector, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
56 | if (args.dlatent_avg != ''):
57 | generator.set_dlatent_avg(np.load(args.dlatent_avg))
58 | print(" Loading ResNet Model...")
59 | ff_model = load_model(args.load_resnet)
60 |
61 | # Find the dlatent of the image
62 | dlatents = []
63 | imgs = []
64 | for image, mask in zip(images, masks):
65 | if (args.use_preprocess_input):
66 | dlatent = ff_model.predict(preprocess_input((load_image(image, image_size=args.resnet_image_size))))
67 | else:
68 | dlatent = ff_model.predict((load_image(image, image_size=args.resnet_image_size)))
69 | if dlatent is not None:
70 | generator.set_dlatents(dlatent)
71 |
72 | # Using Projector to generate images
73 | generator.set_dlatents(dlatent)
74 | generated_images = generator.generate_images()
75 |
76 | # Merge images with new face
77 | img_array = generated_images[0]
78 | ori_img = image
79 | width, height = ori_img.size
80 | mask = mask.resize((width, height))
81 | mask = mask.filter(ImageFilter.GaussianBlur(args.composite_blur))
82 | mask = np.array(mask) / 255
83 | mask = np.expand_dims(mask, axis=-1)
84 | img_array = mask * np.array(img_array) + (1.0 - mask) * np.array(ori_img)
85 | img_array = img_array.astype(np.uint8)
86 | img = PIL.Image.fromarray(img_array, 'RGB')
87 |
88 | imgs.append(img)
89 | dlatents.append(dlatent)
90 |
91 | return imgs, dlatents
92 |
93 |
--------------------------------------------------------------------------------
/tools/face_alignment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.ndimage
3 | import os
4 | import PIL.Image
5 | from PIL import ImageDraw
6 |
7 |
8 | def image_align(src_file, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True, x_scale=1, y_scale=1, em_scale=0.1):
9 | # Align function from FFHQ dataset pre-processing step
10 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
11 |
12 | lm = np.array(face_landmarks)
13 | lm_chin = lm[0 : 17] # left-right
14 | lm_eyebrow_left = lm[17 : 22] # left-right
15 | lm_eyebrow_right = lm[22 : 27] # left-right
16 | lm_nose = lm[27 : 31] # top-down
17 | lm_nostrils = lm[31 : 36] # top-down
18 | lm_eye_left = lm[36 : 42] # left-clockwise
19 | lm_eye_right = lm[42 : 48] # left-clockwise
20 | lm_mouth_outer = lm[48 : 60] # left-clockwise
21 | lm_mouth_inner = lm[60 : 68] # left-clockwise
22 |
23 | # Calculate auxiliary vectors.
24 | eye_left = np.mean(lm_eye_left, axis=0)
25 | eye_right = np.mean(lm_eye_right, axis=0)
26 | eye_avg = (eye_left + eye_right) * 0.5
27 | eye_to_eye = eye_right - eye_left
28 | mouth_left = lm_mouth_outer[0]
29 | mouth_right = lm_mouth_outer[6]
30 | mouth_avg = (mouth_left + mouth_right) * 0.5
31 | eye_to_mouth = mouth_avg - eye_avg
32 |
33 | # Choose oriented crop rectangle.
34 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
35 | x /= np.hypot(*x)
36 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
37 | x *= x_scale
38 | y = np.flipud(x) * [-y_scale, y_scale]
39 | c = eye_avg + eye_to_mouth * em_scale
40 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
41 | qsize = np.hypot(*x) * 2
42 |
43 | # Load in-the-wild image.
44 | if not os.path.isfile(src_file):
45 | print('\nCannot find source image. Please run "--wilds" before "--align".')
46 | return
47 | img_bg = PIL.Image.open(src_file).convert('RGBA').convert('RGB')
48 |
49 | # Shrink.
50 | img = img_bg.copy()
51 | shrink = int(np.floor(qsize / output_size * 0.5))
52 | if shrink > 1:
53 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
54 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
55 | quad /= shrink
56 | qsize /= shrink
57 |
58 | # Crop.
59 | border = max(int(np.rint(qsize * 0.1)), 3)
60 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
61 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
62 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
63 | img = img.crop(crop)
64 | bg_draw = ImageDraw.ImageDraw(img_bg)
65 | bg_draw.rectangle(crop, fill='white')
66 | quad -= crop[0:2]
67 |
68 | # Pad.
69 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
70 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
71 | if enable_padding and max(pad) > border - 4:
72 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
73 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
74 |
75 | h, w, _ = img.shape
76 | y, x, _ = np.ogrid[:h, :w, :1]
77 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
78 | blur = qsize * 0.02
79 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
80 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
81 | img = np.uint8(np.clip(np.rint(img), 0, 255))
82 | img = PIL.Image.fromarray(img, 'RGB')
83 | quad += pad[:2]
84 |
85 | # Transform.
86 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
87 |
88 | if output_size < transform_size:
89 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
90 |
91 | # Return aligned image.
92 | return img,crop,pad,quad
--------------------------------------------------------------------------------
/tools/functions.py:
--------------------------------------------------------------------------------
1 | import PIL.Image as Image
2 | import cv2
3 | import numpy as np
4 | import math
5 |
6 | def rotate(img, degree):
7 | height, width = img.shape[:2]
8 | heightNew = round(width * math.fabs(math.sin(math.radians(degree))) + height * math.fabs(math.cos(math.radians(degree))))
9 | widthNew = round(height * math.fabs(math.sin(math.radians(degree))) + width * math.fabs(math.cos(math.radians(degree))))
10 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)
11 | matRotation[0, 2] += (widthNew - width) / 2
12 | matRotation[1, 2] += (heightNew - height) / 2
13 | imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew))
14 | return imgRotation
15 |
16 | def merge_image(bg_img, fg_img, mask, crop, quad, pad):
17 | bg_img_ori = bg_img.copy()
18 | bg_img_alpha = cv2.cvtColor(bg_img, cv2.COLOR_BGR2BGRA)
19 | fg_img = cv2.cvtColor(np.asarray(fg_img), cv2.COLOR_RGB2BGR)
20 | mask = np.asarray(mask)
21 | line = int(round(max(quad[2][0]-quad[0][0], quad[3][0]-quad[1][0])))
22 | radian = math.atan((quad[1][0]-quad[0][0])/(quad[1][1]-quad[0][1]))
23 | degree = math.degrees(radian)
24 | fg_img = rotate(fg_img, degree)
25 | fg_img = cv2.resize(fg_img, (line, line), interpolation=cv2.INTER_NEAREST)
26 | mask = rotate(mask, degree)
27 | mask = cv2.resize(mask, (line, line), interpolation=cv2.INTER_NEAREST)
28 | x1 = int(round(crop[0]-pad[0]+min([quad[0][0], quad[1][0], quad[2][0], quad[3][0]])))
29 | y1 = int(round(crop[1]-pad[0]+min([quad[0][1], quad[1][1], quad[2][1], quad[3][1]])))
30 | x2 = x1+line
31 | y2 = y1+line
32 | if x1 < 0:
33 | fg_img = fg_img[:, -x1:]
34 | mask = mask[:, -x1:]
35 | x1 = 0
36 | if y1 < 0:
37 | fg_img = fg_img[-y1:, :]
38 | mask = mask[-y1:, :]
39 | y1 = 0
40 | if x2 > bg_img.shape[1]:
41 | fg_img = fg_img[:, :-(x2-bg_img.shape[1])]
42 | mask = mask[:, :-(x2-bg_img.shape[1])]
43 | x2 = bg_img.shape[1]
44 | if y2 > bg_img.shape[0]:
45 | fg_img = fg_img[:-(y2 - bg_img.shape[0]), :]
46 | mask = mask[:-(y2 - bg_img.shape[0]), :]
47 | y2 = bg_img.shape[0]
48 | #alpha = cv2.erode(mask / 255.0, np.ones((3,3), np.uint8), iterations = 1)
49 | alpha = cv2.GaussianBlur(mask / 255.0, (5,5), 0)
50 | bg_img[y1:y2, x1:x2, 0] = (1. - alpha) * bg_img[y1:y2, x1:x2, 0] + alpha * fg_img[..., 0]
51 | bg_img[y1:y2, x1:x2, 1] = (1. - alpha) * bg_img[y1:y2, x1:x2, 1] + alpha * fg_img[..., 1]
52 | bg_img[y1:y2, x1:x2, 2] = (1. - alpha) * bg_img[y1:y2, x1:x2, 2] + alpha * fg_img[..., 2]
53 | bg_img[y1:y2, x1:x2] = cv2.fastNlMeansDenoisingColored(bg_img[y1:y2, x1:x2], None, 3.0, 3.0, 7, 21)
54 |
55 | # Seamlessly clone src into dst and put the results in output
56 | width, height, channels = bg_img_ori.shape
57 | center = (height // 2, width // 2)
58 | mask = 255 * np.ones(bg_img.shape, bg_img.dtype)
59 | normal_clone = cv2.seamlessClone(bg_img, bg_img_ori, mask, center, cv2.NORMAL_CLONE)
60 |
61 | return normal_clone
62 |
63 |
64 | def generate_face_mask(im, landmarks_detector):
65 | from imutils import face_utils
66 | rects = landmarks_detector.detector(im, 1)
67 | # loop over the face detections
68 | for (j, rect) in enumerate(rects):
69 | """
70 | Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array
71 | """
72 | shape = landmarks_detector.shape_predictor(im, rect)
73 | shape = face_utils.shape_to_np(shape)
74 |
75 | # we extract the face
76 | vertices = cv2.convexHull(shape)
77 | mask = np.zeros(im.shape[:2],np.uint8)
78 | cv2.fillConvexPoly(mask, vertices, 1)
79 | bgdModel = np.zeros((1,65),np.float64)
80 | fgdModel = np.zeros((1,65),np.float64)
81 | rect = (0,0,im.shape[1],im.shape[2])
82 | (x,y),radius = cv2.minEnclosingCircle(vertices)
83 | center = (int(x), int(y))
84 | radius = int(radius*1.4)
85 | mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1)
86 | cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD)
87 | cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK)
88 | mask = np.where((mask==2)|(mask==0),0,1)
89 | cv2.rectangle(mask, (0, 0), (mask.shape[1], mask.shape[0]), 0, thickness=10)
90 | return mask
91 |
92 |
93 | def generate_face_mask_without_hair(im, landmarks_detector, ie_polys=None):
94 | # get the mask of the image with only face area
95 | rects = landmarks_detector.detector(im, 1)
96 | image_landmarks = np.matrix([[p.x, p.y] for p in landmarks_detector.shape_predictor(im, rects[0]).parts()])
97 | if image_landmarks.shape[0] != 68:
98 | raise Exception(
99 | 'get_image_hull_mask works only with 68 landmarks')
100 | int_lmrks = np.array(image_landmarks, dtype=np.int)
101 |
102 | # hull_mask = np.zeros(image_shape[0:2]+(1,), dtype=np.float32)
103 | hull_mask = np.full(im.shape[0:2] + (1,), 0, dtype=np.float32)
104 |
105 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
106 | np.concatenate((int_lmrks[0:9],
107 | int_lmrks[17:18]))), (1,))
108 |
109 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
110 | np.concatenate((int_lmrks[8:17],
111 | int_lmrks[26:27]))), (1,))
112 |
113 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
114 | np.concatenate((int_lmrks[17:20],
115 | int_lmrks[8:9]))), (1,))
116 |
117 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
118 | np.concatenate((int_lmrks[24:27],
119 | int_lmrks[8:9]))), (1,))
120 |
121 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
122 | np.concatenate((int_lmrks[19:25],
123 | int_lmrks[8:9],
124 | ))), (1,))
125 |
126 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
127 | np.concatenate((int_lmrks[17:22],
128 | int_lmrks[27:28],
129 | int_lmrks[31:36],
130 | int_lmrks[8:9]
131 | ))), (1,))
132 |
133 | cv2.fillConvexPoly(hull_mask, cv2.convexHull(
134 | np.concatenate((int_lmrks[22:27],
135 | int_lmrks[27:28],
136 | int_lmrks[31:36],
137 | int_lmrks[8:9]
138 | ))), (1,))
139 |
140 | # nose
141 | cv2.fillConvexPoly(
142 | hull_mask, cv2.convexHull(int_lmrks[27:36]), (1,))
143 |
144 | if ie_polys is not None:
145 | ie_polys.overlay_mask(hull_mask)
146 | hull_mask = hull_mask.squeeze()
147 | return hull_mask
148 |
--------------------------------------------------------------------------------
/tools/landmarks_detector.py:
--------------------------------------------------------------------------------
1 | import dlib
2 |
3 |
4 | class LandmarksDetector:
5 | def __init__(self, predictor_model_path):
6 | """
7 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
8 | """
9 | self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used
10 | self.shape_predictor = dlib.shape_predictor(predictor_model_path)
11 |
12 | def get_landmarks(self, image):
13 | img = dlib.load_rgb_image(image)
14 | dets = self.detector(img, 1)
15 |
16 | for detection in dets:
17 | try:
18 | face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
19 | yield face_landmarks
20 | except:
21 | print("Exception in get_landmarks()!")
22 |
--------------------------------------------------------------------------------