├── LISENCE.txt ├── README.md ├── dnnlib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── util.cpython-37.pyc ├── submission │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── run_context.cpython-37.pyc │ │ └── submit.cpython-37.pyc │ ├── _internal │ │ └── run.py │ ├── run_context.py │ └── submit.py ├── tflib │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── autosummary.cpython-37.pyc │ │ ├── network.cpython-37.pyc │ │ ├── optimizer.cpython-37.pyc │ │ └── tfutil.cpython-37.pyc │ ├── autosummary.py │ ├── network.py │ ├── optimizer.py │ └── tfutil.py └── util.py ├── encoder ├── generator_model.py ├── model.py ├── perceptual_model.py └── resnet.py ├── input ├── test1.jpg ├── test2.jpeg ├── test3.jpg └── test4.jpeg ├── main.py ├── networks └── download_weights.txt ├── pics ├── architecture.png ├── example_2kids.jpg ├── example_2wanghong.png ├── examples_mix.jpg ├── multi-model-solution.png ├── preview.jpg ├── single_input.png └── single_output.png ├── project_image.py ├── project_image_without_optimizer.py └── tools ├── face_alignment.py ├── functions.py └── landmarks_detector.py /LISENCE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, www.seeprettyface.com. All rights reserved. 2 | 3 | 4 | Attribution-NonCommercial 4.0 International 5 | 6 | ======================================================================= 7 | 8 | Creative Commons Corporation ("Creative Commons") is not a law firm and 9 | does not provide legal services or legal advice. Distribution of 10 | Creative Commons public licenses does not create a lawyer-client or 11 | other relationship. Creative Commons makes its licenses and related 12 | information available on an "as-is" basis. Creative Commons gives no 13 | warranties regarding its licenses, any material licensed under their 14 | terms and conditions, or any related information. Creative Commons 15 | disclaims all liability for damages resulting from their use to the 16 | fullest extent possible. 17 | 18 | Using Creative Commons Public Licenses 19 | 20 | Creative Commons public licenses provide a standard set of terms and 21 | conditions that creators and other rights holders may use to share 22 | original works of authorship and other material subject to copyright 23 | and certain other rights specified in the public license below. The 24 | following considerations are for informational purposes only, are not 25 | exhaustive, and do not form part of our licenses. 26 | 27 | Considerations for licensors: Our public licenses are 28 | intended for use by those authorized to give the public 29 | permission to use material in ways otherwise restricted by 30 | copyright and certain other rights. Our licenses are 31 | irrevocable. Licensors should read and understand the terms 32 | and conditions of the license they choose before applying it. 33 | Licensors should also secure all rights necessary before 34 | applying our licenses so that the public can reuse the 35 | material as expected. Licensors should clearly mark any 36 | material not subject to the license. This includes other CC- 37 | licensed material, or material used under an exception or 38 | limitation to copyright. More considerations for licensors: 39 | wiki.creativecommons.org/Considerations_for_licensors 40 | 41 | Considerations for the public: By using one of our public 42 | licenses, a licensor grants the public permission to use the 43 | licensed material under specified terms and conditions. If 44 | the licensor's permission is not necessary for any reason--for 45 | example, because of any applicable exception or limitation to 46 | copyright--then that use is not regulated by the license. Our 47 | licenses grant only permissions under copyright and certain 48 | other rights that a licensor has authority to grant. Use of 49 | the licensed material may still be restricted for other 50 | reasons, including because others have copyright or other 51 | rights in the material. A licensor may make special requests, 52 | such as asking that all changes be marked or described. 53 | Although not required by our licenses, you are encouraged to 54 | respect those requests where reasonable. More_considerations 55 | for the public: 56 | wiki.creativecommons.org/Considerations_for_licensees 57 | 58 | ======================================================================= 59 | 60 | Creative Commons Attribution-NonCommercial 4.0 International Public 61 | License 62 | 63 | By exercising the Licensed Rights (defined below), You accept and agree 64 | to be bound by the terms and conditions of this Creative Commons 65 | Attribution-NonCommercial 4.0 International Public License ("Public 66 | License"). To the extent this Public License may be interpreted as a 67 | contract, You are granted the Licensed Rights in consideration of Your 68 | acceptance of these terms and conditions, and the Licensor grants You 69 | such rights in consideration of benefits the Licensor receives from 70 | making the Licensed Material available under these terms and 71 | conditions. 72 | 73 | 74 | Section 1 -- Definitions. 75 | 76 | a. Adapted Material means material subject to Copyright and Similar 77 | Rights that is derived from or based upon the Licensed Material 78 | and in which the Licensed Material is translated, altered, 79 | arranged, transformed, or otherwise modified in a manner requiring 80 | permission under the Copyright and Similar Rights held by the 81 | Licensor. For purposes of this Public License, where the Licensed 82 | Material is a musical work, performance, or sound recording, 83 | Adapted Material is always produced where the Licensed Material is 84 | synched in timed relation with a moving image. 85 | 86 | b. Adapter's License means the license You apply to Your Copyright 87 | and Similar Rights in Your contributions to Adapted Material in 88 | accordance with the terms and conditions of this Public License. 89 | 90 | c. Copyright and Similar Rights means copyright and/or similar rights 91 | closely related to copyright including, without limitation, 92 | performance, broadcast, sound recording, and Sui Generis Database 93 | Rights, without regard to how the rights are labeled or 94 | categorized. For purposes of this Public License, the rights 95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 96 | Rights. 97 | d. Effective Technological Measures means those measures that, in the 98 | absence of proper authority, may not be circumvented under laws 99 | fulfilling obligations under Article 11 of the WIPO Copyright 100 | Treaty adopted on December 20, 1996, and/or similar international 101 | agreements. 102 | 103 | e. Exceptions and Limitations means fair use, fair dealing, and/or 104 | any other exception or limitation to Copyright and Similar Rights 105 | that applies to Your use of the Licensed Material. 106 | 107 | f. Licensed Material means the artistic or literary work, database, 108 | or other material to which the Licensor applied this Public 109 | License. 110 | 111 | g. Licensed Rights means the rights granted to You subject to the 112 | terms and conditions of this Public License, which are limited to 113 | all Copyright and Similar Rights that apply to Your use of the 114 | Licensed Material and that the Licensor has authority to license. 115 | 116 | h. Licensor means the individual(s) or entity(ies) granting rights 117 | under this Public License. 118 | 119 | i. NonCommercial means not primarily intended for or directed towards 120 | commercial advantage or monetary compensation. For purposes of 121 | this Public License, the exchange of the Licensed Material for 122 | other material subject to Copyright and Similar Rights by digital 123 | file-sharing or similar means is NonCommercial provided there is 124 | no payment of monetary compensation in connection with the 125 | exchange. 126 | 127 | j. Share means to provide material to the public by any means or 128 | process that requires permission under the Licensed Rights, such 129 | as reproduction, public display, public performance, distribution, 130 | dissemination, communication, or importation, and to make material 131 | available to the public including in ways that members of the 132 | public may access the material from a place and at a time 133 | individually chosen by them. 134 | 135 | k. Sui Generis Database Rights means rights other than copyright 136 | resulting from Directive 96/9/EC of the European Parliament and of 137 | the Council of 11 March 1996 on the legal protection of databases, 138 | as amended and/or succeeded, as well as other essentially 139 | equivalent rights anywhere in the world. 140 | 141 | l. You means the individual or entity exercising the Licensed Rights 142 | under this Public License. Your has a corresponding meaning. 143 | 144 | 145 | Section 2 -- Scope. 146 | 147 | a. License grant. 148 | 149 | 1. Subject to the terms and conditions of this Public License, 150 | the Licensor hereby grants You a worldwide, royalty-free, 151 | non-sublicensable, non-exclusive, irrevocable license to 152 | exercise the Licensed Rights in the Licensed Material to: 153 | 154 | a. reproduce and Share the Licensed Material, in whole or 155 | in part, for NonCommercial purposes only; and 156 | 157 | b. produce, reproduce, and Share Adapted Material for 158 | NonCommercial purposes only. 159 | 160 | 2. Exceptions and Limitations. For the avoidance of doubt, where 161 | Exceptions and Limitations apply to Your use, this Public 162 | License does not apply, and You do not need to comply with 163 | its terms and conditions. 164 | 165 | 3. Term. The term of this Public License is specified in Section 166 | 6(a). 167 | 168 | 4. Media and formats; technical modifications allowed. The 169 | Licensor authorizes You to exercise the Licensed Rights in 170 | all media and formats whether now known or hereafter created, 171 | and to make technical modifications necessary to do so. The 172 | Licensor waives and/or agrees not to assert any right or 173 | authority to forbid You from making technical modifications 174 | necessary to exercise the Licensed Rights, including 175 | technical modifications necessary to circumvent Effective 176 | Technological Measures. For purposes of this Public License, 177 | simply making modifications authorized by this Section 2(a) 178 | (4) never produces Adapted Material. 179 | 180 | 5. Downstream recipients. 181 | 182 | a. Offer from the Licensor -- Licensed Material. Every 183 | recipient of the Licensed Material automatically 184 | receives an offer from the Licensor to exercise the 185 | Licensed Rights under the terms and conditions of this 186 | Public License. 187 | 188 | b. No downstream restrictions. You may not offer or impose 189 | any additional or different terms or conditions on, or 190 | apply any Effective Technological Measures to, the 191 | Licensed Material if doing so restricts exercise of the 192 | Licensed Rights by any recipient of the Licensed 193 | Material. 194 | 195 | 6. No endorsement. Nothing in this Public License constitutes or 196 | may be construed as permission to assert or imply that You 197 | are, or that Your use of the Licensed Material is, connected 198 | with, or sponsored, endorsed, or granted official status by, 199 | the Licensor or others designated to receive attribution as 200 | provided in Section 3(a)(1)(A)(i). 201 | 202 | b. Other rights. 203 | 204 | 1. Moral rights, such as the right of integrity, are not 205 | licensed under this Public License, nor are publicity, 206 | privacy, and/or other similar personality rights; however, to 207 | the extent possible, the Licensor waives and/or agrees not to 208 | assert any such rights held by the Licensor to the limited 209 | extent necessary to allow You to exercise the Licensed 210 | Rights, but not otherwise. 211 | 212 | 2. Patent and trademark rights are not licensed under this 213 | Public License. 214 | 215 | 3. To the extent possible, the Licensor waives any right to 216 | collect royalties from You for the exercise of the Licensed 217 | Rights, whether directly or through a collecting society 218 | under any voluntary or waivable statutory or compulsory 219 | licensing scheme. In all other cases the Licensor expressly 220 | reserves any right to collect such royalties, including when 221 | the Licensed Material is used other than for NonCommercial 222 | purposes. 223 | 224 | 225 | Section 3 -- License Conditions. 226 | 227 | Your exercise of the Licensed Rights is expressly made subject to the 228 | following conditions. 229 | 230 | a. Attribution. 231 | 232 | 1. If You Share the Licensed Material (including in modified 233 | form), You must: 234 | 235 | a. retain the following if it is supplied by the Licensor 236 | with the Licensed Material: 237 | 238 | i. identification of the creator(s) of the Licensed 239 | Material and any others designated to receive 240 | attribution, in any reasonable manner requested by 241 | the Licensor (including by pseudonym if 242 | designated); 243 | 244 | ii. a copyright notice; 245 | 246 | iii. a notice that refers to this Public License; 247 | 248 | iv. a notice that refers to the disclaimer of 249 | warranties; 250 | 251 | v. a URI or hyperlink to the Licensed Material to the 252 | extent reasonably practicable; 253 | 254 | b. indicate if You modified the Licensed Material and 255 | retain an indication of any previous modifications; and 256 | 257 | c. indicate the Licensed Material is licensed under this 258 | Public License, and include the text of, or the URI or 259 | hyperlink to, this Public License. 260 | 261 | 2. You may satisfy the conditions in Section 3(a)(1) in any 262 | reasonable manner based on the medium, means, and context in 263 | which You Share the Licensed Material. For example, it may be 264 | reasonable to satisfy the conditions by providing a URI or 265 | hyperlink to a resource that includes the required 266 | information. 267 | 268 | 3. If requested by the Licensor, You must remove any of the 269 | information required by Section 3(a)(1)(A) to the extent 270 | reasonably practicable. 271 | 272 | 4. If You Share Adapted Material You produce, the Adapter's 273 | License You apply must not prevent recipients of the Adapted 274 | Material from complying with this Public License. 275 | 276 | 277 | Section 4 -- Sui Generis Database Rights. 278 | 279 | Where the Licensed Rights include Sui Generis Database Rights that 280 | apply to Your use of the Licensed Material: 281 | 282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 283 | to extract, reuse, reproduce, and Share all or a substantial 284 | portion of the contents of the database for NonCommercial purposes 285 | only; 286 | 287 | b. if You include all or a substantial portion of the database 288 | contents in a database in which You have Sui Generis Database 289 | Rights, then the database in which You have Sui Generis Database 290 | Rights (but not its individual contents) is Adapted Material; and 291 | 292 | c. You must comply with the conditions in Section 3(a) if You Share 293 | all or a substantial portion of the contents of the database. 294 | 295 | For the avoidance of doubt, this Section 4 supplements and does not 296 | replace Your obligations under this Public License where the Licensed 297 | Rights include other Copyright and Similar Rights. 298 | 299 | 300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 301 | 302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 312 | 313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 322 | 323 | c. The disclaimer of warranties and limitation of liability provided 324 | above shall be interpreted in a manner that, to the extent 325 | possible, most closely approximates an absolute disclaimer and 326 | waiver of all liability. 327 | 328 | 329 | Section 6 -- Term and Termination. 330 | 331 | a. This Public License applies for the term of the Copyright and 332 | Similar Rights licensed here. However, if You fail to comply with 333 | this Public License, then Your rights under this Public License 334 | terminate automatically. 335 | 336 | b. Where Your right to use the Licensed Material has terminated under 337 | Section 6(a), it reinstates: 338 | 339 | 1. automatically as of the date the violation is cured, provided 340 | it is cured within 30 days of Your discovery of the 341 | violation; or 342 | 343 | 2. upon express reinstatement by the Licensor. 344 | 345 | For the avoidance of doubt, this Section 6(b) does not affect any 346 | right the Licensor may have to seek remedies for Your violations 347 | of this Public License. 348 | 349 | c. For the avoidance of doubt, the Licensor may also offer the 350 | Licensed Material under separate terms or conditions or stop 351 | distributing the Licensed Material at any time; however, doing so 352 | will not terminate this Public License. 353 | 354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 355 | License. 356 | 357 | 358 | Section 7 -- Other Terms and Conditions. 359 | 360 | a. The Licensor shall not be bound by any additional or different 361 | terms or conditions communicated by You unless expressly agreed. 362 | 363 | b. Any arrangements, understandings, or agreements regarding the 364 | Licensed Material not stated herein are separate from and 365 | independent of the terms and conditions of this Public License. 366 | 367 | 368 | Section 8 -- Interpretation. 369 | 370 | a. For the avoidance of doubt, this Public License does not, and 371 | shall not be interpreted to, reduce, limit, restrict, or impose 372 | conditions on any use of the Licensed Material that could lawfully 373 | be made without permission under this Public License. 374 | 375 | b. To the extent possible, if any provision of this Public License is 376 | deemed unenforceable, it shall be automatically reformed to the 377 | minimum extent necessary to make it enforceable. If the provision 378 | cannot be reformed, it shall be severed from this Public License 379 | without affecting the enforceability of the remaining terms and 380 | conditions. 381 | 382 | c. No term or condition of this Public License will be waived and no 383 | failure to comply consented to unless expressly agreed to by the 384 | Licensor. 385 | 386 | d. Nothing in this Public License constitutes or may be interpreted 387 | as a limitation upon, or waiver of, any privileges and immunities 388 | that apply to the Licensor or You, including from the legal 389 | processes of any jurisdiction or authority. 390 | 391 | ======================================================================= 392 | 393 | Creative Commons is not a party to its public 394 | licenses. Notwithstanding, Creative Commons may elect to apply one of 395 | its public licenses to material it publishes and in those instances 396 | will be considered the "Licensor." The text of the Creative Commons 397 | public licenses is dedicated to the public domain under the CC0 Public 398 | Domain Dedication. Except for the limited purpose of indicating that 399 | material is shared under a Creative Commons public license or as 400 | otherwise permitted by the Creative Commons policies published at 401 | creativecommons.org/policies, Creative Commons does not authorize the 402 | use of the trademark "Creative Commons" or any other trademark or logo 403 | of Creative Commons without its prior written consent including, 404 | without limitation, in connection with any unauthorized modifications 405 | to any of its public licenses or any other arrangements, 406 | understandings, or agreements concerning use of licensed material. For 407 | the avoidance of doubt, this paragraph does not form part of the 408 | public licenses. 409 | 410 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Better model,Better performance 2 | Model-Swap-Face_v2挂出来了,可以参考使用。 3 |

4 |   Note:这个项目展示的是我在数字模特方面的一些探索,希望通过降本增效的方式挖掘生成技术的实际商用价值。此项目展示的是仅支持端到端的单模特头像合成方案,即在保留输入模特表情信息的情况下生成一张更富样式吸引力的新模特。如果想了解支持多模特形象选择的方案可以参阅我的研究笔记
5 |

6 | # 效果预览 7 | ## 单图输入-输出展示 8 |

9 | Sample 10 |

11 |

输入


12 |

13 | Sample 14 |

15 |

模特风格输出



16 | 17 | ## 多图对比展示 18 |

19 | Sample 20 |

21 |

多效果转换图预览

22 | 23 | ## 替换效果展示 24 |   此处是展示生成图像替换回原图的效果,引入了额外的后处理。
25 |

26 | Sample 27 |

28 |

转小孩子风格图片——左:输入-右:输出


29 |

30 | Sample 31 |

32 |

转网红风格图片——左:输入-右:输出


33 |

34 | Sample 35 |

36 |

转多种风格图片——1排:输入-2-5排:输出


37 |

38 | 39 | # Inference框架 40 |

41 | Sample 42 |

43 |

44 | 45 | # 使用方法 46 | 47 | ## 环境配置 48 | * Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons. 49 | * 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer. 50 | * TensorFlow 1.10.0 or newer with GPU support. 51 | * One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs. 52 | * NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer. 53 | * 54 | ## 运行方法 55 |   1.按照```netwotk/download_weights.txt```所示将模型文件下载至networks文件夹下。
56 |   2.配置好main.py并运行```python main.py```。
57 |

58 | 59 | # 多模特选择方案 60 |

61 | Sample 62 |

63 |   多模特选择方案支持更多样的模特选择,实现方法可以参阅我的研究笔记
64 |

65 | 66 | 67 | # 致谢 68 |   代码部分借用了PuzerPbaylies的代码,感谢分享。
69 |

70 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /dnnlib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /dnnlib/submission/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/submission/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__pycache__/run_context.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/submission/__pycache__/run_context.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/submission/__pycache__/submit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/submission/__pycache__/submit.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /dnnlib/submission/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Submit a function to be run either locally or in a computing cluster.""" 9 | 10 | import copy 11 | import io 12 | import os 13 | import pathlib 14 | import pickle 15 | import platform 16 | import pprint 17 | import re 18 | import shutil 19 | import time 20 | import traceback 21 | 22 | import zipfile 23 | 24 | from enum import Enum 25 | 26 | from .. import util 27 | from ..util import EasyDict 28 | 29 | 30 | class SubmitTarget(Enum): 31 | """The target where the function should be run. 32 | 33 | LOCAL: Run it locally. 34 | """ 35 | LOCAL = 1 36 | 37 | 38 | class PathType(Enum): 39 | """Determines in which format should a path be formatted. 40 | 41 | WINDOWS: Format with Windows style. 42 | LINUX: Format with Linux/Posix style. 43 | AUTO: Use current OS type to select either WINDOWS or LINUX. 44 | """ 45 | WINDOWS = 1 46 | LINUX = 2 47 | AUTO = 3 48 | 49 | 50 | _user_name_override = None 51 | 52 | 53 | class SubmitConfig(util.EasyDict): 54 | """Strongly typed config dict needed to submit runs. 55 | 56 | Attributes: 57 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. 58 | run_desc: Description of the run. Will be used in the run dir and task name. 59 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. 60 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. 61 | submit_target: Submit target enum value. Used to select where the run is actually launched. 62 | num_gpus: Number of GPUs used/requested for the run. 63 | print_info: Whether to print debug information when submitting. 64 | ask_confirmation: Whether to ask a confirmation before submitting. 65 | run_id: Automatically populated value during submit. 66 | run_name: Automatically populated value during submit. 67 | run_dir: Automatically populated value during submit. 68 | run_func_name: Automatically populated value during submit. 69 | run_func_kwargs: Automatically populated value during submit. 70 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. 71 | task_name: Automatically populated value during submit. 72 | host_name: Automatically populated value during submit. 73 | """ 74 | 75 | def __init__(self): 76 | super().__init__() 77 | 78 | # run (set these) 79 | self.run_dir_root = "" # should always be passed through get_path_from_template 80 | self.run_desc = "" 81 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] 82 | self.run_dir_extra_files = None 83 | 84 | # submit (set these) 85 | self.submit_target = SubmitTarget.LOCAL 86 | self.num_gpus = 1 87 | self.print_info = False 88 | self.ask_confirmation = False 89 | 90 | # (automatically populated) 91 | self.run_id = None 92 | self.run_name = None 93 | self.run_dir = None 94 | self.run_func_name = None 95 | self.run_func_kwargs = None 96 | self.user_name = None 97 | self.task_name = None 98 | self.host_name = "localhost" 99 | 100 | 101 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: 102 | """Replace tags in the given path template and return either Windows or Linux formatted path.""" 103 | # automatically select path type depending on running OS 104 | if path_type == PathType.AUTO: 105 | if platform.system() == "Windows": 106 | path_type = PathType.WINDOWS 107 | elif platform.system() == "Linux": 108 | path_type = PathType.LINUX 109 | else: 110 | raise RuntimeError("Unknown platform") 111 | 112 | path_template = path_template.replace("", get_user_name()) 113 | 114 | # return correctly formatted path 115 | if path_type == PathType.WINDOWS: 116 | return str(pathlib.PureWindowsPath(path_template)) 117 | elif path_type == PathType.LINUX: 118 | return str(pathlib.PurePosixPath(path_template)) 119 | else: 120 | raise RuntimeError("Unknown platform") 121 | 122 | 123 | def get_template_from_path(path: str) -> str: 124 | """Convert a normal path back to its template representation.""" 125 | # replace all path parts with the template tags 126 | path = path.replace("\\", "/") 127 | return path 128 | 129 | 130 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: 131 | """Convert a normal path to template and the convert it back to a normal path with given path type.""" 132 | path_template = get_template_from_path(path) 133 | path = get_path_from_template(path_template, path_type) 134 | return path 135 | 136 | 137 | def set_user_name_override(name: str) -> None: 138 | """Set the global username override value.""" 139 | global _user_name_override 140 | _user_name_override = name 141 | 142 | 143 | def get_user_name(): 144 | """Get the current user name.""" 145 | if _user_name_override is not None: 146 | return _user_name_override 147 | elif platform.system() == "Windows": 148 | return os.getlogin() 149 | elif platform.system() == "Linux": 150 | try: 151 | import pwd # pylint: disable=import-error 152 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member 153 | except: 154 | return "unknown" 155 | else: 156 | raise RuntimeError("Unknown platform") 157 | 158 | 159 | def _create_run_dir_local(submit_config: SubmitConfig) -> str: 160 | """Create a new run dir with increasing ID number at the start.""" 161 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) 162 | 163 | if not os.path.exists(run_dir_root): 164 | print("Creating the run dir root: {}".format(run_dir_root)) 165 | os.makedirs(run_dir_root) 166 | 167 | submit_config.run_id = _get_next_run_id_local(run_dir_root) 168 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) 169 | run_dir = os.path.join(run_dir_root, submit_config.run_name) 170 | 171 | if os.path.exists(run_dir): 172 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) 173 | 174 | print("Creating the run dir: {}".format(run_dir)) 175 | os.makedirs(run_dir) 176 | 177 | return run_dir 178 | 179 | 180 | def _get_next_run_id_local(run_dir_root: str) -> int: 181 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" 182 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] 183 | r = re.compile("^\\d+") # match one or more digits at the start of the string 184 | run_id = 0 185 | 186 | for dir_name in dir_names: 187 | m = r.match(dir_name) 188 | 189 | if m is not None: 190 | i = int(m.group()) 191 | run_id = max(run_id, i + 1) 192 | 193 | return run_id 194 | 195 | 196 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: 197 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" 198 | print("Copying files to the run dir") 199 | files = [] 200 | 201 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) 202 | assert '.' in submit_config.run_func_name 203 | for _idx in range(submit_config.run_func_name.count('.') - 1): 204 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) 205 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) 206 | 207 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") 208 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) 209 | 210 | if submit_config.run_dir_extra_files is not None: 211 | files += submit_config.run_dir_extra_files 212 | 213 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] 214 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] 215 | 216 | util.copy_files_and_create_dirs(files) 217 | 218 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) 219 | 220 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: 221 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) 222 | 223 | 224 | def run_wrapper(submit_config: SubmitConfig) -> None: 225 | """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" 226 | is_local = submit_config.submit_target == SubmitTarget.LOCAL 227 | 228 | checker = None 229 | 230 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing 231 | if is_local: 232 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) 233 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) 234 | logger = util.Logger(file_name=None, should_flush=True) 235 | 236 | import dnnlib 237 | dnnlib.submit_config = submit_config 238 | 239 | try: 240 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) 241 | start_time = time.time() 242 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) 243 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) 244 | except: 245 | if is_local: 246 | raise 247 | else: 248 | traceback.print_exc() 249 | 250 | log_src = os.path.join(submit_config.run_dir, "log.txt") 251 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) 252 | shutil.copyfile(log_src, log_dst) 253 | finally: 254 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() 255 | 256 | dnnlib.submit_config = None 257 | logger.close() 258 | 259 | if checker is not None: 260 | checker.stop() 261 | 262 | 263 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: 264 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" 265 | submit_config = copy.copy(submit_config) 266 | 267 | if submit_config.user_name is None: 268 | submit_config.user_name = get_user_name() 269 | 270 | submit_config.run_func_name = run_func_name 271 | submit_config.run_func_kwargs = run_func_kwargs 272 | 273 | assert submit_config.submit_target == SubmitTarget.LOCAL 274 | if submit_config.submit_target in {SubmitTarget.LOCAL}: 275 | run_dir = _create_run_dir_local(submit_config) 276 | 277 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) 278 | submit_config.run_dir = run_dir 279 | _populate_run_dir(run_dir, submit_config) 280 | 281 | if submit_config.print_info: 282 | print("\nSubmit config:\n") 283 | pprint.pprint(submit_config, indent=4, width=200, compact=False) 284 | print() 285 | 286 | if submit_config.ask_confirmation: 287 | if not util.ask_yes_no("Continue submitting the job?"): 288 | return 289 | 290 | run_wrapper(submit_config) 291 | -------------------------------------------------------------------------------- /dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/network.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /dnnlib/tflib/network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for managing networks.""" 9 | 10 | import types 11 | import inspect 12 | import re 13 | import uuid 14 | import sys 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from collections import OrderedDict 19 | from typing import Any, List, Tuple, Union 20 | 21 | from . import tfutil 22 | from .. import util 23 | 24 | from .tfutil import TfExpression, TfExpressionEx 25 | 26 | _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. 27 | _import_module_src = dict() # Source code for temporary modules created during pickle import. 28 | 29 | 30 | def import_handler(handler_func): 31 | """Function decorator for declaring custom import handlers.""" 32 | _import_handlers.append(handler_func) 33 | return handler_func 34 | 35 | 36 | class Network: 37 | """Generic network abstraction. 38 | 39 | Acts as a convenience wrapper for a parameterized network construction 40 | function, providing several utility methods and convenient access to 41 | the inputs/outputs/weights. 42 | 43 | Network objects can be safely pickled and unpickled for long-term 44 | archival purposes. The pickling works reliably as long as the underlying 45 | network construction function is defined in a standalone Python module 46 | that has no side effects or application-specific imports. 47 | 48 | Args: 49 | name: Network name. Used to select TensorFlow name and variable scopes. 50 | func_name: Fully qualified name of the underlying network construction function, or a top-level function object. 51 | static_kwargs: Keyword arguments to be passed in to the network construction function. 52 | 53 | Attributes: 54 | name: User-specified name, defaults to build func name if None. 55 | scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name. 56 | static_kwargs: Arguments passed to the user-supplied build func. 57 | components: Container for sub-networks. Passed to the build func, and retained between calls. 58 | num_inputs: Number of input tensors. 59 | num_outputs: Number of output tensors. 60 | input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension. 61 | output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension. 62 | input_shape: Short-hand for input_shapes[0]. 63 | output_shape: Short-hand for output_shapes[0]. 64 | input_templates: Input placeholders in the template graph. 65 | output_templates: Output tensors in the template graph. 66 | input_names: Name string for each input. 67 | output_names: Name string for each output. 68 | own_vars: Variables defined by this network (local_name => var), excluding sub-networks. 69 | vars: All variables (local_name => var). 70 | trainables: All trainable variables (local_name => var). 71 | var_global_to_local: Mapping from variable global names to local names. 72 | """ 73 | 74 | def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): 75 | tfutil.assert_tf_initialized() 76 | assert isinstance(name, str) or name is None 77 | assert func_name is not None 78 | assert isinstance(func_name, str) or util.is_top_level_function(func_name) 79 | assert util.is_pickleable(static_kwargs) 80 | 81 | self._init_fields() 82 | self.name = name 83 | self.static_kwargs = util.EasyDict(static_kwargs) 84 | 85 | # Locate the user-specified network build function. 86 | if util.is_top_level_function(func_name): 87 | func_name = util.get_top_level_function_name(func_name) 88 | module, self._build_func_name = util.get_module_from_obj_name(func_name) 89 | self._build_func = util.get_obj_from_module(module, self._build_func_name) 90 | assert callable(self._build_func) 91 | 92 | # Dig up source code for the module containing the build function. 93 | self._build_module_src = _import_module_src.get(module, None) 94 | if self._build_module_src is None: 95 | self._build_module_src = inspect.getsource(module) 96 | 97 | # Init TensorFlow graph. 98 | self._init_graph() 99 | self.reset_own_vars() 100 | 101 | def _init_fields(self) -> None: 102 | self.name = None 103 | self.scope = None 104 | self.static_kwargs = util.EasyDict() 105 | self.components = util.EasyDict() 106 | self.num_inputs = 0 107 | self.num_outputs = 0 108 | self.input_shapes = [[]] 109 | self.output_shapes = [[]] 110 | self.input_shape = [] 111 | self.output_shape = [] 112 | self.input_templates = [] 113 | self.output_templates = [] 114 | self.input_names = [] 115 | self.output_names = [] 116 | self.own_vars = OrderedDict() 117 | self.vars = OrderedDict() 118 | self.trainables = OrderedDict() 119 | self.var_global_to_local = OrderedDict() 120 | 121 | self._build_func = None # User-supplied build function that constructs the network. 122 | self._build_func_name = None # Name of the build function. 123 | self._build_module_src = None # Full source code of the module containing the build function. 124 | self._run_cache = dict() # Cached graph data for Network.run(). 125 | 126 | def _init_graph(self) -> None: 127 | # Collect inputs. 128 | self.input_names = [] 129 | 130 | for param in inspect.signature(self._build_func).parameters.values(): 131 | if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: 132 | self.input_names.append(param.name) 133 | 134 | self.num_inputs = len(self.input_names) 135 | assert self.num_inputs >= 1 136 | 137 | # Choose name and scope. 138 | if self.name is None: 139 | self.name = self._build_func_name 140 | assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) 141 | with tf.name_scope(None): 142 | self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) 143 | 144 | # Finalize build func kwargs. 145 | build_kwargs = dict(self.static_kwargs) 146 | build_kwargs["is_template_graph"] = True 147 | build_kwargs["components"] = self.components 148 | 149 | # Build template graph. 150 | with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes 151 | assert tf.get_variable_scope().name == self.scope 152 | assert tf.get_default_graph().get_name_scope() == self.scope 153 | with tf.control_dependencies(None): # ignore surrounding control dependencies 154 | self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] 155 | out_expr = self._build_func(*self.input_templates, **build_kwargs) 156 | 157 | # Collect outputs. 158 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) 159 | self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) 160 | self.num_outputs = len(self.output_templates) 161 | assert self.num_outputs >= 1 162 | assert all(tfutil.is_tf_expression(t) for t in self.output_templates) 163 | 164 | # Perform sanity checks. 165 | if any(t.shape.ndims is None for t in self.input_templates): 166 | raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") 167 | if any(t.shape.ndims is None for t in self.output_templates): 168 | raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") 169 | if any(not isinstance(comp, Network) for comp in self.components.values()): 170 | raise ValueError("Components of a Network must be Networks themselves.") 171 | if len(self.components) != len(set(comp.name for comp in self.components.values())): 172 | raise ValueError("Components of a Network must have unique names.") 173 | 174 | # List inputs and outputs. 175 | self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates] 176 | self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates] 177 | self.input_shape = self.input_shapes[0] 178 | self.output_shape = self.output_shapes[0] 179 | self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] 180 | 181 | # List variables. 182 | self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) 183 | self.vars = OrderedDict(self.own_vars) 184 | self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) 185 | self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) 186 | self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) 187 | 188 | def reset_own_vars(self) -> None: 189 | """Re-initialize all variables of this network, excluding sub-networks.""" 190 | tfutil.run([var.initializer for var in self.own_vars.values()]) 191 | 192 | def reset_vars(self) -> None: 193 | """Re-initialize all variables of this network, including sub-networks.""" 194 | tfutil.run([var.initializer for var in self.vars.values()]) 195 | 196 | def reset_trainables(self) -> None: 197 | """Re-initialize all trainable variables of this network, including sub-networks.""" 198 | tfutil.run([var.initializer for var in self.trainables.values()]) 199 | 200 | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: 201 | """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" 202 | assert len(in_expr) == self.num_inputs 203 | assert not all(expr is None for expr in in_expr) 204 | 205 | # Finalize build func kwargs. 206 | build_kwargs = dict(self.static_kwargs) 207 | build_kwargs.update(dynamic_kwargs) 208 | build_kwargs["is_template_graph"] = False 209 | build_kwargs["components"] = self.components 210 | 211 | # Build TensorFlow graph to evaluate the network. 212 | with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): 213 | assert tf.get_variable_scope().name == self.scope 214 | valid_inputs = [expr for expr in in_expr if expr is not None] 215 | final_inputs = [] 216 | for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): 217 | if expr is not None: 218 | expr = tf.identity(expr, name=name) 219 | else: 220 | expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) 221 | final_inputs.append(expr) 222 | out_expr = self._build_func(*final_inputs, **build_kwargs) 223 | 224 | # Propagate input shapes back to the user-specified expressions. 225 | for expr, final in zip(in_expr, final_inputs): 226 | if isinstance(expr, tf.Tensor): 227 | expr.set_shape(final.shape) 228 | 229 | # Express outputs in the desired format. 230 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) 231 | if return_as_list: 232 | out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) 233 | return out_expr 234 | 235 | def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: 236 | """Get the local name of a given variable, without any surrounding name scopes.""" 237 | assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) 238 | global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name 239 | return self.var_global_to_local[global_name] 240 | 241 | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: 242 | """Find variable by local or global name.""" 243 | assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) 244 | return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name 245 | 246 | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: 247 | """Get the value of a given variable as NumPy array. 248 | Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" 249 | return self.find_var(var_or_local_name).eval() 250 | 251 | def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: 252 | """Set the value of a given variable based on the given NumPy array. 253 | Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" 254 | tfutil.set_vars({self.find_var(var_or_local_name): new_value}) 255 | 256 | def __getstate__(self) -> dict: 257 | """Pickle export.""" 258 | state = dict() 259 | state["version"] = 3 260 | state["name"] = self.name 261 | state["static_kwargs"] = dict(self.static_kwargs) 262 | state["components"] = dict(self.components) 263 | state["build_module_src"] = self._build_module_src 264 | state["build_func_name"] = self._build_func_name 265 | state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) 266 | return state 267 | 268 | def __setstate__(self, state: dict) -> None: 269 | """Pickle import.""" 270 | # pylint: disable=attribute-defined-outside-init 271 | tfutil.assert_tf_initialized() 272 | self._init_fields() 273 | 274 | # Execute custom import handlers. 275 | for handler in _import_handlers: 276 | state = handler(state) 277 | 278 | # Set basic fields. 279 | assert state["version"] in [2, 3] 280 | self.name = state["name"] 281 | self.static_kwargs = util.EasyDict(state["static_kwargs"]) 282 | self.components = util.EasyDict(state.get("components", {})) 283 | self._build_module_src = state["build_module_src"] 284 | self._build_func_name = state["build_func_name"] 285 | 286 | # Create temporary module from the imported source code. 287 | module_name = "_tflib_network_import_" + uuid.uuid4().hex 288 | module = types.ModuleType(module_name) 289 | sys.modules[module_name] = module 290 | _import_module_src[module] = self._build_module_src 291 | exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used 292 | 293 | # Locate network build function in the temporary module. 294 | self._build_func = util.get_obj_from_module(module, self._build_func_name) 295 | assert callable(self._build_func) 296 | 297 | # Init TensorFlow graph. 298 | self._init_graph() 299 | self.reset_own_vars() 300 | tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]}) 301 | 302 | def clone(self, name: str = None, **new_static_kwargs) -> "Network": 303 | """Create a clone of this network with its own copy of the variables.""" 304 | # pylint: disable=protected-access 305 | net = object.__new__(Network) 306 | net._init_fields() 307 | net.name = name if name is not None else self.name 308 | net.static_kwargs = util.EasyDict(self.static_kwargs) 309 | net.static_kwargs.update(new_static_kwargs) 310 | net._build_module_src = self._build_module_src 311 | net._build_func_name = self._build_func_name 312 | net._build_func = self._build_func 313 | net._init_graph() 314 | net.copy_vars_from(self) 315 | return net 316 | 317 | def copy_own_vars_from(self, src_net: "Network") -> None: 318 | """Copy the values of all variables from the given network, excluding sub-networks.""" 319 | names = [name for name in self.own_vars.keys() if name in src_net.own_vars] 320 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 321 | 322 | def copy_vars_from(self, src_net: "Network") -> None: 323 | """Copy the values of all variables from the given network, including sub-networks.""" 324 | names = [name for name in self.vars.keys() if name in src_net.vars] 325 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 326 | 327 | def copy_trainables_from(self, src_net: "Network") -> None: 328 | """Copy the values of all trainable variables from the given network, including sub-networks.""" 329 | names = [name for name in self.trainables.keys() if name in src_net.trainables] 330 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 331 | 332 | def copy_compatible_trainables_from(self, src_net: "Network") -> None: 333 | """Copy the compatible values of all trainable variables from the given network, including sub-networks""" 334 | names = [] 335 | for name in self.trainables.keys(): 336 | if name not in src_net.trainables: 337 | print("Not restoring (not present): {}".format(name)) 338 | elif self.trainables[name].shape != src_net.trainables[name].shape: 339 | print("Not restoring (different shape): {}".format(name)) 340 | 341 | if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape: 342 | names.append(name) 343 | 344 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) 345 | 346 | def apply_swa(self, src_net, epoch): 347 | """Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks""" 348 | names = [] 349 | for name in self.trainables.keys(): 350 | if name not in src_net.trainables: 351 | print("Not restoring (not present): {}".format(name)) 352 | elif self.trainables[name].shape != src_net.trainables[name].shape: 353 | print("Not restoring (different shape): {}".format(name)) 354 | 355 | if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape: 356 | names.append(name) 357 | 358 | scale_new_data = 1.0 / (epoch + 1) 359 | scale_moving_average = (1.0 - scale_new_data) 360 | tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names})) 361 | 362 | def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": 363 | """Create new network with the given parameters, and copy all variables from this network.""" 364 | if new_name is None: 365 | new_name = self.name 366 | static_kwargs = dict(self.static_kwargs) 367 | static_kwargs.update(new_static_kwargs) 368 | net = Network(name=new_name, func_name=new_func_name, **static_kwargs) 369 | net.copy_vars_from(self) 370 | return net 371 | 372 | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: 373 | """Construct a TensorFlow op that updates the variables of this network 374 | to be slightly closer to those of the given network.""" 375 | with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): 376 | ops = [] 377 | for name, var in self.vars.items(): 378 | if name in src_net.vars: 379 | cur_beta = beta if name in self.trainables else beta_nontrainable 380 | new_value = tfutil.lerp(src_net.vars[name], var, cur_beta) 381 | ops.append(var.assign(new_value)) 382 | return tf.group(*ops) 383 | 384 | def run(self, 385 | *in_arrays: Tuple[Union[np.ndarray, None], ...], 386 | input_transform: dict = None, 387 | output_transform: dict = None, 388 | return_as_list: bool = False, 389 | print_progress: bool = False, 390 | minibatch_size: int = None, 391 | num_gpus: int = 1, 392 | assume_frozen: bool = False, 393 | custom_inputs=None, 394 | **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: 395 | """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). 396 | 397 | Args: 398 | input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. 399 | The dict must contain a 'func' field that points to a top-level function. The function is called with the input 400 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. 401 | output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. 402 | The dict must contain a 'func' field that points to a top-level function. The function is called with the output 403 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. 404 | return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. 405 | print_progress: Print progress to the console? Useful for very large input arrays. 406 | minibatch_size: Maximum minibatch size to use, None = disable batching. 407 | num_gpus: Number of GPUs to use. 408 | assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. 409 | dynamic_kwargs: Additional keyword arguments to be passed into the network build function. 410 | custom_inputs: Allow to use another Tensor as input instead of default Placeholders 411 | """ 412 | assert len(in_arrays) == self.num_inputs 413 | assert not all(arr is None for arr in in_arrays) 414 | assert input_transform is None or util.is_top_level_function(input_transform["func"]) 415 | assert output_transform is None or util.is_top_level_function(output_transform["func"]) 416 | output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) 417 | num_items = in_arrays[0].shape[0] 418 | if minibatch_size is None: 419 | minibatch_size = num_items 420 | 421 | # Construct unique hash key from all arguments that affect the TensorFlow graph. 422 | key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) 423 | def unwind_key(obj): 424 | if isinstance(obj, dict): 425 | return [(key, unwind_key(value)) for key, value in sorted(obj.items())] 426 | if callable(obj): 427 | return util.get_top_level_function_name(obj) 428 | return obj 429 | key = repr(unwind_key(key)) 430 | 431 | # Build graph. 432 | if key not in self._run_cache: 433 | with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): 434 | if custom_inputs is not None: 435 | with tf.device("/gpu:0"): 436 | in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)] 437 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) 438 | else: 439 | with tf.device("/cpu:0"): 440 | in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] 441 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) 442 | 443 | out_split = [] 444 | for gpu in range(num_gpus): 445 | with tf.device("/gpu:%d" % gpu): 446 | net_gpu = self.clone() if assume_frozen else self 447 | in_gpu = in_split[gpu] 448 | 449 | if input_transform is not None: 450 | in_kwargs = dict(input_transform) 451 | in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) 452 | in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) 453 | 454 | assert len(in_gpu) == self.num_inputs 455 | out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) 456 | 457 | if output_transform is not None: 458 | out_kwargs = dict(output_transform) 459 | out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) 460 | out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) 461 | 462 | assert len(out_gpu) == self.num_outputs 463 | out_split.append(out_gpu) 464 | 465 | with tf.device("/cpu:0"): 466 | out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] 467 | self._run_cache[key] = in_expr, out_expr 468 | 469 | # Run minibatches. 470 | in_expr, out_expr = self._run_cache[key] 471 | out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr] 472 | 473 | for mb_begin in range(0, num_items, minibatch_size): 474 | if print_progress: 475 | print("\r%d / %d" % (mb_begin, num_items), end="") 476 | 477 | mb_end = min(mb_begin + minibatch_size, num_items) 478 | mb_num = mb_end - mb_begin 479 | mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] 480 | mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) 481 | 482 | for dst, src in zip(out_arrays, mb_out): 483 | dst[mb_begin: mb_end] = src 484 | 485 | # Done. 486 | if print_progress: 487 | print("\r%d / %d" % (num_items, num_items)) 488 | 489 | if not return_as_list: 490 | out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) 491 | return out_arrays 492 | 493 | def list_ops(self) -> List[TfExpression]: 494 | include_prefix = self.scope + "/" 495 | exclude_prefix = include_prefix + "_" 496 | ops = tf.get_default_graph().get_operations() 497 | ops = [op for op in ops if op.name.startswith(include_prefix)] 498 | ops = [op for op in ops if not op.name.startswith(exclude_prefix)] 499 | return ops 500 | 501 | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: 502 | """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to 503 | individual layers of the network. Mainly intended to be used for reporting.""" 504 | layers = [] 505 | 506 | def recurse(scope, parent_ops, parent_vars, level): 507 | # Ignore specific patterns. 508 | if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): 509 | return 510 | 511 | # Filter ops and vars by scope. 512 | global_prefix = scope + "/" 513 | local_prefix = global_prefix[len(self.scope) + 1:] 514 | cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] 515 | cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] 516 | if not cur_ops and not cur_vars: 517 | return 518 | 519 | # Filter out all ops related to variables. 520 | for var in [op for op in cur_ops if op.type.startswith("Variable")]: 521 | var_prefix = var.name + "/" 522 | cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] 523 | 524 | # Scope does not contain ops as immediate children => recurse deeper. 525 | contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops) 526 | if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1: 527 | visited = set() 528 | for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: 529 | token = rel_name.split("/")[0] 530 | if token not in visited: 531 | recurse(global_prefix + token, cur_ops, cur_vars, level + 1) 532 | visited.add(token) 533 | return 534 | 535 | # Report layer. 536 | layer_name = scope[len(self.scope) + 1:] 537 | layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] 538 | layer_trainables = [var for _name, var in cur_vars if var.trainable] 539 | layers.append((layer_name, layer_output, layer_trainables)) 540 | 541 | recurse(self.scope, self.list_ops(), list(self.vars.items()), 0) 542 | return layers 543 | 544 | def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: 545 | """Print a summary table of the network structure.""" 546 | rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] 547 | rows += [["---"] * 4] 548 | total_params = 0 549 | 550 | for layer_name, layer_output, layer_trainables in self.list_layers(): 551 | num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables) 552 | weights = [var for var in layer_trainables if var.name.endswith("/weight:0") or var.name.endswith("/weight_1:0")] 553 | weights.sort(key=lambda x: len(x.name)) 554 | if len(weights) == 0 and len(layer_trainables) == 1: 555 | weights = layer_trainables 556 | total_params += num_params 557 | 558 | if not hide_layers_with_no_params or num_params != 0: 559 | num_params_str = str(num_params) if num_params > 0 else "-" 560 | output_shape_str = str(layer_output.shape) 561 | weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" 562 | rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] 563 | 564 | rows += [["---"] * 4] 565 | rows += [["Total", str(total_params), "", ""]] 566 | 567 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 568 | print() 569 | for row in rows: 570 | print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) 571 | print() 572 | 573 | def setup_weight_histograms(self, title: str = None) -> None: 574 | """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" 575 | if title is None: 576 | title = self.name 577 | 578 | with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): 579 | for local_name, var in self.trainables.items(): 580 | if "/" in local_name: 581 | p = local_name.split("/") 582 | name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) 583 | else: 584 | name = title + "_toplevel/" + local_name 585 | 586 | tf.summary.histogram(name, var) 587 | 588 | #---------------------------------------------------------------------------- 589 | # Backwards-compatible emulation of legacy output transformation in Network.run(). 590 | 591 | _print_legacy_warning = True 592 | 593 | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): 594 | global _print_legacy_warning 595 | legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] 596 | if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): 597 | return output_transform, dynamic_kwargs 598 | 599 | if _print_legacy_warning: 600 | _print_legacy_warning = False 601 | print() 602 | print("WARNING: Old-style output transformations in Network.run() are deprecated.") 603 | print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") 604 | print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") 605 | print() 606 | assert output_transform is None 607 | 608 | new_kwargs = dict(dynamic_kwargs) 609 | new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} 610 | new_transform["func"] = _legacy_output_transform_func 611 | return new_transform, new_kwargs 612 | 613 | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): 614 | if out_mul != 1.0: 615 | expr = [x * out_mul for x in expr] 616 | 617 | if out_add != 0.0: 618 | expr = [x + out_add for x in expr] 619 | 620 | if out_shrink > 1: 621 | ksize = [1, 1, out_shrink, out_shrink] 622 | expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] 623 | 624 | if out_dtype is not None: 625 | if tf.as_dtype(out_dtype).is_integer: 626 | expr = [tf.round(x) for x in expr] 627 | expr = [tf.saturate_cast(x, out_dtype) for x in expr] 628 | return expr 629 | -------------------------------------------------------------------------------- /dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | tf.reset_default_graph() 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | if uint8_cast: 241 | images = tf.saturate_cast(images, tf.uint8) 242 | return images 243 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import uuid 27 | 28 | from distutils.util import strtobool 29 | from typing import Any, List, Tuple, Union 30 | 31 | 32 | # Util classes 33 | # ------------------------------------------------------------------------------------------ 34 | 35 | 36 | class EasyDict(dict): 37 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 38 | 39 | def __getattr__(self, name: str) -> Any: 40 | try: 41 | return self[name] 42 | except KeyError: 43 | raise AttributeError(name) 44 | 45 | def __setattr__(self, name: str, value: Any) -> None: 46 | self[name] = value 47 | 48 | def __delattr__(self, name: str) -> None: 49 | del self[name] 50 | 51 | 52 | class Logger(object): 53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 54 | 55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 56 | self.file = None 57 | 58 | if file_name is not None: 59 | self.file = open(file_name, file_mode) 60 | 61 | self.should_flush = should_flush 62 | self.stdout = sys.stdout 63 | self.stderr = sys.stderr 64 | 65 | sys.stdout = self 66 | sys.stderr = self 67 | 68 | def __enter__(self) -> "Logger": 69 | return self 70 | 71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 72 | self.close() 73 | 74 | def write(self, text: str) -> None: 75 | """Write text to stdout (and a file) and optionally flush.""" 76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 77 | return 78 | 79 | if self.file is not None: 80 | self.file.write(text) 81 | 82 | self.stdout.write(text) 83 | 84 | if self.should_flush: 85 | self.flush() 86 | 87 | def flush(self) -> None: 88 | """Flush written text to both stdout and a file, if open.""" 89 | if self.file is not None: 90 | self.file.flush() 91 | 92 | self.stdout.flush() 93 | 94 | def close(self) -> None: 95 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 96 | self.flush() 97 | 98 | # if using multiple loggers, prevent closing in wrong order 99 | if sys.stdout is self: 100 | sys.stdout = self.stdout 101 | if sys.stderr is self: 102 | sys.stderr = self.stderr 103 | 104 | if self.file is not None: 105 | self.file.close() 106 | 107 | 108 | # Small util functions 109 | # ------------------------------------------------------------------------------------------ 110 | 111 | 112 | def format_time(seconds: Union[int, float]) -> str: 113 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 114 | s = int(np.rint(seconds)) 115 | 116 | if s < 60: 117 | return "{0}s".format(s) 118 | elif s < 60 * 60: 119 | return "{0}m {1:02}s".format(s // 60, s % 60) 120 | elif s < 24 * 60 * 60: 121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 122 | else: 123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 124 | 125 | 126 | def ask_yes_no(question: str) -> bool: 127 | """Ask the user the question until the user inputs a valid answer.""" 128 | while True: 129 | try: 130 | print("{0} [y/n]".format(question)) 131 | return strtobool(input().lower()) 132 | except ValueError: 133 | pass 134 | 135 | 136 | def tuple_product(t: Tuple) -> Any: 137 | """Calculate the product of the tuple elements.""" 138 | result = 1 139 | 140 | for v in t: 141 | result *= v 142 | 143 | return result 144 | 145 | 146 | _str_to_ctype = { 147 | "uint8": ctypes.c_ubyte, 148 | "uint16": ctypes.c_uint16, 149 | "uint32": ctypes.c_uint32, 150 | "uint64": ctypes.c_uint64, 151 | "int8": ctypes.c_byte, 152 | "int16": ctypes.c_int16, 153 | "int32": ctypes.c_int32, 154 | "int64": ctypes.c_int64, 155 | "float32": ctypes.c_float, 156 | "float64": ctypes.c_double 157 | } 158 | 159 | 160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 162 | type_str = None 163 | 164 | if isinstance(type_obj, str): 165 | type_str = type_obj 166 | elif hasattr(type_obj, "__name__"): 167 | type_str = type_obj.__name__ 168 | elif hasattr(type_obj, "name"): 169 | type_str = type_obj.name 170 | else: 171 | raise RuntimeError("Cannot infer type name from input") 172 | 173 | assert type_str in _str_to_ctype.keys() 174 | 175 | my_dtype = np.dtype(type_str) 176 | my_ctype = _str_to_ctype[type_str] 177 | 178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 179 | 180 | return my_dtype, my_ctype 181 | 182 | 183 | def is_pickleable(obj: Any) -> bool: 184 | try: 185 | with io.BytesIO() as stream: 186 | pickle.dump(obj, stream) 187 | return True 188 | except: 189 | return False 190 | 191 | 192 | # Functionality to import modules/objects by name, and call functions by name 193 | # ------------------------------------------------------------------------------------------ 194 | 195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 196 | """Searches for the underlying module behind the name to some python object. 197 | Returns the module and the object name (original name with module part removed).""" 198 | 199 | # allow convenience shorthands, substitute them by full names 200 | obj_name = re.sub("^np.", "numpy.", obj_name) 201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 202 | 203 | # list alternatives for (module_name, local_obj_name) 204 | parts = obj_name.split(".") 205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 206 | 207 | # try each alternative in turn 208 | for module_name, local_obj_name in name_pairs: 209 | try: 210 | module = importlib.import_module(module_name) # may raise ImportError 211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 212 | return module, local_obj_name 213 | except: 214 | pass 215 | 216 | # maybe some of the modules themselves contain errors? 217 | for module_name, _local_obj_name in name_pairs: 218 | try: 219 | importlib.import_module(module_name) # may raise ImportError 220 | except ImportError: 221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 222 | raise 223 | 224 | # maybe the requested attribute is missing? 225 | for module_name, local_obj_name in name_pairs: 226 | try: 227 | module = importlib.import_module(module_name) # may raise ImportError 228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 229 | except ImportError: 230 | pass 231 | 232 | # we are out of luck, but we have no idea why 233 | raise ImportError(obj_name) 234 | 235 | 236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 237 | """Traverses the object name and returns the last (rightmost) python object.""" 238 | if obj_name == '': 239 | return module 240 | obj = module 241 | for part in obj_name.split("."): 242 | obj = getattr(obj, part) 243 | return obj 244 | 245 | 246 | def get_obj_by_name(name: str) -> Any: 247 | """Finds the python object with the given name.""" 248 | module, obj_name = get_module_from_obj_name(name) 249 | return get_obj_from_module(module, obj_name) 250 | 251 | 252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 253 | """Finds the python object with the given name and calls it as a function.""" 254 | assert func_name is not None 255 | func_obj = get_obj_by_name(func_name) 256 | assert callable(func_obj) 257 | return func_obj(*args, **kwargs) 258 | 259 | 260 | def get_module_dir_by_obj_name(obj_name: str) -> str: 261 | """Get the directory path of the module containing the given object name.""" 262 | module, _ = get_module_from_obj_name(obj_name) 263 | return os.path.dirname(inspect.getfile(module)) 264 | 265 | 266 | def is_top_level_function(obj: Any) -> bool: 267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 269 | 270 | 271 | def get_top_level_function_name(obj: Any) -> str: 272 | """Return the fully-qualified name of a top-level function.""" 273 | assert is_top_level_function(obj) 274 | return obj.__module__ + "." + obj.__name__ 275 | 276 | 277 | # File system helpers 278 | # ------------------------------------------------------------------------------------------ 279 | 280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 281 | """List all files recursively in a given directory while ignoring given file and directory names. 282 | Returns list of tuples containing both absolute and relative paths.""" 283 | assert os.path.isdir(dir_path) 284 | base_name = os.path.basename(os.path.normpath(dir_path)) 285 | 286 | if ignores is None: 287 | ignores = [] 288 | 289 | result = [] 290 | 291 | for root, dirs, files in os.walk(dir_path, topdown=True): 292 | for ignore_ in ignores: 293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 294 | 295 | # dirs need to be edited in-place 296 | for d in dirs_to_remove: 297 | dirs.remove(d) 298 | 299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 300 | 301 | absolute_paths = [os.path.join(root, f) for f in files] 302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 303 | 304 | if add_base_to_relative: 305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 306 | 307 | assert len(absolute_paths) == len(relative_paths) 308 | result += zip(absolute_paths, relative_paths) 309 | 310 | return result 311 | 312 | 313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 314 | """Takes in a list of tuples of (src, dst) paths and copies files. 315 | Will create all necessary directories.""" 316 | for file in files: 317 | target_dir_name = os.path.dirname(file[1]) 318 | 319 | # will create all intermediate-level directories 320 | if not os.path.exists(target_dir_name): 321 | os.makedirs(target_dir_name) 322 | 323 | shutil.copyfile(file[0], file[1]) 324 | 325 | 326 | # URL helpers 327 | # ------------------------------------------------------------------------------------------ 328 | 329 | def is_url(obj: Any) -> bool: 330 | """Determine whether the given object is a valid URL string.""" 331 | if not isinstance(obj, str) or not "://" in obj: 332 | return False 333 | try: 334 | res = requests.compat.urlparse(obj) 335 | if not res.scheme or not res.netloc or not "." in res.netloc: 336 | return False 337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 338 | if not res.scheme or not res.netloc or not "." in res.netloc: 339 | return False 340 | except: 341 | return False 342 | return True 343 | 344 | 345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: 346 | """Download the given URL and return a binary-mode file object to access the data.""" 347 | if not is_url(url) and os.path.isfile(url): 348 | return open(url, 'rb') 349 | 350 | assert is_url(url) 351 | assert num_attempts >= 1 352 | 353 | # Lookup from cache. 354 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 355 | if cache_dir is not None: 356 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 357 | if len(cache_files) == 1: 358 | return open(cache_files[0], "rb") 359 | 360 | # Download. 361 | url_name = None 362 | url_data = None 363 | with requests.Session() as session: 364 | if verbose: 365 | print("Downloading %s ..." % url, end="", flush=True) 366 | for attempts_left in reversed(range(num_attempts)): 367 | try: 368 | with session.get(url) as res: 369 | res.raise_for_status() 370 | if len(res.content) == 0: 371 | raise IOError("No data received") 372 | 373 | if len(res.content) < 8192: 374 | content_str = res.content.decode("utf-8") 375 | if "download_warning" in res.headers.get("Set-Cookie", ""): 376 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 377 | if len(links) == 1: 378 | url = requests.compat.urljoin(url, links[0]) 379 | raise IOError("Google Drive virus checker nag") 380 | if "Google Drive - Quota exceeded" in content_str: 381 | raise IOError("Google Drive quota exceeded") 382 | 383 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 384 | url_name = match[1] if match else url 385 | url_data = res.content 386 | if verbose: 387 | print(" done") 388 | break 389 | except: 390 | if not attempts_left: 391 | if verbose: 392 | print(" failed") 393 | raise 394 | if verbose: 395 | print(".", end="", flush=True) 396 | 397 | # Save to cache. 398 | if cache_dir is not None: 399 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 400 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 401 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 402 | os.makedirs(cache_dir, exist_ok=True) 403 | with open(temp_file, "wb") as f: 404 | f.write(url_data) 405 | os.replace(temp_file, cache_file) # atomic 406 | 407 | # Return data as file object. 408 | return io.BytesIO(url_data) 409 | -------------------------------------------------------------------------------- /encoder/generator_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import tensorflow as tf 3 | import numpy as np 4 | import dnnlib.tflib as tflib 5 | from functools import partial 6 | 7 | 8 | def create_stub(name, batch_size): 9 | return tf.constant(0, dtype='float32', shape=(batch_size, 0)) 10 | 11 | 12 | def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1): 13 | if tiled_dlatent: 14 | low_dim_dlatent = tf.get_variable('learnable_dlatents', 15 | shape=(batch_size, tile_size, 512), 16 | dtype='float32', 17 | initializer=tf.initializers.random_normal()) 18 | return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1]) 19 | else: 20 | return tf.get_variable('learnable_dlatents', 21 | shape=(batch_size, model_scale, 512), 22 | dtype='float32', 23 | initializer=tf.initializers.random_normal()) 24 | 25 | 26 | class Generator: 27 | def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False, initial=True): 28 | self.batch_size = batch_size 29 | self.tiled_dlatent=tiled_dlatent 30 | self.model_scale = int(2*(math.log(model_res, 2)-1)) # For example, 1024 -> 18 31 | if tiled_dlatent: 32 | self.initial_dlatents = np.zeros((self.batch_size, 1, 512)) 33 | if initial: 34 | model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)), 35 | randomize_noise=randomize_noise, minibatch_size=self.batch_size, 36 | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True, model_scale=self.model_scale), 37 | partial(create_stub, batch_size=batch_size)], 38 | structure='fixed') 39 | else: 40 | self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512)) 41 | if initial: 42 | if custom_input is not None: 43 | model.components.synthesis.run(self.initial_dlatents, 44 | randomize_noise=randomize_noise, minibatch_size=self.batch_size, 45 | custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)], 46 | structure='fixed') 47 | else: 48 | model.components.synthesis.run(self.initial_dlatents, 49 | randomize_noise=randomize_noise, minibatch_size=self.batch_size, 50 | custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale), 51 | partial(create_stub, batch_size=batch_size)], 52 | structure='fixed') 53 | self.dlatent_avg_def = model.get_var('dlatent_avg') 54 | self.reset_dlatent_avg() 55 | self.sess = tf.get_default_session() 56 | self.graph = tf.get_default_graph() 57 | 58 | self.dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name) 59 | self._assign_dlatent_ph = tf.placeholder(tf.float32, name="assign_dlatent_ph") 60 | self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph) 61 | self.set_dlatents(self.initial_dlatents) 62 | 63 | def get_tensor(name): 64 | try: 65 | return self.graph.get_tensor_by_name(name) 66 | except KeyError: 67 | return None 68 | 69 | self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0') 70 | if self.generator_output is None: 71 | self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0') 72 | if self.generator_output is None: 73 | self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0') 74 | # If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph. 75 | if self.generator_output is None: 76 | self.generator_output = get_tensor('G_synthesis/_Run/concat:0') 77 | if self.generator_output is None: 78 | self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0') 79 | if self.generator_output is None: 80 | self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0') 81 | if self.generator_output is None: 82 | for op in self.graph.get_operations(): 83 | print(op) 84 | raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output") 85 | self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False) 86 | self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8) 87 | 88 | # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782 89 | # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper, 90 | # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.) 91 | clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold) 92 | clipped_values = tf.where(clipping_mask, tf.random_normal(shape=self.dlatent_variable.shape), self.dlatent_variable) 93 | self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values) 94 | 95 | def reset_dlatents(self): 96 | self.set_dlatents(self.initial_dlatents) 97 | 98 | def set_dlatents(self, dlatents): 99 | if self.tiled_dlatent: 100 | if (dlatents.shape != (self.batch_size, 1, 512)) and (dlatents.shape[1] != 512): 101 | dlatents = np.mean(dlatents, axis=1, keepdims=True) 102 | if (dlatents.shape != (self.batch_size, 1, 512)): 103 | dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 1, 512))]) 104 | assert (dlatents.shape == (self.batch_size, 1, 512)) 105 | else: 106 | if (dlatents.shape[1] > self.model_scale): 107 | dlatents = dlatents[:,:self.model_scale,:] 108 | if (isinstance(dlatents.shape[0], int)): 109 | if (dlatents.shape != (self.batch_size, self.model_scale, 512)): 110 | dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))]) 111 | assert (dlatents.shape == (self.batch_size, self.model_scale, 512)) 112 | self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents}) 113 | return 114 | else: 115 | self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents) 116 | return 117 | self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents}) 118 | 119 | def stochastic_clip_dlatents(self): 120 | self.sess.run(self.stochastic_clip_op) 121 | 122 | def get_dlatents(self): 123 | return self.sess.run(self.dlatent_variable) 124 | 125 | def get_dlatent_avg(self): 126 | return self.dlatent_avg 127 | 128 | def set_dlatent_avg(self, dlatent_avg): 129 | self.dlatent_avg = dlatent_avg 130 | 131 | def reset_dlatent_avg(self): 132 | self.dlatent_avg = self.dlatent_avg_def 133 | 134 | def generate_images(self, dlatents=None): 135 | if dlatents is not None: 136 | self.set_dlatents(dlatents) 137 | return self.sess.run(self.generated_image_uint8) 138 | -------------------------------------------------------------------------------- /encoder/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from encoder.resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size=ks, 20 | stride=stride, 21 | padding=padding, 22 | bias=False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: 36 | nn.init.constant_(ly.bias, 0) 37 | 38 | 39 | class BiSeNetOutput(nn.Module): 40 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 41 | super(BiSeNetOutput, self).__init__() 42 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 43 | self.conv_out = nn.Conv2d( 44 | mid_chan, n_classes, kernel_size=1, bias=False) 45 | self.init_weight() 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | x = self.conv_out(x) 50 | return x 51 | 52 | def init_weight(self): 53 | for ly in self.children(): 54 | if isinstance(ly, nn.Conv2d): 55 | nn.init.kaiming_normal_(ly.weight, a=1) 56 | if not ly.bias is None: 57 | nn.init.constant_(ly.bias, 0) 58 | 59 | def get_params(self): 60 | wd_params, nowd_params = [], [] 61 | for name, module in self.named_modules(): 62 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 63 | wd_params.append(module.weight) 64 | if not module.bias is None: 65 | nowd_params.append(module.bias) 66 | elif isinstance(module, nn.BatchNorm2d): 67 | nowd_params += list(module.parameters()) 68 | return wd_params, nowd_params 69 | 70 | 71 | class AttentionRefinementModule(nn.Module): 72 | def __init__(self, in_chan, out_chan, *args, **kwargs): 73 | super(AttentionRefinementModule, self).__init__() 74 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 75 | self.conv_atten = nn.Conv2d( 76 | out_chan, out_chan, kernel_size=1, bias=False) 77 | self.bn_atten = nn.BatchNorm2d(out_chan) 78 | self.sigmoid_atten = nn.Sigmoid() 79 | self.init_weight() 80 | 81 | def forward(self, x): 82 | feat = self.conv(x) 83 | atten = F.avg_pool2d(feat, feat.size()[2:]) 84 | atten = self.conv_atten(atten) 85 | atten = self.bn_atten(atten) 86 | atten = self.sigmoid_atten(atten) 87 | out = torch.mul(feat, atten) 88 | return out 89 | 90 | def init_weight(self): 91 | for ly in self.children(): 92 | if isinstance(ly, nn.Conv2d): 93 | nn.init.kaiming_normal_(ly.weight, a=1) 94 | if not ly.bias is None: 95 | nn.init.constant_(ly.bias, 0) 96 | 97 | 98 | class ContextPath(nn.Module): 99 | def __init__(self, *args, **kwargs): 100 | super(ContextPath, self).__init__() 101 | self.resnet = Resnet18() 102 | self.arm16 = AttentionRefinementModule(256, 128) 103 | self.arm32 = AttentionRefinementModule(512, 128) 104 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 105 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 106 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 107 | 108 | self.init_weight() 109 | 110 | def forward(self, x): 111 | H0, W0 = x.size()[2:] 112 | feat8, feat16, feat32 = self.resnet(x) 113 | H8, W8 = feat8.size()[2:] 114 | H16, W16 = feat16.size()[2:] 115 | H32, W32 = feat32.size()[2:] 116 | 117 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 118 | avg = self.conv_avg(avg) 119 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 120 | 121 | feat32_arm = self.arm32(feat32) 122 | feat32_sum = feat32_arm + avg_up 123 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 124 | feat32_up = self.conv_head32(feat32_up) 125 | 126 | feat16_arm = self.arm16(feat16) 127 | feat16_sum = feat16_arm + feat32_up 128 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 129 | feat16_up = self.conv_head16(feat16_up) 130 | 131 | return feat8, feat16_up, feat32_up # x8, x8, x16 132 | 133 | def init_weight(self): 134 | for ly in self.children(): 135 | if isinstance(ly, nn.Conv2d): 136 | nn.init.kaiming_normal_(ly.weight, a=1) 137 | if not ly.bias is None: 138 | nn.init.constant_(ly.bias, 0) 139 | 140 | def get_params(self): 141 | wd_params, nowd_params = [], [] 142 | for name, module in self.named_modules(): 143 | if isinstance(module, (nn.Linear, nn.Conv2d)): 144 | wd_params.append(module.weight) 145 | if not module.bias is None: 146 | nowd_params.append(module.bias) 147 | elif isinstance(module, nn.BatchNorm2d): 148 | nowd_params += list(module.parameters()) 149 | return wd_params, nowd_params 150 | 151 | 152 | # This is not used, since I replace this with the resnet feature with the same size 153 | class SpatialPath(nn.Module): 154 | def __init__(self, *args, **kwargs): 155 | super(SpatialPath, self).__init__() 156 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 157 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 158 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 159 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 160 | self.init_weight() 161 | 162 | def forward(self, x): 163 | feat = self.conv1(x) 164 | feat = self.conv2(feat) 165 | feat = self.conv3(feat) 166 | feat = self.conv_out(feat) 167 | return feat 168 | 169 | def init_weight(self): 170 | for ly in self.children(): 171 | if isinstance(ly, nn.Conv2d): 172 | nn.init.kaiming_normal_(ly.weight, a=1) 173 | if not ly.bias is None: 174 | nn.init.constant_(ly.bias, 0) 175 | 176 | def get_params(self): 177 | wd_params, nowd_params = [], [] 178 | for name, module in self.named_modules(): 179 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 180 | wd_params.append(module.weight) 181 | if not module.bias is None: 182 | nowd_params.append(module.bias) 183 | elif isinstance(module, nn.BatchNorm2d): 184 | nowd_params += list(module.parameters()) 185 | return wd_params, nowd_params 186 | 187 | 188 | class FeatureFusionModule(nn.Module): 189 | def __init__(self, in_chan, out_chan, *args, **kwargs): 190 | super(FeatureFusionModule, self).__init__() 191 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 192 | self.conv1 = nn.Conv2d(out_chan, 193 | out_chan//4, 194 | kernel_size=1, 195 | stride=1, 196 | padding=0, 197 | bias=False) 198 | self.conv2 = nn.Conv2d(out_chan//4, 199 | out_chan, 200 | kernel_size=1, 201 | stride=1, 202 | padding=0, 203 | bias=False) 204 | self.relu = nn.ReLU(inplace=True) 205 | self.sigmoid = nn.Sigmoid() 206 | self.init_weight() 207 | 208 | def forward(self, fsp, fcp): 209 | fcat = torch.cat([fsp, fcp], dim=1) 210 | feat = self.convblk(fcat) 211 | atten = F.avg_pool2d(feat, feat.size()[2:]) 212 | atten = self.conv1(atten) 213 | atten = self.relu(atten) 214 | atten = self.conv2(atten) 215 | atten = self.sigmoid(atten) 216 | feat_atten = torch.mul(feat, atten) 217 | feat_out = feat_atten + feat 218 | return feat_out 219 | 220 | def init_weight(self): 221 | for ly in self.children(): 222 | if isinstance(ly, nn.Conv2d): 223 | nn.init.kaiming_normal_(ly.weight, a=1) 224 | if not ly.bias is None: 225 | nn.init.constant_(ly.bias, 0) 226 | 227 | def get_params(self): 228 | wd_params, nowd_params = [], [] 229 | for name, module in self.named_modules(): 230 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 231 | wd_params.append(module.weight) 232 | if not module.bias is None: 233 | nowd_params.append(module.bias) 234 | elif isinstance(module, nn.BatchNorm2d): 235 | nowd_params += list(module.parameters()) 236 | return wd_params, nowd_params 237 | 238 | 239 | class BiSeNet(nn.Module): 240 | def __init__(self, n_classes, *args, **kwargs): 241 | super(BiSeNet, self).__init__() 242 | self.cp = ContextPath() 243 | # here self.sp is deleted 244 | self.ffm = FeatureFusionModule(256, 256) 245 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 246 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 247 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 248 | self.init_weight() 249 | 250 | def forward(self, x): 251 | H, W = x.size()[2:] 252 | feat_res8, feat_cp8, feat_cp16 = self.cp( 253 | x) # here return res3b1 feature 254 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 255 | feat_fuse = self.ffm(feat_sp, feat_cp8) 256 | 257 | feat_out = self.conv_out(feat_fuse) 258 | feat_out16 = self.conv_out16(feat_cp8) 259 | feat_out32 = self.conv_out32(feat_cp16) 260 | 261 | feat_out = F.interpolate( 262 | feat_out, (H, W), mode='bilinear', align_corners=True) 263 | feat_out16 = F.interpolate( 264 | feat_out16, (H, W), mode='bilinear', align_corners=True) 265 | feat_out32 = F.interpolate( 266 | feat_out32, (H, W), mode='bilinear', align_corners=True) 267 | return feat_out, feat_out16, feat_out32 268 | 269 | def init_weight(self): 270 | for ly in self.children(): 271 | if isinstance(ly, nn.Conv2d): 272 | nn.init.kaiming_normal_(ly.weight, a=1) 273 | if not ly.bias is None: 274 | nn.init.constant_(ly.bias, 0) 275 | 276 | def get_params(self): 277 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 278 | for name, child in self.named_children(): 279 | child_wd_params, child_nowd_params = child.get_params() 280 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 281 | lr_mul_wd_params += child_wd_params 282 | lr_mul_nowd_params += child_nowd_params 283 | else: 284 | wd_params += child_wd_params 285 | nowd_params += child_nowd_params 286 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 287 | 288 | 289 | if __name__ == "__main__": 290 | net = BiSeNet(19) 291 | net.cuda() 292 | net.eval() 293 | in_ten = torch.randn(16, 3, 640, 480).cuda() 294 | out, out16, out32 = net(in_ten) 295 | print(out.shape) 296 | 297 | net.get_params() 298 | -------------------------------------------------------------------------------- /encoder/perceptual_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import tensorflow as tf 3 | #import tensorflow_probability as tfp 4 | #tf.enable_eager_execution() 5 | 6 | import os 7 | import bz2 8 | import PIL.Image 9 | from PIL import ImageFilter 10 | import numpy as np 11 | from keras.models import Model 12 | from keras.utils import get_file 13 | from keras.applications.vgg16 import VGG16, preprocess_input 14 | import keras.backend as K 15 | import traceback 16 | import dnnlib.tflib as tflib 17 | 18 | def load_image(image, image_size=256, sharpen=False): 19 | loaded_images = list() 20 | img = image.convert('RGB') 21 | if image_size is not None: 22 | img = img.resize((image_size, image_size), PIL.Image.LANCZOS) 23 | if (sharpen): 24 | img = img.filter(ImageFilter.DETAIL) 25 | img = np.array(img) 26 | img = np.expand_dims(img, 0) 27 | loaded_images.append(img) 28 | loaded_images = np.vstack(loaded_images) 29 | return loaded_images 30 | 31 | def tf_custom_adaptive_loss(a,b): 32 | from adaptive import lossfun 33 | shape = a.get_shape().as_list() 34 | dim = np.prod(shape[1:]) 35 | a = tf.reshape(a, [-1, dim]) 36 | b = tf.reshape(b, [-1, dim]) 37 | loss, _, _ = lossfun(b-a, var_suffix='1') 38 | return tf.math.reduce_mean(loss) 39 | 40 | def tf_custom_adaptive_rgb_loss(a,b): 41 | from adaptive import image_lossfun 42 | loss, _, _ = image_lossfun(b-a, color_space='RGB', representation='PIXEL') 43 | return tf.math.reduce_mean(loss) 44 | 45 | def tf_custom_l1_loss(img1,img2): 46 | return tf.math.reduce_mean(tf.math.abs(img2-img1), axis=None) 47 | 48 | def tf_custom_logcosh_loss(img1,img2): 49 | return tf.math.reduce_mean(tf.keras.losses.logcosh(img1,img2)) 50 | 51 | def create_stub(batch_size): 52 | return tf.constant(0, dtype='float32', shape=(batch_size, 0)) 53 | 54 | def unpack_bz2(src_path): 55 | data = bz2.BZ2File(src_path).read() 56 | dst_path = src_path[:-4] 57 | with open(dst_path, 'wb') as fp: 58 | fp.write(data) 59 | return dst_path 60 | 61 | class PerceptualModel: 62 | def __init__(self, args, batch_size=1, perc_model=None, sess=None): 63 | self.sess = tf.get_default_session() if sess is None else sess 64 | K.set_session(self.sess) 65 | self.epsilon = 0.00000001 66 | self.lr = args.lr 67 | self.decay_rate = args.decay_rate 68 | self.decay_steps = args.decay_steps 69 | self.img_size = args.image_size 70 | self.layer = args.use_vgg_layer 71 | self.vgg_loss = args.use_vgg_loss 72 | if (self.layer <= 0 or self.vgg_loss <= self.epsilon): 73 | self.vgg_loss = None 74 | self.pixel_loss = args.use_pixel_loss 75 | if (self.pixel_loss <= self.epsilon): 76 | self.pixel_loss = None 77 | self.mssim_loss = args.use_mssim_loss 78 | if (self.mssim_loss <= self.epsilon): 79 | self.mssim_loss = None 80 | self.lpips_loss = args.use_lpips_loss 81 | if (self.lpips_loss <= self.epsilon): 82 | self.lpips_loss = None 83 | self.l1_penalty = args.use_l1_penalty 84 | if (self.l1_penalty <= self.epsilon): 85 | self.l1_penalty = None 86 | self.adaptive_loss = args.use_adaptive_loss 87 | self.sharpen_input = args.sharpen_input 88 | self.batch_size = batch_size 89 | if perc_model is not None and self.lpips_loss is not None: 90 | self.perc_model = perc_model 91 | else: 92 | self.perc_model = None 93 | self.ref_img = None 94 | self.ref_weight = None 95 | self.perceptual_model = None 96 | self.ref_img_features = None 97 | self.features_weight = None 98 | self.loss = None 99 | self.discriminator_loss = args.use_discriminator_loss 100 | if (self.discriminator_loss <= self.epsilon): 101 | self.discriminator_loss = None 102 | if self.discriminator_loss is not None: 103 | self.discriminator = None 104 | self.stub = create_stub(batch_size) 105 | 106 | def add_placeholder(self, var_name): 107 | var_val = getattr(self, var_name) 108 | setattr(self, var_name + "_placeholder", tf.placeholder(var_val.dtype, shape=var_val.get_shape())) 109 | setattr(self, var_name + "_op", var_val.assign(getattr(self, var_name + "_placeholder"))) 110 | 111 | def assign_placeholder(self, var_name, var_val): 112 | self.sess.run(getattr(self, var_name + "_op"), {getattr(self, var_name + "_placeholder"): var_val}) 113 | 114 | def build_perceptual_model(self, generator, discriminator=None): 115 | # Learning rate 116 | global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step") 117 | incremented_global_step = tf.assign_add(global_step, 1) 118 | self._reset_global_step = tf.assign(global_step, 0) 119 | self.learning_rate = tf.train.exponential_decay(self.lr, incremented_global_step, 120 | self.decay_steps, self.decay_rate, staircase=True) 121 | self.sess.run([self._reset_global_step]) 122 | 123 | if self.discriminator_loss is not None: 124 | self.discriminator = discriminator 125 | 126 | generated_image_tensor = generator.generated_image 127 | generated_image = tf.image.resize_nearest_neighbor(generated_image_tensor, 128 | (self.img_size, self.img_size), align_corners=True) 129 | 130 | self.ref_img = tf.get_variable('ref_img', shape=generated_image.shape, 131 | dtype='float32', initializer=tf.initializers.zeros()) 132 | self.ref_weight = tf.get_variable('ref_weight', shape=generated_image.shape, 133 | dtype='float32', initializer=tf.initializers.zeros()) 134 | self.add_placeholder("ref_img") 135 | self.add_placeholder("ref_weight") 136 | 137 | if (self.vgg_loss is not None): 138 | vgg16 = VGG16(include_top=False, input_shape=(self.img_size, self.img_size, 3)) 139 | self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output) 140 | generated_img_features = self.perceptual_model(preprocess_input(self.ref_weight * generated_image)) 141 | self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape, 142 | dtype='float32', initializer=tf.initializers.zeros()) 143 | self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape, 144 | dtype='float32', initializer=tf.initializers.zeros()) 145 | self.sess.run([self.features_weight.initializer, self.features_weight.initializer]) 146 | self.add_placeholder("ref_img_features") 147 | self.add_placeholder("features_weight") 148 | 149 | if self.perc_model is not None and self.lpips_loss is not None: 150 | img1 = tflib.convert_images_from_uint8(self.ref_weight * self.ref_img, nhwc_to_nchw=True) 151 | img2 = tflib.convert_images_from_uint8(self.ref_weight * generated_image, nhwc_to_nchw=True) 152 | 153 | self.loss = 0 154 | # L1 loss on VGG16 features 155 | if (self.vgg_loss is not None): 156 | if self.adaptive_loss: 157 | self.loss += self.vgg_loss * tf_custom_adaptive_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features) 158 | else: 159 | self.loss += self.vgg_loss * tf_custom_logcosh_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features) 160 | # + logcosh loss on image pixels 161 | if (self.pixel_loss is not None): 162 | if self.adaptive_loss: 163 | self.loss += self.pixel_loss * tf_custom_adaptive_rgb_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image) 164 | else: 165 | self.loss += self.pixel_loss * tf_custom_logcosh_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image) 166 | # + MS-SIM loss on image pixels 167 | if (self.mssim_loss is not None): 168 | self.loss += self.mssim_loss * tf.math.reduce_mean(1-tf.image.ssim_multiscale(self.ref_weight * self.ref_img, self.ref_weight * generated_image, 1)) 169 | # + extra perceptual loss on image pixels 170 | if self.perc_model is not None and self.lpips_loss is not None: 171 | self.loss += self.lpips_loss * tf.math.reduce_mean(self.perc_model.get_output_for(img1, img2)) 172 | # + L1 penalty on dlatent weights 173 | if self.l1_penalty is not None: 174 | self.loss += self.l1_penalty * 512 * tf.math.reduce_mean(tf.math.abs(generator.dlatent_variable-generator.get_dlatent_avg())) 175 | # discriminator loss (realism) 176 | if self.discriminator_loss is not None: 177 | self.loss += self.discriminator_loss * tf.math.reduce_mean(self.discriminator.get_output_for(tflib.convert_images_from_uint8(generated_image_tensor, nhwc_to_nchw=True), self.stub)) 178 | # - discriminator_network.get_output_for(tflib.convert_images_from_uint8(ref_img, nhwc_to_nchw=True), stub) 179 | 180 | 181 | def generate_face_mask(self, im): 182 | from imutils import face_utils 183 | import cv2 184 | rects = self.detector(im, 1) 185 | # loop over the face detections 186 | for (j, rect) in enumerate(rects): 187 | """ 188 | Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array 189 | """ 190 | shape = self.predictor(im, rect) 191 | shape = face_utils.shape_to_np(shape) 192 | 193 | # we extract the face 194 | vertices = cv2.convexHull(shape) 195 | mask = np.zeros(im.shape[:2],np.uint8) 196 | cv2.fillConvexPoly(mask, vertices, 1) 197 | if self.use_grabcut: 198 | bgdModel = np.zeros((1,65),np.float64) 199 | fgdModel = np.zeros((1,65),np.float64) 200 | rect = (0,0,im.shape[1],im.shape[2]) 201 | (x,y),radius = cv2.minEnclosingCircle(vertices) 202 | center = (int(x),int(y)) 203 | radius = int(radius*self.scale_mask) 204 | mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1) 205 | cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD) 206 | cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK) 207 | mask = np.where((mask==2)|(mask==0),0,1) 208 | return mask 209 | 210 | def set_reference_image(self, image): 211 | loaded_image = load_image(image, self.img_size, sharpen=self.sharpen_input) 212 | image_features = None 213 | if self.perceptual_model is not None: 214 | image_features = self.perceptual_model.predict_on_batch(preprocess_input(np.array(loaded_image))) 215 | weight_mask = np.ones(self.features_weight.shape) 216 | if image_features is not None: 217 | self.assign_placeholder("features_weight", weight_mask) 218 | self.assign_placeholder("ref_img_features", image_features) 219 | image_mask = np.ones(self.ref_weight.shape) 220 | self.assign_placeholder("ref_weight", image_mask) 221 | self.assign_placeholder("ref_img", loaded_image) 222 | 223 | def optimize(self, vars_to_optimize, iterations=200, use_optimizer='adam'): 224 | vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize] 225 | if use_optimizer == 'lbfgs': 226 | optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.loss, var_list=vars_to_optimize, method='L-BFGS-B', options={'maxiter': iterations}) 227 | else: 228 | if use_optimizer == 'ggt': 229 | optimizer = tf.contrib.opt.GGTOptimizer(learning_rate=self.learning_rate) 230 | else: 231 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 232 | min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize]) 233 | self.sess.run(tf.variables_initializer(optimizer.variables())) 234 | fetch_ops = [min_op, self.loss, self.learning_rate] 235 | #min_op = optimizer.minimize(self.sess) 236 | #optim_results = tfp.optimizer.lbfgs_minimize(make_val_and_grad_fn(get_loss), initial_position=vars_to_optimize, num_correction_pairs=10, tolerance=1e-8) 237 | self.sess.run(self._reset_global_step) 238 | #self.sess.graph.finalize() # Graph is read-only after this statement. 239 | for _ in range(iterations): 240 | if use_optimizer == 'lbfgs': 241 | optimizer.minimize(self.sess, fetches=[vars_to_optimize, self.loss]) 242 | yield {"loss":self.loss.eval()} 243 | else: 244 | _, loss, lr = self.sess.run(fetch_ops) 245 | yield {"loss":loss,"lr":lr} 246 | -------------------------------------------------------------------------------- /encoder/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /input/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test1.jpg -------------------------------------------------------------------------------- /input/test2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test2.jpeg -------------------------------------------------------------------------------- /input/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test3.jpg -------------------------------------------------------------------------------- /input/test4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/input/test4.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | #import project_image_without_optimizer as projector # much faster but effect worse 4 | import project_image as projector 5 | from encoder.model import BiSeNet 6 | import torch 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | import PIL.Image 10 | from PIL import ImageFilter 11 | import cv2 12 | from tools.face_alignment import image_align 13 | from tools.landmarks_detector import LandmarksDetector 14 | from tools import functions 15 | from skimage.measure import label 16 | 17 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): 18 | # Colors for all 20 parts 19 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 20 | [255, 0, 85], [255, 0, 170], 21 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 22 | [0, 255, 85], [0, 255, 170], 23 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 24 | [0, 85, 255], [0, 170, 255], 25 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 26 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 27 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 28 | 29 | im = np.array(im) 30 | vis_im = im.copy().astype(np.uint8) 31 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 32 | vis_parsing_anno = cv2.resize( 33 | vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 34 | 35 | vis_parsing_anno_color = np.zeros( 36 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 37 | mask = np.zeros( 38 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1]), dtype=np.uint8) 39 | num_of_class = np.max(vis_parsing_anno) 40 | 41 | idx = 11 42 | for pi in range(1, num_of_class + 1): 43 | index = np.where((vis_parsing_anno <= 5) & ( 44 | vis_parsing_anno >= 1) | ((vis_parsing_anno >= 10) & (vis_parsing_anno <= 13))) 45 | mask[index[0], index[1]] = 1 46 | return mask 47 | 48 | def find_max_region(bw_img): # find Maximum Connected Domain of parsing mask 49 | labeled_img, num = label(bw_img, background=0, return_num=True) 50 | max_label = 0 51 | max_num = 0 52 | for i in range(1, num + 1): 53 | if np.sum(labeled_img == i) > max_num: 54 | max_num = np.sum(labeled_img == i) 55 | max_label = i 56 | lcc = (labeled_img == max_label) 57 | return lcc 58 | 59 | 60 | def main(): 61 | """ 62 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step 63 | """ 64 | parser = argparse.ArgumentParser(description='Model Face Swap', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 65 | parser.add_argument('--input_img', type=str, default='input/test1.jpg', help='Directory with raw images for face swap') 66 | parser.add_argument('--output_dir', type=str, default='output/', help='Directory for storing changed images') 67 | parser.add_argument('--project_style', type=str, default='model', help='model/pop-star/kids/wanghong...') 68 | parser.add_argument('--record', type=bool, default=True, help='Recording process') 69 | parser.add_argument('--landmark_path', type=str, default='networks/shape_predictor_68_face_landmarks.dat', help='face landmark file path') 70 | parser.add_argument('--parsing_path', type=str, default='networks/79999_iter.pth', help='parsing model path') 71 | args, _ = parser.parse_known_args() 72 | 73 | landmarks_detector = LandmarksDetector(args.landmark_path) 74 | parse_net = BiSeNet(n_classes=19) 75 | parse_net.cuda() 76 | parse_net.load_state_dict(torch.load(args.parsing_path)) 77 | parse_net.eval() 78 | to_tensor = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 81 | ]) 82 | 83 | os.makedirs(args.output_dir, exist_ok=True) 84 | dst_path = os.path.join(args.output_dir, args.input_img.rsplit('/', 1)[1].split('.')[0])+'_to-'+args.project_style+'/' 85 | os.makedirs(dst_path, exist_ok=True) 86 | ori_img = cv2.imread(args.input_img) 87 | face_data = {'aligned_images': [], 'masks': [], 'crops': [], 'pads': [], 'quads': [], 'record_paths': []} 88 | print('Step1 - Face alignment and mask extraction...') 89 | for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(args.input_img), start=1): 90 | if i == 1: 91 | cv2.imwrite(dst_path + 'input.png', ori_img) 92 | if args.record: 93 | record_path = dst_path + 'face'+str(i) + '/' 94 | face_data['record_paths'].append(record_path) 95 | os.makedirs(record_path, exist_ok=True) 96 | 97 | # face aligned 98 | aligned_image, crop, pad, quad = image_align(args.input_img, face_landmarks, output_size=1024, x_scale=1, y_scale=1, em_scale=0.1) 99 | face_data['aligned_images'].append(aligned_image) 100 | face_data['crops'].append(crop) 101 | face_data['pads'].append(pad) 102 | face_data['quads'].append(quad) 103 | if args.record: 104 | aligned_image.save(record_path+'face_input.png', 'PNG') 105 | 106 | # mask extraction 107 | image_sharp = aligned_image.filter(ImageFilter.DETAIL) 108 | alinged_image_np = np.array(image_sharp) 109 | img = to_tensor(alinged_image_np) 110 | img = torch.unsqueeze(img, 0) 111 | img = img.cuda() 112 | out = parse_net(img)[0] 113 | parsing = out.detach().squeeze(0).cpu().numpy().argmax(0) 114 | mask = vis_parsing_maps(alinged_image_np, parsing, stride=1) 115 | mask = find_max_region(mask) 116 | mask = (255 * mask).astype('uint8') 117 | mask = PIL.Image.fromarray(mask, 'L') 118 | face_data['masks'].append(mask) 119 | if args.record: 120 | mask.save(record_path+'face_mask.png', 'PNG') 121 | 122 | print('Step2 - Face projection and mixing back...') 123 | projected_images, dlatents = projector.project(face_data['aligned_images'], face_data['masks'], args.project_style) 124 | merged_image = ori_img 125 | for projected_image, dlatent, crop, quad, pad, record_path, mask in zip(projected_images, dlatents, 126 | face_data['crops'], face_data['quads'], face_data['pads'], face_data['record_paths'], face_data['masks']): 127 | if args.record: 128 | projected_image.save(record_path+'face_output.png', 'PNG') 129 | np.save(record_path+'dlatent.npy', dlatent) 130 | merged_image = functions.merge_image(merged_image, projected_image, mask, crop, quad, pad) 131 | cv2.imwrite(dst_path+'output.png', merged_image) 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /networks/download_weights.txt: -------------------------------------------------------------------------------- 1 | 模型打包下载地址: 2 | 链接:https://pan.baidu.com/s/1yr_QNpHrXvq4PegMZBzDGA 3 | 提取码:v4xt 4 | 5 | -------------------------------------------------------------------------------- /pics/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/architecture.png -------------------------------------------------------------------------------- /pics/example_2kids.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/example_2kids.jpg -------------------------------------------------------------------------------- /pics/example_2wanghong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/example_2wanghong.png -------------------------------------------------------------------------------- /pics/examples_mix.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/examples_mix.jpg -------------------------------------------------------------------------------- /pics/multi-model-solution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/multi-model-solution.png -------------------------------------------------------------------------------- /pics/preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/preview.jpg -------------------------------------------------------------------------------- /pics/single_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/single_input.png -------------------------------------------------------------------------------- /pics/single_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a312863063/Model-Swap-Face/b38eb4e76fb83e6c960f559b4c398c6b8802e1d2/pics/single_output.png -------------------------------------------------------------------------------- /project_image.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import argparse 4 | import pickle 5 | from tqdm import tqdm 6 | import PIL.Image 7 | from PIL import ImageFilter 8 | import numpy as np 9 | import dnnlib 10 | import dnnlib.tflib as tflib 11 | import tensorflow as tf 12 | from encoder.generator_model import Generator 13 | from encoder.perceptual_model import PerceptualModel, load_image 14 | #from tensorflow.keras.models import load_model 15 | from keras.models import load_model 16 | from keras.applications.resnet50 import preprocess_input 17 | 18 | def split_to_batches(l, n): 19 | for i in range(0, len(l), n): 20 | yield l[i:i + n] 21 | 22 | def str2bool(v): 23 | if isinstance(v, bool): 24 | return v 25 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 26 | return True 27 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 28 | return False 29 | else: 30 | raise argparse.ArgumentTypeError('Boolean value expected.') 31 | 32 | def project(images, masks, projector_name): 33 | parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs') 35 | parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int) 36 | parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int) 37 | parser.add_argument('--optimizer', default='ggt', help='Optimization algorithm used for optimizing dlatents') 38 | 39 | # Perceptual model params 40 | parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int) 41 | parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int) 42 | parser.add_argument('--lr', default=0.25, help='Learning rate for perceptual model', type=float) 43 | parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float) 44 | parser.add_argument('--iterations', default=1000, help='Number of optimization steps for each batch', type=int) 45 | parser.add_argument('--decay_steps', default=4, help='Decay steps for learning rate decay (as a percent of iterations)', type=float) 46 | parser.add_argument('--early_stopping', default=True, help='Stop early once training stabilizes', type=str2bool, nargs='?', const=True) 47 | parser.add_argument('--early_stopping_threshold', default=0.5, help='Stop after this threshold has been reached', type=float) 48 | parser.add_argument('--early_stopping_patience', default=10, help='Number of iterations to wait below threshold', type=int) 49 | parser.add_argument('--load_resnet', default='networks/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents') 50 | parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True) 51 | parser.add_argument('--use_best_loss', default=True, help='Output the lowest loss value found as the solution', type=str2bool, nargs='?', const=True) 52 | parser.add_argument('--average_best_loss', default=0.25, help='Do a running weighted average with the previous best dlatents found', type=float) 53 | parser.add_argument('--sharpen_input', default=True, help='Sharpen the input images', type=str2bool, nargs='?', const=True) 54 | 55 | # Loss function options 56 | parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float) 57 | parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int) 58 | parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float) 59 | parser.add_argument('--use_mssim_loss', default=200, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.', type=float) 60 | parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.', type=float) 61 | parser.add_argument('--use_l1_penalty', default=0.5, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float) 62 | parser.add_argument('--use_discriminator_loss', default=0.5, help='Use trained discriminator to evaluate realism.', type=float) 63 | parser.add_argument('--use_adaptive_loss', default=False, help='Use the adaptive robust loss function from Google Research for pixel and VGG feature loss.', type=str2bool, nargs='?', const=True) 64 | 65 | # Generator params 66 | parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True) 67 | parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True) 68 | parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float) 69 | 70 | # Masking params 71 | parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int) 72 | 73 | args, other_args = parser.parse_known_args() 74 | args.decay_steps *= 0.01 * args.iterations # Calculate steps as a percent of total iterations 75 | 76 | # Initialize generator and perceptual model 77 | tflib.init_tf() 78 | with open('networks/karras2019stylegan-ffhq-1024x1024.pkl','rb') as f: 79 | generator_network, discriminator_network, Gs_network = pickle.load(f) 80 | 81 | generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise) 82 | if (args.dlatent_avg != ''): 83 | generator.set_dlatent_avg(np.load(args.dlatent_avg)) 84 | 85 | perc_model = None 86 | if (args.use_lpips_loss > 0.00000001): 87 | with open('networks/vgg16_zhang_perceptual.pkl', 'rb') as f: 88 | perc_model = pickle.load(f) 89 | perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size) 90 | perceptual_model.build_perceptual_model(generator, discriminator_network) 91 | 92 | ff_model = None 93 | 94 | # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space 95 | best_dlatents = [] 96 | for image, mask in zip(images, masks): 97 | perceptual_model.set_reference_image(image) 98 | dlatents = None 99 | if (ff_model is None): 100 | if os.path.exists(args.load_resnet): 101 | print("Loading ResNet Model:") 102 | ff_model = load_model(args.load_resnet) 103 | if (ff_model is not None): # predict initial dlatents with ResNet model 104 | if (args.use_preprocess_input): 105 | dlatents = ff_model.predict(preprocess_input(load_image(image, image_size=args.resnet_image_size))) 106 | else: 107 | dlatents = ff_model.predict(load_image(image,image_size=args.resnet_image_size)) 108 | if dlatents is not None: 109 | generator.set_dlatents(dlatents) 110 | op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, use_optimizer=args.optimizer) 111 | pbar = tqdm(op, leave=False, total=args.iterations) 112 | best_loss = None 113 | best_dlatent = None 114 | avg_loss_count = 0 115 | if args.early_stopping: 116 | avg_loss = prev_loss = None 117 | for loss_dict in pbar: 118 | if args.early_stopping: # early stopping feature 119 | if prev_loss is not None: 120 | if avg_loss is not None: 121 | avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"]) 122 | if avg_loss < args.early_stopping_threshold: # count while under threshold; else reset 123 | avg_loss_count += 1 124 | else: 125 | avg_loss_count = 0 126 | if avg_loss_count > args.early_stopping_patience: # stop once threshold is reached 127 | break 128 | else: 129 | avg_loss = prev_loss - loss_dict["loss"] 130 | pbar.set_description(" Oprimizing dlatent: " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()])) 131 | if best_loss is None or loss_dict["loss"] < best_loss: 132 | if best_dlatent is None or args.average_best_loss <= 0.00000001: 133 | best_dlatent = generator.get_dlatents() 134 | else: 135 | best_dlatent = 0.25 * best_dlatent + 0.75 * generator.get_dlatents() 136 | if args.use_best_loss: 137 | generator.set_dlatents(best_dlatent) 138 | best_loss = loss_dict["loss"] 139 | generator.stochastic_clip_dlatents() 140 | prev_loss = loss_dict["loss"] 141 | if not args.use_best_loss: 142 | best_loss = prev_loss 143 | best_dlatents.append(best_dlatent) 144 | print("\n Optimizing dlatent Best Loss {:.4f}".format(best_loss)) 145 | 146 | # Using Projector to generate images 147 | tflib.init_tf() 148 | with open('networks/projector_'+projector_name+'.pkl', 'rb') as f: 149 | Gs_network = pickle.load(f) 150 | generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, 151 | tiled_dlatent=args.tile_dlatents, model_res=args.model_res, 152 | randomize_noise=args.randomize_noise) 153 | imgs = [] 154 | for best_dlatent, image, mask in zip(best_dlatents, images, masks): 155 | generator.set_dlatents(best_dlatent) 156 | img_array = generator.generate_images()[0] 157 | generator.reset_dlatents() 158 | 159 | # Merge images with new face 160 | width, height = image.size 161 | mask = mask.resize((width, height)) 162 | mask = mask.filter(ImageFilter.GaussianBlur(args.composite_blur)) 163 | mask = np.array(mask) / 255 164 | mask = np.expand_dims(mask, axis=-1) 165 | img_array = mask * np.array(img_array) + (1.0 - mask) * np.array(image) 166 | img_array = img_array.astype(np.uint8) 167 | img = PIL.Image.fromarray(img_array, 'RGB') 168 | imgs.append(img) 169 | 170 | return imgs, best_dlatents -------------------------------------------------------------------------------- /project_image_without_optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import argparse 4 | import pickle 5 | from tqdm import tqdm 6 | import PIL.Image 7 | from PIL import ImageFilter 8 | import numpy as np 9 | import dnnlib.tflib as tflib 10 | from encoder.generator_model import Generator 11 | from encoder.perceptual_model import load_image 12 | from keras.models import load_model 13 | from keras.applications.resnet50 import preprocess_input 14 | 15 | def split_to_batches(l, n): 16 | for i in range(0, len(l), n): 17 | yield l[i:i + n] 18 | 19 | def str2bool(v): 20 | if isinstance(v, bool): 21 | return v 22 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 23 | return True 24 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 25 | return False 26 | else: 27 | raise argparse.ArgumentTypeError('Boolean value expected.') 28 | 29 | def project(images, masks, projector_name): 30 | parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs') 32 | parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int) 33 | parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int) 34 | 35 | # Perceptual model params 36 | parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int) 37 | parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int) 38 | parser.add_argument('--load_resnet', default='networks/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents') 39 | parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True) 40 | 41 | # Generator params 42 | parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True) 43 | parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True) 44 | parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float) 45 | 46 | # Masking params 47 | parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int) 48 | 49 | args, other_args = parser.parse_known_args() 50 | 51 | # Initialize generator and encoder model 52 | tflib.init_tf() 53 | with open('networks/projector_'+projector_name+'.pkl', 'rb') as f: 54 | projector = pickle.load(f) 55 | generator = Generator(projector, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise) 56 | if (args.dlatent_avg != ''): 57 | generator.set_dlatent_avg(np.load(args.dlatent_avg)) 58 | print(" Loading ResNet Model...") 59 | ff_model = load_model(args.load_resnet) 60 | 61 | # Find the dlatent of the image 62 | dlatents = [] 63 | imgs = [] 64 | for image, mask in zip(images, masks): 65 | if (args.use_preprocess_input): 66 | dlatent = ff_model.predict(preprocess_input((load_image(image, image_size=args.resnet_image_size)))) 67 | else: 68 | dlatent = ff_model.predict((load_image(image, image_size=args.resnet_image_size))) 69 | if dlatent is not None: 70 | generator.set_dlatents(dlatent) 71 | 72 | # Using Projector to generate images 73 | generator.set_dlatents(dlatent) 74 | generated_images = generator.generate_images() 75 | 76 | # Merge images with new face 77 | img_array = generated_images[0] 78 | ori_img = image 79 | width, height = ori_img.size 80 | mask = mask.resize((width, height)) 81 | mask = mask.filter(ImageFilter.GaussianBlur(args.composite_blur)) 82 | mask = np.array(mask) / 255 83 | mask = np.expand_dims(mask, axis=-1) 84 | img_array = mask * np.array(img_array) + (1.0 - mask) * np.array(ori_img) 85 | img_array = img_array.astype(np.uint8) 86 | img = PIL.Image.fromarray(img_array, 'RGB') 87 | 88 | imgs.append(img) 89 | dlatents.append(dlatent) 90 | 91 | return imgs, dlatents 92 | 93 | -------------------------------------------------------------------------------- /tools/face_alignment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | import os 4 | import PIL.Image 5 | from PIL import ImageDraw 6 | 7 | 8 | def image_align(src_file, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True, x_scale=1, y_scale=1, em_scale=0.1): 9 | # Align function from FFHQ dataset pre-processing step 10 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 11 | 12 | lm = np.array(face_landmarks) 13 | lm_chin = lm[0 : 17] # left-right 14 | lm_eyebrow_left = lm[17 : 22] # left-right 15 | lm_eyebrow_right = lm[22 : 27] # left-right 16 | lm_nose = lm[27 : 31] # top-down 17 | lm_nostrils = lm[31 : 36] # top-down 18 | lm_eye_left = lm[36 : 42] # left-clockwise 19 | lm_eye_right = lm[42 : 48] # left-clockwise 20 | lm_mouth_outer = lm[48 : 60] # left-clockwise 21 | lm_mouth_inner = lm[60 : 68] # left-clockwise 22 | 23 | # Calculate auxiliary vectors. 24 | eye_left = np.mean(lm_eye_left, axis=0) 25 | eye_right = np.mean(lm_eye_right, axis=0) 26 | eye_avg = (eye_left + eye_right) * 0.5 27 | eye_to_eye = eye_right - eye_left 28 | mouth_left = lm_mouth_outer[0] 29 | mouth_right = lm_mouth_outer[6] 30 | mouth_avg = (mouth_left + mouth_right) * 0.5 31 | eye_to_mouth = mouth_avg - eye_avg 32 | 33 | # Choose oriented crop rectangle. 34 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 35 | x /= np.hypot(*x) 36 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 37 | x *= x_scale 38 | y = np.flipud(x) * [-y_scale, y_scale] 39 | c = eye_avg + eye_to_mouth * em_scale 40 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 41 | qsize = np.hypot(*x) * 2 42 | 43 | # Load in-the-wild image. 44 | if not os.path.isfile(src_file): 45 | print('\nCannot find source image. Please run "--wilds" before "--align".') 46 | return 47 | img_bg = PIL.Image.open(src_file).convert('RGBA').convert('RGB') 48 | 49 | # Shrink. 50 | img = img_bg.copy() 51 | shrink = int(np.floor(qsize / output_size * 0.5)) 52 | if shrink > 1: 53 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 54 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 55 | quad /= shrink 56 | qsize /= shrink 57 | 58 | # Crop. 59 | border = max(int(np.rint(qsize * 0.1)), 3) 60 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 61 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 62 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 63 | img = img.crop(crop) 64 | bg_draw = ImageDraw.ImageDraw(img_bg) 65 | bg_draw.rectangle(crop, fill='white') 66 | quad -= crop[0:2] 67 | 68 | # Pad. 69 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 70 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 71 | if enable_padding and max(pad) > border - 4: 72 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 73 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 74 | 75 | h, w, _ = img.shape 76 | y, x, _ = np.ogrid[:h, :w, :1] 77 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 78 | blur = qsize * 0.02 79 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 80 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 81 | img = np.uint8(np.clip(np.rint(img), 0, 255)) 82 | img = PIL.Image.fromarray(img, 'RGB') 83 | quad += pad[:2] 84 | 85 | # Transform. 86 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 87 | 88 | if output_size < transform_size: 89 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 90 | 91 | # Return aligned image. 92 | return img,crop,pad,quad -------------------------------------------------------------------------------- /tools/functions.py: -------------------------------------------------------------------------------- 1 | import PIL.Image as Image 2 | import cv2 3 | import numpy as np 4 | import math 5 | 6 | def rotate(img, degree): 7 | height, width = img.shape[:2] 8 | heightNew = round(width * math.fabs(math.sin(math.radians(degree))) + height * math.fabs(math.cos(math.radians(degree)))) 9 | widthNew = round(height * math.fabs(math.sin(math.radians(degree))) + width * math.fabs(math.cos(math.radians(degree)))) 10 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 11 | matRotation[0, 2] += (widthNew - width) / 2 12 | matRotation[1, 2] += (heightNew - height) / 2 13 | imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew)) 14 | return imgRotation 15 | 16 | def merge_image(bg_img, fg_img, mask, crop, quad, pad): 17 | bg_img_ori = bg_img.copy() 18 | bg_img_alpha = cv2.cvtColor(bg_img, cv2.COLOR_BGR2BGRA) 19 | fg_img = cv2.cvtColor(np.asarray(fg_img), cv2.COLOR_RGB2BGR) 20 | mask = np.asarray(mask) 21 | line = int(round(max(quad[2][0]-quad[0][0], quad[3][0]-quad[1][0]))) 22 | radian = math.atan((quad[1][0]-quad[0][0])/(quad[1][1]-quad[0][1])) 23 | degree = math.degrees(radian) 24 | fg_img = rotate(fg_img, degree) 25 | fg_img = cv2.resize(fg_img, (line, line), interpolation=cv2.INTER_NEAREST) 26 | mask = rotate(mask, degree) 27 | mask = cv2.resize(mask, (line, line), interpolation=cv2.INTER_NEAREST) 28 | x1 = int(round(crop[0]-pad[0]+min([quad[0][0], quad[1][0], quad[2][0], quad[3][0]]))) 29 | y1 = int(round(crop[1]-pad[0]+min([quad[0][1], quad[1][1], quad[2][1], quad[3][1]]))) 30 | x2 = x1+line 31 | y2 = y1+line 32 | if x1 < 0: 33 | fg_img = fg_img[:, -x1:] 34 | mask = mask[:, -x1:] 35 | x1 = 0 36 | if y1 < 0: 37 | fg_img = fg_img[-y1:, :] 38 | mask = mask[-y1:, :] 39 | y1 = 0 40 | if x2 > bg_img.shape[1]: 41 | fg_img = fg_img[:, :-(x2-bg_img.shape[1])] 42 | mask = mask[:, :-(x2-bg_img.shape[1])] 43 | x2 = bg_img.shape[1] 44 | if y2 > bg_img.shape[0]: 45 | fg_img = fg_img[:-(y2 - bg_img.shape[0]), :] 46 | mask = mask[:-(y2 - bg_img.shape[0]), :] 47 | y2 = bg_img.shape[0] 48 | #alpha = cv2.erode(mask / 255.0, np.ones((3,3), np.uint8), iterations = 1) 49 | alpha = cv2.GaussianBlur(mask / 255.0, (5,5), 0) 50 | bg_img[y1:y2, x1:x2, 0] = (1. - alpha) * bg_img[y1:y2, x1:x2, 0] + alpha * fg_img[..., 0] 51 | bg_img[y1:y2, x1:x2, 1] = (1. - alpha) * bg_img[y1:y2, x1:x2, 1] + alpha * fg_img[..., 1] 52 | bg_img[y1:y2, x1:x2, 2] = (1. - alpha) * bg_img[y1:y2, x1:x2, 2] + alpha * fg_img[..., 2] 53 | bg_img[y1:y2, x1:x2] = cv2.fastNlMeansDenoisingColored(bg_img[y1:y2, x1:x2], None, 3.0, 3.0, 7, 21) 54 | 55 | # Seamlessly clone src into dst and put the results in output 56 | width, height, channels = bg_img_ori.shape 57 | center = (height // 2, width // 2) 58 | mask = 255 * np.ones(bg_img.shape, bg_img.dtype) 59 | normal_clone = cv2.seamlessClone(bg_img, bg_img_ori, mask, center, cv2.NORMAL_CLONE) 60 | 61 | return normal_clone 62 | 63 | 64 | def generate_face_mask(im, landmarks_detector): 65 | from imutils import face_utils 66 | rects = landmarks_detector.detector(im, 1) 67 | # loop over the face detections 68 | for (j, rect) in enumerate(rects): 69 | """ 70 | Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array 71 | """ 72 | shape = landmarks_detector.shape_predictor(im, rect) 73 | shape = face_utils.shape_to_np(shape) 74 | 75 | # we extract the face 76 | vertices = cv2.convexHull(shape) 77 | mask = np.zeros(im.shape[:2],np.uint8) 78 | cv2.fillConvexPoly(mask, vertices, 1) 79 | bgdModel = np.zeros((1,65),np.float64) 80 | fgdModel = np.zeros((1,65),np.float64) 81 | rect = (0,0,im.shape[1],im.shape[2]) 82 | (x,y),radius = cv2.minEnclosingCircle(vertices) 83 | center = (int(x), int(y)) 84 | radius = int(radius*1.4) 85 | mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1) 86 | cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD) 87 | cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK) 88 | mask = np.where((mask==2)|(mask==0),0,1) 89 | cv2.rectangle(mask, (0, 0), (mask.shape[1], mask.shape[0]), 0, thickness=10) 90 | return mask 91 | 92 | 93 | def generate_face_mask_without_hair(im, landmarks_detector, ie_polys=None): 94 | # get the mask of the image with only face area 95 | rects = landmarks_detector.detector(im, 1) 96 | image_landmarks = np.matrix([[p.x, p.y] for p in landmarks_detector.shape_predictor(im, rects[0]).parts()]) 97 | if image_landmarks.shape[0] != 68: 98 | raise Exception( 99 | 'get_image_hull_mask works only with 68 landmarks') 100 | int_lmrks = np.array(image_landmarks, dtype=np.int) 101 | 102 | # hull_mask = np.zeros(image_shape[0:2]+(1,), dtype=np.float32) 103 | hull_mask = np.full(im.shape[0:2] + (1,), 0, dtype=np.float32) 104 | 105 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 106 | np.concatenate((int_lmrks[0:9], 107 | int_lmrks[17:18]))), (1,)) 108 | 109 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 110 | np.concatenate((int_lmrks[8:17], 111 | int_lmrks[26:27]))), (1,)) 112 | 113 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 114 | np.concatenate((int_lmrks[17:20], 115 | int_lmrks[8:9]))), (1,)) 116 | 117 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 118 | np.concatenate((int_lmrks[24:27], 119 | int_lmrks[8:9]))), (1,)) 120 | 121 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 122 | np.concatenate((int_lmrks[19:25], 123 | int_lmrks[8:9], 124 | ))), (1,)) 125 | 126 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 127 | np.concatenate((int_lmrks[17:22], 128 | int_lmrks[27:28], 129 | int_lmrks[31:36], 130 | int_lmrks[8:9] 131 | ))), (1,)) 132 | 133 | cv2.fillConvexPoly(hull_mask, cv2.convexHull( 134 | np.concatenate((int_lmrks[22:27], 135 | int_lmrks[27:28], 136 | int_lmrks[31:36], 137 | int_lmrks[8:9] 138 | ))), (1,)) 139 | 140 | # nose 141 | cv2.fillConvexPoly( 142 | hull_mask, cv2.convexHull(int_lmrks[27:36]), (1,)) 143 | 144 | if ie_polys is not None: 145 | ie_polys.overlay_mask(hull_mask) 146 | hull_mask = hull_mask.squeeze() 147 | return hull_mask 148 | -------------------------------------------------------------------------------- /tools/landmarks_detector.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | 3 | 4 | class LandmarksDetector: 5 | def __init__(self, predictor_model_path): 6 | """ 7 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file 8 | """ 9 | self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used 10 | self.shape_predictor = dlib.shape_predictor(predictor_model_path) 11 | 12 | def get_landmarks(self, image): 13 | img = dlib.load_rgb_image(image) 14 | dets = self.detector(img, 1) 15 | 16 | for detection in dets: 17 | try: 18 | face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()] 19 | yield face_landmarks 20 | except: 21 | print("Exception in get_landmarks()!") 22 | --------------------------------------------------------------------------------